Compare commits

...

3 Commits

Author SHA1 Message Date
Evan
584ecd173b woahg 2026-02-04 14:58:05 +00:00
Evan
394d6c9af2 add resources dir to nix 2026-02-04 14:34:04 +00:00
Alex Cheema
41ed7afb3b feat: add model picker modal with grouped models and HF Hub search (#1369)
## Motivation

Reimplements the model picker modal from #1191 on top of the custom
model support branch. Replaces the inline model dropdown with a
full-featured modal that groups models by base model, supports
filtering, favorites, and HuggingFace Hub search.

## Changes

**Backend:**
- Add `family`, `quantization`, `base_model`, `capabilities` metadata
fields to `ModelCard` and all 40 TOML model cards
- Pass new fields through `ModelListModel` and `get_models()` API
response
- Add `GET /models/search` endpoint using
`huggingface_hub.list_models()`

**Dashboard (7 new files):**
- `ModelPickerModal.svelte` — Main modal with search, family filtering,
HuggingFace Hub tab
- `ModelPickerGroup.svelte` — Expandable model group row with
quantization variants
- `FamilySidebar.svelte` — Vertical sidebar with family icons (All,
Favorites, Hub, model families)
- `FamilyLogos.svelte` — SVG icons for each model family
- `ModelFilterPopover.svelte` — Capability and size range filters
- `HuggingFaceResultItem.svelte` — HF search result item with
download/like counts
- `favorites.svelte.ts` — localStorage-backed favorites store

**Integration:**
- Replace inline dropdown in `+page.svelte` with button that opens
`ModelPickerModal`
- Custom models shown in Hub tab with delete support

**Polish:**
- Real brand logos (Meta, Qwen, DeepSeek, OpenAI, GLM, MiniMax, Kimi,
HuggingFace) from Simple Icons / LobeHub
- Clean SVG stroke icons for capabilities (thinking, code, vision, image
gen)
- Consistent `border-exo-yellow/10` borders, descriptive tooltips
throughout
- Cluster memory (used/total) shown in modal header
- Selected model highlight with checkmark for both single and
multi-variant groups
- Cursor pointer on all interactive elements, fix filter popover
click-outside bug
- Custom models now appear in All tab alongside built-in models

## Bug Fix: Gemma 3 EOS tokens

Also included in this branch: fix for Gemma 3 models generating infinite
`<end_of_turn>` tokens. The tokenizer's `eos_token_ids` was missing
token ID 106 (`<end_of_turn>`), so generation never stopped. The fix
appends this token to the EOS list after loading the tokenizer. Also
handles `eos_token_ids` being a `set` (not just a `list`).

## Why It Works

Model metadata (family, capabilities, etc.) is stored directly in TOML
cards rather than derived from heuristics, ensuring accuracy. The modal
groups models by `base_model` field so quantization variants appear
together. Custom models are separated into the Hub tab since they lack
grouping metadata.

## Test Plan

### Manual Testing
- Open dashboard, click model selector to open modal
- Browse models by family sidebar, search, and filters
- Expand model groups to see quantization variants
- Star favorites and verify persistence across page reloads
- Navigate to Hub tab, search and add models
- Verify error messages shown for invalid model IDs
- Run a Gemma 3 model and verify generation stops at `<end_of_turn>`

### Automated Testing
- `uv run basedpyright` — 0 errors
- `uv run ruff check` — passes
- `nix fmt` — clean
- `uv run pytest src/` — 173 passed
- `cd dashboard && npm run build` — builds successfully

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 05:56:23 -08:00
95 changed files with 2827 additions and 3404 deletions

View File

@@ -142,4 +142,6 @@ jobs:
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
export EXO_DASHBOARD_DIR="$PWD/dashboard/"
export EXO_RESOURCES_DIR="$PWD/resources"
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

1
.gitignore vendored
View File

@@ -31,3 +31,4 @@ dashboard/.svelte-kit/
# host config snapshots
hosts_*.json
.swp

View File

@@ -108,6 +108,7 @@ class TokenizerWrapper:
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
eos_token_ids: list[int] | set[int] | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int
@@ -117,7 +118,7 @@ class TokenizerWrapper:
self,
tokenizer: Any,
detokenizer_class: Any = ...,
eos_token_ids: list[int] | None = ...,
eos_token_ids: list[int] | set[int] | None = ...,
chat_template: Any = ...,
tool_parser: Any = ...,
tool_call_start: str | None = ...,

271
Cargo.lock generated
View File

@@ -141,12 +141,6 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
[[package]]
name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "asn1-rs"
version = "0.7.1"
@@ -304,19 +298,6 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
[[package]]
name = "bigdecimal"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
dependencies = [
"autocfg",
"libm",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "bimap"
version = "0.6.3"
@@ -353,31 +334,6 @@ dependencies = [
"generic-array",
]
[[package]]
name = "bon"
version = "3.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebeb9aaf9329dff6ceb65c689ca3db33dbf15f324909c60e4e5eef5701ce31b1"
dependencies = [
"bon-macros",
"rustversion",
]
[[package]]
name = "bon-macros"
version = "3.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77e9d642a7e3a318e37c2c9427b5a6a48aa1ad55dcd986f3034ab2239045a645"
dependencies = [
"darling",
"ident_case",
"prettyplease",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.111",
]
[[package]]
name = "bs58"
version = "0.5.1"
@@ -541,15 +497,6 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
[[package]]
name = "convert_case"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
dependencies = [
"unicode-segmentation",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -700,41 +647,6 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "darling"
version = "0.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn 2.0.111",
]
[[package]]
name = "darling_macro"
version = "0.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81"
dependencies = [
"darling_core",
"quote",
"syn 2.0.111",
]
[[package]]
name = "data-encoding"
version = "2.9.0"
@@ -761,17 +673,6 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "delegate"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "780eb241654bf097afb00fc5f054a09b687dad862e485fdcf8399bb056565370"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "der"
version = "0.7.10"
@@ -806,29 +707,6 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "derive_more"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
dependencies = [
"convert_case",
"proc-macro2",
"quote",
"rustc_version",
"syn 2.0.111",
"unicode-xid",
]
[[package]]
name = "digest"
version = "0.10.7"
@@ -998,37 +876,23 @@ dependencies = [
name = "exo_pyo3_bindings"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"env_logger",
"extend",
"futures",
"impl-trait-for-tuples",
"futures-lite",
"libp2p",
"log",
"networking",
"once_cell",
"pin-project",
"pyo3",
"pyo3-async-runtimes",
"pyo3-log",
"pyo3-stub-gen",
"thiserror 2.0.17",
"thread_local",
"tokio",
"util",
]
[[package]]
name = "extend"
version = "1.2.0"
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "311a6d2f1f9d60bff73d2c78a0af97ed27f79672f15c238192a5bbb64db56d00"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "ff"
@@ -1138,7 +1002,10 @@ version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
dependencies = [
"fastrand",
"futures-core",
"futures-io",
"parking",
"pin-project-lite",
]
@@ -1625,12 +1492,6 @@ dependencies = [
"zerovec",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "1.1.0"
@@ -1706,17 +1567,6 @@ dependencies = [
"xmltree",
]
[[package]]
name = "impl-trait-for-tuples"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "indexmap"
version = "2.12.1"
@@ -1745,15 +1595,6 @@ dependencies = [
"generic-array",
]
[[package]]
name = "internment"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "636d4b0f6a39fd684effe2a73f5310df16a3fa7954c26d36833e98f44d1977a2"
dependencies = [
"hashbrown 0.15.5",
]
[[package]]
name = "inventory"
version = "0.3.21"
@@ -1880,12 +1721,6 @@ dependencies = [
"cpufeatures",
]
[[package]]
name = "keccak-const"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57d8d8ce877200136358e0bbff3a77965875db3af755a11e1fa6b1b3e2df13ea"
[[package]]
name = "lalrpop-util"
version = "0.20.2"
@@ -1904,12 +1739,6 @@ version = "0.2.178"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
name = "libm"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
[[package]]
name = "libp2p"
version = "0.56.0"
@@ -2898,20 +2727,10 @@ dependencies = [
name = "networking"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"libp2p",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
]
[[package]]
@@ -2993,17 +2812,6 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@@ -3307,16 +3115,6 @@ dependencies = [
"zerocopy",
]
[[package]]
name = "prettyplease"
version = "0.2.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
dependencies = [
"proc-macro2",
"syn 2.0.111",
]
[[package]]
name = "primeorder"
version = "0.13.6"
@@ -3364,28 +3162,14 @@ version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
dependencies = [
"bigdecimal",
"either",
"hashbrown 0.16.1",
"indexmap",
"indoc",
"inventory",
"libc",
"lock_api",
"memoffset",
"num-bigint",
"num-complex",
"num-rational",
"num-traits",
"once_cell",
"ordered-float",
"parking_lot",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"rust_decimal",
"smallvec",
"unindent",
]
@@ -3740,12 +3524,6 @@ dependencies = [
"yasna",
]
[[package]]
name = "recursion"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dba2197bf7b1d87b4dd460c195f4edeb45a94e82e8054f8d5f317c1f0e93ca1"
[[package]]
name = "redox_syscall"
version = "0.5.18"
@@ -3832,16 +3610,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "rust_decimal"
version = "1.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
dependencies = [
"arrayvec",
"num-traits",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
@@ -4706,24 +4474,12 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-segmentation"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_names2"
version = "1.3.0"
@@ -4804,19 +4560,6 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "util"
version = "0.0.1"
dependencies = [
"bon",
"derive_more",
"extend",
"internment",
"once_cell",
"recursion",
"thiserror 2.0.17",
]
[[package]]
name = "uuid"
version = "1.19.0"

View File

@@ -3,7 +3,6 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/util",
]
[workspace.package]
@@ -24,62 +23,18 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
util = { path = "rust/util" }
# Proc-macro authoring tools
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0"
darling = "0.20"
# Macro dependecies
extend = "1.2"
delegate = "0.13"
impl-trait-for-tuples = "0.2"
clap = "4.5"
derive_more = { version = "2.0.1", features = ["display"] }
pin-project = "1"
# Utility dependencies
itertools = "0.14"
thiserror = "2"
internment = "0.8"
recursion = "0.5"
regex = "1.11"
once_cell = "1.21"
thread_local = "1.1"
bon = "3.4"
generativity = "1.1"
anyhow = "1.0"
keccak-const = "0.2"
# Functional generics/lenses frameworks
frunk_core = "0.4"
frunk = "0.4"
frunk_utils = "0.2"
frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-util = "0.3"
futures-timer = "3.0"
# Data structures
either = "1.15"
ordered-float = "5.0"
ahash = "0.8"
# Tracing/logging
log = "0.4"
# networking
libp2p = "0.56"
libp2p-tcp = "0.44"
[workspace.lints.rust]
static_mut_refs = "warn" # Or use "warn" instead of deny
incomplete_features = "allow"
static_mut_refs = "warn"
# Clippy's lint category level configurations;
# every member crate needs to inherit these by adding
@@ -100,64 +55,3 @@ perf = { level = "warn", priority = -1 }
pedantic = { level = "warn", priority = -1 }
nursery = { level = "warn", priority = -1 }
cargo = { level = "warn", priority = -1 }
# Individual Clippy lints from the `restriction` category
arithmetic_side_effects = "warn"
as_conversions = "warn"
assertions_on_result_states = "warn"
clone_on_ref_ptr = "warn"
decimal_literal_representation = "warn"
default_union_representation = "warn"
deref_by_slicing = "warn"
disallowed_script_idents = "deny"
else_if_without_else = "warn"
empty_enum_variants_with_brackets = "warn"
empty_structs_with_brackets = "warn"
error_impl_error = "warn"
exit = "deny"
expect_used = "warn"
float_cmp_const = "warn"
get_unwrap = "warn"
if_then_some_else_none = "warn"
impl_trait_in_params = "warn"
indexing_slicing = "warn"
infinite_loop = "warn"
let_underscore_must_use = "warn"
let_underscore_untyped = "warn"
lossy_float_literal = "warn"
mem_forget = "warn"
missing_inline_in_public_items = "warn"
multiple_inherent_impl = "warn"
multiple_unsafe_ops_per_block = "warn"
mutex_atomic = "warn"
non_zero_suggestions = "warn"
panic = "warn"
partial_pub_fields = "warn"
pattern_type_mismatch = "warn"
pub_without_shorthand = "warn"
rc_buffer = "warn"
rc_mutex = "warn"
redundant_type_annotations = "warn"
renamed_function_params = "warn"
rest_pat_in_fully_bound_structs = "warn"
same_name_method = "warn"
self_named_module_files = "deny"
semicolon_inside_block = "warn"
shadow_same = "warn"
shadow_unrelated = "warn"
str_to_string = "warn"
string_add = "warn"
string_lit_chars_any = "warn"
string_to_string = "warn"
tests_outside_test_module = "warn"
todo = "warn"
try_err = "warn"
undocumented_unsafe_blocks = "warn"
unnecessary_safety_comment = "warn"
unnecessary_safety_doc = "warn"
unneeded_field_pattern = "warn"
unseparated_literal_suffix = "warn"
unused_result_ok = "warn"
unused_trait_names = "warn"
unwrap_used = "warn"
verbose_file_reads = "warn"

View File

@@ -0,0 +1,73 @@
<script lang="ts">
type FamilyLogoProps = {
family: string;
class?: string;
};
let { family, class: className = "" }: FamilyLogoProps = $props();
</script>
{#if family === "favorites"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{:else if family === "llama" || family === "meta"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M6.915 4.03c-1.968 0-3.683 1.28-4.871 3.113C.704 9.208 0 11.883 0 14.449c0 .706.07 1.369.21 1.973a6.624 6.624 0 0 0 .265.86 5.297 5.297 0 0 0 .371.761c.696 1.159 1.818 1.927 3.593 1.927 1.497 0 2.633-.671 3.965-2.444.76-1.012 1.144-1.626 2.663-4.32l.756-1.339.186-.325c.061.1.121.196.183.3l2.152 3.595c.724 1.21 1.665 2.556 2.47 3.314 1.046.987 1.992 1.22 3.06 1.22 1.075 0 1.876-.355 2.455-.843a3.743 3.743 0 0 0 .81-.973c.542-.939.861-2.127.861-3.745 0-2.72-.681-5.357-2.084-7.45-1.282-1.912-2.957-2.93-4.716-2.93-1.047 0-2.088.467-3.053 1.308-.652.57-1.257 1.29-1.82 2.05-.69-.875-1.335-1.547-1.958-2.056-1.182-.966-2.315-1.303-3.454-1.303zm10.16 2.053c1.147 0 2.188.758 2.992 1.999 1.132 1.748 1.647 4.195 1.647 6.4 0 1.548-.368 2.9-1.839 2.9-.58 0-1.027-.23-1.664-1.004-.496-.601-1.343-1.878-2.832-4.358l-.617-1.028a44.908 44.908 0 0 0-1.255-1.98c.07-.109.141-.224.211-.327 1.12-1.667 2.118-2.602 3.358-2.602zm-10.201.553c1.265 0 2.058.791 2.675 1.446.307.327.737.871 1.234 1.579l-1.02 1.566c-.757 1.163-1.882 3.017-2.837 4.338-1.191 1.649-1.81 1.817-2.486 1.817-.524 0-1.038-.237-1.383-.794-.263-.426-.464-1.13-.464-2.046 0-2.221.63-4.535 1.66-6.088.454-.687.964-1.226 1.533-1.533a2.264 2.264 0 0 1 1.088-.285z"
/>
</svg>
{:else if family === "qwen"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12.604 1.34c.393.69.784 1.382 1.174 2.075a.18.18 0 00.157.091h5.552c.174 0 .322.11.446.327l1.454 2.57c.19.337.24.478.024.837-.26.43-.513.864-.76 1.3l-.367.658c-.106.196-.223.28-.04.512l2.652 4.637c.172.301.111.494-.043.77-.437.785-.882 1.564-1.335 2.34-.159.272-.352.375-.68.37-.777-.016-1.552-.01-2.327.016a.099.099 0 00-.081.05 575.097 575.097 0 01-2.705 4.74c-.169.293-.38.363-.725.364-.997.003-2.002.004-3.017.002a.537.537 0 01-.465-.271l-1.335-2.323a.09.09 0 00-.083-.049H4.982c-.285.03-.553-.001-.805-.092l-1.603-2.77a.543.543 0 01-.002-.54l1.207-2.12a.198.198 0 000-.197 550.951 550.951 0 01-1.875-3.272l-.79-1.395c-.16-.31-.173-.496.095-.965.465-.813.927-1.625 1.387-2.436.132-.234.304-.334.584-.335a338.3 338.3 0 012.589-.001.124.124 0 00.107-.063l2.806-4.895a.488.488 0 01.422-.246c.524-.001 1.053 0 1.583-.006L11.704 1c.341-.003.724.032.9.34zm-3.432.403a.06.06 0 00-.052.03L6.254 6.788a.157.157 0 01-.135.078H3.253c-.056 0-.07.025-.041.074l5.81 10.156c.025.042.013.062-.034.063l-2.795.015a.218.218 0 00-.2.116l-1.32 2.31c-.044.078-.021.118.068.118l5.716.008c.046 0 .08.02.104.061l1.403 2.454c.046.081.092.082.139 0l5.006-8.76.783-1.382a.055.055 0 01.096 0l1.424 2.53a.122.122 0 00.107.062l2.763-.02a.04.04 0 00.035-.02.041.041 0 000-.04l-2.9-5.086a.108.108 0 010-.113l.293-.507 1.12-1.977c.024-.041.012-.062-.035-.062H9.2c-.059 0-.073-.026-.043-.077l1.434-2.505a.107.107 0 000-.114L9.225 1.774a.06.06 0 00-.053-.031zm6.29 8.02c.046 0 .058.02.034.06l-.832 1.465-2.613 4.585a.056.056 0 01-.05.029.058.058 0 01-.05-.029L8.498 9.841c-.02-.034-.01-.052.028-.054l.216-.012 6.722-.012z"
/>
</svg>
{:else if family === "deepseek"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M23.748 4.482c-.254-.124-.364.113-.512.234-.051.039-.094.09-.137.136-.372.397-.806.657-1.373.626-.829-.046-1.537.214-2.163.848-.133-.782-.575-1.248-1.247-1.548-.352-.156-.708-.311-.955-.65-.172-.241-.219-.51-.305-.774-.055-.16-.11-.323-.293-.35-.2-.031-.278.136-.356.276-.313.572-.434 1.202-.422 1.84.027 1.436.633 2.58 1.838 3.393.137.093.172.187.129.323-.082.28-.18.552-.266.833-.055.179-.137.217-.329.14a5.526 5.526 0 01-1.736-1.18c-.857-.828-1.631-1.742-2.597-2.458a11.365 11.365 0 00-.689-.471c-.985-.957.13-1.743.388-1.836.27-.098.093-.432-.779-.428-.872.004-1.67.295-2.687.684a3.055 3.055 0 01-.465.137 9.597 9.597 0 00-2.883-.102c-1.885.21-3.39 1.102-4.497 2.623C.082 8.606-.231 10.684.152 12.85c.403 2.284 1.569 4.175 3.36 5.653 1.858 1.533 3.997 2.284 6.438 2.14 1.482-.085 3.133-.284 4.994-1.86.47.234.962.327 1.78.397.63.059 1.236-.03 1.705-.128.735-.156.684-.837.419-.961-2.155-1.004-1.682-.595-2.113-.926 1.096-1.296 2.746-2.642 3.392-7.003.05-.347.007-.565 0-.845-.004-.17.035-.237.23-.256a4.173 4.173 0 001.545-.475c1.396-.763 1.96-2.015 2.093-3.517.02-.23-.004-.467-.247-.588zM11.581 18c-2.089-1.642-3.102-2.183-3.52-2.16-.392.024-.321.471-.235.763.09.288.207.486.371.739.114.167.192.416-.113.603-.673.416-1.842-.14-1.897-.167-1.361-.802-2.5-1.86-3.301-3.307-.774-1.393-1.224-2.887-1.298-4.482-.02-.386.093-.522.477-.592a4.696 4.696 0 011.529-.039c2.132.312 3.946 1.265 5.468 2.774.868.86 1.525 1.887 2.202 2.891.72 1.066 1.494 2.082 2.48 2.914.348.292.625.514.891.677-.802.09-2.14.11-3.054-.614zm1-6.44a.306.306 0 01.415-.287.302.302 0 01.2.288.306.306 0 01-.31.307.303.303 0 01-.304-.308zm3.11 1.596c-.2.081-.399.151-.59.16a1.245 1.245 0 01-.798-.254c-.274-.23-.47-.358-.552-.758a1.73 1.73 0 01.016-.588c.07-.327-.008-.537-.239-.727-.187-.156-.426-.199-.688-.199a.559.559 0 01-.254-.078c-.11-.054-.2-.19-.114-.358.028-.054.16-.186.192-.21.356-.202.767-.136 1.146.016.352.144.618.408 1.001.782.391.451.462.576.685.914.176.265.336.537.445.848.067.195-.019.354-.25.452z"
/>
</svg>
{:else if family === "openai" || family === "gpt-oss"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M22.2819 9.8211a5.9847 5.9847 0 0 0-.5157-4.9108 6.0462 6.0462 0 0 0-6.5098-2.9A6.0651 6.0651 0 0 0 4.9807 4.1818a5.9847 5.9847 0 0 0-3.9977 2.9 6.0462 6.0462 0 0 0 .7427 7.0966 5.98 5.98 0 0 0 .511 4.9107 6.051 6.051 0 0 0 6.5146 2.9001A5.9847 5.9847 0 0 0 13.2599 24a6.0557 6.0557 0 0 0 5.7718-4.2058 5.9894 5.9894 0 0 0 3.9977-2.9001 6.0557 6.0557 0 0 0-.7475-7.0729zm-9.022 12.6081a4.4755 4.4755 0 0 1-2.8764-1.0408l.1419-.0804 4.7783-2.7582a.7948.7948 0 0 0 .3927-.6813v-6.7369l2.02 1.1686a.071.071 0 0 1 .038.052v5.5826a4.504 4.504 0 0 1-4.4945 4.4944zm-9.6607-4.1254a4.4708 4.4708 0 0 1-.5346-3.0137l.142.0852 4.783 2.7582a.7712.7712 0 0 0 .7806 0l5.8428-3.3685v2.3324a.0804.0804 0 0 1-.0332.0615L9.74 19.9502a4.4992 4.4992 0 0 1-6.1408-1.6464zM2.3408 7.8956a4.485 4.485 0 0 1 2.3655-1.9728V11.6a.7664.7664 0 0 0 .3879.6765l5.8144 3.3543-2.0201 1.1685a.0757.0757 0 0 1-.071 0l-4.8303-2.7865A4.504 4.504 0 0 1 2.3408 7.872zm16.5963 3.8558L13.1038 8.364 15.1192 7.2a.0757.0757 0 0 1 .071 0l4.8303 2.7913a4.4944 4.4944 0 0 1-.6765 8.1042v-5.6772a.79.79 0 0 0-.407-.667zm2.0107-3.0231l-.142-.0852-4.7735-2.7818a.7759.7759 0 0 0-.7854 0L9.409 9.2297V6.8974a.0662.0662 0 0 1 .0284-.0615l4.8303-2.7866a4.4992 4.4992 0 0 1 6.6802 4.66zM8.3065 12.863l-2.02-1.1638a.0804.0804 0 0 1-.038-.0567V6.0742a4.4992 4.4992 0 0 1 7.3757-3.4537l-.142.0805L8.704 5.459a.7948.7948 0 0 0-.3927.6813zm1.0976-2.3654l2.602-1.4998 2.6069 1.4998v2.9994l-2.5974 1.4997-2.6067-1.4997Z"
/>
</svg>
{:else if family === "glm"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M11.991 23.503a.24.24 0 00-.244.248.24.24 0 00.244.249.24.24 0 00.245-.249.24.24 0 00-.22-.247l-.025-.001zM9.671 5.365a1.697 1.697 0 011.099 2.132l-.071.172-.016.04-.018.054c-.07.16-.104.32-.104.498-.035.71.47 1.279 1.186 1.314h.366c1.309.053 2.338 1.173 2.286 2.523-.052 1.332-1.152 2.38-2.478 2.327h-.174c-.715.018-1.274.64-1.239 1.368 0 .124.018.23.053.337.209.373.54.658.96.8.75.23 1.517-.125 1.9-.782l.018-.035c.402-.64 1.17-.96 1.92-.711.854.284 1.378 1.226 1.099 2.167a1.661 1.661 0 01-2.077 1.102 1.711 1.711 0 01-.907-.711l-.017-.035c-.2-.323-.463-.58-.851-.711l-.056-.018a1.646 1.646 0 00-1.954.746 1.66 1.66 0 01-1.065.764 1.677 1.677 0 01-1.989-1.279c-.209-.906.332-1.83 1.257-2.043a1.51 1.51 0 01.296-.035h.018c.68-.071 1.151-.622 1.116-1.333a1.307 1.307 0 00-.227-.693 2.515 2.515 0 01-.366-1.403 2.39 2.39 0 01.366-1.208c.14-.195.21-.444.227-.693.018-.71-.506-1.261-1.186-1.332l-.07-.018a1.43 1.43 0 01-.299-.07l-.05-.019a1.7 1.7 0 01-1.047-2.114 1.68 1.68 0 012.094-1.101zm-5.575 10.11c.26-.264.639-.367.994-.27.355.096.633.379.728.74.095.362-.007.748-.267 1.013-.402.41-1.053.41-1.455 0a1.062 1.062 0 010-1.482zm14.845-.294c.359-.09.738.024.992.297.254.274.344.665.237 1.025-.107.36-.396.634-.756.718-.551.128-1.1-.22-1.23-.781a1.05 1.05 0 01.757-1.26zm-.064-4.39c.314.32.49.753.49 1.206 0 .452-.176.886-.49 1.206-.315.32-.74.5-1.185.5-.444 0-.87-.18-1.184-.5a1.727 1.727 0 010-2.412 1.654 1.654 0 012.369 0zm-11.243.163c.364.484.447 1.128.218 1.691a1.665 1.665 0 01-2.188.923c-.855-.36-1.26-1.358-.907-2.228a1.68 1.68 0 011.33-1.038c.593-.08 1.183.169 1.547.652zm11.545-4.221c.368 0 .708.2.892.524.184.324.184.724 0 1.048a1.026 1.026 0 01-.892.524c-.568 0-1.03-.47-1.03-1.048 0-.579.462-1.048 1.03-1.048zm-14.358 0c.368 0 .707.2.891.524.184.324.184.724 0 1.048a1.026 1.026 0 01-.891.524c-.569 0-1.03-.47-1.03-1.048 0-.579.461-1.048 1.03-1.048zm10.031-1.475c.925 0 1.675.764 1.675 1.706s-.75 1.705-1.675 1.705-1.674-.763-1.674-1.705c0-.942.75-1.706 1.674-1.706zm-2.626-.684c.362-.082.653-.356.761-.718a1.062 1.062 0 00-.238-1.028 1.017 1.017 0 00-.996-.294c-.547.14-.881.7-.752 1.257.13.558.675.907 1.225.783zm0 16.876c.359-.087.644-.36.75-.72a1.062 1.062 0 00-.237-1.019 1.018 1.018 0 00-.985-.301 1.037 1.037 0 00-.762.717c-.108.361-.017.754.239 1.028.245.263.606.377.953.305l.043-.01zM17.19 3.5a.631.631 0 00.628-.64c0-.355-.279-.64-.628-.64a.631.631 0 00-.628.64c0 .355.28.64.628.64zm-10.38 0a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64a.631.631 0 00-.628.64c0 .355.279.64.628.64zm-5.182 7.852a.631.631 0 00-.628.64c0 .354.28.639.628.639a.63.63 0 00.627-.606l.001-.034a.62.62 0 00-.628-.64zm5.182 9.13a.631.631 0 00-.628.64c0 .355.279.64.628.64a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64zm10.38.018a.631.631 0 00-.628.64c0 .355.28.64.628.64a.631.631 0 00.628-.64c0-.355-.279-.64-.628-.64zm5.182-9.148a.631.631 0 00-.628.64c0 .354.279.639.628.639a.631.631 0 00.628-.64c0-.355-.28-.64-.628-.64zm-.384-4.992a.24.24 0 00.244-.249.24.24 0 00-.244-.249.24.24 0 00-.244.249c0 .142.122.249.244.249zM11.991.497a.24.24 0 00.245-.248A.24.24 0 0011.99 0a.24.24 0 00-.244.249c0 .133.108.236.223.247l.021.001zM2.011 6.36a.24.24 0 00.245-.249.24.24 0 00-.244-.249.24.24 0 00-.244.249.24.24 0 00.244.249zm0 11.263a.24.24 0 00-.243.248.24.24 0 00.244.249.24.24 0 00.244-.249.252.252 0 00-.244-.248zm19.995-.018a.24.24 0 00-.245.248.24.24 0 00.245.25.24.24 0 00.244-.25.252.252 0 00-.244-.248z"
/>
</svg>
{:else if family === "minimax"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M16.278 2c1.156 0 2.093.927 2.093 2.07v12.501a.74.74 0 00.744.709.74.74 0 00.743-.709V9.099a2.06 2.06 0 012.071-2.049A2.06 2.06 0 0124 9.1v6.561a.649.649 0 01-.652.645.649.649 0 01-.653-.645V9.1a.762.762 0 00-.766-.758.762.762 0 00-.766.758v7.472a2.037 2.037 0 01-2.048 2.026 2.037 2.037 0 01-2.048-2.026v-12.5a.785.785 0 00-.788-.753.785.785 0 00-.789.752l-.001 15.904A2.037 2.037 0 0113.441 22a2.037 2.037 0 01-2.048-2.026V18.04c0-.356.292-.645.652-.645.36 0 .652.289.652.645v1.934c0 .263.142.506.372.638.23.131.514.131.744 0a.734.734 0 00.372-.638V4.07c0-1.143.937-2.07 2.093-2.07zm-5.674 0c1.156 0 2.093.927 2.093 2.07v11.523a.648.648 0 01-.652.645.648.648 0 01-.652-.645V4.07a.785.785 0 00-.789-.78.785.785 0 00-.789.78v14.013a2.06 2.06 0 01-2.07 2.048 2.06 2.06 0 01-2.071-2.048V9.1a.762.762 0 00-.766-.758.762.762 0 00-.766.758v3.8a2.06 2.06 0 01-2.071 2.049A2.06 2.06 0 010 12.9v-1.378c0-.357.292-.646.652-.646.36 0 .653.29.653.646V12.9c0 .418.343.757.766.757s.766-.339.766-.757V9.099a2.06 2.06 0 012.07-2.048 2.06 2.06 0 012.071 2.048v8.984c0 .419.343.758.767.758.423 0 .766-.339.766-.758V4.07c0-1.143.937-2.07 2.093-2.07z"
/>
</svg>
{:else if family === "kimi"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19.738 5.776c.163-.209.306-.4.457-.585.07-.087.064-.153-.004-.244-.655-.861-.717-1.817-.34-2.787.283-.73.909-1.072 1.674-1.145.477-.045.945.004 1.379.236.57.305.902.77 1.01 1.412.086.512.07 1.012-.075 1.508-.257.878-.888 1.333-1.753 1.448-.718.096-1.446.108-2.17.157-.056.004-.113 0-.178 0z"
/>
<path
d="M17.962 1.844h-4.326l-3.425 7.81H5.369V1.878H1.5V22h3.87v-8.477h6.824a3.025 3.025 0 002.743-1.75V22h3.87v-8.477a3.87 3.87 0 00-3.588-3.86v-.01h-2.125a3.94 3.94 0 002.323-2.12l2.545-5.689z"
/>
</svg>
{:else if family === "huggingface"}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12.025 1.13c-5.77 0-10.449 4.647-10.449 10.378 0 1.112.178 2.181.503 3.185.064-.222.203-.444.416-.577a.96.96 0 0 1 .524-.15c.293 0 .584.124.84.284.278.173.48.408.71.694.226.282.458.611.684.951v-.014c.017-.324.106-.622.264-.874s.403-.487.762-.543c.3-.047.596.06.787.203s.31.313.4.467c.15.257.212.468.233.542.01.026.653 1.552 1.657 2.54.616.605 1.01 1.223 1.082 1.912.055.537-.096 1.059-.38 1.572.637.121 1.294.187 1.967.187.657 0 1.298-.063 1.921-.178-.287-.517-.44-1.041-.384-1.581.07-.69.465-1.307 1.081-1.913 1.004-.987 1.647-2.513 1.657-2.539.021-.074.083-.285.233-.542.09-.154.208-.323.4-.467a1.08 1.08 0 0 1 .787-.203c.359.056.604.29.762.543s.247.55.265.874v.015c.225-.34.457-.67.683-.952.23-.286.432-.52.71-.694.257-.16.547-.284.84-.285a.97.97 0 0 1 .524.151c.228.143.373.388.43.625l.006.04a10.3 10.3 0 0 0 .534-3.273c0-5.731-4.678-10.378-10.449-10.378M8.327 6.583a1.5 1.5 0 0 1 .713.174 1.487 1.487 0 0 1 .617 2.013c-.183.343-.762-.214-1.102-.094-.38.134-.532.914-.917.71a1.487 1.487 0 0 1 .69-2.803m7.486 0a1.487 1.487 0 0 1 .689 2.803c-.385.204-.536-.576-.916-.71-.34-.12-.92.437-1.103.094a1.487 1.487 0 0 1 .617-2.013 1.5 1.5 0 0 1 .713-.174m-10.68 1.55a.96.96 0 1 1 0 1.921.96.96 0 0 1 0-1.92m13.838 0a.96.96 0 1 1 0 1.92.96.96 0 0 1 0-1.92M8.489 11.458c.588.01 1.965 1.157 3.572 1.164 1.607-.007 2.984-1.155 3.572-1.164.196-.003.305.12.305.454 0 .886-.424 2.328-1.563 3.202-.22-.756-1.396-1.366-1.63-1.32q-.011.001-.02.006l-.044.026-.01.008-.03.024q-.018.017-.035.036l-.032.04a1 1 0 0 0-.058.09l-.014.025q-.049.088-.11.19a1 1 0 0 1-.083.116 1.2 1.2 0 0 1-.173.18q-.035.029-.075.058a1.3 1.3 0 0 1-.251-.243 1 1 0 0 1-.076-.107c-.124-.193-.177-.363-.337-.444-.034-.016-.104-.008-.2.022q-.094.03-.216.087-.06.028-.125.063l-.13.074q-.067.04-.136.086a3 3 0 0 0-.135.096 3 3 0 0 0-.26.219 2 2 0 0 0-.12.121 2 2 0 0 0-.106.128l-.002.002a2 2 0 0 0-.09.132l-.001.001a1.2 1.2 0 0 0-.105.212q-.013.036-.024.073c-1.139-.875-1.563-2.317-1.563-3.203 0-.334.109-.457.305-.454m.836 10.354c.824-1.19.766-2.082-.365-3.194-1.13-1.112-1.789-2.738-1.789-2.738s-.246-.945-.806-.858-.97 1.499.202 2.362c1.173.864-.233 1.45-.685.64-.45-.812-1.683-2.896-2.322-3.295s-1.089-.175-.938.647 2.822 2.813 2.562 3.244-1.176-.506-1.176-.506-2.866-2.567-3.49-1.898.473 1.23 2.037 2.16c1.564.932 1.686 1.178 1.464 1.53s-3.675-2.511-4-1.297c-.323 1.214 3.524 1.567 3.287 2.405-.238.839-2.71-1.587-3.216-.642-.506.946 3.49 2.056 3.522 2.064 1.29.33 4.568 1.028 5.713-.624m5.349 0c-.824-1.19-.766-2.082.365-3.194 1.13-1.112 1.789-2.738 1.789-2.738s.246-.945.806-.858.97 1.499-.202 2.362c-1.173.864.233 1.45.685.64.451-.812 1.683-2.896 2.322-3.295s1.089-.175.938.647-2.822 2.813-2.562 3.244 1.176-.506 1.176-.506 2.866-2.567 3.49-1.898-.473 1.23-2.037 2.16c-1.564.932-1.686 1.178-1.464 1.53s3.675-2.511 4-1.297c.323 1.214-3.524 1.567-3.287 2.405.238.839 2.71-1.587 3.216-.642.506.946-3.49 2.056-3.522 2.064-1.29.33-4.568 1.028-5.713-.624"
/>
</svg>
{:else}
<svg class="w-6 h-6 {className}" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 15l-5-5 1.41-1.41L10 14.17l7.59-7.59L19 8l-9 9z"
/>
</svg>
{/if}

View File

@@ -0,0 +1,142 @@
<script lang="ts">
import FamilyLogos from "./FamilyLogos.svelte";
type FamilySidebarProps = {
families: string[];
selectedFamily: string | null;
hasFavorites: boolean;
onSelect: (family: string | null) => void;
};
let { families, selectedFamily, hasFavorites, onSelect }: FamilySidebarProps =
$props();
// Family display names
const familyNames: Record<string, string> = {
favorites: "Favorites",
huggingface: "Hub",
llama: "Meta",
qwen: "Qwen",
deepseek: "DeepSeek",
"gpt-oss": "OpenAI",
glm: "GLM",
minimax: "MiniMax",
kimi: "Kimi",
};
function getFamilyName(family: string): string {
return (
familyNames[family] || family.charAt(0).toUpperCase() + family.slice(1)
);
}
</script>
<div
class="flex flex-col gap-1 py-2 px-1 border-r border-exo-yellow/10 bg-exo-medium-gray/30 min-w-[64px]"
>
<!-- All models (no filter) -->
<button
type="button"
onclick={() => onSelect(null)}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
null
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="All models"
>
<svg
class="w-5 h-5 {selectedFamily === null
? 'text-exo-yellow'
: 'text-white/50 group-hover:text-white/70'}"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M4 8h4V4H4v4zm6 12h4v-4h-4v4zm-6 0h4v-4H4v4zm0-6h4v-4H4v4zm6 0h4v-4h-4v4zm6-10v4h4V4h-4zm-6 4h4V4h-4v4zm6 6h4v-4h-4v4zm0 6h4v-4h-4v4z"
/>
</svg>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === null
? 'text-exo-yellow'
: 'text-white/40 group-hover:text-white/60'}">All</span
>
</button>
<!-- Favorites (only show if has favorites) -->
{#if hasFavorites}
<button
type="button"
onclick={() => onSelect("favorites")}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
'favorites'
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="Show favorited models"
>
<FamilyLogos
family="favorites"
class={selectedFamily === "favorites"
? "text-amber-400"
: "text-white/50 group-hover:text-amber-400/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === 'favorites'
? 'text-amber-400'
: 'text-white/40 group-hover:text-white/60'}">Faves</span
>
</button>
{/if}
<!-- HuggingFace Hub -->
<button
type="button"
onclick={() => onSelect("huggingface")}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
'huggingface'
? 'bg-orange-500/20 border-l-2 border-orange-400'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title="Browse and add models from Hugging Face"
>
<FamilyLogos
family="huggingface"
class={selectedFamily === "huggingface"
? "text-orange-400"
: "text-white/50 group-hover:text-orange-400/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 {selectedFamily === 'huggingface'
? 'text-orange-400'
: 'text-white/40 group-hover:text-white/60'}">Hub</span
>
</button>
<div class="h-px bg-exo-yellow/10 my-1"></div>
<!-- Model families -->
{#each families as family}
<button
type="button"
onclick={() => onSelect(family)}
class="group flex flex-col items-center justify-center p-2 rounded transition-all duration-200 cursor-pointer {selectedFamily ===
family
? 'bg-exo-yellow/20 border-l-2 border-exo-yellow'
: 'hover:bg-white/5 border-l-2 border-transparent'}"
title={getFamilyName(family)}
>
<FamilyLogos
{family}
class={selectedFamily === family
? "text-exo-yellow"
: "text-white/50 group-hover:text-white/70"}
/>
<span
class="text-[9px] font-mono mt-0.5 truncate max-w-full {selectedFamily ===
family
? 'text-exo-yellow'
: 'text-white/40 group-hover:text-white/60'}"
>
{getFamilyName(family)}
</span>
</button>
{/each}
</div>

View File

@@ -0,0 +1,127 @@
<script lang="ts">
interface HuggingFaceModel {
id: string;
author: string;
downloads: number;
likes: number;
last_modified: string;
tags: string[];
}
type HuggingFaceResultItemProps = {
model: HuggingFaceModel;
isAdded: boolean;
isAdding: boolean;
onAdd: () => void;
onSelect: () => void;
};
let {
model,
isAdded,
isAdding,
onAdd,
onSelect,
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
if (num >= 1000000) {
return `${(num / 1000000).toFixed(1)}M`;
} else if (num >= 1000) {
return `${(num / 1000).toFixed(1)}k`;
}
return num.toString();
}
// Extract model name from full ID (e.g., "mlx-community/Llama-3.2-1B" -> "Llama-3.2-1B")
const modelName = $derived(model.id.split("/").pop() || model.id);
</script>
<div
class="flex items-center justify-between gap-3 px-3 py-2.5 hover:bg-white/5 transition-colors border-b border-white/5 last:border-b-0"
>
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2">
<span class="text-sm font-mono text-white truncate" title={model.id}
>{modelName}</span
>
{#if isAdded}
<span
class="px-1.5 py-0.5 text-[10px] font-mono bg-green-500/20 text-green-400 rounded"
>Added</span
>
{/if}
</div>
<div class="flex items-center gap-3 mt-0.5 text-xs text-white/40">
<span class="truncate">{model.author}</span>
<span
class="flex items-center gap-1 shrink-0"
title="Downloads in the last 30 days"
>
<svg
class="w-3 h-3"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"
/>
</svg>
{formatNumber(model.downloads)}
</span>
<span
class="flex items-center gap-1 shrink-0"
title="Community likes on Hugging Face"
>
<svg
class="w-3 h-3"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4.318 6.318a4.5 4.5 0 000 6.364L12 20.364l7.682-7.682a4.5 4.5 0 00-6.364-6.364L12 7.636l-1.318-1.318a4.5 4.5 0 00-6.364 0z"
/>
</svg>
{formatNumber(model.likes)}
</span>
</div>
</div>
<div class="flex items-center gap-2 shrink-0">
{#if isAdded}
<button
type="button"
onclick={onSelect}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-exo-yellow/10 text-exo-yellow border border-exo-yellow/30 hover:bg-exo-yellow/20 transition-colors rounded cursor-pointer"
>
Select
</button>
{:else}
<button
type="button"
onclick={onAdd}
disabled={isAdding}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-orange-500/10 text-orange-400 border border-orange-400/30 hover:bg-orange-500/20 transition-colors rounded cursor-pointer disabled:opacity-50 disabled:cursor-not-allowed"
>
{#if isAdding}
<span class="flex items-center gap-1.5">
<span
class="w-3 h-3 border-2 border-orange-400 border-t-transparent rounded-full animate-spin"
></span>
Adding...
</span>
{:else}
+ Add
{/if}
</button>
{/if}
</div>
</div>

View File

@@ -0,0 +1,182 @@
<script lang="ts">
import { fly } from "svelte/transition";
import { cubicOut } from "svelte/easing";
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
}
type ModelFilterPopoverProps = {
filters: FilterState;
onChange: (filters: FilterState) => void;
onClear: () => void;
onClose: () => void;
};
let { filters, onChange, onClear, onClose }: ModelFilterPopoverProps =
$props();
// Available capabilities
const availableCapabilities = [
{ id: "text", label: "Text" },
{ id: "thinking", label: "Thinking" },
{ id: "code", label: "Code" },
{ id: "vision", label: "Vision" },
];
// Size ranges
const sizeRanges = [
{ label: "< 10GB", min: 0, max: 10 },
{ label: "10-50GB", min: 10, max: 50 },
{ label: "50-200GB", min: 50, max: 200 },
{ label: "> 200GB", min: 200, max: 10000 },
];
function toggleCapability(cap: string) {
const next = filters.capabilities.includes(cap)
? filters.capabilities.filter((c) => c !== cap)
: [...filters.capabilities, cap];
onChange({ ...filters, capabilities: next });
}
function selectSizeRange(range: { min: number; max: number } | null) {
// Toggle off if same range is clicked
if (
filters.sizeRange &&
range &&
filters.sizeRange.min === range.min &&
filters.sizeRange.max === range.max
) {
onChange({ ...filters, sizeRange: null });
} else {
onChange({ ...filters, sizeRange: range });
}
}
function handleClickOutside(e: MouseEvent) {
const target = e.target as HTMLElement;
if (
!target.closest(".filter-popover") &&
!target.closest(".filter-toggle")
) {
onClose();
}
}
</script>
<svelte:window onclick={handleClickOutside} />
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="filter-popover absolute right-0 top-full mt-2 w-64 bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-xl z-10"
transition:fly={{ y: -10, duration: 200, easing: cubicOut }}
onclick={(e) => e.stopPropagation()}
role="dialog"
aria-label="Filter options"
>
<div class="p-3 space-y-4">
<!-- Capabilities -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Capabilities</h4>
<div class="flex flex-wrap gap-1.5">
{#each availableCapabilities as cap}
{@const isSelected = filters.capabilities.includes(cap.id)}
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {isSelected
? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() => toggleCapability(cap.id)}
>
{#if cap.id === "text"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "thinking"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "code"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M16 18l6-6-6-6M8 6l-6 6 6 6"
stroke-linecap="round"
stroke-linejoin="round"
/></svg
>
{:else if cap.id === "vision"}
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
><path
d="M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z"
stroke-linecap="round"
stroke-linejoin="round"
/><circle cx="12" cy="12" r="3" /></svg
>
{/if}
<span class="ml-1">{cap.label}</span>
</button>
{/each}
</div>
</div>
<!-- Size range -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Model Size</h4>
<div class="flex flex-wrap gap-1.5">
{#each sizeRanges as range}
{@const isSelected =
filters.sizeRange &&
filters.sizeRange.min === range.min &&
filters.sizeRange.max === range.max}
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {isSelected
? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() => selectSizeRange(range)}
>
{range.label}
</button>
{/each}
</div>
</div>
<!-- Clear button -->
<button
type="button"
class="w-full py-1.5 text-xs font-mono text-white/50 hover:text-white/70 hover:bg-white/5 rounded transition-colors"
onclick={onClear}
>
Clear all filters
</button>
</div>
</div>

View File

@@ -0,0 +1,324 @@
<script lang="ts">
interface ModelInfo {
id: string;
name?: string;
storage_size_megabytes?: number;
base_model?: string;
quantization?: string;
supports_tensor?: boolean;
capabilities?: string[];
family?: string;
is_custom?: boolean;
}
interface ModelGroup {
id: string;
name: string;
capabilities: string[];
family: string;
variants: ModelInfo[];
smallestVariant: ModelInfo;
hasMultipleVariants: boolean;
}
type ModelPickerGroupProps = {
group: ModelGroup;
isExpanded: boolean;
isFavorite: boolean;
selectedModelId: string | null;
canModelFit: (id: string) => boolean;
onToggleExpand: () => void;
onSelectModel: (modelId: string) => void;
onToggleFavorite: (baseModelId: string) => void;
onShowInfo: (group: ModelGroup) => void;
};
let {
group,
isExpanded,
isFavorite,
selectedModelId,
canModelFit,
onToggleExpand,
onSelectModel,
onToggleFavorite,
onShowInfo,
}: ModelPickerGroupProps = $props();
// Format storage size
function formatSize(mb: number | undefined): string {
if (!mb) return "";
if (mb >= 1024) {
return `${(mb / 1024).toFixed(0)}GB`;
}
return `${mb}MB`;
}
// Check if any variant can fit
const anyVariantFits = $derived(
group.variants.some((v) => canModelFit(v.id)),
);
// Check if this group's model is currently selected (for single-variant groups)
const isMainSelected = $derived(
!group.hasMultipleVariants &&
group.variants.some((v) => v.id === selectedModelId),
);
</script>
<div
class="border-b border-white/5 last:border-b-0 {!anyVariantFits
? 'opacity-50'
: ''}"
>
<!-- Main row -->
<div
class="flex items-center gap-2 px-3 py-2.5 transition-colors {anyVariantFits
? 'hover:bg-white/5 cursor-pointer'
: 'cursor-not-allowed'} {isMainSelected
? 'bg-exo-yellow/10 border-l-2 border-exo-yellow'
: 'border-l-2 border-transparent'}"
onclick={() => {
if (group.hasMultipleVariants) {
onToggleExpand();
} else {
const modelId = group.variants[0]?.id;
if (modelId && canModelFit(modelId)) {
onSelectModel(modelId);
}
}
}}
role="button"
tabindex="0"
onkeydown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
if (group.hasMultipleVariants) {
onToggleExpand();
} else {
const modelId = group.variants[0]?.id;
if (modelId && canModelFit(modelId)) {
onSelectModel(modelId);
}
}
}
}}
>
<!-- Expand/collapse chevron (for groups with variants) -->
{#if group.hasMultipleVariants}
<svg
class="w-4 h-4 text-white/40 transition-transform duration-200 flex-shrink-0 {isExpanded
? 'rotate-90'
: ''}"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M8.59 16.59L13.17 12 8.59 7.41 10 6l6 6-6 6-1.41-1.41z" />
</svg>
{:else}
<div class="w-4 flex-shrink-0"></div>
{/if}
<!-- Model name -->
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2">
<span class="font-mono text-sm text-white truncate">
{group.name}
</span>
<!-- Capability icons -->
{#each group.capabilities.filter((c) => c !== "text") as cap}
{#if cap === "thinking"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports Thinking"
>
<path
d="M12 2a7 7 0 0 0-7 7c0 2.38 1.19 4.47 3 5.74V17a1 1 0 0 0 1 1h6a1 1 0 0 0 1-1v-2.26c1.81-1.27 3-3.36 3-5.74a7 7 0 0 0-7-7zM9 20h6M10 22h4"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
{:else if cap === "code"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports code generation"
>
<path
d="M16 18l6-6-6-6M8 6l-6 6 6 6"
stroke-linecap="round"
stroke-linejoin="round"
/>
</svg>
{:else if cap === "vision"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports image input"
>
<path
d="M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z"
stroke-linecap="round"
stroke-linejoin="round"
/>
<circle cx="12" cy="12" r="3" />
</svg>
{:else if cap === "image_gen"}
<svg
class="w-3.5 h-3.5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
title="Supports image generation"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<path d="M21 15l-5-5L5 21" />
</svg>
{/if}
{/each}
</div>
</div>
<!-- Size indicator (smallest variant) -->
{#if !group.hasMultipleVariants && group.smallestVariant?.storage_size_megabytes}
<span class="text-xs font-mono text-white/30 flex-shrink-0">
{formatSize(group.smallestVariant.storage_size_megabytes)}
</span>
{/if}
<!-- Variant count -->
{#if group.hasMultipleVariants}
<span class="text-xs font-mono text-white/30 flex-shrink-0">
{group.variants.length} variants
</span>
{/if}
<!-- Check mark if selected (single-variant) -->
{#if isMainSelected}
<svg
class="w-4 h-4 text-exo-yellow flex-shrink-0"
viewBox="0 0 24 24"
fill="currentColor"
>
<path d="M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41L9 16.17z" />
</svg>
{/if}
<!-- Favorite star -->
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0"
onclick={(e) => {
e.stopPropagation();
onToggleFavorite(group.id);
}}
title={isFavorite ? "Remove from favorites" : "Add to favorites"}
>
{#if isFavorite}
<svg
class="w-4 h-4 text-amber-400"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{:else}
<svg
class="w-4 h-4 text-white/30 hover:text-white/50"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M12 2l3.09 6.26L22 9.27l-5 4.87 1.18 6.88L12 17.77l-6.18 3.25L7 14.14 2 9.27l6.91-1.01L12 2z"
/>
</svg>
{/if}
</button>
<!-- Info button -->
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors flex-shrink-0"
onclick={(e) => {
e.stopPropagation();
onShowInfo(group);
}}
title="View model details"
>
<svg
class="w-4 h-4 text-white/30 hover:text-white/50"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm1 15h-2v-6h2v6zm0-8h-2V7h2v2z"
/>
</svg>
</button>
</div>
<!-- Expanded variants -->
{#if isExpanded && group.hasMultipleVariants}
<div class="bg-black/20 border-t border-white/5">
{#each group.variants as variant}
{@const modelCanFit = canModelFit(variant.id)}
{@const isSelected = selectedModelId === variant.id}
<button
type="button"
class="w-full flex items-center gap-3 px-3 py-2 pl-10 hover:bg-white/5 transition-colors text-left {!modelCanFit
? 'opacity-50 cursor-not-allowed'
: 'cursor-pointer'} {isSelected
? 'bg-exo-yellow/10 border-l-2 border-exo-yellow'
: 'border-l-2 border-transparent'}"
disabled={!modelCanFit}
onclick={() => {
if (modelCanFit) {
onSelectModel(variant.id);
}
}}
>
<!-- Quantization badge -->
<span
class="text-xs font-mono px-1.5 py-0.5 rounded bg-white/10 text-white/70 flex-shrink-0"
>
{variant.quantization || "default"}
</span>
<!-- Size -->
<span class="text-xs font-mono text-white/40 flex-1">
{formatSize(variant.storage_size_megabytes)}
</span>
<!-- Check mark if selected -->
{#if isSelected}
<svg
class="w-4 h-4 text-exo-yellow"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41L9 16.17z"
/>
</svg>
{/if}
</button>
{/each}
</div>
{/if}
</div>

View File

@@ -0,0 +1,748 @@
<script lang="ts">
import { fade, fly } from "svelte/transition";
import { cubicOut } from "svelte/easing";
import FamilySidebar from "./FamilySidebar.svelte";
import ModelPickerGroup from "./ModelPickerGroup.svelte";
import ModelFilterPopover from "./ModelFilterPopover.svelte";
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
interface ModelInfo {
id: string;
name?: string;
storage_size_megabytes?: number;
base_model?: string;
quantization?: string;
supports_tensor?: boolean;
capabilities?: string[];
family?: string;
is_custom?: boolean;
tasks?: string[];
hugging_face_id?: string;
}
interface ModelGroup {
id: string;
name: string;
capabilities: string[];
family: string;
variants: ModelInfo[];
smallestVariant: ModelInfo;
hasMultipleVariants: boolean;
}
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
}
interface HuggingFaceModel {
id: string;
author: string;
downloads: number;
likes: number;
last_modified: string;
tags: string[];
}
type ModelPickerModalProps = {
isOpen: boolean;
models: ModelInfo[];
selectedModelId: string | null;
favorites: Set<string>;
existingModelIds: Set<string>;
canModelFit: (modelId: string) => boolean;
onSelect: (modelId: string) => void;
onClose: () => void;
onToggleFavorite: (baseModelId: string) => void;
onAddModel: (modelId: string) => Promise<void>;
onDeleteModel: (modelId: string) => Promise<void>;
totalMemoryGB: number;
usedMemoryGB: number;
};
let {
isOpen,
models,
selectedModelId,
favorites,
existingModelIds,
canModelFit,
onSelect,
onClose,
onToggleFavorite,
onAddModel,
onDeleteModel,
totalMemoryGB,
usedMemoryGB,
}: ModelPickerModalProps = $props();
// Local state
let searchQuery = $state("");
let selectedFamily = $state<string | null>(null);
let expandedGroups = $state<Set<string>>(new Set());
let showFilters = $state(false);
let filters = $state<FilterState>({ capabilities: [], sizeRange: null });
let infoGroup = $state<ModelGroup | null>(null);
// HuggingFace Hub state
let hfSearchQuery = $state("");
let hfSearchResults = $state<HuggingFaceModel[]>([]);
let hfTrendingModels = $state<HuggingFaceModel[]>([]);
let hfIsSearching = $state(false);
let hfIsLoadingTrending = $state(false);
let addingModelId = $state<string | null>(null);
let hfSearchDebounceTimer: ReturnType<typeof setTimeout> | null = null;
let manualModelId = $state("");
let addModelError = $state<string | null>(null);
// Reset state when modal opens
$effect(() => {
if (isOpen) {
searchQuery = "";
selectedFamily = null;
expandedGroups = new Set();
showFilters = false;
hfSearchQuery = "";
hfSearchResults = [];
manualModelId = "";
addModelError = null;
}
});
// Fetch trending models when HuggingFace is selected
$effect(() => {
if (
selectedFamily === "huggingface" &&
hfTrendingModels.length === 0 &&
!hfIsLoadingTrending
) {
fetchTrendingModels();
}
});
async function fetchTrendingModels() {
hfIsLoadingTrending = true;
try {
const response = await fetch("/models/search?query=&limit=20");
if (response.ok) {
hfTrendingModels = await response.json();
}
} catch (error) {
console.error("Failed to fetch trending models:", error);
} finally {
hfIsLoadingTrending = false;
}
}
async function searchHuggingFace(query: string) {
if (query.length < 2) {
hfSearchResults = [];
return;
}
hfIsSearching = true;
try {
const response = await fetch(
`/models/search?query=${encodeURIComponent(query)}&limit=20`,
);
if (response.ok) {
hfSearchResults = await response.json();
} else {
hfSearchResults = [];
}
} catch (error) {
console.error("Failed to search models:", error);
hfSearchResults = [];
} finally {
hfIsSearching = false;
}
}
function handleHfSearchInput(query: string) {
hfSearchQuery = query;
addModelError = null;
if (hfSearchDebounceTimer) {
clearTimeout(hfSearchDebounceTimer);
}
if (query.length >= 2) {
hfSearchDebounceTimer = setTimeout(() => {
searchHuggingFace(query);
}, 300);
} else {
hfSearchResults = [];
}
}
async function handleAddModel(modelId: string) {
addingModelId = modelId;
addModelError = null;
try {
await onAddModel(modelId);
} catch (error) {
addModelError =
error instanceof Error ? error.message : "Failed to add model";
} finally {
addingModelId = null;
}
}
async function handleAddManualModel() {
if (!manualModelId.trim()) return;
await handleAddModel(manualModelId.trim());
if (!addModelError) {
manualModelId = "";
}
}
function handleSelectHfModel(modelId: string) {
onSelect(modelId);
onClose();
}
// Models to display in HuggingFace view
const hfDisplayModels = $derived.by((): HuggingFaceModel[] => {
if (hfSearchQuery.length >= 2) {
return hfSearchResults;
}
return hfTrendingModels;
});
// Group models by base_model
const groupedModels = $derived.by((): ModelGroup[] => {
const groups = new Map<string, ModelGroup>();
for (const model of models) {
const groupId = model.base_model || model.id;
const groupName = model.base_model || model.name || model.id;
if (!groups.has(groupId)) {
groups.set(groupId, {
id: groupId,
name: groupName,
capabilities: model.capabilities || ["text"],
family: model.family || "",
variants: [],
smallestVariant: model,
hasMultipleVariants: false,
});
}
const group = groups.get(groupId)!;
group.variants.push(model);
// Track smallest variant
if (
(model.storage_size_megabytes || 0) <
(group.smallestVariant.storage_size_megabytes || Infinity)
) {
group.smallestVariant = model;
}
// Update capabilities if not set
if (
group.capabilities.length <= 1 &&
model.capabilities &&
model.capabilities.length > 1
) {
group.capabilities = model.capabilities;
}
if (!group.family && model.family) {
group.family = model.family;
}
}
// Sort variants within each group by size
for (const group of groups.values()) {
group.variants.sort(
(a, b) =>
(a.storage_size_megabytes || 0) - (b.storage_size_megabytes || 0),
);
group.hasMultipleVariants = group.variants.length > 1;
}
// Convert to array and sort by smallest variant size (biggest first)
return Array.from(groups.values()).sort((a, b) => {
return (
(b.smallestVariant.storage_size_megabytes || 0) -
(a.smallestVariant.storage_size_megabytes || 0)
);
});
});
// Get unique families
const uniqueFamilies = $derived.by((): string[] => {
const families = new Set<string>();
for (const group of groupedModels) {
if (group.family) {
families.add(group.family);
}
}
const familyOrder = [
"kimi",
"qwen",
"glm",
"minimax",
"deepseek",
"gpt-oss",
"llama",
];
return Array.from(families).sort((a, b) => {
const aIdx = familyOrder.indexOf(a);
const bIdx = familyOrder.indexOf(b);
if (aIdx === -1 && bIdx === -1) return a.localeCompare(b);
if (aIdx === -1) return 1;
if (bIdx === -1) return -1;
return aIdx - bIdx;
});
});
// Filter models based on search, family, and filters
const filteredGroups = $derived.by((): ModelGroup[] => {
let result: ModelGroup[] = [...groupedModels];
// Filter by family
if (selectedFamily === "favorites") {
result = result.filter((g) => favorites.has(g.id));
} else if (selectedFamily && selectedFamily !== "huggingface") {
result = result.filter((g) => g.family === selectedFamily);
}
// Filter by search query
if (searchQuery.trim()) {
const query = searchQuery.toLowerCase().trim();
result = result.filter(
(g) =>
g.name.toLowerCase().includes(query) ||
g.variants.some(
(v) =>
v.id.toLowerCase().includes(query) ||
(v.name || "").toLowerCase().includes(query),
),
);
}
// Filter by capabilities
if (filters.capabilities.length > 0) {
result = result.filter((g) =>
filters.capabilities.every((cap) => g.capabilities.includes(cap)),
);
}
// Filter by size range
if (filters.sizeRange) {
const { min, max } = filters.sizeRange;
result = result.filter((g) => {
const sizeGB = (g.smallestVariant.storage_size_megabytes || 0) / 1024;
return sizeGB >= min && sizeGB <= max;
});
}
// Sort: models that fit first, then by size (largest first)
result.sort((a, b) => {
const aFits = a.variants.some((v) => canModelFit(v.id));
const bFits = b.variants.some((v) => canModelFit(v.id));
if (aFits && !bFits) return -1;
if (!aFits && bFits) return 1;
return (
(b.smallestVariant.storage_size_megabytes || 0) -
(a.smallestVariant.storage_size_megabytes || 0)
);
});
return result;
});
// Check if any favorites exist
const hasFavorites = $derived(favorites.size > 0);
function toggleGroupExpanded(groupId: string) {
const next = new Set(expandedGroups);
if (next.has(groupId)) {
next.delete(groupId);
} else {
next.add(groupId);
}
expandedGroups = next;
}
function handleSelect(modelId: string) {
onSelect(modelId);
onClose();
}
function handleKeydown(e: KeyboardEvent) {
if (e.key === "Escape") {
onClose();
}
}
function handleFiltersChange(newFilters: FilterState) {
filters = newFilters;
}
function clearFilters() {
filters = { capabilities: [], sizeRange: null };
}
const hasActiveFilters = $derived(
filters.capabilities.length > 0 || filters.sizeRange !== null,
);
</script>
<svelte:window onkeydown={handleKeydown} />
{#if isOpen}
<!-- Backdrop -->
<div
class="fixed inset-0 z-50 bg-black/80 backdrop-blur-sm"
transition:fade={{ duration: 200 }}
onclick={onClose}
role="presentation"
></div>
<!-- Modal -->
<div
class="fixed z-50 top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(90vw,600px)] h-[min(80vh,700px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl overflow-hidden flex flex-col"
transition:fly={{ y: 20, duration: 300, easing: cubicOut }}
role="dialog"
aria-modal="true"
aria-label="Select a model"
>
<!-- Header with search -->
<div
class="flex items-center gap-2 p-3 border-b border-exo-yellow/10 bg-exo-medium-gray/30"
>
{#if selectedFamily === "huggingface"}
<!-- HuggingFace search -->
<svg
class="w-5 h-5 text-orange-400/60 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<circle cx="11" cy="11" r="8" />
<path d="M21 21l-4.35-4.35" />
</svg>
<input
type="search"
class="flex-1 bg-transparent border-none outline-none text-sm font-mono text-white placeholder-white/40"
placeholder="Search mlx-community models..."
value={hfSearchQuery}
oninput={(e) => handleHfSearchInput(e.currentTarget.value)}
/>
{#if hfIsSearching}
<div class="flex-shrink-0">
<span
class="w-4 h-4 border-2 border-orange-400 border-t-transparent rounded-full animate-spin block"
></span>
</div>
{/if}
{:else}
<!-- Normal model search -->
<svg
class="w-5 h-5 text-white/40 flex-shrink-0"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<circle cx="11" cy="11" r="8" />
<path d="M21 21l-4.35-4.35" />
</svg>
<input
type="search"
class="flex-1 bg-transparent border-none outline-none text-sm font-mono text-white placeholder-white/40"
placeholder="Search models..."
bind:value={searchQuery}
/>
<!-- Cluster memory -->
<span
class="text-xs font-mono flex-shrink-0"
title="Cluster memory usage"
><span class="text-exo-yellow">{Math.round(usedMemoryGB)}GB</span
><span class="text-white/40">/{Math.round(totalMemoryGB)}GB</span
></span
>
<!-- Filter button -->
<div class="relative filter-toggle">
<button
type="button"
class="p-1.5 rounded hover:bg-white/10 transition-colors {hasActiveFilters
? 'text-exo-yellow'
: 'text-white/50'}"
onclick={() => (showFilters = !showFilters)}
title="Filter by capability or size"
>
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="currentColor">
<path d="M10 18h4v-2h-4v2zM3 6v2h18V6H3zm3 7h12v-2H6v2z" />
</svg>
</button>
{#if showFilters}
<ModelFilterPopover
{filters}
onChange={handleFiltersChange}
onClear={clearFilters}
onClose={() => (showFilters = false)}
/>
{/if}
</div>
{/if}
<!-- Close button -->
<button
type="button"
class="p-1.5 rounded hover:bg-white/10 transition-colors text-white/50 hover:text-white/70"
onclick={onClose}
title="Close model picker"
>
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z"
/>
</svg>
</button>
</div>
<!-- Body -->
<div class="flex flex-1 overflow-hidden">
<!-- Family sidebar -->
<FamilySidebar
families={uniqueFamilies}
{selectedFamily}
{hasFavorites}
onSelect={(family) => (selectedFamily = family)}
/>
<!-- Model list -->
<div class="flex-1 overflow-y-auto flex flex-col">
{#if selectedFamily === "huggingface"}
<!-- HuggingFace Hub view -->
<div class="flex-1 flex flex-col min-h-0">
<!-- Section header -->
<div
class="sticky top-0 z-10 px-3 py-2 bg-exo-dark-gray/95 border-b border-exo-yellow/10"
>
<span class="text-xs font-mono text-white/40">
{#if hfSearchQuery.length >= 2}
Search results for "{hfSearchQuery}"
{:else}
Trending on mlx-community
{/if}
</span>
</div>
<!-- Results list -->
<div class="flex-1 overflow-y-auto">
{#if hfIsLoadingTrending && hfTrendingModels.length === 0}
<div
class="flex items-center justify-center py-12 text-white/40"
>
<span
class="w-5 h-5 border-2 border-orange-400 border-t-transparent rounded-full animate-spin mr-2"
></span>
<span class="font-mono text-sm"
>Loading trending models...</span
>
</div>
{:else if hfDisplayModels.length === 0}
<div
class="flex flex-col items-center justify-center py-12 text-white/40"
>
<svg
class="w-10 h-10 mb-2"
viewBox="0 0 24 24"
fill="currentColor"
>
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 13.5c-.83 0-1.5-.67-1.5-1.5s.67-1.5 1.5-1.5 1.5.67 1.5 1.5-.67 1.5-1.5 1.5zm4 0c-.83 0-1.5-.67-1.5-1.5s.67-1.5 1.5-1.5 1.5.67 1.5 1.5-.67 1.5-1.5 1.5zm2-4.5H8c0-2.21 1.79-4 4-4s4 1.79 4 4z"
/>
</svg>
<p class="font-mono text-sm">No models found</p>
{#if hfSearchQuery}
<p class="font-mono text-xs mt-1">
Try a different search term
</p>
{/if}
</div>
{:else}
{#each hfDisplayModels as model}
<HuggingFaceResultItem
{model}
isAdded={existingModelIds.has(model.id)}
isAdding={addingModelId === model.id}
onAdd={() => handleAddModel(model.id)}
onSelect={() => handleSelectHfModel(model.id)}
/>
{/each}
{/if}
</div>
<!-- Manual input footer -->
<div
class="sticky bottom-0 border-t border-exo-yellow/10 bg-exo-dark-gray p-3"
>
{#if addModelError}
<div
class="bg-red-500/10 border border-red-500/30 rounded px-3 py-2 mb-2"
>
<p class="text-red-400 text-xs font-mono break-words">
{addModelError}
</p>
</div>
{/if}
<div class="flex gap-2">
<input
type="text"
class="flex-1 bg-exo-black/60 border border-exo-yellow/30 rounded px-3 py-1.5 text-xs font-mono text-white placeholder-white/30 focus:outline-none focus:border-exo-yellow/50"
placeholder="Or paste model ID directly..."
bind:value={manualModelId}
onkeydown={(e) => {
if (e.key === "Enter") handleAddManualModel();
}}
/>
<button
type="button"
onclick={handleAddManualModel}
disabled={!manualModelId.trim() || addingModelId !== null}
class="px-3 py-1.5 text-xs font-mono tracking-wider uppercase bg-orange-500/10 text-orange-400 border border-orange-400/30 hover:bg-orange-500/20 transition-colors rounded disabled:opacity-50 disabled:cursor-not-allowed"
>
Add
</button>
</div>
</div>
</div>
{:else if filteredGroups.length === 0}
<div
class="flex flex-col items-center justify-center h-full text-white/40 p-8"
>
<svg class="w-12 h-12 mb-3" viewBox="0 0 24 24" fill="currentColor">
<path
d="M12 2C6.48 2 2 6.48 2 12s4.48 10 10 10 10-4.48 10-10S17.52 2 12 2zm-2 15l-5-5 1.41-1.41L10 14.17l7.59-7.59L19 8l-9 9z"
/>
</svg>
<p class="font-mono text-sm">No models found</p>
{#if hasActiveFilters || searchQuery}
<button
type="button"
class="mt-2 text-xs text-exo-yellow hover:underline"
onclick={() => {
searchQuery = "";
clearFilters();
}}
>
Clear filters
</button>
{/if}
</div>
{:else}
{#each filteredGroups as group}
<ModelPickerGroup
{group}
isExpanded={expandedGroups.has(group.id)}
isFavorite={favorites.has(group.id)}
{selectedModelId}
{canModelFit}
onToggleExpand={() => toggleGroupExpanded(group.id)}
onSelectModel={handleSelect}
{onToggleFavorite}
onShowInfo={(g) => (infoGroup = g)}
/>
{/each}
{/if}
</div>
</div>
<!-- Footer with active filters indicator -->
{#if hasActiveFilters}
<div
class="flex items-center gap-2 px-3 py-2 border-t border-exo-yellow/10 bg-exo-medium-gray/20 text-xs font-mono text-white/50"
>
<span>Filters:</span>
{#each filters.capabilities as cap}
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded"
>{cap}</span
>
{/each}
{#if filters.sizeRange}
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded">
{filters.sizeRange.min}GB - {filters.sizeRange.max}GB
</span>
{/if}
<button
type="button"
class="ml-auto text-white/40 hover:text-white/60"
onclick={clearFilters}
>
Clear all
</button>
</div>
{/if}
</div>
<!-- Info modal -->
{#if infoGroup}
<div
class="fixed inset-0 z-[60] bg-black/60"
transition:fade={{ duration: 150 }}
onclick={() => (infoGroup = null)}
role="presentation"
></div>
<div
class="fixed z-[60] top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[min(80vw,400px)] bg-exo-dark-gray border border-exo-yellow/10 rounded-lg shadow-2xl p-4"
transition:fly={{ y: 10, duration: 200, easing: cubicOut }}
role="dialog"
aria-modal="true"
>
<div class="flex items-start justify-between mb-3">
<h3 class="font-mono text-lg text-white">{infoGroup.name}</h3>
<button
type="button"
class="p-1 rounded hover:bg-white/10 transition-colors text-white/50"
onclick={() => (infoGroup = null)}
title="Close model details"
aria-label="Close info dialog"
>
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="currentColor">
<path
d="M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z"
/>
</svg>
</button>
</div>
<div class="space-y-2 text-xs font-mono">
<div class="flex items-center gap-2">
<span class="text-white/40">Family:</span>
<span class="text-white/70">{infoGroup.family || "Unknown"}</span>
</div>
<div class="flex items-center gap-2">
<span class="text-white/40">Capabilities:</span>
<span class="text-white/70">{infoGroup.capabilities.join(", ")}</span>
</div>
<div class="flex items-center gap-2">
<span class="text-white/40">Variants:</span>
<span class="text-white/70">{infoGroup.variants.length}</span>
</div>
{#if infoGroup.variants.length > 0}
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
<span class="text-white/40">Available quantizations:</span>
<div class="flex flex-wrap gap-1 mt-1">
{#each infoGroup.variants as variant}
<span
class="px-1.5 py-0.5 bg-white/10 text-white/60 rounded text-[10px]"
>
{variant.quantization || "default"} ({Math.round(
(variant.storage_size_megabytes || 0) / 1024,
)}GB)
</span>
{/each}
</div>
</div>
{/if}
</div>
</div>
{/if}
{/if}

View File

@@ -6,3 +6,9 @@ export { default as ChatSidebar } from "./ChatSidebar.svelte";
export { default as ModelCard } from "./ModelCard.svelte";
export { default as MarkdownContent } from "./MarkdownContent.svelte";
export { default as ImageParamsPanel } from "./ImageParamsPanel.svelte";
export { default as FamilyLogos } from "./FamilyLogos.svelte";
export { default as FamilySidebar } from "./FamilySidebar.svelte";
export { default as HuggingFaceResultItem } from "./HuggingFaceResultItem.svelte";
export { default as ModelFilterPopover } from "./ModelFilterPopover.svelte";
export { default as ModelPickerGroup } from "./ModelPickerGroup.svelte";
export { default as ModelPickerModal } from "./ModelPickerModal.svelte";

View File

@@ -0,0 +1,97 @@
/**
* FavoritesStore - Manages favorite models with localStorage persistence
*/
import { browser } from "$app/environment";
const FAVORITES_KEY = "exo-favorite-models";
class FavoritesStore {
favorites = $state<Set<string>>(new Set());
constructor() {
if (browser) {
this.loadFromStorage();
}
}
private loadFromStorage() {
try {
const stored = localStorage.getItem(FAVORITES_KEY);
if (stored) {
const parsed = JSON.parse(stored) as string[];
this.favorites = new Set(parsed);
}
} catch (error) {
console.error("Failed to load favorites:", error);
}
}
private saveToStorage() {
try {
const array = Array.from(this.favorites);
localStorage.setItem(FAVORITES_KEY, JSON.stringify(array));
} catch (error) {
console.error("Failed to save favorites:", error);
}
}
add(baseModelId: string) {
const next = new Set(this.favorites);
next.add(baseModelId);
this.favorites = next;
this.saveToStorage();
}
remove(baseModelId: string) {
const next = new Set(this.favorites);
next.delete(baseModelId);
this.favorites = next;
this.saveToStorage();
}
toggle(baseModelId: string) {
if (this.favorites.has(baseModelId)) {
this.remove(baseModelId);
} else {
this.add(baseModelId);
}
}
isFavorite(baseModelId: string): boolean {
return this.favorites.has(baseModelId);
}
getAll(): string[] {
return Array.from(this.favorites);
}
getSet(): Set<string> {
return new Set(this.favorites);
}
hasAny(): boolean {
return this.favorites.size > 0;
}
clearAll() {
this.favorites = new Set();
this.saveToStorage();
}
}
export const favoritesStore = new FavoritesStore();
export const favorites = () => favoritesStore.favorites;
export const hasFavorites = () => favoritesStore.hasAny();
export const isFavorite = (baseModelId: string) =>
favoritesStore.isFavorite(baseModelId);
export const toggleFavorite = (baseModelId: string) =>
favoritesStore.toggle(baseModelId);
export const addFavorite = (baseModelId: string) =>
favoritesStore.add(baseModelId);
export const removeFavorite = (baseModelId: string) =>
favoritesStore.remove(baseModelId);
export const getFavorites = () => favoritesStore.getAll();
export const getFavoritesSet = () => favoritesStore.getSet();
export const clearFavorites = () => favoritesStore.clearAll();

View File

@@ -5,7 +5,13 @@
ChatMessages,
ChatSidebar,
ModelCard,
ModelPickerModal,
} from "$lib/components";
import {
favorites,
toggleFavorite,
getFavoritesSet,
} from "$lib/stores/favorites.svelte";
import {
hasStartedChat,
isTopologyMinimized,
@@ -101,6 +107,10 @@
tasks?: string[];
hugging_face_id?: string;
is_custom?: boolean;
family?: string;
quantization?: string;
base_model?: string;
capabilities?: string[];
}>
>([]);
@@ -212,14 +222,11 @@
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state("");
// Model picker modal state
let isModelPickerOpen = $state(false);
// Custom model add state
let customModelInput = $state("");
let isAddingCustomModel = $state(false);
let customModelError = $state("");
// Favorites state (reactive)
const favoritesSet = $derived(getFavoritesSet());
// Slider dragging state
let isDraggingSlider = $state(false);
@@ -536,41 +543,25 @@
}
}
async function addCustomModel() {
const modelId = customModelInput.trim();
if (!modelId) return;
async function addModelFromPicker(modelId: string) {
const response = await fetch("/models/add", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ model_id: modelId }),
});
isAddingCustomModel = true;
customModelError = "";
try {
const response = await fetch("/models/add", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ model_id: modelId }),
});
if (!response.ok) {
try {
const err = await response.json();
customModelError =
err.detail || `Failed to add model (${response.status})`;
} catch {
customModelError = `Failed to add model (${response.status}: ${response.statusText})`;
}
return;
if (!response.ok) {
let message = `Failed to add model (${response.status}: ${response.statusText})`;
try {
const err = await response.json();
if (err.detail) message = err.detail;
} catch {
// use default message
}
const added = await response.json();
customModelInput = "";
await fetchModels();
selectPreviewModel(added.id);
isModelDropdownOpen = false;
} catch {
customModelError = "Network error";
} finally {
isAddingCustomModel = false;
throw new Error(message);
}
await fetchModels();
}
async function deleteCustomModel(modelId: string) {
@@ -587,6 +578,12 @@
}
}
function handleModelPickerSelect(modelId: string) {
selectPreviewModel(modelId);
saveLaunchDefaults();
isModelPickerOpen = false;
}
async function launchInstance(
modelId: string,
specificPreview?: PlacementPreview | null,
@@ -2417,14 +2414,12 @@
>
</div>
<!-- Model Dropdown (Custom) -->
<div class="flex-shrink-0 mb-3 relative">
<!-- Model Picker Button -->
<div class="flex-shrink-0 mb-3">
<button
type="button"
onclick={() => (isModelDropdownOpen = !isModelDropdownOpen)}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-2.5 text-sm font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen
? 'border-exo-yellow/70'
: ''}"
onclick={() => (isModelPickerOpen = true)}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-2.5 text-sm font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 relative"
>
{#if selectedModelId}
{@const foundModel = models.find(
@@ -2432,54 +2427,12 @@
)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
{@const isImageModel = modelSupportsImageGeneration(
foundModel.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
foundModel.id,
)}
<span
class="flex items-center justify-between gap-2 w-full pr-4"
>
<span
class="flex items-center gap-2 text-exo-light-gray truncate"
>
{#if isImageModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect
x="3"
y="3"
width="18"
height="18"
rx="2"
ry="2"
/>
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate"
>{foundModel.name || foundModel.id}</span
>
@@ -2496,193 +2449,24 @@
{:else}
<span class="text-white/50"> SELECT MODEL </span>
{/if}
</button>
<div
class="absolute right-3 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isModelDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-4 h-4 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
{#if isModelDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-40 cursor-default"
onclick={() => (isModelDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel -->
<div
class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-50 max-h-64 overflow-y-auto"
class="absolute right-3 top-1/2 -translate-y-1/2 pointer-events-none"
>
<!-- Search within dropdown -->
<div
class="sticky top-0 bg-exo-dark-gray border-b border-exo-medium-gray/30 p-2 space-y-2"
<svg
class="w-4 h-4 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<input
type="text"
placeholder="Search models..."
bind:value={modelDropdownSearch}
class="w-full bg-exo-dark-gray/60 border border-exo-medium-gray/30 rounded px-2 py-1.5 text-xs font-mono text-white/80 placeholder:text-white/40 focus:outline-none focus:border-exo-yellow/50"
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
<!-- Add custom model -->
<form
onsubmit={(e) => {
e.preventDefault();
addCustomModel();
}}
class="flex gap-1.5"
>
<input
type="text"
placeholder="org/model-name"
bind:value={customModelInput}
class="flex-1 bg-exo-dark-gray/60 border border-exo-medium-gray/30 rounded px-2 py-1.5 text-xs font-mono text-white/80 placeholder:text-white/40 focus:outline-none focus:border-exo-yellow/50"
/>
<button
type="submit"
disabled={!customModelInput.trim() ||
isAddingCustomModel}
class="px-2 py-1.5 text-xs font-mono bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/30 rounded hover:bg-exo-yellow/30 disabled:opacity-40 disabled:cursor-not-allowed"
>
{isAddingCustomModel ? "..." : "+ ADD"}
</button>
</form>
{#if customModelError}
<div class="text-xs text-red-400 font-mono">
{customModelError}
</div>
{/if}
</div>
<!-- Options -->
<div class="py-1">
{#each sortedModels().filter((m) => !modelDropdownSearch || (m.name || m.id)
.toLowerCase()
.includes(modelDropdownSearch.toLowerCase())) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(
model.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
model.id,
)}
<button
type="button"
onclick={() => {
if (modelCanFit) {
selectPreviewModel(model.id);
saveLaunchDefaults();
isModelDropdownOpen = false;
modelDropdownSearch = "";
}
}}
disabled={!modelCanFit}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedModelId ===
model.id
? 'bg-transparent text-exo-yellow cursor-pointer'
: modelCanFit
? 'text-white/80 hover:text-exo-yellow cursor-pointer'
: 'text-white/30 cursor-default'}"
>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image generation model"
>
<rect
x="3"
y="3"
width="18"
height="18"
rx="2"
ry="2"
/>
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image editing model"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span class="flex items-center gap-1.5 flex-shrink-0">
<span
class="text-xs {modelCanFit
? 'text-white/50'
: 'text-red-400/60'}"
>
{sizeGB >= 1
? sizeGB.toFixed(0)
: sizeGB.toFixed(1)}GB
</span>
{#if model.is_custom}
<button
type="button"
onclick={(e) => {
e.stopPropagation();
deleteCustomModel(model.id);
}}
class="text-white/30 hover:text-red-400 transition-colors p-0.5"
title="Remove custom model"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path d="M18 6L6 18M6 6l12 12" />
</svg>
</button>
{/if}
</span>
</button>
{:else}
<div class="px-3 py-2 text-xs text-white/50 font-mono">
No models found
</div>
{/each}
</div>
</svg>
</div>
{/if}
</button>
</div>
<!-- Configuration Options -->
@@ -3462,3 +3246,22 @@
{/if}
</main>
</div>
<ModelPickerModal
isOpen={isModelPickerOpen}
{models}
{selectedModelId}
favorites={favoritesSet}
existingModelIds={new Set(models.map((m) => m.id))}
canModelFit={(modelId) => {
const model = models.find((m) => m.id === modelId);
return model ? hasEnoughMemory(model) : false;
}}
onSelect={handleModelPickerSelect}
onClose={() => (isModelPickerOpen = false)}
onToggleFavorite={toggleFavorite}
onAddModel={addModelFromPicker}
onDeleteModel={deleteCustomModel}
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
/>

View File

@@ -69,6 +69,7 @@
./dashboard/parts.nix
./rust/parts.nix
./python/parts.nix
./resources/parts.nix
];
perSystem =

View File

@@ -69,7 +69,8 @@
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set DASHBOARD_DIR ${self'.packages.dashboard} \
--set EXO_DASHBOARD_DIR ${self'.packages.dashboard} \
--set EXO_RESOURCES_DIR ${self'.packages.resources} \
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
done
'';

View File

@@ -3,6 +3,10 @@ n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "deepseek"
quantization = "4bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405874409472

View File

@@ -3,6 +3,10 @@ n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "deepseek"
quantization = "8bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 765577920512

View File

@@ -3,6 +3,10 @@ n_layers = 46
hidden_size = 4096
supports_tensor = false
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 122406567936

View File

@@ -3,6 +3,10 @@ n_layers = 46
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 229780750336

View File

@@ -3,6 +3,10 @@ n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 198556925568

View File

@@ -3,6 +3,10 @@ n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 286737579648

View File

@@ -3,6 +3,10 @@ n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 396963397248

View File

@@ -3,6 +3,10 @@ n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 19327352832

View File

@@ -3,6 +3,10 @@ n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "5bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 22548578304

View File

@@ -3,6 +3,10 @@ n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 26843545600

View File

@@ -3,6 +3,10 @@ n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 34359738368

View File

@@ -3,6 +3,10 @@ n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = "4bit"
base_model = "Kimi K2"
capabilities = ["text"]
[storage_size]
in_bytes = 620622774272

View File

@@ -3,6 +3,10 @@ n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 706522120192

View File

@@ -3,6 +3,10 @@ n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2.5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 662498705408

View File

@@ -3,6 +3,10 @@ n_layers = 16
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.2 1B"
capabilities = ["text"]
[storage_size]
in_bytes = 729808896

View File

@@ -3,6 +3,10 @@ n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.2 3B"
capabilities = ["text"]
[storage_size]
in_bytes = 1863319552

View File

@@ -3,6 +3,10 @@ n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.2 3B"
capabilities = ["text"]
[storage_size]
in_bytes = 3501195264

View File

@@ -3,6 +3,10 @@ n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 40652242944

View File

@@ -3,6 +3,10 @@ n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 76799803392

View File

@@ -3,6 +3,10 @@ n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.1 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 40652242944

View File

@@ -3,6 +3,10 @@ n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "4bit"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 4637851648

View File

@@ -3,6 +3,10 @@ n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "8bit"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 8954839040

View File

@@ -3,6 +3,10 @@ n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "bf16"
base_model = "Llama 3.1 8B"
capabilities = ["text"]
[storage_size]
in_bytes = 16882073600

View File

@@ -3,6 +3,10 @@ n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "3bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 100086644736

View File

@@ -3,6 +3,10 @@ n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -3,6 +3,10 @@ n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 342884352

View File

@@ -3,6 +3,10 @@ n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 698351616

View File

@@ -3,6 +3,10 @@ n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 141733920768

View File

@@ -3,6 +3,10 @@ n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 268435456000

View File

@@ -3,6 +3,10 @@ n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 17612931072

View File

@@ -3,6 +3,10 @@ n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 33279705088

View File

@@ -3,6 +3,10 @@ n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Coder 480B"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 289910292480

View File

@@ -3,6 +3,10 @@ n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Coder 480B"
capabilities = ["text", "code"]
[storage_size]
in_bytes = 579820584960

View File

@@ -3,6 +3,10 @@ n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text"]
[storage_size]
in_bytes = 46976204800

View File

@@ -3,6 +3,10 @@ n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text"]
[storage_size]
in_bytes = 88814387200

View File

@@ -3,6 +3,10 @@ n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 47080074240

View File

@@ -3,6 +3,10 @@ n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 88814387200

View File

@@ -3,6 +3,10 @@ n_layers = 36
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
family = "gpt-oss"
quantization = "MXFP4-Q8"
base_model = "GPT-OSS 120B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 70652212224

View File

@@ -3,6 +3,10 @@ n_layers = 24
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
family = "gpt-oss"
quantization = "MXFP4-Q8"
base_model = "GPT-OSS 20B"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 12025908224

View File

@@ -3,6 +3,10 @@ n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "fp16"
base_model = "Llama 3.3 70B"
capabilities = ["text"]
[storage_size]
in_bytes = 144383672320

17
resources/parts.nix Normal file
View File

@@ -0,0 +1,17 @@
{ inputs, ... }:
{
perSystem =
{ pkgs, lib, ... }:
let
# Filter source to only include resources directory
resourcesSrc = lib.cleanSourceWith {
src = inputs.self + "/resources";
};
in
{
packages.resources = pkgs.runCommand "exo-resources" { } ''
cp -r ${resourcesSrc} $out
'';
};
}

View File

@@ -1,2 +0,0 @@
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
#allowed-duplicate-crates = ["hashbrown"]

View File

@@ -25,44 +25,25 @@ workspace = true
networking = { workspace = true }
# interop
pyo3 = { version = "0.27.1", features = [
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
"nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
"multiple-pymethods", # allows multiple #[pymethods] sections per class
pyo3 = { version = "0.27.2", features = [
"abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
# "nightly", # enables better-supported GIL integration
"experimental-async" # async support in #[pyfunction] & #[pymethods]
# "experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
# "py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
"ordered-float", "rust_decimal", "smallvec",
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
# "ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.17.2" }
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
pin-project = { workspace = true }
# async runtime
tokio = { workspace = true, features = ["full", "tracing"] }
futures = { workspace = true }
# utility dependencies
once_cell = "1.21.3"
thread_local = "1.1.9"
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
# Tracing
#tracing = "0.1"
@@ -75,3 +56,4 @@ env_logger = "0.11"
# Networking
libp2p = { workspace = true, features = ["full"] }
futures-lite = "2.6.1"

View File

@@ -2,220 +2,39 @@
# ruff: noqa: E501, F401
import builtins
import enum
import typing
@typing.final
class AllQueuesFullError(builtins.Exception):
def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdate:
@property
def update_type(self) -> ConnectionUpdateType:
r"""
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> PeerId:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@property
def remote_ipv4(self) -> builtins.str:
r"""
Remote connection's IPv4 address.
"""
@property
def remote_tcp_port(self) -> builtins.int:
r"""
Remote connection's TCP port.
"""
@typing.final
class Keypair:
r"""
Identity keypair of a node.
"""
@staticmethod
def generate_ed25519() -> Keypair:
def generate() -> Keypair:
r"""
Generate a new Ed25519 keypair.
"""
@staticmethod
def generate_ecdsa() -> Keypair:
r"""
Generate a new ECDSA keypair.
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
Generate a new ed25519 keypair
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
Encode a private key to a protobuf structure.
"""
def to_string(self) -> builtins.str: ...
@typing.final
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
class PyPeer:
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
"""
def to_bytes(self) -> bytes:
r"""
Return a copy of this [`Multiaddr`]'s byte representation.
"""
def to_string(self) -> builtins.str:
r"""
Convert a Multiaddr to a string.
"""
def new(kp: Keypair, namespace: builtins.str) -> PyPeer: ...
async def subscribe(self, topic: builtins.str) -> None: ...
async def unsubscribe(self, topic: builtins.str) -> None: ...
async def send(self, topic: builtins.str, payload: bytes) -> None: ...
async def run(self) -> None: ...
async def recv(self) -> PySwarmEvent: ...
@typing.final
class NetworkingHandle:
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
async def connection_update_recv(self) -> ConnectionUpdate:
r"""
Receives the next `ConnectionUpdate` from networking.
"""
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
r"""
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
will sleep until a `ConnectionUpdate`s is sent.
"""
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Subscribe to a `GossipSub` topic.
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
"""
async def gossipsub_unsubscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Unsubscribes from a `GossipSub` topic.
Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
"""
async def gossipsub_publish(self, topic: builtins.str, data: bytes) -> None:
r"""
Publishes a message with multiple topics to the `GossipSub` network.
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
"""
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
r"""
Receives the next message from the `GossipSub` network.
"""
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
r"""
Receives at most `limit` messages from the `GossipSub` network and returns them.
For `limit = 0`, an empty collection of messages will be returned immediately.
For `limit > 0`, if there are no messages in the channel's queue this method
will sleep until a message is sent.
"""
@typing.final
class NoPeersSubscribedToTopicError(builtins.Exception):
def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""
Connection or disconnection event discriminant type.
"""
Connected = ...
Disconnected = ...
class PySwarmEvent:
def downcast_discovered(self) -> typing.Optional[builtins.str]: ...
def downcast_expired(self) -> typing.Optional[builtins.str]: ...
def downcast_message(self) -> typing.Optional[tuple[builtins.str, builtins.str, bytes]]: ...

View File

@@ -1,8 +1,4 @@
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
//!
use pin_project::pin_project;
use pyo3::marker::Ungil;
//! See: <https://pyo3.rs/v0.27.2/async-await.html#detaching-from-the-interpreter-across-await>
use pyo3::prelude::*;
use std::{
future::Future,
@@ -10,31 +6,17 @@ use std::{
task::{Context, Poll},
};
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
#[pin_project]
#[repr(transparent)]
pub(crate) struct AllowThreads<F>(#[pin] F);
impl<F> AllowThreads<F>
where
Self: Future,
{
pub fn new(f: F) -> Self {
Self(f)
}
}
pub struct AllowThreads<F>(pub(crate) F);
impl<F> Future for AllowThreads<F>
where
F: Future + Ungil,
F::Output: Ungil,
F: Future + Unpin + Send,
F::Output: Send,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
Python::with_gil(|py| {
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
})
Python::attach(|py| py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker))))
}
}

View File

@@ -1,240 +0,0 @@
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
//!
//! Pattern examples include:
//! - Async task handles: with GC-integrated cleanup
//! - Sync/async callbacks from python: with propper eventloop handling
//!
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
//! - Store mutable fields in tokio's `Mutex<T>`
//! - For async code: take `&self` and `.lock().await`
//! - For sync code: take `&mut self` and `.get_mut()`
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
fn needs_tokio_runtime() {
tokio::runtime::Handle::current();
}
type SyncCallback = Box<dyn Fn() + Send + Sync>;
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
enum AsyncTaskMessage {
SyncCallback(SyncCallback),
AsyncCallback(AsyncCallback),
}
async fn async_task(
sender: mpsc::UnboundedSender<()>,
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
) {
log::info!("RUST: async task started");
// task state
let mut interval = tokio::time::interval(Duration::from_secs(1));
let mut sync_cbs: Vec<SyncCallback> = vec![];
let mut async_cbs: Vec<AsyncCallback> = vec![];
loop {
tokio::select! {
// handle incoming messages from task-handle
message = receiver.recv() => {
// handle closed channel by exiting
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming event
match message {
AsyncTaskMessage::SyncCallback(cb) => {
sync_cbs.push(cb);
}
AsyncTaskMessage::AsyncCallback(cb) => {
async_cbs.push(cb);
}
}
}
// handle all other events
_ = interval.tick() => {
log::info!("RUST: async task tick");
// call back all sync callbacks
for cb in &sync_cbs {
cb();
}
// call back all async callbacks
for cb in &async_cbs {
cb().await;
}
// send event on unbounded channel
sender.send(()).expect("handle receiver cannot be closed/dropped");
}
}
}
log::info!("RUST: async task stopped");
}
// #[gen_stub_pyclass]
#[pyclass(name = "AsyncTaskHandle")]
#[derive(Debug)]
struct PyAsyncTaskHandle {
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
receiver: mpsc::UnboundedReceiver<()>,
}
#[allow(clippy::expect_used)]
impl PyAsyncTaskHandle {
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_mut()
.expect("The sender should only be None after de-initialization.")
}
const fn new(
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
receiver: mpsc::UnboundedReceiver<()>,
) -> Self {
Self {
sender: Some(sender),
receiver,
}
}
}
// #[gen_stub_pymethods]
#[pymethods]
impl PyAsyncTaskHandle {
#[new]
fn py_new(py: Python<'_>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channel TOWARDS our task
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
// create communication channel FROM our task
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
// perform necessary setup within tokio context - or it crashes
let () = get_runtime().block_on(async { needs_tokio_runtime() });
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
_ = get_runtime().spawn_with_scope(py, async move {
async_task(t_sender, t_receiver).await;
});
Ok(Self::new(h_sender, h_receiver))
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_sync_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], None]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
})))
.pyerr()?;
Ok(())
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_async_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
let c = Python::with_gil(|py| callback.clone_ref(py));
async move {
if let Some(f) = Python::with_gil(|py| {
let coroutine = c.call0(py).write_unraisable_with(py)?;
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
.write_unraisable_with(py)
}) {
_ = f.await.write_unraisable();
}
}
.boxed()
})))
.pyerr()?;
Ok(())
}
async fn receive_unit(&mut self) -> PyResult<()> {
self.receiver
.recv()
.await
.ok_or(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
))
}
fn drain_units(&mut self) -> PyResult<i32> {
let mut cnt = 0;
loop {
match self.receiver.try_recv() {
Err(TryRecvError::Disconnected) => {
return Err(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
));
}
Err(TryRecvError::Empty) => return Ok(cnt),
Ok(()) => {
cnt += 1;
continue;
}
}
}
}
// #[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
// #[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
}
}
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAsyncTaskHandle>()?;
Ok(())
}

View File

@@ -1,216 +1,42 @@
//! TODO: crate documentation
//!
//! this is here as a placeholder documentation
//!
//!
pub(crate) mod allow_threading;
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
pub(crate) mod take_once {
use tokio::sync::Mutex;
pub struct TakeOnce<T>(Mutex<Option<T>>);
impl<T> TakeOnce<T> {
pub fn new(t: T) -> Self {
Self(Mutex::new(Some(t)))
}
pub fn take(&self) -> Option<T> {
match self.0.try_lock() {
Ok(mut o) => o.take(),
Err(_) => None,
}
}
}
}
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::prelude::*;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
/// Namespace for all the constants used by this crate.
pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
use extend::ext;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::marker::Ungil;
use pyo3::types::PyBytes;
use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;
#[ext(pub, name = ByteArrayExt)]
impl [u8] {
fn pybytes(&self) -> Py<PyBytes> {
Python::with_gil(|py| PyBytes::new(py, self).unbind())
}
}
#[ext(pub, name = ResultExt)]
impl<T, E> Result<T, E>
where
E: ToString,
{
fn pyerr(self) -> PyResult<T> {
self.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
}
pub trait FutureExt: Future + Sized {
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
fn allow_threads_py(self) -> AllowThreads<Self>
where
AllowThreads<Self>: Future,
{
AllowThreads::new(self)
}
}
impl<T: Future> FutureExt for T {}
#[ext(pub, name = PyErrExt)]
impl PyErr {
fn receiver_channel_closed() -> Self {
PyConnectionError::new_err("Receiver channel closed unexpectedly")
}
}
#[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> {
Python::with_gil(|py| self.write_unraisable_with(py))
}
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
match self {
Ok(v) => Some(v),
Err(e) => {
// write error back to python
e.write_unraisable(py, None);
None
}
}
}
}
#[ext(pub, name = TokioRuntimeExt)]
impl Runtime {
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
}
}
#[ext(pub, name = TokioMpscSenderExt)]
impl<T> mpsc::Sender<T> {
/// Sends a value, waiting until there is capacity.
///
/// A successful send occurs when it is determined that the other end of the
/// channel has not hung up already. An unsuccessful send would be one where
/// the corresponding receiver has already been closed.
async fn send_py(&self, value: T) -> PyResult<()> {
self.send(value)
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
}
#[ext(pub, name = TokioMpscReceiverExt)]
impl<T> mpsc::Receiver<T> {
/// Receives the next value for this receiver.
async fn recv_py(&mut self) -> PyResult<T> {
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
}
/// Receives at most `limit` values for this receiver and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
// get updates from receiver channel
let mut updates = Vec::with_capacity(limit);
let received = self.recv_many(&mut updates, limit).await;
// if we received zero items, then the channel was unexpectedly closed
if limit != 0 && received == 0 {
return Err(PyErr::receiver_channel_closed());
}
Ok(updates)
}
/// Tries to receive the next value for this receiver.
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
match self.try_recv() {
Ok(v) => Ok(Some(v)),
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
}
}
}
}
pub(crate) mod private {
use std::marker::Sized;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule(name = "exo_pyo3_bindings")]
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
pub fn networking_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// install logger
pyo3_log::init();
// setup runtime
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder.enable_all();
pyo3_async_runtimes::tokio::init(builder);
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?;
multiaddr_submodule(m)?;
networking_submodule(m)?;
// top-level constructs
// TODO: ...
m.add_class::<networking::PyPeer>()?;
m.add_class::<networking::PyKeypair>()?;
Ok(())
}

View File

@@ -1,571 +1,214 @@
#![allow(
clippy::multiple_inherent_impl,
clippy::unnecessary_wraps,
clippy::unused_self,
clippy::needless_pass_by_value
)]
use crate::allow_threading::AllowThreads;
use crate::take_once::TakeOnce;
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::pyclass;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use libp2p::{gossipsub, mdns};
use networking::discovery;
use networking::swarm::create_swarm;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
use std::net::IpAddr;
use tokio::sync::{Mutex, mpsc, oneshot};
use util::ext::VecExt as _;
use std::pin::pin;
mod exception {
use pyo3::types::PyTuple;
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
use pyo3_stub_gen::derive::*;
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
pub struct PyNoPeersSubscribedToTopicError {}
impl PyNoPeersSubscribedToTopicError {
const MSG: &'static str = "\
No peers are currently subscribed to receive messages on this topic. \
Wait for peers to subscribe or check your network connectivity.";
/// Creates a new [ `PyErr` ] of this type.
///
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
pub(crate) fn new_err() -> PyErr {
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PyNoPeersSubscribedToTopicError {
#[new]
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="AllQueuesFullError")]
pub struct PyAllQueuesFullError {}
impl PyAllQueuesFullError {
const MSG: &'static str =
"All libp2p peers are unresponsive, resend the message or reconnect.";
/// Creates a new [ `PyErr` ] of this type.
///
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
pub(crate) fn new_err() -> PyErr {
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PyAllQueuesFullError {
#[new]
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
}
/// Connection or disconnection event discriminant type.
#[gen_stub_pyclass_enum]
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
#[derive(Debug, Clone, PartialEq)]
enum PyConnectionUpdateType {
Connected = 0,
Disconnected,
}
use futures_lite::FutureExt;
use libp2p::{gossipsub::PublishError, identity::Keypair};
use networking::{FromSwarm, Peer, ToSwarm};
use pyo3::{
coroutine::CancelHandle,
exceptions::{PyConnectionError, PyRuntimeError, PyValueError},
prelude::*,
types::PyBytes,
};
use pyo3_stub_gen::{
derive::{gen_methods_from_python, gen_stub_pyclass, gen_stub_pymethods},
inventory::submit,
};
use tokio::sync::{Mutex, mpsc};
#[gen_stub_pyclass]
#[pyclass(frozen, name = "ConnectionUpdate")]
#[derive(Debug, Clone)]
struct PyConnectionUpdate {
/// Whether this is a connection or disconnection event
#[pyo3(get)]
update_type: PyConnectionUpdateType,
#[pyclass(name = "Keypair", frozen)]
#[derive(Clone)]
pub struct PyKeypair(Keypair);
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: PyPeerId,
#[gen_stub_pymethods]
#[pymethods]
impl PyKeypair {
/// Generate a new ed25519 keypair
#[staticmethod]
fn generate() -> Self {
Self(Keypair::generate_ed25519())
}
/// Remote connection's IPv4 address.
#[pyo3(get)]
remote_ipv4: String,
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: &Bound<'_, PyBytes>) -> Self {
let bytes = Vec::from(bytes.as_bytes());
Self(Keypair::from_protobuf_encoding(&bytes).expect("todo"))
}
/// Remote connection's TCP port.
#[pyo3(get)]
remote_tcp_port: u16,
}
enum ToTask {
GossipsubSubscribe {
topic: String,
result_tx: oneshot::Sender<PyResult<bool>>,
},
GossipsubUnsubscribe {
topic: String,
result_tx: oneshot::Sender<bool>,
},
GossipsubPublish {
topic: String,
data: Vec<u8>,
result_tx: oneshot::Sender<PyResult<MessageId>>,
},
}
#[allow(clippy::enum_glob_use)]
async fn networking_task(
mut swarm: networking::swarm::Swarm,
mut to_task_rx: mpsc::Receiver<ToTask>,
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
) {
use SwarmEvent::*;
use ToTask::*;
use mdns::Event::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
loop {
tokio::select! {
message = to_task_rx.recv() => {
// handle closed channel
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming messages
match message {
GossipsubSubscribe { topic, result_tx } => {
// try to subscribe
let result = swarm.behaviour_mut()
.gossipsub.subscribe(&IdentTopic::new(topic));
// send response oneshot
if let Err(e) = result_tx.send(result.pyerr()) {
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubUnsubscribe { topic, result_tx } => {
// try to unsubscribe from the topic
let result = swarm.behaviour_mut()
.gossipsub.unsubscribe(&IdentTopic::new(topic));
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(result) {
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubPublish { topic, data, result_tx } => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
let result = swarm.behaviour_mut().gossipsub.publish(
IdentTopic::new(topic), data);
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
} else if let Err(PublishError::AllQueuesFull(_)) = result {
Err(exception::PyAllQueuesFullError::new_err())
} else {
result.pyerr()
};
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(pyresult) {
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
continue;
}
}
}
}
// architectural solution to this problem:
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
//
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
// then for actual communication it will dial those peers if need-be
swarm_event = swarm.select_next_some() => {
match swarm_event {
Behaviour(Gossipsub(gossipsub::Event::Message {
message: Message {
topic,
data,
..
},
..
})) => {
// topic-ID is just the topic hash!!! (since we used identity hasher)
let message = (topic.into_string(), data);
// send incoming message to channel (or exit if connection closed)
if let Err(e) = gossipsub_message_tx.send(message).await {
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
continue;
}
};
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
continue;
}
};
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
e => {
log::info!("RUST: other event {e:?}");
}
}
}
/// Encode a private key to a protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
match self.0.to_protobuf_encoding() {
Ok(bytes) => Ok(PyBytes::new(py, &bytes)),
Err(e) => Err(PyValueError::new_err(e.to_string())),
}
}
log::info!("RUST: networking task stopped");
fn to_string(&self) -> String {
self.0.public().to_peer_id().to_base58()
}
}
struct PeerBuilder(
String,
Keypair,
mpsc::Sender<FromSwarm>,
mpsc::Receiver<ToSwarm>,
);
#[gen_stub_pyclass]
#[pyclass(name = "NetworkingHandle")]
#[derive(Debug)]
struct PyNetworkingHandle {
// channels
to_task_tx: Option<mpsc::Sender<ToTask>>,
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
}
impl Drop for PyNetworkingHandle {
fn drop(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
}
#[allow(clippy::expect_used)]
impl PyNetworkingHandle {
fn new(
to_task_tx: mpsc::Sender<ToTask>,
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
) -> Self {
Self {
to_task_tx: Some(to_task_tx),
connection_update_rx: Mutex::new(connection_update_rx),
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
}
}
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
self.to_task_tx
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
#[pyclass]
pub struct PyPeer {
peer: TakeOnce<PeerBuilder>,
to_swarm: mpsc::Sender<ToSwarm>,
from_swarm: Mutex<mpsc::Receiver<FromSwarm>>,
}
#[gen_stub_pymethods]
#[pymethods]
impl PyNetworkingHandle {
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
// immediately beforehand to release the interpreter.
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
// ---- Lifecycle management methods ----
#[new]
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channels
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
// get identity
let identity = identity.borrow().0.clone();
// create networking swarm (within tokio context!! or it crashes)
let swarm = get_runtime()
.block_on(async { create_swarm(identity) })
.pyerr()?;
// spawn tokio task running the networking logic
get_runtime().spawn(async move {
networking_task(
swarm,
to_task_rx,
connection_update_tx,
gossipsub_message_tx,
)
.await;
});
Ok(Self::new(
to_task_tx,
connection_update_rx,
gossipsub_message_rx,
))
impl PyPeer {
#[staticmethod]
fn new(kp: PyKeypair, namespace: String) -> PyResult<Self> {
let (to_client, from_swarm) = mpsc::channel(1024);
let (to_swarm, from_client) = mpsc::channel(1024);
Ok(Self {
peer: TakeOnce::new(PeerBuilder(namespace, kp.0, to_client, from_client)),
to_swarm,
from_swarm: Mutex::new(from_swarm),
})
}
#[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
async fn run(&self, #[pyo3(cancel_handle)] mut cancel: CancelHandle) -> PyResult<()> {
let builder = self
.peer
.take()
.ok_or_else(|| PyRuntimeError::new_err("tried to run peer twice"))?;
let jh = pyo3_async_runtimes::tokio::get_runtime()
.spawn(async move {
let mut peer =
Peer::new(builder.0, builder.1, builder.2, builder.3).map_err(|_| {
PyConnectionError::new_err("peer failed to listen on default address")
})?;
peer.run()
.await
.map_err(|()| PyConnectionError::new_err("peer communication closed"))
})
.or(async {
cancel.cancelled().await;
Ok(Ok(()))
});
match AllowThreads(pin!(jh)).await {
Err(e) if e.is_cancelled() => Ok(()),
Err(e) if e.is_panic() => Err(PyRuntimeError::new_err(format!("tokio panic {e}"))),
Err(_) => unreachable!(),
Ok(res) => res,
}
}
async fn subscribe(&self, topic: String) -> PyResult<()> {
self.to_swarm
.send(ToSwarm::Subscribe(topic))
.await
.map_err(|_| PyRuntimeError::new_err("swarm communication closed"))
}
async fn unsubscribe(&self, topic: String) -> PyResult<()> {
self.to_swarm
.send(ToSwarm::Unsubscribe(topic))
.await
.map_err(|_| PyRuntimeError::new_err("swarm communication closed"))
}
async fn send(&self, topic: String, payload: Py<PyBytes>) -> PyResult<()> {
// this function attaches to the python interpreter synchronously to avoid holding the GIL
let bytes = Python::attach(|py| Vec::from(payload.bind(py).as_bytes()));
self.to_swarm
.send(ToSwarm::Message(topic, bytes))
.await
.map_err(|_| PyRuntimeError::new_err("swarm communication closed"))
}
#[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
async fn recv(
&self,
#[pyo3(cancel_handle)] mut cancel: CancelHandle,
) -> PyResult<PySwarmEvent> {
loop {
return match AllowThreads(pin!(
self.from_swarm
.try_lock()
.map_err(|_| PyRuntimeError::new_err("tried to recv twice"))?
.recv()
.or(async {
cancel.cancelled().await;
None
})
))
.await
{
Some(FromSwarm::PublishError(p)) => match p {
PublishError::AllQueuesFull(_) => {
Err(PyConnectionError::new_err("swarm overloaded"))
}
PublishError::MessageTooLarge => {
Err(PyValueError::new_err("message too large"))
}
PublishError::NoPeersSubscribedToTopic => {
continue;
}
// TODO(evan): logs here
_ => continue,
},
None => Err(PyRuntimeError::new_err("swarm communication closed")),
Some(fs) => Ok(PySwarmEvent(fs)),
};
}
}
// ---- Connection update receiver methods ----
/// Receives the next `ConnectionUpdate` from networking.
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
}
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
///
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
/// will sleep until a `ConnectionUpdate`s is sent.
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next `ConnectionUpdate` from networking.
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
// self.connection_update_rx.blocking_lock().try_recv_py()
// }
//
// /// Checks if the `ConnectionUpdate` channel is empty.
// fn connection_update_is_empty(&self) -> bool {
// self.connection_update_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `ConnectionUpdate`s in the channel.
// fn connection_update_len(&self) -> usize {
// self.connection_update_rx.blocking_lock().len()
// }
// ---- Gossipsub management methods ----
/// Subscribe to a `GossipSub` topic.
///
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
self.to_task_tx()
.send_py(ToTask::GossipsubSubscribe {
topic,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())?
}
/// Unsubscribes from a `GossipSub` topic.
///
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to unsubscribe
self.to_task_tx()
.send_py(ToTask::GossipsubUnsubscribe {
topic,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & convert any errors
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
/// Publishes a message with multiple topics to the `GossipSub` network.
///
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
topic,
data,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors => ignore messageID for now!!!
let _ = rx
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())??;
Ok(())
}
// ---- Gossipsub message receiver methods ----
/// Receives the next message from the `GossipSub` network.
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
self.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
.map(|(t, d)| (t, d.pybytes()))
}
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
Ok(self
.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await?
.map(|(t, d)| (t, d.pybytes())))
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next message from the `GossipSub` network.
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
// Ok(self
// .gossipsub_message_rx
// .blocking_lock()
// .try_recv_py()?
// .map(|(t, d)| (t, d.pybytes())))
// }
//
// /// Checks if the `GossipSub` message channel is empty.
// fn gossipsub_is_empty(&self) -> bool {
// self.gossipsub_message_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `GossipSub` messages in the channel.
// fn gossipsub_len(&self) -> usize {
// self.gossipsub_message_rx.blocking_lock().len()
// }
}
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
m.add_class::<exception::PyAllQueuesFullError>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyConnectionUpdate>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyNetworkingHandle>()?;
Ok(())
// Manually submit the run()/recv() stub because the cancelhandle is poorly understood
submit! {
gen_methods_from_python! {
r#"
class PyPeer:
async def run(self): ...
async def recv(self) -> PySwarmEvent: ...
"#
}
}
#[gen_stub_pyclass]
#[pyclass]
pub struct PySwarmEvent(FromSwarm);
#[gen_stub_pymethods]
#[pymethods]
impl PySwarmEvent {
// probably a better way to do this, but...
fn downcast_discovered(&self) -> Option<String> {
if let FromSwarm::Discovered(peer_id) = self.0 {
Some(peer_id.to_base58())
} else {
None
}
}
fn downcast_expired(&self) -> Option<String> {
if let FromSwarm::Expired(peer_id) = self.0 {
Some(peer_id.to_base58())
} else {
None
}
}
fn downcast_message<'py>(
&self,
py: Python<'py>,
) -> Option<(String, String, Bound<'py, PyBytes>)> {
if let FromSwarm::Message(peer_id, topic, data) = &self.0 {
Some((peer_id.to_base58(), topic.clone(), PyBytes::new(py, data)))
} else {
None
}
}
}

View File

@@ -1,159 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519())
}
/// Generate a new ECDSA keypair.
#[staticmethod]
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -1,8 +0,0 @@
//! A module for exposing Rust's libp2p datatypes over Pyo3
//!
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
//! independent identity type of some kind or another. This may require handshaking.
//!
pub mod ident;
pub mod multiaddr;

View File

@@ -1,81 +0,0 @@
use crate::ext::ResultExt as _;
use libp2p::Multiaddr;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr as _;
/// Representation of a Multiaddr.
#[gen_stub_pyclass]
#[pyclass(name = "Multiaddr", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyMultiaddr(pub Multiaddr);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyMultiaddr {
/// Create a new, empty multiaddress.
#[staticmethod]
fn empty() -> Self {
Self(Multiaddr::empty())
}
/// Create a new, empty multiaddress with the given capacity.
#[staticmethod]
fn with_capacity(n: usize) -> Self {
Self(Multiaddr::with_capacity(n))
}
/// Parse a `Multiaddr` value from its byte slice representation.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// Parse a `Multiaddr` value from its string representation.
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// Return the length in bytes of this multiaddress.
fn len(&self) -> usize {
self.0.len()
}
/// Returns true if the length of this multiaddress is 0.
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Return a copy of this [`Multiaddr`]'s byte representation.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_vec();
PyBytes::new(py, &bytes)
}
/// Convert a Multiaddr to a string.
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.to_string()
}
}
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMultiaddr>()?;
Ok(())
}

View File

@@ -1,54 +0,0 @@
#[cfg(test)]
mod tests {
use core::mem::drop;
use core::option::Option::Some;
use core::time::Duration;
use tokio;
use tokio::sync::mpsc;
#[tokio::test]
async fn test_drop_channel() {
struct Ping;
let (tx, mut rx) = mpsc::channel::<Ping>(10);
let _ = tokio::spawn(async move {
println!("TASK: entered");
loop {
tokio::select! {
result = rx.recv() => {
match result {
Some(_) => {
println!("TASK: pinged");
}
None => {
println!("TASK: closing channel");
break;
}
}
}
_ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {
println!("TASK: heartbeat");
}
}
}
println!("TASK: exited");
});
let tx2 = tx.clone();
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
tx.send(Ping).await.expect("Should not fail");
drop(tx);
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
tx2.send(Ping).await.expect("Should not fail");
drop(tx2);
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
}
}

View File

@@ -13,32 +13,14 @@ path = "src/lib.rs"
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
libp2p = { workspace = true, features = ["full"] }

View File

@@ -1,6 +1,6 @@
use futures::stream::StreamExt as _;
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm};
use libp2p::identity;
use networking::{self, FromSwarm, ToSwarm};
use tokio::sync::mpsc;
use tokio::{io, io::AsyncBufReadExt as _, select};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter;
@@ -12,63 +12,51 @@ async fn main() {
.try_init();
// Configure swarm
let mut swarm =
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
let (to_client, mut from_swarm) = mpsc::channel(20);
let (to_swarm, from_client) = mpsc::channel(20);
let mut peer = networking::Peer::new(
"chatroom!".to_string(),
identity::Keypair::generate_ed25519(),
to_client,
from_client,
)
.expect("listen error");
// Create a Gossipsub topic & subscribe
let topic = gossipsub::IdentTopic::new("test-net");
swarm
.behaviour_mut()
.gossipsub
.subscribe(&topic)
.expect("Subscribing to topic failed");
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
let jh = tokio::spawn(async move { peer.run().await });
_ = to_swarm
.send(ToSwarm::Subscribe("chatting".to_string()))
.await;
// Kick it off
loop {
select! {
// on gossipsub outgoing
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
_ = to_swarm.send(ToSwarm::Message("chatting".to_string(), line.into_bytes())).await;
}
event = swarm.select_next_some() => match event {
event = from_swarm.recv() => match event {
// on gossipsub incoming
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
Some(FromSwarm::Message(peer_id,_, data)) => println!(
"\n\nGot message: '{}' from peer: {peer_id}\n\n",
String::from_utf8_lossy(&data),
),
// on discovery
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
discovery::Event::ConnectionClosed {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
Some(FromSwarm::Discovered(peer_id)) => {
println!("\n\nConnected to: {peer_id}\n\n");
}
// ignore outgoing errors: those are normal
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }
Some(FromSwarm::Expired(peer_id)) => {
println!("\n\nDisconnected from: {peer_id}\n\n");
}
Some(FromSwarm::PublishError(e)) => eprintln!("\n\nError {e:?}\n\n"),
None => break,
}
}
}
_ = jh.await;
}

View File

@@ -1,127 +0,0 @@
// Copyright 2018 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::time::Duration;
use std::{error::Error, hash::Hash};
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;
// We create a custom network behaviour that combines Gossipsub and Mdns.
#[derive(NetworkBehaviour)]
struct MyBehaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.try_init();
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_behaviour(|key| {
// Set a custom gossipsub configuration
let gossipsub_config = gossipsub::ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
.build()
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
// build a gossipsub network behaviour
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(key.clone()),
gossipsub_config,
)?;
let mdns =
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
Ok(MyBehaviour { gossipsub, mdns })
})?
.build();
println!("Running swarm with identity {}", swarm.local_peer_id());
// Create a Gossipsub topic
let topic = gossipsub::IdentTopic::new("test-net");
// subscribes to our topic
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"Got message: '{}' with id: {id} from peer: {peer_id}",
String::from_utf8_lossy(&message.data),
),
SwarmEvent::NewListenAddr { address, .. } => {
println!("Local node is listening on {address}");
}
e => {
println!("Other swarm event: {:?}", e);
}
}
}
}
}

View File

@@ -1,44 +0,0 @@
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
https://stackoverflow.com/questions/17169298/af-packet-on-osx
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
https://www.unix.com/man_page/mojave/4/ip/
^OS X manpages for IP
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
^driver kit, system extensions & kexts for macOS
----
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
----> here is a guide on how to set up thunderbolt-ethernet on linux
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
^GPT discussion about making socket interface
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
^GPT discussion about accessing TB on MacOS low level interactions
--------------------------------
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
---------------------------------
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c

View File

@@ -1,383 +0,0 @@
use crate::ext::MultiaddrExt;
use crate::keep_alive;
use delegate::delegate;
use either::Either;
use futures::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
use libp2p::swarm::behaviour::ConnectionEstablished;
use libp2p::swarm::dial_opts::DialOpts;
use libp2p::swarm::{
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
THandlerOutEvent, ToSwarm, dummy,
};
use libp2p::{Multiaddr, PeerId, identity, mdns};
use std::collections::{BTreeSet, HashMap};
use std::convert::Infallible;
use std::io;
use std::net::IpAddr;
use std::task::{Context, Poll};
use std::time::Duration;
use util::wakerdeque::WakerDeque;
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
mod managed {
use libp2p::swarm::NetworkBehaviour;
use libp2p::{identity, mdns, ping};
use std::io;
use std::time::Duration;
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
#[derive(NetworkBehaviour)]
pub struct Behaviour {
mdns: mdns::tokio::Behaviour,
ping: ping::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
mdns: mdns_behaviour(keypair)?,
ping: ping_behaviour(),
})
}
}
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
use mdns::{Config, tokio};
// mDNS config => enable IPv6
let mdns_config = Config {
ttl: MDNS_RECORD_TTL,
query_interval: MDNS_QUERY_INTERVAL,
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
..Default::default()
};
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
Ok(mdns_behaviour?)
}
fn ping_behaviour() -> ping::Behaviour {
ping::Behaviour::new(
ping::Config::new()
.with_timeout(PING_TIMEOUT)
.with_interval(PING_INTERVAL),
)
}
}
/// Events for when a listening connection is truly established and truly closed.
#[derive(Debug, Clone)]
pub enum Event {
ConnectionEstablished {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
ConnectionClosed {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
}
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
///
/// The behaviour operates as such:
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
/// to the swarm.
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
/// immediately, and expired but connected peers are disconnected from immediately.
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
/// connected peers are disconnected from.
pub struct Behaviour {
// state-tracking for managed behaviors & mDNS-discovered peers
managed: managed::Behaviour,
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
retry_delay: Delay, // retry interval
// pending events to emmit => waker-backed Deque to control polling
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
managed: managed::Behaviour::new(keypair)?,
mdns_discovered: HashMap::new(),
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
pending_events: WakerDeque::new(),
})
}
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
self.pending_events.push_back(ToSwarm::Dial {
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
})
}
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
// push front to make this IMMEDIATE
self.pending_events.push_front(ToSwarm::CloseConnection {
peer_id,
connection: CloseConnection::One(connection),
})
}
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
self.dial(p, ma.clone()); // always connect
// get peer's multi-addresses or insert if missing
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
continue;
};
// multiaddress should never already be present - else something has gone wrong
let is_new_addr = mas.insert(ma);
assert!(is_new_addr, "cannot discover a discovered peer");
}
}
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
// at this point, we *must* have the peer
let mas = self
.mdns_discovered
.get_mut(&p)
.expect("nonexistent peer cannot expire");
// at this point, we *must* have the multiaddress
let was_present = mas.remove(&ma);
assert!(was_present, "nonexistent multiaddress cannot expire");
// if empty, remove the peer-id entirely
if mas.is_empty() {
self.mdns_discovered.remove(&p);
}
}
}
fn on_connection_established(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out connected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
fn on_connection_closed(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out disconnected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
}
impl NetworkBehaviour for Behaviour {
type ConnectionHandler =
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
type ToSwarm = Event;
// simply delegate to underlying mDNS behaviour
delegate! {
to self.managed {
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
}
}
fn handle_established_inbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
local_addr: &Multiaddr,
remote_addr: &Multiaddr,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_inbound_connection(
connection_id,
peer,
local_addr,
remote_addr,
)?,
))
}
#[allow(clippy::needless_question_mark)]
fn handle_established_outbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
addr: &Multiaddr,
role_override: Endpoint,
port_use: PortUse,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_outbound_connection(
connection_id,
peer,
addr,
role_override,
port_use,
)?,
))
}
fn on_connection_handler_event(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
event: THandlerOutEvent<Self>,
) {
match event {
Either::Left(ev) => libp2p::core::util::unreachable(ev),
Either::Right(ev) => {
self.managed
.on_connection_handler_event(peer_id, connection_id, ev)
}
}
}
// hook into these methods to drive behavior
fn on_swarm_event(&mut self, event: FromSwarm) {
self.managed.on_swarm_event(event); // let mDNS handle swarm events
// handle swarm events to update internal state:
match event {
FromSwarm::ConnectionEstablished(ConnectionEstablished {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection established event which is filtered correctly
self.on_connection_established(peer_id, connection_id, ip, port)
}
}
FromSwarm::ConnectionClosed(ConnectionClosed {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection closed event which is filtered correctly
self.on_connection_closed(peer_id, connection_id, ip, port)
}
}
// since we are running TCP/IP transport layer, we are assuming that
// no address changes can occur, hence encountering one is a fatal error
FromSwarm::AddressChange(a) => {
unreachable!("unhandlable: address change encountered: {:?}", a)
}
_ => {}
}
}
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
// delegate to managed behaviors for any behaviors they need to perform
match self.managed.poll(cx) {
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
match e {
// handle discovered and expired events from mDNS
managed::BehaviourEvent::Mdns(e) => match e.clone() {
mdns::Event::Discovered(peers) => {
self.handle_mdns_discovered(peers);
}
mdns::Event::Expired(peers) => {
self.handle_mdns_expired(peers);
}
},
// handle ping events => if error then disconnect
managed::BehaviourEvent::Ping(e) => {
if let Err(_) = e.result {
self.close_connection(e.peer, e.connection.clone())
}
}
}
// since we just consumed an event, we should immediately wake just in case
// there are more events to come where that came from
cx.waker().wake_by_ref();
}
// forward any other mDNS event to the swarm or its connection handler(s)
Poll::Ready(e) => {
return Poll::Ready(
e.map_out(|_| unreachable!("events returning to swarm already handled"))
.map_in(Either::Right),
);
}
Poll::Pending => {}
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
if self.retry_delay.poll_unpin(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)
}
}
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
}
// send out any pending events from our own service
if let Some(e) = self.pending_events.pop_front(cx) {
return Poll::Ready(e.map_in(Either::Left));
}
// wait for pending events
Poll::Pending
}
}

View File

@@ -1,44 +0,0 @@
use delegate::delegate;
use libp2p::swarm::handler::ConnectionEvent;
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
use std::task::{Context, Poll};
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
/// the connection alive.
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnectionHandler(dummy::ConnectionHandler);
impl ConnectionHandler {
pub fn new() -> Self {
ConnectionHandler(dummy::ConnectionHandler)
}
}
impl handler::ConnectionHandler for ConnectionHandler {
// delegate types and implementation mostly to dummy handler
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
type InboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
type OutboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
type OutboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
delegate! {
to self.0 {
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
}
}
// specifically override this to force connection to stay alive
fn connection_keep_alive(&self) -> bool {
true
}
}

View File

@@ -1,64 +1,451 @@
//! TODO: crate documentation
//!
//! this is here as a placeholder documentation
//!
//!
use libp2p::{
Multiaddr, PeerId,
futures::StreamExt,
gossipsub::{self, TopicHash},
identify,
identity::Keypair,
mdns,
swarm::{NetworkBehaviour, SwarmEvent, dial_opts::DialOpts},
};
use std::collections::HashMap;
use tokio::sync::mpsc;
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
// #![feature(stmt_expr_attributes)]
// #![feature(unboxed_closures)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
#[derive(Debug)]
pub struct ListenError;
pub mod discovery;
pub mod keep_alive;
pub mod swarm;
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
pub enum FromSwarm {
PublishError(gossipsub::PublishError),
Discovered(PeerId),
Expired(PeerId),
Message(PeerId, String, Vec<u8>),
}
pub enum ToSwarm {
Message(String, Vec<u8>),
Subscribe(String),
Unsubscribe(String),
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use extend::ext;
use libp2p::Multiaddr;
use libp2p::multiaddr::Protocol;
use std::net::IpAddr;
pub struct Peer {
pub swarm: libp2p::Swarm<Behaviour>,
to_client: mpsc::Sender<FromSwarm>,
from_client: mpsc::Receiver<ToSwarm>,
namespace: String,
known_peers: HashMap<PeerId, Vec<Multiaddr>>,
}
impl Peer {
pub fn new(
namespace: String,
kp: Keypair,
to_client: mpsc::Sender<FromSwarm>,
from_client: mpsc::Receiver<ToSwarm>,
) -> Result<Self, ListenError> {
let mut swarm = libp2p::SwarmBuilder::with_existing_identity(kp)
.with_tokio()
.with_quic()
// TODO(evan) .with_bandwidth_metrics()
.with_behaviour(|kp| Behaviour::new(namespace.clone(), kp))
.expect("invalid swarm behaviour")
.build();
#[ext(pub, name = MultiaddrExt)]
impl Multiaddr {
/// If the multiaddress corresponds to a TCP address, extracts it
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {
let mut ps = self.into_iter();
let ip = if let Some(p) = ps.next() {
match p {
Protocol::Ip4(ip) => IpAddr::V4(ip),
Protocol::Ip6(ip) => IpAddr::V6(ip),
_ => return None,
swarm
.listen_on("/ip6/::/udp/0/quic-v1".parse().expect("invalid multiaddr"))
.map_err(|_| ListenError)?;
swarm
.listen_on(
"/ip4/0.0.0.0/udp/0/quic-v1"
.parse()
.expect("invalid multiaddr"),
)
.map_err(|_| ListenError)?;
Ok(Self {
swarm,
to_client,
from_client,
namespace,
known_peers: HashMap::default(),
})
}
pub async fn run(&mut self) -> Result<(), ()> {
loop {
tokio::select! {
event = self.swarm.next() => self.handle_event(event.ok_or(())?).await?,
msg = self.from_client.recv() => self.handle_message(msg.ok_or(())?).await?,
}
}
}
async fn handle_message(&mut self, message: ToSwarm) -> Result<(), ()> {
match message {
ToSwarm::Message(topic, data) => {
if let Err(e) = self
.swarm
.behaviour_mut()
.gossipsub
.publish(TopicHash::from_raw(topic), data)
{
self.to_client
.send(FromSwarm::PublishError(e))
.await
.map_err(|_| ())?;
}
} else {
return None;
};
let Some(Protocol::Tcp(port)) = ps.next() else {
return None;
};
Some((ip, port))
}
ToSwarm::Subscribe(topic) => {
match self
.swarm
.behaviour_mut()
.gossipsub
.subscribe(&gossipsub::IdentTopic::new(topic))
{
Ok(_) => {}
Err(gossipsub::SubscriptionError::NotAllowed) => {
unreachable!("subscription filter hit")
}
Err(gossipsub::SubscriptionError::PublishError(e)) => self
.to_client
.send(FromSwarm::PublishError(e))
.await
.map_err(|_| ())?,
}
}
ToSwarm::Unsubscribe(topic) => {
self.swarm
.behaviour_mut()
.gossipsub
.unsubscribe(&gossipsub::IdentTopic::new(topic));
}
}
Ok(())
}
async fn handle_event(&mut self, event: SwarmEvent<BehaviourEvent>) -> Result<(), ()> {
let SwarmEvent::Behaviour(event) = event else {
return Ok(());
};
match event {
BehaviourEvent::Gossipsub(gossipsub::Event::Message { message, .. }) => {
if let Some(source) = message.source {
self.to_client
.send(FromSwarm::Message(
source,
message.topic.into_string(),
message.data,
))
.await
.map_err(|_| ())?;
}
}
BehaviourEvent::Identify(identify::Event::Received { peer_id, info, .. }) => {
log::debug!(
"identify from {peer_id}: protocol_version='{}' agent_version='{}' (local namespace='{}')",
info.protocol_version, info.agent_version, self.namespace
);
if info.protocol_version == self.namespace {
self.passed_namespace(peer_id);
self.to_client
.send(FromSwarm::Discovered(peer_id))
.await
.map_err(|_| ())?;
} else {
self.failed_namespace(peer_id);
}
}
BehaviourEvent::Mdns(mdns::Event::Discovered(v)) => {
for (peer_id, addr) in v {
self.known_peers.entry(peer_id).or_default().push(addr);
}
for (peer_id, addrs) in &self.known_peers {
// dialopts handles rate limiting, we should check errors if we want to blacklist earlier
let _ = self
.swarm
.dial(DialOpts::peer_id(*peer_id).addresses(addrs.clone()).build());
}
}
BehaviourEvent::Mdns(mdns::Event::Expired(v)) => {
for (peer_id, addr) in v {
let addrs = self.known_peers.entry(peer_id).or_default();
addrs.retain(|a| *a != addr);
if addrs.is_empty() {
self.known_peers.remove(&peer_id);
self.swarm
.behaviour_mut()
.gossipsub
.remove_explicit_peer(&peer_id);
self.to_client
.send(FromSwarm::Expired(peer_id))
.await
.map_err(|_| ())?;
}
}
}
_ => {}
}
Ok(())
}
fn passed_namespace(&mut self, peer_id: PeerId) {
self.swarm
.behaviour_mut()
.gossipsub
.remove_blacklisted_peer(&peer_id);
self.swarm
.behaviour_mut()
.gossipsub
.add_explicit_peer(&peer_id);
}
fn failed_namespace(&mut self, peer_id: PeerId) {
self.swarm
.behaviour_mut()
.gossipsub
.blacklist_peer(&peer_id);
self.swarm
.behaviour_mut()
.gossipsub
.remove_explicit_peer(&peer_id);
}
}
#[derive(NetworkBehaviour)]
pub struct Behaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
identify: identify::Behaviour,
}
impl Behaviour {
fn new(namespace: String, kp: &Keypair) -> Self {
let mdns = mdns::Behaviour::new(mdns::Config::default(), kp.public().to_peer_id())
.expect("mdns behaviour failed to build");
let identify =
identify::Behaviour::new(identify::Config::new_with_signed_peer_record(namespace, kp));
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(kp.clone()),
gossipsub::ConfigBuilder::default()
.max_transmit_size(1024 * 1024)
.validation_mode(gossipsub::ValidationMode::Strict)
.build()
.expect("invalid gossipsub configuration"),
)
.expect("gossipsub behaviour failed ot build");
Self {
gossipsub,
mdns,
identify,
}
}
}
pub(crate) mod private {
#![allow(dead_code)]
// generated tests for right this moment, ill revisit these soon
#[cfg(test)]
mod tests {
use super::*;
use libp2p::multiaddr::Protocol;
use tokio::time::{timeout, Duration};
use std::net::Ipv4Addr;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
fn make_peer(
namespace: &str,
) -> (
Peer,
mpsc::Receiver<FromSwarm>, // events coming out of Peer
mpsc::Sender<ToSwarm>, // commands going into Peer
) {
let kp = Keypair::generate_ed25519();
let (to_client_tx, to_client_rx) = mpsc::channel(64);
let (to_peer_tx, to_peer_rx) = mpsc::channel(64);
let peer = Peer::new(namespace.to_string(), kp, to_client_tx, to_peer_rx)
.expect("Peer::new should succeed in tests");
(peer, to_client_rx, to_peer_tx)
}
async fn next_listen_addr(peer: &mut Peer) -> Multiaddr {
loop {
match peer.swarm.next().await {
Some(SwarmEvent::NewListenAddr { address, .. }) => return address,
Some(_) => {},
None => panic!("swarm stream ended unexpectedly"),
}
}
}
async fn drain_until_discovered(
rx: &mut mpsc::Receiver<FromSwarm>,
who: PeerId,
) -> Result<(), ()> {
// Wait until we see a Discovered(who), ignoring other events.
loop {
match rx.recv().await.ok_or(())? {
FromSwarm::Discovered(pid) if pid == who => return Ok(()),
_ => {},
}
}
}
#[tokio::test]
async fn subscribe_and_unsubscribe_do_not_error() {
let (mut peer, mut events_rx, commands_tx) = make_peer("ns-test");
// Drive the swarm just enough to get at least one listen address event,
// so the background run loop has something initialized.
let _addr = next_listen_addr(&mut peer).await;
// Run the peer loop in the background.
let handle = tokio::spawn(async move {
let _ = peer.run().await;
});
commands_tx
.send(ToSwarm::Subscribe("topic-a".to_string()))
.await
.unwrap();
commands_tx
.send(ToSwarm::Unsubscribe("topic-a".to_string()))
.await
.unwrap();
// We don't *require* any FromSwarm events here; this is mainly a
// smoke test that the message-handling path doesn't panic/hang.
// Still, poll briefly to ensure the task is alive.
let _ = timeout(Duration::from_millis(200), events_rx.recv()).await;
// Shut down: dropping the command sender closes the channel, causing run() to return Err.
drop(commands_tx);
let _ = handle.await;
}
#[tokio::test]
async fn two_peers_same_namespace_can_exchange_a_message() {
let namespace = "ns-e2e";
let topic = "hello-topic";
let (mut p1, mut _p1_events, p1_cmd) = make_peer(namespace);
let (mut p2, mut p2_events, p2_cmd) = make_peer(namespace);
// Learn each peer's listen address before we hand ownership to run() tasks.
let _p1_addr = next_listen_addr(&mut p1).await;
let _p2_addr = next_listen_addr(&mut p2).await;
let p1_id = *p1.swarm.local_peer_id();
let p2_id = *p2.swarm.local_peer_id();
// Dial peer 2 from peer 1 using /p2p/<peerid> encapsulation.
let mut loopback_addr = Multiaddr::empty();
loopback_addr.push(Protocol::Ip4(Ipv4Addr::LOCALHOST));
let p2_dial = loopback_addr
.clone()
.with(Protocol::P2p(p2_id));
p1.swarm.dial(p2_dial).expect("dial should start");
// Start both peers.
let h1 = tokio::spawn(async move { let _ = p1.run().await; });
let h2 = tokio::spawn(async move { let _ = p2.run().await; });
// Subscribe both sides.
p1_cmd
.send(ToSwarm::Subscribe(topic.to_string()))
.await
.unwrap();
p2_cmd
.send(ToSwarm::Subscribe(topic.to_string()))
.await
.unwrap();
// Wait for discovery (Identify exchange) so gossipsub will explicitly peer.
let discovered = timeout(Duration::from_secs(10), drain_until_discovered(&mut p2_events, p1_id)).await;
assert!(
discovered.is_ok(),
"timed out waiting for p2 to discover p1 via identify"
);
// Now publish from p1 and expect p2 to receive it.
let payload = b"ping".to_vec();
p1_cmd
.send(ToSwarm::Message(topic.to_string(), payload.clone()))
.await
.unwrap();
let got = timeout(Duration::from_secs(10), async {
loop {
match p2_events.recv().await {
Some(FromSwarm::Message(from, t, data)) => {
if from == p1_id && t == topic && data == payload {
return true;
}
}
Some(_) => {}
None => return false,
}
}
})
.await
.expect("timed out waiting for message");
assert!(got, "did not receive expected gossipsub message");
// Shutdown.
drop(p1_cmd);
drop(p2_cmd);
let _ = h1.await;
let _ = h2.await;
}
#[tokio::test]
async fn different_namespaces_do_not_emit_discovered() {
let topic = "ns-mismatch-topic";
let (mut p1, mut _p1_events, p1_cmd) = make_peer("ns-a");
let (mut p2, mut p2_events, p2_cmd) = make_peer("ns-b");
let p1_addr = next_listen_addr(&mut p1).await;
let p2_addr = next_listen_addr(&mut p2).await;
let p1_id = *p1.swarm.local_peer_id();
let p2_id = *p2.swarm.local_peer_id();
// Dial so Identify runs, but namespaces differ.
let p2_dial = p2_addr.clone().with(Protocol::P2p(p2_id));
p1.swarm.dial(p2_dial).expect("dial should start");
let h1 = tokio::spawn(async move { let _ = p1.run().await; });
let h2 = tokio::spawn(async move { let _ = p2.run().await; });
// Subscribe both (even if subscribed, discovery should not happen with mismatched namespace).
p1_cmd
.send(ToSwarm::Subscribe(topic.to_string()))
.await
.unwrap();
p2_cmd
.send(ToSwarm::Subscribe(topic.to_string()))
.await
.unwrap();
// Assert p2 does NOT produce Discovered(p1) within a short window.
let no_discovered = timeout(Duration::from_secs(10), async {
loop {
match p2_events.recv().await {
Some(FromSwarm::Discovered(pid)) if pid == p1_id => return false,
Some(_) => {},
None => return true,
}
}
})
.await;
assert!(
no_discovered.is_ok() && no_discovered.unwrap(),
"p2 unexpectedly discovered p1 despite namespace mismatch"
);
// Shutdown.
drop(p1_cmd);
drop(p2_cmd);
let _ = h1.await;
let _ = h2.await;
// Silence unused warnings if your compiler complains about p1_addr/p2_addr being unused
let _ = (p1_addr, p2_addr);
}
}

View File

@@ -1,145 +0,0 @@
use crate::alias;
use crate::swarm::transport::tcp_transport;
pub use behaviour::{Behaviour, BehaviourEvent};
use libp2p::{SwarmBuilder, identity};
pub type Swarm = libp2p::Swarm<Behaviour>;
/// The current version of the network: this prevents devices running different versions of the
/// software from interacting with each other.
///
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
/// even be, and how to inject the right version into this config/initialization. E.g. should
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
/// Create and configure a swarm which listens to all ports on OS
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
.with_behaviour(Behaviour::new)?
.build();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(swarm)
}
mod transport {
use crate::alias;
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
use futures::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256;
use libp2p::core::muxing;
use libp2p::core::transport::Boxed;
use libp2p::pnet::{PnetError, PnetOutput};
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
use std::{env, sync::LazyLock};
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
/// See [`pnet_upgrade`] for more.
static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| {
let builder = Sha3_256::new().update(b"exo_discovery_network");
if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) {
let bytes = var.into_bytes();
builder.update(&bytes)
} else {
builder.update(NETWORK_VERSION)
}
.finalize()
});
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
/// also different-versioned instances of this same network.
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
async fn pnet_upgrade<TSocket>(
socket: TSocket,
_: impl Sized,
) -> Result<PnetOutput<TSocket>, PnetError>
where
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
use pnet::{PnetConfig, PreSharedKey};
PnetConfig::new(PreSharedKey::new(*PNET_PRESHARED_KEY))
.handshake(socket)
.await
}
/// TCP/IP transport layer configuration.
pub fn tcp_transport(
keypair: &identity::Keypair,
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
use libp2p::{
core::upgrade::Version,
tcp::{Config, tokio},
};
// `TCP_NODELAY` enabled => avoid latency
let tcp_config = Config::default().nodelay(true);
// V1 + lazy flushing => 0-RTT negotiation
let upgrade_version = Version::V1Lazy;
// Noise is faster than TLS + we don't care much for security
let noise_config = noise::Config::new(keypair)?;
// Use default Yamux config for multiplexing
let yamux_config = yamux::Config::default();
// Create new Tokio-driven TCP/IP transport layer
let base_transport = tokio::Transport::new(tcp_config)
.and_then(pnet_upgrade)
.upgrade(upgrade_version)
.authenticate(noise_config)
.multiplex(yamux_config);
// Return boxed transport (to flatten complex type)
Ok(base_transport.boxed())
}
}
mod behaviour {
use crate::{alias, discovery};
use libp2p::swarm::NetworkBehaviour;
use libp2p::{gossipsub, identity};
use std::time::Duration;
/// Behavior of the Swarm which composes all desired behaviors:
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
#[derive(NetworkBehaviour)]
pub struct Behaviour {
pub discovery: discovery::Behaviour,
pub gossipsub: gossipsub::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
Ok(Self {
discovery: discovery::Behaviour::new(keypair)?,
gossipsub: gossipsub_behaviour(keypair),
})
}
}
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
// build a gossipsub network behaviour
// => signed message authenticity + strict validation mode means the message-ID is
// automatically provided by gossipsub w/out needing to provide custom message-ID function
gossipsub::Behaviour::new(
MessageAuthenticity::Signed(keypair.clone()),
ConfigBuilder::default()
.publish_queue_duration(Duration::from_secs(15))
.max_transmit_size(1024 * 1024)
.validation_mode(ValidationMode::Strict)
.build()
.expect("the configuration should always be valid"),
)
.expect("creating gossipsub behavior should always work")
}
}

View File

@@ -1,7 +0,0 @@
// maybe this will hold test in the future...??
#[cfg(test)]
mod tests {
#[test]
fn does_nothing() {}
}

View File

@@ -1,2 +0,0 @@
[toolchain]
channel = "nightly"

View File

@@ -1,25 +0,0 @@
[package]
name = "util"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "util"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
# macro dependencies
extend = { workspace = true }
# utility dependencies
thiserror = { workspace = true }
once_cell = { workspace = true }
internment = { workspace = true }
derive_more = { workspace = true }
bon = { workspace = true }
recursion = { workspace = true }

View File

@@ -1,53 +0,0 @@
//! TODO: crate documentation
//!
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub mod nonempty;
pub mod wakerdeque;
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub mod ext {
use extend::ext;
#[ext(pub, name = BoxedSliceExt)]
impl<T> Box<[T]> {
#[inline]
fn map<B, F>(self, f: F) -> Box<[B]>
where
F: FnMut(T) -> B,
{
self.into_iter().map(f).collect()
}
}
#[ext(pub, name = VecExt)]
impl<T> Vec<T> {
#[inline]
fn map<B, F>(self, f: F) -> Vec<B>
where
F: FnMut(T) -> B,
{
self.into_iter().map(f).collect()
}
}
}

View File

@@ -1,138 +0,0 @@
use std::slice::SliceIndex;
use std::{ops, slice};
use thiserror::Error;
#[derive(Error, Debug)]
#[error("Cannot create to `NonemptyArray` because the supplied slice is empty")]
pub struct EmptySliceError;
/// A pointer to a non-empty fixed-size slice allocated on the heap.
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[repr(transparent)]
pub struct NonemptyArray<T>(Box<[T]>);
#[allow(clippy::arbitrary_source_item_ordering)]
impl<T> NonemptyArray<T> {
#[inline]
pub fn singleton(value: T) -> Self {
Self(Box::new([value]))
}
#[allow(clippy::missing_errors_doc)]
#[inline]
pub fn try_from_boxed_slice<S: Into<Box<[T]>>>(
boxed_slice: S,
) -> Result<Self, EmptySliceError> {
let boxed_slice = boxed_slice.into();
if boxed_slice.is_empty() {
Err(EmptySliceError)
} else {
Ok(Self(boxed_slice))
}
}
#[must_use]
#[inline]
pub fn into_boxed_slice(self) -> Box<[T]> {
self.0
}
#[must_use]
#[inline]
pub fn to_vec(&self) -> Vec<T>
where
T: Clone,
{
self.0.to_vec()
}
#[must_use]
#[inline]
pub const fn as_slice(&self) -> &[T] {
&self.0
}
#[allow(clippy::indexing_slicing)]
#[must_use]
#[inline]
pub fn first(&self) -> &T {
&self.0[0]
}
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
#[must_use]
#[inline]
pub fn last(&self) -> &T {
&self.0[self.0.len() - 1]
}
#[must_use]
#[inline]
pub fn get<I>(&self, index: I) -> Option<&I::Output>
where
I: SliceIndex<[T]>,
{
self.0.get(index)
}
#[allow(clippy::len_without_is_empty)]
#[must_use]
#[inline]
pub const fn len(&self) -> usize {
self.0.len()
}
#[allow(clippy::iter_without_into_iter)]
#[inline]
pub fn iter(&self) -> slice::Iter<'_, T> {
self.0.iter()
}
#[allow(clippy::iter_without_into_iter)]
#[inline]
pub fn iter_mut(&mut self) -> slice::IterMut<'_, T> {
self.0.iter_mut()
}
#[inline]
#[must_use]
pub fn map<U, F: FnMut(T) -> U>(self, f: F) -> NonemptyArray<U> {
NonemptyArray(self.0.into_iter().map(f).collect())
}
}
impl<T> From<NonemptyArray<T>> for Box<[T]> {
#[inline]
fn from(value: NonemptyArray<T>) -> Self {
value.into_boxed_slice()
}
}
impl<T> ops::Index<usize> for NonemptyArray<T> {
type Output = T;
#[inline]
fn index(&self, index: usize) -> &Self::Output {
self.0.index(index)
}
}
impl<T> IntoIterator for NonemptyArray<T> {
type Item = T;
type IntoIter = std::vec::IntoIter<T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.into_boxed_slice().into_vec().into_iter()
}
}
impl<'a, T> IntoIterator for &'a NonemptyArray<T> {
type Item = &'a T;
type IntoIter = slice::Iter<'a, T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}

View File

@@ -1,55 +0,0 @@
use std::collections::VecDeque;
use std::fmt::{Debug, Formatter};
use std::task::{Context, Waker};
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
pub struct WakerDeque<T> {
waker: Option<Waker>,
deque: VecDeque<T>,
}
impl<T: Debug> Debug for WakerDeque<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.deque.fmt(f)
}
}
impl<T> WakerDeque<T> {
pub fn new() -> Self {
Self {
waker: None,
deque: VecDeque::new(),
}
}
fn update(&mut self, cx: &mut Context<'_>) {
self.waker = Some(cx.waker().clone());
}
fn wake(&mut self) {
let Some(ref mut w) = self.waker else { return };
w.wake_by_ref();
self.waker = None;
}
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
self.update(cx);
self.deque.pop_front()
}
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
self.update(cx);
self.deque.pop_back()
}
pub fn push_front(&mut self, value: T) {
self.wake();
self.deque.push_front(value);
}
pub fn push_back(&mut self, value: T) {
self.wake();
self.deque.push_back(value);
}
}

View File

@@ -1,4 +1,5 @@
import argparse
import importlib.metadata
import itertools
import multiprocessing as mp
import os
@@ -45,9 +46,9 @@ class Node:
@classmethod
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
node_id = NodeId(keypair.to_string())
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
router = Router.create(keypair, namespace=args.namespace)
await router.register_topic(topics.GLOBAL_EVENTS)
await router.register_topic(topics.LOCAL_EVENTS)
await router.register_topic(topics.COMMANDS)
@@ -73,7 +74,7 @@ class Node:
else:
download_coordinator = None
if args.spawn_api:
if not args.no_api:
api = API(
node_id,
session_id,
@@ -253,7 +254,7 @@ def main():
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
logger.info(f"Namespace: {args.namespace}")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
@@ -270,13 +271,13 @@ def main():
class Args(CamelCaseModel):
verbosity: int = 0
force_master: bool = False
spawn_api: bool = False
api_port: PositiveInt = 52415
tb_only: bool = False
verbosity: int
force_master: bool
no_api: bool
api_port: PositiveInt
no_worker: bool = False
no_downloads: bool = False
namespace: str
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -306,14 +307,15 @@ class Args(CamelCaseModel):
)
parser.add_argument(
"--no-api",
action="store_false",
dest="spawn_api",
action="store_true",
help="Disable the API server for this node",
)
parser.add_argument(
"--api-port",
type=int,
dest="api_port",
default=52415,
help="Which port the API server will be available on",
)
parser.add_argument(
"--no-worker",
@@ -324,6 +326,11 @@ class Args(CamelCaseModel):
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
parser.add_argument(
"--namespace",
default=importlib.metadata.version("exo"),
help="Set the EXO namespace to run multiple isolated clusters",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -74,6 +74,7 @@ from exo.shared.types.api import (
ErrorResponse,
FinishReason,
GenerationStats,
HuggingFaceSearchResult,
ImageData,
ImageEditsTaskParams,
ImageGenerationResponse,
@@ -262,6 +263,7 @@ class API:
self.app.get("/v1/models")(self.get_models)
self.app.post("/models/add")(self.add_custom_model)
self.app.delete("/models/custom/{model_id:path}")(self.delete_custom_model)
self.app.get("/models/search")(self.search_models)
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
@@ -1222,6 +1224,10 @@ class API:
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
is_custom=is_custom_card(card.model_id),
family=card.family,
quantization=card.quantization,
base_model=card.base_model,
capabilities=card.capabilities,
)
for card in await get_model_cards()
]
@@ -1257,6 +1263,30 @@ class API:
{"message": "Model card deleted", "model_id": str(model_id)}
)
async def search_models(
self, query: str = "", limit: int = 20
) -> list[HuggingFaceSearchResult]:
"""Search HuggingFace Hub for mlx-community models."""
from huggingface_hub import list_models
results = list_models(
search=query or None,
author="mlx-community",
sort="downloads",
limit=limit,
)
return [
HuggingFaceSearchResult(
id=m.id,
author=m.author or "",
downloads=m.downloads or 0,
likes=m.likes or 0,
last_modified=str(m.last_modified or ""),
tags=list(m.tags or []),
)
for m in results
]
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -369,7 +369,7 @@ class Master:
await self._handle_traces_collected(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
logger.trace(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)

View File

@@ -41,7 +41,7 @@ from exo.utils.channels import channel
@pytest.mark.asyncio
async def test_master():
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
node_id = NodeId(keypair.to_string())
session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[ForwarderEvent]()
@@ -72,7 +72,7 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
sender_node_id = NodeId(f"{keypair.to_string()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(

View File

@@ -1,37 +1,9 @@
from enum import Enum
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
from exo.shared.types.common import NodeId
from exo.utils.pydantic_ext import CamelCaseModel
"""Serialisable types for Connection Updates/Messages"""
class ConnectionMessageType(Enum):
Connected = 0
Disconnected = 1
@staticmethod
def from_update_type(update_type: ConnectionUpdateType):
match update_type:
case ConnectionUpdateType.Connected:
return ConnectionMessageType.Connected
case ConnectionUpdateType.Disconnected:
return ConnectionMessageType.Disconnected
class ConnectionMessage(CamelCaseModel):
node_id: NodeId
connection_type: ConnectionMessageType
remote_ipv4: str
remote_tcp_port: int
@classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls(
node_id=NodeId(update.peer_id.to_base58()),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,
)
expired: bool

View File

@@ -1,5 +1,5 @@
from copy import copy
from itertools import count
from dataclasses import dataclass, field
from math import inf
from os import PathLike
from pathlib import Path
@@ -13,15 +13,14 @@ from anyio import (
)
from anyio.abc import TaskGroup
from exo_pyo3_bindings import (
AllQueuesFullError,
Keypair,
NetworkingHandle,
NoPeersSubscribedToTopicError,
PyPeer,
)
from filelock import FileLock
from loguru import logger
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
from exo.shared.types.common import NodeId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.pydantic_ext import CamelCaseModel
@@ -98,28 +97,32 @@ class TopicRouter[T: CamelCaseModel]:
)
@dataclass
class Router:
@classmethod
def create(cls, identity: Keypair) -> "Router":
return cls(handle=NetworkingHandle(identity))
_peer: PyPeer
topic_routers: dict[str, TopicRouter[CamelCaseModel]] = field(
init=False, default_factory=dict
)
networking_receiver: Receiver[tuple[str, bytes]] = field(init=False)
_tmp_networking_sender: Sender[tuple[str, bytes]] | None = field(init=False)
_tg: TaskGroup | None = None
def __init__(self, handle: NetworkingHandle):
self.topic_routers: dict[str, TopicRouter[CamelCaseModel]] = {}
send, recv = channel[tuple[str, bytes]]()
self.networking_receiver: Receiver[tuple[str, bytes]] = recv
self._net: NetworkingHandle = handle
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
self._id_count = count()
self._tg: TaskGroup | None = None
def __post_init__(self):
self._tmp_networking_sender, self.networking_receiver = channel()
@classmethod
def create(cls, identity: Keypair, namespace: str) -> "Router":
return cls(_peer=PyPeer.new(identity, namespace))
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
assert self._tg is None, "Attempted to register topic after setup time"
send = self._tmp_networking_sender
if send:
self._tmp_networking_sender = None
else:
send = self.networking_receiver.clone_sender()
router = TopicRouter[T](topic, send)
if self._tg is not None:
self._tg.start_soon(router.run)
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
await self._networking_subscribe(str(topic.topic))
@@ -146,14 +149,18 @@ class Router:
async def run(self):
logger.debug("Starting Router")
async def _peer_run():
await self._peer.run()
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
tg.start_soon(_peer_run)
# Router only shuts down if you cancel it.
await sleep_forever()
for topic in self.topic_routers:
@@ -167,46 +174,57 @@ class Router:
async def _networking_subscribe(self, topic: str):
logger.info(f"Subscribing to {topic}")
await self._net.gossipsub_subscribe(topic)
await self._peer.subscribe(topic)
async def _networking_unsubscribe(self, topic: str):
logger.info(f"Unsubscribing from {topic}")
await self._net.gossipsub_unsubscribe(topic)
await self._peer.unsubscribe(topic)
async def _networking_recv(self):
while True:
topic, data = await self._net.gossipsub_recv()
logger.trace(f"Received message on {topic} with payload {data}")
try:
swarm_event = await self._peer.recv()
except ValueError:
logger.error("Message too large for gossipsub, dropped")
continue
except ConnectionError:
logger.error("All peer queues full, network overloaded")
continue
except RuntimeError:
break
cm = None
if (peer_id := swarm_event.downcast_discovered()) is not None:
cm = ConnectionMessage(node_id=NodeId(peer_id), expired=False)
if (peer_id := swarm_event.downcast_expired()) is not None:
cm = ConnectionMessage(node_id=NodeId(peer_id), expired=True)
if cm is not None:
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(cm)
continue
assert (msg := swarm_event.downcast_message()) is not None
_origin, topic, payload = msg
logger.debug(f"Received message on {topic} with payload {payload}")
if topic not in self.topic_routers:
logger.warning(f"Received message on unknown or inactive topic {topic}")
continue
router = self.topic_routers[topic]
await router.publish_bytes(data)
async def _networking_recv_connection_messages(self):
while True:
update = await self._net.connection_update_recv()
message = ConnectionMessage.from_update(update)
logger.trace(
f"Received message on connection_messages with payload {message}"
)
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(message)
await router.publish_bytes(payload)
async def _networking_publish(self):
with self.networking_receiver as networked_items:
async for topic, data in networked_items:
try:
logger.trace(f"Sending message on {topic} with payload {data}")
await self._net.gossipsub_publish(topic, data)
# As a hack, this also catches AllQueuesFull
# Need to fix that ASAP.
except (NoPeersSubscribedToTopicError, AllQueuesFullError):
pass
await self._peer.send(topic, data)
except RuntimeError:
break
def get_node_id_keypair(
@@ -217,7 +235,7 @@ def get_node_id_keypair(
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519()
return Keypair.generate()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")

View File

@@ -39,7 +39,7 @@ RESOURCES_DIR = (
)
_DASHBOARD_DIR_ENV = os.environ.get("EXO_DASHBOARD_DIR", None)
DASHBOARD_DIR = (
find_dashboard() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
find_dashboard() if _DASHBOARD_DIR_ENV is None else Path.home() / _DASHBOARD_DIR_ENV
)
# Log files (data/logs or cache)

View File

@@ -78,6 +78,10 @@ class ModelCard(CamelCaseModel):
supports_tensor: bool
tasks: list[ModelTask]
components: list[ComponentInfo] | None = None
family: str = ""
quantization: str = ""
base_model: str = ""
capabilities: list[str] = []
@field_validator("tasks", mode="before")
@classmethod

View File

@@ -1,7 +1,7 @@
import pytest
from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId
@@ -330,9 +330,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
await cm_tx.send(
ConnectionMessage(
node_id=NodeId(),
connection_type=ConnectionMessageType.Connected,
remote_ipv4="",
remote_tcp_port=0,
expired=False,
)
)

View File

@@ -43,6 +43,10 @@ class ModelListModel(BaseModel):
supports_tensor: bool = Field(default=False)
tasks: list[str] = Field(default=[])
is_custom: bool = Field(default=False)
family: str = Field(default="")
quantization: str = Field(default="")
base_model: str = Field(default="")
capabilities: list[str] = Field(default_factory=list)
class ModelList(BaseModel):
@@ -206,6 +210,15 @@ class AddCustomModelParams(BaseModel):
model_id: ModelId
class HuggingFaceSearchResult(BaseModel):
id: str
author: str = ""
downloads: int = 0
likes: int = 0
last_modified: str = ""
tags: list[str] = Field(default_factory=list)
class PlaceInstanceParams(BaseModel):
model_id: ModelId
sharding: Sharding = Sharding.Pipeline

View File

@@ -164,6 +164,12 @@ def _inner_model(model: nn.Module) -> nn.Module:
if isinstance(inner, nn.Module):
return inner
inner = getattr(model, "language_model", None)
if isinstance(inner, nn.Module):
inner_inner = getattr(inner, "model", None)
if isinstance(inner_inner, nn.Module):
return inner_inner
raise ValueError("Model must either have a 'model' or 'transformer' attribute")

View File

@@ -384,6 +384,17 @@ def load_tokenizer_for_model_id(
eos_token_ids=eos_token_ids,
)
if "gemma-3" in model_id_lower:
gemma_3_eos_id = 1
gemma_3_end_of_turn_id = 106
if tokenizer.eos_token_ids is not None:
if gemma_3_end_of_turn_id not in tokenizer.eos_token_ids:
tokenizer.eos_token_ids = list(tokenizer.eos_token_ids) + [
gemma_3_end_of_turn_id
]
else:
tokenizer.eos_token_ids = [gemma_3_eos_id, gemma_3_end_of_turn_id]
return tokenizer

View File

@@ -334,7 +334,7 @@ class Worker:
session=self.session_id,
event=event,
)
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
logger.trace(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe