mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 14:55:13 -05:00
Compare commits
6 Commits
leo/add-gl
...
simplify-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1397ee38a5 | ||
|
|
a808b93b7c | ||
|
|
6c322ebb72 | ||
|
|
2ebe6216b4 | ||
|
|
f54c80b121 | ||
|
|
48b8f86395 |
13
Cargo.lock
generated
13
Cargo.lock
generated
@@ -890,7 +890,7 @@ dependencies = [
|
||||
"delegate",
|
||||
"env_logger",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-lite",
|
||||
"libp2p",
|
||||
"log",
|
||||
"networking",
|
||||
@@ -914,6 +914,12 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "ff"
|
||||
version = "0.13.1"
|
||||
@@ -1022,7 +1028,10 @@ version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"parking",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
@@ -2753,7 +2762,7 @@ dependencies = [
|
||||
"delegate",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-lite",
|
||||
"futures-timer",
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
|
||||
@@ -29,14 +29,13 @@ util = { path = "rust/util" }
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures = "0.3"
|
||||
futures-lite = "2.6.1"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
|
||||
@@ -103,7 +103,7 @@
|
||||
const modelSupportsThinking = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
const caps = modelCapabilities[currentModel] || [];
|
||||
return caps.includes("thinking") && caps.includes("text");
|
||||
return caps.includes("thinking_toggle") && caps.includes("text");
|
||||
});
|
||||
|
||||
const isEditOnlyWithoutImage = $derived(
|
||||
|
||||
@@ -59,13 +59,14 @@
|
||||
}
|
||||
|
||||
const sizeOptions: ImageGenerationParams["size"][] = [
|
||||
"auto",
|
||||
"512x512",
|
||||
"768x768",
|
||||
"1024x1024",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
"1024x1365",
|
||||
"1365x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
];
|
||||
|
||||
const qualityOptions: ImageGenerationParams["quality"][] = [
|
||||
@@ -176,92 +177,90 @@
|
||||
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
|
||||
<!-- Basic params row -->
|
||||
<div class="flex items-center gap-3 flex-wrap">
|
||||
<!-- Size (hidden in edit mode - output size comes from input image) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
<!-- Size -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
{params.size.toUpperCase()}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
{params.size}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size.toUpperCase()}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Quality -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
@@ -311,7 +310,7 @@
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
|
||||
style="bottom: calc(100vh - {qualityDropdownPosition()
|
||||
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
|
||||
>
|
||||
|
||||
@@ -306,13 +306,14 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
|
||||
export interface ImageGenerationParams {
|
||||
// Basic params
|
||||
size:
|
||||
| "auto"
|
||||
| "512x512"
|
||||
| "768x768"
|
||||
| "1024x1024"
|
||||
| "1024x768"
|
||||
| "768x1024"
|
||||
| "1024x1365"
|
||||
| "1365x1024";
|
||||
| "1024x1536"
|
||||
| "1536x1024";
|
||||
quality: "low" | "medium" | "high";
|
||||
outputFormat: "png" | "jpeg";
|
||||
numImages: number;
|
||||
@@ -336,7 +337,7 @@ export interface EditingImage {
|
||||
}
|
||||
|
||||
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
size: "1024x1024",
|
||||
size: "auto",
|
||||
quality: "medium",
|
||||
outputFormat: "png",
|
||||
numImages: 1,
|
||||
|
||||
@@ -74,7 +74,6 @@
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "deepseek"
|
||||
quantization = "4bit"
|
||||
base_model = "DeepSeek V3.1"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405874409472
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "deepseek"
|
||||
quantization = "8bit"
|
||||
base_model = "DeepSeek V3.1"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 765577920512
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.5 Air"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 122406567936
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM 4.5 Air"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 229780750336
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "4bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 198556925568
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "6bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 286737579648
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 396963397248
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "4bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 19327352832
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "5bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 22548578304
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "6bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26843545600
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 34359738368
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "kimi"
|
||||
quantization = ""
|
||||
base_model = "Kimi K2"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 706522120192
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "kimi"
|
||||
quantization = ""
|
||||
base_model = "Kimi K2.5"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 662498705408
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "3bit"
|
||||
base_model = "MiniMax M2.1"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 100086644736
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "8bit"
|
||||
base_model = "MiniMax M2.1"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 0.6B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 342884352
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 0.6B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 698351616
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 235B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 141733920768
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 235B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 268435456000
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 30B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 17612931072
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 30B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33279705088
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 Next 80B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 47080074240
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 Next 80B"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "4bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 114572190076
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "6bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 159039627774
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "8bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking"]
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 209082699847
|
||||
|
||||
@@ -27,7 +27,7 @@ networking = { workspace = true }
|
||||
# interop
|
||||
pyo3 = { version = "0.27.2", features = [
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
"nightly", # enables better-supported GIL integration
|
||||
# "nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
@@ -45,11 +45,10 @@ pyo3-log = "0.13.2"
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
@@ -60,3 +59,4 @@ env_logger = "0.11"
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
|
||||
@@ -19,7 +19,7 @@ class ConnectionUpdate:
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> PeerId:
|
||||
def peer_id(self) -> builtins.str:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@@ -40,92 +40,22 @@ class Keypair:
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ed25519() -> Keypair:
|
||||
def generate() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ecdsa() -> Keypair:
|
||||
def from_bytes(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Generate a new ECDSA keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_secp256k1() -> Keypair:
|
||||
r"""
|
||||
Generate a new Secp256k1 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
"""
|
||||
@staticmethod
|
||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
format (i.e. unencrypted) as defined in [RFC5208].
|
||||
|
||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
"""
|
||||
@staticmethod
|
||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
structure as defined in [RFC5915].
|
||||
|
||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
"""
|
||||
@staticmethod
|
||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key as protobuf structure.
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Multiaddr:
|
||||
r"""
|
||||
Representation of a Multiaddr.
|
||||
"""
|
||||
@staticmethod
|
||||
def empty() -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress.
|
||||
"""
|
||||
@staticmethod
|
||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress with the given capacity.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its byte slice representation.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string: builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its string representation.
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
Return the length in bytes of this multiaddress.
|
||||
"""
|
||||
def is_empty(self) -> builtins.bool:
|
||||
r"""
|
||||
Returns true if the length of this multiaddress is 0.
|
||||
Construct an Ed25519 keypair from secret key bytes
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
Get the secret key bytes underlying the keypair
|
||||
"""
|
||||
def to_string(self) -> builtins.str:
|
||||
def to_node_id(self) -> builtins.str:
|
||||
r"""
|
||||
Convert a Multiaddr to a string.
|
||||
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our NodeId.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
@@ -180,37 +110,6 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class PeerId:
|
||||
r"""
|
||||
Identifier of a peer of the network.
|
||||
|
||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
"""
|
||||
@staticmethod
|
||||
def random() -> PeerId:
|
||||
r"""
|
||||
Generates a random peer ID from a cryptographically secure PRNG.
|
||||
|
||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> PeerId:
|
||||
r"""
|
||||
Parses a `PeerId` from bytes.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Returns a raw bytes representation of this `PeerId`.
|
||||
"""
|
||||
def to_base58(self) -> builtins.str:
|
||||
r"""
|
||||
Returns a base-58 encoded string of this `PeerId`.
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
//!
|
||||
|
||||
use pin_project::pin_project;
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
@@ -26,8 +25,8 @@ where
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Ungil,
|
||||
F::Output: Ungil,
|
||||
F: Future + Send,
|
||||
F::Output: Send,
|
||||
{
|
||||
type Output = F::Output;
|
||||
|
||||
|
||||
47
rust/exo_pyo3_bindings/src/ident.rs
Normal file
47
rust/exo_pyo3_bindings/src/ident.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::types::{PyBytes, PyBytesMethods};
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Construct an Ed25519 keypair from secret key bytes
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Get the secret key bytes underlying the keypair
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self
|
||||
.0
|
||||
.clone()
|
||||
.try_into_ed25519()
|
||||
.expect("we only use ed25519 keys")
|
||||
.secret()
|
||||
.as_ref()
|
||||
.to_vec();
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our NodeId.
|
||||
fn to_node_id(&self) -> String {
|
||||
self.0.public().to_peer_id().to_base58()
|
||||
}
|
||||
}
|
||||
@@ -4,26 +4,14 @@
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(tuple_trait)]
|
||||
#![feature(unboxed_closures)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
mod ident;
|
||||
mod networking;
|
||||
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::types::PyModuleMethods;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
@@ -32,14 +20,6 @@ pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
@@ -179,8 +159,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
m.add_class::<PyKeypair>()?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: PyPeerId,
|
||||
peer_id: String,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
@@ -251,7 +251,7 @@ async fn networking_task(
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
@@ -272,7 +272,7 @@ async fn networking_task(
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
||||
//!
|
||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
||||
//! independent identity type of some kind or another. This may require handshaking.
|
||||
//!
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
@@ -1,81 +0,0 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::Multiaddr;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::str::FromStr as _;
|
||||
|
||||
/// Representation of a Multiaddr.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Multiaddr", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyMultiaddr(pub Multiaddr);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyMultiaddr {
|
||||
/// Create a new, empty multiaddress.
|
||||
#[staticmethod]
|
||||
fn empty() -> Self {
|
||||
Self(Multiaddr::empty())
|
||||
}
|
||||
|
||||
/// Create a new, empty multiaddress with the given capacity.
|
||||
#[staticmethod]
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self(Multiaddr::with_capacity(n))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its string representation.
|
||||
#[staticmethod]
|
||||
fn from_string(string: String) -> PyResult<Self> {
|
||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
||||
}
|
||||
|
||||
/// Return the length in bytes of this multiaddress.
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns true if the length of this multiaddress is 0.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_vec();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Convert a Multiaddr to a string.
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __repr__(&self) -> String {
|
||||
format!("Multiaddr({})", self.0)
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __str__(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyMultiaddr>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -22,7 +22,7 @@ delegate = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use futures::stream::StreamExt as _;
|
||||
use futures_lite::StreamExt;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
@@ -38,19 +38,19 @@ async fn main() {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
event = swarm.next() => match event {
|
||||
// on gossipsub incoming
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
}))) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
@@ -64,7 +64,7 @@ async fn main() {
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use futures::stream::StreamExt;
|
||||
use libp2p::{
|
||||
gossipsub, mdns, noise,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::error::Error;
|
||||
use std::time::Duration;
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
||||
#[derive(NetworkBehaviour)]
|
||||
struct MyBehaviour {
|
||||
gossipsub: gossipsub::Behaviour,
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.try_init();
|
||||
|
||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
||||
.with_tokio()
|
||||
.with_tcp(
|
||||
tcp::Config::default(),
|
||||
noise::Config::new,
|
||||
yamux::Config::default,
|
||||
)?
|
||||
.with_behaviour(|key| {
|
||||
// Set a custom gossipsub configuration
|
||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
||||
.heartbeat_interval(Duration::from_secs(10))
|
||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
||||
.build()
|
||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
||||
gossipsub_config,
|
||||
)?;
|
||||
|
||||
let mdns =
|
||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
||||
Ok(MyBehaviour { gossipsub, mdns })
|
||||
})?
|
||||
.build();
|
||||
|
||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
||||
|
||||
// Create a Gossipsub topic
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
// subscribes to our topic
|
||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
SwarmEvent::NewListenAddr { address, .. } => {
|
||||
println!("Local node is listening on {address}");
|
||||
}
|
||||
e => {
|
||||
println!("Other swarm event: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
use futures_lite::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
@@ -362,7 +362,7 @@ impl NetworkBehaviour for Behaviour {
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
||||
if self.retry_delay.poll(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
|
||||
@@ -31,7 +31,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use futures_lite::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, ... }:
|
||||
{ inputs', pkgs, lib, ... }:
|
||||
let
|
||||
# Fenix nightly toolchain with all components
|
||||
fenixPkgs = inputs'.fenix.packages;
|
||||
rustToolchain = fenixPkgs.complete.withComponents [
|
||||
rustToolchain = inputs'.fenix.packages.stable.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
[toolchain]
|
||||
channel = "nightly"
|
||||
@@ -47,6 +47,7 @@ class DownloadCoordinator:
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
event_index_counter: Iterator[int]
|
||||
offline: bool = False
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
@@ -62,6 +63,8 @@ class DownloadCoordinator:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
if self.offline:
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
|
||||
def _model_dir(self, model_id: ModelId) -> str:
|
||||
@@ -107,13 +110,17 @@ class DownloadCoordinator:
|
||||
self._last_progress_time[model_id] = current_time()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
self._test_internet_connection()
|
||||
logger.info(
|
||||
f"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}"
|
||||
)
|
||||
if not self.offline:
|
||||
self._test_internet_connection()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
if not self.offline:
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
|
||||
def _test_internet_connection(self) -> None:
|
||||
try:
|
||||
@@ -202,6 +209,20 @@ class DownloadCoordinator:
|
||||
)
|
||||
return
|
||||
|
||||
if self.offline:
|
||||
logger.warning(
|
||||
f"Offline mode: model {model_id} is not fully available locally, cannot download"
|
||||
)
|
||||
failed = DownloadFailed(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=f"Model files not found locally in offline mode: {model_id}",
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=failed))
|
||||
return
|
||||
|
||||
# Start actual download
|
||||
self._start_download_task(shard, initial_progress)
|
||||
|
||||
|
||||
@@ -448,12 +448,13 @@ async def download_file_with_retry(
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
on_connection_lost: Callable[[], None] = lambda: None,
|
||||
skip_internet: bool = False,
|
||||
) -> Path:
|
||||
n_attempts = 3
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _download_file(
|
||||
model_id, revision, path, target_dir, on_progress
|
||||
model_id, revision, path, target_dir, on_progress, skip_internet
|
||||
)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
@@ -487,10 +488,14 @@ async def _download_file(
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
skip_internet: bool = False,
|
||||
) -> Path:
|
||||
target_path = target_dir / path
|
||||
|
||||
if await aios.path.exists(target_path):
|
||||
if skip_internet:
|
||||
return target_path
|
||||
|
||||
local_size = (await aios.stat(target_path)).st_size
|
||||
|
||||
# Try to verify against remote, but allow offline operation
|
||||
@@ -510,6 +515,11 @@ async def _download_file(
|
||||
)
|
||||
return target_path
|
||||
|
||||
if skip_internet:
|
||||
raise FileNotFoundError(
|
||||
f"File {path} not found locally and cannot download in offline mode"
|
||||
)
|
||||
|
||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(model_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
@@ -814,6 +824,7 @@ async def download_shard(
|
||||
file, curr_bytes, total_bytes, is_renamed
|
||||
),
|
||||
on_connection_lost=on_connection_lost,
|
||||
skip_internet=skip_internet,
|
||||
)
|
||||
|
||||
if not skip_download:
|
||||
|
||||
230
src/exo/download/tests/test_offline_mode.py
Normal file
230
src/exo/download/tests/test_offline_mode.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""Tests for offline/air-gapped mode."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import pytest
|
||||
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
download_file_with_retry,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.worker.downloads import FileListEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_id() -> ModelId:
|
||||
return ModelId("test-org/test-model")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
yield models_dir
|
||||
|
||||
|
||||
class TestDownloadFileOffline:
|
||||
"""Tests for _download_file with skip_internet=True."""
|
||||
|
||||
async def test_returns_local_file_without_http_verification(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file exists locally, return it immediately
|
||||
without making any HTTP calls (no file_meta verification)."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
local_file = target_dir / "model.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"model weights data")
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
async def test_raises_file_not_found_for_missing_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file does NOT exist locally,
|
||||
raise FileNotFoundError instead of attempting download."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="offline mode"):
|
||||
await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"missing_model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
async def test_returns_local_file_in_subdirectory(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and file exists in a subdirectory,
|
||||
return it without HTTP calls."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
subdir = target_dir / "transformer"
|
||||
await aios.makedirs(subdir, exist_ok=True)
|
||||
|
||||
local_file = subdir / "diffusion_pytorch_model.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"weights")
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await _download_file(
|
||||
model_id,
|
||||
"main",
|
||||
"transformer/diffusion_pytorch_model.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
|
||||
class TestDownloadFileWithRetryOffline:
|
||||
"""Tests for download_file_with_retry with skip_internet=True."""
|
||||
|
||||
async def test_propagates_skip_internet_to_download_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Verify skip_internet is passed through to _download_file."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
local_file = target_dir / "config.json"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b'{"model_type": "qwen2"}')
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_file_meta:
|
||||
result = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_not_called()
|
||||
|
||||
async def test_file_not_found_does_not_retry(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""FileNotFoundError from offline mode should not trigger retries."""
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
"nonexistent.safetensors",
|
||||
target_dir,
|
||||
skip_internet=True,
|
||||
)
|
||||
|
||||
|
||||
class TestFetchFileListOffline:
|
||||
"""Tests for fetch_file_list_with_cache with skip_internet=True."""
|
||||
|
||||
async def test_uses_cached_file_list(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and cache file exists, use it without network."""
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
cache_dir = temp_models_dir / "caches" / model_id.normalize()
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
cached_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
FileListEntry(type="file", path="config.json", size=200),
|
||||
]
|
||||
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
|
||||
)
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_fetch:
|
||||
result = await fetch_file_list_with_cache(
|
||||
model_id, "main", skip_internet=True
|
||||
)
|
||||
|
||||
assert result == cached_list
|
||||
mock_fetch.assert_not_called()
|
||||
|
||||
async def test_falls_back_to_local_directory_scan(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and no cache but local files exist,
|
||||
build file list from local directory."""
|
||||
import json
|
||||
|
||||
model_dir = temp_models_dir / model_id.normalize()
|
||||
await aios.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
async with aiofiles.open(model_dir / "config.json", "w") as f:
|
||||
await f.write('{"model_type": "qwen2"}')
|
||||
|
||||
index_data = {
|
||||
"metadata": {},
|
||||
"weight_map": {"model.layers.0.weight": "model.safetensors"},
|
||||
}
|
||||
async with aiofiles.open(model_dir / "model.safetensors.index.json", "w") as f:
|
||||
await f.write(json.dumps(index_data))
|
||||
|
||||
async with aiofiles.open(model_dir / "model.safetensors", "wb") as f:
|
||||
await f.write(b"x" * 500)
|
||||
|
||||
with patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_fetch:
|
||||
result = await fetch_file_list_with_cache(
|
||||
model_id, "main", skip_internet=True
|
||||
)
|
||||
|
||||
mock_fetch.assert_not_called()
|
||||
paths = {entry.path for entry in result}
|
||||
assert "config.json" in paths
|
||||
assert "model.safetensors" in paths
|
||||
|
||||
async def test_raises_when_no_cache_and_no_local_files(
|
||||
self, model_id: ModelId, temp_models_dir: Path
|
||||
) -> None:
|
||||
"""When skip_internet=True and neither cache nor local files exist,
|
||||
raise FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="No internet"):
|
||||
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)
|
||||
@@ -39,12 +39,13 @@ class Node:
|
||||
|
||||
node_id: NodeId
|
||||
event_index_counter: Iterator[int]
|
||||
offline: bool
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
router = Router.create(keypair)
|
||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||
@@ -68,6 +69,7 @@ class Node:
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=event_index_counter,
|
||||
offline=args.offline,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
@@ -132,6 +134,7 @@ class Node:
|
||||
api,
|
||||
node_id,
|
||||
event_index_counter,
|
||||
args.offline,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
@@ -222,6 +225,7 @@ class Node:
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=self.event_index_counter,
|
||||
offline=self.offline,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
@@ -260,6 +264,9 @@ def main():
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
if args.offline:
|
||||
logger.info("Running in OFFLINE mode — no internet checks, local models only")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
@@ -282,6 +289,7 @@ class Args(CamelCaseModel):
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
no_downloads: bool = False
|
||||
offline: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
@@ -329,6 +337,11 @@ class Args(CamelCaseModel):
|
||||
action="store_true",
|
||||
help="Disable the download coordinator (node won't download models)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offline",
|
||||
action="store_true",
|
||||
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
|
||||
@@ -85,6 +85,7 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
ImageListItem,
|
||||
ImageListResponse,
|
||||
ImageSize,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -100,6 +101,7 @@ from exo.shared.types.api import (
|
||||
TraceRankStats,
|
||||
TraceResponse,
|
||||
TraceStatsResponse,
|
||||
normalize_image_size,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
@@ -751,9 +753,11 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
update={
|
||||
"model": await self._validate_image_model(ModelId(payload.model)),
|
||||
"advanced_params": _ensure_seed(payload.advanced_params),
|
||||
}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
@@ -1009,12 +1013,13 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
update={
|
||||
"model": await self._validate_image_model(ModelId(payload.model)),
|
||||
"stream": False,
|
||||
"partial_images": 0,
|
||||
"advanced_params": _ensure_seed(payload.advanced_params),
|
||||
}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
@@ -1035,7 +1040,7 @@ class API:
|
||||
prompt: str,
|
||||
model: ModelId,
|
||||
n: int,
|
||||
size: str,
|
||||
size: ImageSize,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
input_fidelity: Literal["low", "high"],
|
||||
stream: bool,
|
||||
@@ -1105,7 +1110,7 @@ class API:
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
size: str | None = Form(None),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: str = Form("false"),
|
||||
@@ -1131,7 +1136,7 @@ class API:
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=size,
|
||||
size=normalize_image_size(size),
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=stream_bool,
|
||||
@@ -1167,7 +1172,7 @@ class API:
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
size: str | None = Form(None),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
quality: Literal["high", "medium", "low"] = Form("medium"),
|
||||
@@ -1187,7 +1192,7 @@ class API:
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=size,
|
||||
size=normalize_image_size(size),
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=False,
|
||||
|
||||
@@ -42,7 +42,7 @@ from exo.utils.channels import channel
|
||||
@pytest.mark.asyncio
|
||||
async def test_master():
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
|
||||
@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
|
||||
@classmethod
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(update.peer_id.to_base58()),
|
||||
node_id=NodeId(update.peer_id),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
|
||||
@@ -221,7 +221,7 @@ def get_node_id_keypair(
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate_ed25519()
|
||||
return Keypair.generate()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
@@ -235,12 +235,12 @@ def get_node_id_keypair(
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
return Keypair.from_bytes(protobuf_encoded)
|
||||
except ValueError as e: # on runtime error, assume corrupt file
|
||||
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
f.write(keypair.to_bytes())
|
||||
return keypair
|
||||
|
||||
@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
|
||||
sem.release()
|
||||
# wait to be told to begin simultaneous read
|
||||
ev.wait()
|
||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
||||
queue.put(get_node_id_keypair().to_bytes())
|
||||
|
||||
|
||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Annotated, Any, Literal, get_args
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
@@ -262,6 +262,27 @@ class DeleteInstanceResponse(BaseModel):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
ImageSize = Literal[
|
||||
"auto",
|
||||
"512x512",
|
||||
"768x768",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
"1024x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
]
|
||||
|
||||
|
||||
def normalize_image_size(v: object) -> ImageSize:
|
||||
"""Shared validator for ImageSize fields: maps None → "auto" and rejects invalid values."""
|
||||
if v is None:
|
||||
return "auto"
|
||||
if v not in get_args(ImageSize):
|
||||
raise ValueError(f"Invalid size: {v!r}. Must be one of {get_args(ImageSize)}")
|
||||
return v # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
class AdvancedImageParams(BaseModel):
|
||||
seed: Annotated[int, Field(ge=0)] | None = None
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
@@ -281,7 +302,7 @@ class ImageGenerationTaskParams(BaseModel):
|
||||
partial_images: int | None = 0
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
size: ImageSize = "auto"
|
||||
stream: bool | None = False
|
||||
style: str | None = "vivid"
|
||||
user: str | None = None
|
||||
@@ -289,6 +310,11 @@ class ImageGenerationTaskParams(BaseModel):
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
@field_validator("size", mode="before")
|
||||
@classmethod
|
||||
def normalize_size(cls, v: object) -> ImageSize:
|
||||
return normalize_image_size(v)
|
||||
|
||||
|
||||
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
|
||||
bench: bool = True
|
||||
@@ -305,13 +331,18 @@ class ImageEditsTaskParams(BaseModel):
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
size: ImageSize = "auto"
|
||||
image_strength: float | None = 0.7
|
||||
stream: bool = False
|
||||
partial_images: int | None = 0
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
bench: bool = False
|
||||
|
||||
@field_validator("size", mode="before")
|
||||
@classmethod
|
||||
def normalize_size(cls, v: object) -> ImageSize:
|
||||
return normalize_image_size(v)
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
if name == "image_data":
|
||||
|
||||
@@ -14,6 +14,7 @@ from exo.shared.types.api import (
|
||||
ImageEditsTaskParams,
|
||||
ImageGenerationStats,
|
||||
ImageGenerationTaskParams,
|
||||
ImageSize,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
@@ -23,9 +24,9 @@ from exo.shared.types.worker.runner_response import (
|
||||
from exo.worker.engines.image.distributed_model import DistributedImageModel
|
||||
|
||||
|
||||
def parse_size(size_str: str | None) -> tuple[int, int]:
|
||||
def parse_size(size_str: ImageSize) -> tuple[int, int]:
|
||||
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
|
||||
if not size_str:
|
||||
if size_str == "auto":
|
||||
return (1024, 1024)
|
||||
|
||||
try:
|
||||
@@ -109,6 +110,9 @@ def generate_image(
|
||||
# Decode base64 image data and save to temp file
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
if task.size == "auto":
|
||||
with Image.open(image_path) as img:
|
||||
width, height = img.size
|
||||
|
||||
for image_num in range(num_images):
|
||||
# Increment seed for each image to ensure unique results
|
||||
|
||||
@@ -57,7 +57,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
group: mx.distributed.Group | None,
|
||||
group: mx.distributed.Group | None = None,
|
||||
) -> tuple[float, int, list[CacheSnapshot]]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user