Compare commits

..

2 Commits

Author SHA1 Message Date
Sami Khan
3e9eb93f82 exo theme 2026-02-19 04:21:58 +05:00
Sami Khan
ab622f79c3 EXO iOS app 2026-02-18 06:40:07 +05:00
109 changed files with 5673 additions and 687 deletions

View File

@@ -1,46 +0,0 @@
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from .base import BaseModelArgs
from .deepseek_v32 import Model as DSV32Model
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
index_head_dim: int
index_n_heads: int
index_topk: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: int
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_parameters: Dict[str, Any]
attention_bias: bool
rope_scaling: Dict[str, Any] | None
rope_theta: float | None
class Model(DSV32Model):
def __init__(self, config: ModelArgs) -> None: ...

150
Cargo.lock generated
View File

@@ -141,6 +141,12 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
[[package]]
name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "asn1-rs"
version = "0.7.1"
@@ -298,6 +304,19 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
[[package]]
name = "bigdecimal"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
dependencies = [
"autocfg",
"libm",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "bimap"
version = "0.6.3"
@@ -497,6 +516,15 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
[[package]]
name = "convert_case"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
dependencies = [
"unicode-segmentation",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -673,6 +701,17 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "delegate"
version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "780eb241654bf097afb00fc5f054a09b687dad862e485fdcf8399bb056565370"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "der"
version = "0.7.10"
@@ -707,6 +746,29 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "derive_more"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
dependencies = [
"convert_case",
"proc-macro2",
"quote",
"rustc_version",
"syn 2.0.111",
"unicode-xid",
]
[[package]]
name = "digest"
version = "0.10.7"
@@ -876,16 +938,25 @@ dependencies = [
name = "exo_pyo3_bindings"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"env_logger",
"extend",
"futures-lite",
"futures",
"impl-trait-for-tuples",
"libp2p",
"log",
"networking",
"once_cell",
"pin-project",
"pyo3",
"pyo3-async-runtimes",
"pyo3-log",
"pyo3-stub-gen",
"thiserror 2.0.17",
"thread_local",
"tokio",
"util",
]
[[package]]
@@ -1569,6 +1640,17 @@ dependencies = [
"xmltree",
]
[[package]]
name = "impl-trait-for-tuples"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "indexmap"
version = "2.12.1"
@@ -1747,6 +1829,12 @@ version = "0.2.178"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
name = "libm"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
[[package]]
name = "libp2p"
version = "0.56.0"
@@ -2735,13 +2823,20 @@ dependencies = [
name = "networking"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures-lite",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"libp2p",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
]
[[package]]
@@ -2823,6 +2918,17 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@@ -3173,14 +3279,28 @@ version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
dependencies = [
"bigdecimal",
"either",
"hashbrown 0.16.1",
"indexmap",
"indoc",
"inventory",
"libc",
"lock_api",
"memoffset",
"num-bigint",
"num-complex",
"num-rational",
"num-traits",
"once_cell",
"ordered-float",
"parking_lot",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"rust_decimal",
"smallvec",
"unindent",
]
@@ -3621,6 +3741,16 @@ dependencies = [
"tokio",
]
[[package]]
name = "rust_decimal"
version = "1.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
dependencies = [
"arrayvec",
"num-traits",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
@@ -4485,12 +4615,24 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-segmentation"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_names2"
version = "1.3.0"
@@ -4571,6 +4713,10 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "util"
version = "0.0.1"
[[package]]
name = "uuid"
version = "1.19.0"

View File

@@ -3,6 +3,7 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/util",
]
[workspace.package]
@@ -23,22 +24,51 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
util = { path = "rust/util" }
# Proc-macro authoring tools
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0"
darling = "0.20"
# Macro dependecies
extend = "1.2"
delegate = "0.13"
impl-trait-for-tuples = "0.2"
clap = "4.5"
derive_more = { version = "2.0.1", features = ["display"] }
pin-project = "1"
# Utility dependencies
itertools = "0.14"
thiserror = "2"
internment = "0.8"
recursion = "0.5"
regex = "1.11"
once_cell = "1.21"
thread_local = "1.1"
bon = "3.4"
generativity = "1.1"
anyhow = "1.0"
keccak-const = "0.2"
# Functional generics/lenses frameworks
frunk_core = "0.4"
frunk = "0.4"
frunk_utils = "0.2"
frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-util = "0.3"
futures-timer = "3.0"
# Data structures
either = "1.15"
ordered-float = "5.0"
ahash = "0.8"
# Tracing/logging
log = "0.4"

View File

@@ -72,23 +72,16 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
```bash
nix run .#exo
```
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
- [node](https://github.com/nodejs/node) (for building the dashboard)
```bash
brew install uv macmon node
```

View File

@@ -0,0 +1,628 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 77;
objects = {
/* Begin PBXBuildFile section */
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17512F44F359009C51A3 /* MLXLLM */; };
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17532F44F359009C51A3 /* MLXLMCommon */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
proxyType = 1;
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
remoteInfo = "EXO-iOS";
};
E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
proxyType = 1;
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
remoteInfo = "EXO-iOS";
};
/* End PBXContainerItemProxy section */
/* Begin PBXFileReference section */
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "EXO-iOS.app"; sourceTree = BUILT_PRODUCTS_DIR; };
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSTests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSUITests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
/* Begin PBXFileSystemSynchronizedBuildFileExceptionSet section */
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */ = {
isa = PBXFileSystemSynchronizedBuildFileExceptionSet;
membershipExceptions = (
Info.plist,
);
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
};
/* End PBXFileSystemSynchronizedBuildFileExceptionSet section */
/* Begin PBXFileSystemSynchronizedRootGroup section */
E09D16712F44CA1E009C51A3 /* EXO-iOS */ = {
isa = PBXFileSystemSynchronizedRootGroup;
exceptions = (
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */,
);
path = "EXO-iOS";
sourceTree = "<group>";
};
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */ = {
isa = PBXFileSystemSynchronizedRootGroup;
path = "EXO-iOSTests";
sourceTree = "<group>";
};
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */ = {
isa = PBXFileSystemSynchronizedRootGroup;
path = "EXO-iOSUITests";
sourceTree = "<group>";
};
/* End PBXFileSystemSynchronizedRootGroup section */
/* Begin PBXFrameworksBuildPhase section */
E09D166C2F44CA1E009C51A3 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */,
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16792F44CA20009C51A3 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16832F44CA20009C51A3 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
E09D16662F44CA1E009C51A3 = {
isa = PBXGroup;
children = (
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
E09D16702F44CA1E009C51A3 /* Products */,
);
sourceTree = "<group>";
};
E09D16702F44CA1E009C51A3 /* Products */ = {
isa = PBXGroup;
children = (
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */,
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */,
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */,
);
name = Products;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
E09D166E2F44CA1E009C51A3 /* EXO-iOS */ = {
isa = PBXNativeTarget;
buildConfigurationList = E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */;
buildPhases = (
E09D166B2F44CA1E009C51A3 /* Sources */,
E09D166C2F44CA1E009C51A3 /* Frameworks */,
E09D166D2F44CA1E009C51A3 /* Resources */,
);
buildRules = (
);
dependencies = (
);
fileSystemSynchronizedGroups = (
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
);
name = "EXO-iOS";
packageProductDependencies = (
E09D17512F44F359009C51A3 /* MLXLLM */,
E09D17532F44F359009C51A3 /* MLXLMCommon */,
);
productName = "EXO-iOS";
productReference = E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */;
productType = "com.apple.product-type.application";
};
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */ = {
isa = PBXNativeTarget;
buildConfigurationList = E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */;
buildPhases = (
E09D16782F44CA20009C51A3 /* Sources */,
E09D16792F44CA20009C51A3 /* Frameworks */,
E09D167A2F44CA20009C51A3 /* Resources */,
);
buildRules = (
);
dependencies = (
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */,
);
fileSystemSynchronizedGroups = (
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
);
name = "EXO-iOSTests";
packageProductDependencies = (
);
productName = "EXO-iOSTests";
productReference = E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */;
productType = "com.apple.product-type.bundle.unit-test";
};
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */ = {
isa = PBXNativeTarget;
buildConfigurationList = E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */;
buildPhases = (
E09D16822F44CA20009C51A3 /* Sources */,
E09D16832F44CA20009C51A3 /* Frameworks */,
E09D16842F44CA20009C51A3 /* Resources */,
);
buildRules = (
);
dependencies = (
E09D16882F44CA20009C51A3 /* PBXTargetDependency */,
);
fileSystemSynchronizedGroups = (
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
);
name = "EXO-iOSUITests";
packageProductDependencies = (
);
productName = "EXO-iOSUITests";
productReference = E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */;
productType = "com.apple.product-type.bundle.ui-testing";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
E09D16672F44CA1E009C51A3 /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastSwiftUpdateCheck = 2620;
LastUpgradeCheck = 2620;
TargetAttributes = {
E09D166E2F44CA1E009C51A3 = {
CreatedOnToolsVersion = 26.2;
};
E09D167B2F44CA20009C51A3 = {
CreatedOnToolsVersion = 26.2;
TestTargetID = E09D166E2F44CA1E009C51A3;
};
E09D16852F44CA20009C51A3 = {
CreatedOnToolsVersion = 26.2;
TestTargetID = E09D166E2F44CA1E009C51A3;
};
};
};
buildConfigurationList = E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */;
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = E09D16662F44CA1E009C51A3;
minimizedProjectReferenceProxies = 1;
packageReferences = (
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */,
);
preferredProjectObjectVersion = 77;
productRefGroup = E09D16702F44CA1E009C51A3 /* Products */;
projectDirPath = "";
projectRoot = "";
targets = (
E09D166E2F44CA1E009C51A3 /* EXO-iOS */,
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */,
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */,
);
};
/* End PBXProject section */
/* Begin PBXResourcesBuildPhase section */
E09D166D2F44CA1E009C51A3 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D167A2F44CA20009C51A3 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16842F44CA20009C51A3 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
E09D166B2F44CA1E009C51A3 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16782F44CA20009C51A3 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16822F44CA20009C51A3 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin PBXTargetDependency section */
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */ = {
isa = PBXTargetDependency;
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
targetProxy = E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */;
};
E09D16882F44CA20009C51A3 /* PBXTargetDependency */ = {
isa = PBXTargetDependency;
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
targetProxy = E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */;
};
/* End PBXTargetDependency section */
/* Begin XCBuildConfiguration section */
E09D168E2F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
};
name = Debug;
};
E09D168F2F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
SDKROOT = iphoneos;
SWIFT_COMPILATION_MODE = wholemodule;
VALIDATE_PRODUCT = YES;
};
name = Release;
};
E09D16912F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = 3M3M67U93M;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_FILE = "EXO-iOS/Info.plist";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = YES;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
E09D16922F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = 3M3M67U93M;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_FILE = "EXO-iOS/Info.plist";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = YES;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Release;
};
E09D16942F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUNDLE_LOADER = "$(TEST_HOST)";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
};
name = Debug;
};
E09D16952F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUNDLE_LOADER = "$(TEST_HOST)";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
};
name = Release;
};
E09D16972F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_TARGET_NAME = "EXO-iOS";
};
name = Debug;
};
E09D16982F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_TARGET_NAME = "EXO-iOS";
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D168E2F44CA20009C51A3 /* Debug */,
E09D168F2F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D16912F44CA20009C51A3 /* Debug */,
E09D16922F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D16942F44CA20009C51A3 /* Debug */,
E09D16952F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D16972F44CA20009C51A3 /* Debug */,
E09D16982F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
/* Begin XCRemoteSwiftPackageReference section */
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/ml-explore/mlx-swift-lm";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.30.3;
};
};
/* End XCRemoteSwiftPackageReference section */
/* Begin XCSwiftPackageProductDependency section */
E09D17512F44F359009C51A3 /* MLXLLM */ = {
isa = XCSwiftPackageProductDependency;
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
productName = MLXLLM;
};
E09D17532F44F359009C51A3 /* MLXLMCommon */ = {
isa = XCSwiftPackageProductDependency;
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
productName = MLXLMCommon;
};
/* End XCSwiftPackageProductDependency section */
};
rootObject = E09D16672F44CA1E009C51A3 /* Project object */;
}

View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<Workspace
version = "1.0">
<FileRef
location = "self:">
</FileRef>
</Workspace>

View File

@@ -0,0 +1,60 @@
{
"originHash" : "facc0ac7c70363ea20f6cd1235de91dea6b06f0d00190946045a6c8ae753abc2",
"pins" : [
{
"identity" : "mlx-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift",
"state" : {
"revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d",
"version" : "0.30.6"
}
},
{
"identity" : "mlx-swift-lm",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift-lm",
"state" : {
"revision" : "360c5052b81cc154b04ee0933597a4ad6db4b8ae",
"version" : "2.30.3"
}
},
{
"identity" : "swift-collections",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-collections.git",
"state" : {
"revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e",
"version" : "1.3.0"
}
},
{
"identity" : "swift-jinja",
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-jinja.git",
"state" : {
"revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0",
"version" : "2.3.1"
}
},
{
"identity" : "swift-numerics",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-numerics",
"state" : {
"revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2",
"version" : "1.1.1"
}
},
{
"identity" : "swift-transformers",
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers",
"state" : {
"revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0",
"version" : "1.1.6"
}
}
],
"version" : 3
}

View File

@@ -0,0 +1,20 @@
{
"colors" : [
{
"color" : {
"color-space" : "srgb",
"components" : {
"alpha" : "1.000",
"blue" : "0x00",
"green" : "0xD7",
"red" : "0xFF"
}
},
"idiom" : "universal"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

View File

@@ -0,0 +1,38 @@
{
"images" : [
{
"filename" : "AppIcon.png",
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"appearances" : [
{
"appearance" : "luminosity",
"value" : "dark"
}
],
"filename" : "AppIcon.png",
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"appearances" : [
{
"appearance" : "luminosity",
"value" : "tinted"
}
],
"filename" : "AppIcon.png",
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,21 @@
{
"images" : [
{
"filename" : "exo-logo.png",
"idiom" : "universal",
"scale" : "1x"
},
{
"idiom" : "universal",
"scale" : "2x"
},
{
"idiom" : "universal",
"scale" : "3x"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>com.apple.developer.kernel.increased-memory-limit</key>
<true/>
</dict>
</plist>

View File

@@ -0,0 +1,67 @@
import SwiftUI
import UIKit
@main
struct EXO_iOSApp: App {
@State private var clusterService = ClusterService()
@State private var discoveryService = DiscoveryService()
@State private var localInferenceService = LocalInferenceService()
@State private var chatService: ChatService?
init() {
let darkGray = UIColor(red: 0x1F / 255.0, green: 0x1F / 255.0, blue: 0x1F / 255.0, alpha: 1)
let yellow = UIColor(red: 0xFF / 255.0, green: 0xD7 / 255.0, blue: 0x00 / 255.0, alpha: 1)
let navAppearance = UINavigationBarAppearance()
navAppearance.configureWithOpaqueBackground()
navAppearance.backgroundColor = darkGray
navAppearance.titleTextAttributes = [
.foregroundColor: yellow,
.font: UIFont.monospacedSystemFont(ofSize: 17, weight: .semibold),
]
navAppearance.largeTitleTextAttributes = [
.foregroundColor: yellow,
.font: UIFont.monospacedSystemFont(ofSize: 34, weight: .bold),
]
UINavigationBar.appearance().standardAppearance = navAppearance
UINavigationBar.appearance().compactAppearance = navAppearance
UINavigationBar.appearance().scrollEdgeAppearance = navAppearance
UINavigationBar.appearance().tintColor = yellow
}
var body: some Scene {
WindowGroup {
if let chatService {
RootView()
.environment(clusterService)
.environment(discoveryService)
.environment(chatService)
.environment(localInferenceService)
.preferredColorScheme(.dark)
.task {
await clusterService.attemptAutoReconnect()
discoveryService.startBrowsing()
await localInferenceService.prepareModel()
}
.onChange(of: discoveryService.discoveredClusters) { _, clusters in
guard !clusterService.isConnected,
case .disconnected = clusterService.connectionState,
let first = clusters.first
else { return }
Task {
await clusterService.connectToDiscoveredCluster(
first, using: discoveryService)
}
}
} else {
Color.exoBlack.onAppear {
chatService = ChatService(
clusterService: clusterService,
localInferenceService: localInferenceService
)
}
}
}
}
}

View File

@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>UIUserInterfaceStyle</key>
<string>Dark</string>
<key>CFBundleDisplayName</key>
<string>EXO</string>
<key>NSLocalNetworkUsageDescription</key>
<string>EXO needs local network access to connect to your EXO cluster.</string>
<key>NSBonjourServices</key>
<array>
<string>_exo._tcp</string>
<string>_p2p._tcp</string>
<string>_p2p._udp</string>
<string>_libp2p._udp</string>
</array>
</dict>
</plist>

View File

@@ -0,0 +1,129 @@
import Foundation
// MARK: - Request
struct ChatCompletionRequest: Encodable {
let model: String
let messages: [ChatCompletionMessageParam]
let stream: Bool
let maxTokens: Int?
let temperature: Double?
enum CodingKeys: String, CodingKey {
case model, messages, stream, temperature
case maxTokens = "max_tokens"
}
}
struct ChatCompletionMessageParam: Encodable {
let role: String
let content: String
}
// MARK: - Streaming Response
struct ChatCompletionChunk: Decodable {
let id: String
let model: String?
let choices: [StreamingChoice]
let usage: ChunkUsage?
init(id: String, model: String?, choices: [StreamingChoice], usage: ChunkUsage?) {
self.id = id
self.model = model
self.choices = choices
self.usage = usage
}
}
struct StreamingChoice: Decodable {
let index: Int
let delta: Delta
let finishReason: String?
enum CodingKeys: String, CodingKey {
case index, delta
case finishReason = "finish_reason"
}
init(index: Int, delta: Delta, finishReason: String?) {
self.index = index
self.delta = delta
self.finishReason = finishReason
}
}
struct Delta: Decodable {
let role: String?
let content: String?
init(role: String?, content: String?) {
self.role = role
self.content = content
}
}
struct ChunkUsage: Decodable {
let promptTokens: Int?
let completionTokens: Int?
let totalTokens: Int?
enum CodingKeys: String, CodingKey {
case promptTokens = "prompt_tokens"
case completionTokens = "completion_tokens"
case totalTokens = "total_tokens"
}
init(promptTokens: Int?, completionTokens: Int?, totalTokens: Int?) {
self.promptTokens = promptTokens
self.completionTokens = completionTokens
self.totalTokens = totalTokens
}
}
// MARK: - Non-Streaming Response
struct ChatCompletionResponse: Decodable {
let id: String
let model: String?
let choices: [ResponseChoice]
}
struct ResponseChoice: Decodable {
let index: Int
let message: ResponseMessage
let finishReason: String?
enum CodingKeys: String, CodingKey {
case index, message
case finishReason = "finish_reason"
}
}
struct ResponseMessage: Decodable {
let role: String?
let content: String?
}
// MARK: - Models List
struct ModelListResponse: Decodable {
let data: [ModelInfo]
}
struct ModelInfo: Decodable, Identifiable {
let id: String
let name: String?
}
// MARK: - Error
struct APIErrorResponse: Decodable {
let error: APIErrorInfo
}
struct APIErrorInfo: Decodable {
let message: String
let type: String?
let code: Int?
}

View File

@@ -0,0 +1,26 @@
import Foundation
struct ChatMessage: Identifiable, Equatable {
let id: UUID
let role: Role
var content: String
let timestamp: Date
var isStreaming: Bool
enum Role: String, Codable {
case user
case assistant
case system
}
init(
id: UUID = UUID(), role: Role, content: String, timestamp: Date = Date(),
isStreaming: Bool = false
) {
self.id = id
self.role = role
self.content = content
self.timestamp = timestamp
self.isStreaming = isStreaming
}
}

View File

@@ -0,0 +1,11 @@
import Foundation
struct ConnectionInfo: Codable, Equatable {
let host: String
let port: Int
let nodeId: String?
var baseURL: URL { URL(string: "http://\(host):\(port)")! }
static let defaultPort = 52415
}

View File

@@ -0,0 +1,34 @@
import Foundation
struct Conversation: Identifiable, Codable, Equatable {
let id: UUID
var title: String
var messages: [StoredMessage]
var modelId: String?
let createdAt: Date
init(
id: UUID = UUID(), title: String = "New Chat", messages: [StoredMessage] = [],
modelId: String? = nil, createdAt: Date = Date()
) {
self.id = id
self.title = title
self.messages = messages
self.modelId = modelId
self.createdAt = createdAt
}
}
struct StoredMessage: Identifiable, Codable, Equatable {
let id: UUID
let role: String
var content: String
let timestamp: Date
init(id: UUID = UUID(), role: String, content: String, timestamp: Date = Date()) {
self.id = id
self.role = role
self.content = content
self.timestamp = timestamp
}
}

View File

@@ -0,0 +1,227 @@
import Foundation
@Observable
@MainActor
final class ChatService {
var conversations: [Conversation] = []
var activeConversationId: UUID?
private(set) var isGenerating: Bool = false
private var currentGenerationTask: Task<Void, Never>?
private let clusterService: ClusterService
private let localInferenceService: LocalInferenceService
var canSendMessage: Bool {
clusterService.isConnected || localInferenceService.isAvailable
}
var activeConversation: Conversation? {
guard let id = activeConversationId else { return nil }
return conversations.first { $0.id == id }
}
var activeMessages: [ChatMessage] {
guard let conversation = activeConversation else { return [] }
return conversation.messages.map { stored in
ChatMessage(
id: stored.id,
role: ChatMessage.Role(rawValue: stored.role) ?? .user,
content: stored.content,
timestamp: stored.timestamp
)
}
}
init(clusterService: ClusterService, localInferenceService: LocalInferenceService) {
self.clusterService = clusterService
self.localInferenceService = localInferenceService
loadConversations()
}
// MARK: - Conversation Management
func createConversation(modelId: String? = nil) {
let conversation = Conversation(
modelId: modelId ?? clusterService.availableModels.first?.id)
conversations.insert(conversation, at: 0)
activeConversationId = conversation.id
saveConversations()
}
func deleteConversation(id: UUID) {
conversations.removeAll { $0.id == id }
if activeConversationId == id {
activeConversationId = conversations.first?.id
}
saveConversations()
}
func setActiveConversation(id: UUID) {
activeConversationId = id
}
func setModelForActiveConversation(_ modelId: String) {
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
return
}
conversations[index].modelId = modelId
saveConversations()
}
// MARK: - Messaging
func sendMessage(_ text: String) {
guard !text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { return }
if activeConversation == nil {
createConversation()
}
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
return
}
let userMessage = StoredMessage(role: "user", content: text)
conversations[index].messages.append(userMessage)
if conversations[index].title == "New Chat" {
let preview = String(text.prefix(40))
conversations[index].title = preview + (text.count > 40 ? "..." : "")
}
let modelId: String
if clusterService.isConnected {
guard
let clusterId = conversations[index].modelId
?? clusterService.availableModels.first?.id
else {
let errorMessage = StoredMessage(
role: "assistant", content: "No model selected. Please select a model first.")
conversations[index].messages.append(errorMessage)
saveConversations()
return
}
modelId = clusterId
} else if localInferenceService.isAvailable {
modelId = localInferenceService.defaultModelId
} else {
let errorMessage = StoredMessage(
role: "assistant",
content: "Not connected to a cluster and local model is not available.")
conversations[index].messages.append(errorMessage)
saveConversations()
return
}
conversations[index].modelId = modelId
let assistantMessageId = UUID()
let assistantMessage = StoredMessage(
id: assistantMessageId, role: "assistant", content: "", timestamp: Date())
conversations[index].messages.append(assistantMessage)
let messagesForAPI = conversations[index].messages.dropLast().map { stored in
ChatCompletionMessageParam(role: stored.role, content: stored.content)
}
let request = ChatCompletionRequest(
model: modelId,
messages: Array(messagesForAPI),
stream: true,
maxTokens: 4096,
temperature: nil
)
let conversationId = conversations[index].id
isGenerating = true
currentGenerationTask = Task { [weak self] in
guard let self else { return }
await self.performStreaming(
request: request, conversationId: conversationId,
assistantMessageId: assistantMessageId)
}
saveConversations()
}
func cancelGeneration() {
currentGenerationTask?.cancel()
currentGenerationTask = nil
localInferenceService.cancelGeneration()
isGenerating = false
}
// MARK: - Streaming
private func performStreaming(
request: ChatCompletionRequest, conversationId: UUID, assistantMessageId: UUID
) async {
defer {
isGenerating = false
currentGenerationTask = nil
saveConversations()
}
do {
let stream =
clusterService.isConnected
? clusterService.streamChatCompletion(request: request)
: localInferenceService.streamChatCompletion(request: request)
for try await chunk in stream {
guard !Task.isCancelled else { return }
guard let content = chunk.choices.first?.delta.content, !content.isEmpty else {
continue
}
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
$0.id == assistantMessageId
})
{
conversations[convIndex].messages[msgIndex].content += content
}
}
} catch {
if !Task.isCancelled {
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
$0.id == assistantMessageId
})
{
if conversations[convIndex].messages[msgIndex].content.isEmpty {
conversations[convIndex].messages[msgIndex].content =
"Error: \(error.localizedDescription)"
}
}
}
}
}
// MARK: - Persistence
private static var storageURL: URL {
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)
.first!
return documents.appendingPathComponent("exo_conversations.json")
}
private func saveConversations() {
do {
let data = try JSONEncoder().encode(conversations)
try data.write(to: Self.storageURL, options: .atomic)
} catch {
// Save failed silently
}
}
private func loadConversations() {
do {
let data = try Data(contentsOf: Self.storageURL)
conversations = try JSONDecoder().decode([Conversation].self, from: data)
activeConversationId = conversations.first?.id
} catch {
conversations = []
}
}
}

View File

@@ -0,0 +1,200 @@
import Foundation
enum ConnectionState: Equatable {
case disconnected
case connecting
case connected(ConnectionInfo)
}
struct ModelOption: Identifiable, Equatable {
let id: String
let displayName: String
}
@Observable
@MainActor
final class ClusterService {
private(set) var connectionState: ConnectionState = .disconnected
private(set) var availableModels: [ModelOption] = []
private(set) var lastError: String?
private let session: URLSession
private let decoder: JSONDecoder
private var pollingTask: Task<Void, Never>?
private static let connectionInfoKey = "exo_last_connection_info"
var isConnected: Bool {
if case .connected = connectionState { return true }
return false
}
var currentConnection: ConnectionInfo? {
if case .connected(let info) = connectionState { return info }
return nil
}
init(session: URLSession = .shared) {
self.session = session
let decoder = JSONDecoder()
self.decoder = decoder
}
// MARK: - Connection
func connect(to info: ConnectionInfo) async {
connectionState = .connecting
lastError = nil
do {
let url = info.baseURL.appendingPathComponent("node_id")
var request = URLRequest(url: url)
request.timeoutInterval = 5
request.cachePolicy = .reloadIgnoringLocalCacheData
let (_, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
throw URLError(.badServerResponse)
}
connectionState = .connected(info)
persistConnection(info)
startPolling()
await fetchModels(baseURL: info.baseURL)
} catch {
connectionState = .disconnected
lastError = "Could not connect to \(info.host):\(info.port)"
}
}
func connectToDiscoveredCluster(
_ cluster: DiscoveredCluster, using discoveryService: DiscoveryService
) async {
guard case .disconnected = connectionState else { return }
connectionState = .connecting
lastError = nil
guard let info = await discoveryService.resolve(cluster) else {
connectionState = .disconnected
lastError = "Could not resolve \(cluster.name)"
return
}
connectionState = .disconnected // reset so connect() can proceed
await connect(to: info)
}
func disconnect() {
stopPolling()
connectionState = .disconnected
availableModels = []
lastError = nil
}
func attemptAutoReconnect() async {
guard case .disconnected = connectionState,
let info = loadPersistedConnection()
else { return }
await connect(to: info)
}
// MARK: - Polling
private func startPolling(interval: TimeInterval = 2.0) {
stopPolling()
pollingTask = Task { [weak self] in
while !Task.isCancelled {
try? await Task.sleep(for: .seconds(interval))
guard let self, !Task.isCancelled else { return }
guard let connection = self.currentConnection else { return }
await self.fetchModels(baseURL: connection.baseURL)
}
}
}
private func stopPolling() {
pollingTask?.cancel()
pollingTask = nil
}
// MARK: - API
private func fetchModels(baseURL: URL) async {
do {
let url = baseURL.appendingPathComponent("models")
var request = URLRequest(url: url)
request.cachePolicy = .reloadIgnoringLocalCacheData
let (data, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else { return }
let list = try decoder.decode(ModelListResponse.self, from: data)
availableModels = list.data.map {
ModelOption(id: $0.id, displayName: $0.name ?? $0.id)
}
} catch {
// Models fetch failed silently will retry on next poll
}
}
func streamChatCompletion(request body: ChatCompletionRequest) -> AsyncThrowingStream<
ChatCompletionChunk, Error
> {
AsyncThrowingStream { continuation in
let task = Task { [weak self] in
guard let self, let connection = self.currentConnection else {
continuation.finish(throwing: URLError(.notConnectedToInternet))
return
}
do {
let url = connection.baseURL.appendingPathComponent("v1/chat/completions")
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.httpBody = try JSONEncoder().encode(body)
let (bytes, response) = try await self.session.bytes(for: request)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
continuation.finish(throwing: URLError(.badServerResponse))
return
}
let parser = SSEStreamParser<ChatCompletionChunk>(
bytes: bytes, decoder: self.decoder)
for try await chunk in parser {
continuation.yield(chunk)
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
continuation.onTermination = { _ in
task.cancel()
}
}
}
// MARK: - Persistence
private func persistConnection(_ info: ConnectionInfo) {
if let data = try? JSONEncoder().encode(info) {
UserDefaults.standard.set(data, forKey: Self.connectionInfoKey)
}
}
private func loadPersistedConnection() -> ConnectionInfo? {
guard let data = UserDefaults.standard.data(forKey: Self.connectionInfoKey) else {
return nil
}
return try? JSONDecoder().decode(ConnectionInfo.self, from: data)
}
}

View File

@@ -0,0 +1,123 @@
import Foundation
import Network
import os
struct DiscoveredCluster: Identifiable, Equatable {
let id: String
let name: String
let endpoint: NWEndpoint
static func == (lhs: DiscoveredCluster, rhs: DiscoveredCluster) -> Bool {
lhs.id == rhs.id && lhs.name == rhs.name
}
}
@Observable
@MainActor
final class DiscoveryService {
private(set) var discoveredClusters: [DiscoveredCluster] = []
private(set) var isSearching = false
private var browser: NWBrowser?
func startBrowsing() {
guard browser == nil else { return }
let browser = NWBrowser(for: .bonjour(type: "_exo._tcp", domain: nil), using: .tcp)
browser.stateUpdateHandler = { [weak self] state in
guard let service = self else { return }
Task { @MainActor in
switch state {
case .ready:
service.isSearching = true
case .failed, .cancelled:
service.isSearching = false
default:
break
}
}
}
browser.browseResultsChangedHandler = { [weak self] results, _ in
guard let service = self else { return }
Task { @MainActor in
service.discoveredClusters = results.compactMap { result in
guard case .service(let name, _, _, _) = result.endpoint else {
return nil
}
return DiscoveredCluster(
id: name,
name: name,
endpoint: result.endpoint
)
}
}
}
browser.start(queue: .main)
self.browser = browser
}
func stopBrowsing() {
browser?.cancel()
browser = nil
isSearching = false
discoveredClusters = []
}
/// Resolve a discovered Bonjour endpoint to an IP address and port, then return a ConnectionInfo.
func resolve(_ cluster: DiscoveredCluster) async -> ConnectionInfo? {
await withCheckedContinuation { continuation in
let didResume = OSAllocatedUnfairLock(initialState: false)
let connection = NWConnection(to: cluster.endpoint, using: .tcp)
connection.stateUpdateHandler = { state in
guard
didResume.withLock({
guard !$0 else { return false }
$0 = true
return true
})
else { return }
switch state {
case .ready:
if let innerEndpoint = connection.currentPath?.remoteEndpoint,
case .hostPort(let host, let port) = innerEndpoint
{
var hostString: String
switch host {
case .ipv4(let addr):
hostString = "\(addr)"
case .ipv6(let addr):
hostString = "\(addr)"
case .name(let name, _):
hostString = name
@unknown default:
hostString = "\(host)"
}
// Strip interface scope suffix (e.g. "%en0")
if let pct = hostString.firstIndex(of: "%") {
hostString = String(hostString[..<pct])
}
let info = ConnectionInfo(
host: hostString,
port: Int(port.rawValue),
nodeId: nil
)
connection.cancel()
continuation.resume(returning: info)
} else {
connection.cancel()
continuation.resume(returning: nil)
}
case .failed, .cancelled:
continuation.resume(returning: nil)
default:
// Not a terminal state allow future callbacks
didResume.withLock { $0 = false }
}
}
connection.start(queue: .global(qos: .userInitiated))
}
}
}

View File

@@ -0,0 +1,201 @@
import Foundation
import MLXLLM
import MLXLMCommon
enum LocalModelState: Equatable {
case notDownloaded
case downloading(progress: Double)
case downloaded
case loading
case ready
case generating
case error(String)
}
@Observable
@MainActor
final class LocalInferenceService {
private(set) var modelState: LocalModelState = .notDownloaded
private var modelContainer: ModelContainer?
private var generationTask: Task<Void, Never>?
let defaultModelId = "mlx-community/Qwen3-0.6B-4bit"
private static let modelDownloadedKey = "exo_local_model_downloaded"
var isReady: Bool {
modelState == .ready
}
var isAvailable: Bool {
modelState == .ready || modelState == .generating
}
init() {
if UserDefaults.standard.bool(forKey: Self.modelDownloadedKey) {
modelState = .downloaded
}
}
// MARK: - Model Lifecycle
func prepareModel() async {
guard modelState == .notDownloaded || modelState == .downloaded else { return }
let wasDownloaded = modelState == .downloaded
if !wasDownloaded {
modelState = .downloading(progress: 0)
} else {
modelState = .loading
}
do {
let container = try await loadModelContainer(
id: defaultModelId
) { [weak self] progress in
guard let self else { return }
Task { @MainActor in
if case .downloading = self.modelState {
self.modelState = .downloading(progress: progress.fractionCompleted)
}
}
}
self.modelContainer = container
UserDefaults.standard.set(true, forKey: Self.modelDownloadedKey)
modelState = .ready
} catch {
modelState = .error(error.localizedDescription)
}
}
func unloadModel() {
cancelGeneration()
modelContainer = nil
modelState = .downloaded
}
// MARK: - Generation
func streamChatCompletion(request: ChatCompletionRequest) -> AsyncThrowingStream<
ChatCompletionChunk, Error
> {
AsyncThrowingStream { continuation in
let task = Task { [weak self] in
guard let self else {
continuation.finish(throwing: LocalInferenceError.serviceUnavailable)
return
}
guard let container = self.modelContainer else {
continuation.finish(throwing: LocalInferenceError.modelNotLoaded)
return
}
await MainActor.run {
self.modelState = .generating
}
defer {
Task { @MainActor [weak self] in
if self?.modelState == .generating {
self?.modelState = .ready
}
}
}
let chunkId = "local-\(UUID().uuidString)"
do {
// Build Chat.Message array from the request
var chatMessages: [Chat.Message] = []
for msg in request.messages {
switch msg.role {
case "system":
chatMessages.append(.system(msg.content))
case "assistant":
chatMessages.append(.assistant(msg.content))
default:
chatMessages.append(.user(msg.content))
}
}
// Use ChatSession for streaming generation
let session = ChatSession(
container,
history: chatMessages,
generateParameters: GenerateParameters(
maxTokens: request.maxTokens ?? 4096,
temperature: Float(request.temperature ?? 0.7)
)
)
// Stream with an empty prompt since history already contains the conversation
let stream = session.streamResponse(to: "")
for try await text in stream {
if Task.isCancelled { break }
let chunk = ChatCompletionChunk(
id: chunkId,
model: request.model,
choices: [
StreamingChoice(
index: 0,
delta: Delta(role: nil, content: text),
finishReason: nil
)
],
usage: nil
)
continuation.yield(chunk)
}
// Send final chunk with finish reason
let finalChunk = ChatCompletionChunk(
id: chunkId,
model: request.model,
choices: [
StreamingChoice(
index: 0,
delta: Delta(role: nil, content: nil),
finishReason: "stop"
)
],
usage: nil
)
continuation.yield(finalChunk)
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
self.generationTask = task
continuation.onTermination = { _ in
task.cancel()
}
}
}
func cancelGeneration() {
generationTask?.cancel()
generationTask = nil
if modelState == .generating {
modelState = .ready
}
}
}
enum LocalInferenceError: LocalizedError {
case serviceUnavailable
case modelNotLoaded
var errorDescription: String? {
switch self {
case .serviceUnavailable: "Local inference service is unavailable"
case .modelNotLoaded: "Local model is not loaded"
}
}
}

View File

@@ -0,0 +1,50 @@
import Foundation
struct SSEStreamParser<T: Decodable>: AsyncSequence {
typealias Element = T
let bytes: URLSession.AsyncBytes
let decoder: JSONDecoder
init(bytes: URLSession.AsyncBytes, decoder: JSONDecoder = JSONDecoder()) {
self.bytes = bytes
self.decoder = decoder
}
func makeAsyncIterator() -> AsyncIterator {
AsyncIterator(lines: bytes.lines, decoder: decoder)
}
struct AsyncIterator: AsyncIteratorProtocol {
var lines: AsyncLineSequence<URLSession.AsyncBytes>.AsyncIterator
let decoder: JSONDecoder
init(lines: AsyncLineSequence<URLSession.AsyncBytes>, decoder: JSONDecoder) {
self.lines = lines.makeAsyncIterator()
self.decoder = decoder
}
mutating func next() async throws -> T? {
while let line = try await lines.next() {
let trimmed = line.trimmingCharacters(in: .whitespaces)
guard trimmed.hasPrefix("data: ") else { continue }
let payload = String(trimmed.dropFirst(6))
if payload == "[DONE]" {
return nil
}
guard let data = payload.data(using: .utf8) else { continue }
do {
return try decoder.decode(T.self, from: data)
} catch {
continue
}
}
return nil
}
}
}

View File

@@ -0,0 +1,203 @@
import SwiftUI
struct ChatView: View {
@Environment(ClusterService.self) private var clusterService
@Environment(ChatService.self) private var chatService
@Environment(LocalInferenceService.self) private var localInferenceService
@State private var inputText = ""
@State private var showModelSelector = false
var body: some View {
VStack(spacing: 0) {
modelBar
GradientDivider()
messageList
GradientDivider()
inputBar
}
.background(Color.exoBlack)
.sheet(isPresented: $showModelSelector) {
ModelSelectorView(
models: clusterService.availableModels,
selectedModelId: chatService.activeConversation?.modelId
) { modelId in
chatService.setModelForActiveConversation(modelId)
}
.presentationBackground(Color.exoDarkGray)
}
}
// MARK: - Model Bar
private var useLocalModel: Bool {
!clusterService.isConnected && localInferenceService.isAvailable
}
private var modelBar: some View {
Button {
if !useLocalModel {
showModelSelector = true
}
} label: {
HStack {
Image(systemName: useLocalModel ? "iphone" : "cpu")
.font(.exoCaption)
.foregroundStyle(useLocalModel ? Color.exoYellow : Color.exoLightGray)
if useLocalModel {
Text(localInferenceService.defaultModelId)
.font(.exoSubheadline)
.foregroundStyle(Color.exoForeground)
.lineLimit(1)
} else if let modelId = chatService.activeConversation?.modelId {
Text(modelId)
.font(.exoSubheadline)
.foregroundStyle(Color.exoForeground)
.lineLimit(1)
} else {
Text("SELECT MODEL")
.font(.exoSubheadline)
.tracking(1.5)
.foregroundStyle(Color.exoLightGray)
}
Spacer()
if useLocalModel {
Text("ON-DEVICE")
.font(.exoCaption)
.tracking(1)
.foregroundStyle(Color.exoYellow)
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(Color.exoYellow.opacity(0.15))
.clipShape(Capsule())
} else {
Image(systemName: "chevron.right")
.font(.caption)
.foregroundStyle(Color.exoLightGray)
}
}
.padding(.horizontal)
.padding(.vertical, 10)
.background(Color.exoDarkGray)
}
.tint(.primary)
.disabled(useLocalModel)
}
// MARK: - Messages
private var messageList: some View {
ScrollViewReader { proxy in
ScrollView {
LazyVStack(spacing: 12) {
if chatService.activeMessages.isEmpty {
emptyState
} else {
ForEach(chatService.activeMessages) { message in
MessageBubbleView(message: message)
.id(message.id)
}
}
}
.padding()
}
.background(Color.exoBlack)
.onChange(of: chatService.activeMessages.last?.content) {
if let lastId = chatService.activeMessages.last?.id {
withAnimation(.easeOut(duration: 0.2)) {
proxy.scrollTo(lastId, anchor: .bottom)
}
}
}
}
}
private var emptyState: some View {
VStack(spacing: 16) {
Spacer(minLength: 80)
ZStack {
Circle()
.stroke(Color.exoYellow.opacity(0.15), lineWidth: 1)
.frame(width: 80, height: 80)
Circle()
.stroke(Color.exoYellow.opacity(0.3), lineWidth: 1)
.frame(width: 56, height: 56)
Circle()
.fill(Color.exoYellow.opacity(0.15))
.frame(width: 32, height: 32)
Circle()
.fill(Color.exoYellow)
.frame(width: 8, height: 8)
.shadow(color: Color.exoYellow.opacity(0.6), radius: 6)
}
Text("AWAITING INPUT")
.font(.exoSubheadline)
.tracking(3)
.foregroundStyle(Color.exoLightGray)
Text("Send a message to begin.")
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray.opacity(0.6))
Spacer(minLength: 80)
}
.padding()
}
// MARK: - Input
private var inputBar: some View {
HStack(alignment: .bottom, spacing: 8) {
TextField("Message...", text: $inputText, axis: .vertical)
.font(.exoBody)
.lineLimit(1...6)
.textFieldStyle(.plain)
.padding(10)
.background(Color.exoMediumGray)
.foregroundStyle(Color.exoForeground)
.clipShape(RoundedRectangle(cornerRadius: 8))
if chatService.isGenerating {
Button {
chatService.cancelGeneration()
} label: {
Image(systemName: "stop.circle.fill")
.font(.title2)
.foregroundStyle(Color.exoDestructive)
}
} else {
Button {
let text = inputText
inputText = ""
chatService.sendMessage(text)
} label: {
Text("SEND")
.font(.exoMono(12, weight: .bold))
.tracking(1)
.foregroundStyle(canSend ? Color.exoBlack : Color.exoLightGray)
.padding(.horizontal, 14)
.padding(.vertical, 8)
.background(canSend ? Color.exoYellow : Color.exoMediumGray)
.clipShape(RoundedRectangle(cornerRadius: 8))
}
.disabled(!canSend)
}
}
.padding(.horizontal)
.padding(.vertical, 8)
.background(Color.exoDarkGray)
}
private var canSend: Bool {
!inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
&& (clusterService.isConnected || localInferenceService.isAvailable)
}
}

View File

@@ -0,0 +1,54 @@
import SwiftUI
struct MessageBubbleView: View {
let message: ChatMessage
private var isAssistant: Bool { message.role == .assistant }
var body: some View {
HStack {
if message.role == .user { Spacer(minLength: 48) }
VStack(alignment: isAssistant ? .leading : .trailing, spacing: 6) {
// Header
HStack(spacing: 4) {
if isAssistant {
Circle()
.fill(Color.exoYellow)
.frame(width: 6, height: 6)
.shadow(color: Color.exoYellow.opacity(0.6), radius: 4)
Text("EXO")
.font(.exoMono(10, weight: .bold))
.tracking(1.5)
.foregroundStyle(Color.exoYellow)
} else {
Text("QUERY")
.font(.exoMono(10, weight: .medium))
.tracking(1.5)
.foregroundStyle(Color.exoLightGray)
}
}
// Bubble
HStack(spacing: 0) {
if isAssistant {
RoundedRectangle(cornerRadius: 1)
.fill(Color.exoYellow.opacity(0.5))
.frame(width: 2)
}
Text(message.content + (message.isStreaming ? " \u{258C}" : ""))
.font(.exoBody)
.textSelection(.enabled)
.foregroundStyle(Color.exoForeground)
.padding(.horizontal, 14)
.padding(.vertical, 10)
}
.background(Color.exoDarkGray)
.clipShape(RoundedRectangle(cornerRadius: 8))
}
if isAssistant { Spacer(minLength: 48) }
}
}
}

View File

@@ -0,0 +1,75 @@
import SwiftUI
struct ModelSelectorView: View {
let models: [ModelOption]
let selectedModelId: String?
let onSelect: (String) -> Void
@Environment(\.dismiss) private var dismiss
var body: some View {
NavigationStack {
List {
if models.isEmpty {
emptyContent
} else {
modelsList
}
}
.scrollContentBackground(.hidden)
.background(Color.exoBlack)
.navigationTitle("SELECT MODEL")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .cancellationAction) {
Button("Cancel") { dismiss() }
.font(.exoSubheadline)
.foregroundStyle(Color.exoYellow)
}
}
}
}
private var emptyContent: some View {
ContentUnavailableView(
"No Models Available",
systemImage: "cpu",
description: Text("Connect to an EXO cluster to see available models.")
)
.foregroundStyle(Color.exoLightGray)
.listRowBackground(Color.exoBlack)
}
private var modelsList: some View {
ForEach(models) { model in
Button {
onSelect(model.id)
dismiss()
} label: {
modelRow(model)
}
.tint(.primary)
.listRowBackground(Color.exoDarkGray)
}
}
private func modelRow(_ model: ModelOption) -> some View {
HStack {
VStack(alignment: .leading, spacing: 2) {
Text(model.displayName)
.font(.exoSubheadline)
.foregroundStyle(Color.exoForeground)
Text(model.id)
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
Spacer()
if model.id == selectedModelId {
Image(systemName: "checkmark")
.font(.exoSubheadline)
.foregroundStyle(Color.exoYellow)
}
}
}
}

View File

@@ -0,0 +1,68 @@
import SwiftUI
struct ConnectionStatusBadge: View {
let connectionState: ConnectionState
var localModelState: LocalModelState = .notDownloaded
private var isLocalReady: Bool {
if case .disconnected = connectionState {
return localModelState == .ready || localModelState == .generating
}
return false
}
var body: some View {
HStack(spacing: 6) {
Circle()
.fill(dotColor)
.frame(width: 8, height: 8)
.shadow(color: dotColor.opacity(0.6), radius: 4)
Text(label.uppercased())
.font(.exoMono(10, weight: .medium))
.tracking(1)
.foregroundStyle(Color.exoForeground)
}
.padding(.horizontal, 10)
.padding(.vertical, 5)
.background(backgroundColor)
.clipShape(Capsule())
.overlay(
Capsule()
.stroke(dotColor.opacity(0.3), lineWidth: 1)
)
}
private var dotColor: Color {
if isLocalReady {
return .exoYellow
}
switch connectionState {
case .connected: return .green
case .connecting: return .orange
case .disconnected: return .exoLightGray
}
}
private var label: String {
if isLocalReady {
return "Local"
}
switch connectionState {
case .connected: return "Connected"
case .connecting: return "Connecting"
case .disconnected: return "Disconnected"
}
}
private var backgroundColor: Color {
if isLocalReady {
return Color.exoYellow.opacity(0.1)
}
switch connectionState {
case .connected: return .green.opacity(0.1)
case .connecting: return .orange.opacity(0.1)
case .disconnected: return Color.exoMediumGray.opacity(0.5)
}
}
}

View File

@@ -0,0 +1,136 @@
import SwiftUI
struct RootView: View {
@Environment(ClusterService.self) private var clusterService
@Environment(DiscoveryService.self) private var discoveryService
@Environment(ChatService.self) private var chatService
@Environment(LocalInferenceService.self) private var localInferenceService
@State private var showSettings = false
@State private var showConversations = false
var body: some View {
NavigationStack {
ChatView()
.navigationTitle("EXO")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .topBarLeading) {
conversationMenuButton
}
ToolbarItem(placement: .principal) {
ConnectionStatusBadge(
connectionState: clusterService.connectionState,
localModelState: localInferenceService.modelState
)
}
ToolbarItem(placement: .topBarTrailing) {
Button {
showSettings = true
} label: {
Image(systemName: "gear")
.foregroundStyle(Color.exoYellow)
}
}
}
}
.tint(Color.exoYellow)
.sheet(isPresented: $showSettings) {
SettingsView()
.environment(discoveryService)
.presentationBackground(Color.exoDarkGray)
}
.sheet(isPresented: $showConversations) {
conversationList
.presentationBackground(Color.exoDarkGray)
}
}
// MARK: - Conversations
private var conversationMenuButton: some View {
HStack(spacing: 12) {
Button {
showConversations = true
} label: {
Image(systemName: "sidebar.left")
.foregroundStyle(Color.exoYellow)
}
Button {
chatService.createConversation()
} label: {
Image(systemName: "square.and.pencil")
.foregroundStyle(Color.exoYellow)
}
}
}
private var conversationList: some View {
NavigationStack {
List {
if chatService.conversations.isEmpty {
Text("No conversations yet")
.font(.exoBody)
.foregroundStyle(Color.exoLightGray)
.listRowBackground(Color.exoDarkGray)
} else {
ForEach(chatService.conversations) { conversation in
let isActive = conversation.id == chatService.activeConversationId
Button {
chatService.setActiveConversation(id: conversation.id)
showConversations = false
} label: {
VStack(alignment: .leading, spacing: 4) {
Text(conversation.title)
.font(.exoSubheadline)
.fontWeight(isActive ? .semibold : .regular)
.foregroundStyle(
isActive ? Color.exoYellow : Color.exoForeground
)
.lineLimit(1)
if let modelId = conversation.modelId {
Text(modelId)
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
.lineLimit(1)
}
}
}
.listRowBackground(
isActive
? Color.exoYellow.opacity(0.1)
: Color.exoDarkGray
)
}
.onDelete { indexSet in
for index in indexSet {
chatService.deleteConversation(id: chatService.conversations[index].id)
}
}
}
}
.scrollContentBackground(.hidden)
.background(Color.exoBlack)
.navigationTitle("Conversations")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .confirmationAction) {
Button("Done") { showConversations = false }
.font(.exoSubheadline)
.foregroundStyle(Color.exoYellow)
}
ToolbarItem(placement: .topBarLeading) {
Button {
chatService.createConversation()
} label: {
Image(systemName: "plus")
.foregroundStyle(Color.exoYellow)
}
}
}
}
}
}

View File

@@ -0,0 +1,314 @@
import SwiftUI
struct SettingsView: View {
@Environment(ClusterService.self) private var clusterService
@Environment(DiscoveryService.self) private var discoveryService
@Environment(LocalInferenceService.self) private var localInferenceService
@Environment(\.dismiss) private var dismiss
@State private var host: String = ""
@State private var port: String = "52415"
var body: some View {
NavigationStack {
Form {
localModelSection
nearbyClustersSection
connectionSection
statusSection
}
.scrollContentBackground(.hidden)
.background(Color.exoBlack)
.navigationTitle("Settings")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .confirmationAction) {
Button("Done") { dismiss() }
.font(.exoSubheadline)
.foregroundStyle(Color.exoYellow)
}
}
}
}
// MARK: - Section Headers
private func sectionHeader(_ title: String) -> some View {
Text(title.uppercased())
.font(.exoMono(10, weight: .semibold))
.tracking(2)
.foregroundStyle(Color.exoYellow)
}
// MARK: - Local Model
private var localModelSection: some View {
Section {
HStack {
VStack(alignment: .leading, spacing: 4) {
Text(localInferenceService.defaultModelId)
.font(.exoSubheadline)
.foregroundStyle(Color.exoForeground)
Text(localModelStatusText)
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
Spacer()
localModelActionButton
}
.listRowBackground(Color.exoDarkGray)
if case .downloading(let progress) = localInferenceService.modelState {
ProgressView(value: progress)
.tint(Color.exoYellow)
.listRowBackground(Color.exoDarkGray)
}
} header: {
sectionHeader("Local Model")
} footer: {
Text(
"When disconnected from a cluster, messages are processed on-device using this model."
)
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray.opacity(0.7))
}
}
private var localModelStatusText: String {
switch localInferenceService.modelState {
case .notDownloaded: "Not downloaded"
case .downloading(let progress): "Downloading \(Int(progress * 100))%..."
case .downloaded: "Downloaded — not loaded"
case .loading: "Loading into memory..."
case .ready: "Ready"
case .generating: "Generating..."
case .error(let message): "Error: \(message)"
}
}
@ViewBuilder
private var localModelActionButton: some View {
switch localInferenceService.modelState {
case .notDownloaded:
exoButton("Download") {
Task { await localInferenceService.prepareModel() }
}
case .downloading:
ProgressView()
.controlSize(.small)
.tint(Color.exoYellow)
case .downloaded:
exoButton("Load") {
Task { await localInferenceService.prepareModel() }
}
case .loading:
ProgressView()
.controlSize(.small)
.tint(Color.exoYellow)
case .ready, .generating:
exoButton("Unload") {
localInferenceService.unloadModel()
}
case .error:
exoButton("Retry", destructive: true) {
Task { await localInferenceService.prepareModel() }
}
}
}
private func exoButton(_ title: String, destructive: Bool = false, action: @escaping () -> Void)
-> some View
{
let borderColor = destructive ? Color.exoDestructive : Color.exoYellow
return Button(action: action) {
Text(title.uppercased())
.font(.exoMono(11, weight: .semibold))
.tracking(1)
.foregroundStyle(borderColor)
.padding(.horizontal, 10)
.padding(.vertical, 5)
.overlay(
RoundedRectangle(cornerRadius: 6)
.stroke(borderColor, lineWidth: 1)
)
}
}
// MARK: - Nearby Clusters
private var nearbyClustersSection: some View {
Section {
if discoveryService.discoveredClusters.isEmpty {
if discoveryService.isSearching {
HStack {
ProgressView()
.tint(Color.exoYellow)
.padding(.trailing, 8)
Text("Searching for clusters...")
.font(.exoBody)
.foregroundStyle(Color.exoLightGray)
}
.listRowBackground(Color.exoDarkGray)
} else {
Text("No clusters found")
.font(.exoBody)
.foregroundStyle(Color.exoLightGray)
.listRowBackground(Color.exoDarkGray)
}
} else {
ForEach(discoveryService.discoveredClusters) { cluster in
HStack {
VStack(alignment: .leading) {
Text(cluster.name)
.font(.exoBody)
.foregroundStyle(Color.exoForeground)
}
Spacer()
exoButton("Connect") {
Task {
await clusterService.connectToDiscoveredCluster(
cluster, using: discoveryService
)
if clusterService.isConnected {
dismiss()
}
}
}
}
.listRowBackground(Color.exoDarkGray)
}
}
} header: {
sectionHeader("Nearby Clusters")
}
}
// MARK: - Manual Connection
private var connectionSection: some View {
Section {
TextField("IP Address (e.g. 192.168.1.42)", text: $host)
.font(.exoBody)
.keyboardType(.decimalPad)
.textContentType(.URL)
.autocorrectionDisabled()
.foregroundStyle(Color.exoForeground)
.listRowBackground(Color.exoDarkGray)
TextField("Port", text: $port)
.font(.exoBody)
.keyboardType(.numberPad)
.foregroundStyle(Color.exoForeground)
.listRowBackground(Color.exoDarkGray)
Button {
Task {
let portNum = Int(port) ?? ConnectionInfo.defaultPort
let info = ConnectionInfo(host: host, port: portNum, nodeId: nil)
await clusterService.connect(to: info)
if clusterService.isConnected {
dismiss()
}
}
} label: {
Text(clusterService.isConnected ? "RECONNECT" : "CONNECT")
.font(.exoMono(13, weight: .semibold))
.tracking(1.5)
.foregroundStyle(
host.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
? Color.exoLightGray : Color.exoYellow
)
.frame(maxWidth: .infinity)
}
.disabled(host.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
.listRowBackground(Color.exoDarkGray)
} header: {
sectionHeader("Manual Connection")
}
}
// MARK: - Status
private var statusSection: some View {
Section {
if let connection = clusterService.currentConnection {
LabeledContent {
Text(connection.host)
.font(.exoCaption)
.foregroundStyle(Color.exoForeground)
} label: {
Text("Host")
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
.listRowBackground(Color.exoDarkGray)
LabeledContent {
Text("\(connection.port)")
.font(.exoCaption)
.foregroundStyle(Color.exoForeground)
} label: {
Text("Port")
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
.listRowBackground(Color.exoDarkGray)
if let nodeId = connection.nodeId {
LabeledContent {
Text(String(nodeId.prefix(12)) + "...")
.font(.exoCaption)
.foregroundStyle(Color.exoForeground)
} label: {
Text("Node ID")
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
.listRowBackground(Color.exoDarkGray)
}
LabeledContent {
Text("\(clusterService.availableModels.count)")
.font(.exoCaption)
.foregroundStyle(Color.exoForeground)
} label: {
Text("Models")
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
.listRowBackground(Color.exoDarkGray)
Button(role: .destructive) {
clusterService.disconnect()
} label: {
Text("DISCONNECT")
.font(.exoMono(13, weight: .semibold))
.tracking(1.5)
.foregroundStyle(Color.exoDestructive)
.frame(maxWidth: .infinity)
}
.listRowBackground(Color.exoDarkGray)
} else {
if let error = clusterService.lastError {
Label {
Text(error)
.font(.exoCaption)
} icon: {
Image(systemName: "exclamationmark.triangle")
}
.foregroundStyle(Color.exoDestructive)
.listRowBackground(Color.exoDarkGray)
} else {
Text("Not connected")
.font(.exoBody)
.foregroundStyle(Color.exoLightGray)
.listRowBackground(Color.exoDarkGray)
}
}
} header: {
sectionHeader("Status")
}
}
}

View File

@@ -0,0 +1,51 @@
import SwiftUI
// MARK: - EXO Color Palette
extension Color {
/// Primary background near-black (#121212)
static let exoBlack = Color(red: 0x12 / 255.0, green: 0x12 / 255.0, blue: 0x12 / 255.0)
/// Card / surface background (#1F1F1F)
static let exoDarkGray = Color(red: 0x1F / 255.0, green: 0x1F / 255.0, blue: 0x1F / 255.0)
/// Input field / elevated surface (#353535)
static let exoMediumGray = Color(red: 0x35 / 255.0, green: 0x35 / 255.0, blue: 0x35 / 255.0)
/// Secondary text (#999999)
static let exoLightGray = Color(red: 0x99 / 255.0, green: 0x99 / 255.0, blue: 0x99 / 255.0)
/// Accent yellow matches dashboard (#FFD700)
static let exoYellow = Color(red: 0xFF / 255.0, green: 0xD7 / 255.0, blue: 0x00 / 255.0)
/// Primary foreground text (#E5E5E5)
static let exoForeground = Color(red: 0xE5 / 255.0, green: 0xE5 / 255.0, blue: 0xE5 / 255.0)
/// Destructive / error (#E74C3C)
static let exoDestructive = Color(red: 0xE7 / 255.0, green: 0x4C / 255.0, blue: 0x3C / 255.0)
}
// MARK: - EXO Typography (SF Mono via .monospaced design)
extension Font {
/// Monospaced font at a given size and weight.
static func exoMono(_ size: CGFloat, weight: Font.Weight = .regular) -> Font {
.system(size: size, weight: weight, design: .monospaced)
}
/// Body text 15pt monospaced
static let exoBody: Font = .system(size: 15, weight: .regular, design: .monospaced)
/// Caption 11pt monospaced
static let exoCaption: Font = .system(size: 11, weight: .regular, design: .monospaced)
/// Subheadline 13pt monospaced medium
static let exoSubheadline: Font = .system(size: 13, weight: .medium, design: .monospaced)
/// Headline 17pt monospaced semibold
static let exoHeadline: Font = .system(size: 17, weight: .semibold, design: .monospaced)
}
// MARK: - Reusable Gradient Divider
struct GradientDivider: View {
var body: some View {
LinearGradient(
colors: [.clear, Color.exoYellow.opacity(0.3), .clear],
startPoint: .leading,
endPoint: .trailing
)
.frame(height: 1)
}
}

View File

@@ -0,0 +1,18 @@
//
// EXO_iOSTests.swift
// EXO-iOSTests
//
// Created by Sami Khan on 2026-02-17.
//
import Testing
@testable import EXO_iOS
struct EXO_iOSTests {
@Test func example() async throws {
// Write your test here and use APIs like `#expect(...)` to check expected conditions.
}
}

View File

@@ -0,0 +1,41 @@
//
// EXO_iOSUITests.swift
// EXO-iOSUITests
//
// Created by Sami Khan on 2026-02-17.
//
import XCTest
final class EXO_iOSUITests: XCTestCase {
override func setUpWithError() throws {
// Put setup code here. This method is called before the invocation of each test method in the class.
// In UI tests it is usually best to stop immediately when a failure occurs.
continueAfterFailure = false
// In UI tests its important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this.
}
override func tearDownWithError() throws {
// Put teardown code here. This method is called after the invocation of each test method in the class.
}
@MainActor
func testExample() throws {
// UI tests must launch the application that they test.
let app = XCUIApplication()
app.launch()
// Use XCTAssert and related functions to verify your tests produce the correct results.
}
@MainActor
func testLaunchPerformance() throws {
// This measures how long it takes to launch your application.
measure(metrics: [XCTApplicationLaunchMetric()]) {
XCUIApplication().launch()
}
}
}

View File

@@ -0,0 +1,33 @@
//
// EXO_iOSUITestsLaunchTests.swift
// EXO-iOSUITests
//
// Created by Sami Khan on 2026-02-17.
//
import XCTest
final class EXO_iOSUITestsLaunchTests: XCTestCase {
override class var runsForEachTargetApplicationUIConfiguration: Bool {
true
}
override func setUpWithError() throws {
continueAfterFailure = false
}
@MainActor
func testLaunch() throws {
let app = XCUIApplication()
app.launch()
// Insert steps here to perform after app launch but before taking a screenshot,
// such as logging into a test account or navigating somewhere in the app
let attachment = XCTAttachment(screenshot: app.screenshot())
attachment.name = "Launch Screen"
attachment.lifetime = .keepAlways
add(attachment)
}
}

View File

@@ -126,37 +126,11 @@ final class ExoProcessController: ObservableObject {
return
}
process.terminationHandler = nil
status = .stopped
guard process.isRunning else {
self.process = nil
return
if process.isRunning {
process.terminate()
}
let proc = process
self.process = nil
Task.detached {
proc.interrupt()
for _ in 0..<50 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
proc.terminate()
}
for _ in 0..<30 {
if !proc.isRunning { return }
try? await Task.sleep(nanoseconds: 100_000_000)
}
if proc.isRunning {
kill(proc.processIdentifier, SIGKILL)
}
}
status = .stopped
}
func restart() {

View File

@@ -103,7 +103,7 @@
const modelSupportsThinking = $derived(() => {
if (!currentModel) return false;
const caps = modelCapabilities[currentModel] || [];
return caps.includes("thinking_toggle") && caps.includes("text");
return caps.includes("thinking") && caps.includes("text");
});
const isEditOnlyWithoutImage = $derived(

View File

@@ -59,14 +59,13 @@
}
const sizeOptions: ImageGenerationParams["size"][] = [
"auto",
"512x512",
"768x768",
"1024x1024",
"1024x768",
"768x1024",
"1024x1536",
"1536x1024",
"1024x1365",
"1365x1024",
];
const qualityOptions: ImageGenerationParams["quality"][] = [
@@ -177,90 +176,92 @@
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
<!-- Size (hidden in edit mode - output size comes from input image) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
{params.size.toUpperCase()}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size.toUpperCase()}</span>
</button>
{/each}
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{/if}
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
<!-- Quality -->
<div class="flex items-center gap-1.5">
@@ -310,7 +311,7 @@
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {qualityDropdownPosition()
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
>

View File

@@ -306,14 +306,13 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
export interface ImageGenerationParams {
// Basic params
size:
| "auto"
| "512x512"
| "768x768"
| "1024x1024"
| "1024x768"
| "768x1024"
| "1024x1536"
| "1536x1024";
| "1024x1365"
| "1365x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
@@ -337,7 +336,7 @@ export interface EditingImage {
}
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
size: "auto",
size: "1024x1024",
quality: "medium",
outputFormat: "png",
numImages: 1,

View File

@@ -115,7 +115,7 @@
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{

View File

@@ -41,16 +41,16 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.7.dev20260218+14841977"; in
version = let v = "0.30.6"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "1484197707f35186ad3bd614357c7c47fdf86ebc";
hash = "sha256-FupCMoK/SF/ldfKuvMSAKECcOP8c+ANgkQlPZttDsLk=";
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
};
patches = [

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx; sys_platform == 'darwin'",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.7",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -64,7 +64,6 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }

View File

@@ -58,21 +58,6 @@
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
mlx = ignoreMissing prev.mlx;
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [
final.nvidia-cublas
final.nvidia-cuda-nvrtc
final.nvidia-cudnn-cu13
final.nvidia-nccl-cu13
];
preFixup = ''
addAutoPatchelfSearchPath ${final.nvidia-cublas}
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
'';
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
torch = ignoreMissing prev.torch;
triton = ignoreMissing prev.triton;
}
@@ -89,25 +74,14 @@
linuxOverlay
]
);
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
"lib/python3.13/site-packages/mlx*"
"lib/python3.13/site-packages/nvidia*"
];
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
)).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
);
mkPythonScript = name: path: pkgs.writeShellApplication {
inherit name;

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "4bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405874409472

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "8bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 765577920512

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 122406567936

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 229780750336

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 198556925568

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 286737579648

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 396963397248

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 19327352832

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "5bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 22548578304

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 26843545600

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 34359738368

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5-8bit-MXFP8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 790517400864

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5-MXFP4-Q8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "MXFP4-Q8"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405478939008

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 1487822475264

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 706522120192

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2.5"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 662498705408

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "3bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 100086644736

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 242986745856

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 342884352

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 698351616

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 141733920768

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 268435456000

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 17612931072

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 33279705088

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 47080074240

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 88814387200

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 114572190076

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 159039627774

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking", "thinking_toggle"]
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 209082699847

2
rust/clippy.toml Normal file
View File

@@ -0,0 +1,2 @@
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
#allowed-duplicate-crates = ["hashbrown"]

View File

@@ -5,6 +5,7 @@ edition = { workspace = true }
publish = false
[lib]
doctest = false
path = "src/lib.rs"
name = "exo_pyo3_bindings"
@@ -24,17 +25,17 @@ workspace = true
networking = { workspace = true }
# interop
pyo3 = { version = "0.27.2", features = [
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
pyo3 = { version = "0.27.1", features = [
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
"nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
"multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
# "ordered-float", "rust_decimal", "smallvec",
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
"ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.17.2" }
@@ -43,11 +44,34 @@ pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
pin-project = { workspace = true }
# async runtime
tokio = { workspace = true, features = ["full", "tracing"] }
futures = { workspace = true }
# utility dependencies
once_cell = "1.21.3"
thread_local = "1.1.9"
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
# Tracing
#tracing = "0.1"
#tracing-subscriber = "0.3"
#console-subscriber = "0.1.5"
#tracing-log = "0.2.0"
log = { workspace = true }
env_logger = "0.11"
# Networking
libp2p = { workspace = true, features = ["full"] }

View File

@@ -0,0 +1 @@
TODO: do something here....

View File

@@ -1,4 +1,8 @@
//! See: <https://pyo3.rs/v0.27.2/async-await.html#detaching-from-the-interpreter-across-await>
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
//!
use pin_project::pin_project;
use pyo3::marker::Ungil;
use pyo3::prelude::*;
use std::{
future::Future,
@@ -6,17 +10,31 @@ use std::{
task::{Context, Poll},
};
pub struct AllowThreads<F>(pub(crate) F);
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
#[pin_project]
#[repr(transparent)]
pub(crate) struct AllowThreads<F>(#[pin] F);
impl<F> AllowThreads<F>
where
Self: Future,
{
pub fn new(f: F) -> Self {
Self(f)
}
}
impl<F> Future for AllowThreads<F>
where
F: Future + Unpin + Send,
F::Output: Send,
F: Future + Ungil,
F::Output: Ungil,
{
type Output = F::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
Python::attach(|py| py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker))))
Python::with_gil(|py| {
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
})
}
}

View File

@@ -0,0 +1,240 @@
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
//!
//! Pattern examples include:
//! - Async task handles: with GC-integrated cleanup
//! - Sync/async callbacks from python: with propper eventloop handling
//!
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
//! - Store mutable fields in tokio's `Mutex<T>`
//! - For async code: take `&self` and `.lock().await`
//! - For sync code: take `&mut self` and `.get_mut()`
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
fn needs_tokio_runtime() {
tokio::runtime::Handle::current();
}
type SyncCallback = Box<dyn Fn() + Send + Sync>;
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
enum AsyncTaskMessage {
SyncCallback(SyncCallback),
AsyncCallback(AsyncCallback),
}
async fn async_task(
sender: mpsc::UnboundedSender<()>,
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
) {
log::info!("RUST: async task started");
// task state
let mut interval = tokio::time::interval(Duration::from_secs(1));
let mut sync_cbs: Vec<SyncCallback> = vec![];
let mut async_cbs: Vec<AsyncCallback> = vec![];
loop {
tokio::select! {
// handle incoming messages from task-handle
message = receiver.recv() => {
// handle closed channel by exiting
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming event
match message {
AsyncTaskMessage::SyncCallback(cb) => {
sync_cbs.push(cb);
}
AsyncTaskMessage::AsyncCallback(cb) => {
async_cbs.push(cb);
}
}
}
// handle all other events
_ = interval.tick() => {
log::info!("RUST: async task tick");
// call back all sync callbacks
for cb in &sync_cbs {
cb();
}
// call back all async callbacks
for cb in &async_cbs {
cb().await;
}
// send event on unbounded channel
sender.send(()).expect("handle receiver cannot be closed/dropped");
}
}
}
log::info!("RUST: async task stopped");
}
// #[gen_stub_pyclass]
#[pyclass(name = "AsyncTaskHandle")]
#[derive(Debug)]
struct PyAsyncTaskHandle {
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
receiver: mpsc::UnboundedReceiver<()>,
}
#[allow(clippy::expect_used)]
impl PyAsyncTaskHandle {
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_mut()
.expect("The sender should only be None after de-initialization.")
}
const fn new(
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
receiver: mpsc::UnboundedReceiver<()>,
) -> Self {
Self {
sender: Some(sender),
receiver,
}
}
}
// #[gen_stub_pymethods]
#[pymethods]
impl PyAsyncTaskHandle {
#[new]
fn py_new(py: Python<'_>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channel TOWARDS our task
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
// create communication channel FROM our task
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
// perform necessary setup within tokio context - or it crashes
let () = get_runtime().block_on(async { needs_tokio_runtime() });
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
_ = get_runtime().spawn_with_scope(py, async move {
async_task(t_sender, t_receiver).await;
});
Ok(Self::new(h_sender, h_receiver))
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_sync_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], None]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
})))
.pyerr()?;
Ok(())
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_async_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
let c = Python::with_gil(|py| callback.clone_ref(py));
async move {
if let Some(f) = Python::with_gil(|py| {
let coroutine = c.call0(py).write_unraisable_with(py)?;
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
.write_unraisable_with(py)
}) {
_ = f.await.write_unraisable();
}
}
.boxed()
})))
.pyerr()?;
Ok(())
}
async fn receive_unit(&mut self) -> PyResult<()> {
self.receiver
.recv()
.await
.ok_or(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
))
}
fn drain_units(&mut self) -> PyResult<i32> {
let mut cnt = 0;
loop {
match self.receiver.try_recv() {
Err(TryRecvError::Disconnected) => {
return Err(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
));
}
Err(TryRecvError::Empty) => return Ok(cnt),
Ok(()) => {
cnt += 1;
continue;
}
}
}
}
// #[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
// #[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
}
}
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAsyncTaskHandle>()?;
Ok(())
}

View File

@@ -1,5 +1,217 @@
//! TODO: crate documentation
//!
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(tuple_trait)]
#![feature(unboxed_closures)]
// #![feature(stmt_expr_attributes)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
extern crate core;
mod allow_threading;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::prelude::*;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
mod allow_threading;
/// Namespace for all the constants used by this crate.
pub(crate) mod r#const {
pub const MPSC_CHANNEL_SIZE: usize = 1024;
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
use extend::ext;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::marker::Ungil;
use pyo3::types::PyBytes;
use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;
#[ext(pub, name = ByteArrayExt)]
impl [u8] {
fn pybytes(&self) -> Py<PyBytes> {
Python::with_gil(|py| PyBytes::new(py, self).unbind())
}
}
#[ext(pub, name = ResultExt)]
impl<T, E> Result<T, E>
where
E: ToString,
{
fn pyerr(self) -> PyResult<T> {
self.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}
}
pub trait FutureExt: Future + Sized {
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
fn allow_threads_py(self) -> AllowThreads<Self>
where
AllowThreads<Self>: Future,
{
AllowThreads::new(self)
}
}
impl<T: Future> FutureExt for T {}
#[ext(pub, name = PyErrExt)]
impl PyErr {
fn receiver_channel_closed() -> Self {
PyConnectionError::new_err("Receiver channel closed unexpectedly")
}
}
#[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> {
Python::with_gil(|py| self.write_unraisable_with(py))
}
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
match self {
Ok(v) => Some(v),
Err(e) => {
// write error back to python
e.write_unraisable(py, None);
None
}
}
}
}
#[ext(pub, name = TokioRuntimeExt)]
impl Runtime {
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
}
}
#[ext(pub, name = TokioMpscSenderExt)]
impl<T> mpsc::Sender<T> {
/// Sends a value, waiting until there is capacity.
///
/// A successful send occurs when it is determined that the other end of the
/// channel has not hung up already. An unsuccessful send would be one where
/// the corresponding receiver has already been closed.
async fn send_py(&self, value: T) -> PyResult<()> {
self.send(value)
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
}
#[ext(pub, name = TokioMpscReceiverExt)]
impl<T> mpsc::Receiver<T> {
/// Receives the next value for this receiver.
async fn recv_py(&mut self) -> PyResult<T> {
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
}
/// Receives at most `limit` values for this receiver and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
// get updates from receiver channel
let mut updates = Vec::with_capacity(limit);
let received = self.recv_many(&mut updates, limit).await;
// if we received zero items, then the channel was unexpectedly closed
if limit != 0 && received == 0 {
return Err(PyErr::receiver_channel_closed());
}
Ok(updates)
}
/// Tries to receive the next value for this receiver.
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
match self.try_recv() {
Ok(v) => Ok(Some(v)),
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
}
}
}
}
pub(crate) mod private {
use std::marker::Sized;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
#[pymodule(name = "exo_pyo3_bindings")]
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// install logger
pyo3_log::init();
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues...
ident_submodule(m)?;
multiaddr_submodule(m)?;
networking_submodule(m)?;
// top-level constructs
// TODO: ...
Ok(())
}
define_stub_info_gatherer!(stub_info);

View File

@@ -0,0 +1,572 @@
#![allow(
clippy::multiple_inherent_impl,
clippy::unnecessary_wraps,
clippy::unused_self,
clippy::needless_pass_by_value
)]
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::pyclass;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use libp2p::{gossipsub, mdns};
use networking::discovery;
use networking::swarm::create_swarm;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
use std::net::IpAddr;
use tokio::sync::{Mutex, mpsc, oneshot};
mod exception {
use pyo3::types::PyTuple;
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
use pyo3_stub_gen::derive::*;
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
pub struct PyNoPeersSubscribedToTopicError {}
impl PyNoPeersSubscribedToTopicError {
const MSG: &'static str = "\
No peers are currently subscribed to receive messages on this topic. \
Wait for peers to subscribe or check your network connectivity.";
/// Creates a new [ `PyErr` ] of this type.
///
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
pub(crate) fn new_err() -> PyErr {
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PyNoPeersSubscribedToTopicError {
#[new]
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="AllQueuesFullError")]
pub struct PyAllQueuesFullError {}
impl PyAllQueuesFullError {
const MSG: &'static str =
"All libp2p peers are unresponsive, resend the message or reconnect.";
/// Creates a new [ `PyErr` ] of this type.
///
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
pub(crate) fn new_err() -> PyErr {
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PyAllQueuesFullError {
#[new]
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("PeerId(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
}
/// Connection or disconnection event discriminant type.
#[gen_stub_pyclass_enum]
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
#[derive(Debug, Clone, PartialEq)]
enum PyConnectionUpdateType {
Connected = 0,
Disconnected,
}
#[gen_stub_pyclass]
#[pyclass(frozen, name = "ConnectionUpdate")]
#[derive(Debug, Clone)]
struct PyConnectionUpdate {
/// Whether this is a connection or disconnection event
#[pyo3(get)]
update_type: PyConnectionUpdateType,
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: PyPeerId,
/// Remote connection's IPv4 address.
#[pyo3(get)]
remote_ipv4: String,
/// Remote connection's TCP port.
#[pyo3(get)]
remote_tcp_port: u16,
}
enum ToTask {
GossipsubSubscribe {
topic: String,
result_tx: oneshot::Sender<PyResult<bool>>,
},
GossipsubUnsubscribe {
topic: String,
result_tx: oneshot::Sender<bool>,
},
GossipsubPublish {
topic: String,
data: Vec<u8>,
result_tx: oneshot::Sender<PyResult<MessageId>>,
},
}
#[allow(clippy::enum_glob_use)]
async fn networking_task(
mut swarm: networking::swarm::Swarm,
mut to_task_rx: mpsc::Receiver<ToTask>,
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
) {
use SwarmEvent::*;
use ToTask::*;
use mdns::Event::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
loop {
tokio::select! {
message = to_task_rx.recv() => {
// handle closed channel
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming messages
match message {
GossipsubSubscribe { topic, result_tx } => {
// try to subscribe
let result = swarm.behaviour_mut()
.gossipsub.subscribe(&IdentTopic::new(topic));
// send response oneshot
if let Err(e) = result_tx.send(result.pyerr()) {
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubUnsubscribe { topic, result_tx } => {
// try to unsubscribe from the topic
let result = swarm.behaviour_mut()
.gossipsub.unsubscribe(&IdentTopic::new(topic));
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(result) {
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubPublish { topic, data, result_tx } => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
let result = swarm.behaviour_mut().gossipsub.publish(
IdentTopic::new(topic), data);
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
} else if let Err(PublishError::AllQueuesFull(_)) = result {
Err(exception::PyAllQueuesFullError::new_err())
} else {
result.pyerr()
};
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(pyresult) {
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
continue;
}
}
}
}
// architectural solution to this problem:
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
//
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
// then for actual communication it will dial those peers if need-be
swarm_event = swarm.select_next_some() => {
match swarm_event {
Behaviour(Gossipsub(gossipsub::Event::Message {
message: Message {
topic,
data,
..
},
..
})) => {
// topic-ID is just the topic hash!!! (since we used identity hasher)
let message = (topic.into_string(), data);
// send incoming message to channel (or exit if connection closed)
if let Err(e) = gossipsub_message_tx.send(message).await {
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
continue;
}
};
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
continue;
}
};
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
e => {
log::info!("RUST: other event {e:?}");
}
}
}
}
}
log::info!("RUST: networking task stopped");
}
#[gen_stub_pyclass]
#[pyclass(name = "NetworkingHandle")]
#[derive(Debug)]
struct PyNetworkingHandle {
// channels
to_task_tx: Option<mpsc::Sender<ToTask>>,
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
}
impl Drop for PyNetworkingHandle {
fn drop(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
}
#[allow(clippy::expect_used)]
impl PyNetworkingHandle {
fn new(
to_task_tx: mpsc::Sender<ToTask>,
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
) -> Self {
Self {
to_task_tx: Some(to_task_tx),
connection_update_rx: Mutex::new(connection_update_rx),
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
}
}
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
self.to_task_tx
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
}
#[gen_stub_pymethods]
#[pymethods]
impl PyNetworkingHandle {
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
// immediately beforehand to release the interpreter.
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
// ---- Lifecycle management methods ----
#[new]
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channels
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
// get identity
let identity = identity.borrow().0.clone();
// create networking swarm (within tokio context!! or it crashes)
let swarm = get_runtime()
.block_on(async { create_swarm(identity) })
.pyerr()?;
// spawn tokio task running the networking logic
get_runtime().spawn(async move {
networking_task(
swarm,
to_task_rx,
connection_update_tx,
gossipsub_message_tx,
)
.await;
});
Ok(Self::new(
to_task_tx,
connection_update_rx,
gossipsub_message_rx,
))
}
#[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
#[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
// ---- Connection update receiver methods ----
/// Receives the next `ConnectionUpdate` from networking.
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
}
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
///
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
/// will sleep until a `ConnectionUpdate`s is sent.
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next `ConnectionUpdate` from networking.
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
// self.connection_update_rx.blocking_lock().try_recv_py()
// }
//
// /// Checks if the `ConnectionUpdate` channel is empty.
// fn connection_update_is_empty(&self) -> bool {
// self.connection_update_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `ConnectionUpdate`s in the channel.
// fn connection_update_len(&self) -> usize {
// self.connection_update_rx.blocking_lock().len()
// }
// ---- Gossipsub management methods ----
/// Subscribe to a `GossipSub` topic.
///
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
self.to_task_tx()
.send_py(ToTask::GossipsubSubscribe {
topic,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())?
}
/// Unsubscribes from a `GossipSub` topic.
///
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
let (tx, rx) = oneshot::channel();
// send off request to unsubscribe
self.to_task_tx()
.send_py(ToTask::GossipsubUnsubscribe {
topic,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & convert any errors
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())
}
/// Publishes a message with multiple topics to the `GossipSub` network.
///
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
topic,
data,
result_tx: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
// wait for response & return any errors => ignore messageID for now!!!
let _ = rx
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())??;
Ok(())
}
// ---- Gossipsub message receiver methods ----
/// Receives the next message from the `GossipSub` network.
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
self.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
.map(|(t, d)| (t, d.pybytes()))
}
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
Ok(self
.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await?
.into_iter()
.map(|(t, d)| (t, d.pybytes()))
.collect())
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next message from the `GossipSub` network.
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
// Ok(self
// .gossipsub_message_rx
// .blocking_lock()
// .try_recv_py()?
// .map(|(t, d)| (t, d.pybytes())))
// }
//
// /// Checks if the `GossipSub` message channel is empty.
// fn gossipsub_is_empty(&self) -> bool {
// self.gossipsub_message_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `GossipSub` messages in the channel.
// fn gossipsub_len(&self) -> usize {
// self.gossipsub_message_rx.blocking_lock().len()
// }
}
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
m.add_class::<exception::PyAllQueuesFullError>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyConnectionUpdate>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyNetworkingHandle>()?;
Ok(())
}

View File

@@ -0,0 +1,159 @@
use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
/// Identity keypair of a node.
#[gen_stub_pyclass]
#[pyclass(name = "Keypair", frozen)]
#[repr(transparent)]
pub struct PyKeypair(pub Keypair);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyKeypair {
/// Generate a new Ed25519 keypair.
#[staticmethod]
fn generate_ed25519() -> Self {
Self(Keypair::generate_ed25519())
}
/// Generate a new ECDSA keypair.
#[staticmethod]
fn generate_ecdsa() -> Self {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
}
/// Encode a private key as protobuf structure.
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?;
Ok(PyBytes::new(py, &bytes))
}
/// Convert the `Keypair` into the corresponding `PeerId`.
fn to_peer_id(&self) -> PyPeerId {
PyPeerId(self.0.public().to_peer_id())
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
}
}
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -0,0 +1,8 @@
//! A module for exposing Rust's libp2p datatypes over Pyo3
//!
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
//! independent identity type of some kind or another. This may require handshaking.
//!
pub mod ident;
pub mod multiaddr;

View File

@@ -0,0 +1,81 @@
use crate::ext::ResultExt as _;
use libp2p::Multiaddr;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::str::FromStr as _;
/// Representation of a Multiaddr.
#[gen_stub_pyclass]
#[pyclass(name = "Multiaddr", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyMultiaddr(pub Multiaddr);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyMultiaddr {
/// Create a new, empty multiaddress.
#[staticmethod]
fn empty() -> Self {
Self(Multiaddr::empty())
}
/// Create a new, empty multiaddress with the given capacity.
#[staticmethod]
fn with_capacity(n: usize) -> Self {
Self(Multiaddr::with_capacity(n))
}
/// Parse a `Multiaddr` value from its byte slice representation.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
}
/// Parse a `Multiaddr` value from its string representation.
#[staticmethod]
fn from_string(string: String) -> PyResult<Self> {
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
}
/// Return the length in bytes of this multiaddress.
fn len(&self) -> usize {
self.0.len()
}
/// Returns true if the length of this multiaddress is 0.
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// Return a copy of this [`Multiaddr`]'s byte representation.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_vec();
PyBytes::new(py, &bytes)
}
/// Convert a Multiaddr to a string.
fn to_string(&self) -> String {
self.0.to_string()
}
#[gen_stub(skip)]
fn __repr__(&self) -> String {
format!("Multiaddr({})", self.0)
}
#[gen_stub(skip)]
fn __str__(&self) -> String {
self.to_string()
}
}
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMultiaddr>()?;
Ok(())
}

View File

@@ -0,0 +1,54 @@
#[cfg(test)]
mod tests {
use core::mem::drop;
use core::option::Option::Some;
use core::time::Duration;
use tokio;
use tokio::sync::mpsc;
#[tokio::test]
async fn test_drop_channel() {
struct Ping;
let (tx, mut rx) = mpsc::channel::<Ping>(10);
let _ = tokio::spawn(async move {
println!("TASK: entered");
loop {
tokio::select! {
result = rx.recv() => {
match result {
Some(_) => {
println!("TASK: pinged");
}
None => {
println!("TASK: closing channel");
break;
}
}
}
_ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {
println!("TASK: heartbeat");
}
}
}
println!("TASK: exited");
});
let tx2 = tx.clone();
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
tx.send(Ping).await.expect("Should not fail");
drop(tx);
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
tx2.send(Ping).await.expect("Should not fail");
drop(tx2);
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
}
}

View File

@@ -0,0 +1,34 @@
import asyncio
import pytest
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError
@pytest.mark.asyncio
async def test_sleep_on_multiple_items() -> None:
print("PYTHON: starting handle")
h = NetworkingHandle(Keypair.generate_ed25519())
ct = asyncio.create_task(_await_cons(h))
mt = asyncio.create_task(_await_msg(h))
# sleep for 4 ticks
for i in range(4):
await asyncio.sleep(1)
try:
await h.gossipsub_publish("topic", b"somehting or other")
except NoPeersSubscribedToTopicError as e:
print("caught it", e)
async def _await_cons(h: NetworkingHandle):
while True:
c = await h.connection_update_recv()
print(f"PYTHON: connection update: {c}")
async def _await_msg(h: NetworkingHandle):
while True:
m = await h.gossipsub_recv()
print(f"PYTHON: message: {m}")

View File

@@ -5,6 +5,7 @@ edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "networking"
path = "src/lib.rs"
@@ -12,14 +13,27 @@ path = "src/lib.rs"
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures-lite = { workspace = true }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
@@ -27,4 +41,4 @@ keccak-const = { workspace = true }
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
libp2p = { workspace = true, features = ["full"] }

View File

@@ -1,3 +1,7 @@
use futures::stream::StreamExt as _;
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm};
use tokio::{io, io::AsyncBufReadExt as _, select};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter;
@@ -6,4 +10,65 @@ async fn main() {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.try_init();
// Configure swarm
let mut swarm =
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
// Create a Gossipsub topic & subscribe
let topic = gossipsub::IdentTopic::new("test-net");
swarm
.behaviour_mut()
.gossipsub
.subscribe(&topic)
.expect("Subscribing to topic failed");
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
// on gossipsub outgoing
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
// on gossipsub incoming
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
discovery::Event::ConnectionClosed {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
}
// ignore outgoing errors: those are normal
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }
}
}
}
}

View File

@@ -0,0 +1,127 @@
// Copyright 2018 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::stream::StreamExt;
use libp2p::{
gossipsub, mdns, noise,
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::time::Duration;
use std::{error::Error, hash::Hash};
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;
// We create a custom network behaviour that combines Gossipsub and Mdns.
#[derive(NetworkBehaviour)]
struct MyBehaviour {
gossipsub: gossipsub::Behaviour,
mdns: mdns::tokio::Behaviour,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.try_init();
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
.with_tokio()
.with_tcp(
tcp::Config::default(),
noise::Config::new,
yamux::Config::default,
)?
.with_behaviour(|key| {
// Set a custom gossipsub configuration
let gossipsub_config = gossipsub::ConfigBuilder::default()
.heartbeat_interval(Duration::from_secs(10))
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
.build()
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
// build a gossipsub network behaviour
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(key.clone()),
gossipsub_config,
)?;
let mdns =
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
Ok(MyBehaviour { gossipsub, mdns })
})?
.build();
println!("Running swarm with identity {}", swarm.local_peer_id());
// Create a Gossipsub topic
let topic = gossipsub::IdentTopic::new("test-net");
// subscribes to our topic
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
// Kick it off
loop {
select! {
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
}
event = swarm.select_next_some() => match event {
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
for (peer_id, multiaddr) in list {
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
}
},
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
})) => println!(
"Got message: '{}' with id: {id} from peer: {peer_id}",
String::from_utf8_lossy(&message.data),
),
SwarmEvent::NewListenAddr { address, .. } => {
println!("Local node is listening on {address}");
}
e => {
println!("Other swarm event: {:?}", e);
}
}
}
}
}

View File

@@ -0,0 +1,44 @@
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
https://stackoverflow.com/questions/17169298/af-packet-on-osx
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
https://www.unix.com/man_page/mojave/4/ip/
^OS X manpages for IP
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
^driver kit, system extensions & kexts for macOS
----
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
----> here is a guide on how to set up thunderbolt-ethernet on linux
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
^GPT discussion about making socket interface
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
^GPT discussion about accessing TB on MacOS low level interactions
--------------------------------
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
---------------------------------
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c

View File

@@ -0,0 +1,383 @@
use crate::ext::MultiaddrExt;
use crate::keep_alive;
use delegate::delegate;
use either::Either;
use futures::FutureExt;
use futures_timer::Delay;
use libp2p::core::transport::PortUse;
use libp2p::core::{ConnectedPoint, Endpoint};
use libp2p::swarm::behaviour::ConnectionEstablished;
use libp2p::swarm::dial_opts::DialOpts;
use libp2p::swarm::{
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
THandlerOutEvent, ToSwarm, dummy,
};
use libp2p::{Multiaddr, PeerId, identity, mdns};
use std::collections::{BTreeSet, HashMap};
use std::convert::Infallible;
use std::io;
use std::net::IpAddr;
use std::task::{Context, Poll};
use std::time::Duration;
use util::wakerdeque::WakerDeque;
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
mod managed {
use libp2p::swarm::NetworkBehaviour;
use libp2p::{identity, mdns, ping};
use std::io;
use std::time::Duration;
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
#[derive(NetworkBehaviour)]
pub struct Behaviour {
mdns: mdns::tokio::Behaviour,
ping: ping::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
mdns: mdns_behaviour(keypair)?,
ping: ping_behaviour(),
})
}
}
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
use mdns::{Config, tokio};
// mDNS config => enable IPv6
let mdns_config = Config {
ttl: MDNS_RECORD_TTL,
query_interval: MDNS_QUERY_INTERVAL,
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
..Default::default()
};
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
Ok(mdns_behaviour?)
}
fn ping_behaviour() -> ping::Behaviour {
ping::Behaviour::new(
ping::Config::new()
.with_timeout(PING_TIMEOUT)
.with_interval(PING_INTERVAL),
)
}
}
/// Events for when a listening connection is truly established and truly closed.
#[derive(Debug, Clone)]
pub enum Event {
ConnectionEstablished {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
ConnectionClosed {
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
},
}
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
///
/// The behaviour operates as such:
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
/// to the swarm.
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
/// immediately, and expired but connected peers are disconnected from immediately.
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
/// connected peers are disconnected from.
pub struct Behaviour {
// state-tracking for managed behaviors & mDNS-discovered peers
managed: managed::Behaviour,
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
retry_delay: Delay, // retry interval
// pending events to emmit => waker-backed Deque to control polling
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
Ok(Self {
managed: managed::Behaviour::new(keypair)?,
mdns_discovered: HashMap::new(),
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
pending_events: WakerDeque::new(),
})
}
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
self.pending_events.push_back(ToSwarm::Dial {
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
})
}
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
// push front to make this IMMEDIATE
self.pending_events.push_front(ToSwarm::CloseConnection {
peer_id,
connection: CloseConnection::One(connection),
})
}
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
self.dial(p, ma.clone()); // always connect
// get peer's multi-addresses or insert if missing
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
continue;
};
// multiaddress should never already be present - else something has gone wrong
let is_new_addr = mas.insert(ma);
assert!(is_new_addr, "cannot discover a discovered peer");
}
}
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
for (p, ma) in peers {
// at this point, we *must* have the peer
let mas = self
.mdns_discovered
.get_mut(&p)
.expect("nonexistent peer cannot expire");
// at this point, we *must* have the multiaddress
let was_present = mas.remove(&ma);
assert!(was_present, "nonexistent multiaddress cannot expire");
// if empty, remove the peer-id entirely
if mas.is_empty() {
self.mdns_discovered.remove(&p);
}
}
}
fn on_connection_established(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out connected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
fn on_connection_closed(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
remote_ip: IpAddr,
remote_tcp_port: u16,
) {
// send out disconnected event
self.pending_events
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
peer_id,
connection_id,
remote_ip,
remote_tcp_port,
}));
}
}
impl NetworkBehaviour for Behaviour {
type ConnectionHandler =
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
type ToSwarm = Event;
// simply delegate to underlying mDNS behaviour
delegate! {
to self.managed {
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
}
}
fn handle_established_inbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
local_addr: &Multiaddr,
remote_addr: &Multiaddr,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_inbound_connection(
connection_id,
peer,
local_addr,
remote_addr,
)?,
))
}
#[allow(clippy::needless_question_mark)]
fn handle_established_outbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
addr: &Multiaddr,
role_override: Endpoint,
port_use: PortUse,
) -> Result<THandler<Self>, ConnectionDenied> {
Ok(ConnectionHandler::select(
dummy::ConnectionHandler,
self.managed.handle_established_outbound_connection(
connection_id,
peer,
addr,
role_override,
port_use,
)?,
))
}
fn on_connection_handler_event(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
event: THandlerOutEvent<Self>,
) {
match event {
Either::Left(ev) => libp2p::core::util::unreachable(ev),
Either::Right(ev) => {
self.managed
.on_connection_handler_event(peer_id, connection_id, ev)
}
}
}
// hook into these methods to drive behavior
fn on_swarm_event(&mut self, event: FromSwarm) {
self.managed.on_swarm_event(event); // let mDNS handle swarm events
// handle swarm events to update internal state:
match event {
FromSwarm::ConnectionEstablished(ConnectionEstablished {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection established event which is filtered correctly
self.on_connection_established(peer_id, connection_id, ip, port)
}
}
FromSwarm::ConnectionClosed(ConnectionClosed {
peer_id,
connection_id,
endpoint,
..
}) => {
let remote_address = match endpoint {
ConnectedPoint::Dialer { address, .. } => address,
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
// handle connection closed event which is filtered correctly
self.on_connection_closed(peer_id, connection_id, ip, port)
}
}
// since we are running TCP/IP transport layer, we are assuming that
// no address changes can occur, hence encountering one is a fatal error
FromSwarm::AddressChange(a) => {
unreachable!("unhandlable: address change encountered: {:?}", a)
}
_ => {}
}
}
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
// delegate to managed behaviors for any behaviors they need to perform
match self.managed.poll(cx) {
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
match e {
// handle discovered and expired events from mDNS
managed::BehaviourEvent::Mdns(e) => match e.clone() {
mdns::Event::Discovered(peers) => {
self.handle_mdns_discovered(peers);
}
mdns::Event::Expired(peers) => {
self.handle_mdns_expired(peers);
}
},
// handle ping events => if error then disconnect
managed::BehaviourEvent::Ping(e) => {
if let Err(_) = e.result {
self.close_connection(e.peer, e.connection.clone())
}
}
}
// since we just consumed an event, we should immediately wake just in case
// there are more events to come where that came from
cx.waker().wake_by_ref();
}
// forward any other mDNS event to the swarm or its connection handler(s)
Poll::Ready(e) => {
return Poll::Ready(
e.map_out(|_| unreachable!("events returning to swarm already handled"))
.map_in(Either::Right),
);
}
Poll::Pending => {}
}
// retry connecting to all mDNS peers periodically (fails safely if already connected)
if self.retry_delay.poll_unpin(cx).is_ready() {
for (p, mas) in self.mdns_discovered.clone() {
for ma in mas {
self.dial(p, ma)
}
}
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
}
// send out any pending events from our own service
if let Some(e) = self.pending_events.pop_front(cx) {
return Poll::Ready(e.map_in(Either::Left));
}
// wait for pending events
Poll::Pending
}
}

View File

@@ -0,0 +1,44 @@
use delegate::delegate;
use libp2p::swarm::handler::ConnectionEvent;
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
use std::task::{Context, Poll};
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
/// the connection alive.
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnectionHandler(dummy::ConnectionHandler);
impl ConnectionHandler {
pub fn new() -> Self {
ConnectionHandler(dummy::ConnectionHandler)
}
}
impl handler::ConnectionHandler for ConnectionHandler {
// delegate types and implementation mostly to dummy handler
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
type InboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
type OutboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
type OutboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
delegate! {
to self.0 {
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
}
}
// specifically override this to force connection to stay alive
fn connection_keep_alive(&self) -> bool {
true
}
}

View File

@@ -0,0 +1,64 @@
//! TODO: crate documentation
//!
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
// #![feature(stmt_expr_attributes)]
// #![feature(unboxed_closures)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
pub mod discovery;
pub mod keep_alive;
pub mod swarm;
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {
use extend::ext;
use libp2p::Multiaddr;
use libp2p::multiaddr::Protocol;
use std::net::IpAddr;
#[ext(pub, name = MultiaddrExt)]
impl Multiaddr {
/// If the multiaddress corresponds to a TCP address, extracts it
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {
let mut ps = self.into_iter();
let ip = if let Some(p) = ps.next() {
match p {
Protocol::Ip4(ip) => IpAddr::V4(ip),
Protocol::Ip6(ip) => IpAddr::V6(ip),
_ => return None,
}
} else {
return None;
};
let Some(Protocol::Tcp(port)) = ps.next() else {
return None;
};
Some((ip, port))
}
}
}
pub(crate) mod private {
#![allow(dead_code)]
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}

View File

@@ -0,0 +1,143 @@
use crate::alias;
use crate::swarm::transport::tcp_transport;
pub use behaviour::{Behaviour, BehaviourEvent};
use libp2p::{SwarmBuilder, identity};
pub type Swarm = libp2p::Swarm<Behaviour>;
/// The current version of the network: this prevents devices running different versions of the
/// software from interacting with each other.
///
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
/// even be, and how to inject the right version into this config/initialization. E.g. should
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
/// Create and configure a swarm which listens to all ports on OS
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
.with_behaviour(Behaviour::new)?
.build();
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(swarm)
}
mod transport {
use crate::alias;
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
use futures::{AsyncRead, AsyncWrite};
use keccak_const::Sha3_256;
use libp2p::core::muxing;
use libp2p::core::transport::Boxed;
use libp2p::pnet::{PnetError, PnetOutput};
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
use std::{env, sync::LazyLock};
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
/// See [`pnet_upgrade`] for more.
static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| {
let builder = Sha3_256::new().update(b"exo_discovery_network");
if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) {
let bytes = var.into_bytes();
builder.update(&bytes)
} else {
builder.update(NETWORK_VERSION)
}
.finalize()
});
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
/// also different-versioned instances of this same network.
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
async fn pnet_upgrade<TSocket>(
socket: TSocket,
_: impl Sized,
) -> Result<PnetOutput<TSocket>, PnetError>
where
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
use pnet::{PnetConfig, PreSharedKey};
PnetConfig::new(PreSharedKey::new(*PNET_PRESHARED_KEY))
.handshake(socket)
.await
}
/// TCP/IP transport layer configuration.
pub fn tcp_transport(
keypair: &identity::Keypair,
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
use libp2p::{
core::upgrade::Version,
tcp::{Config, tokio},
};
// `TCP_NODELAY` enabled => avoid latency
let tcp_config = Config::default().nodelay(true);
// V1 + lazy flushing => 0-RTT negotiation
let upgrade_version = Version::V1Lazy;
// Noise is faster than TLS + we don't care much for security
let noise_config = noise::Config::new(keypair)?;
// Use default Yamux config for multiplexing
let yamux_config = yamux::Config::default();
// Create new Tokio-driven TCP/IP transport layer
let base_transport = tokio::Transport::new(tcp_config)
.and_then(pnet_upgrade)
.upgrade(upgrade_version)
.authenticate(noise_config)
.multiplex(yamux_config);
// Return boxed transport (to flatten complex type)
Ok(base_transport.boxed())
}
}
mod behaviour {
use crate::{alias, discovery};
use libp2p::swarm::NetworkBehaviour;
use libp2p::{gossipsub, identity};
/// Behavior of the Swarm which composes all desired behaviors:
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
#[derive(NetworkBehaviour)]
pub struct Behaviour {
pub discovery: discovery::Behaviour,
pub gossipsub: gossipsub::Behaviour,
}
impl Behaviour {
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
Ok(Self {
discovery: discovery::Behaviour::new(keypair)?,
gossipsub: gossipsub_behaviour(keypair),
})
}
}
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
// build a gossipsub network behaviour
// => signed message authenticity + strict validation mode means the message-ID is
// automatically provided by gossipsub w/out needing to provide custom message-ID function
gossipsub::Behaviour::new(
MessageAuthenticity::Signed(keypair.clone()),
ConfigBuilder::default()
.max_transmit_size(1024 * 1024)
.validation_mode(ValidationMode::Strict)
.build()
.expect("the configuration should always be valid"),
)
.expect("creating gossipsub behavior should always work")
}
}

View File

@@ -0,0 +1,7 @@
// maybe this will hold test in the future...??
#[cfg(test)]
mod tests {
#[test]
fn does_nothing() {}
}

15
rust/util/Cargo.toml Normal file
View File

@@ -0,0 +1,15 @@
[package]
name = "util"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "util"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]

1
rust/util/src/lib.rs Normal file
View File

@@ -0,0 +1 @@
pub mod wakerdeque;

View File

@@ -0,0 +1,55 @@
use std::collections::VecDeque;
use std::fmt::{Debug, Formatter};
use std::task::{Context, Waker};
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
pub struct WakerDeque<T> {
waker: Option<Waker>,
deque: VecDeque<T>,
}
impl<T: Debug> Debug for WakerDeque<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.deque.fmt(f)
}
}
impl<T> WakerDeque<T> {
pub fn new() -> Self {
Self {
waker: None,
deque: VecDeque::new(),
}
}
fn update(&mut self, cx: &mut Context<'_>) {
self.waker = Some(cx.waker().clone());
}
fn wake(&mut self) {
let Some(ref mut w) = self.waker else { return };
w.wake_by_ref();
self.waker = None;
}
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
self.update(cx);
self.deque.pop_front()
}
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
self.update(cx);
self.deque.pop_back()
}
pub fn push_front(&mut self, value: T) {
self.wake();
self.deque.push_front(value);
}
pub fn push_back(&mut self, value: T) {
self.wake();
self.deque.push_back(value);
}
}

View File

@@ -47,7 +47,6 @@ class DownloadCoordinator:
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
offline: bool = False
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -63,8 +62,6 @@ class DownloadCoordinator:
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
@@ -110,17 +107,13 @@ class DownloadCoordinator:
self._last_progress_time[model_id] = current_time()
async def run(self) -> None:
logger.info(
f"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}"
)
if not self.offline:
self._test_internet_connection()
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
if not self.offline:
tg.start_soon(self._check_internet_connection)
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
@@ -209,20 +202,6 @@ class DownloadCoordinator:
)
return
if self.offline:
logger.warning(
f"Offline mode: model {model_id} is not fully available locally, cannot download"
)
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=f"Model files not found locally in offline mode: {model_id}",
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(NodeDownloadProgress(download_progress=failed))
return
# Start actual download
self._start_download_task(shard, initial_progress)

View File

@@ -448,13 +448,12 @@ async def download_file_with_retry(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
skip_internet: bool = False,
) -> Path:
n_attempts = 3
for attempt in range(n_attempts):
try:
return await _download_file(
model_id, revision, path, target_dir, on_progress, skip_internet
model_id, revision, path, target_dir, on_progress
)
except HuggingFaceAuthenticationError:
raise
@@ -488,14 +487,10 @@ async def _download_file(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
skip_internet: bool = False,
) -> Path:
target_path = target_dir / path
if await aios.path.exists(target_path):
if skip_internet:
return target_path
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
@@ -515,11 +510,6 @@ async def _download_file(
)
return target_path
if skip_internet:
raise FileNotFoundError(
f"File {path} not found locally and cannot download in offline mode"
)
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -824,7 +814,6 @@ async def download_shard(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
skip_internet=skip_internet,
)
if not skip_download:

View File

@@ -1,230 +0,0 @@
"""Tests for offline/air-gapped mode."""
from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import AsyncMock, patch
import aiofiles
import aiofiles.os as aios
import pytest
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
download_file_with_retry,
fetch_file_list_with_cache,
)
from exo.shared.types.common import ModelId
from exo.shared.types.worker.downloads import FileListEntry
@pytest.fixture
def model_id() -> ModelId:
return ModelId("test-org/test-model")
@pytest.fixture
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
yield models_dir
class TestDownloadFileOffline:
"""Tests for _download_file with skip_internet=True."""
async def test_returns_local_file_without_http_verification(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists locally, return it immediately
without making any HTTP calls (no file_meta verification)."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"model weights data")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_raises_file_not_found_for_missing_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file does NOT exist locally,
raise FileNotFoundError instead of attempting download."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError, match="offline mode"):
await _download_file(
model_id,
"main",
"missing_model.safetensors",
target_dir,
skip_internet=True,
)
async def test_returns_local_file_in_subdirectory(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists in a subdirectory,
return it without HTTP calls."""
target_dir = tmp_path / "downloads"
subdir = target_dir / "transformer"
await aios.makedirs(subdir, exist_ok=True)
local_file = subdir / "diffusion_pytorch_model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"weights")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"transformer/diffusion_pytorch_model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
class TestDownloadFileWithRetryOffline:
"""Tests for download_file_with_retry with skip_internet=True."""
async def test_propagates_skip_internet_to_download_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Verify skip_internet is passed through to _download_file."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "config.json"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b'{"model_type": "qwen2"}')
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_file_not_found_does_not_retry(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""FileNotFoundError from offline mode should not trigger retries."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError):
await download_file_with_retry(
model_id,
"main",
"nonexistent.safetensors",
target_dir,
skip_internet=True,
)
class TestFetchFileListOffline:
"""Tests for fetch_file_list_with_cache with skip_internet=True."""
async def test_uses_cached_file_list(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and cache file exists, use it without network."""
from pydantic import TypeAdapter
cache_dir = temp_models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
cached_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
FileListEntry(type="file", path="config.json", size=200),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
assert result == cached_list
mock_fetch.assert_not_called()
async def test_falls_back_to_local_directory_scan(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and no cache but local files exist,
build file list from local directory."""
import json
model_dir = temp_models_dir / model_id.normalize()
await aios.makedirs(model_dir, exist_ok=True)
async with aiofiles.open(model_dir / "config.json", "w") as f:
await f.write('{"model_type": "qwen2"}')
index_data = {
"metadata": {},
"weight_map": {"model.layers.0.weight": "model.safetensors"},
}
async with aiofiles.open(model_dir / "model.safetensors.index.json", "w") as f:
await f.write(json.dumps(index_data))
async with aiofiles.open(model_dir / "model.safetensors", "wb") as f:
await f.write(b"x" * 500)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
mock_fetch.assert_not_called()
paths = {entry.path for entry in result}
assert "config.json" in paths
assert "model.safetensors" in paths
async def test_raises_when_no_cache_and_no_local_files(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and neither cache nor local files exist,
raise FileNotFoundError."""
with pytest.raises(FileNotFoundError, match="No internet"):
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)

View File

@@ -39,7 +39,6 @@ class Node:
node_id: NodeId
event_index_counter: Iterator[int]
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
@@ -69,7 +68,6 @@ class Node:
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
)
else:
download_coordinator = None
@@ -134,13 +132,10 @@ class Node:
api,
node_id,
event_index_counter,
args.offline,
)
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -152,6 +147,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit
@@ -225,7 +222,6 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
@@ -264,9 +260,6 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
@@ -289,7 +282,6 @@ class Args(CamelCaseModel):
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -337,11 +329,6 @@ class Args(CamelCaseModel):
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
parser.add_argument(
"--offline",
action="store_true",
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

Some files were not shown because too many files have changed in this diff Show More