mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-20 07:46:42 -05:00
Compare commits
4 Commits
feat/bug-r
...
meta-insta
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5c31e50f3 | ||
|
|
aa3f106fb9 | ||
|
|
2e29605194 | ||
|
|
cacb456cb2 |
13
Cargo.lock
generated
13
Cargo.lock
generated
@@ -890,7 +890,7 @@ dependencies = [
|
|||||||
"delegate",
|
"delegate",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"extend",
|
"extend",
|
||||||
"futures",
|
"futures-lite",
|
||||||
"libp2p",
|
"libp2p",
|
||||||
"log",
|
"log",
|
||||||
"networking",
|
"networking",
|
||||||
@@ -914,6 +914,12 @@ dependencies = [
|
|||||||
"syn 2.0.111",
|
"syn 2.0.111",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fastrand"
|
||||||
|
version = "2.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ff"
|
name = "ff"
|
||||||
version = "0.13.1"
|
version = "0.13.1"
|
||||||
@@ -1022,7 +1028,10 @@ version = "2.6.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"fastrand",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-io",
|
||||||
|
"parking",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -2753,7 +2762,7 @@ dependencies = [
|
|||||||
"delegate",
|
"delegate",
|
||||||
"either",
|
"either",
|
||||||
"extend",
|
"extend",
|
||||||
"futures",
|
"futures-lite",
|
||||||
"futures-timer",
|
"futures-timer",
|
||||||
"keccak-const",
|
"keccak-const",
|
||||||
"libp2p",
|
"libp2p",
|
||||||
|
|||||||
@@ -29,14 +29,13 @@ util = { path = "rust/util" }
|
|||||||
# Macro dependecies
|
# Macro dependecies
|
||||||
extend = "1.2"
|
extend = "1.2"
|
||||||
delegate = "0.13"
|
delegate = "0.13"
|
||||||
pin-project = "1"
|
|
||||||
|
|
||||||
# Utility dependencies
|
# Utility dependencies
|
||||||
keccak-const = "0.2"
|
keccak-const = "0.2"
|
||||||
|
|
||||||
# Async dependencies
|
# Async dependencies
|
||||||
tokio = "1.46"
|
tokio = "1.46"
|
||||||
futures = "0.3"
|
futures-lite = "2.6.1"
|
||||||
futures-timer = "3.0"
|
futures-timer = "3.0"
|
||||||
|
|
||||||
# Data structures
|
# Data structures
|
||||||
|
|||||||
@@ -1,188 +0,0 @@
|
|||||||
<script lang="ts">
|
|
||||||
import { fade, fly } from "svelte/transition";
|
|
||||||
import { cubicOut } from "svelte/easing";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
isOpen: boolean;
|
|
||||||
onClose: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
let { isOpen, onClose }: Props = $props();
|
|
||||||
|
|
||||||
let bugReportId = $state<string | null>(null);
|
|
||||||
let githubIssueUrl = $state<string | null>(null);
|
|
||||||
let isLoading = $state(false);
|
|
||||||
let error = $state<string | null>(null);
|
|
||||||
|
|
||||||
async function generateBugReport() {
|
|
||||||
isLoading = true;
|
|
||||||
error = null;
|
|
||||||
try {
|
|
||||||
const response = await fetch("/bug-report", { method: "POST" });
|
|
||||||
if (!response.ok) {
|
|
||||||
error = "Failed to generate bug report. Please try again.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const data = await response.json();
|
|
||||||
bugReportId = data.bugReportId;
|
|
||||||
githubIssueUrl = data.githubIssueUrl;
|
|
||||||
} catch {
|
|
||||||
error = "Failed to connect to the server. Please try again.";
|
|
||||||
} finally {
|
|
||||||
isLoading = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleClose() {
|
|
||||||
bugReportId = null;
|
|
||||||
githubIssueUrl = null;
|
|
||||||
error = null;
|
|
||||||
isLoading = false;
|
|
||||||
onClose();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate bug report when modal opens
|
|
||||||
$effect(() => {
|
|
||||||
if (isOpen && !bugReportId && !isLoading) {
|
|
||||||
generateBugReport();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
|
|
||||||
{#if isOpen}
|
|
||||||
<!-- Backdrop -->
|
|
||||||
<div
|
|
||||||
class="fixed inset-0 z-50 bg-black/80 backdrop-blur-sm"
|
|
||||||
transition:fade={{ duration: 200 }}
|
|
||||||
onclick={handleClose}
|
|
||||||
role="presentation"
|
|
||||||
></div>
|
|
||||||
|
|
||||||
<!-- Modal -->
|
|
||||||
<div
|
|
||||||
class="fixed z-50 top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(90vw,480px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl overflow-hidden flex flex-col"
|
|
||||||
transition:fly={{ y: 20, duration: 300, easing: cubicOut }}
|
|
||||||
role="dialog"
|
|
||||||
aria-modal="true"
|
|
||||||
aria-label="Bug Report"
|
|
||||||
>
|
|
||||||
<!-- Header -->
|
|
||||||
<div
|
|
||||||
class="flex items-center justify-between px-5 py-4 border-b border-exo-medium-gray/30"
|
|
||||||
>
|
|
||||||
<div class="flex items-center gap-2">
|
|
||||||
<svg
|
|
||||||
class="w-5 h-5 text-exo-yellow"
|
|
||||||
fill="none"
|
|
||||||
viewBox="0 0 24 24"
|
|
||||||
stroke="currentColor"
|
|
||||||
stroke-width="2"
|
|
||||||
>
|
|
||||||
<path
|
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
<h2 class="text-sm font-mono text-exo-yellow tracking-wider uppercase">
|
|
||||||
Report a Bug
|
|
||||||
</h2>
|
|
||||||
</div>
|
|
||||||
<button
|
|
||||||
onclick={handleClose}
|
|
||||||
class="text-exo-light-gray hover:text-white transition-colors cursor-pointer"
|
|
||||||
aria-label="Close"
|
|
||||||
>
|
|
||||||
<svg
|
|
||||||
class="w-5 h-5"
|
|
||||||
fill="none"
|
|
||||||
viewBox="0 0 24 24"
|
|
||||||
stroke="currentColor"
|
|
||||||
stroke-width="2"
|
|
||||||
>
|
|
||||||
<path
|
|
||||||
stroke-linecap="round"
|
|
||||||
stroke-linejoin="round"
|
|
||||||
d="M6 18L18 6M6 6l12 12"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Body -->
|
|
||||||
<div class="px-5 py-5 space-y-4">
|
|
||||||
{#if isLoading}
|
|
||||||
<div class="flex items-center justify-center py-6">
|
|
||||||
<div
|
|
||||||
class="w-5 h-5 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"
|
|
||||||
></div>
|
|
||||||
<span class="ml-3 text-sm text-exo-light-gray font-mono"
|
|
||||||
>Generating bug report...</span
|
|
||||||
>
|
|
||||||
</div>
|
|
||||||
{:else if error}
|
|
||||||
<div
|
|
||||||
class="text-sm text-red-400 font-mono bg-red-400/10 border border-red-400/20 rounded px-4 py-3"
|
|
||||||
>
|
|
||||||
{error}
|
|
||||||
</div>
|
|
||||||
<button
|
|
||||||
onclick={generateBugReport}
|
|
||||||
class="w-full px-4 py-2.5 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded text-sm font-mono text-exo-yellow hover:border-exo-yellow/60 transition-colors cursor-pointer"
|
|
||||||
>
|
|
||||||
Try Again
|
|
||||||
</button>
|
|
||||||
{:else if bugReportId && githubIssueUrl}
|
|
||||||
<p class="text-sm text-exo-light-gray leading-relaxed">
|
|
||||||
Would you like to create a GitHub issue? This would help us track and
|
|
||||||
fix the issue for you.
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<!-- Bug Report ID -->
|
|
||||||
<div
|
|
||||||
class="bg-exo-black/50 border border-exo-medium-gray/30 rounded px-4 py-3"
|
|
||||||
>
|
|
||||||
<div
|
|
||||||
class="text-[11px] text-exo-light-gray/60 font-mono tracking-wider uppercase mb-1"
|
|
||||||
>
|
|
||||||
Bug Report ID
|
|
||||||
</div>
|
|
||||||
<div class="text-sm text-exo-yellow font-mono tracking-wide">
|
|
||||||
{bugReportId}
|
|
||||||
</div>
|
|
||||||
<div class="text-[11px] text-exo-light-gray/50 font-mono mt-1">
|
|
||||||
Include this ID when communicating with the team.
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<p class="text-xs text-exo-light-gray/60 leading-relaxed">
|
|
||||||
No diagnostic data is attached. The issue template contains
|
|
||||||
placeholder fields for you to fill in.
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<!-- Actions -->
|
|
||||||
<div class="flex gap-3 pt-1">
|
|
||||||
<a
|
|
||||||
href={githubIssueUrl}
|
|
||||||
target="_blank"
|
|
||||||
rel="noopener noreferrer"
|
|
||||||
class="flex-1 flex items-center justify-center gap-2 px-4 py-2.5 bg-exo-yellow/10 border border-exo-yellow/40 rounded text-sm font-mono text-exo-yellow hover:bg-exo-yellow/20 hover:border-exo-yellow/60 transition-colors"
|
|
||||||
>
|
|
||||||
<svg class="w-4 h-4" viewBox="0 0 16 16" fill="currentColor">
|
|
||||||
<path
|
|
||||||
d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
Create GitHub Issue
|
|
||||||
</a>
|
|
||||||
<button
|
|
||||||
onclick={handleClose}
|
|
||||||
class="px-4 py-2.5 border border-exo-medium-gray/40 rounded text-sm font-mono text-exo-light-gray hover:border-exo-medium-gray/60 transition-colors cursor-pointer"
|
|
||||||
>
|
|
||||||
Close
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
{/if}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{/if}
|
|
||||||
@@ -74,7 +74,6 @@
|
|||||||
perSystem =
|
perSystem =
|
||||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||||
let
|
let
|
||||||
fenixToolchain = inputs'.fenix.packages.complete;
|
|
||||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||||
in
|
in
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
|
|
||||||
#allowed-duplicate-crates = ["hashbrown"]
|
|
||||||
@@ -27,7 +27,7 @@ networking = { workspace = true }
|
|||||||
# interop
|
# interop
|
||||||
pyo3 = { version = "0.27.2", features = [
|
pyo3 = { version = "0.27.2", features = [
|
||||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||||
"nightly", # enables better-supported GIL integration
|
# "nightly", # enables better-supported GIL integration
|
||||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||||
@@ -45,11 +45,10 @@ pyo3-log = "0.13.2"
|
|||||||
# macro dependencies
|
# macro dependencies
|
||||||
extend = { workspace = true }
|
extend = { workspace = true }
|
||||||
delegate = { workspace = true }
|
delegate = { workspace = true }
|
||||||
pin-project = { workspace = true }
|
|
||||||
|
|
||||||
# async runtime
|
# async runtime
|
||||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||||
futures = { workspace = true }
|
futures-lite = { workspace = true }
|
||||||
|
|
||||||
# utility dependencies
|
# utility dependencies
|
||||||
util = { workspace = true }
|
util = { workspace = true }
|
||||||
@@ -60,3 +59,4 @@ env_logger = "0.11"
|
|||||||
|
|
||||||
# Networking
|
# Networking
|
||||||
libp2p = { workspace = true, features = ["full"] }
|
libp2p = { workspace = true, features = ["full"] }
|
||||||
|
pin-project = "1.1.10"
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
//!
|
//!
|
||||||
|
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use pyo3::marker::Ungil;
|
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use std::{
|
use std::{
|
||||||
future::Future,
|
future::Future,
|
||||||
@@ -26,8 +25,8 @@ where
|
|||||||
|
|
||||||
impl<F> Future for AllowThreads<F>
|
impl<F> Future for AllowThreads<F>
|
||||||
where
|
where
|
||||||
F: Future + Ungil,
|
F: Future + Send,
|
||||||
F::Output: Ungil,
|
F::Output: Send,
|
||||||
{
|
{
|
||||||
type Output = F::Output;
|
type Output = F::Output;
|
||||||
|
|
||||||
|
|||||||
@@ -4,25 +4,12 @@
|
|||||||
//!
|
//!
|
||||||
//!
|
//!
|
||||||
|
|
||||||
// enable Rust-unstable features for convenience
|
|
||||||
#![feature(trait_alias)]
|
|
||||||
#![feature(tuple_trait)]
|
|
||||||
#![feature(unboxed_closures)]
|
|
||||||
// #![feature(stmt_expr_attributes)]
|
|
||||||
// #![feature(assert_matches)]
|
|
||||||
// #![feature(async_fn_in_dyn_trait)]
|
|
||||||
// #![feature(async_for_loop)]
|
|
||||||
// #![feature(auto_traits)]
|
|
||||||
// #![feature(negative_impls)]
|
|
||||||
|
|
||||||
extern crate core;
|
|
||||||
mod allow_threading;
|
mod allow_threading;
|
||||||
pub(crate) mod networking;
|
mod ident;
|
||||||
pub(crate) mod pylibp2p;
|
mod networking;
|
||||||
|
|
||||||
|
use crate::ident::ident_submodule;
|
||||||
use crate::networking::networking_submodule;
|
use crate::networking::networking_submodule;
|
||||||
use crate::pylibp2p::ident::ident_submodule;
|
|
||||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
|
||||||
use pyo3::prelude::PyModule;
|
use pyo3::prelude::PyModule;
|
||||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||||
@@ -32,14 +19,6 @@ pub(crate) mod r#const {
|
|||||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Namespace for all the type/trait aliases used by this crate.
|
|
||||||
pub(crate) mod alias {
|
|
||||||
use std::marker::Tuple;
|
|
||||||
|
|
||||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
|
||||||
Fn<Args, Output = Output> + Send + 'static;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Namespace for crate-wide extension traits/methods
|
/// Namespace for crate-wide extension traits/methods
|
||||||
pub(crate) mod ext {
|
pub(crate) mod ext {
|
||||||
use crate::allow_threading::AllowThreads;
|
use crate::allow_threading::AllowThreads;
|
||||||
@@ -180,7 +159,6 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|||||||
// work with maturin, where the types generate correctly, in the right folder, without
|
// work with maturin, where the types generate correctly, in the right folder, without
|
||||||
// too many importing issues...
|
// too many importing issues...
|
||||||
ident_submodule(m)?;
|
ident_submodule(m)?;
|
||||||
multiaddr_submodule(m)?;
|
|
||||||
networking_submodule(m)?;
|
networking_submodule(m)?;
|
||||||
|
|
||||||
// top-level constructs
|
// top-level constructs
|
||||||
|
|||||||
@@ -8,8 +8,8 @@
|
|||||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||||
|
use crate::ident::{PyKeypair, PyPeerId};
|
||||||
use crate::pyclass;
|
use crate::pyclass;
|
||||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
|
||||||
use libp2p::futures::StreamExt as _;
|
use libp2p::futures::StreamExt as _;
|
||||||
use libp2p::gossipsub;
|
use libp2p::gossipsub;
|
||||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
|
||||||
//!
|
|
||||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
|
||||||
//! independent identity type of some kind or another. This may require handshaking.
|
|
||||||
//!
|
|
||||||
|
|
||||||
pub mod ident;
|
|
||||||
pub mod multiaddr;
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
use crate::ext::ResultExt as _;
|
|
||||||
use libp2p::Multiaddr;
|
|
||||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
|
||||||
use pyo3::types::PyBytes;
|
|
||||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
|
||||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
|
||||||
use std::str::FromStr as _;
|
|
||||||
|
|
||||||
/// Representation of a Multiaddr.
|
|
||||||
#[gen_stub_pyclass]
|
|
||||||
#[pyclass(name = "Multiaddr", frozen)]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
#[repr(transparent)]
|
|
||||||
pub struct PyMultiaddr(pub Multiaddr);
|
|
||||||
|
|
||||||
#[gen_stub_pymethods]
|
|
||||||
#[pymethods]
|
|
||||||
#[allow(clippy::needless_pass_by_value)]
|
|
||||||
impl PyMultiaddr {
|
|
||||||
/// Create a new, empty multiaddress.
|
|
||||||
#[staticmethod]
|
|
||||||
fn empty() -> Self {
|
|
||||||
Self(Multiaddr::empty())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new, empty multiaddress with the given capacity.
|
|
||||||
#[staticmethod]
|
|
||||||
fn with_capacity(n: usize) -> Self {
|
|
||||||
Self(Multiaddr::with_capacity(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
|
||||||
#[staticmethod]
|
|
||||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
|
||||||
let bytes = Vec::from(bytes.as_bytes());
|
|
||||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse a `Multiaddr` value from its string representation.
|
|
||||||
#[staticmethod]
|
|
||||||
fn from_string(string: String) -> PyResult<Self> {
|
|
||||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return the length in bytes of this multiaddress.
|
|
||||||
fn len(&self) -> usize {
|
|
||||||
self.0.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true if the length of this multiaddress is 0.
|
|
||||||
fn is_empty(&self) -> bool {
|
|
||||||
self.0.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
|
||||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
|
||||||
let bytes = self.0.to_vec();
|
|
||||||
PyBytes::new(py, &bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert a Multiaddr to a string.
|
|
||||||
fn to_string(&self) -> String {
|
|
||||||
self.0.to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gen_stub(skip)]
|
|
||||||
fn __repr__(&self) -> String {
|
|
||||||
format!("Multiaddr({})", self.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gen_stub(skip)]
|
|
||||||
fn __str__(&self) -> String {
|
|
||||||
self.to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
||||||
m.add_class::<PyMultiaddr>()?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@@ -22,7 +22,7 @@ delegate = { workspace = true }
|
|||||||
|
|
||||||
# async
|
# async
|
||||||
tokio = { workspace = true, features = ["full"] }
|
tokio = { workspace = true, features = ["full"] }
|
||||||
futures = { workspace = true }
|
futures-lite = { workspace = true }
|
||||||
futures-timer = { workspace = true }
|
futures-timer = { workspace = true }
|
||||||
|
|
||||||
# utility dependencies
|
# utility dependencies
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use futures::stream::StreamExt as _;
|
use futures_lite::StreamExt;
|
||||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||||
use networking::{discovery, swarm};
|
use networking::{discovery, swarm};
|
||||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||||
@@ -38,19 +38,19 @@ async fn main() {
|
|||||||
println!("Publish error: {e:?}");
|
println!("Publish error: {e:?}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
event = swarm.select_next_some() => match event {
|
event = swarm.next() => match event {
|
||||||
// on gossipsub incoming
|
// on gossipsub incoming
|
||||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||||
propagation_source: peer_id,
|
propagation_source: peer_id,
|
||||||
message_id: id,
|
message_id: id,
|
||||||
message,
|
message,
|
||||||
})) => println!(
|
}))) => println!(
|
||||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||||
String::from_utf8_lossy(&message.data),
|
String::from_utf8_lossy(&message.data),
|
||||||
),
|
),
|
||||||
|
|
||||||
// on discovery
|
// on discovery
|
||||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
|
||||||
discovery::Event::ConnectionEstablished {
|
discovery::Event::ConnectionEstablished {
|
||||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||||
} => {
|
} => {
|
||||||
@@ -64,7 +64,7 @@ async fn main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ignore outgoing errors: those are normal
|
// ignore outgoing errors: those are normal
|
||||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||||
|
|
||||||
// otherwise log any other event
|
// otherwise log any other event
|
||||||
e => { log::info!("Other event {e:?}"); }
|
e => { log::info!("Other event {e:?}"); }
|
||||||
|
|||||||
@@ -1,127 +0,0 @@
|
|||||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
|
||||||
//
|
|
||||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
|
||||||
// copy of this software and associated documentation files (the "Software"),
|
|
||||||
// to deal in the Software without restriction, including without limitation
|
|
||||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
||||||
// and/or sell copies of the Software, and to permit persons to whom the
|
|
||||||
// Software is furnished to do so, subject to the following conditions:
|
|
||||||
//
|
|
||||||
// The above copyright notice and this permission notice shall be included in
|
|
||||||
// all copies or substantial portions of the Software.
|
|
||||||
//
|
|
||||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
||||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
||||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|
||||||
// DEALINGS IN THE SOFTWARE.
|
|
||||||
|
|
||||||
use futures::stream::StreamExt;
|
|
||||||
use libp2p::{
|
|
||||||
gossipsub, mdns, noise,
|
|
||||||
swarm::{NetworkBehaviour, SwarmEvent},
|
|
||||||
tcp, yamux,
|
|
||||||
};
|
|
||||||
use std::error::Error;
|
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::{io, io::AsyncBufReadExt, select};
|
|
||||||
use tracing_subscriber::EnvFilter;
|
|
||||||
|
|
||||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
|
||||||
#[derive(NetworkBehaviour)]
|
|
||||||
struct MyBehaviour {
|
|
||||||
gossipsub: gossipsub::Behaviour,
|
|
||||||
mdns: mdns::tokio::Behaviour,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> Result<(), Box<dyn Error>> {
|
|
||||||
let _ = tracing_subscriber::fmt()
|
|
||||||
.with_env_filter(EnvFilter::from_default_env())
|
|
||||||
.try_init();
|
|
||||||
|
|
||||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
|
||||||
.with_tokio()
|
|
||||||
.with_tcp(
|
|
||||||
tcp::Config::default(),
|
|
||||||
noise::Config::new,
|
|
||||||
yamux::Config::default,
|
|
||||||
)?
|
|
||||||
.with_behaviour(|key| {
|
|
||||||
// Set a custom gossipsub configuration
|
|
||||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
|
||||||
.heartbeat_interval(Duration::from_secs(10))
|
|
||||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
|
||||||
.build()
|
|
||||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
|
||||||
|
|
||||||
// build a gossipsub network behaviour
|
|
||||||
let gossipsub = gossipsub::Behaviour::new(
|
|
||||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
|
||||||
gossipsub_config,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mdns =
|
|
||||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
|
||||||
Ok(MyBehaviour { gossipsub, mdns })
|
|
||||||
})?
|
|
||||||
.build();
|
|
||||||
|
|
||||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
|
||||||
|
|
||||||
// Create a Gossipsub topic
|
|
||||||
let topic = gossipsub::IdentTopic::new("test-net");
|
|
||||||
// subscribes to our topic
|
|
||||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
|
||||||
|
|
||||||
// Read full lines from stdin
|
|
||||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
|
||||||
|
|
||||||
// Listen on all interfaces and whatever port the OS assigns
|
|
||||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
|
||||||
|
|
||||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
|
||||||
|
|
||||||
// Kick it off
|
|
||||||
loop {
|
|
||||||
select! {
|
|
||||||
Ok(Some(line)) = stdin.next_line() => {
|
|
||||||
if let Err(e) = swarm
|
|
||||||
.behaviour_mut().gossipsub
|
|
||||||
.publish(topic.clone(), line.as_bytes()) {
|
|
||||||
println!("Publish error: {e:?}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
event = swarm.select_next_some() => match event {
|
|
||||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
|
||||||
for (peer_id, multiaddr) in list {
|
|
||||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
|
||||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
|
||||||
for (peer_id, multiaddr) in list {
|
|
||||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
|
||||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
|
||||||
propagation_source: peer_id,
|
|
||||||
message_id: id,
|
|
||||||
message,
|
|
||||||
})) => println!(
|
|
||||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
|
||||||
String::from_utf8_lossy(&message.data),
|
|
||||||
),
|
|
||||||
SwarmEvent::NewListenAddr { address, .. } => {
|
|
||||||
println!("Local node is listening on {address}");
|
|
||||||
}
|
|
||||||
e => {
|
|
||||||
println!("Other swarm event: {:?}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
use crate::ext::MultiaddrExt;
|
use crate::ext::MultiaddrExt;
|
||||||
use delegate::delegate;
|
use delegate::delegate;
|
||||||
use either::Either;
|
use either::Either;
|
||||||
use futures::FutureExt;
|
use futures_lite::FutureExt;
|
||||||
use futures_timer::Delay;
|
use futures_timer::Delay;
|
||||||
use libp2p::core::transport::PortUse;
|
use libp2p::core::transport::PortUse;
|
||||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||||
@@ -362,7 +362,7 @@ impl NetworkBehaviour for Behaviour {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
if self.retry_delay.poll(cx).is_ready() {
|
||||||
for (p, mas) in self.mdns_discovered.clone() {
|
for (p, mas) in self.mdns_discovered.clone() {
|
||||||
for ma in mas {
|
for ma in mas {
|
||||||
self.dial(p, ma)
|
self.dial(p, ma)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
|||||||
mod transport {
|
mod transport {
|
||||||
use crate::alias;
|
use crate::alias;
|
||||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||||
use futures::{AsyncRead, AsyncWrite};
|
use futures_lite::{AsyncRead, AsyncWrite};
|
||||||
use keccak_const::Sha3_256;
|
use keccak_const::Sha3_256;
|
||||||
use libp2p::core::muxing;
|
use libp2p::core::muxing;
|
||||||
use libp2p::core::transport::Boxed;
|
use libp2p::core::transport::Boxed;
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
{ inputs, ... }:
|
{ inputs, ... }:
|
||||||
{
|
{
|
||||||
perSystem =
|
perSystem =
|
||||||
{ config, self', inputs', pkgs, lib, ... }:
|
{ inputs', pkgs, lib, ... }:
|
||||||
let
|
let
|
||||||
# Fenix nightly toolchain with all components
|
# Fenix nightly toolchain with all components
|
||||||
fenixPkgs = inputs'.fenix.packages;
|
rustToolchain = inputs'.fenix.packages.stable.withComponents [
|
||||||
rustToolchain = fenixPkgs.complete.withComponents [
|
|
||||||
"cargo"
|
"cargo"
|
||||||
"rustc"
|
"rustc"
|
||||||
"clippy"
|
"clippy"
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
[toolchain]
|
|
||||||
channel = "nightly"
|
|
||||||
@@ -31,6 +31,7 @@ from exo.shared.types.openai_responses import (
|
|||||||
ResponseOutputText,
|
ResponseOutputText,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ResponsesResponse,
|
ResponsesResponse,
|
||||||
|
ResponsesStreamEvent,
|
||||||
ResponseTextDeltaEvent,
|
ResponseTextDeltaEvent,
|
||||||
ResponseTextDoneEvent,
|
ResponseTextDoneEvent,
|
||||||
ResponseUsage,
|
ResponseUsage,
|
||||||
@@ -38,6 +39,11 @@ from exo.shared.types.openai_responses import (
|
|||||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||||
|
|
||||||
|
|
||||||
|
def _format_sse(event: ResponsesStreamEvent) -> str:
|
||||||
|
"""Format a streaming event as an SSE message."""
|
||||||
|
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
|
||||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -219,13 +225,13 @@ async def generate_responses_stream(
|
|||||||
created_event = ResponseCreatedEvent(
|
created_event = ResponseCreatedEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
yield _format_sse(created_event)
|
||||||
|
|
||||||
# response.in_progress
|
# response.in_progress
|
||||||
in_progress_event = ResponseInProgressEvent(
|
in_progress_event = ResponseInProgressEvent(
|
||||||
sequence_number=next(seq), response=initial_response
|
sequence_number=next(seq), response=initial_response
|
||||||
)
|
)
|
||||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
yield _format_sse(in_progress_event)
|
||||||
|
|
||||||
# response.output_item.added
|
# response.output_item.added
|
||||||
initial_item = ResponseMessageItem(
|
initial_item = ResponseMessageItem(
|
||||||
@@ -236,7 +242,7 @@ async def generate_responses_stream(
|
|||||||
item_added = ResponseOutputItemAddedEvent(
|
item_added = ResponseOutputItemAddedEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=initial_item
|
sequence_number=next(seq), output_index=0, item=initial_item
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
yield _format_sse(item_added)
|
||||||
|
|
||||||
# response.content_part.added
|
# response.content_part.added
|
||||||
initial_part = ResponseOutputText(text="")
|
initial_part = ResponseOutputText(text="")
|
||||||
@@ -247,7 +253,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=initial_part,
|
part=initial_part,
|
||||||
)
|
)
|
||||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
yield _format_sse(part_added)
|
||||||
|
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
function_call_items: list[ResponseFunctionCallItem] = []
|
function_call_items: list[ResponseFunctionCallItem] = []
|
||||||
@@ -281,7 +287,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_item,
|
item=fc_item,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
yield _format_sse(fc_added)
|
||||||
|
|
||||||
# response.function_call_arguments.delta
|
# response.function_call_arguments.delta
|
||||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||||
@@ -290,7 +296,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
delta=tool.arguments,
|
delta=tool.arguments,
|
||||||
)
|
)
|
||||||
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
yield _format_sse(args_delta)
|
||||||
|
|
||||||
# response.function_call_arguments.done
|
# response.function_call_arguments.done
|
||||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||||
@@ -300,7 +306,7 @@ async def generate_responses_stream(
|
|||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool.arguments,
|
arguments=tool.arguments,
|
||||||
)
|
)
|
||||||
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
yield _format_sse(args_done)
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
fc_done_item = ResponseFunctionCallItem(
|
fc_done_item = ResponseFunctionCallItem(
|
||||||
@@ -315,7 +321,7 @@ async def generate_responses_stream(
|
|||||||
output_index=next_output_index,
|
output_index=next_output_index,
|
||||||
item=fc_done_item,
|
item=fc_done_item,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
yield _format_sse(fc_item_done)
|
||||||
|
|
||||||
function_call_items.append(fc_done_item)
|
function_call_items.append(fc_done_item)
|
||||||
next_output_index += 1
|
next_output_index += 1
|
||||||
@@ -331,7 +337,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
delta=chunk.text,
|
delta=chunk.text,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
yield _format_sse(delta_event)
|
||||||
|
|
||||||
# response.output_text.done
|
# response.output_text.done
|
||||||
text_done = ResponseTextDoneEvent(
|
text_done = ResponseTextDoneEvent(
|
||||||
@@ -341,7 +347,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
text=accumulated_text,
|
text=accumulated_text,
|
||||||
)
|
)
|
||||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
yield _format_sse(text_done)
|
||||||
|
|
||||||
# response.content_part.done
|
# response.content_part.done
|
||||||
final_part = ResponseOutputText(text=accumulated_text)
|
final_part = ResponseOutputText(text=accumulated_text)
|
||||||
@@ -352,7 +358,7 @@ async def generate_responses_stream(
|
|||||||
content_index=0,
|
content_index=0,
|
||||||
part=final_part,
|
part=final_part,
|
||||||
)
|
)
|
||||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
yield _format_sse(part_done)
|
||||||
|
|
||||||
# response.output_item.done
|
# response.output_item.done
|
||||||
final_message_item = ResponseMessageItem(
|
final_message_item = ResponseMessageItem(
|
||||||
@@ -363,7 +369,7 @@ async def generate_responses_stream(
|
|||||||
item_done = ResponseOutputItemDoneEvent(
|
item_done = ResponseOutputItemDoneEvent(
|
||||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||||
)
|
)
|
||||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
yield _format_sse(item_done)
|
||||||
|
|
||||||
# Create usage from usage data if available
|
# Create usage from usage data if available
|
||||||
usage = None
|
usage = None
|
||||||
@@ -388,4 +394,4 @@ async def generate_responses_stream(
|
|||||||
completed_event = ResponseCompletedEvent(
|
completed_event = ResponseCompletedEvent(
|
||||||
sequence_number=next(seq), response=final_response
|
sequence_number=next(seq), response=final_response
|
||||||
)
|
)
|
||||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
yield _format_sse(completed_event)
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ from exo.shared.types.events import (
|
|||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
InputChunkReceived,
|
InputChunkReceived,
|
||||||
InstanceDeleted,
|
InstanceDeleted,
|
||||||
|
JacclSideChannelData,
|
||||||
|
JacclSideChannelGathered,
|
||||||
NodeGatheredInfo,
|
NodeGatheredInfo,
|
||||||
NodeTimedOut,
|
NodeTimedOut,
|
||||||
TaskCreated,
|
TaskCreated,
|
||||||
@@ -60,6 +62,7 @@ from exo.shared.types.tasks import (
|
|||||||
TextGeneration as TextGenerationTask,
|
TextGeneration as TextGenerationTask,
|
||||||
)
|
)
|
||||||
from exo.shared.types.worker.instances import InstanceId
|
from exo.shared.types.worker.instances import InstanceId
|
||||||
|
from exo.shared.types.worker.runners import RunnerId
|
||||||
from exo.utils.channels import Receiver, Sender, channel
|
from exo.utils.channels import Receiver, Sender, channel
|
||||||
from exo.utils.event_buffer import MultiSourceBuffer
|
from exo.utils.event_buffer import MultiSourceBuffer
|
||||||
|
|
||||||
@@ -94,6 +97,7 @@ class Master:
|
|||||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||||
|
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
logger.info("Starting Master")
|
logger.info("Starting Master")
|
||||||
@@ -407,6 +411,11 @@ class Master:
|
|||||||
self._event_log.append(event)
|
self._event_log.append(event)
|
||||||
await self._send_event(indexed)
|
await self._send_event(indexed)
|
||||||
|
|
||||||
|
# After broadcasting JacclSideChannelData, accumulate and
|
||||||
|
# emit gathered result when all runners have contributed.
|
||||||
|
if isinstance(event, JacclSideChannelData):
|
||||||
|
await self._handle_jaccl_side_channel(event)
|
||||||
|
|
||||||
async def _loopback_processor(self) -> None:
|
async def _loopback_processor(self) -> None:
|
||||||
# this would ideally not be necessary.
|
# this would ideally not be necessary.
|
||||||
# this is WAY less hacky than how I was working around this before
|
# this is WAY less hacky than how I was working around this before
|
||||||
@@ -460,3 +469,42 @@ class Master:
|
|||||||
del self._pending_traces[task_id]
|
del self._pending_traces[task_id]
|
||||||
if task_id in self._expected_ranks:
|
if task_id in self._expected_ranks:
|
||||||
del self._expected_ranks[task_id]
|
del self._expected_ranks[task_id]
|
||||||
|
|
||||||
|
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
|
||||||
|
"""Accumulate SideChannel contributions; when all runners for an instance
|
||||||
|
have submitted for the same sequence, emit JacclSideChannelGathered."""
|
||||||
|
iid = event.instance_id
|
||||||
|
seq = event.sequence
|
||||||
|
|
||||||
|
if iid not in self._jaccl_pending:
|
||||||
|
self._jaccl_pending[iid] = {}
|
||||||
|
if seq not in self._jaccl_pending[iid]:
|
||||||
|
self._jaccl_pending[iid][seq] = {}
|
||||||
|
self._jaccl_pending[iid][seq][event.runner_id] = event.data
|
||||||
|
|
||||||
|
instance = self.state.instances.get(iid)
|
||||||
|
if instance is None:
|
||||||
|
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
|
||||||
|
return
|
||||||
|
|
||||||
|
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
|
||||||
|
submitted = set(self._jaccl_pending[iid][seq].keys())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"JACCL side channel: instance={iid} seq={seq} "
|
||||||
|
f"submitted={len(submitted)}/{len(expected_runners)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if submitted >= expected_runners:
|
||||||
|
gathered = dict(self._jaccl_pending[iid][seq])
|
||||||
|
del self._jaccl_pending[iid][seq]
|
||||||
|
if not self._jaccl_pending[iid]:
|
||||||
|
del self._jaccl_pending[iid]
|
||||||
|
|
||||||
|
await self.event_sender.send(
|
||||||
|
JacclSideChannelGathered(
|
||||||
|
instance_id=iid,
|
||||||
|
sequence=seq,
|
||||||
|
gathered_data=gathered,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from exo.shared.types.events import (
|
|||||||
InputChunkReceived,
|
InputChunkReceived,
|
||||||
InstanceCreated,
|
InstanceCreated,
|
||||||
InstanceDeleted,
|
InstanceDeleted,
|
||||||
|
JacclSideChannelData,
|
||||||
|
JacclSideChannelGathered,
|
||||||
NodeDownloadProgress,
|
NodeDownloadProgress,
|
||||||
NodeGatheredInfo,
|
NodeGatheredInfo,
|
||||||
NodeTimedOut,
|
NodeTimedOut,
|
||||||
@@ -68,6 +70,8 @@ def event_apply(event: Event, state: State) -> State:
|
|||||||
| PrefillProgress()
|
| PrefillProgress()
|
||||||
| TracesCollected()
|
| TracesCollected()
|
||||||
| TracesMerged()
|
| TracesMerged()
|
||||||
|
| JacclSideChannelData()
|
||||||
|
| JacclSideChannelGathered()
|
||||||
): # Pass-through events that don't modify state
|
): # Pass-through events that don't modify state
|
||||||
return state
|
return state
|
||||||
case InstanceCreated():
|
case InstanceCreated():
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
import base64
|
||||||
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import final
|
from typing import Annotated, final
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BeforeValidator, Field, PlainSerializer
|
||||||
|
|
||||||
from exo.shared.topology import Connection
|
from exo.shared.topology import Connection
|
||||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||||
@@ -14,6 +16,28 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
|||||||
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_base64_bytes(v: bytes | str) -> bytes:
|
||||||
|
if isinstance(v, bytes):
|
||||||
|
return v
|
||||||
|
return base64.b64decode(v)
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_base64_bytes(v: bytes) -> str:
|
||||||
|
return base64.b64encode(v).decode("ascii")
|
||||||
|
|
||||||
|
|
||||||
|
Base64Bytes = Annotated[
|
||||||
|
bytes,
|
||||||
|
BeforeValidator(_decode_base64_bytes),
|
||||||
|
PlainSerializer(_encode_base64_bytes, return_type=str),
|
||||||
|
]
|
||||||
|
"""bytes that serialize to/from base64 strings in JSON.
|
||||||
|
|
||||||
|
Needed because TaggedModel's wrap validator converts JSON→Python validation
|
||||||
|
context, which breaks strict-mode bytes deserialization from JSON strings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class EventId(Id):
|
class EventId(Id):
|
||||||
"""
|
"""
|
||||||
Newtype around `ID`
|
Newtype around `ID`
|
||||||
@@ -139,6 +163,25 @@ class TracesMerged(BaseEvent):
|
|||||||
traces: list[TraceEventData]
|
traces: list[TraceEventData]
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class JacclSideChannelData(BaseEvent):
|
||||||
|
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
|
||||||
|
|
||||||
|
instance_id: InstanceId
|
||||||
|
runner_id: RunnerId
|
||||||
|
sequence: int
|
||||||
|
data: Base64Bytes
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class JacclSideChannelGathered(BaseEvent):
|
||||||
|
"""Gathered result of a JACCL SideChannel all_gather round."""
|
||||||
|
|
||||||
|
instance_id: InstanceId
|
||||||
|
sequence: int
|
||||||
|
gathered_data: Mapping[RunnerId, Base64Bytes]
|
||||||
|
|
||||||
|
|
||||||
Event = (
|
Event = (
|
||||||
TestEvent
|
TestEvent
|
||||||
| TaskCreated
|
| TaskCreated
|
||||||
@@ -160,6 +203,8 @@ Event = (
|
|||||||
| TopologyEdgeDeleted
|
| TopologyEdgeDeleted
|
||||||
| TracesCollected
|
| TracesCollected
|
||||||
| TracesMerged
|
| TracesMerged
|
||||||
|
| JacclSideChannelData
|
||||||
|
| JacclSideChannelGathered
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -643,6 +643,11 @@ def mlx_cleanup(
|
|||||||
|
|
||||||
|
|
||||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||||
|
"""Synchronize a boolean across all distributed nodes.
|
||||||
|
|
||||||
|
Returns True if any node has bool_=True. Uses all_sum so every
|
||||||
|
node participates in the collective — preventing GPU deadlocks.
|
||||||
|
"""
|
||||||
if group is None:
|
if group is None:
|
||||||
return bool_
|
return bool_
|
||||||
num_true = mx.distributed.all_sum(
|
num_true = mx.distributed.all_sum(
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from exo.shared.types.events import (
|
|||||||
ForwarderEvent,
|
ForwarderEvent,
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
InputChunkReceived,
|
InputChunkReceived,
|
||||||
|
JacclSideChannelGathered,
|
||||||
NodeGatheredInfo,
|
NodeGatheredInfo,
|
||||||
TaskCreated,
|
TaskCreated,
|
||||||
TaskStatusUpdated,
|
TaskStatusUpdated,
|
||||||
@@ -159,6 +160,15 @@ class Worker:
|
|||||||
for idx, event in indexed_events:
|
for idx, event in indexed_events:
|
||||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||||
|
|
||||||
|
# Dispatch JACCL gathered events to the relevant RunnerSupervisor
|
||||||
|
if isinstance(event, JacclSideChannelGathered):
|
||||||
|
for runner in self.runners.values():
|
||||||
|
if (
|
||||||
|
runner.bound_instance.instance.instance_id
|
||||||
|
== event.instance_id
|
||||||
|
):
|
||||||
|
runner.notify_gathered(event)
|
||||||
|
|
||||||
# Buffer input image chunks for image editing
|
# Buffer input image chunks for image editing
|
||||||
if isinstance(event, InputChunkReceived):
|
if isinstance(event, InputChunkReceived):
|
||||||
cmd_id = event.command_id
|
cmd_id = event.command_id
|
||||||
@@ -241,6 +251,11 @@ class Worker:
|
|||||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||||
):
|
):
|
||||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||||
|
await self.event_sender.send(
|
||||||
|
TaskStatusUpdated(
|
||||||
|
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||||
|
)
|
||||||
|
)
|
||||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||||
# Assemble image from chunks and inject into task
|
# Assemble image from chunks and inject into task
|
||||||
cmd_id = task.command_id
|
cmd_id = task.command_id
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ def entrypoint(
|
|||||||
task_receiver: MpReceiver[Task],
|
task_receiver: MpReceiver[Task],
|
||||||
cancel_receiver: MpReceiver[TaskId],
|
cancel_receiver: MpReceiver[TaskId],
|
||||||
_logger: "loguru.Logger",
|
_logger: "loguru.Logger",
|
||||||
|
pipe_fifo_paths: tuple[str, str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||||
if fast_synch_override == "on" or (
|
if fast_synch_override == "on" or (
|
||||||
@@ -30,6 +31,16 @@ def entrypoint(
|
|||||||
else:
|
else:
|
||||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||||
|
|
||||||
|
# Open JACCL FIFOs by path and set env vars for C++ SideChannel.
|
||||||
|
# Named pipes (FIFOs) work across multiprocessing spawn (macOS default).
|
||||||
|
if pipe_fifo_paths is not None:
|
||||||
|
fifo_c2p, fifo_p2c = pipe_fifo_paths
|
||||||
|
# C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT)
|
||||||
|
pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY)
|
||||||
|
pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY)
|
||||||
|
os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd)
|
||||||
|
os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd)
|
||||||
|
|
||||||
global logger
|
global logger
|
||||||
logger = _logger
|
logger = _logger
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import os
|
||||||
import signal
|
import signal
|
||||||
|
import struct
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from functools import partial
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
@@ -14,12 +18,14 @@ from loguru import logger
|
|||||||
|
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
Event,
|
Event,
|
||||||
|
JacclSideChannelData,
|
||||||
|
JacclSideChannelGathered,
|
||||||
RunnerStatusUpdated,
|
RunnerStatusUpdated,
|
||||||
TaskAcknowledged,
|
TaskAcknowledged,
|
||||||
TaskStatusUpdated,
|
TaskStatusUpdated,
|
||||||
)
|
)
|
||||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||||
from exo.shared.types.worker.instances import BoundInstance
|
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||||
from exo.shared.types.worker.runners import (
|
from exo.shared.types.worker.runners import (
|
||||||
RunnerConnecting,
|
RunnerConnecting,
|
||||||
RunnerFailed,
|
RunnerFailed,
|
||||||
@@ -34,6 +40,26 @@ from exo.shared.types.worker.shards import ShardMetadata
|
|||||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||||
from exo.worker.runner.bootstrap import entrypoint
|
from exo.worker.runner.bootstrap import entrypoint
|
||||||
|
|
||||||
|
|
||||||
|
def _pipe_read_exact(fd: int, n: int) -> bytes | None:
|
||||||
|
"""Read exactly n bytes from a file descriptor. Returns None on EOF."""
|
||||||
|
data = b""
|
||||||
|
while len(data) < n:
|
||||||
|
chunk = os.read(fd, n - len(data))
|
||||||
|
if not chunk:
|
||||||
|
return None
|
||||||
|
data += chunk
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _pipe_write_all(fd: int, data: bytes) -> None:
|
||||||
|
"""Write all bytes to a file descriptor."""
|
||||||
|
view = memoryview(data)
|
||||||
|
while view:
|
||||||
|
written = os.write(fd, view)
|
||||||
|
view = view[written:]
|
||||||
|
|
||||||
|
|
||||||
PREFILL_TIMEOUT_SECONDS = 60
|
PREFILL_TIMEOUT_SECONDS = 60
|
||||||
DECODE_TIMEOUT_SECONDS = 5
|
DECODE_TIMEOUT_SECONDS = 5
|
||||||
|
|
||||||
@@ -48,10 +74,19 @@ class RunnerSupervisor:
|
|||||||
_task_sender: MpSender[Task]
|
_task_sender: MpSender[Task]
|
||||||
_event_sender: Sender[Event]
|
_event_sender: Sender[Event]
|
||||||
_cancel_sender: MpSender[TaskId]
|
_cancel_sender: MpSender[TaskId]
|
||||||
|
_pipe_read_fd: int | None = None # Python reads runner's pipe output
|
||||||
|
_pipe_write_fd: int | None = None # Python writes gathered data to runner
|
||||||
|
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
|
||||||
|
_fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup)
|
||||||
|
_fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads
|
||||||
|
_fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads
|
||||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||||
|
_gathered_waiters: dict[
|
||||||
|
int, tuple[anyio.Event, JacclSideChannelGathered | None]
|
||||||
|
] = field(default_factory=dict, init=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@@ -65,6 +100,23 @@ class RunnerSupervisor:
|
|||||||
task_sender, task_recv = mp_channel[Task]()
|
task_sender, task_recv = mp_channel[Task]()
|
||||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||||
|
|
||||||
|
# For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay.
|
||||||
|
# Named pipes work across multiprocessing.Process spawn (macOS default).
|
||||||
|
# FIFO c2p: C++ writes local data → Python reads it
|
||||||
|
# FIFO p2c: Python writes gathered data → C++ reads it
|
||||||
|
fifo_dir: str | None = None
|
||||||
|
fifo_c2p: str | None = None
|
||||||
|
fifo_p2c: str | None = None
|
||||||
|
pipe_fifo_paths: tuple[str, str] | None = None
|
||||||
|
|
||||||
|
if isinstance(bound_instance.instance, MlxJacclInstance):
|
||||||
|
fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_")
|
||||||
|
fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python
|
||||||
|
fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++
|
||||||
|
os.mkfifo(fifo_c2p)
|
||||||
|
os.mkfifo(fifo_p2c)
|
||||||
|
pipe_fifo_paths = (fifo_c2p, fifo_p2c)
|
||||||
|
|
||||||
runner_process = Process(
|
runner_process = Process(
|
||||||
target=entrypoint,
|
target=entrypoint,
|
||||||
args=(
|
args=(
|
||||||
@@ -73,6 +125,7 @@ class RunnerSupervisor:
|
|||||||
task_recv,
|
task_recv,
|
||||||
cancel_recv,
|
cancel_recv,
|
||||||
logger,
|
logger,
|
||||||
|
pipe_fifo_paths,
|
||||||
),
|
),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
@@ -88,22 +141,58 @@ class RunnerSupervisor:
|
|||||||
_task_sender=task_sender,
|
_task_sender=task_sender,
|
||||||
_cancel_sender=cancel_sender,
|
_cancel_sender=cancel_sender,
|
||||||
_event_sender=event_sender,
|
_event_sender=event_sender,
|
||||||
|
_fifo_dir=fifo_dir,
|
||||||
|
_fifo_c2p=fifo_c2p,
|
||||||
|
_fifo_p2c=fifo_p2c,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
self.runner_process.start()
|
self.runner_process.start()
|
||||||
await self._forward_events()
|
|
||||||
|
if self._fifo_c2p is not None and self._fifo_p2c is not None:
|
||||||
|
# Open FIFOs from parent side. These block until child opens the other end,
|
||||||
|
# so we run them in threads concurrently to avoid deadlock.
|
||||||
|
fifo_c2p = self._fifo_c2p
|
||||||
|
fifo_p2c = self._fifo_p2c
|
||||||
|
|
||||||
|
async def open_read() -> None:
|
||||||
|
self._pipe_read_fd = await to_thread.run_sync(
|
||||||
|
partial(os.open, fifo_c2p, os.O_RDONLY)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def open_write() -> None:
|
||||||
|
self._pipe_write_fd = await to_thread.run_sync(
|
||||||
|
partial(os.open, fifo_p2c, os.O_WRONLY)
|
||||||
|
)
|
||||||
|
|
||||||
|
async with anyio.create_task_group() as open_tg:
|
||||||
|
open_tg.start_soon(open_read)
|
||||||
|
open_tg.start_soon(open_write)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with anyio.create_task_group() as tg:
|
||||||
|
tg.start_soon(self._pipe_relay)
|
||||||
|
tg.start_soon(self._forward_events)
|
||||||
|
else:
|
||||||
|
await self._forward_events()
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
logger.info("Runner supervisor shutting down")
|
logger.info("Runner supervisor shutting down")
|
||||||
self._ev_recv.close()
|
self._ev_recv.close()
|
||||||
self._task_sender.close()
|
self._task_sender.close()
|
||||||
|
try:
|
||||||
|
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||||
|
self._cancel_sender.close()
|
||||||
|
except ClosedResourceError:
|
||||||
|
pass
|
||||||
self._event_sender.close()
|
self._event_sender.close()
|
||||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
self._close_pipe_fds()
|
||||||
self._cancel_sender.close()
|
self.runner_process.join(1)
|
||||||
self.runner_process.join(5)
|
|
||||||
if not self.runner_process.is_alive():
|
if not self.runner_process.is_alive():
|
||||||
logger.info("Runner process succesfully terminated")
|
logger.info("Runner process succesfully terminated")
|
||||||
return
|
return
|
||||||
@@ -140,6 +229,7 @@ class RunnerSupervisor:
|
|||||||
await event.wait()
|
await event.wait()
|
||||||
|
|
||||||
async def cancel_task(self, task_id: TaskId):
|
async def cancel_task(self, task_id: TaskId):
|
||||||
|
"""Send a cancellation signal to the runner process."""
|
||||||
if task_id in self.completed:
|
if task_id in self.completed:
|
||||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||||
return
|
return
|
||||||
@@ -181,6 +271,110 @@ class RunnerSupervisor:
|
|||||||
for tid in self.pending:
|
for tid in self.pending:
|
||||||
self.pending[tid].set()
|
self.pending[tid].set()
|
||||||
|
|
||||||
|
def _close_pipe_fds(self) -> None:
|
||||||
|
if self._pipe_read_fd is not None:
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
os.close(self._pipe_read_fd)
|
||||||
|
self._pipe_read_fd = None
|
||||||
|
if self._pipe_write_fd is not None:
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
os.close(self._pipe_write_fd)
|
||||||
|
self._pipe_write_fd = None
|
||||||
|
if self._child_pipe_fds is not None:
|
||||||
|
for fd in self._child_pipe_fds:
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
os.close(fd)
|
||||||
|
self._child_pipe_fds = None
|
||||||
|
# Clean up FIFO files
|
||||||
|
if self._fifo_c2p is not None:
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
os.unlink(self._fifo_c2p)
|
||||||
|
self._fifo_c2p = None
|
||||||
|
if self._fifo_p2c is not None:
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
os.unlink(self._fifo_p2c)
|
||||||
|
self._fifo_p2c = None
|
||||||
|
if self._fifo_dir is not None:
|
||||||
|
with contextlib.suppress(OSError):
|
||||||
|
os.rmdir(self._fifo_dir)
|
||||||
|
self._fifo_dir = None
|
||||||
|
|
||||||
|
async def _pipe_relay(self) -> None:
|
||||||
|
"""Relay JACCL SideChannel all_gather rounds between runner pipes and exo events."""
|
||||||
|
assert self._pipe_read_fd is not None
|
||||||
|
assert self._pipe_write_fd is not None
|
||||||
|
read_fd = self._pipe_read_fd
|
||||||
|
write_fd = self._pipe_write_fd
|
||||||
|
sequence = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# 1. Read local data from runner: [uint32 size][size bytes]
|
||||||
|
header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4))
|
||||||
|
if header is None:
|
||||||
|
logger.info("JACCL pipe relay: runner closed pipe (EOF)")
|
||||||
|
break
|
||||||
|
data_size: int = struct.unpack("<I", header)[0] # pyright: ignore[reportAny]
|
||||||
|
local_data = await to_thread.run_sync(
|
||||||
|
partial(_pipe_read_exact, read_fd, data_size)
|
||||||
|
)
|
||||||
|
if local_data is None:
|
||||||
|
logger.warning("JACCL pipe relay: EOF reading data payload")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"JACCL pipe relay: read {data_size} bytes from runner, seq={sequence}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Emit JacclSideChannelData event
|
||||||
|
waiter = anyio.Event()
|
||||||
|
self._gathered_waiters[sequence] = (waiter, None)
|
||||||
|
await self._event_sender.send(
|
||||||
|
JacclSideChannelData(
|
||||||
|
instance_id=self.bound_instance.instance.instance_id,
|
||||||
|
runner_id=self.bound_instance.bound_runner_id,
|
||||||
|
sequence=sequence,
|
||||||
|
data=local_data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Wait for gathered result
|
||||||
|
await waiter.wait()
|
||||||
|
_, gathered_event = self._gathered_waiters.pop(sequence)
|
||||||
|
assert gathered_event is not None
|
||||||
|
|
||||||
|
# 4. Order gathered data by runner rank and concatenate
|
||||||
|
instance = self.bound_instance.instance
|
||||||
|
assert isinstance(instance, MlxJacclInstance)
|
||||||
|
runner_order = list(instance.shard_assignments.runner_to_shard.keys())
|
||||||
|
ordered_data = b"".join(
|
||||||
|
gathered_event.gathered_data[rid] for rid in runner_order
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Write gathered data to runner: [uint32 total_size][total_size bytes]
|
||||||
|
total_size = len(ordered_data)
|
||||||
|
response = struct.pack("<I", total_size) + ordered_data
|
||||||
|
await to_thread.run_sync(partial(_pipe_write_all, write_fd, response))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"JACCL pipe relay: wrote {total_size} bytes to runner, seq={sequence}"
|
||||||
|
)
|
||||||
|
sequence += 1
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning(f"JACCL pipe relay: OS error: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.opt(exception=e).error("JACCL pipe relay: unexpected error")
|
||||||
|
|
||||||
|
def notify_gathered(self, event: JacclSideChannelGathered) -> None:
|
||||||
|
"""Called by the worker when a JacclSideChannelGathered event arrives."""
|
||||||
|
seq = event.sequence
|
||||||
|
if seq not in self._gathered_waiters:
|
||||||
|
logger.warning(f"JACCL: received gathered event for unknown sequence {seq}")
|
||||||
|
return
|
||||||
|
waiter, _ = self._gathered_waiters[seq]
|
||||||
|
self._gathered_waiters[seq] = (waiter, event)
|
||||||
|
waiter.set()
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
if self.runner_process.is_alive():
|
if self.runner_process.is_alive():
|
||||||
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
||||||
|
|||||||
Reference in New Issue
Block a user