mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 23:06:23 -05:00
Compare commits
2 Commits
move-messa
...
sami/iOS-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e9eb93f82 | ||
|
|
ab622f79c3 |
@@ -200,7 +200,7 @@ class Module(dict):
|
||||
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
|
||||
"""Return the submodules that do not contain other modules."""
|
||||
|
||||
def update(self, parameters: dict[str, Any], strict: bool = ...) -> Module:
|
||||
def update(self, parameters: dict, strict: bool = ...) -> Module:
|
||||
"""Replace the parameters of this Module with the provided ones in the
|
||||
dict of dicts and lists.
|
||||
|
||||
|
||||
@@ -7,10 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from mlx.core import MX_ARRAY_TREE
|
||||
|
||||
def tree_map(
|
||||
fn: Callable[..., Any],
|
||||
tree: Any,
|
||||
*rest: Any,
|
||||
is_leaf: Callable[..., bool] | None = ...,
|
||||
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = ...
|
||||
) -> Any:
|
||||
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
@@ -47,11 +44,11 @@ def tree_map(
|
||||
"""
|
||||
|
||||
def tree_map_with_path(
|
||||
fn: Callable[..., Any],
|
||||
fn: Callable,
|
||||
tree: Any,
|
||||
*rest: Any,
|
||||
is_leaf: Callable[..., bool] | None = ...,
|
||||
path: str | None = ...,
|
||||
is_leaf: Optional[Callable] = ...,
|
||||
path: Optional[Any] = ...,
|
||||
) -> Any:
|
||||
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
|
||||
returns a new collection with the results.
|
||||
@@ -83,9 +80,9 @@ def tree_map_with_path(
|
||||
def tree_flatten(
|
||||
tree: Any,
|
||||
prefix: str = ...,
|
||||
is_leaf: Callable[..., bool] | None = ...,
|
||||
destination: list[tuple[str, Any]] | dict[str, Any] | None = ...,
|
||||
) -> list[tuple[str, Any]] | dict[str, Any]:
|
||||
is_leaf: Optional[Callable] = ...,
|
||||
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = ...,
|
||||
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
|
||||
"""Flattens a Python tree to a list of key, value tuples.
|
||||
|
||||
The keys are using the dot notation to define trees of arbitrary depth and
|
||||
@@ -121,7 +118,7 @@ def tree_flatten(
|
||||
the Python tree.
|
||||
"""
|
||||
|
||||
def tree_unflatten(tree: list[tuple[str, Any]] | dict[str, Any]) -> Any:
|
||||
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
||||
"""Recreate a Python tree from its flat representation.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .deepseek_v32 import Model as DSV32Model
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
index_head_dim: int
|
||||
index_n_heads: int
|
||||
index_topk: int
|
||||
intermediate_size: int
|
||||
moe_intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
n_shared_experts: Optional[int]
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
kv_lora_rank: int
|
||||
q_lora_rank: int
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
topk_method: str
|
||||
scoring_func: str
|
||||
norm_topk_prob: bool
|
||||
n_group: int
|
||||
topk_group: int
|
||||
num_experts_per_tok: int
|
||||
moe_layer_freq: int
|
||||
first_k_dense_replace: int
|
||||
max_position_embeddings: int
|
||||
rms_norm_eps: float
|
||||
rope_parameters: Dict[str, Any]
|
||||
attention_bias: bool
|
||||
rope_scaling: Dict[str, Any] | None
|
||||
rope_theta: float | None
|
||||
|
||||
class Model(DSV32Model):
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
137
Cargo.lock
generated
137
Cargo.lock
generated
@@ -141,6 +141,12 @@ version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs"
|
||||
version = "0.7.1"
|
||||
@@ -298,6 +304,19 @@ version = "1.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
|
||||
|
||||
[[package]]
|
||||
name = "bigdecimal"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bimap"
|
||||
version = "0.6.3"
|
||||
@@ -497,6 +516,15 @@ version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
|
||||
|
||||
[[package]]
|
||||
name = "convert_case"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
|
||||
dependencies = [
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
@@ -718,6 +746,29 @@ dependencies = [
|
||||
"powerfmt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
|
||||
dependencies = [
|
||||
"derive_more-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more-impl"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
|
||||
dependencies = [
|
||||
"convert_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustc_version",
|
||||
"syn 2.0.111",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
@@ -888,17 +939,22 @@ name = "exo_pyo3_bindings"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"env_logger",
|
||||
"extend",
|
||||
"futures-lite",
|
||||
"futures",
|
||||
"impl-trait-for-tuples",
|
||||
"libp2p",
|
||||
"log",
|
||||
"networking",
|
||||
"once_cell",
|
||||
"pin-project",
|
||||
"pyo3",
|
||||
"pyo3-async-runtimes",
|
||||
"pyo3-log",
|
||||
"pyo3-stub-gen",
|
||||
"thiserror 2.0.17",
|
||||
"thread_local",
|
||||
"tokio",
|
||||
"util",
|
||||
]
|
||||
@@ -914,12 +970,6 @@ dependencies = [
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "ff"
|
||||
version = "0.13.1"
|
||||
@@ -1028,10 +1078,7 @@ version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"parking",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
@@ -1593,6 +1640,17 @@ dependencies = [
|
||||
"xmltree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "impl-trait-for-tuples"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.12.1"
|
||||
@@ -1771,6 +1829,12 @@ version = "0.2.178"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
|
||||
|
||||
[[package]]
|
||||
name = "libp2p"
|
||||
version = "0.56.0"
|
||||
@@ -2760,14 +2824,16 @@ name = "networking"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"either",
|
||||
"extend",
|
||||
"futures-lite",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"impl-trait-for-tuples",
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
"log",
|
||||
"pin-project",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
@@ -2852,6 +2918,17 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
@@ -3202,14 +3279,28 @@ version = "0.27.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
|
||||
dependencies = [
|
||||
"bigdecimal",
|
||||
"either",
|
||||
"hashbrown 0.16.1",
|
||||
"indexmap",
|
||||
"indoc",
|
||||
"inventory",
|
||||
"libc",
|
||||
"lock_api",
|
||||
"memoffset",
|
||||
"num-bigint",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"ordered-float",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
"pyo3-macros",
|
||||
"rust_decimal",
|
||||
"smallvec",
|
||||
"unindent",
|
||||
]
|
||||
|
||||
@@ -3650,6 +3741,16 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust_decimal"
|
||||
version = "1.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
@@ -4514,12 +4615,24 @@ version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
|
||||
|
||||
[[package]]
|
||||
name = "unicode_names2"
|
||||
version = "1.3.0"
|
||||
|
||||
31
Cargo.toml
31
Cargo.toml
@@ -26,20 +26,49 @@ opt-level = 3
|
||||
networking = { path = "rust/networking" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
syn = "2.0"
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
darling = "0.20"
|
||||
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
impl-trait-for-tuples = "0.2"
|
||||
clap = "4.5"
|
||||
derive_more = { version = "2.0.1", features = ["display"] }
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
itertools = "0.14"
|
||||
thiserror = "2"
|
||||
internment = "0.8"
|
||||
recursion = "0.5"
|
||||
regex = "1.11"
|
||||
once_cell = "1.21"
|
||||
thread_local = "1.1"
|
||||
bon = "3.4"
|
||||
generativity = "1.1"
|
||||
anyhow = "1.0"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Functional generics/lenses frameworks
|
||||
frunk_core = "0.4"
|
||||
frunk = "0.4"
|
||||
frunk_utils = "0.2"
|
||||
frunk-enum-core = "0.3"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
futures-lite = "2.6.1"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
either = "1.15"
|
||||
ordered-float = "5.0"
|
||||
ahash = "0.8"
|
||||
|
||||
# Tracing/logging
|
||||
log = "0.4"
|
||||
|
||||
11
README.md
11
README.md
@@ -72,23 +72,16 @@ There are two ways to run exo:
|
||||
|
||||
### Run from Source (macOS)
|
||||
|
||||
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
|
||||
|
||||
```bash
|
||||
nix run .#exo
|
||||
```
|
||||
|
||||
**Prerequisites:**
|
||||
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
```
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard)
|
||||
|
||||
|
||||
```bash
|
||||
brew install uv macmon node
|
||||
```
|
||||
|
||||
628
app/EXO-iOS/EXO-iOS.xcodeproj/project.pbxproj
Normal file
628
app/EXO-iOS/EXO-iOS.xcodeproj/project.pbxproj
Normal file
@@ -0,0 +1,628 @@
|
||||
// !$*UTF8*$!
|
||||
{
|
||||
archiveVersion = 1;
|
||||
classes = {
|
||||
};
|
||||
objectVersion = 77;
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17512F44F359009C51A3 /* MLXLLM */; };
|
||||
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17532F44F359009C51A3 /* MLXLMCommon */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXContainerItemProxy section */
|
||||
E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
|
||||
proxyType = 1;
|
||||
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
|
||||
remoteInfo = "EXO-iOS";
|
||||
};
|
||||
E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
|
||||
proxyType = 1;
|
||||
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
|
||||
remoteInfo = "EXO-iOS";
|
||||
};
|
||||
/* End PBXContainerItemProxy section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "EXO-iOS.app"; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSTests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSUITests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFileSystemSynchronizedBuildFileExceptionSet section */
|
||||
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */ = {
|
||||
isa = PBXFileSystemSynchronizedBuildFileExceptionSet;
|
||||
membershipExceptions = (
|
||||
Info.plist,
|
||||
);
|
||||
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
|
||||
};
|
||||
/* End PBXFileSystemSynchronizedBuildFileExceptionSet section */
|
||||
|
||||
/* Begin PBXFileSystemSynchronizedRootGroup section */
|
||||
E09D16712F44CA1E009C51A3 /* EXO-iOS */ = {
|
||||
isa = PBXFileSystemSynchronizedRootGroup;
|
||||
exceptions = (
|
||||
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */,
|
||||
);
|
||||
path = "EXO-iOS";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */ = {
|
||||
isa = PBXFileSystemSynchronizedRootGroup;
|
||||
path = "EXO-iOSTests";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */ = {
|
||||
isa = PBXFileSystemSynchronizedRootGroup;
|
||||
path = "EXO-iOSUITests";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXFileSystemSynchronizedRootGroup section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
E09D166C2F44CA1E009C51A3 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */,
|
||||
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16792F44CA20009C51A3 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16832F44CA20009C51A3 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXFrameworksBuildPhase section */
|
||||
|
||||
/* Begin PBXGroup section */
|
||||
E09D16662F44CA1E009C51A3 = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
|
||||
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
|
||||
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
|
||||
E09D16702F44CA1E009C51A3 /* Products */,
|
||||
);
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E09D16702F44CA1E009C51A3 /* Products */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */,
|
||||
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */,
|
||||
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */,
|
||||
);
|
||||
name = Products;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXGroup section */
|
||||
|
||||
/* Begin PBXNativeTarget section */
|
||||
E09D166E2F44CA1E009C51A3 /* EXO-iOS */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */;
|
||||
buildPhases = (
|
||||
E09D166B2F44CA1E009C51A3 /* Sources */,
|
||||
E09D166C2F44CA1E009C51A3 /* Frameworks */,
|
||||
E09D166D2F44CA1E009C51A3 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
);
|
||||
fileSystemSynchronizedGroups = (
|
||||
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
|
||||
);
|
||||
name = "EXO-iOS";
|
||||
packageProductDependencies = (
|
||||
E09D17512F44F359009C51A3 /* MLXLLM */,
|
||||
E09D17532F44F359009C51A3 /* MLXLMCommon */,
|
||||
);
|
||||
productName = "EXO-iOS";
|
||||
productReference = E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */;
|
||||
productType = "com.apple.product-type.application";
|
||||
};
|
||||
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */;
|
||||
buildPhases = (
|
||||
E09D16782F44CA20009C51A3 /* Sources */,
|
||||
E09D16792F44CA20009C51A3 /* Frameworks */,
|
||||
E09D167A2F44CA20009C51A3 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */,
|
||||
);
|
||||
fileSystemSynchronizedGroups = (
|
||||
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
|
||||
);
|
||||
name = "EXO-iOSTests";
|
||||
packageProductDependencies = (
|
||||
);
|
||||
productName = "EXO-iOSTests";
|
||||
productReference = E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */;
|
||||
productType = "com.apple.product-type.bundle.unit-test";
|
||||
};
|
||||
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */;
|
||||
buildPhases = (
|
||||
E09D16822F44CA20009C51A3 /* Sources */,
|
||||
E09D16832F44CA20009C51A3 /* Frameworks */,
|
||||
E09D16842F44CA20009C51A3 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
E09D16882F44CA20009C51A3 /* PBXTargetDependency */,
|
||||
);
|
||||
fileSystemSynchronizedGroups = (
|
||||
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
|
||||
);
|
||||
name = "EXO-iOSUITests";
|
||||
packageProductDependencies = (
|
||||
);
|
||||
productName = "EXO-iOSUITests";
|
||||
productReference = E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */;
|
||||
productType = "com.apple.product-type.bundle.ui-testing";
|
||||
};
|
||||
/* End PBXNativeTarget section */
|
||||
|
||||
/* Begin PBXProject section */
|
||||
E09D16672F44CA1E009C51A3 /* Project object */ = {
|
||||
isa = PBXProject;
|
||||
attributes = {
|
||||
BuildIndependentTargetsInParallel = 1;
|
||||
LastSwiftUpdateCheck = 2620;
|
||||
LastUpgradeCheck = 2620;
|
||||
TargetAttributes = {
|
||||
E09D166E2F44CA1E009C51A3 = {
|
||||
CreatedOnToolsVersion = 26.2;
|
||||
};
|
||||
E09D167B2F44CA20009C51A3 = {
|
||||
CreatedOnToolsVersion = 26.2;
|
||||
TestTargetID = E09D166E2F44CA1E009C51A3;
|
||||
};
|
||||
E09D16852F44CA20009C51A3 = {
|
||||
CreatedOnToolsVersion = 26.2;
|
||||
TestTargetID = E09D166E2F44CA1E009C51A3;
|
||||
};
|
||||
};
|
||||
};
|
||||
buildConfigurationList = E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */;
|
||||
developmentRegion = en;
|
||||
hasScannedForEncodings = 0;
|
||||
knownRegions = (
|
||||
en,
|
||||
Base,
|
||||
);
|
||||
mainGroup = E09D16662F44CA1E009C51A3;
|
||||
minimizedProjectReferenceProxies = 1;
|
||||
packageReferences = (
|
||||
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */,
|
||||
);
|
||||
preferredProjectObjectVersion = 77;
|
||||
productRefGroup = E09D16702F44CA1E009C51A3 /* Products */;
|
||||
projectDirPath = "";
|
||||
projectRoot = "";
|
||||
targets = (
|
||||
E09D166E2F44CA1E009C51A3 /* EXO-iOS */,
|
||||
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */,
|
||||
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */,
|
||||
);
|
||||
};
|
||||
/* End PBXProject section */
|
||||
|
||||
/* Begin PBXResourcesBuildPhase section */
|
||||
E09D166D2F44CA1E009C51A3 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D167A2F44CA20009C51A3 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16842F44CA20009C51A3 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXResourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXSourcesBuildPhase section */
|
||||
E09D166B2F44CA1E009C51A3 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16782F44CA20009C51A3 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16822F44CA20009C51A3 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXSourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXTargetDependency section */
|
||||
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */ = {
|
||||
isa = PBXTargetDependency;
|
||||
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
|
||||
targetProxy = E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */;
|
||||
};
|
||||
E09D16882F44CA20009C51A3 /* PBXTargetDependency */ = {
|
||||
isa = PBXTargetDependency;
|
||||
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
|
||||
targetProxy = E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */;
|
||||
};
|
||||
/* End PBXTargetDependency section */
|
||||
|
||||
/* Begin XCBuildConfiguration section */
|
||||
E09D168E2F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = dwarf;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_TESTABILITY = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_DYNAMIC_NO_PIC = NO;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_OPTIMIZATION_LEVEL = 0;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = (
|
||||
"DEBUG=1",
|
||||
"$(inherited)",
|
||||
);
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
|
||||
MTL_FAST_MATH = YES;
|
||||
ONLY_ACTIVE_ARCH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D168F2F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
|
||||
ENABLE_NS_ASSERTIONS = NO;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MTL_ENABLE_DEBUG_INFO = NO;
|
||||
MTL_FAST_MATH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
VALIDATE_PRODUCT = YES;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
E09D16912F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_TEAM = 3M3M67U93M;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_FILE = "EXO-iOS/Info.plist";
|
||||
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
||||
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
||||
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
);
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = YES;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D16922F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_TEAM = 3M3M67U93M;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_FILE = "EXO-iOS/Info.plist";
|
||||
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
||||
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
||||
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
);
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = YES;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
E09D16942F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
BUNDLE_LOADER = "$(TEST_HOST)";
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D16952F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
BUNDLE_LOADER = "$(TEST_HOST)";
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
E09D16972F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_TARGET_NAME = "EXO-iOS";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D16982F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_TARGET_NAME = "EXO-iOS";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
/* End XCBuildConfiguration section */
|
||||
|
||||
/* Begin XCConfigurationList section */
|
||||
E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D168E2F44CA20009C51A3 /* Debug */,
|
||||
E09D168F2F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D16912F44CA20009C51A3 /* Debug */,
|
||||
E09D16922F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D16942F44CA20009C51A3 /* Debug */,
|
||||
E09D16952F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D16972F44CA20009C51A3 /* Debug */,
|
||||
E09D16982F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
/* End XCConfigurationList section */
|
||||
|
||||
/* Begin XCRemoteSwiftPackageReference section */
|
||||
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */ = {
|
||||
isa = XCRemoteSwiftPackageReference;
|
||||
repositoryURL = "https://github.com/ml-explore/mlx-swift-lm";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 2.30.3;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
/* Begin XCSwiftPackageProductDependency section */
|
||||
E09D17512F44F359009C51A3 /* MLXLLM */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
|
||||
productName = MLXLLM;
|
||||
};
|
||||
E09D17532F44F359009C51A3 /* MLXLMCommon */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
|
||||
productName = MLXLMCommon;
|
||||
};
|
||||
/* End XCSwiftPackageProductDependency section */
|
||||
};
|
||||
rootObject = E09D16672F44CA1E009C51A3 /* Project object */;
|
||||
}
|
||||
7
app/EXO-iOS/EXO-iOS.xcodeproj/project.xcworkspace/contents.xcworkspacedata
generated
Normal file
7
app/EXO-iOS/EXO-iOS.xcodeproj/project.xcworkspace/contents.xcworkspacedata
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Workspace
|
||||
version = "1.0">
|
||||
<FileRef
|
||||
location = "self:">
|
||||
</FileRef>
|
||||
</Workspace>
|
||||
@@ -0,0 +1,60 @@
|
||||
{
|
||||
"originHash" : "facc0ac7c70363ea20f6cd1235de91dea6b06f0d00190946045a6c8ae753abc2",
|
||||
"pins" : [
|
||||
{
|
||||
"identity" : "mlx-swift",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/ml-explore/mlx-swift",
|
||||
"state" : {
|
||||
"revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d",
|
||||
"version" : "0.30.6"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "mlx-swift-lm",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/ml-explore/mlx-swift-lm",
|
||||
"state" : {
|
||||
"revision" : "360c5052b81cc154b04ee0933597a4ad6db4b8ae",
|
||||
"version" : "2.30.3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-collections",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-collections.git",
|
||||
"state" : {
|
||||
"revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e",
|
||||
"version" : "1.3.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-jinja",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-jinja.git",
|
||||
"state" : {
|
||||
"revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0",
|
||||
"version" : "2.3.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-numerics",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-numerics",
|
||||
"state" : {
|
||||
"revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2",
|
||||
"version" : "1.1.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-transformers",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-transformers",
|
||||
"state" : {
|
||||
"revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0",
|
||||
"version" : "1.1.6"
|
||||
}
|
||||
}
|
||||
],
|
||||
"version" : 3
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"colors" : [
|
||||
{
|
||||
"color" : {
|
||||
"color-space" : "srgb",
|
||||
"components" : {
|
||||
"alpha" : "1.000",
|
||||
"blue" : "0x00",
|
||||
"green" : "0xD7",
|
||||
"red" : "0xFF"
|
||||
}
|
||||
},
|
||||
"idiom" : "universal"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 10 KiB |
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"filename" : "AppIcon.png",
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
},
|
||||
{
|
||||
"appearances" : [
|
||||
{
|
||||
"appearance" : "luminosity",
|
||||
"value" : "dark"
|
||||
}
|
||||
],
|
||||
"filename" : "AppIcon.png",
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
},
|
||||
{
|
||||
"appearances" : [
|
||||
{
|
||||
"appearance" : "luminosity",
|
||||
"value" : "tinted"
|
||||
}
|
||||
],
|
||||
"filename" : "AppIcon.png",
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
6
app/EXO-iOS/EXO-iOS/Assets.xcassets/Contents.json
Normal file
6
app/EXO-iOS/EXO-iOS/Assets.xcassets/Contents.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
21
app/EXO-iOS/EXO-iOS/Assets.xcassets/ExoLogo.imageset/Contents.json
vendored
Normal file
21
app/EXO-iOS/EXO-iOS/Assets.xcassets/ExoLogo.imageset/Contents.json
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"filename" : "exo-logo.png",
|
||||
"idiom" : "universal",
|
||||
"scale" : "1x"
|
||||
},
|
||||
{
|
||||
"idiom" : "universal",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "universal",
|
||||
"scale" : "3x"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
BIN
app/EXO-iOS/EXO-iOS/Assets.xcassets/ExoLogo.imageset/exo-logo.png
vendored
Normal file
BIN
app/EXO-iOS/EXO-iOS/Assets.xcassets/ExoLogo.imageset/exo-logo.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.6 KiB |
8
app/EXO-iOS/EXO-iOS/EXO-iOS.entitlements
Normal file
8
app/EXO-iOS/EXO-iOS/EXO-iOS.entitlements
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>com.apple.developer.kernel.increased-memory-limit</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
67
app/EXO-iOS/EXO-iOS/EXO_iOSApp.swift
Normal file
67
app/EXO-iOS/EXO-iOS/EXO_iOSApp.swift
Normal file
@@ -0,0 +1,67 @@
|
||||
import SwiftUI
|
||||
import UIKit
|
||||
|
||||
@main
|
||||
struct EXO_iOSApp: App {
|
||||
@State private var clusterService = ClusterService()
|
||||
@State private var discoveryService = DiscoveryService()
|
||||
@State private var localInferenceService = LocalInferenceService()
|
||||
@State private var chatService: ChatService?
|
||||
|
||||
init() {
|
||||
let darkGray = UIColor(red: 0x1F / 255.0, green: 0x1F / 255.0, blue: 0x1F / 255.0, alpha: 1)
|
||||
let yellow = UIColor(red: 0xFF / 255.0, green: 0xD7 / 255.0, blue: 0x00 / 255.0, alpha: 1)
|
||||
|
||||
let navAppearance = UINavigationBarAppearance()
|
||||
navAppearance.configureWithOpaqueBackground()
|
||||
navAppearance.backgroundColor = darkGray
|
||||
navAppearance.titleTextAttributes = [
|
||||
.foregroundColor: yellow,
|
||||
.font: UIFont.monospacedSystemFont(ofSize: 17, weight: .semibold),
|
||||
]
|
||||
navAppearance.largeTitleTextAttributes = [
|
||||
.foregroundColor: yellow,
|
||||
.font: UIFont.monospacedSystemFont(ofSize: 34, weight: .bold),
|
||||
]
|
||||
|
||||
UINavigationBar.appearance().standardAppearance = navAppearance
|
||||
UINavigationBar.appearance().compactAppearance = navAppearance
|
||||
UINavigationBar.appearance().scrollEdgeAppearance = navAppearance
|
||||
UINavigationBar.appearance().tintColor = yellow
|
||||
}
|
||||
|
||||
var body: some Scene {
|
||||
WindowGroup {
|
||||
if let chatService {
|
||||
RootView()
|
||||
.environment(clusterService)
|
||||
.environment(discoveryService)
|
||||
.environment(chatService)
|
||||
.environment(localInferenceService)
|
||||
.preferredColorScheme(.dark)
|
||||
.task {
|
||||
await clusterService.attemptAutoReconnect()
|
||||
discoveryService.startBrowsing()
|
||||
await localInferenceService.prepareModel()
|
||||
}
|
||||
.onChange(of: discoveryService.discoveredClusters) { _, clusters in
|
||||
guard !clusterService.isConnected,
|
||||
case .disconnected = clusterService.connectionState,
|
||||
let first = clusters.first
|
||||
else { return }
|
||||
Task {
|
||||
await clusterService.connectToDiscoveredCluster(
|
||||
first, using: discoveryService)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Color.exoBlack.onAppear {
|
||||
chatService = ChatService(
|
||||
clusterService: clusterService,
|
||||
localInferenceService: localInferenceService
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
19
app/EXO-iOS/EXO-iOS/Info.plist
Normal file
19
app/EXO-iOS/EXO-iOS/Info.plist
Normal file
@@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>UIUserInterfaceStyle</key>
|
||||
<string>Dark</string>
|
||||
<key>CFBundleDisplayName</key>
|
||||
<string>EXO</string>
|
||||
<key>NSLocalNetworkUsageDescription</key>
|
||||
<string>EXO needs local network access to connect to your EXO cluster.</string>
|
||||
<key>NSBonjourServices</key>
|
||||
<array>
|
||||
<string>_exo._tcp</string>
|
||||
<string>_p2p._tcp</string>
|
||||
<string>_p2p._udp</string>
|
||||
<string>_libp2p._udp</string>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
||||
129
app/EXO-iOS/EXO-iOS/Models/ChatCompletionTypes.swift
Normal file
129
app/EXO-iOS/EXO-iOS/Models/ChatCompletionTypes.swift
Normal file
@@ -0,0 +1,129 @@
|
||||
import Foundation
|
||||
|
||||
// MARK: - Request
|
||||
|
||||
struct ChatCompletionRequest: Encodable {
|
||||
let model: String
|
||||
let messages: [ChatCompletionMessageParam]
|
||||
let stream: Bool
|
||||
let maxTokens: Int?
|
||||
let temperature: Double?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case model, messages, stream, temperature
|
||||
case maxTokens = "max_tokens"
|
||||
}
|
||||
}
|
||||
|
||||
struct ChatCompletionMessageParam: Encodable {
|
||||
let role: String
|
||||
let content: String
|
||||
}
|
||||
|
||||
// MARK: - Streaming Response
|
||||
|
||||
struct ChatCompletionChunk: Decodable {
|
||||
let id: String
|
||||
let model: String?
|
||||
let choices: [StreamingChoice]
|
||||
let usage: ChunkUsage?
|
||||
|
||||
init(id: String, model: String?, choices: [StreamingChoice], usage: ChunkUsage?) {
|
||||
self.id = id
|
||||
self.model = model
|
||||
self.choices = choices
|
||||
self.usage = usage
|
||||
}
|
||||
}
|
||||
|
||||
struct StreamingChoice: Decodable {
|
||||
let index: Int
|
||||
let delta: Delta
|
||||
let finishReason: String?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case index, delta
|
||||
case finishReason = "finish_reason"
|
||||
}
|
||||
|
||||
init(index: Int, delta: Delta, finishReason: String?) {
|
||||
self.index = index
|
||||
self.delta = delta
|
||||
self.finishReason = finishReason
|
||||
}
|
||||
}
|
||||
|
||||
struct Delta: Decodable {
|
||||
let role: String?
|
||||
let content: String?
|
||||
|
||||
init(role: String?, content: String?) {
|
||||
self.role = role
|
||||
self.content = content
|
||||
}
|
||||
}
|
||||
|
||||
struct ChunkUsage: Decodable {
|
||||
let promptTokens: Int?
|
||||
let completionTokens: Int?
|
||||
let totalTokens: Int?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case promptTokens = "prompt_tokens"
|
||||
case completionTokens = "completion_tokens"
|
||||
case totalTokens = "total_tokens"
|
||||
}
|
||||
|
||||
init(promptTokens: Int?, completionTokens: Int?, totalTokens: Int?) {
|
||||
self.promptTokens = promptTokens
|
||||
self.completionTokens = completionTokens
|
||||
self.totalTokens = totalTokens
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Non-Streaming Response
|
||||
|
||||
struct ChatCompletionResponse: Decodable {
|
||||
let id: String
|
||||
let model: String?
|
||||
let choices: [ResponseChoice]
|
||||
}
|
||||
|
||||
struct ResponseChoice: Decodable {
|
||||
let index: Int
|
||||
let message: ResponseMessage
|
||||
let finishReason: String?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case index, message
|
||||
case finishReason = "finish_reason"
|
||||
}
|
||||
}
|
||||
|
||||
struct ResponseMessage: Decodable {
|
||||
let role: String?
|
||||
let content: String?
|
||||
}
|
||||
|
||||
// MARK: - Models List
|
||||
|
||||
struct ModelListResponse: Decodable {
|
||||
let data: [ModelInfo]
|
||||
}
|
||||
|
||||
struct ModelInfo: Decodable, Identifiable {
|
||||
let id: String
|
||||
let name: String?
|
||||
}
|
||||
|
||||
// MARK: - Error
|
||||
|
||||
struct APIErrorResponse: Decodable {
|
||||
let error: APIErrorInfo
|
||||
}
|
||||
|
||||
struct APIErrorInfo: Decodable {
|
||||
let message: String
|
||||
let type: String?
|
||||
let code: Int?
|
||||
}
|
||||
26
app/EXO-iOS/EXO-iOS/Models/ChatMessage.swift
Normal file
26
app/EXO-iOS/EXO-iOS/Models/ChatMessage.swift
Normal file
@@ -0,0 +1,26 @@
|
||||
import Foundation
|
||||
|
||||
struct ChatMessage: Identifiable, Equatable {
|
||||
let id: UUID
|
||||
let role: Role
|
||||
var content: String
|
||||
let timestamp: Date
|
||||
var isStreaming: Bool
|
||||
|
||||
enum Role: String, Codable {
|
||||
case user
|
||||
case assistant
|
||||
case system
|
||||
}
|
||||
|
||||
init(
|
||||
id: UUID = UUID(), role: Role, content: String, timestamp: Date = Date(),
|
||||
isStreaming: Bool = false
|
||||
) {
|
||||
self.id = id
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.timestamp = timestamp
|
||||
self.isStreaming = isStreaming
|
||||
}
|
||||
}
|
||||
11
app/EXO-iOS/EXO-iOS/Models/ConnectionInfo.swift
Normal file
11
app/EXO-iOS/EXO-iOS/Models/ConnectionInfo.swift
Normal file
@@ -0,0 +1,11 @@
|
||||
import Foundation
|
||||
|
||||
struct ConnectionInfo: Codable, Equatable {
|
||||
let host: String
|
||||
let port: Int
|
||||
let nodeId: String?
|
||||
|
||||
var baseURL: URL { URL(string: "http://\(host):\(port)")! }
|
||||
|
||||
static let defaultPort = 52415
|
||||
}
|
||||
34
app/EXO-iOS/EXO-iOS/Models/Conversation.swift
Normal file
34
app/EXO-iOS/EXO-iOS/Models/Conversation.swift
Normal file
@@ -0,0 +1,34 @@
|
||||
import Foundation
|
||||
|
||||
struct Conversation: Identifiable, Codable, Equatable {
|
||||
let id: UUID
|
||||
var title: String
|
||||
var messages: [StoredMessage]
|
||||
var modelId: String?
|
||||
let createdAt: Date
|
||||
|
||||
init(
|
||||
id: UUID = UUID(), title: String = "New Chat", messages: [StoredMessage] = [],
|
||||
modelId: String? = nil, createdAt: Date = Date()
|
||||
) {
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.messages = messages
|
||||
self.modelId = modelId
|
||||
self.createdAt = createdAt
|
||||
}
|
||||
}
|
||||
|
||||
struct StoredMessage: Identifiable, Codable, Equatable {
|
||||
let id: UUID
|
||||
let role: String
|
||||
var content: String
|
||||
let timestamp: Date
|
||||
|
||||
init(id: UUID = UUID(), role: String, content: String, timestamp: Date = Date()) {
|
||||
self.id = id
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.timestamp = timestamp
|
||||
}
|
||||
}
|
||||
227
app/EXO-iOS/EXO-iOS/Services/ChatService.swift
Normal file
227
app/EXO-iOS/EXO-iOS/Services/ChatService.swift
Normal file
@@ -0,0 +1,227 @@
|
||||
import Foundation
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class ChatService {
|
||||
var conversations: [Conversation] = []
|
||||
var activeConversationId: UUID?
|
||||
private(set) var isGenerating: Bool = false
|
||||
private var currentGenerationTask: Task<Void, Never>?
|
||||
|
||||
private let clusterService: ClusterService
|
||||
private let localInferenceService: LocalInferenceService
|
||||
|
||||
var canSendMessage: Bool {
|
||||
clusterService.isConnected || localInferenceService.isAvailable
|
||||
}
|
||||
|
||||
var activeConversation: Conversation? {
|
||||
guard let id = activeConversationId else { return nil }
|
||||
return conversations.first { $0.id == id }
|
||||
}
|
||||
|
||||
var activeMessages: [ChatMessage] {
|
||||
guard let conversation = activeConversation else { return [] }
|
||||
return conversation.messages.map { stored in
|
||||
ChatMessage(
|
||||
id: stored.id,
|
||||
role: ChatMessage.Role(rawValue: stored.role) ?? .user,
|
||||
content: stored.content,
|
||||
timestamp: stored.timestamp
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
init(clusterService: ClusterService, localInferenceService: LocalInferenceService) {
|
||||
self.clusterService = clusterService
|
||||
self.localInferenceService = localInferenceService
|
||||
loadConversations()
|
||||
}
|
||||
|
||||
// MARK: - Conversation Management
|
||||
|
||||
func createConversation(modelId: String? = nil) {
|
||||
let conversation = Conversation(
|
||||
modelId: modelId ?? clusterService.availableModels.first?.id)
|
||||
conversations.insert(conversation, at: 0)
|
||||
activeConversationId = conversation.id
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
func deleteConversation(id: UUID) {
|
||||
conversations.removeAll { $0.id == id }
|
||||
if activeConversationId == id {
|
||||
activeConversationId = conversations.first?.id
|
||||
}
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
func setActiveConversation(id: UUID) {
|
||||
activeConversationId = id
|
||||
}
|
||||
|
||||
func setModelForActiveConversation(_ modelId: String) {
|
||||
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
|
||||
return
|
||||
}
|
||||
conversations[index].modelId = modelId
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
// MARK: - Messaging
|
||||
|
||||
func sendMessage(_ text: String) {
|
||||
guard !text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { return }
|
||||
|
||||
if activeConversation == nil {
|
||||
createConversation()
|
||||
}
|
||||
|
||||
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
|
||||
return
|
||||
}
|
||||
|
||||
let userMessage = StoredMessage(role: "user", content: text)
|
||||
conversations[index].messages.append(userMessage)
|
||||
|
||||
if conversations[index].title == "New Chat" {
|
||||
let preview = String(text.prefix(40))
|
||||
conversations[index].title = preview + (text.count > 40 ? "..." : "")
|
||||
}
|
||||
|
||||
let modelId: String
|
||||
if clusterService.isConnected {
|
||||
guard
|
||||
let clusterId = conversations[index].modelId
|
||||
?? clusterService.availableModels.first?.id
|
||||
else {
|
||||
let errorMessage = StoredMessage(
|
||||
role: "assistant", content: "No model selected. Please select a model first.")
|
||||
conversations[index].messages.append(errorMessage)
|
||||
saveConversations()
|
||||
return
|
||||
}
|
||||
modelId = clusterId
|
||||
} else if localInferenceService.isAvailable {
|
||||
modelId = localInferenceService.defaultModelId
|
||||
} else {
|
||||
let errorMessage = StoredMessage(
|
||||
role: "assistant",
|
||||
content: "Not connected to a cluster and local model is not available.")
|
||||
conversations[index].messages.append(errorMessage)
|
||||
saveConversations()
|
||||
return
|
||||
}
|
||||
|
||||
conversations[index].modelId = modelId
|
||||
|
||||
let assistantMessageId = UUID()
|
||||
let assistantMessage = StoredMessage(
|
||||
id: assistantMessageId, role: "assistant", content: "", timestamp: Date())
|
||||
conversations[index].messages.append(assistantMessage)
|
||||
|
||||
let messagesForAPI = conversations[index].messages.dropLast().map { stored in
|
||||
ChatCompletionMessageParam(role: stored.role, content: stored.content)
|
||||
}
|
||||
|
||||
let request = ChatCompletionRequest(
|
||||
model: modelId,
|
||||
messages: Array(messagesForAPI),
|
||||
stream: true,
|
||||
maxTokens: 4096,
|
||||
temperature: nil
|
||||
)
|
||||
|
||||
let conversationId = conversations[index].id
|
||||
|
||||
isGenerating = true
|
||||
currentGenerationTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
await self.performStreaming(
|
||||
request: request, conversationId: conversationId,
|
||||
assistantMessageId: assistantMessageId)
|
||||
}
|
||||
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
func cancelGeneration() {
|
||||
currentGenerationTask?.cancel()
|
||||
currentGenerationTask = nil
|
||||
localInferenceService.cancelGeneration()
|
||||
isGenerating = false
|
||||
}
|
||||
|
||||
// MARK: - Streaming
|
||||
|
||||
private func performStreaming(
|
||||
request: ChatCompletionRequest, conversationId: UUID, assistantMessageId: UUID
|
||||
) async {
|
||||
defer {
|
||||
isGenerating = false
|
||||
currentGenerationTask = nil
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
do {
|
||||
let stream =
|
||||
clusterService.isConnected
|
||||
? clusterService.streamChatCompletion(request: request)
|
||||
: localInferenceService.streamChatCompletion(request: request)
|
||||
for try await chunk in stream {
|
||||
guard !Task.isCancelled else { return }
|
||||
guard let content = chunk.choices.first?.delta.content, !content.isEmpty else {
|
||||
continue
|
||||
}
|
||||
|
||||
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
|
||||
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
|
||||
$0.id == assistantMessageId
|
||||
})
|
||||
{
|
||||
conversations[convIndex].messages[msgIndex].content += content
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
if !Task.isCancelled {
|
||||
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
|
||||
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
|
||||
$0.id == assistantMessageId
|
||||
})
|
||||
{
|
||||
if conversations[convIndex].messages[msgIndex].content.isEmpty {
|
||||
conversations[convIndex].messages[msgIndex].content =
|
||||
"Error: \(error.localizedDescription)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Persistence
|
||||
|
||||
private static var storageURL: URL {
|
||||
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)
|
||||
.first!
|
||||
return documents.appendingPathComponent("exo_conversations.json")
|
||||
}
|
||||
|
||||
private func saveConversations() {
|
||||
do {
|
||||
let data = try JSONEncoder().encode(conversations)
|
||||
try data.write(to: Self.storageURL, options: .atomic)
|
||||
} catch {
|
||||
// Save failed silently
|
||||
}
|
||||
}
|
||||
|
||||
private func loadConversations() {
|
||||
do {
|
||||
let data = try Data(contentsOf: Self.storageURL)
|
||||
conversations = try JSONDecoder().decode([Conversation].self, from: data)
|
||||
activeConversationId = conversations.first?.id
|
||||
} catch {
|
||||
conversations = []
|
||||
}
|
||||
}
|
||||
}
|
||||
200
app/EXO-iOS/EXO-iOS/Services/ClusterService.swift
Normal file
200
app/EXO-iOS/EXO-iOS/Services/ClusterService.swift
Normal file
@@ -0,0 +1,200 @@
|
||||
import Foundation
|
||||
|
||||
enum ConnectionState: Equatable {
|
||||
case disconnected
|
||||
case connecting
|
||||
case connected(ConnectionInfo)
|
||||
}
|
||||
|
||||
struct ModelOption: Identifiable, Equatable {
|
||||
let id: String
|
||||
let displayName: String
|
||||
}
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class ClusterService {
|
||||
private(set) var connectionState: ConnectionState = .disconnected
|
||||
private(set) var availableModels: [ModelOption] = []
|
||||
private(set) var lastError: String?
|
||||
|
||||
private let session: URLSession
|
||||
private let decoder: JSONDecoder
|
||||
private var pollingTask: Task<Void, Never>?
|
||||
|
||||
private static let connectionInfoKey = "exo_last_connection_info"
|
||||
|
||||
var isConnected: Bool {
|
||||
if case .connected = connectionState { return true }
|
||||
return false
|
||||
}
|
||||
|
||||
var currentConnection: ConnectionInfo? {
|
||||
if case .connected(let info) = connectionState { return info }
|
||||
return nil
|
||||
}
|
||||
|
||||
init(session: URLSession = .shared) {
|
||||
self.session = session
|
||||
let decoder = JSONDecoder()
|
||||
self.decoder = decoder
|
||||
}
|
||||
|
||||
// MARK: - Connection
|
||||
|
||||
func connect(to info: ConnectionInfo) async {
|
||||
connectionState = .connecting
|
||||
lastError = nil
|
||||
|
||||
do {
|
||||
let url = info.baseURL.appendingPathComponent("node_id")
|
||||
var request = URLRequest(url: url)
|
||||
request.timeoutInterval = 5
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (_, response) = try await session.data(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
throw URLError(.badServerResponse)
|
||||
}
|
||||
|
||||
connectionState = .connected(info)
|
||||
persistConnection(info)
|
||||
startPolling()
|
||||
await fetchModels(baseURL: info.baseURL)
|
||||
} catch {
|
||||
connectionState = .disconnected
|
||||
lastError = "Could not connect to \(info.host):\(info.port)"
|
||||
}
|
||||
}
|
||||
|
||||
func connectToDiscoveredCluster(
|
||||
_ cluster: DiscoveredCluster, using discoveryService: DiscoveryService
|
||||
) async {
|
||||
guard case .disconnected = connectionState else { return }
|
||||
connectionState = .connecting
|
||||
lastError = nil
|
||||
|
||||
guard let info = await discoveryService.resolve(cluster) else {
|
||||
connectionState = .disconnected
|
||||
lastError = "Could not resolve \(cluster.name)"
|
||||
return
|
||||
}
|
||||
connectionState = .disconnected // reset so connect() can proceed
|
||||
await connect(to: info)
|
||||
}
|
||||
|
||||
func disconnect() {
|
||||
stopPolling()
|
||||
connectionState = .disconnected
|
||||
availableModels = []
|
||||
lastError = nil
|
||||
}
|
||||
|
||||
func attemptAutoReconnect() async {
|
||||
guard case .disconnected = connectionState,
|
||||
let info = loadPersistedConnection()
|
||||
else { return }
|
||||
await connect(to: info)
|
||||
}
|
||||
|
||||
// MARK: - Polling
|
||||
|
||||
private func startPolling(interval: TimeInterval = 2.0) {
|
||||
stopPolling()
|
||||
pollingTask = Task { [weak self] in
|
||||
while !Task.isCancelled {
|
||||
try? await Task.sleep(for: .seconds(interval))
|
||||
guard let self, !Task.isCancelled else { return }
|
||||
guard let connection = self.currentConnection else { return }
|
||||
await self.fetchModels(baseURL: connection.baseURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func stopPolling() {
|
||||
pollingTask?.cancel()
|
||||
pollingTask = nil
|
||||
}
|
||||
|
||||
// MARK: - API
|
||||
|
||||
private func fetchModels(baseURL: URL) async {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("models")
|
||||
var request = URLRequest(url: url)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (data, response) = try await session.data(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else { return }
|
||||
|
||||
let list = try decoder.decode(ModelListResponse.self, from: data)
|
||||
availableModels = list.data.map {
|
||||
ModelOption(id: $0.id, displayName: $0.name ?? $0.id)
|
||||
}
|
||||
} catch {
|
||||
// Models fetch failed silently — will retry on next poll
|
||||
}
|
||||
}
|
||||
|
||||
func streamChatCompletion(request body: ChatCompletionRequest) -> AsyncThrowingStream<
|
||||
ChatCompletionChunk, Error
|
||||
> {
|
||||
AsyncThrowingStream { continuation in
|
||||
let task = Task { [weak self] in
|
||||
guard let self, let connection = self.currentConnection else {
|
||||
continuation.finish(throwing: URLError(.notConnectedToInternet))
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
let url = connection.baseURL.appendingPathComponent("v1/chat/completions")
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "POST"
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
request.httpBody = try JSONEncoder().encode(body)
|
||||
|
||||
let (bytes, response) = try await self.session.bytes(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
continuation.finish(throwing: URLError(.badServerResponse))
|
||||
return
|
||||
}
|
||||
|
||||
let parser = SSEStreamParser<ChatCompletionChunk>(
|
||||
bytes: bytes, decoder: self.decoder)
|
||||
for try await chunk in parser {
|
||||
continuation.yield(chunk)
|
||||
}
|
||||
continuation.finish()
|
||||
} catch {
|
||||
continuation.finish(throwing: error)
|
||||
}
|
||||
}
|
||||
|
||||
continuation.onTermination = { _ in
|
||||
task.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Persistence
|
||||
|
||||
private func persistConnection(_ info: ConnectionInfo) {
|
||||
if let data = try? JSONEncoder().encode(info) {
|
||||
UserDefaults.standard.set(data, forKey: Self.connectionInfoKey)
|
||||
}
|
||||
}
|
||||
|
||||
private func loadPersistedConnection() -> ConnectionInfo? {
|
||||
guard let data = UserDefaults.standard.data(forKey: Self.connectionInfoKey) else {
|
||||
return nil
|
||||
}
|
||||
return try? JSONDecoder().decode(ConnectionInfo.self, from: data)
|
||||
}
|
||||
}
|
||||
123
app/EXO-iOS/EXO-iOS/Services/DiscoveryService.swift
Normal file
123
app/EXO-iOS/EXO-iOS/Services/DiscoveryService.swift
Normal file
@@ -0,0 +1,123 @@
|
||||
import Foundation
|
||||
import Network
|
||||
import os
|
||||
|
||||
struct DiscoveredCluster: Identifiable, Equatable {
|
||||
let id: String
|
||||
let name: String
|
||||
let endpoint: NWEndpoint
|
||||
|
||||
static func == (lhs: DiscoveredCluster, rhs: DiscoveredCluster) -> Bool {
|
||||
lhs.id == rhs.id && lhs.name == rhs.name
|
||||
}
|
||||
}
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class DiscoveryService {
|
||||
private(set) var discoveredClusters: [DiscoveredCluster] = []
|
||||
private(set) var isSearching = false
|
||||
|
||||
private var browser: NWBrowser?
|
||||
|
||||
func startBrowsing() {
|
||||
guard browser == nil else { return }
|
||||
|
||||
let browser = NWBrowser(for: .bonjour(type: "_exo._tcp", domain: nil), using: .tcp)
|
||||
|
||||
browser.stateUpdateHandler = { [weak self] state in
|
||||
guard let service = self else { return }
|
||||
Task { @MainActor in
|
||||
switch state {
|
||||
case .ready:
|
||||
service.isSearching = true
|
||||
case .failed, .cancelled:
|
||||
service.isSearching = false
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
browser.browseResultsChangedHandler = { [weak self] results, _ in
|
||||
guard let service = self else { return }
|
||||
Task { @MainActor in
|
||||
service.discoveredClusters = results.compactMap { result in
|
||||
guard case .service(let name, _, _, _) = result.endpoint else {
|
||||
return nil
|
||||
}
|
||||
return DiscoveredCluster(
|
||||
id: name,
|
||||
name: name,
|
||||
endpoint: result.endpoint
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
browser.start(queue: .main)
|
||||
self.browser = browser
|
||||
}
|
||||
|
||||
func stopBrowsing() {
|
||||
browser?.cancel()
|
||||
browser = nil
|
||||
isSearching = false
|
||||
discoveredClusters = []
|
||||
}
|
||||
|
||||
/// Resolve a discovered Bonjour endpoint to an IP address and port, then return a ConnectionInfo.
|
||||
func resolve(_ cluster: DiscoveredCluster) async -> ConnectionInfo? {
|
||||
await withCheckedContinuation { continuation in
|
||||
let didResume = OSAllocatedUnfairLock(initialState: false)
|
||||
let connection = NWConnection(to: cluster.endpoint, using: .tcp)
|
||||
connection.stateUpdateHandler = { state in
|
||||
guard
|
||||
didResume.withLock({
|
||||
guard !$0 else { return false }
|
||||
$0 = true
|
||||
return true
|
||||
})
|
||||
else { return }
|
||||
switch state {
|
||||
case .ready:
|
||||
if let innerEndpoint = connection.currentPath?.remoteEndpoint,
|
||||
case .hostPort(let host, let port) = innerEndpoint
|
||||
{
|
||||
var hostString: String
|
||||
switch host {
|
||||
case .ipv4(let addr):
|
||||
hostString = "\(addr)"
|
||||
case .ipv6(let addr):
|
||||
hostString = "\(addr)"
|
||||
case .name(let name, _):
|
||||
hostString = name
|
||||
@unknown default:
|
||||
hostString = "\(host)"
|
||||
}
|
||||
// Strip interface scope suffix (e.g. "%en0")
|
||||
if let pct = hostString.firstIndex(of: "%") {
|
||||
hostString = String(hostString[..<pct])
|
||||
}
|
||||
let info = ConnectionInfo(
|
||||
host: hostString,
|
||||
port: Int(port.rawValue),
|
||||
nodeId: nil
|
||||
)
|
||||
connection.cancel()
|
||||
continuation.resume(returning: info)
|
||||
} else {
|
||||
connection.cancel()
|
||||
continuation.resume(returning: nil)
|
||||
}
|
||||
case .failed, .cancelled:
|
||||
continuation.resume(returning: nil)
|
||||
default:
|
||||
// Not a terminal state — allow future callbacks
|
||||
didResume.withLock { $0 = false }
|
||||
}
|
||||
}
|
||||
connection.start(queue: .global(qos: .userInitiated))
|
||||
}
|
||||
}
|
||||
}
|
||||
201
app/EXO-iOS/EXO-iOS/Services/LocalInferenceService.swift
Normal file
201
app/EXO-iOS/EXO-iOS/Services/LocalInferenceService.swift
Normal file
@@ -0,0 +1,201 @@
|
||||
import Foundation
|
||||
import MLXLLM
|
||||
import MLXLMCommon
|
||||
|
||||
enum LocalModelState: Equatable {
|
||||
case notDownloaded
|
||||
case downloading(progress: Double)
|
||||
case downloaded
|
||||
case loading
|
||||
case ready
|
||||
case generating
|
||||
case error(String)
|
||||
}
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class LocalInferenceService {
|
||||
private(set) var modelState: LocalModelState = .notDownloaded
|
||||
private var modelContainer: ModelContainer?
|
||||
private var generationTask: Task<Void, Never>?
|
||||
|
||||
let defaultModelId = "mlx-community/Qwen3-0.6B-4bit"
|
||||
|
||||
private static let modelDownloadedKey = "exo_local_model_downloaded"
|
||||
|
||||
var isReady: Bool {
|
||||
modelState == .ready
|
||||
}
|
||||
|
||||
var isAvailable: Bool {
|
||||
modelState == .ready || modelState == .generating
|
||||
}
|
||||
|
||||
init() {
|
||||
if UserDefaults.standard.bool(forKey: Self.modelDownloadedKey) {
|
||||
modelState = .downloaded
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Lifecycle
|
||||
|
||||
func prepareModel() async {
|
||||
guard modelState == .notDownloaded || modelState == .downloaded else { return }
|
||||
|
||||
let wasDownloaded = modelState == .downloaded
|
||||
|
||||
if !wasDownloaded {
|
||||
modelState = .downloading(progress: 0)
|
||||
} else {
|
||||
modelState = .loading
|
||||
}
|
||||
|
||||
do {
|
||||
let container = try await loadModelContainer(
|
||||
id: defaultModelId
|
||||
) { [weak self] progress in
|
||||
guard let self else { return }
|
||||
Task { @MainActor in
|
||||
if case .downloading = self.modelState {
|
||||
self.modelState = .downloading(progress: progress.fractionCompleted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.modelContainer = container
|
||||
UserDefaults.standard.set(true, forKey: Self.modelDownloadedKey)
|
||||
modelState = .ready
|
||||
} catch {
|
||||
modelState = .error(error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
func unloadModel() {
|
||||
cancelGeneration()
|
||||
modelContainer = nil
|
||||
modelState = .downloaded
|
||||
}
|
||||
|
||||
// MARK: - Generation
|
||||
|
||||
func streamChatCompletion(request: ChatCompletionRequest) -> AsyncThrowingStream<
|
||||
ChatCompletionChunk, Error
|
||||
> {
|
||||
AsyncThrowingStream { continuation in
|
||||
let task = Task { [weak self] in
|
||||
guard let self else {
|
||||
continuation.finish(throwing: LocalInferenceError.serviceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
guard let container = self.modelContainer else {
|
||||
continuation.finish(throwing: LocalInferenceError.modelNotLoaded)
|
||||
return
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
self.modelState = .generating
|
||||
}
|
||||
|
||||
defer {
|
||||
Task { @MainActor [weak self] in
|
||||
if self?.modelState == .generating {
|
||||
self?.modelState = .ready
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let chunkId = "local-\(UUID().uuidString)"
|
||||
|
||||
do {
|
||||
// Build Chat.Message array from the request
|
||||
var chatMessages: [Chat.Message] = []
|
||||
for msg in request.messages {
|
||||
switch msg.role {
|
||||
case "system":
|
||||
chatMessages.append(.system(msg.content))
|
||||
case "assistant":
|
||||
chatMessages.append(.assistant(msg.content))
|
||||
default:
|
||||
chatMessages.append(.user(msg.content))
|
||||
}
|
||||
}
|
||||
|
||||
// Use ChatSession for streaming generation
|
||||
let session = ChatSession(
|
||||
container,
|
||||
history: chatMessages,
|
||||
generateParameters: GenerateParameters(
|
||||
maxTokens: request.maxTokens ?? 4096,
|
||||
temperature: Float(request.temperature ?? 0.7)
|
||||
)
|
||||
)
|
||||
|
||||
// Stream with an empty prompt since history already contains the conversation
|
||||
let stream = session.streamResponse(to: "")
|
||||
for try await text in stream {
|
||||
if Task.isCancelled { break }
|
||||
|
||||
let chunk = ChatCompletionChunk(
|
||||
id: chunkId,
|
||||
model: request.model,
|
||||
choices: [
|
||||
StreamingChoice(
|
||||
index: 0,
|
||||
delta: Delta(role: nil, content: text),
|
||||
finishReason: nil
|
||||
)
|
||||
],
|
||||
usage: nil
|
||||
)
|
||||
continuation.yield(chunk)
|
||||
}
|
||||
|
||||
// Send final chunk with finish reason
|
||||
let finalChunk = ChatCompletionChunk(
|
||||
id: chunkId,
|
||||
model: request.model,
|
||||
choices: [
|
||||
StreamingChoice(
|
||||
index: 0,
|
||||
delta: Delta(role: nil, content: nil),
|
||||
finishReason: "stop"
|
||||
)
|
||||
],
|
||||
usage: nil
|
||||
)
|
||||
continuation.yield(finalChunk)
|
||||
continuation.finish()
|
||||
} catch {
|
||||
continuation.finish(throwing: error)
|
||||
}
|
||||
}
|
||||
|
||||
self.generationTask = task
|
||||
|
||||
continuation.onTermination = { _ in
|
||||
task.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cancelGeneration() {
|
||||
generationTask?.cancel()
|
||||
generationTask = nil
|
||||
if modelState == .generating {
|
||||
modelState = .ready
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum LocalInferenceError: LocalizedError {
|
||||
case serviceUnavailable
|
||||
case modelNotLoaded
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .serviceUnavailable: "Local inference service is unavailable"
|
||||
case .modelNotLoaded: "Local model is not loaded"
|
||||
}
|
||||
}
|
||||
}
|
||||
50
app/EXO-iOS/EXO-iOS/Services/SSEStreamParser.swift
Normal file
50
app/EXO-iOS/EXO-iOS/Services/SSEStreamParser.swift
Normal file
@@ -0,0 +1,50 @@
|
||||
import Foundation
|
||||
|
||||
struct SSEStreamParser<T: Decodable>: AsyncSequence {
|
||||
typealias Element = T
|
||||
|
||||
let bytes: URLSession.AsyncBytes
|
||||
let decoder: JSONDecoder
|
||||
|
||||
init(bytes: URLSession.AsyncBytes, decoder: JSONDecoder = JSONDecoder()) {
|
||||
self.bytes = bytes
|
||||
self.decoder = decoder
|
||||
}
|
||||
|
||||
func makeAsyncIterator() -> AsyncIterator {
|
||||
AsyncIterator(lines: bytes.lines, decoder: decoder)
|
||||
}
|
||||
|
||||
struct AsyncIterator: AsyncIteratorProtocol {
|
||||
var lines: AsyncLineSequence<URLSession.AsyncBytes>.AsyncIterator
|
||||
let decoder: JSONDecoder
|
||||
|
||||
init(lines: AsyncLineSequence<URLSession.AsyncBytes>, decoder: JSONDecoder) {
|
||||
self.lines = lines.makeAsyncIterator()
|
||||
self.decoder = decoder
|
||||
}
|
||||
|
||||
mutating func next() async throws -> T? {
|
||||
while let line = try await lines.next() {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
|
||||
guard trimmed.hasPrefix("data: ") else { continue }
|
||||
|
||||
let payload = String(trimmed.dropFirst(6))
|
||||
|
||||
if payload == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
|
||||
guard let data = payload.data(using: .utf8) else { continue }
|
||||
|
||||
do {
|
||||
return try decoder.decode(T.self, from: data)
|
||||
} catch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
203
app/EXO-iOS/EXO-iOS/Views/Chat/ChatView.swift
Normal file
203
app/EXO-iOS/EXO-iOS/Views/Chat/ChatView.swift
Normal file
@@ -0,0 +1,203 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ChatView: View {
|
||||
@Environment(ClusterService.self) private var clusterService
|
||||
@Environment(ChatService.self) private var chatService
|
||||
@Environment(LocalInferenceService.self) private var localInferenceService
|
||||
@State private var inputText = ""
|
||||
@State private var showModelSelector = false
|
||||
|
||||
var body: some View {
|
||||
VStack(spacing: 0) {
|
||||
modelBar
|
||||
|
||||
GradientDivider()
|
||||
|
||||
messageList
|
||||
|
||||
GradientDivider()
|
||||
|
||||
inputBar
|
||||
}
|
||||
.background(Color.exoBlack)
|
||||
.sheet(isPresented: $showModelSelector) {
|
||||
ModelSelectorView(
|
||||
models: clusterService.availableModels,
|
||||
selectedModelId: chatService.activeConversation?.modelId
|
||||
) { modelId in
|
||||
chatService.setModelForActiveConversation(modelId)
|
||||
}
|
||||
.presentationBackground(Color.exoDarkGray)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Bar
|
||||
|
||||
private var useLocalModel: Bool {
|
||||
!clusterService.isConnected && localInferenceService.isAvailable
|
||||
}
|
||||
|
||||
private var modelBar: some View {
|
||||
Button {
|
||||
if !useLocalModel {
|
||||
showModelSelector = true
|
||||
}
|
||||
} label: {
|
||||
HStack {
|
||||
Image(systemName: useLocalModel ? "iphone" : "cpu")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(useLocalModel ? Color.exoYellow : Color.exoLightGray)
|
||||
|
||||
if useLocalModel {
|
||||
Text(localInferenceService.defaultModelId)
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.lineLimit(1)
|
||||
} else if let modelId = chatService.activeConversation?.modelId {
|
||||
Text(modelId)
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.lineLimit(1)
|
||||
} else {
|
||||
Text("SELECT MODEL")
|
||||
.font(.exoSubheadline)
|
||||
.tracking(1.5)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
if useLocalModel {
|
||||
Text("ON-DEVICE")
|
||||
.font(.exoCaption)
|
||||
.tracking(1)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
.padding(.horizontal, 6)
|
||||
.padding(.vertical, 2)
|
||||
.background(Color.exoYellow.opacity(0.15))
|
||||
.clipShape(Capsule())
|
||||
} else {
|
||||
Image(systemName: "chevron.right")
|
||||
.font(.caption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal)
|
||||
.padding(.vertical, 10)
|
||||
.background(Color.exoDarkGray)
|
||||
}
|
||||
.tint(.primary)
|
||||
.disabled(useLocalModel)
|
||||
}
|
||||
|
||||
// MARK: - Messages
|
||||
|
||||
private var messageList: some View {
|
||||
ScrollViewReader { proxy in
|
||||
ScrollView {
|
||||
LazyVStack(spacing: 12) {
|
||||
if chatService.activeMessages.isEmpty {
|
||||
emptyState
|
||||
} else {
|
||||
ForEach(chatService.activeMessages) { message in
|
||||
MessageBubbleView(message: message)
|
||||
.id(message.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
.background(Color.exoBlack)
|
||||
.onChange(of: chatService.activeMessages.last?.content) {
|
||||
if let lastId = chatService.activeMessages.last?.id {
|
||||
withAnimation(.easeOut(duration: 0.2)) {
|
||||
proxy.scrollTo(lastId, anchor: .bottom)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var emptyState: some View {
|
||||
VStack(spacing: 16) {
|
||||
Spacer(minLength: 80)
|
||||
|
||||
ZStack {
|
||||
Circle()
|
||||
.stroke(Color.exoYellow.opacity(0.15), lineWidth: 1)
|
||||
.frame(width: 80, height: 80)
|
||||
Circle()
|
||||
.stroke(Color.exoYellow.opacity(0.3), lineWidth: 1)
|
||||
.frame(width: 56, height: 56)
|
||||
Circle()
|
||||
.fill(Color.exoYellow.opacity(0.15))
|
||||
.frame(width: 32, height: 32)
|
||||
Circle()
|
||||
.fill(Color.exoYellow)
|
||||
.frame(width: 8, height: 8)
|
||||
.shadow(color: Color.exoYellow.opacity(0.6), radius: 6)
|
||||
}
|
||||
|
||||
Text("AWAITING INPUT")
|
||||
.font(.exoSubheadline)
|
||||
.tracking(3)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
|
||||
Text("Send a message to begin.")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray.opacity(0.6))
|
||||
|
||||
Spacer(minLength: 80)
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
|
||||
// MARK: - Input
|
||||
|
||||
private var inputBar: some View {
|
||||
HStack(alignment: .bottom, spacing: 8) {
|
||||
TextField("Message...", text: $inputText, axis: .vertical)
|
||||
.font(.exoBody)
|
||||
.lineLimit(1...6)
|
||||
.textFieldStyle(.plain)
|
||||
.padding(10)
|
||||
.background(Color.exoMediumGray)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.clipShape(RoundedRectangle(cornerRadius: 8))
|
||||
|
||||
if chatService.isGenerating {
|
||||
Button {
|
||||
chatService.cancelGeneration()
|
||||
} label: {
|
||||
Image(systemName: "stop.circle.fill")
|
||||
.font(.title2)
|
||||
.foregroundStyle(Color.exoDestructive)
|
||||
}
|
||||
} else {
|
||||
Button {
|
||||
let text = inputText
|
||||
inputText = ""
|
||||
chatService.sendMessage(text)
|
||||
} label: {
|
||||
Text("SEND")
|
||||
.font(.exoMono(12, weight: .bold))
|
||||
.tracking(1)
|
||||
.foregroundStyle(canSend ? Color.exoBlack : Color.exoLightGray)
|
||||
.padding(.horizontal, 14)
|
||||
.padding(.vertical, 8)
|
||||
.background(canSend ? Color.exoYellow : Color.exoMediumGray)
|
||||
.clipShape(RoundedRectangle(cornerRadius: 8))
|
||||
}
|
||||
.disabled(!canSend)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal)
|
||||
.padding(.vertical, 8)
|
||||
.background(Color.exoDarkGray)
|
||||
}
|
||||
|
||||
private var canSend: Bool {
|
||||
!inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
|
||||
&& (clusterService.isConnected || localInferenceService.isAvailable)
|
||||
}
|
||||
}
|
||||
54
app/EXO-iOS/EXO-iOS/Views/Chat/MessageBubbleView.swift
Normal file
54
app/EXO-iOS/EXO-iOS/Views/Chat/MessageBubbleView.swift
Normal file
@@ -0,0 +1,54 @@
|
||||
import SwiftUI
|
||||
|
||||
struct MessageBubbleView: View {
|
||||
let message: ChatMessage
|
||||
|
||||
private var isAssistant: Bool { message.role == .assistant }
|
||||
|
||||
var body: some View {
|
||||
HStack {
|
||||
if message.role == .user { Spacer(minLength: 48) }
|
||||
|
||||
VStack(alignment: isAssistant ? .leading : .trailing, spacing: 6) {
|
||||
// Header
|
||||
HStack(spacing: 4) {
|
||||
if isAssistant {
|
||||
Circle()
|
||||
.fill(Color.exoYellow)
|
||||
.frame(width: 6, height: 6)
|
||||
.shadow(color: Color.exoYellow.opacity(0.6), radius: 4)
|
||||
Text("EXO")
|
||||
.font(.exoMono(10, weight: .bold))
|
||||
.tracking(1.5)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
} else {
|
||||
Text("QUERY")
|
||||
.font(.exoMono(10, weight: .medium))
|
||||
.tracking(1.5)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
}
|
||||
|
||||
// Bubble
|
||||
HStack(spacing: 0) {
|
||||
if isAssistant {
|
||||
RoundedRectangle(cornerRadius: 1)
|
||||
.fill(Color.exoYellow.opacity(0.5))
|
||||
.frame(width: 2)
|
||||
}
|
||||
|
||||
Text(message.content + (message.isStreaming ? " \u{258C}" : ""))
|
||||
.font(.exoBody)
|
||||
.textSelection(.enabled)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.padding(.horizontal, 14)
|
||||
.padding(.vertical, 10)
|
||||
}
|
||||
.background(Color.exoDarkGray)
|
||||
.clipShape(RoundedRectangle(cornerRadius: 8))
|
||||
}
|
||||
|
||||
if isAssistant { Spacer(minLength: 48) }
|
||||
}
|
||||
}
|
||||
}
|
||||
75
app/EXO-iOS/EXO-iOS/Views/Chat/ModelSelectorView.swift
Normal file
75
app/EXO-iOS/EXO-iOS/Views/Chat/ModelSelectorView.swift
Normal file
@@ -0,0 +1,75 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ModelSelectorView: View {
|
||||
let models: [ModelOption]
|
||||
let selectedModelId: String?
|
||||
let onSelect: (String) -> Void
|
||||
@Environment(\.dismiss) private var dismiss
|
||||
|
||||
var body: some View {
|
||||
NavigationStack {
|
||||
List {
|
||||
if models.isEmpty {
|
||||
emptyContent
|
||||
} else {
|
||||
modelsList
|
||||
}
|
||||
}
|
||||
.scrollContentBackground(.hidden)
|
||||
.background(Color.exoBlack)
|
||||
.navigationTitle("SELECT MODEL")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .cancellationAction) {
|
||||
Button("Cancel") { dismiss() }
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var emptyContent: some View {
|
||||
ContentUnavailableView(
|
||||
"No Models Available",
|
||||
systemImage: "cpu",
|
||||
description: Text("Connect to an EXO cluster to see available models.")
|
||||
)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
.listRowBackground(Color.exoBlack)
|
||||
}
|
||||
|
||||
private var modelsList: some View {
|
||||
ForEach(models) { model in
|
||||
Button {
|
||||
onSelect(model.id)
|
||||
dismiss()
|
||||
} label: {
|
||||
modelRow(model)
|
||||
}
|
||||
.tint(.primary)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
}
|
||||
|
||||
private func modelRow(_ model: ModelOption) -> some View {
|
||||
HStack {
|
||||
VStack(alignment: .leading, spacing: 2) {
|
||||
Text(model.displayName)
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
Text(model.id)
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
if model.id == selectedModelId {
|
||||
Image(systemName: "checkmark")
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ConnectionStatusBadge: View {
|
||||
let connectionState: ConnectionState
|
||||
var localModelState: LocalModelState = .notDownloaded
|
||||
|
||||
private var isLocalReady: Bool {
|
||||
if case .disconnected = connectionState {
|
||||
return localModelState == .ready || localModelState == .generating
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
HStack(spacing: 6) {
|
||||
Circle()
|
||||
.fill(dotColor)
|
||||
.frame(width: 8, height: 8)
|
||||
.shadow(color: dotColor.opacity(0.6), radius: 4)
|
||||
|
||||
Text(label.uppercased())
|
||||
.font(.exoMono(10, weight: .medium))
|
||||
.tracking(1)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
}
|
||||
.padding(.horizontal, 10)
|
||||
.padding(.vertical, 5)
|
||||
.background(backgroundColor)
|
||||
.clipShape(Capsule())
|
||||
.overlay(
|
||||
Capsule()
|
||||
.stroke(dotColor.opacity(0.3), lineWidth: 1)
|
||||
)
|
||||
}
|
||||
|
||||
private var dotColor: Color {
|
||||
if isLocalReady {
|
||||
return .exoYellow
|
||||
}
|
||||
switch connectionState {
|
||||
case .connected: return .green
|
||||
case .connecting: return .orange
|
||||
case .disconnected: return .exoLightGray
|
||||
}
|
||||
}
|
||||
|
||||
private var label: String {
|
||||
if isLocalReady {
|
||||
return "Local"
|
||||
}
|
||||
switch connectionState {
|
||||
case .connected: return "Connected"
|
||||
case .connecting: return "Connecting"
|
||||
case .disconnected: return "Disconnected"
|
||||
}
|
||||
}
|
||||
|
||||
private var backgroundColor: Color {
|
||||
if isLocalReady {
|
||||
return Color.exoYellow.opacity(0.1)
|
||||
}
|
||||
switch connectionState {
|
||||
case .connected: return .green.opacity(0.1)
|
||||
case .connecting: return .orange.opacity(0.1)
|
||||
case .disconnected: return Color.exoMediumGray.opacity(0.5)
|
||||
}
|
||||
}
|
||||
}
|
||||
136
app/EXO-iOS/EXO-iOS/Views/RootView.swift
Normal file
136
app/EXO-iOS/EXO-iOS/Views/RootView.swift
Normal file
@@ -0,0 +1,136 @@
|
||||
import SwiftUI
|
||||
|
||||
struct RootView: View {
|
||||
@Environment(ClusterService.self) private var clusterService
|
||||
@Environment(DiscoveryService.self) private var discoveryService
|
||||
@Environment(ChatService.self) private var chatService
|
||||
@Environment(LocalInferenceService.self) private var localInferenceService
|
||||
@State private var showSettings = false
|
||||
@State private var showConversations = false
|
||||
|
||||
var body: some View {
|
||||
NavigationStack {
|
||||
ChatView()
|
||||
.navigationTitle("EXO")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .topBarLeading) {
|
||||
conversationMenuButton
|
||||
}
|
||||
|
||||
ToolbarItem(placement: .principal) {
|
||||
ConnectionStatusBadge(
|
||||
connectionState: clusterService.connectionState,
|
||||
localModelState: localInferenceService.modelState
|
||||
)
|
||||
}
|
||||
|
||||
ToolbarItem(placement: .topBarTrailing) {
|
||||
Button {
|
||||
showSettings = true
|
||||
} label: {
|
||||
Image(systemName: "gear")
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.tint(Color.exoYellow)
|
||||
.sheet(isPresented: $showSettings) {
|
||||
SettingsView()
|
||||
.environment(discoveryService)
|
||||
.presentationBackground(Color.exoDarkGray)
|
||||
}
|
||||
.sheet(isPresented: $showConversations) {
|
||||
conversationList
|
||||
.presentationBackground(Color.exoDarkGray)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Conversations
|
||||
|
||||
private var conversationMenuButton: some View {
|
||||
HStack(spacing: 12) {
|
||||
Button {
|
||||
showConversations = true
|
||||
} label: {
|
||||
Image(systemName: "sidebar.left")
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
|
||||
Button {
|
||||
chatService.createConversation()
|
||||
} label: {
|
||||
Image(systemName: "square.and.pencil")
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var conversationList: some View {
|
||||
NavigationStack {
|
||||
List {
|
||||
if chatService.conversations.isEmpty {
|
||||
Text("No conversations yet")
|
||||
.font(.exoBody)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
} else {
|
||||
ForEach(chatService.conversations) { conversation in
|
||||
let isActive = conversation.id == chatService.activeConversationId
|
||||
Button {
|
||||
chatService.setActiveConversation(id: conversation.id)
|
||||
showConversations = false
|
||||
} label: {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text(conversation.title)
|
||||
.font(.exoSubheadline)
|
||||
.fontWeight(isActive ? .semibold : .regular)
|
||||
.foregroundStyle(
|
||||
isActive ? Color.exoYellow : Color.exoForeground
|
||||
)
|
||||
.lineLimit(1)
|
||||
|
||||
if let modelId = conversation.modelId {
|
||||
Text(modelId)
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
.lineLimit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
.listRowBackground(
|
||||
isActive
|
||||
? Color.exoYellow.opacity(0.1)
|
||||
: Color.exoDarkGray
|
||||
)
|
||||
}
|
||||
.onDelete { indexSet in
|
||||
for index in indexSet {
|
||||
chatService.deleteConversation(id: chatService.conversations[index].id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.scrollContentBackground(.hidden)
|
||||
.background(Color.exoBlack)
|
||||
.navigationTitle("Conversations")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .confirmationAction) {
|
||||
Button("Done") { showConversations = false }
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
ToolbarItem(placement: .topBarLeading) {
|
||||
Button {
|
||||
chatService.createConversation()
|
||||
} label: {
|
||||
Image(systemName: "plus")
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
314
app/EXO-iOS/EXO-iOS/Views/Settings/SettingsView.swift
Normal file
314
app/EXO-iOS/EXO-iOS/Views/Settings/SettingsView.swift
Normal file
@@ -0,0 +1,314 @@
|
||||
import SwiftUI
|
||||
|
||||
struct SettingsView: View {
|
||||
@Environment(ClusterService.self) private var clusterService
|
||||
@Environment(DiscoveryService.self) private var discoveryService
|
||||
@Environment(LocalInferenceService.self) private var localInferenceService
|
||||
@Environment(\.dismiss) private var dismiss
|
||||
@State private var host: String = ""
|
||||
@State private var port: String = "52415"
|
||||
|
||||
var body: some View {
|
||||
NavigationStack {
|
||||
Form {
|
||||
localModelSection
|
||||
nearbyClustersSection
|
||||
connectionSection
|
||||
statusSection
|
||||
}
|
||||
.scrollContentBackground(.hidden)
|
||||
.background(Color.exoBlack)
|
||||
.navigationTitle("Settings")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .confirmationAction) {
|
||||
Button("Done") { dismiss() }
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Section Headers
|
||||
|
||||
private func sectionHeader(_ title: String) -> some View {
|
||||
Text(title.uppercased())
|
||||
.font(.exoMono(10, weight: .semibold))
|
||||
.tracking(2)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
|
||||
// MARK: - Local Model
|
||||
|
||||
private var localModelSection: some View {
|
||||
Section {
|
||||
HStack {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text(localInferenceService.defaultModelId)
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
|
||||
Text(localModelStatusText)
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
localModelActionButton
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
if case .downloading(let progress) = localInferenceService.modelState {
|
||||
ProgressView(value: progress)
|
||||
.tint(Color.exoYellow)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
} header: {
|
||||
sectionHeader("Local Model")
|
||||
} footer: {
|
||||
Text(
|
||||
"When disconnected from a cluster, messages are processed on-device using this model."
|
||||
)
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray.opacity(0.7))
|
||||
}
|
||||
}
|
||||
|
||||
private var localModelStatusText: String {
|
||||
switch localInferenceService.modelState {
|
||||
case .notDownloaded: "Not downloaded"
|
||||
case .downloading(let progress): "Downloading \(Int(progress * 100))%..."
|
||||
case .downloaded: "Downloaded — not loaded"
|
||||
case .loading: "Loading into memory..."
|
||||
case .ready: "Ready"
|
||||
case .generating: "Generating..."
|
||||
case .error(let message): "Error: \(message)"
|
||||
}
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private var localModelActionButton: some View {
|
||||
switch localInferenceService.modelState {
|
||||
case .notDownloaded:
|
||||
exoButton("Download") {
|
||||
Task { await localInferenceService.prepareModel() }
|
||||
}
|
||||
case .downloading:
|
||||
ProgressView()
|
||||
.controlSize(.small)
|
||||
.tint(Color.exoYellow)
|
||||
case .downloaded:
|
||||
exoButton("Load") {
|
||||
Task { await localInferenceService.prepareModel() }
|
||||
}
|
||||
case .loading:
|
||||
ProgressView()
|
||||
.controlSize(.small)
|
||||
.tint(Color.exoYellow)
|
||||
case .ready, .generating:
|
||||
exoButton("Unload") {
|
||||
localInferenceService.unloadModel()
|
||||
}
|
||||
case .error:
|
||||
exoButton("Retry", destructive: true) {
|
||||
Task { await localInferenceService.prepareModel() }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func exoButton(_ title: String, destructive: Bool = false, action: @escaping () -> Void)
|
||||
-> some View
|
||||
{
|
||||
let borderColor = destructive ? Color.exoDestructive : Color.exoYellow
|
||||
return Button(action: action) {
|
||||
Text(title.uppercased())
|
||||
.font(.exoMono(11, weight: .semibold))
|
||||
.tracking(1)
|
||||
.foregroundStyle(borderColor)
|
||||
.padding(.horizontal, 10)
|
||||
.padding(.vertical, 5)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 6)
|
||||
.stroke(borderColor, lineWidth: 1)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Nearby Clusters
|
||||
|
||||
private var nearbyClustersSection: some View {
|
||||
Section {
|
||||
if discoveryService.discoveredClusters.isEmpty {
|
||||
if discoveryService.isSearching {
|
||||
HStack {
|
||||
ProgressView()
|
||||
.tint(Color.exoYellow)
|
||||
.padding(.trailing, 8)
|
||||
Text("Searching for clusters...")
|
||||
.font(.exoBody)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
} else {
|
||||
Text("No clusters found")
|
||||
.font(.exoBody)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
} else {
|
||||
ForEach(discoveryService.discoveredClusters) { cluster in
|
||||
HStack {
|
||||
VStack(alignment: .leading) {
|
||||
Text(cluster.name)
|
||||
.font(.exoBody)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
}
|
||||
Spacer()
|
||||
exoButton("Connect") {
|
||||
Task {
|
||||
await clusterService.connectToDiscoveredCluster(
|
||||
cluster, using: discoveryService
|
||||
)
|
||||
if clusterService.isConnected {
|
||||
dismiss()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
}
|
||||
} header: {
|
||||
sectionHeader("Nearby Clusters")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Manual Connection
|
||||
|
||||
private var connectionSection: some View {
|
||||
Section {
|
||||
TextField("IP Address (e.g. 192.168.1.42)", text: $host)
|
||||
.font(.exoBody)
|
||||
.keyboardType(.decimalPad)
|
||||
.textContentType(.URL)
|
||||
.autocorrectionDisabled()
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
TextField("Port", text: $port)
|
||||
.font(.exoBody)
|
||||
.keyboardType(.numberPad)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
Button {
|
||||
Task {
|
||||
let portNum = Int(port) ?? ConnectionInfo.defaultPort
|
||||
let info = ConnectionInfo(host: host, port: portNum, nodeId: nil)
|
||||
await clusterService.connect(to: info)
|
||||
if clusterService.isConnected {
|
||||
dismiss()
|
||||
}
|
||||
}
|
||||
} label: {
|
||||
Text(clusterService.isConnected ? "RECONNECT" : "CONNECT")
|
||||
.font(.exoMono(13, weight: .semibold))
|
||||
.tracking(1.5)
|
||||
.foregroundStyle(
|
||||
host.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
|
||||
? Color.exoLightGray : Color.exoYellow
|
||||
)
|
||||
.frame(maxWidth: .infinity)
|
||||
}
|
||||
.disabled(host.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
} header: {
|
||||
sectionHeader("Manual Connection")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Status
|
||||
|
||||
private var statusSection: some View {
|
||||
Section {
|
||||
if let connection = clusterService.currentConnection {
|
||||
LabeledContent {
|
||||
Text(connection.host)
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
} label: {
|
||||
Text("Host")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
LabeledContent {
|
||||
Text("\(connection.port)")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
} label: {
|
||||
Text("Port")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
if let nodeId = connection.nodeId {
|
||||
LabeledContent {
|
||||
Text(String(nodeId.prefix(12)) + "...")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
} label: {
|
||||
Text("Node ID")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
|
||||
LabeledContent {
|
||||
Text("\(clusterService.availableModels.count)")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
} label: {
|
||||
Text("Models")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
Button(role: .destructive) {
|
||||
clusterService.disconnect()
|
||||
} label: {
|
||||
Text("DISCONNECT")
|
||||
.font(.exoMono(13, weight: .semibold))
|
||||
.tracking(1.5)
|
||||
.foregroundStyle(Color.exoDestructive)
|
||||
.frame(maxWidth: .infinity)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
} else {
|
||||
if let error = clusterService.lastError {
|
||||
Label {
|
||||
Text(error)
|
||||
.font(.exoCaption)
|
||||
} icon: {
|
||||
Image(systemName: "exclamationmark.triangle")
|
||||
}
|
||||
.foregroundStyle(Color.exoDestructive)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
} else {
|
||||
Text("Not connected")
|
||||
.font(.exoBody)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
}
|
||||
} header: {
|
||||
sectionHeader("Status")
|
||||
}
|
||||
}
|
||||
}
|
||||
51
app/EXO-iOS/EXO-iOS/Views/Theme/EXOTheme.swift
Normal file
51
app/EXO-iOS/EXO-iOS/Views/Theme/EXOTheme.swift
Normal file
@@ -0,0 +1,51 @@
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - EXO Color Palette
|
||||
|
||||
extension Color {
|
||||
/// Primary background — near-black (#121212)
|
||||
static let exoBlack = Color(red: 0x12 / 255.0, green: 0x12 / 255.0, blue: 0x12 / 255.0)
|
||||
/// Card / surface background (#1F1F1F)
|
||||
static let exoDarkGray = Color(red: 0x1F / 255.0, green: 0x1F / 255.0, blue: 0x1F / 255.0)
|
||||
/// Input field / elevated surface (#353535)
|
||||
static let exoMediumGray = Color(red: 0x35 / 255.0, green: 0x35 / 255.0, blue: 0x35 / 255.0)
|
||||
/// Secondary text (#999999)
|
||||
static let exoLightGray = Color(red: 0x99 / 255.0, green: 0x99 / 255.0, blue: 0x99 / 255.0)
|
||||
/// Accent yellow — matches dashboard (#FFD700)
|
||||
static let exoYellow = Color(red: 0xFF / 255.0, green: 0xD7 / 255.0, blue: 0x00 / 255.0)
|
||||
/// Primary foreground text (#E5E5E5)
|
||||
static let exoForeground = Color(red: 0xE5 / 255.0, green: 0xE5 / 255.0, blue: 0xE5 / 255.0)
|
||||
/// Destructive / error (#E74C3C)
|
||||
static let exoDestructive = Color(red: 0xE7 / 255.0, green: 0x4C / 255.0, blue: 0x3C / 255.0)
|
||||
}
|
||||
|
||||
// MARK: - EXO Typography (SF Mono via .monospaced design)
|
||||
|
||||
extension Font {
|
||||
/// Monospaced font at a given size and weight.
|
||||
static func exoMono(_ size: CGFloat, weight: Font.Weight = .regular) -> Font {
|
||||
.system(size: size, weight: weight, design: .monospaced)
|
||||
}
|
||||
|
||||
/// Body text — 15pt monospaced
|
||||
static let exoBody: Font = .system(size: 15, weight: .regular, design: .monospaced)
|
||||
/// Caption — 11pt monospaced
|
||||
static let exoCaption: Font = .system(size: 11, weight: .regular, design: .monospaced)
|
||||
/// Subheadline — 13pt monospaced medium
|
||||
static let exoSubheadline: Font = .system(size: 13, weight: .medium, design: .monospaced)
|
||||
/// Headline — 17pt monospaced semibold
|
||||
static let exoHeadline: Font = .system(size: 17, weight: .semibold, design: .monospaced)
|
||||
}
|
||||
|
||||
// MARK: - Reusable Gradient Divider
|
||||
|
||||
struct GradientDivider: View {
|
||||
var body: some View {
|
||||
LinearGradient(
|
||||
colors: [.clear, Color.exoYellow.opacity(0.3), .clear],
|
||||
startPoint: .leading,
|
||||
endPoint: .trailing
|
||||
)
|
||||
.frame(height: 1)
|
||||
}
|
||||
}
|
||||
18
app/EXO-iOS/EXO-iOSTests/EXO_iOSTests.swift
Normal file
18
app/EXO-iOS/EXO-iOSTests/EXO_iOSTests.swift
Normal file
@@ -0,0 +1,18 @@
|
||||
//
|
||||
// EXO_iOSTests.swift
|
||||
// EXO-iOSTests
|
||||
//
|
||||
// Created by Sami Khan on 2026-02-17.
|
||||
//
|
||||
|
||||
import Testing
|
||||
|
||||
@testable import EXO_iOS
|
||||
|
||||
struct EXO_iOSTests {
|
||||
|
||||
@Test func example() async throws {
|
||||
// Write your test here and use APIs like `#expect(...)` to check expected conditions.
|
||||
}
|
||||
|
||||
}
|
||||
41
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITests.swift
Normal file
41
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITests.swift
Normal file
@@ -0,0 +1,41 @@
|
||||
//
|
||||
// EXO_iOSUITests.swift
|
||||
// EXO-iOSUITests
|
||||
//
|
||||
// Created by Sami Khan on 2026-02-17.
|
||||
//
|
||||
|
||||
import XCTest
|
||||
|
||||
final class EXO_iOSUITests: XCTestCase {
|
||||
|
||||
override func setUpWithError() throws {
|
||||
// Put setup code here. This method is called before the invocation of each test method in the class.
|
||||
|
||||
// In UI tests it is usually best to stop immediately when a failure occurs.
|
||||
continueAfterFailure = false
|
||||
|
||||
// In UI tests it’s important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this.
|
||||
}
|
||||
|
||||
override func tearDownWithError() throws {
|
||||
// Put teardown code here. This method is called after the invocation of each test method in the class.
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func testExample() throws {
|
||||
// UI tests must launch the application that they test.
|
||||
let app = XCUIApplication()
|
||||
app.launch()
|
||||
|
||||
// Use XCTAssert and related functions to verify your tests produce the correct results.
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func testLaunchPerformance() throws {
|
||||
// This measures how long it takes to launch your application.
|
||||
measure(metrics: [XCTApplicationLaunchMetric()]) {
|
||||
XCUIApplication().launch()
|
||||
}
|
||||
}
|
||||
}
|
||||
33
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITestsLaunchTests.swift
Normal file
33
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITestsLaunchTests.swift
Normal file
@@ -0,0 +1,33 @@
|
||||
//
|
||||
// EXO_iOSUITestsLaunchTests.swift
|
||||
// EXO-iOSUITests
|
||||
//
|
||||
// Created by Sami Khan on 2026-02-17.
|
||||
//
|
||||
|
||||
import XCTest
|
||||
|
||||
final class EXO_iOSUITestsLaunchTests: XCTestCase {
|
||||
|
||||
override class var runsForEachTargetApplicationUIConfiguration: Bool {
|
||||
true
|
||||
}
|
||||
|
||||
override func setUpWithError() throws {
|
||||
continueAfterFailure = false
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func testLaunch() throws {
|
||||
let app = XCUIApplication()
|
||||
app.launch()
|
||||
|
||||
// Insert steps here to perform after app launch but before taking a screenshot,
|
||||
// such as logging into a test account or navigating somewhere in the app
|
||||
|
||||
let attachment = XCTAttachment(screenshot: app.screenshot())
|
||||
attachment.name = "Launch Screen"
|
||||
attachment.lifetime = .keepAlways
|
||||
add(attachment)
|
||||
}
|
||||
}
|
||||
@@ -126,37 +126,11 @@ final class ExoProcessController: ObservableObject {
|
||||
return
|
||||
}
|
||||
process.terminationHandler = nil
|
||||
status = .stopped
|
||||
|
||||
guard process.isRunning else {
|
||||
self.process = nil
|
||||
return
|
||||
if process.isRunning {
|
||||
process.terminate()
|
||||
}
|
||||
|
||||
let proc = process
|
||||
self.process = nil
|
||||
|
||||
Task.detached {
|
||||
proc.interrupt()
|
||||
|
||||
for _ in 0..<50 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
proc.terminate()
|
||||
}
|
||||
|
||||
for _ in 0..<30 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
kill(proc.processIdentifier, SIGKILL)
|
||||
}
|
||||
}
|
||||
status = .stopped
|
||||
}
|
||||
|
||||
func restart() {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,47 +1,29 @@
|
||||
# type: ignore
|
||||
#!/usr/bin/env python3
|
||||
"""Tool-calling eval for exo's OpenAI-compatible API.
|
||||
|
||||
Tests whether models correctly:
|
||||
- Trigger tool calls when appropriate
|
||||
- Return valid JSON arguments matching function schemas
|
||||
- Handle multi-turn tool use (call -> result -> final answer)
|
||||
- Avoid calling tools when unnecessary
|
||||
|
||||
Start exo with a model first, then run:
|
||||
uv run python tool_call_eval.py --model <model-id>
|
||||
uv run python tool_call_eval.py --model <model-id> --host 10.0.0.5 --port 52415
|
||||
uv run python tool_call_eval.py --model <model-id> --repeat 3
|
||||
uv run python tool_call_eval.py --model <model-id> --scenarios weather_simple calculator_multi_turn
|
||||
"""
|
||||
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from harness import (
|
||||
ExoClient,
|
||||
ExoHttpError,
|
||||
add_common_instance_args,
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
)
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Backoff constants for cluster settling retry
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
@@ -121,6 +103,154 @@ def load_tokenizer_for_bench(model_id: str) -> Any:
|
||||
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def format_peak_memory(b: float) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if b < 1024.0:
|
||||
@@ -139,6 +269,184 @@ def parse_int_list(values: list[str]) -> list[int]:
|
||||
return items
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if m.get("name").lower() == model_arg.lower():
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m.get("hugging_face_id") or m["name"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def run_one_completion(
|
||||
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
@@ -230,12 +538,76 @@ class PromptSizer:
|
||||
return content, tok
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
description="Benchmark exo model throughput across placement previews.",
|
||||
)
|
||||
add_common_instance_args(ap)
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--pp",
|
||||
nargs="+",
|
||||
@@ -248,6 +620,34 @@ def main() -> int:
|
||||
required=True,
|
||||
help="Generation lengths (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
)
|
||||
@@ -257,6 +657,9 @@ def main() -> int:
|
||||
default=0,
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
default="bench/results.json",
|
||||
@@ -271,6 +674,17 @@ def main() -> int:
|
||||
action="store_true",
|
||||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -305,10 +719,24 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected = settle_and_fetch_placements(
|
||||
client, full_model_id, args, settle_timeout=args.settle_timeout
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and settle_deadline:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not selected and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
@@ -332,6 +760,16 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
327
bench/harness.py
327
bench/harness.py
@@ -1,327 +0,0 @@
|
||||
# type: ignore
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
|
||||
_SETTLE_INITIAL_BACKOFF_S = 1.0
|
||||
_SETTLE_MAX_BACKOFF_S = 60.0
|
||||
_SETTLE_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
raise
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if (m.get("name") or "").lower() == model_arg.lower():
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m.get("hugging_face_id") or m["name"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["name"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def fetch_and_filter_placements(
|
||||
client: ExoClient, full_model_id: str, args: argparse.Namespace
|
||||
) -> list[dict[str, Any]]:
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_tensor_ring
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "ring" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def settle_and_fetch_placements(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
args: argparse.Namespace,
|
||||
settle_timeout: float = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and settle_timeout > 0:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Skip pipeline+jaccl placements, as it's often pointless.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-tensor-ring",
|
||||
action="store_true",
|
||||
help="Skip tensor+ring placements, as it's so slow.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--settle-timeout",
|
||||
type=float,
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
@@ -4,7 +4,6 @@ version = "0.1.0"
|
||||
description = "Benchmarking tool for exo distributed inference"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"httpx>=0.27.0",
|
||||
"loguru>=0.7.3",
|
||||
"transformers>=5.0.0",
|
||||
"huggingface-hub>=0.33.4",
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
# Tool definitions — each becomes an OpenAI function tool.
|
||||
# All scenarios get all tools unless they specify a `tools` list.
|
||||
|
||||
[tools.get_current_weather]
|
||||
description = "Get the current weather in a given location"
|
||||
required = ["location"]
|
||||
|
||||
[tools.get_current_weather.properties.location]
|
||||
type = "string"
|
||||
description = "City and state, e.g. San Francisco, CA"
|
||||
|
||||
[tools.get_current_weather.properties.unit]
|
||||
type = "string"
|
||||
enum = ["celsius", "fahrenheit"]
|
||||
description = "Temperature unit"
|
||||
|
||||
[tools.calculate]
|
||||
description = "Evaluate a mathematical expression and return the numeric result"
|
||||
required = ["expression"]
|
||||
|
||||
[tools.calculate.properties.expression]
|
||||
type = "string"
|
||||
description = "The math expression to evaluate, e.g. '2 + 3 * 4'"
|
||||
|
||||
[tools.search_products]
|
||||
description = "Search for products in a catalog by query, category, and price"
|
||||
required = ["query"]
|
||||
|
||||
[tools.search_products.properties.query]
|
||||
type = "string"
|
||||
description = "Search query string"
|
||||
|
||||
[tools.search_products.properties.category]
|
||||
type = "string"
|
||||
enum = ["electronics", "clothing", "food", "books"]
|
||||
description = "Product category to filter by"
|
||||
|
||||
[tools.search_products.properties.max_price]
|
||||
type = "number"
|
||||
description = "Maximum price in USD"
|
||||
|
||||
# -- Should call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
name = "weather_simple"
|
||||
description = "Basic weather query -> get_current_weather"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather like in Tokyo right now?"
|
||||
|
||||
[[scenarios]]
|
||||
name = "calculator_simple"
|
||||
description = "Math question -> calculate"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Use the calculator to compute 3847 * 926 + 17293"
|
||||
|
||||
[[scenarios]]
|
||||
name = "search_with_filters"
|
||||
description = "Product search with category and price filter"
|
||||
expect_tool_call = true
|
||||
expected_function = "search_products"
|
||||
required_arg_keys = ["query"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Find me electronics under $50"
|
||||
|
||||
# -- Multi-turn: tool call then follow-up --
|
||||
|
||||
[[scenarios]]
|
||||
name = "weather_multi_turn"
|
||||
description = "Weather query -> tool result -> natural language summary"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[scenarios.tool_result]
|
||||
temperature = "18C"
|
||||
condition = "partly cloudy"
|
||||
humidity = "65%"
|
||||
wind = "12 km/h NW"
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather in Paris?"
|
||||
|
||||
[[scenarios]]
|
||||
name = "calculator_multi_turn"
|
||||
description = "Math query -> tool result -> model reports the answer"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[scenarios.tool_result]
|
||||
result = 491682
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Use the calculator to compute 1847 * 263 + 5921"
|
||||
|
||||
[[scenarios]]
|
||||
name = "search_multi_turn"
|
||||
description = "Search query -> tool result -> model summarizes products"
|
||||
expect_tool_call = true
|
||||
expected_function = "search_products"
|
||||
required_arg_keys = ["query"]
|
||||
|
||||
[[scenarios.tool_result.results]]
|
||||
name = "Hands-On Machine Learning"
|
||||
price = 45.99
|
||||
rating = 4.8
|
||||
|
||||
[[scenarios.tool_result.results]]
|
||||
name = "Deep Learning with Python"
|
||||
price = 39.99
|
||||
rating = 4.6
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Search for books about machine learning"
|
||||
|
||||
# -- Sequential tool calls --
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_same"
|
||||
description = "Thinking + weather(Tokyo) -> result -> model must call weather(London)"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Compare the weather in Tokyo and London."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll check both cities. Let me start with Tokyo."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_1"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Tokyo" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_1"
|
||||
content = '{"temperature": "25C", "condition": "sunny"}'
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_different"
|
||||
description = "Thinking + weather(Berlin) -> result -> model must call calculator"
|
||||
expect_tool_call = true
|
||||
expected_function = "calculate"
|
||||
required_arg_keys = ["expression"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What's the weather in Berlin, and also use the calculator to compute 4819 * 37 + 291."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll handle both. Let me check Berlin's weather first."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_2"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Berlin" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_2"
|
||||
content = '{"temperature": "12C", "condition": "rainy"}'
|
||||
|
||||
[[scenarios]]
|
||||
name = "chained_tool_calls_three"
|
||||
description = "Two prior thinking+tool calls -> results -> model must make a third"
|
||||
expect_tool_call = true
|
||||
expected_function = "get_current_weather"
|
||||
required_arg_keys = ["location"]
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Compare weather in Tokyo, Paris, and London."
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "I'll check all three cities. Starting with Tokyo."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_3"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Tokyo" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_3"
|
||||
content = '{"temperature": "25C", "condition": "sunny"}'
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "assistant"
|
||||
content = "Got Tokyo. Now checking Paris."
|
||||
|
||||
[[scenarios.messages.tool_calls]]
|
||||
id = "call_4"
|
||||
name = "get_current_weather"
|
||||
arguments = { location = "Paris" }
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "tool"
|
||||
tool_call_id = "call_4"
|
||||
content = '{"temperature": "18C", "condition": "cloudy"}'
|
||||
|
||||
# -- Should NOT call a tool --
|
||||
|
||||
[[scenarios]]
|
||||
name = "no_tool_joke"
|
||||
description = "Joke request should NOT trigger any tool"
|
||||
expect_tool_call = false
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "Tell me a funny joke about cats."
|
||||
|
||||
[[scenarios]]
|
||||
name = "no_tool_factual"
|
||||
description = "Factual question answerable from training data"
|
||||
expect_tool_call = false
|
||||
|
||||
[[scenarios.messages]]
|
||||
role = "user"
|
||||
content = "What is the capital of Japan?"
|
||||
@@ -103,7 +103,7 @@
|
||||
const modelSupportsThinking = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
const caps = modelCapabilities[currentModel] || [];
|
||||
return caps.includes("thinking_toggle") && caps.includes("text");
|
||||
return caps.includes("thinking") && caps.includes("text");
|
||||
});
|
||||
|
||||
const isEditOnlyWithoutImage = $derived(
|
||||
|
||||
@@ -59,14 +59,13 @@
|
||||
}
|
||||
|
||||
const sizeOptions: ImageGenerationParams["size"][] = [
|
||||
"auto",
|
||||
"512x512",
|
||||
"768x768",
|
||||
"1024x1024",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
"1024x1536",
|
||||
"1536x1024",
|
||||
"1024x1365",
|
||||
"1365x1024",
|
||||
];
|
||||
|
||||
const qualityOptions: ImageGenerationParams["quality"][] = [
|
||||
@@ -177,90 +176,92 @@
|
||||
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
|
||||
<!-- Basic params row -->
|
||||
<div class="flex items-center gap-3 flex-wrap">
|
||||
<!-- Size -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
<!-- Size (hidden in edit mode - output size comes from input image) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
{params.size.toUpperCase()}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size.toUpperCase()}</span>
|
||||
</button>
|
||||
{/each}
|
||||
{params.size}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Quality -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
@@ -310,7 +311,7 @@
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {qualityDropdownPosition()
|
||||
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
|
||||
>
|
||||
|
||||
@@ -306,14 +306,13 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
|
||||
export interface ImageGenerationParams {
|
||||
// Basic params
|
||||
size:
|
||||
| "auto"
|
||||
| "512x512"
|
||||
| "768x768"
|
||||
| "1024x1024"
|
||||
| "1024x768"
|
||||
| "768x1024"
|
||||
| "1024x1536"
|
||||
| "1536x1024";
|
||||
| "1024x1365"
|
||||
| "1365x1024";
|
||||
quality: "low" | "medium" | "high";
|
||||
outputFormat: "png" | "jpeg";
|
||||
numImages: number;
|
||||
@@ -337,7 +336,7 @@ export interface EditingImage {
|
||||
}
|
||||
|
||||
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
size: "auto",
|
||||
size: "1024x1024",
|
||||
quality: "medium",
|
||||
outputFormat: "png",
|
||||
numImages: 1,
|
||||
|
||||
@@ -74,6 +74,7 @@
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
@@ -114,7 +115,7 @@
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
|
||||
let
|
||||
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
|
||||
uvLockMlxVersion = mlxPackage.version;
|
||||
in
|
||||
{
|
||||
|
||||
10
nix/mlx.nix
10
nix/mlx.nix
@@ -41,16 +41,16 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.7.dev20260218+14841977"; in
|
||||
version = let v = "0.30.6"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "rltakashige";
|
||||
repo = "mlx-jaccl-fix-small-recv";
|
||||
rev = "1484197707f35186ad3bd614357c7c47fdf86ebc";
|
||||
hash = "sha256-FupCMoK/SF/ldfKuvMSAKECcOP8c+ANgkQlPZttDsLk=";
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
|
||||
@@ -17,9 +17,9 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx; sys_platform == 'darwin'",
|
||||
"mlx==0.30.6; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.7",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -64,7 +64,6 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
|
||||
@@ -58,21 +58,6 @@
|
||||
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
|
||||
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
|
||||
mlx = ignoreMissing prev.mlx;
|
||||
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
|
||||
buildInputs = (old.buildInputs or [ ]) ++ [
|
||||
final.nvidia-cublas
|
||||
final.nvidia-cuda-nvrtc
|
||||
final.nvidia-cudnn-cu13
|
||||
final.nvidia-nccl-cu13
|
||||
];
|
||||
preFixup = ''
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cublas}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
|
||||
'';
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
torch = ignoreMissing prev.torch;
|
||||
triton = ignoreMissing prev.triton;
|
||||
}
|
||||
@@ -89,25 +74,14 @@
|
||||
linuxOverlay
|
||||
]
|
||||
);
|
||||
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
|
||||
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
|
||||
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
|
||||
"lib/python3.13/site-packages/mlx*"
|
||||
"lib/python3.13/site-packages/nvidia*"
|
||||
];
|
||||
|
||||
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
)).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
);
|
||||
|
||||
mkPythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
@@ -158,7 +132,6 @@
|
||||
exo-test-env = testVenv;
|
||||
} // {
|
||||
exo-bench = mkBenchScript "exo-bench" (inputs.self + /bench/exo_bench.py);
|
||||
exo-eval-tool-calls = mkBenchScript "exo-eval-tool-calls" (inputs.self + /bench/eval_tool_calls.py);
|
||||
exo-get-all-models-on-cluster = mkSimplePythonScript "exo-get-all-models-on-cluster" (inputs.self + /tests/get_all_models_on_cluster.py);
|
||||
};
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "deepseek"
|
||||
quantization = "4bit"
|
||||
base_model = "DeepSeek V3.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405874409472
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "deepseek"
|
||||
quantization = "8bit"
|
||||
base_model = "DeepSeek V3.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 765577920512
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.5 Air"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 122406567936
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM 4.5 Air"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 229780750336
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "4bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 198556925568
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "6bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 286737579648
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.7"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 396963397248
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "4bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 19327352832
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "5bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 22548578304
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "6bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26843545600
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM 4.7 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 34359738368
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/GLM-5-8bit-MXFP8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "8bit"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 790517400864
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/GLM-5-MXFP4-Q8"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "MXFP4-Q8"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405478939008
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/GLM-5"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "glm"
|
||||
quantization = "bf16"
|
||||
base_model = "GLM-5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1487822475264
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "kimi"
|
||||
quantization = ""
|
||||
base_model = "Kimi K2"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 706522120192
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "kimi"
|
||||
quantization = ""
|
||||
base_model = "Kimi K2.5"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 662498705408
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "3bit"
|
||||
base_model = "MiniMax M2.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 100086644736
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "8bit"
|
||||
base_model = "MiniMax M2.1"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-4bit"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "4bit"
|
||||
base_model = "MiniMax M2.5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 128666664960
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-6bit"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "6bit"
|
||||
base_model = "MiniMax M2.5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 185826705408
|
||||
@@ -1,12 +0,0 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-8bit"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "minimax"
|
||||
quantization = "8bit"
|
||||
base_model = "MiniMax M2.5"
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 0.6B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 342884352
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 0.6B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 698351616
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 235B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 141733920768
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 235B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 268435456000
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 30B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 17612931072
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 30B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33279705088
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "4bit"
|
||||
base_model = "Qwen3 Next 80B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 47080074240
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "qwen"
|
||||
quantization = "8bit"
|
||||
base_model = "Qwen3 Next 80B"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "4bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 114572190076
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "6bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 159039627774
|
||||
|
||||
@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
|
||||
family = "step"
|
||||
quantization = "8bit"
|
||||
base_model = "Step 3.5 Flash"
|
||||
capabilities = ["text", "thinking", "thinking_toggle"]
|
||||
capabilities = ["text", "thinking"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 209082699847
|
||||
|
||||
2
rust/clippy.toml
Normal file
2
rust/clippy.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
|
||||
#allowed-duplicate-crates = ["hashbrown"]
|
||||
@@ -25,17 +25,17 @@ workspace = true
|
||||
networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { version = "0.27.2", features = [
|
||||
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
|
||||
# "nightly", # enables better-supported GIL integration
|
||||
pyo3 = { version = "0.27.1", features = [
|
||||
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
|
||||
"nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
"multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
|
||||
# integrations with other libraries
|
||||
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
# "ordered-float", "rust_decimal", "smallvec",
|
||||
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
"ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.17.2" }
|
||||
@@ -45,18 +45,33 @@ pyo3-log = "0.13.2"
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures-lite = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
once_cell = "1.21.3"
|
||||
thread_local = "1.1.9"
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
|
||||
|
||||
# Tracing
|
||||
#tracing = "0.1"
|
||||
#tracing-subscriber = "0.3"
|
||||
#console-subscriber = "0.1.5"
|
||||
#tracing-log = "0.2.0"
|
||||
log = { workspace = true }
|
||||
env_logger = "0.11"
|
||||
|
||||
|
||||
# Networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
|
||||
@@ -19,7 +19,7 @@ class ConnectionUpdate:
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> builtins.str:
|
||||
def peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@@ -40,22 +40,92 @@ class Keypair:
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate() -> Keypair:
|
||||
def generate_ed25519() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Keypair:
|
||||
def generate_ecdsa() -> Keypair:
|
||||
r"""
|
||||
Construct an Ed25519 keypair from secret key bytes
|
||||
Generate a new ECDSA keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_secp256k1() -> Keypair:
|
||||
r"""
|
||||
Generate a new Secp256k1 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
"""
|
||||
@staticmethod
|
||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
format (i.e. unencrypted) as defined in [RFC5208].
|
||||
|
||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
"""
|
||||
@staticmethod
|
||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
structure as defined in [RFC5915].
|
||||
|
||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
"""
|
||||
@staticmethod
|
||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key as protobuf structure.
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Multiaddr:
|
||||
r"""
|
||||
Representation of a Multiaddr.
|
||||
"""
|
||||
@staticmethod
|
||||
def empty() -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress.
|
||||
"""
|
||||
@staticmethod
|
||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress with the given capacity.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its byte slice representation.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string: builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its string representation.
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
Return the length in bytes of this multiaddress.
|
||||
"""
|
||||
def is_empty(self) -> builtins.bool:
|
||||
r"""
|
||||
Returns true if the length of this multiaddress is 0.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Get the secret key bytes underlying the keypair
|
||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
"""
|
||||
def to_node_id(self) -> builtins.str:
|
||||
def to_string(self) -> builtins.str:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our NodeId.
|
||||
Convert a Multiaddr to a string.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
@@ -110,6 +180,37 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class PeerId:
|
||||
r"""
|
||||
Identifier of a peer of the network.
|
||||
|
||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
"""
|
||||
@staticmethod
|
||||
def random() -> PeerId:
|
||||
r"""
|
||||
Generates a random peer ID from a cryptographically secure PRNG.
|
||||
|
||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> PeerId:
|
||||
r"""
|
||||
Parses a `PeerId` from bytes.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Returns a raw bytes representation of this `PeerId`.
|
||||
"""
|
||||
def to_base58(self) -> builtins.str:
|
||||
r"""
|
||||
Returns a base-58 encoded string of this `PeerId`.
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
//!
|
||||
|
||||
use pin_project::pin_project;
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
pin::{Pin, pin},
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
@@ -25,13 +26,15 @@ where
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Send,
|
||||
F::Output: Send,
|
||||
F: Future + Ungil,
|
||||
F::Output: Ungil,
|
||||
{
|
||||
type Output = F::Output;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let waker = cx.waker();
|
||||
Python::attach(|py| py.detach(|| self.project().0.poll(&mut Context::from_waker(waker))))
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
|
||||
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
|
||||
//!
|
||||
//! Pattern examples include:
|
||||
//! - Async task handles: with GC-integrated cleanup
|
||||
//! - Sync/async callbacks from python: with propper eventloop handling
|
||||
//!
|
||||
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
|
||||
//! - Store mutable fields in tokio's `Mutex<T>`
|
||||
//! - For async code: take `&self` and `.lock().await`
|
||||
//! - For sync code: take `&mut self` and `.get_mut()`
|
||||
|
||||
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::BoxFuture;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::{
|
||||
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
|
||||
fn needs_tokio_runtime() {
|
||||
tokio::runtime::Handle::current();
|
||||
}
|
||||
|
||||
type SyncCallback = Box<dyn Fn() + Send + Sync>;
|
||||
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
|
||||
|
||||
enum AsyncTaskMessage {
|
||||
SyncCallback(SyncCallback),
|
||||
AsyncCallback(AsyncCallback),
|
||||
}
|
||||
|
||||
async fn async_task(
|
||||
sender: mpsc::UnboundedSender<()>,
|
||||
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
|
||||
) {
|
||||
log::info!("RUST: async task started");
|
||||
|
||||
// task state
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(1));
|
||||
|
||||
let mut sync_cbs: Vec<SyncCallback> = vec![];
|
||||
let mut async_cbs: Vec<AsyncCallback> = vec![];
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// handle incoming messages from task-handle
|
||||
message = receiver.recv() => {
|
||||
// handle closed channel by exiting
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming event
|
||||
match message {
|
||||
AsyncTaskMessage::SyncCallback(cb) => {
|
||||
sync_cbs.push(cb);
|
||||
}
|
||||
AsyncTaskMessage::AsyncCallback(cb) => {
|
||||
async_cbs.push(cb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle all other events
|
||||
_ = interval.tick() => {
|
||||
log::info!("RUST: async task tick");
|
||||
|
||||
// call back all sync callbacks
|
||||
for cb in &sync_cbs {
|
||||
cb();
|
||||
}
|
||||
|
||||
// call back all async callbacks
|
||||
for cb in &async_cbs {
|
||||
cb().await;
|
||||
}
|
||||
|
||||
// send event on unbounded channel
|
||||
sender.send(()).expect("handle receiver cannot be closed/dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: async task stopped");
|
||||
}
|
||||
|
||||
// #[gen_stub_pyclass]
|
||||
#[pyclass(name = "AsyncTaskHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyAsyncTaskHandle {
|
||||
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyAsyncTaskHandle {
|
||||
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_mut()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn new(
|
||||
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sender: Some(sender),
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAsyncTaskHandle {
|
||||
#[new]
|
||||
fn py_new(py: Python<'_>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channel TOWARDS our task
|
||||
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
|
||||
|
||||
// create communication channel FROM our task
|
||||
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
|
||||
|
||||
// perform necessary setup within tokio context - or it crashes
|
||||
let () = get_runtime().block_on(async { needs_tokio_runtime() });
|
||||
|
||||
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
|
||||
_ = get_runtime().spawn_with_scope(py, async move {
|
||||
async_task(t_sender, t_receiver).await;
|
||||
});
|
||||
Ok(Self::new(h_sender, h_receiver))
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_sync_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], None]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
|
||||
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_async_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
|
||||
let c = Python::with_gil(|py| callback.clone_ref(py));
|
||||
async move {
|
||||
if let Some(f) = Python::with_gil(|py| {
|
||||
let coroutine = c.call0(py).write_unraisable_with(py)?;
|
||||
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
|
||||
.write_unraisable_with(py)
|
||||
}) {
|
||||
_ = f.await.write_unraisable();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn receive_unit(&mut self) -> PyResult<()> {
|
||||
self.receiver
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
))
|
||||
}
|
||||
|
||||
fn drain_units(&mut self) -> PyResult<i32> {
|
||||
let mut cnt = 0;
|
||||
loop {
|
||||
match self.receiver.try_recv() {
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
return Err(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
));
|
||||
}
|
||||
Err(TryRecvError::Empty) => return Ok(cnt),
|
||||
Ok(()) => {
|
||||
cnt += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyAsyncTaskHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::types::{PyBytes, PyBytesMethods};
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Construct an Ed25519 keypair from secret key bytes
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Get the secret key bytes underlying the keypair
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self
|
||||
.0
|
||||
.clone()
|
||||
.try_into_ed25519()
|
||||
.expect("we only use ed25519 keys")
|
||||
.secret()
|
||||
.as_ref()
|
||||
.to_vec();
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our NodeId.
|
||||
fn to_node_id(&self) -> String {
|
||||
self.0.public().to_peer_id().to_base58()
|
||||
}
|
||||
}
|
||||
@@ -4,14 +4,28 @@
|
||||
//!
|
||||
//!
|
||||
|
||||
mod allow_threading;
|
||||
mod ident;
|
||||
mod networking;
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(tuple_trait)]
|
||||
#![feature(unboxed_closures)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
mod examples;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::types::PyModuleMethods;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
@@ -20,11 +34,24 @@ pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use extend::ext;
|
||||
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Py, PyErr, PyResult, Python};
|
||||
use tokio::runtime::Runtime;
|
||||
@@ -35,7 +62,7 @@ pub(crate) mod ext {
|
||||
#[ext(pub, name = ByteArrayExt)]
|
||||
impl [u8] {
|
||||
fn pybytes(&self) -> Py<PyBytes> {
|
||||
Python::attach(|py| PyBytes::new(py, self).unbind())
|
||||
Python::with_gil(|py| PyBytes::new(py, self).unbind())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +98,7 @@ pub(crate) mod ext {
|
||||
#[ext(pub, name = PyResultExt)]
|
||||
impl<T> PyResult<T> {
|
||||
fn write_unraisable(self) -> Option<T> {
|
||||
Python::attach(|py| self.write_unraisable_with(py))
|
||||
Python::with_gil(|py| self.write_unraisable_with(py))
|
||||
}
|
||||
|
||||
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
|
||||
@@ -148,6 +175,24 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
use std::marker::Sized;
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct ClonePy<T>(pub Py<T>);
|
||||
|
||||
impl<T> Clone for ClonePy<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self(self.0.clone_ref(py)))
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
|
||||
/// import the module.
|
||||
@@ -155,14 +200,12 @@ pub(crate) mod ext {
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
pyo3_log::init();
|
||||
let mut builder = tokio::runtime::Builder::new_multi_thread();
|
||||
builder.enable_all();
|
||||
pyo3_async_runtimes::tokio::init(builder);
|
||||
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
m.add_class::<PyKeypair>()?;
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
|
||||
@@ -1,24 +1,31 @@
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::exception::{PyAllQueuesFullError, PyNoPeersSubscribedToTopicError};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::pyclass;
|
||||
use futures_lite::StreamExt as _;
|
||||
use libp2p::gossipsub::PublishError;
|
||||
use networking::swarm::{FromSwarm, Swarm, ToSwarm, create_swarm};
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use libp2p::{gossipsub, mdns};
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{
|
||||
gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pyclass_enum, gen_stub_pymethods,
|
||||
};
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
|
||||
mod exception {
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3::{exceptions::PyException, prelude::*};
|
||||
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
|
||||
use pyo3_stub_gen::derive::*;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
@@ -112,7 +119,7 @@ struct PyConnectionUpdate {
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: String,
|
||||
peer_id: PyPeerId,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
@@ -123,45 +130,206 @@ struct PyConnectionUpdate {
|
||||
remote_tcp_port: u16,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
pub to_swarm: mpsc::Sender<ToSwarm>,
|
||||
pub swarm: Mutex<Swarm>,
|
||||
enum ToTask {
|
||||
GossipsubSubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
||||
},
|
||||
GossipsubUnsubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<bool>,
|
||||
},
|
||||
GossipsubPublish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass_complex_enum]
|
||||
#[pyclass]
|
||||
enum PyFromSwarm {
|
||||
Connection {
|
||||
peer_id: String,
|
||||
connected: bool,
|
||||
},
|
||||
Message {
|
||||
origin: String,
|
||||
topic: String,
|
||||
data: Py<PyBytes>,
|
||||
},
|
||||
}
|
||||
impl From<FromSwarm> for PyFromSwarm {
|
||||
fn from(value: FromSwarm) -> Self {
|
||||
match value {
|
||||
FromSwarm::Discovered { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: true,
|
||||
},
|
||||
FromSwarm::Expired { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: false,
|
||||
},
|
||||
FromSwarm::Message { from, topic, data } => Self::Message {
|
||||
origin: from.to_base58(),
|
||||
topic: topic,
|
||||
data: data.pybytes(),
|
||||
},
|
||||
#[allow(clippy::enum_glob_use)]
|
||||
async fn networking_task(
|
||||
mut swarm: networking::swarm::Swarm,
|
||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use mdns::Event::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = to_task_rx.recv() => {
|
||||
// handle closed channel
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming messages
|
||||
match message {
|
||||
GossipsubSubscribe { topic, result_tx } => {
|
||||
// try to subscribe
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubUnsubscribe { topic, result_tx } => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(result) {
|
||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubPublish { topic, data, result_tx } => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
||||
Err(exception::PyAllQueuesFullError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(pyresult) {
|
||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// architectural solution to this problem:
|
||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
||||
//
|
||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
||||
// then for actual communication it will dial those peers if need-be
|
||||
swarm_event = swarm.select_next_some() => {
|
||||
match swarm_event {
|
||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
||||
message: Message {
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => {
|
||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
||||
let message = (topic.into_string(), data);
|
||||
|
||||
// send incoming message to channel (or exit if connection closed)
|
||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
e => {
|
||||
log::info!("RUST: other event {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: networking task stopped");
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
||||
}
|
||||
|
||||
impl Drop for PyNetworkingHandle {
|
||||
fn drop(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyNetworkingHandle {
|
||||
fn new(
|
||||
to_task_tx: mpsc::Sender<ToTask>,
|
||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_task_tx: Some(to_task_tx),
|
||||
connection_update_rx: Mutex::new(connection_update_rx),
|
||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
||||
}
|
||||
}
|
||||
|
||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
||||
self.to_task_tx
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
@@ -175,33 +343,97 @@ impl PyNetworkingHandle {
|
||||
|
||||
#[new]
|
||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channels
|
||||
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();
|
||||
let swarm = { create_swarm(identity, from_client).pyerr()? };
|
||||
let swarm = get_runtime()
|
||||
.block_on(async { create_swarm(identity) })
|
||||
.pyerr()?;
|
||||
|
||||
Ok(Self {
|
||||
swarm: Mutex::new(swarm),
|
||||
to_swarm,
|
||||
})
|
||||
// spawn tokio task running the networking logic
|
||||
get_runtime().spawn(async move {
|
||||
networking_task(
|
||||
swarm,
|
||||
to_task_rx,
|
||||
connection_update_tx,
|
||||
gossipsub_message_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(Self::new(
|
||||
to_task_tx,
|
||||
connection_update_rx,
|
||||
gossipsub_message_rx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn recv(&self) -> PyResult<PyFromSwarm> {
|
||||
self.swarm
|
||||
.try_lock()
|
||||
.expect("tried to recv from swarm twice concurrently")
|
||||
.next()
|
||||
.allow_threads_py()
|
||||
#[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
|
||||
// ---- Connection update receiver methods ----
|
||||
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.ok_or(PyErr::receiver_channel_closed())
|
||||
.map(Into::into)
|
||||
}
|
||||
|
||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
||||
// fn connection_update_is_empty(&self) -> bool {
|
||||
// self.connection_update_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
||||
// fn connection_update_len(&self) -> usize {
|
||||
// self.connection_update_rx.blocking_lock().len()
|
||||
// }
|
||||
|
||||
// ---- Gossipsub management methods ----
|
||||
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
@@ -211,10 +443,10 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Subscribe {
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubSubscribe {
|
||||
topic,
|
||||
result_sender: tx,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -223,7 +455,6 @@ impl PyNetworkingHandle {
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
.pyerr()
|
||||
}
|
||||
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
@@ -233,10 +464,10 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Unsubscribe {
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubUnsubscribe {
|
||||
topic,
|
||||
result_sender: tx,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -254,12 +485,12 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Publish {
|
||||
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
topic,
|
||||
data,
|
||||
result_sender: tx,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -268,14 +499,64 @@ impl PyNetworkingHandle {
|
||||
let _ = rx
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
.map_err(|e| match e {
|
||||
PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),
|
||||
PublishError::MessageTooLarge => PyNoPeersSubscribedToTopicError::new_err(),
|
||||
e => PyRuntimeError::new_err(e.to_string()),
|
||||
})?;
|
||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---- Gossipsub message receiver methods ----
|
||||
|
||||
/// Receives the next message from the `GossipSub` network.
|
||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
||||
self.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
}
|
||||
|
||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
Ok(self
|
||||
.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
.collect())
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next message from the `GossipSub` network.
|
||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
||||
// Ok(self
|
||||
// .gossipsub_message_rx
|
||||
// .blocking_lock()
|
||||
// .try_recv_py()?
|
||||
// .map(|(t, d)| (t, d.pybytes())))
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `GossipSub` message channel is empty.
|
||||
// fn gossipsub_is_empty(&self) -> bool {
|
||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `GossipSub` messages in the channel.
|
||||
// fn gossipsub_len(&self) -> usize {
|
||||
// self.gossipsub_message_rx.blocking_lock().len()
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
|
||||
159
rust/exo_pyo3_bindings/src/pylibp2p/ident.rs
Normal file
159
rust/exo_pyo3_bindings/src/pylibp2p/ident.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
8
rust/exo_pyo3_bindings/src/pylibp2p/mod.rs
Normal file
8
rust/exo_pyo3_bindings/src/pylibp2p/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
||||
//!
|
||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
||||
//! independent identity type of some kind or another. This may require handshaking.
|
||||
//!
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
81
rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs
Normal file
81
rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::Multiaddr;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::str::FromStr as _;
|
||||
|
||||
/// Representation of a Multiaddr.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Multiaddr", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyMultiaddr(pub Multiaddr);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyMultiaddr {
|
||||
/// Create a new, empty multiaddress.
|
||||
#[staticmethod]
|
||||
fn empty() -> Self {
|
||||
Self(Multiaddr::empty())
|
||||
}
|
||||
|
||||
/// Create a new, empty multiaddress with the given capacity.
|
||||
#[staticmethod]
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self(Multiaddr::with_capacity(n))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its string representation.
|
||||
#[staticmethod]
|
||||
fn from_string(string: String) -> PyResult<Self> {
|
||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
||||
}
|
||||
|
||||
/// Return the length in bytes of this multiaddress.
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns true if the length of this multiaddress is 0.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_vec();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Convert a Multiaddr to a string.
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __repr__(&self) -> String {
|
||||
format!("Multiaddr({})", self.0)
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __str__(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyMultiaddr>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -19,14 +19,21 @@ either = { workspace = true }
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures-lite = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
@@ -34,5 +41,4 @@ keccak-const = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
@@ -1,8 +1,6 @@
|
||||
use futures_lite::StreamExt;
|
||||
use libp2p::identity;
|
||||
use networking::swarm;
|
||||
use networking::swarm::{FromSwarm, ToSwarm};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use futures::stream::StreamExt as _;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
@@ -13,23 +11,17 @@ async fn main() {
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
let (to_swarm, from_client) = mpsc::channel(20);
|
||||
|
||||
// Configure swarm
|
||||
let mut swarm = swarm::create_swarm(identity::Keypair::generate_ed25519(), from_client)
|
||||
.expect("Swarm creation failed");
|
||||
let mut swarm =
|
||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let (tx, rx) = oneshot::channel();
|
||||
_ = to_swarm
|
||||
.send(ToSwarm::Subscribe {
|
||||
topic: "test-net".to_string(),
|
||||
result_sender: tx,
|
||||
})
|
||||
.await
|
||||
.expect("should send");
|
||||
|
||||
let mut fused = futures_lite::future::fuse(rx);
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("Subscribing to topic failed");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
@@ -40,23 +32,43 @@ async fn main() {
|
||||
select! {
|
||||
// on gossipsub outgoing
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
if let Err(e) = to_swarm.send(swarm::ToSwarm::Publish { topic: "test-net".to_string(), data: line.as_bytes().to_vec(), result_sender: tx }).await {
|
||||
println!("Send error: {e:?}");
|
||||
return
|
||||
};
|
||||
if let Err(e) = rx.await {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
},
|
||||
event = swarm.next() => match event {
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
// on gossipsub incoming
|
||||
Some(FromSwarm::Discovered { peer_id }) => { println!("\n\nconnected to {peer_id}\n\n") },
|
||||
Some(FromSwarm::Expired { peer_id }) => { println!("\n\ndisconnected from {peer_id}\n\n") },
|
||||
Some(FromSwarm::Message { from, topic, data }) => { println!("{topic}/{from}:\n{}", String::from_utf8_lossy(&data)) },
|
||||
None => {},
|
||||
},
|
||||
f = &mut fused => {assert!(f.expect("should recv").expect("should subscribe"))},
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
discovery::Event::ConnectionClosed {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
127
rust/networking/examples/chatroom_manual.rs
Normal file
127
rust/networking/examples/chatroom_manual.rs
Normal file
@@ -0,0 +1,127 @@
|
||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use futures::stream::StreamExt;
|
||||
use libp2p::{
|
||||
gossipsub, mdns, noise,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use std::{error::Error, hash::Hash};
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
||||
#[derive(NetworkBehaviour)]
|
||||
struct MyBehaviour {
|
||||
gossipsub: gossipsub::Behaviour,
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.try_init();
|
||||
|
||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
||||
.with_tokio()
|
||||
.with_tcp(
|
||||
tcp::Config::default(),
|
||||
noise::Config::new,
|
||||
yamux::Config::default,
|
||||
)?
|
||||
.with_behaviour(|key| {
|
||||
// Set a custom gossipsub configuration
|
||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
||||
.heartbeat_interval(Duration::from_secs(10))
|
||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
||||
.build()
|
||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
||||
gossipsub_config,
|
||||
)?;
|
||||
|
||||
let mdns =
|
||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
||||
Ok(MyBehaviour { gossipsub, mdns })
|
||||
})?
|
||||
.build();
|
||||
|
||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
||||
|
||||
// Create a Gossipsub topic
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
// subscribes to our topic
|
||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
SwarmEvent::NewListenAddr { address, .. } => {
|
||||
println!("Local node is listening on {address}");
|
||||
}
|
||||
e => {
|
||||
println!("Other swarm event: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use crate::keep_alive;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures_lite::FutureExt;
|
||||
use futures::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
@@ -362,7 +363,7 @@ impl NetworkBehaviour for Behaviour {
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll(cx).is_ready() {
|
||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
|
||||
44
rust/networking/src/keep_alive.rs
Normal file
44
rust/networking/src/keep_alive.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use delegate::delegate;
|
||||
use libp2p::swarm::handler::ConnectionEvent;
|
||||
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
|
||||
/// the connection alive.
|
||||
#[derive(Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct ConnectionHandler(dummy::ConnectionHandler);
|
||||
|
||||
impl ConnectionHandler {
|
||||
pub fn new() -> Self {
|
||||
ConnectionHandler(dummy::ConnectionHandler)
|
||||
}
|
||||
}
|
||||
|
||||
impl handler::ConnectionHandler for ConnectionHandler {
|
||||
// delegate types and implementation mostly to dummy handler
|
||||
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
|
||||
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
|
||||
type InboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
|
||||
type OutboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
|
||||
type InboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
|
||||
type OutboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
|
||||
|
||||
delegate! {
|
||||
to self.0 {
|
||||
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
|
||||
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
|
||||
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
|
||||
}
|
||||
}
|
||||
|
||||
// specifically override this to force connection to stay alive
|
||||
fn connection_keep_alive(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,19 @@
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(unboxed_closures)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
pub mod discovery;
|
||||
pub mod keep_alive;
|
||||
pub mod swarm;
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
@@ -42,3 +54,11 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
#![allow(dead_code)]
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
use std::task::Poll;
|
||||
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
use crate::{alias, discovery};
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use futures_lite::Stream;
|
||||
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use libp2p::{SwarmBuilder, identity};
|
||||
|
||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
@@ -17,133 +15,8 @@ use tokio::sync::{mpsc, oneshot};
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||
|
||||
pub enum ToSwarm {
|
||||
Unsubscribe {
|
||||
topic: String,
|
||||
result_sender: oneshot::Sender<bool>,
|
||||
},
|
||||
Subscribe {
|
||||
topic: String,
|
||||
result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,
|
||||
},
|
||||
Publish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,
|
||||
},
|
||||
}
|
||||
pub enum FromSwarm {
|
||||
Message {
|
||||
from: PeerId,
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
},
|
||||
Discovered {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
Expired {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
}
|
||||
#[pin_project::pin_project]
|
||||
pub struct Swarm {
|
||||
#[pin]
|
||||
inner: libp2p::Swarm<Behaviour>,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
}
|
||||
impl Swarm {
|
||||
fn on_message(&mut self, message: ToSwarm) {
|
||||
match message {
|
||||
ToSwarm::Subscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
// try to subscribe
|
||||
let result = self
|
||||
.inner
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&gossipsub::IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
_ = result_sender.send(result)
|
||||
}
|
||||
ToSwarm::Unsubscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = self
|
||||
.inner
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.unsubscribe(&gossipsub::IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
_ = result_sender.send(result)
|
||||
}
|
||||
ToSwarm::Publish {
|
||||
topic,
|
||||
data,
|
||||
result_sender,
|
||||
} => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = self
|
||||
.inner
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.publish(gossipsub::IdentTopic::new(topic), data);
|
||||
// send response oneshot (or exit if connection closed)
|
||||
_ = result_sender.send(result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Stream for Swarm {
|
||||
type Item = FromSwarm;
|
||||
fn poll_next(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
match self.from_client.poll_recv(cx) {
|
||||
Poll::Ready(Some(msg)) => self.on_message(msg),
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Pending => {}
|
||||
}
|
||||
match self.project().inner.poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Ready(Some(swarm_event)) => match swarm_event {
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
message:
|
||||
gossipsub::Message {
|
||||
source: Some(peer_id),
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => Poll::Ready(Some(FromSwarm::Message {
|
||||
from: peer_id,
|
||||
topic: topic.into_string(),
|
||||
data,
|
||||
})),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
|
||||
discovery::Event::ConnectionEstablished { peer_id, .. },
|
||||
)) => Poll::Ready(Some(FromSwarm::Discovered { peer_id })),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
|
||||
discovery::Event::ConnectionClosed { peer_id, .. },
|
||||
)) => Poll::Ready(Some(FromSwarm::Expired { peer_id })),
|
||||
_ => Poll::Pending,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Create and configure a swarm which listens to all ports on OS
|
||||
pub fn create_swarm(
|
||||
keypair: identity::Keypair,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
) -> alias::AnyResult<Swarm> {
|
||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
@@ -152,16 +25,13 @@ pub fn create_swarm(
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
Ok(Swarm {
|
||||
inner: swarm,
|
||||
from_client,
|
||||
})
|
||||
Ok(swarm)
|
||||
}
|
||||
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||
use futures_lite::{AsyncRead, AsyncWrite};
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user