mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-30 09:40:46 -05:00
Compare commits
9 Commits
iroh-migra
...
optimize-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
40cbecb5c4 | ||
|
|
fea42473dd | ||
|
|
ca7adcc2a8 | ||
|
|
9d9e24f969 | ||
|
|
b5d424b658 | ||
|
|
b465134012 | ||
|
|
eabdcab978 | ||
|
|
8e9332d6a7 | ||
|
|
4b65d5f896 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -7,6 +7,8 @@ digest.txt
|
||||
# nix
|
||||
.direnv/
|
||||
|
||||
# IDEA (PyCharm)
|
||||
.idea
|
||||
|
||||
# xcode / macos
|
||||
*.xcuserstate
|
||||
|
||||
3738
Cargo.lock
generated
3738
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
73
Cargo.toml
73
Cargo.toml
@@ -1,8 +1,10 @@
|
||||
[workspace]
|
||||
resolver = "3"
|
||||
members = [
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/system_custodian",
|
||||
"rust/util",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -23,38 +25,63 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
system_custodian = { path = "rust/system_custodian" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
syn = "2.0"
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
darling = "0.20"
|
||||
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
impl-trait-for-tuples = "0.2"
|
||||
clap = "4.5"
|
||||
derive_more = { version = "2.0.1", features = ["display"] }
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
itertools = "0.14"
|
||||
thiserror = "2"
|
||||
internment = "0.8"
|
||||
recursion = "0.5"
|
||||
regex = "1.11"
|
||||
once_cell = "1.21"
|
||||
thread_local = "1.1"
|
||||
bon = "3.4"
|
||||
generativity = "1.1"
|
||||
anyhow = "1.0"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Functional generics/lenses frameworks
|
||||
frunk_core = "0.4"
|
||||
frunk = "0.4"
|
||||
frunk_utils = "0.2"
|
||||
frunk-enum-core = "0.3"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
n0-future = "0.3.1"
|
||||
postcard = "1.1.3"
|
||||
n0-error = "0.1.2"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
either = "1.15"
|
||||
ordered-float = "5.0"
|
||||
ahash = "0.8"
|
||||
|
||||
# Tracing/logging
|
||||
log = "0.4"
|
||||
blake3 = "1.8.2"
|
||||
env_logger = "0.11"
|
||||
tracing-subscriber = "0.3.20"
|
||||
|
||||
# networking
|
||||
iroh = "0.95.1"
|
||||
iroh-gossip = "0.95.0"
|
||||
bytes = "1.11.0"
|
||||
|
||||
# pyo3
|
||||
pyo3 = "0.27.1"
|
||||
# pyo3-async-runtimes = "0.27.0"
|
||||
pyo3-log = "0.13.2"
|
||||
pyo3-stub-gen = "0.17.2"
|
||||
|
||||
# util
|
||||
rand = "0.9.2"
|
||||
extend = "1.2"
|
||||
|
||||
[patch.crates-io]
|
||||
netwatch = { git = "https://github.com/Evanev7/net-tools.git", branch="patch-for-exo" }
|
||||
libp2p = "0.56"
|
||||
libp2p-tcp = "0.44"
|
||||
|
||||
[workspace.lints.rust]
|
||||
static_mut_refs = "warn" # Or use "warn" instead of deny
|
||||
incomplete_features = "allow"
|
||||
|
||||
# Clippy's lint category level configurations;
|
||||
# every member crate needs to inherit these by adding
|
||||
|
||||
83
README.md
83
README.md
@@ -61,10 +61,10 @@ Devices running exo automatically discover each other, without needing any manua
|
||||
|
||||
There are two ways to run exo:
|
||||
|
||||
### Run from Source (Mac & Linux)
|
||||
### Run from Source (macOS)
|
||||
|
||||
**Prerequisites:**
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
@@ -98,6 +98,62 @@ uv run exo
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
### Run from Source (Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard) - version 18 or higher
|
||||
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
|
||||
|
||||
**Installation methods:**
|
||||
|
||||
**Option 1: Using system package manager (Ubuntu/Debian example):**
|
||||
```bash
|
||||
# Install Node.js and npm
|
||||
sudo apt update
|
||||
sudo apt install nodejs npm
|
||||
|
||||
# Install uv
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Option 2: Using Homebrew on Linux (if preferred):**
|
||||
```bash
|
||||
# Install Homebrew on Linux
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
|
||||
# Install dependencies
|
||||
brew install uv node
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Note:** The `macmon` package is macOS-only and not required for Linux.
|
||||
|
||||
Clone the repo, build the dashboard, and run exo:
|
||||
|
||||
```bash
|
||||
# Clone exo
|
||||
git clone https://github.com/exo-explore/exo
|
||||
|
||||
# Build dashboard
|
||||
cd exo/dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo
|
||||
uv run exo
|
||||
```
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
|
||||
|
||||
### macOS App
|
||||
|
||||
exo ships a macOS app that runs in the background on your Mac.
|
||||
@@ -112,6 +168,29 @@ The app will ask for permission to modify system settings and install a new Netw
|
||||
|
||||
---
|
||||
|
||||
### Enabling RDMA on macOS
|
||||
|
||||
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
|
||||
|
||||
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
|
||||
To enable RDMA on macOS, follow these steps:
|
||||
|
||||
1. Shut down your Mac.
|
||||
2. Hold down the power button for 10 seconds until the boot menu appears.
|
||||
3. Select "Options" to enter Recovery mode.
|
||||
4. When the Recovery UI appears, open the Terminal from the Utilities menu.
|
||||
5. In the Terminal, type:
|
||||
```
|
||||
rdma_ctl enable
|
||||
```
|
||||
and press Enter.
|
||||
6. Reboot your Mac.
|
||||
|
||||
After that, RDMA will be enabled in macOS and exo will take care of the rest.
|
||||
|
||||
---
|
||||
|
||||
### Using the API
|
||||
|
||||
If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.
|
||||
|
||||
1
TODO.md
1
TODO.md
@@ -19,7 +19,6 @@
|
||||
25. Rethink retry logic
|
||||
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
||||
27. Log cleanup - per-module log filters and default to DEBUG log levels
|
||||
28. Really need to remove all mlx logic outside of the runner - API has a transitive dependency on engines which imports mlx
|
||||
|
||||
Potential refactors:
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ struct ContentView: View {
|
||||
|
||||
private var topologySection: some View {
|
||||
Group {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(), !topology.nodes.isEmpty {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
|
||||
TopologyMiniView(topology: topology)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,7 +82,6 @@ struct BugReportService {
|
||||
}
|
||||
|
||||
private func loadCredentials() throws -> AWSConfig {
|
||||
// These credentials are write-only and necessary to receive bug reports from users
|
||||
return AWSConfig(
|
||||
accessKey: "AKIAYEKP5EMXTOBYDGHX",
|
||||
secretKey: "Ep5gIlUZ1o8ssTLQwmyy34yPGfTPEYQ4evE8NdPE",
|
||||
|
||||
@@ -7,6 +7,7 @@ final class ClusterStateService: ObservableObject {
|
||||
@Published private(set) var lastError: String?
|
||||
@Published private(set) var lastActionMessage: String?
|
||||
@Published private(set) var modelOptions: [ModelOption] = []
|
||||
@Published private(set) var localNodeId: String?
|
||||
|
||||
private var timer: Timer?
|
||||
private let decoder: JSONDecoder
|
||||
@@ -29,6 +30,7 @@ final class ClusterStateService: ObservableObject {
|
||||
func startPolling(interval: TimeInterval = 0.5) {
|
||||
stopPolling()
|
||||
Task {
|
||||
await fetchLocalNodeId()
|
||||
await fetchModels()
|
||||
await fetchSnapshot()
|
||||
}
|
||||
@@ -46,9 +48,31 @@ final class ClusterStateService: ObservableObject {
|
||||
latestSnapshot = nil
|
||||
lastError = nil
|
||||
lastActionMessage = nil
|
||||
localNodeId = nil
|
||||
}
|
||||
|
||||
private func fetchLocalNodeId() async {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("node_id")
|
||||
var request = URLRequest(url: url)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (data, response) = try await session.data(for: request)
|
||||
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
|
||||
return
|
||||
}
|
||||
if let nodeId = try? decoder.decode(String.self, from: data) {
|
||||
localNodeId = nodeId
|
||||
}
|
||||
} catch {
|
||||
// Silently ignore - localNodeId will remain nil and retry on next poll
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchSnapshot() async {
|
||||
// Retry fetching local node ID if not yet set
|
||||
if localNodeId == nil {
|
||||
await fetchLocalNodeId()
|
||||
}
|
||||
do {
|
||||
var request = URLRequest(url: endpoint)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
|
||||
@@ -85,7 +85,7 @@ struct TopologyViewModel {
|
||||
}
|
||||
|
||||
extension ClusterState {
|
||||
func topologyViewModel() -> TopologyViewModel? {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
|
||||
guard !allNodes.isEmpty else { return nil }
|
||||
@@ -105,6 +105,11 @@ extension ClusterState {
|
||||
orderedNodes = allNodes
|
||||
}
|
||||
|
||||
// Rotate so the local node (from /node_id API) is first
|
||||
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
|
||||
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
|
||||
}
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
|
||||
@@ -112,10 +117,7 @@ extension ClusterState {
|
||||
} ?? []
|
||||
let edges = Set(edgesArray)
|
||||
|
||||
let topologyRootId = topology?.nodes.first?.nodeId
|
||||
let currentId = orderedNodes.first(where: { $0.id == topologyRootId })?.id ?? orderedNodes.first?.id
|
||||
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: currentId)
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
48
dashboard/package-lock.json
generated
48
dashboard/package-lock.json
generated
@@ -9,6 +9,8 @@
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -861,7 +863,6 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -901,7 +902,6 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1518,7 +1518,6 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1528,7 +1527,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1941,7 +1939,6 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2254,6 +2251,31 @@
|
||||
"jiti": "lib/jiti-cli.mjs"
|
||||
}
|
||||
},
|
||||
"node_modules/katex": {
|
||||
"version": "0.16.27",
|
||||
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.27.tgz",
|
||||
"integrity": "sha512-aeQoDkuRWSqQN6nSvVCEFvfXdqo1OQiCmmW1kc9xSdjutPv7BGO7pqY9sQRJpMOGrEdfDgF2TfRXe5eUAD2Waw==",
|
||||
"funding": [
|
||||
"https://opencollective.com/katex",
|
||||
"https://github.com/sponsors/katex"
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"commander": "^8.3.0"
|
||||
},
|
||||
"bin": {
|
||||
"katex": "cli.js"
|
||||
}
|
||||
},
|
||||
"node_modules/katex/node_modules/commander": {
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz",
|
||||
"integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
}
|
||||
},
|
||||
"node_modules/kleur": {
|
||||
"version": "4.1.5",
|
||||
"resolved": "https://registry.npmjs.org/kleur/-/kleur-4.1.5.tgz",
|
||||
@@ -2540,6 +2562,18 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.5.5"
|
||||
}
|
||||
},
|
||||
"node_modules/marked": {
|
||||
"version": "17.0.1",
|
||||
"resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz",
|
||||
"integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==",
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"marked": "bin/marked.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 20"
|
||||
}
|
||||
},
|
||||
"node_modules/mode-watcher": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/mode-watcher/-/mode-watcher-1.1.0.tgz",
|
||||
@@ -2612,7 +2646,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2800,7 +2833,6 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2945,7 +2977,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -2967,7 +2998,6 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -27,7 +27,8 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -198,8 +198,10 @@
|
||||
stroke: oklch(0.85 0.18 85 / 0.4);
|
||||
stroke-width: 1.5px;
|
||||
stroke-dasharray: 8, 8;
|
||||
animation: flowAnimation 1s linear infinite;
|
||||
animation: flowAnimation 1.5s linear infinite;
|
||||
filter: drop-shadow(0 0 3px oklch(0.85 0.18 85 / 0.5));
|
||||
/* GPU optimization - hint to browser this element will animate */
|
||||
will-change: stroke-dashoffset;
|
||||
}
|
||||
|
||||
.graph-link-active {
|
||||
@@ -208,6 +210,24 @@
|
||||
filter: drop-shadow(0 0 6px oklch(0.85 0.18 85 / 0.8));
|
||||
}
|
||||
|
||||
/* Reduce motion for users who prefer it - also saves GPU */
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.graph-link {
|
||||
animation: none;
|
||||
}
|
||||
|
||||
.shooting-star {
|
||||
animation: none;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.status-pulse,
|
||||
.cursor-blink,
|
||||
.animate-pulse {
|
||||
animation: none;
|
||||
}
|
||||
}
|
||||
|
||||
/* CRT Screen effect for topology */
|
||||
.crt-screen {
|
||||
position: relative;
|
||||
@@ -266,13 +286,15 @@ input:focus, textarea:focus {
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
/* Shooting Stars Animation */
|
||||
/* Shooting Stars Animation - GPU optimized */
|
||||
.shooting-stars {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
overflow: hidden;
|
||||
pointer-events: none;
|
||||
z-index: 0;
|
||||
/* Only render when visible */
|
||||
content-visibility: auto;
|
||||
}
|
||||
|
||||
.shooting-star {
|
||||
@@ -285,6 +307,9 @@ input:focus, textarea:focus {
|
||||
animation: shootingStar var(--duration, 3s) linear infinite;
|
||||
animation-delay: var(--delay, 0s);
|
||||
opacity: 0;
|
||||
/* GPU optimization */
|
||||
will-change: transform, opacity;
|
||||
transform: translateZ(0);
|
||||
}
|
||||
|
||||
.shooting-star::before {
|
||||
@@ -320,3 +345,13 @@ input:focus, textarea:focus {
|
||||
transform: translate(400px, 400px);
|
||||
}
|
||||
}
|
||||
|
||||
/* Pause animations when page is hidden to save resources */
|
||||
:root:has(body[data-page-hidden="true"]) {
|
||||
.shooting-star,
|
||||
.graph-link,
|
||||
.status-pulse,
|
||||
.cursor-blink {
|
||||
animation-play-state: paused;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,89 +8,80 @@
|
||||
regenerateLastResponse
|
||||
} from '$lib/stores/app.svelte';
|
||||
import type { MessageAttachment } from '$lib/stores/app.svelte';
|
||||
import { tick, onDestroy } from 'svelte';
|
||||
import MarkdownContent from './MarkdownContent.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
|
||||
const messageList = $derived(messages());
|
||||
const response = $derived(currentResponse());
|
||||
const loading = $derived(isLoading());
|
||||
|
||||
// Ref for scroll anchor at bottom
|
||||
let scrollAnchorRef: HTMLDivElement | undefined = $state();
|
||||
// Scroll management - user controls scroll, show button when not at bottom
|
||||
const SCROLL_THRESHOLD = 100;
|
||||
let showScrollButton = $state(false);
|
||||
let lastMessageCount = 0;
|
||||
let containerRef: HTMLDivElement | undefined = $state();
|
||||
|
||||
// Scroll management
|
||||
const SCROLL_BOTTOM_THRESHOLD = 120;
|
||||
let autoScrollEnabled = true;
|
||||
let currentScrollEl: HTMLElement | null = null;
|
||||
|
||||
function resolveScrollElement(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
let node: HTMLElement | null = scrollAnchorRef?.parentElement as HTMLElement | null;
|
||||
while (node) {
|
||||
const isScrollable = node.scrollHeight > node.clientHeight + 1;
|
||||
if (isScrollable) return node;
|
||||
node = node.parentElement;
|
||||
function getScrollContainer(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
return containerRef?.parentElement ?? null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
if (!currentScrollEl) return;
|
||||
const distanceFromBottom = currentScrollEl.scrollHeight - currentScrollEl.scrollTop - currentScrollEl.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
autoScrollEnabled = isNearBottom;
|
||||
}
|
||||
|
||||
function attachScrollListener() {
|
||||
const nextEl = resolveScrollElement();
|
||||
if (currentScrollEl === nextEl) return;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
function isNearBottom(el: HTMLElement): boolean {
|
||||
return el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
|
||||
}
|
||||
currentScrollEl = nextEl;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.addEventListener('scroll', handleScroll);
|
||||
// Initialize state based on current position
|
||||
handleScroll();
|
||||
}
|
||||
}
|
||||
|
||||
onDestroy(() => {
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
// Re-evaluate scroll container if prop changes or after mount
|
||||
scrollParent;
|
||||
attachScrollListener();
|
||||
});
|
||||
|
||||
// Auto-scroll to bottom when messages change or response updates, but only if user is near bottom
|
||||
$effect(() => {
|
||||
// Track these values to trigger effect
|
||||
const _ = messageList.length;
|
||||
const __ = response;
|
||||
const ___ = loading;
|
||||
|
||||
tick().then(() => {
|
||||
const el = currentScrollEl ?? resolveScrollElement();
|
||||
if (!el || !scrollAnchorRef) return;
|
||||
const distanceFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
if (autoScrollEnabled || isNearBottom) {
|
||||
scrollAnchorRef.scrollIntoView({ behavior: 'smooth', block: 'end' });
|
||||
autoScrollEnabled = true;
|
||||
function scrollToBottom() {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
}
|
||||
}
|
||||
|
||||
function updateScrollButtonVisibility() {
|
||||
const el = getScrollContainer();
|
||||
if (!el) return;
|
||||
showScrollButton = !isNearBottom(el);
|
||||
}
|
||||
|
||||
// Attach scroll listener
|
||||
$effect(() => {
|
||||
const el = scrollParent ?? containerRef?.parentElement;
|
||||
if (!el) return;
|
||||
|
||||
el.addEventListener('scroll', updateScrollButtonVisibility, { passive: true });
|
||||
// Initial check
|
||||
updateScrollButtonVisibility();
|
||||
return () => el.removeEventListener('scroll', updateScrollButtonVisibility);
|
||||
});
|
||||
|
||||
// Auto-scroll when user sends a new message
|
||||
$effect(() => {
|
||||
const count = messageList.length;
|
||||
if (count > lastMessageCount) {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
requestAnimationFrame(() => {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
});
|
||||
}
|
||||
}
|
||||
lastMessageCount = count;
|
||||
});
|
||||
|
||||
// Update scroll button visibility when content changes
|
||||
$effect(() => {
|
||||
// Track response to trigger re-check during streaming
|
||||
const _ = response;
|
||||
|
||||
// Small delay to let DOM update
|
||||
requestAnimationFrame(() => updateScrollButtonVisibility());
|
||||
});
|
||||
});
|
||||
|
||||
// Edit state
|
||||
let editingMessageId = $state<string | null>(null);
|
||||
@@ -231,7 +222,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message (message.id)}
|
||||
<div class="group flex {message.role === 'user' ? 'justify-end' : 'justify-start'}">
|
||||
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'max-w-[95%] sm:max-w-[85%]'}">
|
||||
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'w-full max-w-[98%] sm:max-w-[95%]'}">
|
||||
{#if message.role === 'assistant'}
|
||||
<!-- Assistant message header -->
|
||||
<div class="flex items-center gap-1.5 sm:gap-2 mb-1.5 sm:mb-2">
|
||||
@@ -305,7 +296,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{:else}
|
||||
<div class="{message.role === 'user'
|
||||
? 'command-panel rounded-lg rounded-tr-sm inline-block'
|
||||
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 inline-block'}">
|
||||
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 block w-full'}">
|
||||
|
||||
{#if message.role === 'user'}
|
||||
<!-- User message styling -->
|
||||
@@ -331,7 +322,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
|
||||
{#if message.content}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
<div class="text-xs text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content}
|
||||
</div>
|
||||
{/if}
|
||||
@@ -360,7 +351,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</svg>
|
||||
<span>Thinking...</span>
|
||||
</span>
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60">
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60 ml-4">
|
||||
{isThinkingExpanded(message.id) ? 'HIDE' : 'SHOW'}
|
||||
</span>
|
||||
</button>
|
||||
@@ -374,8 +365,8 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content || (loading ? response : '')}
|
||||
<div class="text-xs text-foreground">
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
@@ -457,6 +448,20 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Scroll anchor for auto-scroll -->
|
||||
<div bind:this={scrollAnchorRef}></div>
|
||||
<!-- Invisible element for container reference -->
|
||||
<div bind:this={containerRef}></div>
|
||||
|
||||
<!-- Scroll to bottom button -->
|
||||
{#if showScrollButton}
|
||||
<button
|
||||
type="button"
|
||||
onclick={scrollToBottom}
|
||||
class="sticky bottom-4 left-1/2 -translate-x-1/2 w-10 h-10 rounded-full bg-exo-dark-gray/90 border border-exo-medium-gray/50 flex items-center justify-center text-exo-light-gray hover:text-exo-yellow hover:border-exo-yellow/50 transition-all shadow-lg cursor-pointer z-10"
|
||||
title="Scroll to bottom"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -10,7 +10,9 @@ import {
|
||||
clearChat,
|
||||
instances,
|
||||
debugMode,
|
||||
toggleDebugMode
|
||||
toggleDebugMode,
|
||||
topologyOnlyMode,
|
||||
toggleTopologyOnlyMode
|
||||
} from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -23,6 +25,7 @@ import {
|
||||
const activeId = $derived(activeConversationId());
|
||||
const instanceData = $derived(instances());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
|
||||
let searchQuery = $state('');
|
||||
let editingId = $state<string | null>(null);
|
||||
@@ -424,6 +427,19 @@ const debugEnabled = $derived(debugMode());
|
||||
<div class="text-xs text-white/60 font-mono tracking-wider text-center">
|
||||
{conversationList.length} CONVERSATION{conversationList.length !== 1 ? 'S' : ''}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleTopologyOnlyMode}
|
||||
class="p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title="Toggle topology only mode"
|
||||
>
|
||||
<svg class="w-4 h-4 {topologyOnlyEnabled ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="5" r="2" fill="currentColor" />
|
||||
<circle cx="5" cy="19" r="2" fill="currentColor" />
|
||||
<circle cx="19" cy="19" r="2" fill="currentColor" />
|
||||
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
|
||||
export let showHome = true;
|
||||
export let onHome: (() => void) | null = null;
|
||||
export let showSidebarToggle = false;
|
||||
export let sidebarVisible = true;
|
||||
export let onToggleSidebar: (() => void) | null = null;
|
||||
|
||||
function handleHome(): void {
|
||||
if (onHome) {
|
||||
@@ -14,13 +17,38 @@
|
||||
window.location.hash = '/';
|
||||
}
|
||||
}
|
||||
|
||||
function handleToggleSidebar(): void {
|
||||
if (onToggleSidebar) {
|
||||
onToggleSidebar();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<header class="relative z-20 flex items-center justify-center px-6 pt-8 pb-4 bg-exo-dark-gray">
|
||||
<!-- Left: Sidebar Toggle -->
|
||||
{#if showSidebarToggle}
|
||||
<div class="absolute left-6 top-1/2 -translate-y-1/2">
|
||||
<button
|
||||
onclick={handleToggleSidebar}
|
||||
class="p-2 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title={sidebarVisible ? 'Hide sidebar' : 'Show sidebar'}
|
||||
>
|
||||
<svg class="w-5 h-5 {sidebarVisible ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
{#if sidebarVisible}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M11 19l-7-7 7-7m8 14l-7-7 7-7" />
|
||||
{:else}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M13 5l7 7-7 7M5 5l7 7-7 7" />
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Center: Logo (clickable to go home) -->
|
||||
<button
|
||||
onclick={handleHome}
|
||||
class="hover:opacity-80 transition-opacity {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
class="bg-transparent border-none outline-none focus:outline-none transition-opacity duration-200 hover:opacity-90 {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
title={showHome ? 'Go to home' : ''}
|
||||
disabled={!showHome}
|
||||
>
|
||||
|
||||
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
@@ -0,0 +1,451 @@
|
||||
<script lang="ts">
|
||||
import { marked } from 'marked';
|
||||
import hljs from 'highlight.js';
|
||||
import katex from 'katex';
|
||||
import 'katex/dist/katex.min.css';
|
||||
import { browser } from '$app/environment';
|
||||
|
||||
interface Props {
|
||||
content: string;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { content, class: className = '' }: Props = $props();
|
||||
|
||||
let containerRef = $state<HTMLDivElement>();
|
||||
let processedHtml = $state('');
|
||||
|
||||
// Configure marked with syntax highlighting
|
||||
marked.setOptions({
|
||||
gfm: true,
|
||||
breaks: true
|
||||
});
|
||||
|
||||
// Custom renderer for code blocks
|
||||
const renderer = new marked.Renderer();
|
||||
|
||||
renderer.code = function ({ text, lang }: { text: string; lang?: string }) {
|
||||
const language = lang && hljs.getLanguage(lang) ? lang : 'plaintext';
|
||||
const highlighted = hljs.highlight(text, { language }).value;
|
||||
const codeId = `code-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`;
|
||||
|
||||
return `
|
||||
<div class="code-block-wrapper">
|
||||
<div class="code-block-header">
|
||||
<span class="code-language">${language}</span>
|
||||
<button type="button" class="copy-code-btn" data-code="${encodeURIComponent(text)}" title="Copy code">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
|
||||
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<pre><code class="hljs language-${language}" data-code-id="${codeId}">${highlighted}</code></pre>
|
||||
</div>
|
||||
`;
|
||||
};
|
||||
|
||||
// Inline code
|
||||
renderer.codespan = function ({ text }: { text: string }) {
|
||||
return `<code class="inline-code">${text}</code>`;
|
||||
};
|
||||
|
||||
marked.use({ renderer });
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
|
||||
* Also protect code blocks from LaTeX processing
|
||||
*/
|
||||
function preprocessLaTeX(text: string): string {
|
||||
// Protect code blocks
|
||||
const codeBlocks: string[] = [];
|
||||
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
|
||||
codeBlocks.push(match);
|
||||
return `<<CODE_${codeBlocks.length - 1}>>`;
|
||||
});
|
||||
|
||||
// Convert \(...\) to $...$
|
||||
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
|
||||
|
||||
// Convert \[...\] to $$...$$
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
|
||||
|
||||
// Restore code blocks
|
||||
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
|
||||
|
||||
return processed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render math expressions with KaTeX after HTML is generated
|
||||
*/
|
||||
function renderMath(html: string): string {
|
||||
// Render display math ($$...$$)
|
||||
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: true,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$$${math}$$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
// Render inline math ($...$) but avoid matching currency like $5
|
||||
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
|
||||
// Skip if it looks like currency ($ followed by number)
|
||||
if (/^\d/.test(math.trim())) {
|
||||
return match;
|
||||
}
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: false,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$${math}$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
return html;
|
||||
}
|
||||
|
||||
function processMarkdown(text: string): string {
|
||||
try {
|
||||
// Preprocess LaTeX notation
|
||||
const preprocessed = preprocessLaTeX(text);
|
||||
// Parse markdown
|
||||
let html = marked.parse(preprocessed) as string;
|
||||
// Render math expressions
|
||||
html = renderMath(html);
|
||||
return html;
|
||||
} catch (error) {
|
||||
console.error('Markdown processing error:', error);
|
||||
return text.replace(/\n/g, '<br>');
|
||||
}
|
||||
}
|
||||
|
||||
async function handleCopyClick(event: Event) {
|
||||
const target = event.currentTarget as HTMLButtonElement;
|
||||
const encodedCode = target.getAttribute('data-code');
|
||||
if (!encodedCode) return;
|
||||
|
||||
const code = decodeURIComponent(encodedCode);
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(code);
|
||||
// Show copied feedback
|
||||
const originalHtml = target.innerHTML;
|
||||
target.innerHTML = `
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M20 6L9 17l-5-5"/>
|
||||
</svg>
|
||||
`;
|
||||
target.classList.add('copied');
|
||||
setTimeout(() => {
|
||||
target.innerHTML = originalHtml;
|
||||
target.classList.remove('copied');
|
||||
}, 2000);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function setupCopyButtons() {
|
||||
if (!containerRef || !browser) return;
|
||||
|
||||
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of buttons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (content) {
|
||||
processedHtml = processMarkdown(content);
|
||||
} else {
|
||||
processedHtml = '';
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (containerRef && processedHtml) {
|
||||
setupCopyButtons();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class="markdown-content {className}">
|
||||
{@html processedHtml}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.markdown-content {
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
/* Paragraphs */
|
||||
.markdown-content :global(p) {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* Headers */
|
||||
.markdown-content :global(h1) {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 700;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h2) {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h3) {
|
||||
font-size: 1.125rem;
|
||||
font-weight: 600;
|
||||
margin: 1rem 0 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(h4),
|
||||
.markdown-content :global(h5),
|
||||
.markdown-content :global(h6) {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
margin: 0.75rem 0 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Bold and italic */
|
||||
.markdown-content :global(strong) {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(em) {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Inline code */
|
||||
.markdown-content :global(.inline-code) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
padding: 0.125rem 0.375rem;
|
||||
border-radius: 0.25rem;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
}
|
||||
|
||||
/* Links */
|
||||
.markdown-content :global(a) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
text-decoration: underline;
|
||||
text-underline-offset: 2px;
|
||||
}
|
||||
|
||||
.markdown-content :global(a:hover) {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
/* Lists */
|
||||
.markdown-content :global(ul) {
|
||||
list-style-type: disc;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(ol) {
|
||||
list-style-type: decimal;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li) {
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li::marker) {
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
}
|
||||
|
||||
/* Blockquotes */
|
||||
.markdown-content :global(blockquote) {
|
||||
border-left: 3px solid var(--exo-yellow, #ffd700);
|
||||
padding: 0.5rem 1rem;
|
||||
margin: 1rem 0;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-radius: 0 0.25rem 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Tables */
|
||||
.markdown-content :global(table) {
|
||||
width: 100%;
|
||||
margin: 1rem 0;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(th) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
padding: 0.5rem;
|
||||
text-align: left;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(td) {
|
||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||
padding: 0.5rem;
|
||||
}
|
||||
|
||||
/* Horizontal rule */
|
||||
.markdown-content :global(hr) {
|
||||
border: none;
|
||||
border-top: 1px solid rgba(255, 255, 255, 0.1);
|
||||
margin: 1.5rem 0;
|
||||
}
|
||||
|
||||
/* Code block wrapper */
|
||||
.markdown-content :global(.code-block-wrapper) {
|
||||
margin: 1rem 0;
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
background: rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-bottom: 1px solid rgba(255, 215, 0, 0.1);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-language) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
font-size: 0.7rem;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.1em;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
cursor: pointer;
|
||||
transition: color 0.2s;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn:hover) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn.copied) {
|
||||
color: #22c55e;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper pre) {
|
||||
margin: 0;
|
||||
padding: 1rem;
|
||||
overflow-x: auto;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper code) {
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.8125rem;
|
||||
line-height: 1.5;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
/* Syntax highlighting - dark theme matching EXO style */
|
||||
.markdown-content :global(.hljs) {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-keyword),
|
||||
.markdown-content :global(.hljs-selector-tag),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-section),
|
||||
.markdown-content :global(.hljs-link) {
|
||||
color: #c084fc;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-string),
|
||||
.markdown-content :global(.hljs-title),
|
||||
.markdown-content :global(.hljs-name),
|
||||
.markdown-content :global(.hljs-type),
|
||||
.markdown-content :global(.hljs-attribute),
|
||||
.markdown-content :global(.hljs-symbol),
|
||||
.markdown-content :global(.hljs-bullet),
|
||||
.markdown-content :global(.hljs-addition),
|
||||
.markdown-content :global(.hljs-variable),
|
||||
.markdown-content :global(.hljs-template-tag),
|
||||
.markdown-content :global(.hljs-template-variable) {
|
||||
color: #fbbf24;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-comment),
|
||||
.markdown-content :global(.hljs-quote),
|
||||
.markdown-content :global(.hljs-deletion),
|
||||
.markdown-content :global(.hljs-meta) {
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-number),
|
||||
.markdown-content :global(.hljs-regexp),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-built_in) {
|
||||
color: #34d399;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-function),
|
||||
.markdown-content :global(.hljs-class .hljs-title) {
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
/* KaTeX math styling */
|
||||
.markdown-content :global(.katex) {
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display) {
|
||||
margin: 1rem 0;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display > .katex) {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-error) {
|
||||
color: #f87171;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
background: rgba(248, 113, 113, 0.1);
|
||||
padding: 0.125rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
</style>
|
||||
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview } from '$lib/stores/app.svelte';
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview, TopologyEdge } from '$lib/stores/app.svelte';
|
||||
import { debugMode, topologyData } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
model: { id: string; name?: string; storage_size_megabytes?: number };
|
||||
@@ -206,12 +207,8 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
const centerY = topoHeight / 2;
|
||||
const radius = numNodes === 1 ? 0 : numNodes === 2 ? 45 : Math.min(topoWidth, topoHeight) * 0.32;
|
||||
|
||||
// Use API preview data if available
|
||||
// Only use API preview data - no local estimation
|
||||
const hasApiPreview = apiPreview !== null && apiPreview.error === null && apiPreview.memory_delta_by_node !== null;
|
||||
const canFit = hasApiPreview ? true : (() => {
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return totalAvailable >= estimatedMemory;
|
||||
})();
|
||||
const error = apiPreview?.error ?? null;
|
||||
|
||||
let placementNodes: Array<{
|
||||
@@ -232,135 +229,140 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
modelFillHeight: number;
|
||||
}> = [];
|
||||
|
||||
if (hasApiPreview && apiPreview.memory_delta_by_node) {
|
||||
// Use API placement data
|
||||
const memoryDelta = apiPreview.memory_delta_by_node;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else if (apiPreview?.error) {
|
||||
// API returned an error - model can't fit, show all nodes as unused
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: 0,
|
||||
currentPercent,
|
||||
newPercent: currentPercent,
|
||||
isUsed: false,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: 0
|
||||
};
|
||||
});
|
||||
} else {
|
||||
// Fallback: local estimation based on sharding strategy
|
||||
const memoryNeeded = estimatedMemory;
|
||||
// Use API placement data directly
|
||||
const memoryDelta = apiPreview?.memory_delta_by_node ?? {};
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
if (sharding === 'Pipeline') {
|
||||
const memoryPerNode = memoryNeeded / numNodes;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + memoryPerNode) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: memoryPerNode,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed: true,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else {
|
||||
let remaining = memoryNeeded;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const allocated = Math.min(remaining, n.availableGB);
|
||||
remaining -= allocated;
|
||||
const isUsed = allocated > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + allocated) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: allocated,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return { nodes: placementNodes, canFit: hasApiPreview || canFit, totalAvailable, topoWidth, topoHeight, error };
|
||||
return { nodes: placementNodes, canFit: hasApiPreview, totalAvailable, topoWidth, topoHeight, error };
|
||||
});
|
||||
|
||||
const canFit = $derived(apiPreview ? apiPreview.error === null : placementPreview().canFit);
|
||||
const placementError = $derived(apiPreview?.error ?? null);
|
||||
const nodeCount = $derived(nodeList().length);
|
||||
const filterId = $derived(model.id.replace(/[^a-zA-Z0-9]/g, ''));
|
||||
|
||||
// Debug mode state
|
||||
const isDebugMode = $derived(debugMode());
|
||||
const topology = $derived(topologyData());
|
||||
const isRdma = $derived(runtime === 'MlxIbv' || runtime === 'MlxJaccl');
|
||||
|
||||
// Get interface name for an IP from node data
|
||||
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
|
||||
if (!ip || !topology?.nodes) return null;
|
||||
|
||||
// Strip port if present
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
// Check specified node first
|
||||
const node = topology.nodes[nodeId];
|
||||
if (node) {
|
||||
const match = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
// Fallback: check all nodes
|
||||
for (const [, otherNode] of Object.entries(topology.nodes)) {
|
||||
if (!otherNode) continue;
|
||||
const match = otherNode.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = otherNode.ip_to_interface?.[cleanIp] || otherNode.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
// Get directional arrow based on node positions
|
||||
function getArrow(fromNode: { x: number; y: number }, toNode: { x: number; y: number }): string {
|
||||
const dx = toNode.x - fromNode.x;
|
||||
const dy = toNode.y - fromNode.y;
|
||||
const absX = Math.abs(dx);
|
||||
const absY = Math.abs(dy);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dx > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dy > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dx > 0 && dy > 0) return '↘';
|
||||
if (dx > 0 && dy < 0) return '↗';
|
||||
if (dx < 0 && dy > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Get connection info for edges between two nodes
|
||||
// Returns exactly one connection per direction (A→B and B→A), preferring non-loopback
|
||||
function getConnectionInfo(nodeId1: string, nodeId2: string): Array<{ ip: string; iface: string | null; from: string; to: string }> {
|
||||
if (!topology?.edges) return [];
|
||||
|
||||
// Collect candidates for each direction
|
||||
const aToBCandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
const bToACandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
|
||||
for (const edge of topology.edges) {
|
||||
const ip = edge.sendBackIp || '?';
|
||||
const iface = edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
|
||||
|
||||
if (edge.source === nodeId1 && edge.target === nodeId2) {
|
||||
aToBCandidates.push({ ip, iface });
|
||||
} else if (edge.source === nodeId2 && edge.target === nodeId1) {
|
||||
bToACandidates.push({ ip, iface });
|
||||
}
|
||||
}
|
||||
|
||||
// Pick best (prefer non-loopback)
|
||||
const pickBest = (candidates: Array<{ ip: string; iface: string | null }>) => {
|
||||
if (candidates.length === 0) return null;
|
||||
return candidates.find(c => !c.ip.startsWith('127.')) || candidates[0];
|
||||
};
|
||||
|
||||
const result: Array<{ ip: string; iface: string | null; from: string; to: string }> = [];
|
||||
|
||||
const bestAtoB = pickBest(aToBCandidates);
|
||||
if (bestAtoB) result.push({ ...bestAtoB, from: nodeId1, to: nodeId2 });
|
||||
|
||||
const bestBtoA = pickBest(bToACandidates);
|
||||
if (bestBtoA) result.push({ ...bestBtoA, from: nodeId2, to: nodeId1 });
|
||||
|
||||
return result;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="relative group">
|
||||
@@ -453,6 +455,26 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
|
||||
<!-- Connection lines between nodes (if multiple) -->
|
||||
{#if preview.nodes.length > 1}
|
||||
{@const usedNodes = preview.nodes.filter(n => n.isUsed)}
|
||||
{@const nodePositions = Object.fromEntries(preview.nodes.map(n => [n.id, { x: n.x, y: n.y }]))}
|
||||
{@const allConnections = isDebugMode && usedNodes.length > 1 ? (() => {
|
||||
const conns: Array<{ ip: string; iface: string | null; from: string; to: string; midX: number; midY: number; arrow: string }> = [];
|
||||
for (let i = 0; i < usedNodes.length; i++) {
|
||||
for (let j = i + 1; j < usedNodes.length; j++) {
|
||||
const n1 = usedNodes[i];
|
||||
const n2 = usedNodes[j];
|
||||
const midX = (n1.x + n2.x) / 2;
|
||||
const midY = (n1.y + n2.y) / 2;
|
||||
for (const c of getConnectionInfo(n1.id, n2.id)) {
|
||||
const fromPos = nodePositions[c.from];
|
||||
const toPos = nodePositions[c.to];
|
||||
const arrow = fromPos && toPos ? getArrow(fromPos, toPos) : '→';
|
||||
conns.push({ ...c, midX, midY, arrow });
|
||||
}
|
||||
}
|
||||
}
|
||||
return conns;
|
||||
})() : []}
|
||||
{#each preview.nodes as node, i}
|
||||
{#each preview.nodes.slice(i + 1) as node2}
|
||||
<line
|
||||
@@ -464,6 +486,43 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
/>
|
||||
{/each}
|
||||
{/each}
|
||||
<!-- Debug: Show connection IPs/interfaces in corners -->
|
||||
{#if isDebugMode && allConnections.length > 0}
|
||||
{@const centerX = preview.topoWidth / 2}
|
||||
{@const centerY = preview.topoHeight / 2}
|
||||
{@const quadrants = {
|
||||
topLeft: allConnections.filter(c => c.midX < centerX && c.midY < centerY),
|
||||
topRight: allConnections.filter(c => c.midX >= centerX && c.midY < centerY),
|
||||
bottomLeft: allConnections.filter(c => c.midX < centerX && c.midY >= centerY),
|
||||
bottomRight: allConnections.filter(c => c.midX >= centerX && c.midY >= centerY)
|
||||
}}
|
||||
{@const padding = 4}
|
||||
{@const lineHeight = 8}
|
||||
<!-- Top Left -->
|
||||
{#each quadrants.topLeft as conn, idx}
|
||||
<text x={padding} y={padding + idx * lineHeight} text-anchor="start" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Top Right -->
|
||||
{#each quadrants.topRight as conn, idx}
|
||||
<text x={preview.topoWidth - padding} y={padding + idx * lineHeight} text-anchor="end" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Left -->
|
||||
{#each quadrants.bottomLeft as conn, idx}
|
||||
<text x={padding} y={preview.topoHeight - padding - (quadrants.bottomLeft.length - 1 - idx) * lineHeight} text-anchor="start" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Right -->
|
||||
{#each quadrants.bottomRight as conn, idx}
|
||||
<text x={preview.topoWidth - padding} y={preview.topoHeight - padding - (quadrants.bottomRight.length - 1 - idx) * lineHeight} text-anchor="end" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
{#each preview.nodes as node}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { onMount, onDestroy } from 'svelte';
|
||||
import { onMount, onDestroy, tick } from 'svelte';
|
||||
import * as d3 from 'd3';
|
||||
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
|
||||
|
||||
@@ -12,11 +12,35 @@ import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.sv
|
||||
|
||||
let svgContainer: SVGSVGElement | undefined = $state();
|
||||
let resizeObserver: ResizeObserver | undefined;
|
||||
|
||||
// Optimization: Track last render state to avoid unnecessary re-renders
|
||||
let lastRenderHash = '';
|
||||
let lastHighlightedNodesHash = '';
|
||||
let lastDimensions = { width: 0, height: 0 };
|
||||
let isRendering = false;
|
||||
let pendingRender = false;
|
||||
|
||||
const isMinimized = $derived(isTopologyMinimized());
|
||||
const data = $derived(topologyData());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
|
||||
// Generate a hash of relevant data to detect actual changes
|
||||
function generateDataHash(topologyData: typeof data, minimized: boolean, debug: boolean): string {
|
||||
if (!topologyData) return 'null';
|
||||
const nodes = topologyData.nodes || {};
|
||||
const edges = topologyData.edges || [];
|
||||
|
||||
// Create a lightweight hash from key properties only
|
||||
const nodeHashes = Object.entries(nodes).map(([id, n]) => {
|
||||
const macmon = n.macmon_info;
|
||||
return `${id}:${n.friendly_name || ''}:${macmon?.memory?.ram_usage || 0}:${macmon?.memory?.ram_total || 0}:${macmon?.temp?.gpu_temp_avg || 0}:${macmon?.gpu_usage?.[1] || 0}:${macmon?.sys_power || 0}`;
|
||||
}).sort().join('|');
|
||||
|
||||
const edgeHash = edges.map(e => `${e.source}-${e.target}`).sort().join(',');
|
||||
|
||||
return `${nodeHashes}::${edgeHash}::${minimized}::${debug}`;
|
||||
}
|
||||
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
@@ -24,19 +48,36 @@ function getNodeLabel(nodeId: string): string {
|
||||
|
||||
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
|
||||
if (!ip) return { label: '?', missing: true };
|
||||
const node = data?.nodes?.[nodeId];
|
||||
if (!node) return { label: '?', missing: true };
|
||||
|
||||
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
// Helper to check a node's interfaces
|
||||
function checkNode(node: typeof data.nodes[string]): string | null {
|
||||
if (!node) return null;
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (matchFromInterfaces?.name) {
|
||||
return matchFromInterfaces.name;
|
||||
}
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === ip)
|
||||
);
|
||||
if (matchFromInterfaces?.name) {
|
||||
return { label: matchFromInterfaces.name, missing: false };
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
const mapped = node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return { label: mapped, missing: false };
|
||||
|
||||
// Try specified node first
|
||||
const result = checkNode(data?.nodes?.[nodeId]);
|
||||
if (result) return { label: result, missing: false };
|
||||
|
||||
// Fallback: search all nodes for this IP
|
||||
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
|
||||
const otherResult = checkNode(otherNode);
|
||||
if (otherResult) return { label: otherResult, missing: false };
|
||||
}
|
||||
|
||||
return { label: '?', missing: true };
|
||||
@@ -67,6 +108,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
return lines;
|
||||
}
|
||||
|
||||
|
||||
// Apple logo path for MacBook Pro screen
|
||||
const APPLE_LOGO_PATH = "M788.1 340.9c-5.8 4.5-108.2 62.2-108.2 190.5 0 148.4 130.3 200.9 134.2 202.2-.6 3.2-20.7 71.9-68.7 141.9-42.8 61.6-87.5 123.1-155.5 123.1s-85.5-39.5-164-39.5c-76.5 0-103.7 40.8-165.9 40.8s-105.6-57-155.5-127C46.7 790.7 0 663 0 541.8c0-194.4 126.4-297.5 250.8-297.5 66.1 0 121.2 43.4 162.7 43.4 39.5 0 101.1-46 176.3-46 28.5 0 130.9 2.6 198.3 99.2zm-234-181.5c31.1-36.9 53.1-88.1 53.1-139.3 0-7.1-.6-14.3-1.9-20.1-50.6 1.9-110.8 33.7-147.1 75.8-28.5 32.4-55.1 83.6-55.1 135.5 0 7.8 1.3 15.6 1.9 18.1 3.2.6 8.4 1.3 13.6 1.3 45.4 0 102.5-30.4 135.5-71.3z";
|
||||
const LOGO_NATIVE_WIDTH = 814;
|
||||
@@ -238,6 +280,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
|
||||
|
||||
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
|
||||
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
|
||||
edges.forEach(edge => {
|
||||
if (!edge.source || !edge.target || edge.source === edge.target) return;
|
||||
if (!positionById[edge.source] || !positionById[edge.target]) return;
|
||||
@@ -314,110 +357,98 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('marker-end', 'url(#arrowhead)');
|
||||
}
|
||||
|
||||
// Collect debug labels for later positioning at edges
|
||||
if (debugEnabled && entry.connections.length > 0) {
|
||||
const maxBoxes = 6;
|
||||
const fontSize = isMinimized ? 8 : 9;
|
||||
const lineGap = 2;
|
||||
const labelOffsetOut = Math.max(140, minDimension * 0.38);
|
||||
const labelOffsetSide = isMinimized ? 16 : 20;
|
||||
const boxWidth = 170;
|
||||
const maxLineLen = 26;
|
||||
|
||||
const connections = entry.connections.slice(0, maxBoxes);
|
||||
if (entry.connections.length > maxBoxes) {
|
||||
const remaining = entry.connections.length - maxBoxes;
|
||||
connections.push({
|
||||
from: '',
|
||||
to: '',
|
||||
ip: `(+${remaining} more)`,
|
||||
ifaceLabel: '',
|
||||
missingIface: false
|
||||
});
|
||||
}
|
||||
|
||||
let dirX = mx - centerX;
|
||||
let dirY = my - centerY;
|
||||
const dirLen = Math.hypot(dirX, dirY);
|
||||
if (dirLen < 1) {
|
||||
dirX = -uy;
|
||||
dirY = ux;
|
||||
} else {
|
||||
dirX /= dirLen;
|
||||
dirY /= dirLen;
|
||||
}
|
||||
|
||||
const nx = -dirY;
|
||||
const ny = dirX;
|
||||
|
||||
const labelXRaw = mx + dirX * labelOffsetOut + nx * labelOffsetSide;
|
||||
const labelYRaw = my + dirY * labelOffsetOut + ny * labelOffsetSide;
|
||||
const clampPad = Math.min(120, minDimension * 0.12);
|
||||
const labelX = Math.max(clampPad, Math.min(width - clampPad, labelXRaw));
|
||||
const labelY = Math.max(clampPad, Math.min(height - clampPad, labelYRaw));
|
||||
|
||||
const labelGroup = debugLabelsGroup.append('g')
|
||||
.attr('transform', `translate(${labelX}, ${labelY})`);
|
||||
|
||||
const textGroup = labelGroup.append('g');
|
||||
|
||||
connections.forEach((conn, idx) => {
|
||||
const rawLines = conn.from && conn.to
|
||||
? [
|
||||
`${getNodeLabel(conn.from)}→${getNodeLabel(conn.to)}`,
|
||||
`${conn.ip}`,
|
||||
`${conn.ifaceLabel}`
|
||||
]
|
||||
: [conn.ip];
|
||||
|
||||
const wrapped = rawLines.flatMap(line => wrapLine(line, maxLineLen));
|
||||
|
||||
wrapped.forEach((line, lineIdx) => {
|
||||
textGroup.append('text')
|
||||
.attr('x', 0)
|
||||
.attr('y', (idx * (wrapped.length * (fontSize + lineGap))) + lineIdx * (fontSize + lineGap))
|
||||
.attr('text-anchor', 'middle')
|
||||
.attr('dominant-baseline', 'hanging')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.9)')
|
||||
.text(line);
|
||||
});
|
||||
// Determine which side of viewport based on edge midpoint
|
||||
const isLeft = mx < centerX;
|
||||
const isTop = my < safeCenterY;
|
||||
|
||||
// Store for batch rendering after all edges processed
|
||||
if (!debugEdgeLabels) debugEdgeLabels = [];
|
||||
debugEdgeLabels.push({
|
||||
connections: entry.connections,
|
||||
isLeft,
|
||||
isTop,
|
||||
mx,
|
||||
my
|
||||
});
|
||||
|
||||
const bbox = textGroup.node()?.getBBox();
|
||||
if (bbox) {
|
||||
const paddedWidth = Math.max(boxWidth, bbox.width + 14);
|
||||
const boxHeight = bbox.height + 8;
|
||||
const boxMinX = labelX - paddedWidth / 2;
|
||||
const boxMaxX = labelX + paddedWidth / 2;
|
||||
const boxMinY = labelY + bbox.y - 4;
|
||||
const boxMaxY = boxMinY + boxHeight;
|
||||
|
||||
const clampPadDynamic = Math.min(140, minDimension * 0.18);
|
||||
let shiftX = 0;
|
||||
let shiftY = 0;
|
||||
if (boxMinX < clampPadDynamic) shiftX = clampPadDynamic - boxMinX;
|
||||
if (boxMaxX > width - clampPadDynamic) shiftX = (width - clampPadDynamic) - boxMaxX;
|
||||
if (boxMinY < clampPadDynamic) shiftY = clampPadDynamic - boxMinY;
|
||||
if (boxMaxY > height - clampPadDynamic) shiftY = (height - clampPadDynamic) - boxMaxY;
|
||||
|
||||
const finalX = labelX + shiftX;
|
||||
const finalY = labelY + shiftY;
|
||||
labelGroup.attr('transform', `translate(${finalX}, ${finalY})`);
|
||||
|
||||
labelGroup.insert('rect', 'g')
|
||||
.attr('x', -paddedWidth / 2)
|
||||
.attr('y', bbox.y - 4)
|
||||
.attr('width', paddedWidth)
|
||||
.attr('height', boxHeight)
|
||||
.attr('rx', 4)
|
||||
.attr('fill', 'rgba(0,0,0,0.75)')
|
||||
.attr('stroke', 'rgba(255,255,255,0.12)')
|
||||
.attr('stroke-width', 0.6);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Render debug labels at viewport edges/corners
|
||||
if (debugEdgeLabels && debugEdgeLabels.length > 0) {
|
||||
const fontSize = isMinimized ? 10 : 12;
|
||||
const lineHeight = fontSize + 4;
|
||||
const padding = 10;
|
||||
|
||||
// Helper to get arrow based on direction vector
|
||||
function getArrow(fromId: string, toId: string): string {
|
||||
const fromPos = positionById[fromId];
|
||||
const toPos = positionById[toId];
|
||||
if (!fromPos || !toPos) return '→';
|
||||
|
||||
const dirX = toPos.x - fromPos.x;
|
||||
const dirY = toPos.y - fromPos.y;
|
||||
const absX = Math.abs(dirX);
|
||||
const absY = Math.abs(dirY);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dirX > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dirY > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dirX > 0 && dirY > 0) return '↘';
|
||||
if (dirX > 0 && dirY < 0) return '↗';
|
||||
if (dirX < 0 && dirY > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
|
||||
const quadrants: Record<string, typeof debugEdgeLabels> = {
|
||||
topLeft: [],
|
||||
topRight: [],
|
||||
bottomLeft: [],
|
||||
bottomRight: []
|
||||
};
|
||||
|
||||
debugEdgeLabels.forEach(edge => {
|
||||
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
|
||||
quadrants[key].push(edge);
|
||||
});
|
||||
|
||||
// Render each quadrant
|
||||
Object.entries(quadrants).forEach(([quadrant, edges]) => {
|
||||
if (edges.length === 0) return;
|
||||
|
||||
const isLeft = quadrant.includes('Left');
|
||||
const isTop = quadrant.includes('top');
|
||||
|
||||
let baseX = isLeft ? padding : width - padding;
|
||||
let baseY = isTop ? padding : height - padding;
|
||||
const textAnchor = isLeft ? 'start' : 'end';
|
||||
|
||||
let currentY = baseY;
|
||||
|
||||
edges.forEach(edge => {
|
||||
edge.connections.forEach(conn => {
|
||||
const arrow = getArrow(conn.from, conn.to);
|
||||
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;
|
||||
debugLabelsGroup.append('text')
|
||||
.attr('x', baseX)
|
||||
.attr('y', currentY)
|
||||
.attr('text-anchor', textAnchor)
|
||||
.attr('dominant-baseline', isTop ? 'hanging' : 'auto')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.85)')
|
||||
.text(label);
|
||||
currentY += isTop ? lineHeight : -lineHeight;
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Draw nodes
|
||||
const nodesGroup = svg.append('g').attr('class', 'nodes-group');
|
||||
|
||||
@@ -925,16 +956,59 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (data) {
|
||||
// Throttled render function to prevent too-frequent updates
|
||||
function scheduleRender() {
|
||||
if (isRendering) {
|
||||
pendingRender = true;
|
||||
return;
|
||||
}
|
||||
|
||||
isRendering = true;
|
||||
requestAnimationFrame(() => {
|
||||
renderGraph();
|
||||
isRendering = false;
|
||||
|
||||
if (pendingRender) {
|
||||
pendingRender = false;
|
||||
scheduleRender();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!data || !svgContainer) return;
|
||||
|
||||
// Generate hash of current state
|
||||
const currentHash = generateDataHash(data, isMinimized, debugEnabled);
|
||||
const highlightHash = Array.from(highlightedNodes).sort().join(',');
|
||||
|
||||
// Get current dimensions
|
||||
const rect = svgContainer.getBoundingClientRect();
|
||||
const dimensionsChanged = rect.width !== lastDimensions.width || rect.height !== lastDimensions.height;
|
||||
|
||||
// Only re-render if something actually changed
|
||||
if (currentHash !== lastRenderHash || highlightHash !== lastHighlightedNodesHash || dimensionsChanged) {
|
||||
lastRenderHash = currentHash;
|
||||
lastHighlightedNodesHash = highlightHash;
|
||||
lastDimensions = { width: rect.width, height: rect.height };
|
||||
scheduleRender();
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
if (svgContainer) {
|
||||
// Use a debounced resize observer to prevent rapid re-renders
|
||||
let resizeTimeout: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
resizeObserver = new ResizeObserver(() => {
|
||||
renderGraph();
|
||||
if (resizeTimeout) clearTimeout(resizeTimeout);
|
||||
resizeTimeout = setTimeout(() => {
|
||||
const rect = svgContainer!.getBoundingClientRect();
|
||||
if (rect.width !== lastDimensions.width || rect.height !== lastDimensions.height) {
|
||||
lastDimensions = { width: rect.width, height: rect.height };
|
||||
scheduleRender();
|
||||
}
|
||||
}, 100);
|
||||
});
|
||||
resizeObserver.observe(svgContainer);
|
||||
}
|
||||
@@ -962,10 +1036,20 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
stroke-width: 1px;
|
||||
stroke-dasharray: 4, 4;
|
||||
opacity: 0.8;
|
||||
animation: flowAnimation 0.75s linear infinite;
|
||||
/* Slower animation = less GPU usage */
|
||||
animation: flowAnimation 2s linear infinite;
|
||||
/* GPU optimization */
|
||||
will-change: stroke-dashoffset;
|
||||
}
|
||||
@keyframes flowAnimation {
|
||||
from { stroke-dashoffset: 0; }
|
||||
to { stroke-dashoffset: -10; }
|
||||
}
|
||||
|
||||
/* Respect reduced motion preference */
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
:global(.graph-link) {
|
||||
animation: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -4,4 +4,5 @@ export { default as ChatMessages } from './ChatMessages.svelte';
|
||||
export { default as ChatAttachments } from './ChatAttachments.svelte';
|
||||
export { default as ChatSidebar } from './ChatSidebar.svelte';
|
||||
export { default as ModelCard } from './ModelCard.svelte';
|
||||
export { default as MarkdownContent } from './MarkdownContent.svelte';
|
||||
|
||||
|
||||
@@ -297,6 +297,35 @@ function extractIpFromMultiaddr(ma?: string): string | undefined {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Deep comparison utility for preventing unnecessary state updates
|
||||
function shallowEqual(a: unknown, b: unknown): boolean {
|
||||
if (a === b) return true;
|
||||
if (a === null || b === null) return false;
|
||||
if (typeof a !== 'object' || typeof b !== 'object') return false;
|
||||
|
||||
const aObj = a as Record<string, unknown>;
|
||||
const bObj = b as Record<string, unknown>;
|
||||
const aKeys = Object.keys(aObj);
|
||||
const bKeys = Object.keys(bObj);
|
||||
|
||||
if (aKeys.length !== bKeys.length) return false;
|
||||
|
||||
for (const key of aKeys) {
|
||||
if (aObj[key] !== bObj[key]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Faster JSON comparison for complex nested objects
|
||||
function jsonEqual(a: unknown, b: unknown): boolean {
|
||||
if (a === b) return true;
|
||||
try {
|
||||
return JSON.stringify(a) === JSON.stringify(b);
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
class AppStore {
|
||||
// Conversation state
|
||||
conversations = $state<Conversation[]>([]);
|
||||
@@ -327,19 +356,49 @@ class AppStore {
|
||||
isTopologyMinimized = $state(false);
|
||||
isSidebarOpen = $state(false); // Hidden by default, shown when in chat mode
|
||||
debugMode = $state(false);
|
||||
topologyOnlyMode = $state(false);
|
||||
chatSidebarVisible = $state(true); // Shown by default
|
||||
|
||||
// Visibility state - used to pause polling when tab is hidden
|
||||
private isPageVisible = true;
|
||||
|
||||
private fetchInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
|
||||
// Cache for comparison - prevents unnecessary reactivity
|
||||
private lastTopologyJson = '';
|
||||
private lastInstancesJson = '';
|
||||
private lastRunnersJson = '';
|
||||
private lastDownloadsJson = '';
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
this.startPolling();
|
||||
this.loadConversationsFromStorage();
|
||||
this.loadDebugModeFromStorage();
|
||||
this.loadTopologyOnlyModeFromStorage();
|
||||
this.loadChatSidebarVisibleFromStorage();
|
||||
this.setupVisibilityListener();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Listen for page visibility changes to pause polling when hidden
|
||||
*/
|
||||
private setupVisibilityListener() {
|
||||
if (typeof document === 'undefined') return;
|
||||
|
||||
document.addEventListener('visibilitychange', () => {
|
||||
this.isPageVisible = document.visibilityState === 'visible';
|
||||
|
||||
if (this.isPageVisible) {
|
||||
// Resume polling when page becomes visible
|
||||
this.fetchState();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Load conversations from localStorage
|
||||
*/
|
||||
@@ -394,6 +453,44 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
private loadTopologyOnlyModeFromStorage() {
|
||||
try {
|
||||
const stored = localStorage.getItem('exo-topology-only-mode');
|
||||
if (stored !== null) {
|
||||
this.topologyOnlyMode = stored === 'true';
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load topology only mode:', error);
|
||||
}
|
||||
}
|
||||
|
||||
private saveTopologyOnlyModeToStorage() {
|
||||
try {
|
||||
localStorage.setItem('exo-topology-only-mode', this.topologyOnlyMode ? 'true' : 'false');
|
||||
} catch (error) {
|
||||
console.error('Failed to save topology only mode:', error);
|
||||
}
|
||||
}
|
||||
|
||||
private loadChatSidebarVisibleFromStorage() {
|
||||
try {
|
||||
const stored = localStorage.getItem('exo-chat-sidebar-visible');
|
||||
if (stored !== null) {
|
||||
this.chatSidebarVisible = stored === 'true';
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load chat sidebar visibility:', error);
|
||||
}
|
||||
}
|
||||
|
||||
private saveChatSidebarVisibleToStorage() {
|
||||
try {
|
||||
localStorage.setItem('exo-chat-sidebar-visible', this.chatSidebarVisible ? 'true' : 'false');
|
||||
} catch (error) {
|
||||
console.error('Failed to save chat sidebar visibility:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new conversation
|
||||
*/
|
||||
@@ -698,9 +795,39 @@ class AppStore {
|
||||
this.saveDebugModeToStorage();
|
||||
}
|
||||
|
||||
getTopologyOnlyMode(): boolean {
|
||||
return this.topologyOnlyMode;
|
||||
}
|
||||
|
||||
setTopologyOnlyMode(enabled: boolean) {
|
||||
this.topologyOnlyMode = enabled;
|
||||
this.saveTopologyOnlyModeToStorage();
|
||||
}
|
||||
|
||||
toggleTopologyOnlyMode() {
|
||||
this.topologyOnlyMode = !this.topologyOnlyMode;
|
||||
this.saveTopologyOnlyModeToStorage();
|
||||
}
|
||||
|
||||
getChatSidebarVisible(): boolean {
|
||||
return this.chatSidebarVisible;
|
||||
}
|
||||
|
||||
setChatSidebarVisible(visible: boolean) {
|
||||
this.chatSidebarVisible = visible;
|
||||
this.saveChatSidebarVisibleToStorage();
|
||||
}
|
||||
|
||||
toggleChatSidebarVisible() {
|
||||
this.chatSidebarVisible = !this.chatSidebarVisible;
|
||||
this.saveChatSidebarVisibleToStorage();
|
||||
}
|
||||
|
||||
startPolling() {
|
||||
this.fetchState();
|
||||
this.fetchInterval = setInterval(() => this.fetchState(), 1000);
|
||||
// Poll every 2 seconds instead of 1 second - reduces CPU/GPU load by 50%
|
||||
// Data comparison ensures we only update when something actually changes
|
||||
this.fetchInterval = setInterval(() => this.fetchState(), 2000);
|
||||
}
|
||||
|
||||
stopPolling() {
|
||||
@@ -712,6 +839,9 @@ class AppStore {
|
||||
}
|
||||
|
||||
async fetchState() {
|
||||
// Skip polling when page is hidden to save resources
|
||||
if (!this.isPageVisible) return;
|
||||
|
||||
try {
|
||||
const response = await fetch('/state');
|
||||
if (!response.ok) {
|
||||
@@ -719,19 +849,44 @@ class AppStore {
|
||||
}
|
||||
const data: RawStateResponse = await response.json();
|
||||
|
||||
// Only update topology if it actually changed (prevents unnecessary D3 re-renders)
|
||||
if (data.topology) {
|
||||
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
|
||||
const newTopology = transformTopology(data.topology, data.nodeProfiles);
|
||||
const newTopologyJson = JSON.stringify(newTopology);
|
||||
if (newTopologyJson !== this.lastTopologyJson) {
|
||||
this.lastTopologyJson = newTopologyJson;
|
||||
this.topologyData = newTopology;
|
||||
}
|
||||
}
|
||||
|
||||
// Only update instances if changed
|
||||
if (data.instances) {
|
||||
this.instances = data.instances;
|
||||
this.refreshConversationModelFromInstances();
|
||||
const newInstancesJson = JSON.stringify(data.instances);
|
||||
if (newInstancesJson !== this.lastInstancesJson) {
|
||||
this.lastInstancesJson = newInstancesJson;
|
||||
this.instances = data.instances;
|
||||
this.refreshConversationModelFromInstances();
|
||||
}
|
||||
}
|
||||
|
||||
// Only update runners if changed
|
||||
if (data.runners) {
|
||||
this.runners = data.runners;
|
||||
const newRunnersJson = JSON.stringify(data.runners);
|
||||
if (newRunnersJson !== this.lastRunnersJson) {
|
||||
this.lastRunnersJson = newRunnersJson;
|
||||
this.runners = data.runners;
|
||||
}
|
||||
}
|
||||
|
||||
// Only update downloads if changed
|
||||
if (data.downloads) {
|
||||
this.downloads = data.downloads;
|
||||
const newDownloadsJson = JSON.stringify(data.downloads);
|
||||
if (newDownloadsJson !== this.lastDownloadsJson) {
|
||||
this.lastDownloadsJson = newDownloadsJson;
|
||||
this.downloads = data.downloads;
|
||||
}
|
||||
}
|
||||
|
||||
this.lastUpdate = Date.now();
|
||||
} catch (error) {
|
||||
console.error('Error fetching state:', error);
|
||||
@@ -888,8 +1043,6 @@ class AppStore {
|
||||
|
||||
if (lastUserIndex === -1) return;
|
||||
|
||||
const lastUserMessage = this.messages[lastUserIndex];
|
||||
|
||||
// Remove any messages after the user message
|
||||
this.messages = this.messages.slice(0, lastUserIndex + 1);
|
||||
|
||||
@@ -930,7 +1083,10 @@ class AppStore {
|
||||
}
|
||||
|
||||
if (!modelToUse) {
|
||||
assistantMessage.content = 'Error: No model available. Please launch an instance first.';
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = 'Error: No model available. Please launch an instance first.';
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
@@ -948,7 +1104,10 @@ class AppStore {
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
assistantMessage.content = `Error: ${response.status} - ${errorText}`;
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = `Error: ${response.status} - ${errorText}`;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
@@ -956,7 +1115,10 @@ class AppStore {
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
assistantMessage.content = 'Error: No response stream available';
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = 'Error: No response stream available';
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.updateActiveConversation();
|
||||
return;
|
||||
@@ -984,9 +1146,16 @@ class AppStore {
|
||||
const delta = json.choices?.[0]?.delta?.content;
|
||||
if (delta) {
|
||||
fullContent += delta;
|
||||
const { displayContent } = this.stripThinkingTags(fullContent);
|
||||
const { displayContent, thinkingContent } = this.stripThinkingTags(fullContent);
|
||||
this.currentResponse = displayContent;
|
||||
assistantMessage.content = displayContent;
|
||||
|
||||
// Update the assistant message in place (triggers Svelte reactivity)
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
}
|
||||
} catch {
|
||||
// Skip malformed JSON
|
||||
@@ -995,16 +1164,25 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
const { displayContent } = this.stripThinkingTags(fullContent);
|
||||
assistantMessage.content = displayContent;
|
||||
this.currentResponse = '';
|
||||
this.updateActiveConversation();
|
||||
// Final cleanup of the message
|
||||
const { displayContent, thinkingContent } = this.stripThinkingTags(fullContent);
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = displayContent;
|
||||
this.messages[idx].thinking = thinkingContent || undefined;
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
|
||||
} catch (error) {
|
||||
assistantMessage.content = `Error: ${error instanceof Error ? error.message : 'Unknown error'}`;
|
||||
this.updateActiveConversation();
|
||||
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
|
||||
if (idx !== -1) {
|
||||
this.messages[idx].content = `Error: ${error instanceof Error ? error.message : 'Unknown error'}`;
|
||||
}
|
||||
this.persistActiveConversation();
|
||||
} finally {
|
||||
this.isLoading = false;
|
||||
this.currentResponse = '';
|
||||
this.updateActiveConversation();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1364,6 +1542,8 @@ export const lastUpdate = () => appStore.lastUpdate;
|
||||
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
|
||||
export const selectedChatModel = () => appStore.selectedChatModel;
|
||||
export const debugMode = () => appStore.getDebugMode();
|
||||
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
|
||||
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
|
||||
|
||||
// Actions
|
||||
export const startChat = () => appStore.startChat();
|
||||
@@ -1391,5 +1571,9 @@ export const isSidebarOpen = () => appStore.isSidebarOpen;
|
||||
export const toggleSidebar = () => appStore.toggleSidebar();
|
||||
export const toggleDebugMode = () => appStore.toggleDebugMode();
|
||||
export const setDebugMode = (enabled: boolean) => appStore.setDebugMode(enabled);
|
||||
export const toggleTopologyOnlyMode = () => appStore.toggleTopologyOnlyMode();
|
||||
export const setTopologyOnlyMode = (enabled: boolean) => appStore.setTopologyOnlyMode(enabled);
|
||||
export const toggleChatSidebarVisible = () => appStore.toggleChatSidebarVisible();
|
||||
export const setChatSidebarVisible = (visible: boolean) => appStore.setChatSidebarVisible(visible);
|
||||
export const refreshState = () => appStore.fetchState();
|
||||
|
||||
|
||||
@@ -1,7 +1,25 @@
|
||||
<script lang="ts">
|
||||
import '../app.css';
|
||||
import { onMount } from 'svelte';
|
||||
import { browser } from '$app/environment';
|
||||
|
||||
let { children } = $props();
|
||||
let isPageHidden = $state(false);
|
||||
|
||||
onMount(() => {
|
||||
if (!browser) return;
|
||||
|
||||
// Listen for visibility changes to pause animations when hidden
|
||||
const handleVisibilityChange = () => {
|
||||
isPageHidden = document.visibilityState === 'hidden';
|
||||
};
|
||||
|
||||
document.addEventListener('visibilitychange', handleVisibilityChange);
|
||||
|
||||
return () => {
|
||||
document.removeEventListener('visibilitychange', handleVisibilityChange);
|
||||
};
|
||||
});
|
||||
</script>
|
||||
|
||||
<svelte:head>
|
||||
@@ -9,7 +27,7 @@
|
||||
<meta name="description" content="EXO - Distributed AI Cluster Dashboard" />
|
||||
</svelte:head>
|
||||
|
||||
<div class="min-h-screen bg-background text-foreground">
|
||||
<div class="min-h-screen bg-background text-foreground" data-page-hidden={isPageHidden}>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
|
||||
|
||||
@@ -18,6 +18,10 @@
|
||||
selectedChatModel,
|
||||
debugMode,
|
||||
toggleDebugMode,
|
||||
topologyOnlyMode,
|
||||
toggleTopologyOnlyMode,
|
||||
chatSidebarVisible,
|
||||
toggleChatSidebarVisible,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview
|
||||
} from '$lib/stores/app.svelte';
|
||||
@@ -37,6 +41,8 @@
|
||||
const selectedModelId = $derived(selectedPreviewModelId());
|
||||
const loadingPreviews = $derived(isLoadingPreviews());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
|
||||
let mounted = $state(false);
|
||||
|
||||
@@ -93,17 +99,35 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
}
|
||||
|
||||
// Compute highlighted nodes from hovered instance or hovered preview
|
||||
// Memoized to avoid creating new Sets on every render
|
||||
let lastHighlightedNodesKey = '';
|
||||
let cachedHighlightedNodes: Set<string> = new Set();
|
||||
|
||||
const highlightedNodes = $derived(() => {
|
||||
// Create a key for the current state to enable memoization
|
||||
const previewKey = Array.from(hoveredPreviewNodes).sort().join(',');
|
||||
const currentKey = `${hoveredInstanceId || 'null'}:${previewKey}`;
|
||||
|
||||
// Return cached value if nothing changed
|
||||
if (currentKey === lastHighlightedNodesKey) {
|
||||
return cachedHighlightedNodes;
|
||||
}
|
||||
|
||||
lastHighlightedNodesKey = currentKey;
|
||||
|
||||
// First check instance hover
|
||||
if (hoveredInstanceId) {
|
||||
const instanceWrapped = instanceData[hoveredInstanceId];
|
||||
return unwrapInstanceNodes(instanceWrapped);
|
||||
cachedHighlightedNodes = unwrapInstanceNodes(instanceWrapped);
|
||||
return cachedHighlightedNodes;
|
||||
}
|
||||
// Then check preview hover
|
||||
if (hoveredPreviewNodes.size > 0) {
|
||||
return hoveredPreviewNodes;
|
||||
cachedHighlightedNodes = hoveredPreviewNodes;
|
||||
return cachedHighlightedNodes;
|
||||
}
|
||||
return new Set<string>();
|
||||
cachedHighlightedNodes = new Set<string>();
|
||||
return cachedHighlightedNodes;
|
||||
});
|
||||
|
||||
// Helper to estimate memory from model ID (mirrors ModelCard logic)
|
||||
@@ -472,6 +496,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const progress = parseDownloadProgress(downloadPayload);
|
||||
if (progress) {
|
||||
// Sum all values across nodes - each node downloads independently
|
||||
totalBytes += progress.totalBytes;
|
||||
downloadedBytes += progress.downloadedBytes;
|
||||
totalSpeed += progress.speed;
|
||||
@@ -489,13 +514,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
return { isDownloading: false, progress: null, perNode: [] };
|
||||
}
|
||||
|
||||
// ETA = total remaining bytes / total speed across all nodes
|
||||
const remainingBytes = totalBytes - downloadedBytes;
|
||||
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
speed: totalSpeed,
|
||||
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
|
||||
etaMs,
|
||||
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
|
||||
completedFiles,
|
||||
totalFiles,
|
||||
@@ -505,12 +534,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
};
|
||||
}
|
||||
|
||||
// Debug: Log downloads data when it changes
|
||||
$effect(() => {
|
||||
if (downloadsData && Object.keys(downloadsData).length > 0) {
|
||||
console.log('[Download Debug] Current downloads:', downloadsData);
|
||||
}
|
||||
});
|
||||
// Debug: Log downloads data when it changes (disabled in production for performance)
|
||||
// Uncomment for debugging:
|
||||
// $effect(() => {
|
||||
// if (downloadsData && Object.keys(downloadsData).length > 0) {
|
||||
// console.log('[Download Debug] Current downloads:', downloadsData);
|
||||
// }
|
||||
// });
|
||||
|
||||
// Helper to get download status for an instance
|
||||
function getInstanceDownloadStatus(instanceId: string, instanceWrapped: unknown): {
|
||||
@@ -576,6 +606,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const progress = parseDownloadProgress(downloadPayload);
|
||||
if (progress) {
|
||||
// Sum all values across nodes - each node downloads independently
|
||||
totalBytes += progress.totalBytes;
|
||||
downloadedBytes += progress.downloadedBytes;
|
||||
totalSpeed += progress.speed;
|
||||
@@ -596,13 +627,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
return { isDownloading: false, progress: null, statusText: statusInfo.statusText, perNode: [] };
|
||||
}
|
||||
|
||||
// ETA = total remaining bytes / total speed across all nodes
|
||||
const remainingBytes = totalBytes - downloadedBytes;
|
||||
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
speed: totalSpeed,
|
||||
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
|
||||
etaMs,
|
||||
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
|
||||
completedFiles,
|
||||
totalFiles,
|
||||
@@ -618,10 +653,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
function getStatusColor(statusText: string): string {
|
||||
switch (statusText) {
|
||||
case 'FAILED': return 'text-red-400';
|
||||
case 'SHUTDOWN': return 'text-gray-400';
|
||||
case 'DOWNLOADING': return 'text-blue-400';
|
||||
case 'LOADING':
|
||||
case 'WARMING UP':
|
||||
case 'WAITING': return 'text-yellow-400';
|
||||
case 'WAITING':
|
||||
case 'INITIALIZING': return 'text-yellow-400';
|
||||
case 'RUNNING': return 'text-teal-400';
|
||||
case 'READY':
|
||||
case 'LOADED': return 'text-green-400';
|
||||
@@ -644,12 +681,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
if (!r) return null;
|
||||
const [kind] = getTagged(r);
|
||||
const statusMap: Record<string, string> = {
|
||||
RunnerWaitingForInitialization: 'WaitingForInitialization',
|
||||
RunnerInitializingBackend: 'InitializingBackend',
|
||||
RunnerWaitingForModel: 'WaitingForModel',
|
||||
RunnerLoading: 'Loading',
|
||||
RunnerLoaded: 'Loaded',
|
||||
RunnerWarmingUp: 'WarmingUp',
|
||||
RunnerReady: 'Ready',
|
||||
RunnerRunning: 'Running',
|
||||
RunnerShutdown: 'Shutdown',
|
||||
RunnerFailed: 'Failed',
|
||||
};
|
||||
return kind ? statusMap[kind] || null : null;
|
||||
@@ -660,12 +700,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
if (statuses.length === 0) return { statusText: 'UNKNOWN', statusClass: 'inactive' };
|
||||
if (has('Failed')) return { statusText: 'FAILED', statusClass: 'failed' };
|
||||
if (has('Shutdown')) return { statusText: 'SHUTDOWN', statusClass: 'inactive' };
|
||||
if (has('Loading')) return { statusText: 'LOADING', statusClass: 'starting' };
|
||||
if (has('WarmingUp')) return { statusText: 'WARMING UP', statusClass: 'starting' };
|
||||
if (has('Running')) return { statusText: 'RUNNING', statusClass: 'running' };
|
||||
if (has('Ready')) return { statusText: 'READY', statusClass: 'loaded' };
|
||||
if (has('Loaded')) return { statusText: 'LOADED', statusClass: 'loaded' };
|
||||
if (has('WaitingForModel')) return { statusText: 'WAITING', statusClass: 'starting' };
|
||||
if (has('InitializingBackend')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
|
||||
if (has('WaitingForInitialization')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
|
||||
|
||||
return { statusText: 'RUNNING', statusClass: 'active' };
|
||||
}
|
||||
@@ -1107,16 +1150,47 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="shooting-star" style="top: 50%; left: 40%; --duration: 45s; --delay: 30s;"></div>
|
||||
</div>
|
||||
|
||||
<HeaderNav showHome={chatStarted} onHome={handleGoHome} />
|
||||
{#if !topologyOnlyEnabled}
|
||||
<HeaderNav
|
||||
showHome={chatStarted}
|
||||
onHome={handleGoHome}
|
||||
showSidebarToggle={true}
|
||||
sidebarVisible={sidebarVisible}
|
||||
onToggleSidebar={toggleChatSidebarVisible}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<!-- Main Content -->
|
||||
<main class="flex-1 flex overflow-hidden relative">
|
||||
<!-- Left: Conversation History Sidebar (always visible) -->
|
||||
<!-- Left: Conversation History Sidebar (hidden in topology-only mode or when toggled off) -->
|
||||
{#if !topologyOnlyEnabled && sidebarVisible}
|
||||
<div class="w-80 flex-shrink-0 border-r border-exo-yellow/10">
|
||||
<ChatSidebar class="h-full" />
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if !chatStarted}
|
||||
{#if topologyOnlyEnabled}
|
||||
<!-- TOPOLOGY ONLY MODE: Full-screen topology -->
|
||||
<div class="flex-1 flex flex-col min-h-0 min-w-0 p-4" in:fade={{ duration: 300 }}>
|
||||
<div class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden">
|
||||
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
|
||||
<!-- Exit topology-only mode button -->
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleTopologyOnlyMode}
|
||||
class="absolute bottom-4 right-4 p-2 rounded border border-exo-yellow/30 bg-exo-dark-gray/80 hover:border-exo-yellow/50 hover:bg-exo-dark-gray transition-colors cursor-pointer backdrop-blur-sm"
|
||||
title="Exit topology only mode"
|
||||
>
|
||||
<svg class="w-5 h-5 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="5" r="2" fill="currentColor" />
|
||||
<circle cx="5" cy="19" r="2" fill="currentColor" />
|
||||
<circle cx="19" cy="19" r="2" fill="currentColor" />
|
||||
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{:else if !chatStarted}
|
||||
<!-- WELCOME STATE: Topology + Instance Controls (no left sidebar for cleaner look) -->
|
||||
<div class="flex-1 flex overflow-visible relative" in:fade={{ duration: 300 }} out:fade={{ duration: 200 }}>
|
||||
|
||||
@@ -1300,14 +1374,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{:else}
|
||||
{#each nodeProg.progress.files as f}
|
||||
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
|
||||
{@const isFileComplete = filePercent >= 100}
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
|
||||
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/80">{filePercent.toFixed(1)}%</span>
|
||||
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
|
||||
style="width: {filePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
@@ -1611,13 +1686,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
in:fade={{ duration: 300, delay: 100 }}
|
||||
>
|
||||
<div class="flex-1 overflow-y-auto px-8 py-6" bind:this={chatScrollRef}>
|
||||
<div class="max-w-3xl mx-auto">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatMessages scrollParent={chatScrollRef} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
|
||||
<div class="max-w-3xl mx-auto">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} />
|
||||
</div>
|
||||
</div>
|
||||
@@ -1655,7 +1730,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<!-- Panel Header -->
|
||||
<div class="flex items-center gap-2 mb-4">
|
||||
<div class="w-2 h-2 bg-exo-yellow rounded-full shadow-[0_0_8px_rgba(255,215,0,0.6)] animate-pulse"></div>
|
||||
<h3 class="text-sm text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
|
||||
<h3 class="text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
|
||||
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
|
||||
</div>
|
||||
<div class="space-y-3 max-h-72 overflow-y-auto pr-1">
|
||||
@@ -1701,28 +1776,28 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="flex justify-between items-start mb-2 pl-2">
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
|
||||
<span class="text-exo-light-gray font-mono text-xs tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400/80 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-sm font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[10px] text-white/60 hover:text-exo-yellow transition-colors mt-0.5"
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
href={`https://huggingface.co/${instanceModelId}`}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
aria-label="View model on Hugging Face"
|
||||
>
|
||||
<span>Hugging Face</span>
|
||||
<svg class="w-3 h-3" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<svg class="w-3.5 h-3.5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M14 3h7v7"/>
|
||||
<path d="M10 14l11-11"/>
|
||||
<path d="M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6"/>
|
||||
@@ -1733,68 +1808,84 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-white/60 text-xs font-mono">{instanceInfo.nodeNames.join(', ')}</div>
|
||||
{/if}
|
||||
{#if debugEnabled && instanceConnections.length > 0}
|
||||
<div class="mt-1 space-y-0.5">
|
||||
{#each instanceConnections as conn}
|
||||
<div class="text-[10px] leading-snug font-mono text-white/70">
|
||||
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
|
||||
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
|
||||
</div>
|
||||
{/each}
|
||||
<div class="mt-2 space-y-1">
|
||||
{#each instanceConnections as conn}
|
||||
<div class="text-[11px] leading-snug font-mono text-white/70">
|
||||
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
|
||||
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Download Progress -->
|
||||
{#if downloadInfo.isDownloading && downloadInfo.progress}
|
||||
<div class="mt-2 space-y-1">
|
||||
<div class="flex justify-between text-xs font-mono">
|
||||
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
|
||||
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Download Progress -->
|
||||
{#if downloadInfo.isDownloading && downloadInfo.progress}
|
||||
<div class="mt-2 space-y-1">
|
||||
<div class="flex justify-between text-sm font-mono">
|
||||
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
|
||||
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {downloadInfo.progress.percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
|
||||
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
|
||||
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
|
||||
</div>
|
||||
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {downloadInfo.progress.percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
{#if downloadInfo.perNode.length > 0}
|
||||
<div class="mt-2 space-y-1.5 max-h-48 overflow-y-auto pr-1">
|
||||
{#each downloadInfo.perNode as nodeProg}
|
||||
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
|
||||
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
|
||||
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
|
||||
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
|
||||
</div>
|
||||
</div>
|
||||
{#if downloadInfo.perNode.length > 0}
|
||||
<div class="mt-2 space-y-2 max-h-48 overflow-y-auto pr-1">
|
||||
{#each downloadInfo.perNode as nodeProg}
|
||||
{@const nodePercent = Math.min(100, Math.max(0, nodeProg.progress.percentage))}
|
||||
{@const isExpanded = instanceDownloadExpandedNodes.has(nodeProg.nodeId)}
|
||||
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
|
||||
<button
|
||||
type="button"
|
||||
class="w-full text-left space-y-1.5"
|
||||
onclick={() => toggleInstanceDownloadDetails(nodeProg.nodeId)}
|
||||
>
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
|
||||
<span class="text-white/80 truncate pr-2">{nodeProg.nodeName}</span>
|
||||
<span class="text-blue-300">{Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%</span>
|
||||
<span class="flex items-center gap-1 text-blue-300">
|
||||
{nodePercent.toFixed(1)}%
|
||||
<svg class="w-3 h-3 text-exo-light-gray" viewBox="0 0 20 20" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<path d="M6 8l4 4 4-4" class={isExpanded ? 'transform rotate-180 origin-center transition-transform duration-150' : 'transition-transform duration-150'}></path>
|
||||
</svg>
|
||||
</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mb-1.5">
|
||||
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-blue-500/80 transition-all duration-300"
|
||||
style="width: {Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {nodePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
|
||||
<span>{formatBytes(nodeProg.progress.downloadedBytes)} / {formatBytes(nodeProg.progress.totalBytes)}</span>
|
||||
<span>{formatSpeed(nodeProg.progress.speed)} • ETA {formatEta(nodeProg.progress.etaMs)}</span>
|
||||
</div>
|
||||
{#if nodeProg.progress.files.length > 0}
|
||||
{@const inProgressFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) < 100)}
|
||||
{@const completedFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) >= 100)}
|
||||
{#if inProgressFiles.length > 0}
|
||||
<div class="space-y-1">
|
||||
{#each inProgressFiles as f}
|
||||
<div class="text-[10px] font-mono text-exo-light-gray/80">
|
||||
<div class="flex items-center justify-between">
|
||||
</button>
|
||||
|
||||
{#if isExpanded}
|
||||
<div class="mt-2 space-y-1.5">
|
||||
{#if nodeProg.progress.files.length === 0}
|
||||
<div class="text-[11px] font-mono text-exo-light-gray/70">No file details reported.</div>
|
||||
{:else}
|
||||
{#each nodeProg.progress.files as f}
|
||||
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
|
||||
{@const isFileComplete = filePercent >= 100}
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
|
||||
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/70">{Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%</span>
|
||||
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/50 rounded-sm overflow-hidden mt-0.5">
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70"
|
||||
style="width: {Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
|
||||
style="width: {filePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex items-center justify-between text-[10px] text-exo-light-gray/70 mt-0.5">
|
||||
@@ -1803,27 +1894,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{#if completedFiles.length > 0}
|
||||
<div class="pt-1 space-y-0.5">
|
||||
{#each completedFiles as f}
|
||||
<div class="text-[10px] font-mono text-exo-light-gray/70 flex items-center justify-between">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/60">100%</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-sm text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
|
||||
{:else}
|
||||
<div class="text-sm {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-xs text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
|
||||
{:else}
|
||||
<div class="text-xs {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -345,13 +345,19 @@
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-dark-gray/60 p-3 space-y-2">
|
||||
<div class="flex items-center justify-between gap-3">
|
||||
<div class="min-w-0 space-y-0.5">
|
||||
<div class="text-sm font-mono text-white truncate">{model.prettyName ?? model.modelId}</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono truncate">
|
||||
{model.modelId}
|
||||
</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
<div
|
||||
class="text-xs font-mono text-white truncate"
|
||||
title={model.prettyName ?? model.modelId}
|
||||
>{model.prettyName ?? model.modelId}</div>
|
||||
<div
|
||||
class="text-[10px] text-exo-light-gray font-mono truncate"
|
||||
title={model.modelId}
|
||||
>{model.modelId}</div>
|
||||
{#if model.status !== 'completed'}
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-xs font-mono {pct >= 100 ? 'text-green-400' : pct <= 0 ? 'text-red-400' : 'text-exo-yellow'}">
|
||||
@@ -426,14 +432,14 @@
|
||||
<style>
|
||||
.downloads-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
|
||||
grid-template-columns: repeat(auto-fill, minmax(320px, 1fr));
|
||||
}
|
||||
@media (min-width: 1024px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(3, minmax(0, 1fr));
|
||||
}
|
||||
}
|
||||
@media (min-width: 1440px) {
|
||||
@media (min-width: 1600px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(4, minmax(0, 1fr));
|
||||
}
|
||||
|
||||
18
flake.nix
18
flake.nix
@@ -31,24 +31,22 @@
|
||||
"aarch64-darwin"
|
||||
"aarch64-linux"
|
||||
];
|
||||
fenixToolchain = system: inputs.fenix.packages.${system}.stable;
|
||||
fenixToolchain = system: inputs.fenix.packages.${system}.complete;
|
||||
in
|
||||
inputs.flake-utils.lib.eachSystem systems (
|
||||
system:
|
||||
let
|
||||
pkgs = import inputs.nixpkgs {
|
||||
inherit system;
|
||||
overlays = [ ];
|
||||
overlays = [ inputs.fenix.overlays.default ];
|
||||
};
|
||||
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
|
||||
projectRootFile = "flake.nix";
|
||||
programs = {
|
||||
ruff-format.enable = true;
|
||||
ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
rustfmt.enable = true;
|
||||
rustfmt.package = (fenixToolchain system).rustfmt;
|
||||
nixpkgs-fmt.enable = true;
|
||||
};
|
||||
programs.ruff-format.enable = true;
|
||||
programs.ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
programs.rustfmt.enable = true;
|
||||
programs.rustfmt.package = (fenixToolchain system).rustfmt;
|
||||
programs.nixpkgs-fmt.enable = true;
|
||||
};
|
||||
in
|
||||
{
|
||||
@@ -78,8 +76,6 @@
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
])
|
||||
cargo-machete
|
||||
bacon
|
||||
rustup # Just here to make RustRover happy
|
||||
|
||||
# NIX
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "exo"
|
||||
version = "0.10.0"
|
||||
version = "0.3.0"
|
||||
description = "Exo"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
@@ -29,10 +29,12 @@ dependencies = [
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"bidict>=0.23.1",
|
||||
"mlx>=0.30.1",
|
||||
"mlx>=0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm>=0.28.3",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
2
rust/clippy.toml
Normal file
2
rust/clippy.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
|
||||
#allowed-duplicate-crates = ["hashbrown"]
|
||||
@@ -5,6 +5,8 @@ edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
path = "src/lib.rs"
|
||||
name = "exo_pyo3_bindings"
|
||||
|
||||
# "cdylib" needed to produce shared library for Python to import
|
||||
@@ -20,24 +22,46 @@ doc = false
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
networking.workspace = true
|
||||
networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { workspace = true, features = ["experimental-async"] }
|
||||
pyo3-stub-gen.workspace = true
|
||||
# pyo3-async-runtimes = { workspace = true, features = ["attributes", "tokio-runtime", "testing"] }
|
||||
pyo3-log.workspace = true
|
||||
pyo3 = { version = "0.27.1", features = [
|
||||
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
|
||||
"nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
"multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
|
||||
# integrations with other libraries
|
||||
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
"ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.17.2" }
|
||||
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
|
||||
pyo3-log = "0.13.2"
|
||||
|
||||
# macro dependencies
|
||||
extend.workspace = true
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
postcard = { workspace = true, features = ["use-std"] }
|
||||
rand.workspace = true
|
||||
n0-future.workspace = true
|
||||
once_cell = "1.21.3"
|
||||
thread_local = "1.1.9"
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
|
||||
|
||||
# Tracing
|
||||
@@ -46,9 +70,8 @@ n0-future.workspace = true
|
||||
#console-subscriber = "0.1.5"
|
||||
#tracing-log = "0.2.0"
|
||||
log = { workspace = true }
|
||||
env_logger = { workspace = true }
|
||||
env_logger = "0.11"
|
||||
|
||||
|
||||
# Networking
|
||||
iroh = { workspace = true }
|
||||
iroh-gossip = { workspace = true }
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
from exo_pyo3_bindings import RustNetworkingHandle, Keypair
|
||||
from asyncio import run
|
||||
|
||||
|
||||
async def main():
|
||||
nh = await RustNetworkingHandle.create(Keypair.generate_ed25519(), "mdns_example")
|
||||
recv = await nh.get_connection_receiver()
|
||||
while True:
|
||||
cm = await recv.receive()
|
||||
print(
|
||||
f"Endpoint({cm.endpoint_id}, reachable={list(map(lambda it: it.ip_addr(), cm.current_transport_addrs)) if cm.current_transport_addrs is not None else None})"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run(main())
|
||||
@@ -2,63 +2,220 @@
|
||||
# ruff: noqa: E501, F401
|
||||
|
||||
import builtins
|
||||
import enum
|
||||
import typing
|
||||
|
||||
@typing.final
|
||||
class EndpointId:
|
||||
class AllQueuesFullError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class IpAddress:
|
||||
def __str__(self) -> builtins.str: ...
|
||||
def ip_addr(self) -> builtins.str: ...
|
||||
def port(self) -> builtins.int: ...
|
||||
def zone_id(self) -> typing.Optional[builtins.int]: ...
|
||||
class ConnectionUpdate:
|
||||
@property
|
||||
def update_type(self) -> ConnectionUpdateType:
|
||||
r"""
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@property
|
||||
def remote_ipv4(self) -> builtins.str:
|
||||
r"""
|
||||
Remote connection's IPv4 address.
|
||||
"""
|
||||
@property
|
||||
def remote_tcp_port(self) -> builtins.int:
|
||||
r"""
|
||||
Remote connection's TCP port.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Keypair:
|
||||
r"""
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ed25519() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_postcard_encoding(bytes: bytes) -> Keypair:
|
||||
def generate_ecdsa() -> Keypair:
|
||||
r"""
|
||||
Decode a postcard structure into a keypair
|
||||
Generate a new ECDSA keypair.
|
||||
"""
|
||||
def to_postcard_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key with the postcard format
|
||||
"""
|
||||
def endpoint_id(self) -> EndpointId:
|
||||
r"""
|
||||
Read out the endpoint id corresponding to this keypair
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class RustConnectionMessage:
|
||||
@property
|
||||
def endpoint_id(self) -> EndpointId: ...
|
||||
@property
|
||||
def current_transport_addrs(self) -> typing.Optional[builtins.set[IpAddress]]: ...
|
||||
|
||||
@typing.final
|
||||
class RustConnectionReceiver:
|
||||
async def receive(self) -> RustConnectionMessage: ...
|
||||
|
||||
@typing.final
|
||||
class RustNetworkingHandle:
|
||||
@staticmethod
|
||||
async def create(identity: Keypair, namespace: builtins.str) -> RustNetworkingHandle: ...
|
||||
async def subscribe(self, topic: builtins.str) -> tuple[RustSender, RustReceiver]: ...
|
||||
async def get_connection_receiver(self) -> RustConnectionReceiver: ...
|
||||
def generate_secp256k1() -> Keypair:
|
||||
r"""
|
||||
Generate a new Secp256k1 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
"""
|
||||
@staticmethod
|
||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
format (i.e. unencrypted) as defined in [RFC5208].
|
||||
|
||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
"""
|
||||
@staticmethod
|
||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
structure as defined in [RFC5915].
|
||||
|
||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
"""
|
||||
@staticmethod
|
||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key as protobuf structure.
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class RustReceiver:
|
||||
async def receive(self) -> bytes: ...
|
||||
class Multiaddr:
|
||||
r"""
|
||||
Representation of a Multiaddr.
|
||||
"""
|
||||
@staticmethod
|
||||
def empty() -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress.
|
||||
"""
|
||||
@staticmethod
|
||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress with the given capacity.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its byte slice representation.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string: builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its string representation.
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
Return the length in bytes of this multiaddress.
|
||||
"""
|
||||
def is_empty(self) -> builtins.bool:
|
||||
r"""
|
||||
Returns true if the length of this multiaddress is 0.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
"""
|
||||
def to_string(self) -> builtins.str:
|
||||
r"""
|
||||
Convert a Multiaddr to a string.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class RustSender:
|
||||
async def send(self, message: bytes) -> None: ...
|
||||
class NetworkingHandle:
|
||||
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
|
||||
async def connection_update_recv(self) -> ConnectionUpdate:
|
||||
r"""
|
||||
Receives the next `ConnectionUpdate` from networking.
|
||||
"""
|
||||
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
|
||||
r"""
|
||||
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
will sleep until a `ConnectionUpdate`s is sent.
|
||||
"""
|
||||
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Subscribe to a `GossipSub` topic.
|
||||
|
||||
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
"""
|
||||
async def gossipsub_unsubscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Unsubscribes from a `GossipSub` topic.
|
||||
|
||||
Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
"""
|
||||
async def gossipsub_publish(self, topic: builtins.str, data: bytes) -> None:
|
||||
r"""
|
||||
Publishes a message with multiple topics to the `GossipSub` network.
|
||||
|
||||
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
"""
|
||||
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
|
||||
r"""
|
||||
Receives the next message from the `GossipSub` network.
|
||||
"""
|
||||
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
|
||||
r"""
|
||||
Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
will sleep until a message is sent.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class PeerId:
|
||||
r"""
|
||||
Identifier of a peer of the network.
|
||||
|
||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
"""
|
||||
@staticmethod
|
||||
def random() -> PeerId:
|
||||
r"""
|
||||
Generates a random peer ID from a cryptographically secure PRNG.
|
||||
|
||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> PeerId:
|
||||
r"""
|
||||
Parses a `PeerId` from bytes.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Returns a raw bytes representation of this `PeerId`.
|
||||
"""
|
||||
def to_base58(self) -> builtins.str:
|
||||
r"""
|
||||
Returns a base-58 encoded string of this `PeerId`.
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
Connection or disconnection event discriminant type.
|
||||
"""
|
||||
Connected = ...
|
||||
Disconnected = ...
|
||||
|
||||
|
||||
@@ -8,8 +8,7 @@ version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" },
|
||||
{ name = "Evan Quiney", email = "evanev7@gmail.com" }
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" }
|
||||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = []
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
//! SEE: <https://pyo3.rs/v0.27.1/async-await.html#detaching-from-the-interpreter-across-await>
|
||||
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
//!
|
||||
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pin_project::pin_project;
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
@@ -8,36 +10,31 @@ use std::{
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
#[pin_project]
|
||||
#[repr(transparent)]
|
||||
pub struct AllowThreads<F>(F);
|
||||
pub(crate) struct AllowThreads<F>(#[pin] F);
|
||||
|
||||
impl<F> AllowThreads<F>
|
||||
where
|
||||
Self: Future,
|
||||
{
|
||||
pub(crate) const fn new(f: F) -> Self {
|
||||
pub fn new(f: F) -> Self {
|
||||
Self(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Unpin + Send,
|
||||
F::Output: Send,
|
||||
F: Future + Ungil,
|
||||
F::Output: Ungil,
|
||||
{
|
||||
type Output = Result<F::Output, PyErr>;
|
||||
type Output = F::Output;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let waker = cx.waker();
|
||||
match Python::try_attach(|py| {
|
||||
py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker)))
|
||||
}) {
|
||||
Some(Poll::Pending) => Poll::Pending,
|
||||
Some(Poll::Ready(t)) => Poll::Ready(Ok(t)),
|
||||
// TODO: this doesn't actually work - graceful py shutdown handling
|
||||
None => Poll::Ready(Err(PyRuntimeError::new_err(
|
||||
"Python runtime shutdown while awaiting a future",
|
||||
))),
|
||||
}
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use pyo3_stub_gen::Result;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
|
||||
let stub = exo_pyo3_bindings::stub_info()?;
|
||||
stub.generate()?;
|
||||
Ok(())
|
||||
|
||||
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
|
||||
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
|
||||
//!
|
||||
//! Pattern examples include:
|
||||
//! - Async task handles: with GC-integrated cleanup
|
||||
//! - Sync/async callbacks from python: with propper eventloop handling
|
||||
//!
|
||||
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
|
||||
//! - Store mutable fields in tokio's `Mutex<T>`
|
||||
//! - For async code: take `&self` and `.lock().await`
|
||||
//! - For sync code: take `&mut self` and `.get_mut()`
|
||||
|
||||
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::BoxFuture;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::{
|
||||
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
|
||||
fn needs_tokio_runtime() {
|
||||
tokio::runtime::Handle::current();
|
||||
}
|
||||
|
||||
type SyncCallback = Box<dyn Fn() + Send + Sync>;
|
||||
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
|
||||
|
||||
enum AsyncTaskMessage {
|
||||
SyncCallback(SyncCallback),
|
||||
AsyncCallback(AsyncCallback),
|
||||
}
|
||||
|
||||
async fn async_task(
|
||||
sender: mpsc::UnboundedSender<()>,
|
||||
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
|
||||
) {
|
||||
log::info!("RUST: async task started");
|
||||
|
||||
// task state
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(1));
|
||||
|
||||
let mut sync_cbs: Vec<SyncCallback> = vec![];
|
||||
let mut async_cbs: Vec<AsyncCallback> = vec![];
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// handle incoming messages from task-handle
|
||||
message = receiver.recv() => {
|
||||
// handle closed channel by exiting
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming event
|
||||
match message {
|
||||
AsyncTaskMessage::SyncCallback(cb) => {
|
||||
sync_cbs.push(cb);
|
||||
}
|
||||
AsyncTaskMessage::AsyncCallback(cb) => {
|
||||
async_cbs.push(cb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle all other events
|
||||
_ = interval.tick() => {
|
||||
log::info!("RUST: async task tick");
|
||||
|
||||
// call back all sync callbacks
|
||||
for cb in &sync_cbs {
|
||||
cb();
|
||||
}
|
||||
|
||||
// call back all async callbacks
|
||||
for cb in &async_cbs {
|
||||
cb().await;
|
||||
}
|
||||
|
||||
// send event on unbounded channel
|
||||
sender.send(()).expect("handle receiver cannot be closed/dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: async task stopped");
|
||||
}
|
||||
|
||||
// #[gen_stub_pyclass]
|
||||
#[pyclass(name = "AsyncTaskHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyAsyncTaskHandle {
|
||||
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyAsyncTaskHandle {
|
||||
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_mut()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn new(
|
||||
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sender: Some(sender),
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAsyncTaskHandle {
|
||||
#[new]
|
||||
fn py_new(py: Python<'_>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channel TOWARDS our task
|
||||
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
|
||||
|
||||
// create communication channel FROM our task
|
||||
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
|
||||
|
||||
// perform necessary setup within tokio context - or it crashes
|
||||
let () = get_runtime().block_on(async { needs_tokio_runtime() });
|
||||
|
||||
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
|
||||
_ = get_runtime().spawn_with_scope(py, async move {
|
||||
async_task(t_sender, t_receiver).await;
|
||||
});
|
||||
Ok(Self::new(h_sender, h_receiver))
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_sync_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], None]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
|
||||
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_async_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
|
||||
let c = Python::with_gil(|py| callback.clone_ref(py));
|
||||
async move {
|
||||
if let Some(f) = Python::with_gil(|py| {
|
||||
let coroutine = c.call0(py).write_unraisable_with(py)?;
|
||||
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
|
||||
.write_unraisable_with(py)
|
||||
}) {
|
||||
_ = f.await.write_unraisable();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn receive_unit(&mut self) -> PyResult<()> {
|
||||
self.receiver
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
))
|
||||
}
|
||||
|
||||
fn drain_units(&mut self) -> PyResult<i32> {
|
||||
let mut cnt = 0;
|
||||
loop {
|
||||
match self.receiver.try_recv() {
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
return Err(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
));
|
||||
}
|
||||
Err(TryRecvError::Empty) => return Ok(cnt),
|
||||
Ok(()) => {
|
||||
cnt += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyAsyncTaskHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
use iroh::{EndpointId, SecretKey, endpoint_info::EndpointIdExt as _};
|
||||
use postcard::ser_flavors::StdVec;
|
||||
|
||||
use crate::ext::ResultExt as _;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use rand::rng;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PyKeypair(pub(crate) SecretKey);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(SecretKey::generate(&mut rng()))
|
||||
}
|
||||
/// Decode a postcard structure into a keypair
|
||||
#[staticmethod]
|
||||
fn from_postcard_encoding(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(postcard::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
/// Encode a private key with the postcard format
|
||||
fn to_postcard_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = postcard::serialize_with_flavor(&self.0, StdVec::new()).pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
/// Read out the endpoint id corresponding to this keypair
|
||||
fn endpoint_id(&self) -> PyEndpointId {
|
||||
PyEndpointId(self.0.public())
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "EndpointId", frozen)]
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PyEndpointId(pub(crate) EndpointId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyEndpointId {
|
||||
pub fn __str__(&self) -> String {
|
||||
self.0.to_z32()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EndpointId> for PyEndpointId {
|
||||
fn from(value: EndpointId) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyEndpointId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -4,27 +4,65 @@
|
||||
//!
|
||||
//!
|
||||
|
||||
mod allow_threading;
|
||||
mod identity;
|
||||
mod networking;
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(tuple_trait)]
|
||||
#![feature(unboxed_closures)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
mod examples;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
|
||||
use crate::identity::ident_submodule;
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
/// Namespace for all the constants used by this crate.
|
||||
pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use extend::ext;
|
||||
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Py, PyErr, PyResult, Python};
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[ext(pub, name = ByteArrayExt)]
|
||||
impl [u8] {
|
||||
fn pybytes(&self) -> Py<PyBytes> {
|
||||
Python::attach(|py| PyBytes::new(py, self).unbind())
|
||||
Python::with_gil(|py| PyBytes::new(py, self).unbind())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,9 +77,7 @@ pub(crate) mod ext {
|
||||
}
|
||||
|
||||
pub trait FutureExt: Future + Sized {
|
||||
/// SEE: <https://pyo3.rs/v0.27.1/async-await.html#detaching-from-the-interpreter-across-await>
|
||||
/// An [`AllowThreads`] returns a Future with an Err output if python has shutdown while we
|
||||
/// were awaiting something
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
fn allow_threads_py(self) -> AllowThreads<Self>
|
||||
where
|
||||
AllowThreads<Self>: Future,
|
||||
@@ -62,7 +98,7 @@ pub(crate) mod ext {
|
||||
#[ext(pub, name = PyResultExt)]
|
||||
impl<T> PyResult<T> {
|
||||
fn write_unraisable(self) -> Option<T> {
|
||||
Python::attach(|py| self.write_unraisable_with(py))
|
||||
Python::with_gil(|py| self.write_unraisable_with(py))
|
||||
}
|
||||
|
||||
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
|
||||
@@ -76,6 +112,85 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioRuntimeExt)]
|
||||
impl Runtime {
|
||||
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
|
||||
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscSenderExt)]
|
||||
impl<T> mpsc::Sender<T> {
|
||||
/// Sends a value, waiting until there is capacity.
|
||||
///
|
||||
/// A successful send occurs when it is determined that the other end of the
|
||||
/// channel has not hung up already. An unsuccessful send would be one where
|
||||
/// the corresponding receiver has already been closed.
|
||||
async fn send_py(&self, value: T) -> PyResult<()> {
|
||||
self.send(value)
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscReceiverExt)]
|
||||
impl<T> mpsc::Receiver<T> {
|
||||
/// Receives the next value for this receiver.
|
||||
async fn recv_py(&mut self) -> PyResult<T> {
|
||||
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
|
||||
}
|
||||
|
||||
/// Receives at most `limit` values for this receiver and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
|
||||
// get updates from receiver channel
|
||||
let mut updates = Vec::with_capacity(limit);
|
||||
let received = self.recv_many(&mut updates, limit).await;
|
||||
|
||||
// if we received zero items, then the channel was unexpectedly closed
|
||||
if limit != 0 && received == 0 {
|
||||
return Err(PyErr::receiver_channel_closed());
|
||||
}
|
||||
|
||||
Ok(updates)
|
||||
}
|
||||
|
||||
/// Tries to receive the next value for this receiver.
|
||||
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
|
||||
match self.try_recv() {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(TryRecvError::Empty) => Ok(None),
|
||||
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
use std::marker::Sized;
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct ClonePy<T>(pub Py<T>);
|
||||
|
||||
impl<T> Clone for ClonePy<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self(self.0.clone_ref(py)))
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
@@ -84,18 +199,18 @@ pub(crate) mod ext {
|
||||
#[pymodule(name = "exo_pyo3_bindings")]
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
/*
|
||||
use log::LevelFilter;
|
||||
#[allow(clippy::expect_used)]
|
||||
pyo3_log::Logger::default()
|
||||
.filter(LevelFilter::Warn)
|
||||
.install()
|
||||
.expect("logger install");
|
||||
*/
|
||||
pyo3_log::init();
|
||||
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
// TODO: ...
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,194 +1,570 @@
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt as _, ResultExt as _};
|
||||
use crate::identity::{PyEndpointId, PyKeypair};
|
||||
use iroh::SecretKey;
|
||||
use iroh::discovery::EndpointInfo;
|
||||
use iroh::discovery::mdns::DiscoveryEvent;
|
||||
use iroh_gossip::api::{ApiError, Event, GossipReceiver, GossipSender, Message};
|
||||
use n0_future::{Stream, StreamExt as _};
|
||||
use networking::ExoNet;
|
||||
use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration};
|
||||
use pyo3::prelude::*;
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use libp2p::{gossipsub, mdns};
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::collections::BTreeSet;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::{Pin, pin};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::Mutex;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
use util::ext::VecExt as _;
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
static RUNTIME: LazyLock<Runtime> =
|
||||
LazyLock::new(|| Runtime::new().expect("Failed to create tokio runtime"));
|
||||
mod exception {
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
|
||||
use pyo3_stub_gen::derive::*;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "IpAddress")]
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct PyIpAddress {
|
||||
inner: SocketAddr,
|
||||
}
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
|
||||
pub struct PyNoPeersSubscribedToTopicError {}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyIpAddress {
|
||||
pub fn __str__(&self) -> String {
|
||||
self.inner.to_string()
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
const MSG: &'static str = "\
|
||||
No peers are currently subscribed to receive messages on this topic. \
|
||||
Wait for peers to subscribe or check your network connectivity.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ip_addr(&self) -> String {
|
||||
self.inner.ip().to_string()
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn port(&self) -> u16 {
|
||||
self.inner.port()
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="AllQueuesFullError")]
|
||||
pub struct PyAllQueuesFullError {}
|
||||
|
||||
impl PyAllQueuesFullError {
|
||||
const MSG: &'static str =
|
||||
"All libp2p peers are unresponsive, resend the message or reconnect.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn zone_id(&self) -> Option<u32> {
|
||||
match self.inner {
|
||||
SocketAddr::V6(ip) => Some(ip.scope_id()),
|
||||
SocketAddr::V4(_) => None,
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAllQueuesFullError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection or disconnection event discriminant type.
|
||||
#[gen_stub_pyclass_enum]
|
||||
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum PyConnectionUpdateType {
|
||||
Connected = 0,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustNetworkingHandle")]
|
||||
pub struct PyNetworkingHandle {
|
||||
net: Arc<ExoNet>,
|
||||
#[pyclass(frozen, name = "ConnectionUpdate")]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyConnectionUpdate {
|
||||
/// Whether this is a connection or disconnection event
|
||||
#[pyo3(get)]
|
||||
update_type: PyConnectionUpdateType,
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: PyPeerId,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
remote_ipv4: String,
|
||||
|
||||
/// Remote connection's TCP port.
|
||||
#[pyo3(get)]
|
||||
remote_tcp_port: u16,
|
||||
}
|
||||
|
||||
enum ToTask {
|
||||
GossipsubSubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
||||
},
|
||||
GossipsubUnsubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<bool>,
|
||||
},
|
||||
GossipsubPublish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::enum_glob_use)]
|
||||
async fn networking_task(
|
||||
mut swarm: networking::swarm::Swarm,
|
||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use mdns::Event::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = to_task_rx.recv() => {
|
||||
// handle closed channel
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming messages
|
||||
match message {
|
||||
GossipsubSubscribe { topic, result_tx } => {
|
||||
// try to subscribe
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubUnsubscribe { topic, result_tx } => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(result) {
|
||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubPublish { topic, data, result_tx } => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
||||
Err(exception::PyAllQueuesFullError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(pyresult) {
|
||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// architectural solution to this problem:
|
||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
||||
//
|
||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
||||
// then for actual communication it will dial those peers if need-be
|
||||
swarm_event = swarm.select_next_some() => {
|
||||
match swarm_event {
|
||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
||||
message: Message {
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => {
|
||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
||||
let message = (topic.into_string(), data);
|
||||
|
||||
// send incoming message to channel (or exit if connection closed)
|
||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
e => {
|
||||
log::info!("RUST: other event {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: networking task stopped");
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
||||
}
|
||||
|
||||
impl Drop for PyNetworkingHandle {
|
||||
fn drop(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyNetworkingHandle {
|
||||
fn new(
|
||||
to_task_tx: mpsc::Sender<ToTask>,
|
||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_task_tx: Some(to_task_tx),
|
||||
connection_update_rx: Mutex::new(connection_update_rx),
|
||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
||||
}
|
||||
}
|
||||
|
||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
||||
self.to_task_tx
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNetworkingHandle {
|
||||
#[staticmethod]
|
||||
pub async fn create(identity: PyKeypair, namespace: String) -> PyResult<Self> {
|
||||
let loc: SecretKey = identity.0.clone();
|
||||
let net = Arc::new(
|
||||
RUNTIME
|
||||
.spawn(async move { ExoNet::init_iroh(loc, &namespace).await })
|
||||
.await
|
||||
// todo: pyerr better
|
||||
.pyerr()?
|
||||
.pyerr()?,
|
||||
);
|
||||
let cloned = Arc::clone(&net);
|
||||
RUNTIME.spawn(async move { cloned.start_auto_dialer().await });
|
||||
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
|
||||
// immediately beforehand to release the interpreter.
|
||||
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
|
||||
Ok(Self { net })
|
||||
// ---- Lifecycle management methods ----
|
||||
|
||||
#[new]
|
||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channels
|
||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let swarm = get_runtime()
|
||||
.block_on(async { create_swarm(identity) })
|
||||
.pyerr()?;
|
||||
|
||||
// spawn tokio task running the networking logic
|
||||
get_runtime().spawn(async move {
|
||||
networking_task(
|
||||
swarm,
|
||||
to_task_rx,
|
||||
connection_update_tx,
|
||||
gossipsub_message_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(Self::new(
|
||||
to_task_tx,
|
||||
connection_update_rx,
|
||||
gossipsub_message_rx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn subscribe(&self, topic: String) -> PyResult<(PySender, PyReceiver)> {
|
||||
let fut = self.net.subscribe(&topic);
|
||||
let (send, recv) = fut.await.pyerr()?;
|
||||
Ok((PySender { inner: send }, PyReceiver { inner: recv }))
|
||||
#[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
async fn get_connection_receiver(&self) -> PyResult<PyConnectionReceiver> {
|
||||
let stream = self.net.connection_info().await;
|
||||
Ok(PyConnectionReceiver {
|
||||
inner: Mutex::new(Box::pin(stream)),
|
||||
})
|
||||
#[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustConnectionMessage")]
|
||||
pub struct PyConnectionMessage {
|
||||
#[pyo3(get)]
|
||||
pub endpoint_id: PyEndpointId,
|
||||
#[pyo3(get)]
|
||||
pub current_transport_addrs: Option<BTreeSet<PyIpAddress>>,
|
||||
}
|
||||
// ---- Connection update receiver methods ----
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustSender")]
|
||||
struct PySender {
|
||||
inner: GossipSender,
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PySender {
|
||||
async fn send(&mut self, message: Py<PyBytes>) -> PyResult<()> {
|
||||
let bytes = Python::attach(|py| message.as_bytes(py).to_vec());
|
||||
let broadcast_fut = self.inner.broadcast(bytes.into());
|
||||
pin!(broadcast_fut).allow_threads_py().await?.pyerr()
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustReceiver")]
|
||||
struct PyReceiver {
|
||||
inner: GossipReceiver,
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyReceiver {
|
||||
async fn receive(&mut self) -> PyResult<Py<PyBytes>> {
|
||||
loop {
|
||||
let next_fut = self.inner.next();
|
||||
match pin!(next_fut).allow_threads_py().await? {
|
||||
// Successful cases
|
||||
Some(Ok(Event::Received(Message { content, .. }))) => {
|
||||
return Ok(content.to_vec().pybytes());
|
||||
}
|
||||
Some(Ok(other)) => log::info!("Dropping gossip event {other:?}"),
|
||||
None => return Err(PyStopAsyncIteration::new_err("")),
|
||||
Some(Err(ApiError::Closed { .. })) => {
|
||||
return Err(PyStopAsyncIteration::new_err(""));
|
||||
}
|
||||
|
||||
// Failure case
|
||||
Some(Err(other)) => {
|
||||
return Err(PyRuntimeError::new_err(other.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustConnectionReceiver")]
|
||||
struct PyConnectionReceiver {
|
||||
inner: Mutex<Pin<Box<dyn Stream<Item = DiscoveryEvent> + Send>>>,
|
||||
}
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
||||
// fn connection_update_is_empty(&self) -> bool {
|
||||
// self.connection_update_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
||||
// fn connection_update_len(&self) -> usize {
|
||||
// self.connection_update_rx.blocking_lock().len()
|
||||
// }
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyConnectionReceiver {
|
||||
async fn receive(&mut self) -> PyResult<PyConnectionMessage> {
|
||||
// Errors on trying to receive twice - which is a dev error. This could just block the
|
||||
// async task, but I want the error to persist
|
||||
let mut lock = self.inner.try_lock().pyerr()?;
|
||||
match lock.next().allow_threads_py().await? {
|
||||
// Successful cases
|
||||
Some(DiscoveryEvent::Discovered {
|
||||
endpoint_info: EndpointInfo { endpoint_id, data },
|
||||
..
|
||||
}) => Ok(PyConnectionMessage {
|
||||
endpoint_id: endpoint_id.into(),
|
||||
current_transport_addrs: Some(
|
||||
data.ip_addrs()
|
||||
.map(|inner| PyIpAddress { inner: *inner })
|
||||
.collect(),
|
||||
),
|
||||
}),
|
||||
Some(DiscoveryEvent::Expired { endpoint_id }) => Ok(PyConnectionMessage {
|
||||
endpoint_id: endpoint_id.into(),
|
||||
current_transport_addrs: None,
|
||||
}),
|
||||
// Failure case
|
||||
None => Err(PyStopAsyncIteration::new_err("")),
|
||||
}
|
||||
// ---- Gossipsub management methods ----
|
||||
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubSubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
}
|
||||
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubUnsubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & convert any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
|
||||
/// Publishes a message with multiple topics to the `GossipSub` network.
|
||||
///
|
||||
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
topic,
|
||||
data,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors => ignore messageID for now!!!
|
||||
let _ = rx
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---- Gossipsub message receiver methods ----
|
||||
|
||||
/// Receives the next message from the `GossipSub` network.
|
||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
||||
self.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
}
|
||||
|
||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
Ok(self
|
||||
.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?
|
||||
.map(|(t, d)| (t, d.pybytes())))
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next message from the `GossipSub` network.
|
||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
||||
// Ok(self
|
||||
// .gossipsub_message_rx
|
||||
// .blocking_lock()
|
||||
// .try_recv_py()?
|
||||
// .map(|(t, d)| (t, d.pybytes())))
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `GossipSub` message channel is empty.
|
||||
// fn gossipsub_is_empty(&self) -> bool {
|
||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `GossipSub` messages in the channel.
|
||||
// fn gossipsub_len(&self) -> usize {
|
||||
// self.gossipsub_message_rx.blocking_lock().len()
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyConnectionMessage>()?;
|
||||
m.add_class::<PyReceiver>()?;
|
||||
m.add_class::<PySender>()?;
|
||||
m.add_class::<PyConnectionReceiver>()?;
|
||||
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
|
||||
m.add_class::<exception::PyAllQueuesFullError>()?;
|
||||
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyConnectionUpdate>()?;
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
|
||||
Ok(())
|
||||
|
||||
159
rust/exo_pyo3_bindings/src/pylibp2p/ident.rs
Normal file
159
rust/exo_pyo3_bindings/src/pylibp2p/ident.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
8
rust/exo_pyo3_bindings/src/pylibp2p/mod.rs
Normal file
8
rust/exo_pyo3_bindings/src/pylibp2p/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
||||
//!
|
||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
||||
//! independent identity type of some kind or another. This may require handshaking.
|
||||
//!
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
81
rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs
Normal file
81
rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::Multiaddr;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::str::FromStr as _;
|
||||
|
||||
/// Representation of a Multiaddr.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Multiaddr", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyMultiaddr(pub Multiaddr);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyMultiaddr {
|
||||
/// Create a new, empty multiaddress.
|
||||
#[staticmethod]
|
||||
fn empty() -> Self {
|
||||
Self(Multiaddr::empty())
|
||||
}
|
||||
|
||||
/// Create a new, empty multiaddress with the given capacity.
|
||||
#[staticmethod]
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self(Multiaddr::with_capacity(n))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its string representation.
|
||||
#[staticmethod]
|
||||
fn from_string(string: String) -> PyResult<Self> {
|
||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
||||
}
|
||||
|
||||
/// Return the length in bytes of this multiaddress.
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns true if the length of this multiaddress is 0.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_vec();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Convert a Multiaddr to a string.
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __repr__(&self) -> String {
|
||||
format!("Multiaddr({})", self.0)
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __str__(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyMultiaddr>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
54
rust/exo_pyo3_bindings/tests/dummy.rs
Normal file
54
rust/exo_pyo3_bindings/tests/dummy.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use core::mem::drop;
|
||||
use core::option::Option::Some;
|
||||
use core::time::Duration;
|
||||
use tokio;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drop_channel() {
|
||||
struct Ping;
|
||||
|
||||
let (tx, mut rx) = mpsc::channel::<Ping>(10);
|
||||
|
||||
let _ = tokio::spawn(async move {
|
||||
println!("TASK: entered");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = rx.recv() => {
|
||||
match result {
|
||||
Some(_) => {
|
||||
println!("TASK: pinged");
|
||||
}
|
||||
None => {
|
||||
println!("TASK: closing channel");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {
|
||||
println!("TASK: heartbeat");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("TASK: exited");
|
||||
});
|
||||
|
||||
let tx2 = tx.clone();
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx.send(Ping).await.expect("Should not fail");
|
||||
drop(tx);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx2.send(Ping).await.expect("Should not fail");
|
||||
drop(tx2);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
}
|
||||
}
|
||||
@@ -1,47 +1,34 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from exo_pyo3_bindings import (
|
||||
Keypair,
|
||||
RustNetworkingHandle,
|
||||
RustReceiver,
|
||||
RustConnectionReceiver,
|
||||
)
|
||||
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sleep_on_multiple_items() -> None:
|
||||
print("PYTHON: starting handle")
|
||||
s_h = await RustNetworkingHandle.create(Keypair.generate_ed25519(), "test")
|
||||
r_h = await RustNetworkingHandle.create(Keypair.generate_ed25519(), "test")
|
||||
h = NetworkingHandle(Keypair.generate_ed25519())
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
cm = await r_h.get_connection_receiver()
|
||||
|
||||
_, recv = await r_h.subscribe("topic")
|
||||
send, _ = await s_h.subscribe("topic")
|
||||
|
||||
ct = asyncio.create_task(_await_cons(cm))
|
||||
mt = asyncio.create_task(_await_msg(recv))
|
||||
ct = asyncio.create_task(_await_cons(h))
|
||||
mt = asyncio.create_task(_await_msg(h))
|
||||
|
||||
# sleep for 4 ticks
|
||||
for i in range(4):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await send.send(b"somehting or other")
|
||||
|
||||
await ct
|
||||
await mt
|
||||
try:
|
||||
await h.gossipsub_publish("topic", b"somehting or other")
|
||||
except NoPeersSubscribedToTopicError as e:
|
||||
print("caught it", e)
|
||||
|
||||
|
||||
async def _await_cons(h: RustConnectionReceiver):
|
||||
async def _await_cons(h: NetworkingHandle):
|
||||
while True:
|
||||
c = await h.receive()
|
||||
c = await h.connection_update_recv()
|
||||
print(f"PYTHON: connection update: {c}")
|
||||
|
||||
|
||||
async def _await_msg(r: RustReceiver):
|
||||
async def _await_msg(h: NetworkingHandle):
|
||||
while True:
|
||||
m = await r.receive()
|
||||
m = await h.gossipsub_recv()
|
||||
print(f"PYTHON: message: {m}")
|
||||
|
||||
@@ -1,18 +1,44 @@
|
||||
[package]
|
||||
name = "networking"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
blake3 = { workspace = true, features = ["neon", "rayon"] }
|
||||
iroh = { workspace = true, features = ["discovery-local-network"] }
|
||||
iroh-gossip.workspace = true
|
||||
log.workspace = true
|
||||
n0-error.workspace = true
|
||||
n0-future.workspace = true
|
||||
rand.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter"] }
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "networking"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
@@ -1,85 +1,74 @@
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used, clippy::cargo)]
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use iroh::SecretKey;
|
||||
use iroh_gossip::api::{Event, Message};
|
||||
use n0_future::StreamExt as _;
|
||||
use networking::ExoNet;
|
||||
use tokio::time::sleep;
|
||||
use tokio::{io, io::AsyncBufReadExt as _};
|
||||
use futures::stream::StreamExt as _;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt()
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init()
|
||||
.expect("logger");
|
||||
.try_init();
|
||||
|
||||
// Configure swarm
|
||||
let net = Arc::new(
|
||||
ExoNet::init_iroh(SecretKey::generate(&mut rand::rng()), "chatroom")
|
||||
.await
|
||||
.expect("iroh init shouldn't fail"),
|
||||
);
|
||||
let innet = Arc::clone(&net);
|
||||
let jh1 = tokio::spawn(async move { innet.start_auto_dialer().await });
|
||||
|
||||
while net.known_peers.lock().await.is_empty() {
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
let mut swarm =
|
||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let (send, mut recv) = net
|
||||
.subscribe("chatting")
|
||||
.await
|
||||
.expect("topic shouldn't fail");
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("Subscribing to topic failed");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
let jh2 = tokio::spawn(async move {
|
||||
loop {
|
||||
if let Ok(Some(line)) = stdin.next_line().await
|
||||
&& let Err(e) = send.broadcast(line.into()).await
|
||||
{
|
||||
println!("Publish error: {e:?}");
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
// on gossipsub outgoing
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(Ok(event)) = recv.next().await {
|
||||
match event {
|
||||
event = swarm.select_next_some() => match event {
|
||||
// on gossipsub incoming
|
||||
Event::Received(Message {
|
||||
content,
|
||||
delivered_from,
|
||||
..
|
||||
}) => println!(
|
||||
"\n\nGot message: '{}' with from peer: {delivered_from}\n\n",
|
||||
String::from_utf8_lossy(&content),
|
||||
),
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
Event::NeighborUp(peer_id) => {
|
||||
println!("\n\nConnected to: {peer_id}\n\n");
|
||||
}
|
||||
Event::NeighborDown(peer_id) => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}\n\n");
|
||||
}
|
||||
Event::Lagged => {
|
||||
eprintln!("\n\nLagged\n\n");
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
discovery::Event::ConnectionClosed {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
jh1.await.unwrap();
|
||||
jh2.await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
127
rust/networking/examples/chatroom_manual.rs
Normal file
127
rust/networking/examples/chatroom_manual.rs
Normal file
@@ -0,0 +1,127 @@
|
||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use futures::stream::StreamExt;
|
||||
use libp2p::{
|
||||
gossipsub, mdns, noise,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use std::{error::Error, hash::Hash};
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
||||
#[derive(NetworkBehaviour)]
|
||||
struct MyBehaviour {
|
||||
gossipsub: gossipsub::Behaviour,
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.try_init();
|
||||
|
||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
||||
.with_tokio()
|
||||
.with_tcp(
|
||||
tcp::Config::default(),
|
||||
noise::Config::new,
|
||||
yamux::Config::default,
|
||||
)?
|
||||
.with_behaviour(|key| {
|
||||
// Set a custom gossipsub configuration
|
||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
||||
.heartbeat_interval(Duration::from_secs(10))
|
||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
||||
.build()
|
||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
||||
gossipsub_config,
|
||||
)?;
|
||||
|
||||
let mdns =
|
||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
||||
Ok(MyBehaviour { gossipsub, mdns })
|
||||
})?
|
||||
.build();
|
||||
|
||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
||||
|
||||
// Create a Gossipsub topic
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
// subscribes to our topic
|
||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
SwarmEvent::NewListenAddr { address, .. } => {
|
||||
println!("Local node is listening on {address}");
|
||||
}
|
||||
e => {
|
||||
println!("Other swarm event: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
#![allow(clippy::cargo, clippy::unwrap_used)]
|
||||
use iroh::{SecretKey, endpoint_info::EndpointIdExt as _};
|
||||
use n0_future::StreamExt as _;
|
||||
use networking::ExoNet;
|
||||
|
||||
// Launch a mock version of iroh for testing purposes
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
let key = SecretKey::generate(&mut rand::rng());
|
||||
let dbg_key = key.public().to_z32();
|
||||
println!("Starting with pk: {dbg_key}");
|
||||
let net = ExoNet::init_iroh(key, "").await.unwrap();
|
||||
|
||||
let mut conn_info = net.connection_info().await;
|
||||
|
||||
let task = tokio::task::spawn(async move {
|
||||
println!("Inner task started!");
|
||||
loop {
|
||||
dbg!(conn_info.next().await);
|
||||
}
|
||||
});
|
||||
|
||||
println!("Task started!");
|
||||
|
||||
task.await.unwrap();
|
||||
}
|
||||
44
rust/networking/src/RESEARCH_NOTES.txt
Normal file
44
rust/networking/src/RESEARCH_NOTES.txt
Normal file
@@ -0,0 +1,44 @@
|
||||
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
|
||||
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
|
||||
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
|
||||
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
|
||||
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
|
||||
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
|
||||
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
|
||||
|
||||
https://stackoverflow.com/questions/17169298/af-packet-on-osx
|
||||
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
|
||||
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
|
||||
|
||||
https://www.unix.com/man_page/mojave/4/ip/
|
||||
^OS X manpages for IP
|
||||
|
||||
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
|
||||
^driver kit, system extensions & kexts for macOS
|
||||
|
||||
----
|
||||
|
||||
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
|
||||
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
|
||||
----> here is a guide on how to set up thunderbolt-ethernet on linux
|
||||
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
|
||||
|
||||
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
|
||||
^GPT discussion about making socket interface
|
||||
|
||||
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
|
||||
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
|
||||
^GPT discussion about accessing TB on MacOS low level interactions
|
||||
|
||||
--------------------------------
|
||||
|
||||
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
|
||||
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
|
||||
|
||||
|
||||
---------------------------------
|
||||
|
||||
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
|
||||
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
|
||||
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
|
||||
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c
|
||||
383
rust/networking/src/discovery.rs
Normal file
383
rust/networking/src/discovery.rs
Normal file
@@ -0,0 +1,383 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use crate::keep_alive;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
use libp2p::swarm::behaviour::ConnectionEstablished;
|
||||
use libp2p::swarm::dial_opts::DialOpts;
|
||||
use libp2p::swarm::{
|
||||
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
|
||||
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
|
||||
THandlerOutEvent, ToSwarm, dummy,
|
||||
};
|
||||
use libp2p::{Multiaddr, PeerId, identity, mdns};
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::convert::Infallible;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use util::wakerdeque::WakerDeque;
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
mod managed {
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
|
||||
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
ping: ping::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
mdns: mdns_behaviour(keypair)?,
|
||||
ping: ping_behaviour(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
|
||||
use mdns::{Config, tokio};
|
||||
|
||||
// mDNS config => enable IPv6
|
||||
let mdns_config = Config {
|
||||
ttl: MDNS_RECORD_TTL,
|
||||
query_interval: MDNS_QUERY_INTERVAL,
|
||||
|
||||
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
|
||||
Ok(mdns_behaviour?)
|
||||
}
|
||||
|
||||
fn ping_behaviour() -> ping::Behaviour {
|
||||
ping::Behaviour::new(
|
||||
ping::Config::new()
|
||||
.with_timeout(PING_TIMEOUT)
|
||||
.with_interval(PING_INTERVAL),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Events for when a listening connection is truly established and truly closed.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
ConnectionEstablished {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
ConnectionClosed {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
}
|
||||
|
||||
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
|
||||
///
|
||||
/// The behaviour operates as such:
|
||||
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
|
||||
/// to the swarm.
|
||||
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
|
||||
/// immediately, and expired but connected peers are disconnected from immediately.
|
||||
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
|
||||
/// connected peers are disconnected from.
|
||||
pub struct Behaviour {
|
||||
// state-tracking for managed behaviors & mDNS-discovered peers
|
||||
managed: managed::Behaviour,
|
||||
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
|
||||
|
||||
retry_delay: Delay, // retry interval
|
||||
|
||||
// pending events to emmit => waker-backed Deque to control polling
|
||||
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
managed: managed::Behaviour::new(keypair)?,
|
||||
mdns_discovered: HashMap::new(),
|
||||
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
|
||||
pending_events: WakerDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
|
||||
self.pending_events.push_back(ToSwarm::Dial {
|
||||
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
|
||||
})
|
||||
}
|
||||
|
||||
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
|
||||
// push front to make this IMMEDIATE
|
||||
self.pending_events.push_front(ToSwarm::CloseConnection {
|
||||
peer_id,
|
||||
connection: CloseConnection::One(connection),
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
self.dial(p, ma.clone()); // always connect
|
||||
|
||||
// get peer's multi-addresses or insert if missing
|
||||
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
|
||||
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
|
||||
continue;
|
||||
};
|
||||
|
||||
// multiaddress should never already be present - else something has gone wrong
|
||||
let is_new_addr = mas.insert(ma);
|
||||
assert!(is_new_addr, "cannot discover a discovered peer");
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
// at this point, we *must* have the peer
|
||||
let mas = self
|
||||
.mdns_discovered
|
||||
.get_mut(&p)
|
||||
.expect("nonexistent peer cannot expire");
|
||||
|
||||
// at this point, we *must* have the multiaddress
|
||||
let was_present = mas.remove(&ma);
|
||||
assert!(was_present, "nonexistent multiaddress cannot expire");
|
||||
|
||||
// if empty, remove the peer-id entirely
|
||||
if mas.is_empty() {
|
||||
self.mdns_discovered.remove(&p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn on_connection_established(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out connected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
|
||||
fn on_connection_closed(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out disconnected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl NetworkBehaviour for Behaviour {
|
||||
type ConnectionHandler =
|
||||
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
|
||||
type ToSwarm = Event;
|
||||
|
||||
// simply delegate to underlying mDNS behaviour
|
||||
|
||||
delegate! {
|
||||
to self.managed {
|
||||
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
|
||||
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_established_inbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
local_addr: &Multiaddr,
|
||||
remote_addr: &Multiaddr,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_inbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_question_mark)]
|
||||
fn handle_established_outbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
addr: &Multiaddr,
|
||||
role_override: Endpoint,
|
||||
port_use: PortUse,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_outbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
addr,
|
||||
role_override,
|
||||
port_use,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn on_connection_handler_event(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
event: THandlerOutEvent<Self>,
|
||||
) {
|
||||
match event {
|
||||
Either::Left(ev) => libp2p::core::util::unreachable(ev),
|
||||
Either::Right(ev) => {
|
||||
self.managed
|
||||
.on_connection_handler_event(peer_id, connection_id, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hook into these methods to drive behavior
|
||||
|
||||
fn on_swarm_event(&mut self, event: FromSwarm) {
|
||||
self.managed.on_swarm_event(event); // let mDNS handle swarm events
|
||||
|
||||
// handle swarm events to update internal state:
|
||||
match event {
|
||||
FromSwarm::ConnectionEstablished(ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection established event which is filtered correctly
|
||||
self.on_connection_established(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
FromSwarm::ConnectionClosed(ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
|
||||
// since we are running TCP/IP transport layer, we are assuming that
|
||||
// no address changes can occur, hence encountering one is a fatal error
|
||||
FromSwarm::AddressChange(a) => {
|
||||
unreachable!("unhandlable: address change encountered: {:?}", a)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
|
||||
// delegate to managed behaviors for any behaviors they need to perform
|
||||
match self.managed.poll(cx) {
|
||||
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
|
||||
match e {
|
||||
// handle discovered and expired events from mDNS
|
||||
managed::BehaviourEvent::Mdns(e) => match e.clone() {
|
||||
mdns::Event::Discovered(peers) => {
|
||||
self.handle_mdns_discovered(peers);
|
||||
}
|
||||
mdns::Event::Expired(peers) => {
|
||||
self.handle_mdns_expired(peers);
|
||||
}
|
||||
},
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// since we just consumed an event, we should immediately wake just in case
|
||||
// there are more events to come where that came from
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
|
||||
// forward any other mDNS event to the swarm or its connection handler(s)
|
||||
Poll::Ready(e) => {
|
||||
return Poll::Ready(
|
||||
e.map_out(|_| unreachable!("events returning to swarm already handled"))
|
||||
.map_in(Either::Right),
|
||||
);
|
||||
}
|
||||
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
}
|
||||
}
|
||||
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
|
||||
}
|
||||
|
||||
// send out any pending events from our own service
|
||||
if let Some(e) = self.pending_events.pop_front(cx) {
|
||||
return Poll::Ready(e.map_in(Either::Left));
|
||||
}
|
||||
|
||||
// wait for pending events
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
44
rust/networking/src/keep_alive.rs
Normal file
44
rust/networking/src/keep_alive.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use delegate::delegate;
|
||||
use libp2p::swarm::handler::ConnectionEvent;
|
||||
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
|
||||
/// the connection alive.
|
||||
#[derive(Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct ConnectionHandler(dummy::ConnectionHandler);
|
||||
|
||||
impl ConnectionHandler {
|
||||
pub fn new() -> Self {
|
||||
ConnectionHandler(dummy::ConnectionHandler)
|
||||
}
|
||||
}
|
||||
|
||||
impl handler::ConnectionHandler for ConnectionHandler {
|
||||
// delegate types and implementation mostly to dummy handler
|
||||
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
|
||||
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
|
||||
type InboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
|
||||
type OutboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
|
||||
type InboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
|
||||
type OutboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
|
||||
|
||||
delegate! {
|
||||
to self.0 {
|
||||
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
|
||||
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
|
||||
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
|
||||
}
|
||||
}
|
||||
|
||||
// specifically override this to force connection to stay alive
|
||||
fn connection_keep_alive(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -1,149 +1,64 @@
|
||||
use std::collections::BTreeSet;
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
use iroh::{
|
||||
Endpoint, EndpointId, SecretKey, TransportAddr,
|
||||
discovery::{
|
||||
Discovery as _, EndpointData, IntoDiscoveryError,
|
||||
mdns::{DiscoveryEvent, MdnsDiscovery},
|
||||
},
|
||||
endpoint::BindError,
|
||||
endpoint_info::EndpointIdExt as _,
|
||||
protocol::Router,
|
||||
};
|
||||
use iroh_gossip::{
|
||||
Gossip, TopicId,
|
||||
api::{ApiError, GossipReceiver, GossipSender},
|
||||
};
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(unboxed_closures)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
use n0_error::{e, stack_error};
|
||||
use n0_future::{Stream, StreamExt as _};
|
||||
use tokio::sync::Mutex;
|
||||
pub mod discovery;
|
||||
pub mod keep_alive;
|
||||
pub mod swarm;
|
||||
|
||||
#[stack_error(derive, add_meta, from_sources)]
|
||||
pub enum ExoError {
|
||||
#[error(transparent)]
|
||||
FailedBinding { source: BindError },
|
||||
/// The gossip topic was closed.
|
||||
#[error(transparent)]
|
||||
FailedCommunication { source: ApiError },
|
||||
#[error("No IP Protocol supported on device")]
|
||||
IPNotSupported { source: IntoDiscoveryError },
|
||||
#[error("No peers found before subscribing")]
|
||||
NoPeers,
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ExoNet {
|
||||
pub alpn: String,
|
||||
pub router: Router,
|
||||
pub gossip: Gossip,
|
||||
pub mdns: MdnsDiscovery,
|
||||
pub known_peers: Mutex<BTreeSet<EndpointId>>,
|
||||
}
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use extend::ext;
|
||||
use libp2p::Multiaddr;
|
||||
use libp2p::multiaddr::Protocol;
|
||||
use std::net::IpAddr;
|
||||
|
||||
impl ExoNet {
|
||||
#[inline]
|
||||
pub async fn init_iroh(sk: SecretKey, namespace: &str) -> Result<Self, ExoError> {
|
||||
let endpoint = Endpoint::empty_builder(iroh::RelayMode::Disabled)
|
||||
.secret_key(sk)
|
||||
.bind()
|
||||
.await?;
|
||||
let mdns = MdnsDiscovery::builder().build(endpoint.id())?;
|
||||
let endpoint_addr = endpoint.addr();
|
||||
|
||||
let bound = endpoint_addr.ip_addrs().map(|it| TransportAddr::Ip(*it));
|
||||
|
||||
log::info!("publishing {endpoint_addr:?} with mdns");
|
||||
mdns.publish(&EndpointData::new(bound));
|
||||
endpoint.discovery().add(mdns.clone());
|
||||
let alpn = format!("/exo_discovery_network/{namespace}");
|
||||
// max msg size 4MB
|
||||
let gossip = Gossip::builder()
|
||||
.max_message_size(4 * 1024 * 1024)
|
||||
.alpn(&alpn)
|
||||
.spawn(endpoint.clone());
|
||||
let router = Router::builder(endpoint)
|
||||
.accept(&alpn, gossip.clone())
|
||||
.spawn();
|
||||
Ok(Self {
|
||||
alpn,
|
||||
router,
|
||||
gossip,
|
||||
mdns,
|
||||
known_peers: Mutex::new(BTreeSet::new()),
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn start_auto_dialer(&self) {
|
||||
let mut recv = self.connection_info().await;
|
||||
|
||||
log::info!(
|
||||
"Starting auto dialer for id {}",
|
||||
self.router.endpoint().id().to_z32()
|
||||
);
|
||||
while let Some(item) = recv.next().await {
|
||||
match item {
|
||||
DiscoveryEvent::Discovered { endpoint_info, .. } => {
|
||||
let id = endpoint_info.endpoint_id;
|
||||
if id == self.router.endpoint().id() {
|
||||
continue;
|
||||
}
|
||||
if !self
|
||||
.known_peers
|
||||
.lock()
|
||||
.await
|
||||
.contains(&endpoint_info.endpoint_id)
|
||||
&& let Ok(conn) = self
|
||||
.router
|
||||
.endpoint()
|
||||
.connect(endpoint_info, self.alpn.as_bytes())
|
||||
.await
|
||||
&& conn.alpn() == self.alpn.as_bytes()
|
||||
{
|
||||
self.known_peers.lock().await.insert(id);
|
||||
match self.gossip.handle_connection(conn).await {
|
||||
Ok(()) => log::info!("Successfully dialled"),
|
||||
Err(_) => log::info!("Failed to dial peer"),
|
||||
}
|
||||
}
|
||||
#[ext(pub, name = MultiaddrExt)]
|
||||
impl Multiaddr {
|
||||
/// If the multiaddress corresponds to a TCP address, extracts it
|
||||
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {
|
||||
let mut ps = self.into_iter();
|
||||
let ip = if let Some(p) = ps.next() {
|
||||
match p {
|
||||
Protocol::Ip4(ip) => IpAddr::V4(ip),
|
||||
Protocol::Ip6(ip) => IpAddr::V6(ip),
|
||||
_ => return None,
|
||||
}
|
||||
DiscoveryEvent::Expired { endpoint_id } => {
|
||||
log::info!("Peer expired {}", endpoint_id.to_z32());
|
||||
self.known_peers.lock().await.remove(&endpoint_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
let Some(Protocol::Tcp(port)) = ps.next() else {
|
||||
return None;
|
||||
};
|
||||
Some((ip, port))
|
||||
}
|
||||
log::info!("Auto dialer stopping");
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn connection_info(&self) -> impl Stream<Item = DiscoveryEvent> + Unpin + use<> {
|
||||
self.mdns.subscribe().await
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn subscribe(&self, topic: &str) -> Result<(GossipSender, GossipReceiver), ExoError> {
|
||||
if self.known_peers.lock().await.is_empty() {
|
||||
return Err(e!(ExoError::NoPeers));
|
||||
}
|
||||
Ok(self
|
||||
.gossip
|
||||
.subscribe_and_join(
|
||||
str_to_topic_id(topic),
|
||||
self.known_peers.lock().await.clone().into_iter().collect(),
|
||||
)
|
||||
.await?
|
||||
.split())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[allow(clippy::expect_used)]
|
||||
pub async fn shutdown(&self) {
|
||||
self.router.shutdown().await.expect("router panic");
|
||||
}
|
||||
}
|
||||
|
||||
fn str_to_topic_id(data: &str) -> TopicId {
|
||||
TopicId::from_bytes(*blake3::hash(data.as_bytes()).as_bytes())
|
||||
pub(crate) mod private {
|
||||
#![allow(dead_code)]
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
145
rust/networking/src/swarm.rs
Normal file
145
rust/networking/src/swarm.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use libp2p::{SwarmBuilder, identity};
|
||||
|
||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
///
|
||||
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
|
||||
/// even be, and how to inject the right version into this config/initialization. E.g. should
|
||||
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
|
||||
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||
|
||||
/// Create and configure a swarm which listens to all ports on OS
|
||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
.with_behaviour(Behaviour::new)?
|
||||
.build();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
Ok(swarm)
|
||||
}
|
||||
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
use libp2p::pnet::{PnetError, PnetOutput};
|
||||
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
|
||||
use std::{env, sync::LazyLock};
|
||||
|
||||
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
|
||||
/// See [`pnet_upgrade`] for more.
|
||||
static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| {
|
||||
let builder = Sha3_256::new().update(b"exo_discovery_network");
|
||||
|
||||
if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) {
|
||||
let bytes = var.into_bytes();
|
||||
builder.update(&bytes)
|
||||
} else {
|
||||
builder.update(NETWORK_VERSION)
|
||||
}
|
||||
.finalize()
|
||||
});
|
||||
|
||||
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
|
||||
/// also different-versioned instances of this same network.
|
||||
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
|
||||
async fn pnet_upgrade<TSocket>(
|
||||
socket: TSocket,
|
||||
_: impl Sized,
|
||||
) -> Result<PnetOutput<TSocket>, PnetError>
|
||||
where
|
||||
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
use pnet::{PnetConfig, PreSharedKey};
|
||||
PnetConfig::new(PreSharedKey::new(*PNET_PRESHARED_KEY))
|
||||
.handshake(socket)
|
||||
.await
|
||||
}
|
||||
|
||||
/// TCP/IP transport layer configuration.
|
||||
pub fn tcp_transport(
|
||||
keypair: &identity::Keypair,
|
||||
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
|
||||
use libp2p::{
|
||||
core::upgrade::Version,
|
||||
tcp::{Config, tokio},
|
||||
};
|
||||
|
||||
// `TCP_NODELAY` enabled => avoid latency
|
||||
let tcp_config = Config::default().nodelay(true);
|
||||
|
||||
// V1 + lazy flushing => 0-RTT negotiation
|
||||
let upgrade_version = Version::V1Lazy;
|
||||
|
||||
// Noise is faster than TLS + we don't care much for security
|
||||
let noise_config = noise::Config::new(keypair)?;
|
||||
|
||||
// Use default Yamux config for multiplexing
|
||||
let yamux_config = yamux::Config::default();
|
||||
|
||||
// Create new Tokio-driven TCP/IP transport layer
|
||||
let base_transport = tokio::Transport::new(tcp_config)
|
||||
.and_then(pnet_upgrade)
|
||||
.upgrade(upgrade_version)
|
||||
.authenticate(noise_config)
|
||||
.multiplex(yamux_config);
|
||||
|
||||
// Return boxed transport (to flatten complex type)
|
||||
Ok(base_transport.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
mod behaviour {
|
||||
use crate::{alias, discovery};
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{gossipsub, identity};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Behavior of the Swarm which composes all desired behaviors:
|
||||
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
pub discovery: discovery::Behaviour,
|
||||
pub gossipsub: gossipsub::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
|
||||
Ok(Self {
|
||||
discovery: discovery::Behaviour::new(keypair)?,
|
||||
gossipsub: gossipsub_behaviour(keypair),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
|
||||
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
// => signed message authenticity + strict validation mode means the message-ID is
|
||||
// automatically provided by gossipsub w/out needing to provide custom message-ID function
|
||||
gossipsub::Behaviour::new(
|
||||
MessageAuthenticity::Signed(keypair.clone()),
|
||||
ConfigBuilder::default()
|
||||
.publish_queue_duration(Duration::from_secs(15))
|
||||
.max_transmit_size(1024 * 1024)
|
||||
.validation_mode(ValidationMode::Strict)
|
||||
.build()
|
||||
.expect("the configuration should always be valid"),
|
||||
)
|
||||
.expect("creating gossipsub behavior should always work")
|
||||
}
|
||||
}
|
||||
7
rust/networking/tests/dummy.rs
Normal file
7
rust/networking/tests/dummy.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
// maybe this will hold test in the future...??
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn does_nothing() {}
|
||||
}
|
||||
2
rust/rust-toolchain.toml
Normal file
2
rust/rust-toolchain.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[toolchain]
|
||||
channel = "nightly"
|
||||
47
rust/system_custodian/Cargo.toml
Normal file
47
rust/system_custodian/Cargo.toml
Normal file
@@ -0,0 +1,47 @@
|
||||
[package]
|
||||
name = "system_custodian"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "system_custodian"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
path = "src/bin/main.rs"
|
||||
name = "system_custodian"
|
||||
doc = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
4
rust/system_custodian/src/bin/main.rs
Normal file
4
rust/system_custodian/src/bin/main.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! TODO: documentation
|
||||
//!
|
||||
|
||||
fn main() {}
|
||||
69
rust/system_custodian/src/lib.rs
Normal file
69
rust/system_custodian/src/lib.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
|
||||
//!
|
||||
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
|
||||
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
|
||||
//! etc.) is in an appropriate state to facilitate the running of Exo application.
|
||||
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
|
||||
//! service which Exo application use to _control & query_ it.
|
||||
//!
|
||||
//! # Lifecycle
|
||||
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
|
||||
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
|
||||
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
|
||||
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
|
||||
//! destructive to the user's pre-existing configurations.
|
||||
//!
|
||||
//! # Responsibilities
|
||||
//! TODO: these are purely on MacOS, but change to be more broad
|
||||
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
|
||||
//! 1. duplicate the current network set
|
||||
//! 2. modify existing services to turn on IPv6 if not there
|
||||
//! 3. remove any bridge services & add any missing services that AREN'T bridge
|
||||
//! TODO: In the future:
|
||||
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
|
||||
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
|
||||
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
|
||||
//! logic, this would be the place to spin that up.
|
||||
//!
|
||||
//! Then it will watch the SCDynamicStore for:
|
||||
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
|
||||
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
|
||||
//! interface of any changes
|
||||
//! 2. watch for any __undesirable__ changes to configuration and revert it
|
||||
//!
|
||||
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
|
||||
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
|
||||
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
|
||||
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
|
||||
//! over D-Bus
|
||||
//! TODO:
|
||||
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
|
||||
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
|
||||
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
|
||||
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
|
||||
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
|
||||
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
|
||||
//! then this would be the place to carry out discovery and propper handshakes with devices
|
||||
//! on the other end of the link.
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {}
|
||||
25
rust/util/Cargo.toml
Normal file
25
rust/util/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "util"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "util"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
thiserror = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
internment = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
bon = { workspace = true }
|
||||
recursion = { workspace = true }
|
||||
53
rust/util/src/lib.rs
Normal file
53
rust/util/src/lib.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub mod nonempty;
|
||||
pub mod wakerdeque;
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub mod ext {
|
||||
use extend::ext;
|
||||
|
||||
#[ext(pub, name = BoxedSliceExt)]
|
||||
impl<T> Box<[T]> {
|
||||
#[inline]
|
||||
fn map<B, F>(self, f: F) -> Box<[B]>
|
||||
where
|
||||
F: FnMut(T) -> B,
|
||||
{
|
||||
self.into_iter().map(f).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = VecExt)]
|
||||
impl<T> Vec<T> {
|
||||
#[inline]
|
||||
fn map<B, F>(self, f: F) -> Vec<B>
|
||||
where
|
||||
F: FnMut(T) -> B,
|
||||
{
|
||||
self.into_iter().map(f).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
138
rust/util/src/nonempty.rs
Normal file
138
rust/util/src/nonempty.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use std::slice::SliceIndex;
|
||||
use std::{ops, slice};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("Cannot create to `NonemptyArray` because the supplied slice is empty")]
|
||||
pub struct EmptySliceError;
|
||||
|
||||
/// A pointer to a non-empty fixed-size slice allocated on the heap.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
#[repr(transparent)]
|
||||
pub struct NonemptyArray<T>(Box<[T]>);
|
||||
|
||||
#[allow(clippy::arbitrary_source_item_ordering)]
|
||||
impl<T> NonemptyArray<T> {
|
||||
#[inline]
|
||||
pub fn singleton(value: T) -> Self {
|
||||
Self(Box::new([value]))
|
||||
}
|
||||
|
||||
#[allow(clippy::missing_errors_doc)]
|
||||
#[inline]
|
||||
pub fn try_from_boxed_slice<S: Into<Box<[T]>>>(
|
||||
boxed_slice: S,
|
||||
) -> Result<Self, EmptySliceError> {
|
||||
let boxed_slice = boxed_slice.into();
|
||||
if boxed_slice.is_empty() {
|
||||
Err(EmptySliceError)
|
||||
} else {
|
||||
Ok(Self(boxed_slice))
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn into_boxed_slice(self) -> Box<[T]> {
|
||||
self.0
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn to_vec(&self) -> Vec<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
self.0.to_vec()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub const fn as_slice(&self) -> &[T] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn first(&self) -> &T {
|
||||
&self.0[0]
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn last(&self) -> &T {
|
||||
&self.0[self.0.len() - 1]
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn get<I>(&self, index: I) -> Option<&I::Output>
|
||||
where
|
||||
I: SliceIndex<[T]>,
|
||||
{
|
||||
self.0.get(index)
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub const fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
#[allow(clippy::iter_without_into_iter)]
|
||||
#[inline]
|
||||
pub fn iter(&self) -> slice::Iter<'_, T> {
|
||||
self.0.iter()
|
||||
}
|
||||
|
||||
#[allow(clippy::iter_without_into_iter)]
|
||||
#[inline]
|
||||
pub fn iter_mut(&mut self) -> slice::IterMut<'_, T> {
|
||||
self.0.iter_mut()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn map<U, F: FnMut(T) -> U>(self, f: F) -> NonemptyArray<U> {
|
||||
NonemptyArray(self.0.into_iter().map(f).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<NonemptyArray<T>> for Box<[T]> {
|
||||
#[inline]
|
||||
fn from(value: NonemptyArray<T>) -> Self {
|
||||
value.into_boxed_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ops::Index<usize> for NonemptyArray<T> {
|
||||
type Output = T;
|
||||
|
||||
#[inline]
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
self.0.index(index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoIterator for NonemptyArray<T> {
|
||||
type Item = T;
|
||||
type IntoIter = std::vec::IntoIter<T>;
|
||||
|
||||
#[inline]
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.into_boxed_slice().into_vec().into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> IntoIterator for &'a NonemptyArray<T> {
|
||||
type Item = &'a T;
|
||||
type IntoIter = slice::Iter<'a, T>;
|
||||
|
||||
#[inline]
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.iter()
|
||||
}
|
||||
}
|
||||
55
rust/util/src/wakerdeque.rs
Normal file
55
rust/util/src/wakerdeque.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::task::{Context, Waker};
|
||||
|
||||
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
|
||||
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
|
||||
pub struct WakerDeque<T> {
|
||||
waker: Option<Waker>,
|
||||
deque: VecDeque<T>,
|
||||
}
|
||||
|
||||
impl<T: Debug> Debug for WakerDeque<T> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
self.deque.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WakerDeque<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
waker: None,
|
||||
deque: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, cx: &mut Context<'_>) {
|
||||
self.waker = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
fn wake(&mut self) {
|
||||
let Some(ref mut w) = self.waker else { return };
|
||||
w.wake_by_ref();
|
||||
self.waker = None;
|
||||
}
|
||||
|
||||
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_front()
|
||||
}
|
||||
|
||||
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_back()
|
||||
}
|
||||
|
||||
pub fn push_front(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_front(value);
|
||||
}
|
||||
|
||||
pub fn push_back(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_back(value);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
from importlib.metadata import version
|
||||
|
||||
__version__ = version("exo")
|
||||
|
||||
@@ -39,9 +39,9 @@ class Node:
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(str(keypair.endpoint_id()))
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
router = await Router.create(keypair)
|
||||
router = Router.create(keypair)
|
||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||
await router.register_topic(topics.LOCAL_EVENTS)
|
||||
await router.register_topic(topics.COMMANDS)
|
||||
|
||||
@@ -13,6 +13,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
@@ -21,11 +27,13 @@ from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
FinishReason,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -56,7 +64,7 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
HIDE_THINKING = False
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
@@ -161,7 +169,9 @@ class API:
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/v1/chat/completions")(self.chat_completions)
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -177,17 +187,32 @@ class API:
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=command.model_meta,
|
||||
)
|
||||
|
||||
async def create_instance(
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
command = CreateInstance(instance=payload.instance)
|
||||
instance = payload.instance
|
||||
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
|
||||
required_memory = model_meta.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
if required_memory > available_memory:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Insufficient memory to create instance. Required: {required_memory.in_gb:.1f}GB, Available: {available_memory.in_gb:.1f}GB",
|
||||
)
|
||||
|
||||
command = CreateInstance(
|
||||
instance=instance,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
async def get_placement(
|
||||
@@ -352,32 +377,52 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
async for chunk in token_chunks:
|
||||
stream.process(chunk.token_id)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield chunk.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield chunk.model_copy(update={"text": delta})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if thinking:
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
yield chunk
|
||||
break
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
is_thinking = False
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
if HIDE_THINKING:
|
||||
if chunk.text == "<think>":
|
||||
is_thinking = True
|
||||
if chunk.text == "</think>":
|
||||
is_thinking = False
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
if not (is_thinking and HIDE_THINKING):
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
if parse_gpt_oss:
|
||||
async for chunk in self._process_gpt_oss(token_chunks):
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
else:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
@@ -392,6 +437,59 @@ class API:
|
||||
await self._send(command)
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
@@ -399,10 +497,12 @@ class API:
|
||||
|
||||
async def chat_completions(
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> StreamingResponse:
|
||||
"""Handle chat completions with proper streaming response."""
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
logger.info(f"{parse_gpt_oss=}")
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -417,10 +517,13 @@ class API:
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id, parse_gpt_oss),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
@@ -442,6 +545,8 @@ class API:
|
||||
name=card.name,
|
||||
description=card.description,
|
||||
tags=card.tags,
|
||||
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
|
||||
supports_tensor=card.metadata.supports_tensor,
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
]
|
||||
@@ -458,7 +563,7 @@ class API:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._applystate)
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
@@ -470,7 +575,7 @@ class API:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _applystate(self):
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
|
||||
@@ -95,7 +95,6 @@ class Master:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _command_processor(self) -> None:
|
||||
retry_num = 0
|
||||
with self.command_receiver as commands:
|
||||
async for forwarder_command in commands:
|
||||
try:
|
||||
@@ -187,12 +186,11 @@ class Master:
|
||||
command.finished_command_id
|
||||
]
|
||||
case RequestEventLog():
|
||||
retry_num += 1
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
event = self._event_log[i]
|
||||
event.retry = retry_num
|
||||
await self._send_event(IndexedEvent(idx=i, event=event))
|
||||
await self._send_event(
|
||||
IndexedEvent(idx=i, event=self._event_log[i])
|
||||
)
|
||||
for event in generated_events:
|
||||
await self.event_sender.send(event)
|
||||
except ValueError as e:
|
||||
@@ -234,7 +232,7 @@ class Master:
|
||||
local_event.origin,
|
||||
)
|
||||
for event in self._multi_buffer.drain():
|
||||
logger.trace(f"Master indexing event: {str(event)[:100]}")
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
|
||||
|
||||
@@ -7,20 +7,18 @@ from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.routing.connection_message import IpAddress
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
@@ -131,17 +129,17 @@ def place_instance(
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
hosts: list[IpAddress] = get_hosts_from_subgraph(cycle_digraph)
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
hosts_by_node = get_mlx_ring_hosts_by_node(
|
||||
selected_cycle=selected_cycle,
|
||||
cycle_digraph=cycle_digraph,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
target_instances[instance_id] = MlxRingInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[
|
||||
Host(
|
||||
ip=str(host),
|
||||
port=random_ephemeral_port(),
|
||||
)
|
||||
for host in hosts
|
||||
],
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -4,9 +4,8 @@ from typing import TypeGuard, cast
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.routing.connection_message import IpAddress
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
@@ -154,8 +153,7 @@ def get_shard_assignments(
|
||||
)
|
||||
|
||||
|
||||
def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[IpAddress]:
|
||||
# this function is wrong.
|
||||
def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
cycles = cycle_digraph.get_cycles()
|
||||
expected_length = len(list(cycle_digraph.list_nodes()))
|
||||
cycles = [cycle for cycle in cycles if len(cycle) == expected_length]
|
||||
@@ -173,20 +171,24 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[IpAddress]:
|
||||
logger.info(f"Using thunderbolt cycle: {get_thunderbolt}")
|
||||
|
||||
cycle = cycles[0]
|
||||
hosts: list[IpAddress] = []
|
||||
hosts: list[Host] = []
|
||||
for i in range(len(cycle)):
|
||||
current_node = cycle[i]
|
||||
next_node = cycle[(i + 1) % len(cycle)]
|
||||
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.source_id == current_node.node_id
|
||||
and connection.sink_id == next_node.node_id
|
||||
connection.local_node_id == current_node.node_id
|
||||
and connection.send_back_node_id == next_node.node_id
|
||||
):
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
assert connection.sink_addr is not None
|
||||
hosts.append(connection.sink_addr)
|
||||
assert connection.send_back_multiaddr is not None
|
||||
host = Host(
|
||||
ip=connection.send_back_multiaddr.ip_address,
|
||||
port=connection.send_back_multiaddr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
|
||||
return hosts
|
||||
@@ -213,9 +215,11 @@ def get_mlx_ibv_devices_matrix(
|
||||
continue
|
||||
|
||||
# Find the IP J uses to talk to I
|
||||
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
# This is a local IP on I, which is attached to an interface: find that interface
|
||||
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
|
||||
if interface_name := _find_rdma_interface_name_for_ip(
|
||||
connection_ip, node_i
|
||||
):
|
||||
matrix[i][j] = interface_name
|
||||
logger.info(
|
||||
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
|
||||
@@ -236,17 +240,17 @@ def _find_connection_ip(
|
||||
node_i: NodeInfo,
|
||||
node_j: NodeInfo,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[str]:
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.source_id == node_i.node_id
|
||||
and connection.sink_id == node_j.node_id
|
||||
connection.local_node_id == node_i.node_id
|
||||
and connection.send_back_node_id == node_j.node_id
|
||||
):
|
||||
yield str(connection.sink_addr)
|
||||
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
def _find_rdma_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
@@ -267,6 +271,109 @@ def _find_interface_name_for_ip(
|
||||
return None
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
|
||||
) -> str | None:
|
||||
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
Priority order:
|
||||
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
|
||||
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
"""
|
||||
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
|
||||
# We expect a unique iface -> ip mapping
|
||||
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
|
||||
|
||||
en0_ip = iface_map.get("en0")
|
||||
if en0_ip:
|
||||
return en0_ip
|
||||
|
||||
en1_ip = iface_map.get("en1")
|
||||
if en1_ip:
|
||||
return en1_ip
|
||||
|
||||
non_thunderbolt_ip = next(
|
||||
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
|
||||
)
|
||||
|
||||
if non_thunderbolt_ip:
|
||||
return non_thunderbolt_ip
|
||||
|
||||
if ips:
|
||||
return ips[0][0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
ephemeral_port: int,
|
||||
) -> dict[NodeId, list[Host]]:
|
||||
"""Generate per-node host lists for MLX ring backend.
|
||||
|
||||
Each node gets a list where:
|
||||
- Self position: Host(ip="0.0.0.0", port=ephemeral_port)
|
||||
- Left/right neighbors: actual connection IPs
|
||||
- Non-neighbors: Host(ip="198.51.100.1", port=0) placeholder (RFC 5737 TEST-NET-2)
|
||||
"""
|
||||
world_size = len(selected_cycle)
|
||||
if world_size == 0:
|
||||
return {}
|
||||
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
|
||||
for rank, node in enumerate(selected_cycle):
|
||||
node_id = node.node_id
|
||||
left_rank = (rank - 1) % world_size
|
||||
right_rank = (rank + 1) % world_size
|
||||
|
||||
hosts_for_node: list[Host] = []
|
||||
|
||||
for idx, other_node in enumerate(selected_cycle):
|
||||
if idx == rank:
|
||||
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
continue
|
||||
|
||||
if idx not in {left_rank, right_rank}:
|
||||
# Placeholder IP from RFC 5737 TEST-NET-2
|
||||
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
|
||||
continue
|
||||
|
||||
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
)
|
||||
|
||||
hosts_for_node.append(Host(ip=connection_ip, port=ephemeral_port))
|
||||
|
||||
hosts_by_node[node_id] = hosts_for_node
|
||||
|
||||
return hosts_by_node
|
||||
|
||||
|
||||
def get_mlx_jaccl_coordinators(
|
||||
selected_cycle: list[NodeInfo],
|
||||
coordinator_port: int,
|
||||
@@ -284,7 +391,7 @@ def get_mlx_jaccl_coordinators(
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
return "0.0.0.0"
|
||||
|
||||
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from ipaddress import ip_address
|
||||
from itertools import count
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
@@ -9,45 +11,57 @@ from exo.shared.types.profiling import (
|
||||
)
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
|
||||
|
||||
ip_octet_iter = count()
|
||||
|
||||
|
||||
def create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return NodeInfo(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
@pytest.fixture
|
||||
def create_node():
|
||||
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return NodeInfo(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return _create_node
|
||||
|
||||
|
||||
def create_connection(
|
||||
source_node_id: NodeId,
|
||||
sink_node_id: NodeId,
|
||||
*,
|
||||
ip_octet: int | None = None,
|
||||
) -> Connection:
|
||||
global ip_octet_iter
|
||||
# TODO: this is a hack to get the port for the send_back_multiaddr
|
||||
@pytest.fixture
|
||||
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
||||
port_counter = 1235
|
||||
ip_counter = 1
|
||||
|
||||
return Connection(
|
||||
source_id=source_node_id,
|
||||
sink_id=sink_node_id,
|
||||
sink_addr=ip_address(
|
||||
f"169.254.0.{ip_octet if ip_octet is not None else next(ip_octet_iter)}"
|
||||
),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
def _create_connection(
|
||||
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
|
||||
) -> Connection:
|
||||
nonlocal port_counter
|
||||
nonlocal ip_counter
|
||||
# assign unique ips
|
||||
ip_counter += 1
|
||||
if send_back_port is None:
|
||||
send_back_port = port_counter
|
||||
port_counter += 1
|
||||
return Connection(
|
||||
local_node_id=source_node_id,
|
||||
send_back_node_id=sink_node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
|
||||
),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
return _create_connection
|
||||
|
||||
@@ -43,7 +43,7 @@ from exo.utils.channels import channel
|
||||
@pytest.mark.asyncio
|
||||
async def test_master():
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(str(keypair.endpoint_id()))
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
@@ -74,7 +74,7 @@ async def test_master():
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_postcard_encoding()}_sender")
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
# inject a NodePerformanceProfile event
|
||||
logger.info("inject a NodePerformanceProfile event")
|
||||
await local_event_sender.send(
|
||||
@@ -123,6 +123,8 @@ async def test_master():
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
@@ -140,6 +142,7 @@ async def test_master():
|
||||
origin=node_id,
|
||||
command=(
|
||||
ChatCompletion(
|
||||
command_id=CommandId(),
|
||||
request_params=ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
@@ -162,32 +165,38 @@ async def test_master():
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
runner_id = list(
|
||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||
)[0]
|
||||
assert events[1].event.instance == MlxRingInstance(
|
||||
instance_id=events[1].event.instance.instance_id,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
created_instance = events[1].event.instance
|
||||
assert isinstance(created_instance, MlxRingInstance)
|
||||
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
|
||||
# Validate the shard assignments
|
||||
expected_shard_assignments = ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
),
|
||||
hosts=[],
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
)
|
||||
assert created_instance.shard_assignments == expected_shard_assignments
|
||||
# For single-node, hosts_by_node should have one entry with self-binding
|
||||
assert len(created_instance.hosts_by_node) == 1
|
||||
assert node_id in created_instance.hosts_by_node
|
||||
assert len(created_instance.hosts_by_node[node_id]) == 1
|
||||
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
|
||||
assert created_instance.ephemeral_port > 0
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ipaddress import ip_address
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
@@ -7,7 +7,6 @@ from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.tests.conftest import create_connection, create_node
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
@@ -15,6 +14,7 @@ from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -38,7 +38,8 @@ def instance() -> Instance:
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
|
||||
),
|
||||
hosts=[],
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,6 +50,8 @@ def model_meta() -> ModelMetadata:
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=10,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -76,6 +79,8 @@ def test_get_instance_placements_create_instance(
|
||||
expected_layers: tuple[int, int, int],
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
model_meta.n_layers = total_layers
|
||||
@@ -90,9 +95,13 @@ def test_get_instance_placements_create_instance(
|
||||
topology.add_node(create_node(available_memory[0], node_id_a))
|
||||
topology.add_node(create_node(available_memory[1], node_id_b))
|
||||
topology.add_node(create_node(available_memory[2], node_id_c))
|
||||
# Add bidirectional connections for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {})
|
||||
@@ -121,7 +130,9 @@ def test_get_instance_placements_create_instance(
|
||||
assert shards_sorted[-1].end_layer == total_layers
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
def test_get_instance_placements_one_node_exact_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
@@ -131,6 +142,8 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {})
|
||||
@@ -144,7 +157,9 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1001 * 1024, node_id))
|
||||
@@ -154,6 +169,8 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {})
|
||||
@@ -167,7 +184,9 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
def test_get_instance_placements_one_node_not_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
@@ -177,6 +196,8 @@ def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
storage_size=Memory.from_kb(1001),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -226,15 +247,15 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
assert events[0].instance_id == instance_id
|
||||
|
||||
|
||||
def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
def test_placement_selects_cycle_with_most_memory(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
|
||||
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
|
||||
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
|
||||
# cycle having LESS total memory than D-E-F. The algorithm should still choose
|
||||
# the cycle that contains a leaf node.
|
||||
# Arrange two 3-node cycles with different total memory.
|
||||
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
|
||||
# The algorithm should select the cycle with the most available memory.
|
||||
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
@@ -248,11 +269,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
node_id_e = NodeId()
|
||||
node_id_f = NodeId()
|
||||
|
||||
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
|
||||
node_id_x = NodeId()
|
||||
node_id_y = NodeId()
|
||||
node_id_z = NodeId()
|
||||
|
||||
# A-B-C cycle total memory = 1600 (< D-E-F total)
|
||||
topology.add_node(create_node(400, node_id_a))
|
||||
topology.add_node(create_node(400, node_id_b))
|
||||
@@ -263,24 +279,20 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
topology.add_node(create_node(600, node_id_e))
|
||||
topology.add_node(create_node(600, node_id_f))
|
||||
|
||||
# Extra nodes with tiny memory so they can't form singleton placements
|
||||
topology.add_node(create_node(10, node_id_x))
|
||||
topology.add_node(create_node(10, node_id_y))
|
||||
topology.add_node(create_node(10, node_id_z))
|
||||
|
||||
# Build directed cycles
|
||||
# Build bidirectional cycles for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
topology.add_connection(create_connection(node_id_d, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_f))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_d))
|
||||
|
||||
# Add extra outgoing edges from D/E/F so none of them are leaves
|
||||
topology.add_connection(create_connection(node_id_d, node_id_x))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_y))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_z))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_f))
|
||||
|
||||
cic = place_instance_command(
|
||||
model_meta=model_meta,
|
||||
@@ -289,23 +301,24 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
# Act
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
|
||||
# D-E-F has more total memory.
|
||||
# Assert: D-E-F cycle should be selected as it has more total memory
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
|
||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
|
||||
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
|
||||
assert more_memory_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
model_meta.n_layers = 12
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
@@ -320,7 +333,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
|
||||
ethernet_interface = NetworkInterfaceInfo(
|
||||
name="en0",
|
||||
ip_address=ip_address("192.168.1.100"),
|
||||
ip_address="192.168.1.100",
|
||||
)
|
||||
|
||||
assert node_a.node_profile is not None
|
||||
@@ -335,13 +348,13 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
conn_c_b = create_connection(node_id_c, node_id_b)
|
||||
conn_a_c = create_connection(node_id_a, node_id_c)
|
||||
|
||||
assert conn_a_b.sink_addr is not None
|
||||
assert conn_b_c.sink_addr is not None
|
||||
assert conn_c_a.sink_addr is not None
|
||||
assert conn_a_b.send_back_multiaddr is not None
|
||||
assert conn_b_c.send_back_multiaddr is not None
|
||||
assert conn_c_a.send_back_multiaddr is not None
|
||||
|
||||
assert conn_b_a.sink_addr is not None
|
||||
assert conn_c_b.sink_addr is not None
|
||||
assert conn_a_c.sink_addr is not None
|
||||
assert conn_b_a.send_back_multiaddr is not None
|
||||
assert conn_c_b.send_back_multiaddr is not None
|
||||
assert conn_a_c.send_back_multiaddr is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
@@ -351,11 +364,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_a.sink_addr,
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_a.sink_addr,
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
@@ -369,11 +382,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.sink_addr,
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_b.sink_addr,
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
@@ -387,11 +400,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_c.sink_addr,
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.sink_addr,
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
@@ -7,12 +9,12 @@ from exo.master.placement_utils import (
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import create_connection, create_node
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
@@ -24,6 +26,8 @@ def topology() -> Topology:
|
||||
|
||||
def test_filter_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
@@ -56,6 +60,8 @@ def test_filter_cycles_by_memory(
|
||||
|
||||
def test_filter_cycles_by_insufficient_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
@@ -84,6 +90,8 @@ def test_filter_cycles_by_insufficient_memory(
|
||||
|
||||
def test_filter_multiple_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
@@ -121,6 +129,8 @@ def test_filter_multiple_cycles_by_memory(
|
||||
|
||||
def test_get_smallest_cycles(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
@@ -159,6 +169,8 @@ def test_get_smallest_cycles(
|
||||
)
|
||||
def test_get_shard_assignments(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
@@ -186,6 +198,8 @@ def test_get_shard_assignments(
|
||||
pretty_name="Test Model",
|
||||
n_layers=total_layers,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
@@ -218,6 +232,8 @@ def test_get_shard_assignments(
|
||||
|
||||
def test_get_hosts_from_subgraph(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
@@ -232,10 +248,10 @@ def test_get_hosts_from_subgraph(
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
|
||||
|
||||
# act
|
||||
hosts = get_hosts_from_subgraph(topology)
|
||||
@@ -253,6 +269,8 @@ def test_get_hosts_from_subgraph(
|
||||
|
||||
def test_get_mlx_jaccl_coordinators(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
@@ -263,12 +281,12 @@ def test_get_mlx_jaccl_coordinators(
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
conn_a_b = create_connection(node_a_id, node_b_id)
|
||||
conn_b_a = create_connection(node_b_id, node_a_id)
|
||||
conn_b_c = create_connection(node_b_id, node_c_id)
|
||||
conn_c_b = create_connection(node_c_id, node_b_id)
|
||||
conn_c_a = create_connection(node_c_id, node_a_id)
|
||||
conn_a_c = create_connection(node_a_id, node_c_id)
|
||||
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
|
||||
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
|
||||
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
|
||||
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
|
||||
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
|
||||
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
|
||||
|
||||
# Update node profiles with network interfaces before adding to topology
|
||||
assert node_a.node_profile is not None
|
||||
@@ -283,11 +301,11 @@ def test_get_mlx_jaccl_coordinators(
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_b.sink_addr.ip,
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_c.sink_addr.ip,
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
@@ -300,11 +318,11 @@ def test_get_mlx_jaccl_coordinators(
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_b_a.sink_addr.ip,
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.sink_addr.ip,
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
@@ -317,11 +335,11 @@ def test_get_mlx_jaccl_coordinators(
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.sink_addr.ip,
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_c_a.sink_addr.ip,
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
@@ -371,11 +389,11 @@ def test_get_mlx_jaccl_coordinators(
|
||||
|
||||
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
|
||||
# node_b uses the IP from conn_b_a (node_b -> node_a)
|
||||
assert coordinators[node_b_id] == (f"{conn_b_a.sink_addr.ip}:5000"), (
|
||||
"node_b should use the IP from conn_b_a"
|
||||
)
|
||||
assert coordinators[node_b_id] == (
|
||||
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_b should use the IP from conn_b_a"
|
||||
|
||||
# node_c uses the IP from conn_c_a (node_c -> node_a)
|
||||
assert coordinators[node_c_id] == (f"{conn_c_a.sink_addr.ip}:5000"), (
|
||||
"node_c should use the IP from conn_c_a"
|
||||
)
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from ipaddress import ip_address
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.routing.connection_message import SocketAddress
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
@@ -20,9 +18,9 @@ def topology() -> Topology:
|
||||
@pytest.fixture
|
||||
def connection() -> Connection:
|
||||
return Connection(
|
||||
source_id=NodeId(),
|
||||
sink_id=NodeId(),
|
||||
sink_addr=SocketAddress(ip=ip_address("127.0.0.1"), port=1235, zone_id=None),
|
||||
local_node_id=NodeId(),
|
||||
send_back_node_id=NodeId(),
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
@@ -66,8 +64,12 @@ def test_add_connection(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
@@ -81,8 +83,12 @@ def test_update_node_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_node_profile = NodePerformanceProfile(
|
||||
@@ -97,10 +103,12 @@ def test_update_node_profile(
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_node_profile(connection.source_id, node_profile=new_node_profile)
|
||||
topology.update_node_profile(
|
||||
connection.local_node_id, node_profile=new_node_profile
|
||||
)
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(connection.source_id)
|
||||
data = topology.get_node_profile(connection.local_node_id)
|
||||
assert data == new_node_profile
|
||||
|
||||
|
||||
@@ -108,17 +116,21 @@ def test_update_connection_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_connection_profile = ConnectionProfile(
|
||||
throughput=2000, latency=2000, jitter=2000
|
||||
)
|
||||
connection = Connection(
|
||||
source_id=connection.source_id,
|
||||
sink_id=connection.sink_id,
|
||||
sink_addr=connection.sink_addr,
|
||||
local_node_id=connection.local_node_id,
|
||||
send_back_node_id=connection.send_back_node_id,
|
||||
send_back_multiaddr=connection.send_back_multiaddr,
|
||||
connection_profile=new_connection_profile,
|
||||
)
|
||||
|
||||
@@ -134,8 +146,12 @@ def test_remove_connection_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
@@ -149,23 +165,31 @@ def test_remove_node_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
topology.remove_node(connection.source_id)
|
||||
topology.remove_node(connection.local_node_id)
|
||||
|
||||
# assert
|
||||
assert topology.get_node_profile(connection.source_id) is None
|
||||
assert topology.get_node_profile(connection.local_node_id) is None
|
||||
|
||||
|
||||
def test_list_nodes(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
@@ -175,6 +199,6 @@ def test_list_nodes(
|
||||
assert len(nodes) == 2
|
||||
assert all(isinstance(node, NodeInfo) for node in nodes)
|
||||
assert {node.node_id for node in nodes} == {
|
||||
connection.source_id,
|
||||
connection.sink_id,
|
||||
connection.local_node_id,
|
||||
connection.send_back_node_id,
|
||||
}
|
||||
|
||||
@@ -1,44 +1,37 @@
|
||||
from ipaddress import IPv4Address, IPv6Address, ip_address
|
||||
from enum import Enum
|
||||
|
||||
from exo_pyo3_bindings import RustConnectionMessage
|
||||
from pydantic import ConfigDict
|
||||
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
"""Serialisable types for Connection Updates/Messages"""
|
||||
|
||||
IpAddress = IPv4Address | IPv6Address
|
||||
|
||||
class ConnectionMessageType(Enum):
|
||||
Connected = 0
|
||||
Disconnected = 1
|
||||
|
||||
class SocketAddress(CamelCaseModel):
|
||||
# could be the python IpAddress type if we're feeling fancy
|
||||
ip: IpAddress
|
||||
port: int
|
||||
zone_id: int | None
|
||||
|
||||
model_config = ConfigDict(
|
||||
frozen=True,
|
||||
)
|
||||
@staticmethod
|
||||
def from_update_type(update_type: ConnectionUpdateType):
|
||||
match update_type:
|
||||
case ConnectionUpdateType.Connected:
|
||||
return ConnectionMessageType.Connected
|
||||
case ConnectionUpdateType.Disconnected:
|
||||
return ConnectionMessageType.Disconnected
|
||||
|
||||
|
||||
class ConnectionMessage(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
ips: set[SocketAddress] | None
|
||||
connection_type: ConnectionMessageType
|
||||
remote_ipv4: str
|
||||
remote_tcp_port: int
|
||||
|
||||
@classmethod
|
||||
def from_rust(cls, message: RustConnectionMessage) -> "ConnectionMessage":
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(str(message.endpoint_id)),
|
||||
ips=None
|
||||
if message.current_transport_addrs is None
|
||||
else set(
|
||||
# TODO: better handle fallible conversion
|
||||
SocketAddress(
|
||||
ip=ip_address(addr.ip_addr()),
|
||||
port=addr.port(),
|
||||
zone_id=addr.zone_id(),
|
||||
)
|
||||
for addr in message.current_transport_addrs
|
||||
),
|
||||
node_id=NodeId(update.peer_id.to_base58()),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
@@ -14,15 +13,14 @@ from anyio import (
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from exo_pyo3_bindings import (
|
||||
AllQueuesFullError,
|
||||
Keypair,
|
||||
RustNetworkingHandle,
|
||||
RustReceiver,
|
||||
RustSender,
|
||||
NetworkingHandle,
|
||||
NoPeersSubscribedToTopicError,
|
||||
)
|
||||
from filelock import FileLock
|
||||
from loguru import logger
|
||||
|
||||
from exo import __version__
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
@@ -39,6 +37,7 @@ class TopicRouter[T: CamelCaseModel]:
|
||||
def __init__(
|
||||
self,
|
||||
topic: TypedTopic[T],
|
||||
networking_sender: Sender[tuple[str, bytes]],
|
||||
max_buffer_size: float = inf,
|
||||
):
|
||||
self.topic: TypedTopic[T] = topic
|
||||
@@ -46,41 +45,26 @@ class TopicRouter[T: CamelCaseModel]:
|
||||
send, recv = channel[T]()
|
||||
self.receiver: Receiver[T] = recv
|
||||
self._sender: Sender[T] = send
|
||||
self.networking_sender: RustSender | None = None
|
||||
self.networking_receiver: RustReceiver | None = None
|
||||
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
self.networking_sender: Sender[tuple[str, bytes]] = networking_sender
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self.receive_loop)
|
||||
|
||||
async def receive_loop(self):
|
||||
logger.debug(f"Topic Router {self.topic} ready to send")
|
||||
with self.receiver as items:
|
||||
async for item in items:
|
||||
# Check if we should send to network
|
||||
if (
|
||||
self.topic.publish_policy is PublishPolicy.Always
|
||||
and self.networking_sender is not None
|
||||
len(self.senders) == 0
|
||||
and self.topic.publish_policy is PublishPolicy.Minimal
|
||||
):
|
||||
await self._send_out(item)
|
||||
continue
|
||||
if self.topic.publish_policy is PublishPolicy.Always:
|
||||
await self._send_out(item)
|
||||
# Then publish to all senders
|
||||
await self.publish(item)
|
||||
logger.debug(f"Shut down Topic Router {self.topic}")
|
||||
|
||||
async def net_receive_loop(self):
|
||||
assert self.networking_receiver is not None
|
||||
while True:
|
||||
item = self.topic.deserialize(await self.networking_receiver.receive())
|
||||
await self.publish(item)
|
||||
|
||||
def subscribe_with(self, net_send: RustSender, net_recv: RustReceiver):
|
||||
self.networking_sender = net_send
|
||||
self.networking_receiver = net_recv
|
||||
self._tg.start_soon(self.net_receive_loop)
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug(f"Shutting down Topic Router {self.topic}")
|
||||
# Close all the things!
|
||||
for sender in self.senders:
|
||||
sender.close()
|
||||
@@ -101,32 +85,43 @@ class TopicRouter[T: CamelCaseModel]:
|
||||
to_clear.add(sender)
|
||||
self.senders -= to_clear
|
||||
|
||||
async def publish_bytes(self, data: bytes):
|
||||
await self.publish(self.topic.deserialize(data))
|
||||
|
||||
def new_sender(self) -> Sender[T]:
|
||||
return self._sender.clone()
|
||||
|
||||
async def _send_out(self, item: T):
|
||||
assert self.networking_sender is not None
|
||||
logger.trace(f"TopicRouter {self.topic.topic} sending {item}")
|
||||
await self.networking_sender.send(self.topic.serialize(item))
|
||||
await self.networking_sender.send(
|
||||
(str(self.topic.topic), self.topic.serialize(item))
|
||||
)
|
||||
|
||||
|
||||
class Router:
|
||||
@classmethod
|
||||
async def create(cls, identity: Keypair) -> "Router":
|
||||
return cls(handle=await RustNetworkingHandle.create(identity, __version__))
|
||||
def create(cls, identity: Keypair) -> "Router":
|
||||
return cls(handle=NetworkingHandle(identity))
|
||||
|
||||
def __init__(self, handle: RustNetworkingHandle):
|
||||
def __init__(self, handle: NetworkingHandle):
|
||||
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}
|
||||
self._unsubbed: list[str] = []
|
||||
self._net: RustNetworkingHandle = handle
|
||||
send, recv = channel[tuple[str, bytes]]()
|
||||
self.networking_receiver: Receiver[tuple[str, bytes]] = recv
|
||||
self._net: NetworkingHandle = handle
|
||||
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
|
||||
self._id_count = count()
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
|
||||
assert self._tg is None, "Attempted to register topic after setup time"
|
||||
router = TopicRouter[T](topic)
|
||||
send = self._tmp_networking_sender
|
||||
if send:
|
||||
self._tmp_networking_sender = None
|
||||
else:
|
||||
send = self.networking_receiver.clone_sender()
|
||||
router = TopicRouter[T](topic, send)
|
||||
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
|
||||
self._unsubbed.append(topic.topic)
|
||||
await self._networking_subscribe(str(topic.topic))
|
||||
|
||||
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
|
||||
router = self.topic_routers.get(topic.topic, None)
|
||||
@@ -156,9 +151,13 @@ class Router:
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug("Shutting down Router")
|
||||
@@ -166,33 +165,48 @@ class Router:
|
||||
return
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _networking_recv_connection_messages(self):
|
||||
recv = await self._net.get_connection_receiver()
|
||||
async def _networking_subscribe(self, topic: str):
|
||||
logger.info(f"Subscribing to {topic}")
|
||||
await self._net.gossipsub_subscribe(topic)
|
||||
|
||||
async def _networking_unsubscribe(self, topic: str):
|
||||
logger.info(f"Unsubscribing from {topic}")
|
||||
await self._net.gossipsub_unsubscribe(topic)
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
message = await recv.receive()
|
||||
await anyio.sleep(0.2)
|
||||
topic, data = await self._net.gossipsub_recv()
|
||||
logger.trace(f"Received message on {topic} with payload {data}")
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(f"Received message on unknown or inactive topic {topic}")
|
||||
continue
|
||||
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
|
||||
async def _networking_recv_connection_messages(self):
|
||||
while True:
|
||||
update = await self._net.connection_update_recv()
|
||||
message = ConnectionMessage.from_update(update)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
to_clear: list[str] = []
|
||||
for topic in self._unsubbed:
|
||||
try:
|
||||
rsend, rrecv = await self._net.subscribe(topic)
|
||||
logger.info(f"Subscribed to peer on {topic}")
|
||||
to_clear.append(topic)
|
||||
self.topic_routers[topic].subscribe_with(rsend, rrecv)
|
||||
# TODO: real error
|
||||
except RuntimeError:
|
||||
pass
|
||||
if to_clear:
|
||||
assert to_clear == self._unsubbed
|
||||
self._unsubbed = [i for i in self._unsubbed if i not in to_clear]
|
||||
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(ConnectionMessage.from_rust(message))
|
||||
await router.publish(message)
|
||||
|
||||
async def _networking_publish(self):
|
||||
with self.networking_receiver as networked_items:
|
||||
async for topic, data in networked_items:
|
||||
try:
|
||||
logger.trace(f"Sending message on {topic} with payload {data}")
|
||||
await self._net.gossipsub_publish(topic, data)
|
||||
# As a hack, this also catches AllQueuesFull
|
||||
# Need to fix that ASAP.
|
||||
except (NoPeersSubscribedToTopicError, AllQueuesFullError):
|
||||
pass
|
||||
|
||||
|
||||
def get_node_id_keypair(
|
||||
@@ -211,16 +225,16 @@ def get_node_id_keypair(
|
||||
with open(path, "a+b") as f: # opens in append-mode => starts at EOF
|
||||
# if non-zero EOF, then file exists => use to get node-ID
|
||||
if f.tell() != 0:
|
||||
f.seek(0) # go to start & read postcard-encoded bytes
|
||||
postcard_encoded = f.read()
|
||||
f.seek(0) # go to start & read protobuf-encoded bytes
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_postcard_encoding(postcard_encoded)
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
except ValueError as e: # on runtime error, assume corrupt file
|
||||
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_postcard_encoding())
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
return keypair
|
||||
|
||||
@@ -13,6 +13,8 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
class PublishPolicy(str, Enum):
|
||||
Never = "Never"
|
||||
"""Never publish to the network - this is a local message"""
|
||||
Minimal = "Minimal"
|
||||
"""Only publish when there is no local receiver for this type of message"""
|
||||
Always = "Always"
|
||||
"""Always publish to the network"""
|
||||
|
||||
|
||||
@@ -164,38 +164,28 @@ class Election:
|
||||
self._candidates.append(message)
|
||||
|
||||
async def _connection_receiver(self) -> None:
|
||||
current_peers: set[NodeId] = set()
|
||||
with self._cm_receiver as connection_messages:
|
||||
async for first in connection_messages:
|
||||
if first.node_id not in current_peers or first.ips is None:
|
||||
if first.node_id not in current_peers:
|
||||
current_peers.add(first.node_id)
|
||||
if first.ips is None:
|
||||
current_peers.remove(first.node_id)
|
||||
# Delay after connection message for time to symmetrically setup
|
||||
await anyio.sleep(0.2)
|
||||
rest = connection_messages.collect()
|
||||
for msg in rest:
|
||||
if msg.node_id not in current_peers:
|
||||
current_peers.add(first.node_id)
|
||||
if msg.ips is None:
|
||||
current_peers.remove(first.node_id)
|
||||
# Delay after connection message for time to symmetrically setup
|
||||
await anyio.sleep(0.2)
|
||||
rest = connection_messages.collect()
|
||||
|
||||
logger.info(
|
||||
f"Connection messages received: {first} followed by {rest}"
|
||||
)
|
||||
logger.info(f"Current clock: {self.clock}")
|
||||
# These messages are strictly peer to peer
|
||||
self.clock += 1
|
||||
logger.info(f"New clock: {self.clock}")
|
||||
candidates: list[ElectionMessage] = []
|
||||
self._candidates = candidates
|
||||
logger.info("Starting new campaign")
|
||||
assert self._tg is not None
|
||||
self._tg.start_soon(
|
||||
self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT
|
||||
)
|
||||
logger.info("Campaign started")
|
||||
logger.debug(
|
||||
f"Connection messages received: {first} followed by {rest}"
|
||||
)
|
||||
logger.debug(f"Current clock: {self.clock}")
|
||||
# These messages are strictly peer to peer
|
||||
self.clock += 1
|
||||
logger.debug(f"New clock: {self.clock}")
|
||||
assert self._tg is not None
|
||||
candidates: list[ElectionMessage] = []
|
||||
self._candidates = candidates
|
||||
logger.debug("Starting new campaign")
|
||||
self._tg.start_soon(
|
||||
self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT
|
||||
)
|
||||
logger.debug("Campaign started")
|
||||
logger.debug("Connection message added")
|
||||
|
||||
async def _command_counter(self) -> None:
|
||||
with self._co_receiver as commands:
|
||||
|
||||
@@ -51,6 +51,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="DeepSeek V3.1 (4-bit)",
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
@@ -64,6 +66,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="DeepSeek V3.1 (8-bit)",
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "deepseek-v3.2": ModelCard(
|
||||
@@ -135,6 +139,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Kimi K2 Instruct (4-bit)",
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
@@ -148,6 +154,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Kimi K2 Thinking (4-bit)",
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.1
|
||||
@@ -162,6 +170,38 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.1 8B (4-bit)",
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
short_id="llama-3.1-8b-8bit",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
name="Llama 3.1 8B (8-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.1 8B (8-bit)",
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
short_id="llama-3.1-8b-bf16",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
name="Llama 3.1 8B (BF16)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
pretty_name="Llama 3.1 8B (BF16)",
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
@@ -175,6 +215,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.1 70B (4-bit)",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.2
|
||||
@@ -189,6 +231,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 1B (4-bit)",
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
@@ -202,6 +246,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 3B (4-bit)",
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
@@ -215,6 +261,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 3B (8-bit)",
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.3
|
||||
@@ -229,6 +277,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
@@ -242,6 +292,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B (8-bit)",
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
@@ -255,20 +307,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B (FP16)",
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
),
|
||||
),
|
||||
# phi-3
|
||||
"phi-3-mini": ModelCard(
|
||||
short_id="phi-3-mini",
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
name="Phi 3 Mini 128k (4-bit)",
|
||||
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
pretty_name="Phi 3 Mini 128k (4-bit)",
|
||||
storage_size=Memory.from_mb(2099),
|
||||
n_layers=32,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# qwen3
|
||||
@@ -283,6 +323,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 0.6B (4-bit)",
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
@@ -296,6 +338,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 0.6B (8-bit)",
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
@@ -309,6 +353,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 30B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
@@ -322,6 +368,68 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 30B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
name="Qwen3 80B A3B (4-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
name="Qwen3 80B A3B (8-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
name="Qwen3 80B A3B Thinking (4-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
name="Qwen3 80B A3B Thinking (8-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
@@ -335,6 +443,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 235B A22B (4-bit)",
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
@@ -348,6 +458,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 235B A22B (8-bit)",
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
@@ -361,6 +473,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
@@ -374,77 +488,84 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# granite
|
||||
"granite-3.3-2b": ModelCard(
|
||||
short_id="granite-3.3-2b",
|
||||
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
|
||||
name="Granite 3.3 2B (FP16)",
|
||||
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
|
||||
pretty_name="Granite 3.3 2B (FP16)",
|
||||
storage_size=Memory.from_mb(4951),
|
||||
n_layers=40,
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "granite-3.3-8b": ModelCard(
|
||||
# short_id="granite-3.3-8b",
|
||||
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
# name="Granite 3.3 8B",
|
||||
# description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
"gpt-oss-20b-4bit": ModelCard(
|
||||
short_id="gpt-oss-20b-4bit",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# Needs to be quantized g32 or g16.
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
description="""GLM 4.5 Air 8bit""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
pretty_name="GLM 4.5 Air 8bit",
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
short_id="glm-4.5-air-bf16",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
name="GLM 4.5 Air bf16",
|
||||
description="""GLM 4.5 Air bf16""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
pretty_name="GLM 4.5 Air bf16",
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
|
||||
# short_id="devstral-2-123b-instruct-2512-8bit",
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
# pretty_name="Granite 3.3 8B",
|
||||
# storage_size=Memory.from_kb(15958720),
|
||||
# n_layers=40,
|
||||
# ),
|
||||
# ),
|
||||
# smol-lm
|
||||
# "smol-lm-135m": ModelCard(
|
||||
# short_id="smol-lm-135m",
|
||||
# model_id="mlx-community/SmolLM-135M-4bit",
|
||||
# name="Smol LM 135M",
|
||||
# description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
|
||||
# pretty_name="Smol LM 135M",
|
||||
# storage_size=Memory.from_kb(73940),
|
||||
# n_layers=30,
|
||||
# ),
|
||||
# ),
|
||||
# gpt-oss
|
||||
# "gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
# short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
# name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
# description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
# pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
# storage_size=Memory.from_kb(68_996_301),
|
||||
# n_layers=36,
|
||||
# hidden_size=2880,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
# "gpt-oss-20b-4bit": ModelCard(
|
||||
# short_id="gpt-oss-20b-4bit",
|
||||
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
# name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
# description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
# pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
# storage_size=Memory.from_kb(11_744_051),
|
||||
# n_layers=24,
|
||||
# hidden_size=2880,
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# storage_size=Memory.from_kb(133_000_000),
|
||||
# n_layers=88,
|
||||
# hidden_size=12288,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
|
||||
@@ -6,6 +6,7 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.worker.download.download_utils import (
|
||||
@@ -25,6 +26,7 @@ class ConfigData(BaseModel):
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
@@ -106,10 +108,19 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
model_card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
|
||||
None,
|
||||
)
|
||||
|
||||
return ModelMetadata(
|
||||
model_id=ModelId(model_id),
|
||||
pretty_name=model_id,
|
||||
pretty_name=model_card.name if model_card is not None else model_id,
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=model_card.metadata.supports_tensor
|
||||
if model_card is not None
|
||||
else False,
|
||||
)
|
||||
|
||||
@@ -36,6 +36,8 @@ def get_pipeline_shard_metadata(
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_apply_node_download_progress():
|
||||
NodeDownloadProgress(download_progress=event), state
|
||||
)
|
||||
|
||||
assert new_state == State(downloads={NodeId("node-1"): [event]})
|
||||
assert new_state.downloads == {NodeId("node-1"): [event]}
|
||||
|
||||
|
||||
def test_apply_two_node_download_progress():
|
||||
@@ -42,4 +42,4 @@ def test_apply_two_node_download_progress():
|
||||
# TODO: This test is failing. We should support the following:
|
||||
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
|
||||
# 2. Downloading a model, it completes, then downloading a different model on the same node.
|
||||
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})
|
||||
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from anyio import create_task_group, fail_after, move_on_after
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
@@ -330,7 +330,9 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
|
||||
await cm_tx.send(
|
||||
ConnectionMessage(
|
||||
node_id=NodeId(),
|
||||
ips=set(),
|
||||
connection_type=ConnectionMessageType.Connected,
|
||||
remote_ipv4="",
|
||||
remote_tcp_port=0,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
|
||||
sem.release()
|
||||
# wait to be told to begin simultaneous read
|
||||
ev.wait()
|
||||
queue.put(get_node_id_keypair().to_postcard_encoding())
|
||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
||||
|
||||
|
||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from ipaddress import ip_address
|
||||
|
||||
from exo.routing.connection_message import SocketAddress
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection
|
||||
|
||||
@@ -14,9 +12,9 @@ def test_state_serialization_roundtrip() -> None:
|
||||
node_b = NodeId("node-b")
|
||||
|
||||
connection = Connection(
|
||||
sink_id=node_a,
|
||||
source_id=node_b,
|
||||
sink_addr=SocketAddress(ip=ip_address("127.0.0.1"), port=5354, zone_id=None),
|
||||
local_node_id=node_a,
|
||||
send_back_node_id=node_b,
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
)
|
||||
|
||||
state = State()
|
||||
|
||||
@@ -81,16 +81,16 @@ class Topology:
|
||||
self,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
if connection.source_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.source_id))
|
||||
if connection.sink_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.sink_id))
|
||||
if connection.local_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.local_node_id))
|
||||
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
|
||||
|
||||
if connection in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
src_id = self._node_id_to_rx_id_map[connection.source_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.sink_id]
|
||||
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
|
||||
|
||||
rx_id = self._graph.add_edge(src_id, sink_id, connection)
|
||||
self._edge_id_to_rx_id_map[connection] = rx_id
|
||||
@@ -132,7 +132,10 @@ class Topology:
|
||||
return
|
||||
|
||||
for connection in self.list_connections():
|
||||
if connection.source_id == node_id or connection.sink_id == node_id:
|
||||
if (
|
||||
connection.local_node_id == node_id
|
||||
or connection.send_back_node_id == node_id
|
||||
):
|
||||
self.remove_connection(connection)
|
||||
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
@@ -185,7 +188,10 @@ class Topology:
|
||||
for rx_idx in rx_idxs:
|
||||
topology.add_node(self._graph[rx_idx])
|
||||
for connection in self.list_connections():
|
||||
if connection.source_id in node_idxs and connection.sink_id in node_idxs:
|
||||
if (
|
||||
connection.local_node_id in node_idxs
|
||||
and connection.send_back_node_id in node_idxs
|
||||
):
|
||||
topology.add_connection(connection)
|
||||
return topology
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
@@ -174,6 +174,7 @@ class DeleteInstanceTaskParams(BaseModel):
|
||||
class CreateInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
model_meta: ModelMetadata
|
||||
|
||||
|
||||
class DeleteInstanceResponse(BaseModel):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Self
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
@@ -28,17 +26,6 @@ class PlaceInstance(BaseCommand):
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
|
||||
# Decision point - I like this syntax better than the typical fixtures,
|
||||
# but it's """bloat"""
|
||||
@classmethod
|
||||
def fixture(cls) -> Self:
|
||||
return cls(
|
||||
model_meta=ModelMetadata.fixture(),
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
)
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
instance: Instance
|
||||
|
||||
@@ -22,8 +22,7 @@ class EventId(Id):
|
||||
class BaseEvent(TaggedModel):
|
||||
event_id: EventId = Field(default_factory=EventId)
|
||||
# Internal, for debugging. Please don't rely on this field for anything!
|
||||
_master_time_stamp: datetime | None = None
|
||||
retry: int | None = None
|
||||
_master_time_stamp: None | datetime = None
|
||||
|
||||
|
||||
class TestEvent(BaseEvent):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Self
|
||||
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.types.common import Id
|
||||
@@ -16,12 +14,5 @@ class ModelMetadata(CamelCaseModel):
|
||||
pretty_name: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
|
||||
@classmethod
|
||||
def fixture(cls) -> Self:
|
||||
return cls(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
)
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Self
|
||||
|
||||
import psutil
|
||||
|
||||
from exo.routing.connection_message import IpAddress
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -50,7 +49,7 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
name: str
|
||||
ip_address: IpAddress
|
||||
ip_address: str
|
||||
|
||||
|
||||
class NodePerformanceProfile(CamelCaseModel):
|
||||
|
||||
@@ -40,6 +40,10 @@ class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class ConnectToGroup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class StartWarmup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
@@ -57,5 +61,11 @@ class Shutdown(BaseTask): # emitted by Worker
|
||||
|
||||
|
||||
Task = (
|
||||
CreateRunner | DownloadModel | LoadModel | StartWarmup | ChatCompletion | Shutdown
|
||||
CreateRunner
|
||||
| DownloadModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| Shutdown
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from exo.routing.connection_message import IpAddress
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -10,17 +10,17 @@ class NodeInfo(CamelCaseModel):
|
||||
|
||||
|
||||
class Connection(CamelCaseModel):
|
||||
source_id: NodeId
|
||||
sink_id: NodeId
|
||||
sink_addr: IpAddress
|
||||
local_node_id: NodeId
|
||||
send_back_node_id: NodeId
|
||||
send_back_multiaddr: Multiaddr
|
||||
connection_profile: ConnectionProfile | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.source_id,
|
||||
self.sink_id,
|
||||
self.sink_addr,
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.send_back_multiaddr.address,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -28,10 +28,10 @@ class Connection(CamelCaseModel):
|
||||
if not isinstance(other, Connection):
|
||||
raise ValueError("Cannot compare Connection with non-Connection")
|
||||
return (
|
||||
self.source_id == other.source_id
|
||||
and self.sink_id == other.sink_id
|
||||
and self.sink_addr == other.sink_addr
|
||||
self.local_node_id == other.local_node_id
|
||||
and self.send_back_node_id == other.send_back_node_id
|
||||
and self.send_back_multiaddr == other.send_back_multiaddr
|
||||
)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.sink_addr).startswith("169.254")
|
||||
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
@@ -25,7 +25,8 @@ class BaseInstance(TaggedModel):
|
||||
|
||||
|
||||
class MlxRingInstance(BaseInstance):
|
||||
hosts: list[Host]
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
ephemeral_port: int
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
|
||||
@@ -21,7 +21,15 @@ class BaseRunnerStatus(TaggedModel):
|
||||
return isinstance(self, RunnerRunning)
|
||||
|
||||
|
||||
class RunnerWaitingForModel(BaseRunnerStatus):
|
||||
class RunnerIdle(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerConnecting(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerConnected(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
@@ -54,7 +62,9 @@ class RunnerFailed(BaseRunnerStatus):
|
||||
|
||||
|
||||
RunnerStatus = (
|
||||
RunnerWaitingForModel
|
||||
RunnerIdle
|
||||
| RunnerConnecting
|
||||
| RunnerConnected
|
||||
| RunnerLoading
|
||||
| RunnerLoaded
|
||||
| RunnerWarmingUp
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, TopologyEdgeCreated
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection
|
||||
|
||||
|
||||
def check_connections(
|
||||
local_id: NodeId, msg: ConnectionMessage, state: State
|
||||
) -> list[Event]:
|
||||
remote_id = msg.node_id
|
||||
sockets = msg.ips
|
||||
if (
|
||||
not state.topology.contains_node(remote_id)
|
||||
or remote_id not in state.node_profiles
|
||||
):
|
||||
return []
|
||||
|
||||
out: list[Event] = []
|
||||
conns = list(state.topology.list_connections())
|
||||
for iface in state.node_profiles[remote_id].network_interfaces:
|
||||
if sockets is None:
|
||||
continue
|
||||
for sock in sockets:
|
||||
if iface.ip_address == sock.ip:
|
||||
conn = Connection(source_id=local_id, sink_id=remote_id, sink_addr=sock)
|
||||
if state.topology.contains_connection(conn):
|
||||
conns.remove(conn)
|
||||
continue
|
||||
out.append(TopologyEdgeCreated(edge=conn))
|
||||
|
||||
return out
|
||||
@@ -95,7 +95,15 @@ def extract_layer_num(tensor_name: str) -> int | None:
|
||||
|
||||
def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list[str]:
|
||||
default_patterns = set(
|
||||
["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt", "*.jinja"]
|
||||
[
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"tiktoken.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jinja",
|
||||
]
|
||||
)
|
||||
shard_specific_patterns: set[str] = set()
|
||||
if weight_map:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import copy
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
@@ -12,7 +13,7 @@ from exo.shared.types.worker.shards import (
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress
|
||||
|
||||
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Shoudl this be a classmethod?
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
|
||||
class ShardDownloader(ABC):
|
||||
@abstractmethod
|
||||
async def ensure_shard(
|
||||
@@ -43,34 +44,7 @@ class ShardDownloader(ABC):
|
||||
Yields:
|
||||
tuple[Path, RepoDownloadProgress]: The path and progress of a shard download.
|
||||
"""
|
||||
yield (
|
||||
Path("/tmp/noop_shard"),
|
||||
RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
),
|
||||
)
|
||||
yield (Path("/tmp/noop_shard"), NOOP_DOWNLOAD_PROGRESS)
|
||||
|
||||
@abstractmethod
|
||||
async def get_shard_download_status_for_shard(
|
||||
@@ -94,46 +68,41 @@ class NoopShardDownloader(ShardDownloader):
|
||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||
yield (
|
||||
Path("/tmp/noop_shard"),
|
||||
RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
),
|
||||
NOOP_DOWNLOAD_PROGRESS,
|
||||
)
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
) -> RepoDownloadProgress:
|
||||
return RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=shard,
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
)
|
||||
dp = copy(NOOP_DOWNLOAD_PROGRESS)
|
||||
dp.shard = shard
|
||||
return dp
|
||||
|
||||
|
||||
NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ KEEP_KV_SIZE: int | None = 1600
|
||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
CACHE_GROUP_SIZE: int = 64
|
||||
KV_CACHE_BITS: int | None = 8
|
||||
TEMPERATURE: float = 1.0
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
@@ -13,7 +13,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
TEMPERATURE,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -21,6 +20,8 @@ try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
except ImportError:
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
|
||||
import contextlib
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.utils import load_model
|
||||
@@ -48,6 +49,7 @@ from exo.worker.engines.mlx.auto_parallel import (
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
# Needed for 8 bit model
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
@@ -67,7 +69,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
)
|
||||
|
||||
|
||||
def mx_barrier(group: mx.distributed.Group | None = None):
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
@@ -77,7 +79,7 @@ def mx_barrier(group: mx.distributed.Group | None = None):
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
@@ -99,91 +101,97 @@ class HostList(RootModel[list[str]]):
|
||||
|
||||
def mlx_distributed_init(
|
||||
bound_instance: BoundInstance,
|
||||
) -> mx.distributed.Group:
|
||||
) -> Group:
|
||||
"""
|
||||
Initialize the MLX distributed (runs in thread pool).
|
||||
|
||||
Either hosts or mlx_ibv_devices must be provided:
|
||||
- hosts: traditional host-based connectivity using MLX_HOSTFILE
|
||||
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
|
||||
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
|
||||
- strict: if True, raise an error if the distributed backend is not available
|
||||
Initialize MLX distributed.
|
||||
"""
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts=hosts):
|
||||
hostfile = f"./hosts_{rank}.json"
|
||||
hosts_json = HostList.from_hosts(hosts).model_dump_json()
|
||||
coordination_file = None
|
||||
try:
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
|
||||
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
|
||||
|
||||
with open(hostfile, "w") as f:
|
||||
_ = f.write(hosts_json)
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(hosts_json)
|
||||
|
||||
logger.info(f"rank {rank} hostfile: {hostfile} hosts: {hosts_json}")
|
||||
logger.info(
|
||||
f"rank {rank} hostfile: {coordination_file} hosts: {hosts_json}"
|
||||
)
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
os.environ["MLX_HOSTFILE"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
devices_file = f"./hosts_{rank}.json"
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
|
||||
with open(devices_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = devices_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
return group
|
||||
finally:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
if coordination_file:
|
||||
os.remove(coordination_file)
|
||||
|
||||
|
||||
def initialize_mlx(
|
||||
bound_instance: BoundInstance,
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
"""
|
||||
Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread.
|
||||
"""
|
||||
) -> Group:
|
||||
# should we unseed it?
|
||||
# TODO: pass in seed from params
|
||||
mx.random.seed(42)
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
assert len(bound_instance.instance.shard_assignments.node_to_runner) > 1, (
|
||||
"Tried to initialize mlx for a single node instance"
|
||||
)
|
||||
return mlx_distributed_init(bound_instance)
|
||||
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
# TODO: pass temperature
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
|
||||
logger.info("Created a sampler")
|
||||
|
||||
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
# model, config = quantize_model(
|
||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
||||
# )
|
||||
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
group = mlx_distributed_init(bound_instance)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
@@ -193,14 +201,12 @@ def initialize_mlx(
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
logger.debug(model)
|
||||
|
||||
return cast(Model, model), tokenizer, sampler
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: mx.distributed.Group,
|
||||
group: Group,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from anyio import CancelScope, create_task_group, current_time, fail_after
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
@@ -23,6 +23,7 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
@@ -227,7 +228,7 @@ class Worker:
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.event_sender.send_nowait(
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
@@ -256,12 +257,34 @@ class Worker:
|
||||
|
||||
async def _connection_message_event_writer(self):
|
||||
with self.connection_message_receiver as connection_messages:
|
||||
async for _msg in connection_messages:
|
||||
break
|
||||
# TODO: use mdns for partial discovery
|
||||
# for event in check_connections(self.node_id, msg, self.state):
|
||||
# logger.info(f"Worker discovered connection {event}")
|
||||
# await self.event_sender.send(event)
|
||||
async for msg in connection_messages:
|
||||
await self.event_sender.send(
|
||||
self._convert_connection_message_to_event(msg)
|
||||
)
|
||||
|
||||
def _convert_connection_message_to_event(self, msg: ConnectionMessage):
|
||||
match msg.connection_type:
|
||||
case ConnectionMessageType.Connected:
|
||||
return TopologyEdgeCreated(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
case ConnectionMessageType.Disconnected:
|
||||
return TopologyEdgeDeleted(
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
@@ -380,7 +403,7 @@ class Worker:
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.trace(
|
||||
logger.debug(
|
||||
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
|
||||
)
|
||||
self.local_event_index += 1
|
||||
@@ -389,21 +412,35 @@ class Worker:
|
||||
|
||||
async def _poll_connection_updates(self):
|
||||
while True:
|
||||
edges = self.state.topology.out_edges(self.node_id)
|
||||
pure_edges = set(edge for _, edge in edges)
|
||||
conns = await check_reachable(self.state.topology)
|
||||
|
||||
for nid, conn in edges:
|
||||
if nid in conns and conn.sink_addr in conns.get(nid, set()):
|
||||
continue
|
||||
|
||||
logger.debug(f"ping failed to discover {conn=}")
|
||||
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
|
||||
# TODO: EdgeDeleted
|
||||
edges = set(self.state.topology.list_connections())
|
||||
conns = await check_reachable(self.state.topology, self.node_id)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
edge = Connection(sink_id=self.node_id, source_id=nid, sink_addr=ip)
|
||||
if edge not in pure_edges:
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
logger.warning(
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=nid,
|
||||
# nonsense multiaddr
|
||||
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
if "." in ip
|
||||
# nonsense multiaddr
|
||||
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
|
||||
)
|
||||
if edge not in edges:
|
||||
logger.debug(f"ping discovered {edge=}")
|
||||
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
|
||||
|
||||
for nid, conn in self.state.topology.out_edges(self.node_id):
|
||||
if (
|
||||
nid not in conns
|
||||
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
|
||||
):
|
||||
logger.debug(f"ping failed to discover {conn=}")
|
||||
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Mapping, Sequence
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
LoadModel,
|
||||
@@ -14,17 +15,23 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
@@ -48,6 +55,7 @@ def plan(
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
@@ -106,9 +114,11 @@ def _model_needs_download(
|
||||
download_status: Mapping[ShardMetadata, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
if (
|
||||
isinstance(runner.status, RunnerWaitingForModel)
|
||||
and runner.bound_instance.bound_shard not in download_status
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
not isinstance(
|
||||
download_status.get(runner.bound_instance.bound_shard, None),
|
||||
(DownloadOngoing, DownloadCompleted),
|
||||
)
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
@@ -117,14 +127,54 @@ def _model_needs_download(
|
||||
)
|
||||
|
||||
|
||||
""" --- TODO!
|
||||
def _init_backend(
|
||||
def _init_distributed_backend(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
) -> LoadModel | None:
|
||||
for runner in runner.values()
|
||||
pass
|
||||
"""
|
||||
):
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance:
|
||||
continue
|
||||
|
||||
runner_is_idle = isinstance(runner.status, RunnerIdle)
|
||||
all_runners_connecting = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id),
|
||||
(RunnerConnecting, RunnerIdle),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if not (runner_is_idle and all_runners_connecting):
|
||||
continue
|
||||
|
||||
runner_id = runner.bound_instance.bound_runner_id
|
||||
|
||||
shard = runner.bound_instance.bound_shard
|
||||
device_rank = shard.device_rank
|
||||
world_size = shard.world_size
|
||||
|
||||
assert device_rank < world_size
|
||||
assert device_rank >= 0
|
||||
|
||||
accepting_ranks = device_rank < world_size - 1
|
||||
|
||||
# Rank = n-1
|
||||
connecting_rank_ready = device_rank == world_size - 1 and all(
|
||||
isinstance(all_runners.get(global_runner_id, None), RunnerConnecting)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
if global_runner_id != runner_id
|
||||
)
|
||||
|
||||
if not (accepting_ranks or connecting_rank_ready):
|
||||
continue
|
||||
|
||||
return ConnectToGroup(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _load_model(
|
||||
@@ -136,31 +186,33 @@ def _load_model(
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
all_downloads_complete_local = all(
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata == shard_assignments.runner_to_shard[rid]
|
||||
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
|
||||
for dp in global_download_status[nid]
|
||||
)
|
||||
for nid, rid in shard_assignments.node_to_runner.items()
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
if not all_local_downloads_complete:
|
||||
continue
|
||||
|
||||
runner_is_waiting = isinstance(runner.status, RunnerWaitingForModel)
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
all_runners_expecting_model = all(
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id),
|
||||
(RunnerWaitingForModel, RunnerLoading, RunnerLoaded),
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if (
|
||||
all_downloads_complete_local
|
||||
and runner_is_waiting
|
||||
and all_runners_expecting_model
|
||||
):
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
@@ -183,8 +235,9 @@ def _ready_to_warmup(
|
||||
assert device_rank < world_size
|
||||
assert device_rank >= 0
|
||||
|
||||
# Rank != n-1
|
||||
accepting_ranks_ready = device_rank != world_size - 1 and all(
|
||||
# TODO: Ensure these align with MLX distributeds expectations.
|
||||
# Rank < n-1
|
||||
accepting_ranks_ready = device_rank < world_size - 1 and all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerLoaded, RunnerWarmingUp),
|
||||
@@ -221,6 +274,8 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
|
||||
@@ -2,16 +2,13 @@ import os
|
||||
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
|
||||
logger: "loguru.Logger"
|
||||
|
||||
|
||||
if os.getenv("EXO_TESTS") == "1":
|
||||
logger = loguru.logger
|
||||
logger: "loguru.Logger" = loguru.logger
|
||||
|
||||
|
||||
def entrypoint(
|
||||
@@ -30,6 +27,23 @@ def entrypoint(
|
||||
logger = _logger
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
from exo.worker.runner.runner import main
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
f"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}"
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(error_message=str(e)),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo.shared.types.events import (
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -22,20 +23,23 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -63,9 +67,10 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
sampler = None
|
||||
group = None
|
||||
|
||||
current_status: RunnerStatus = RunnerWaitingForModel()
|
||||
logger.info("runner waiting for model")
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
@@ -78,9 +83,26 @@ def main(
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case LoadModel() if isinstance(
|
||||
current_status, (RunnerWaitingForModel, RunnerFailed)
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected)
|
||||
and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
@@ -89,15 +111,12 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, sampler = initialize_mlx(bound_instance)
|
||||
model, tokenizer, sampler = load_mlx_items(
|
||||
bound_instance, group
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
@@ -123,11 +142,6 @@ def main(
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerReady()
|
||||
)
|
||||
)
|
||||
case ChatCompletion(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
@@ -172,11 +186,6 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerReady()
|
||||
)
|
||||
)
|
||||
case Shutdown():
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
@@ -186,12 +195,19 @@ def main(
|
||||
)
|
||||
break
|
||||
case _:
|
||||
raise ValueError("Received task outside of state machine")
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user