mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-27 11:46:14 -05:00
Compare commits
5 Commits
consistent
...
JakeHillio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bdd6594bc8 | ||
|
|
152a27ea5d | ||
|
|
db36bd5ac6 | ||
|
|
639243aa09 | ||
|
|
db73c4fd5d |
35
.github/renovate.json
vendored
Normal file
35
.github/renovate.json
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
||||
"extends": [
|
||||
"config:recommended"
|
||||
],
|
||||
"dependencyDashboard": true,
|
||||
"customManagers": [
|
||||
{
|
||||
"customType": "regex",
|
||||
"description": "Pin HuggingFace model revisions to commit SHAs",
|
||||
"managerFilePatterns": [
|
||||
"^resources/inference_model_cards/.*\\.toml$",
|
||||
"^resources/image_model_cards/.*\\.toml$"
|
||||
],
|
||||
"matchStrings": [
|
||||
"model_id = \"(?<packageName>[^\"]+)\"(?:\\n|\\r\\n?)revision = \"(?<currentValue>[^\"]+)\""
|
||||
],
|
||||
"datasourceTemplate": "git-refs",
|
||||
"depNameTemplate": "huggingface.co/{{packageName}}",
|
||||
"currentValueTemplate": "main",
|
||||
"versioningTemplate": "git"
|
||||
}
|
||||
],
|
||||
"packageRules": [
|
||||
{
|
||||
"matchDatasources": ["git-refs"],
|
||||
"matchPackageNames": ["huggingface.co/**"],
|
||||
"groupName": null,
|
||||
"automerge": false,
|
||||
"prTitle": "chore: pin {{depName}} to {{newDigest}}",
|
||||
"commitMessageTopic": "{{depName}} revision",
|
||||
"commitMessageExtra": "to {{newDigest}}"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -73,9 +73,11 @@ class GenerationResponse:
|
||||
finish_reason: Optional[str] = ...
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
): # -> None:
|
||||
...
|
||||
prompt_cache: Any,
|
||||
quantized_kv_start: int | None,
|
||||
kv_group_size: int | None,
|
||||
kv_bits: int | None,
|
||||
) -> None: ...
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
|
||||
@@ -16,7 +16,7 @@ class Cache(Protocol):
|
||||
self, keys: mx.array, values: mx.array
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
|
||||
@@ -92,13 +92,14 @@ class _BaseCache(Cache):
|
||||
values: mx.array
|
||||
offset: int
|
||||
@property
|
||||
def state(self) -> tuple[mx.array, mx.array]: ...
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
@property
|
||||
def meta_state(self) -> Literal[""]: ...
|
||||
@meta_state.setter
|
||||
def meta_state(self, v) -> None: ...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def is_trimmable(self) -> Literal[False]: ...
|
||||
@classmethod
|
||||
def from_state(cls, state, meta_state) -> Self: ...
|
||||
@@ -114,15 +115,13 @@ class ConcatenateKVCache(_BaseCache):
|
||||
def update_and_fetch(self, keys, values): # -> tuple[Any | array, Any | array]:
|
||||
...
|
||||
@property
|
||||
def state(self): # -> tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
@@ -132,10 +131,7 @@ class QuantizedKVCache(_BaseCache):
|
||||
def update_and_fetch(self, keys, values): # -> Any:
|
||||
...
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | tuple[array, array, array] | None, Any | tuple[array, array, array] | None] | Any:
|
||||
...
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@@ -147,8 +143,7 @@ class QuantizedKVCache(_BaseCache):
|
||||
...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def make_mask(self, *args, **kwargs): # -> array | Literal['causal'] | None:
|
||||
...
|
||||
|
||||
@@ -160,13 +155,12 @@ class KVCache(_BaseCache):
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
) -> tuple[array, array]: ...
|
||||
) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v) -> None: ...
|
||||
def is_trimmable(self): # -> Literal[True]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
@@ -183,8 +177,7 @@ class RotatingKVCache(_BaseCache):
|
||||
@property
|
||||
def state(
|
||||
self,
|
||||
): # -> tuple[Any | array, Any | array] | tuple[Any | array | None, Any | array | None]:
|
||||
...
|
||||
) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@@ -196,8 +189,7 @@ class RotatingKVCache(_BaseCache):
|
||||
...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
def to_quantized(
|
||||
self, group_size: int = ..., bits: int = ...
|
||||
) -> QuantizedKVCache: ...
|
||||
@@ -212,8 +204,7 @@ class ArraysCache(_BaseCache):
|
||||
...
|
||||
def __getitem__(self, idx): ...
|
||||
@property
|
||||
def state(self): # -> list[Any | array] | list[array]:
|
||||
...
|
||||
def state(self) -> tuple[mx.array | None, mx.array | None]: ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
@@ -239,8 +230,7 @@ class ChunkedKVCache(KVCache):
|
||||
...
|
||||
def update_and_fetch(self, keys, values): # -> tuple[array, array]:
|
||||
...
|
||||
def trim(self, n): # -> int:
|
||||
...
|
||||
def trim(self, n: int) -> int: ...
|
||||
@property
|
||||
def meta_state(self): # -> tuple[str, ...]:
|
||||
...
|
||||
@@ -253,10 +243,9 @@ class CacheList(_BaseCache):
|
||||
def __getitem__(self, idx): ...
|
||||
def is_trimmable(self): # -> bool:
|
||||
...
|
||||
def trim(self, n): ...
|
||||
def trim(self, n: int) -> int: ...
|
||||
@property
|
||||
def state(self): # -> list[Any]:
|
||||
...
|
||||
def state(self) -> list[tuple[mx.array | None, mx.array | None]]: ...
|
||||
@state.setter
|
||||
def state(self, v): # -> None:
|
||||
...
|
||||
|
||||
24
Cargo.lock
generated
24
Cargo.lock
generated
@@ -216,6 +216,28 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
|
||||
dependencies = [
|
||||
"async-stream-impl",
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream-impl"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.89"
|
||||
@@ -2759,6 +2781,7 @@ dependencies = [
|
||||
name = "networking"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"delegate",
|
||||
"either",
|
||||
"extend",
|
||||
@@ -2767,6 +2790,7 @@ dependencies = [
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
"log",
|
||||
"pin-project",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
|
||||
@@ -34,6 +34,7 @@ delegate = "0.13"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Async dependencies
|
||||
async-stream = "0.3"
|
||||
tokio = "1.46"
|
||||
futures-lite = "2.6.1"
|
||||
futures-timer = "3.0"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev-4bit"
|
||||
revision = "4730d16f5f45143bab61f1cbf963d479f205e360"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev-8bit"
|
||||
revision = "cbf01164d429932b260d91872d62a7a4fe2634fa"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-Kontext-dev"
|
||||
revision = "76e13736ad51f8dd8259a336dc087f8d3c7b819e"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
|
||||
revision = "880ebf331481b566f2019a88dace35279c4bcce8"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
|
||||
revision = "bb45e8d78959e51d42a1d424f44a112adabd5cbc"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev"
|
||||
revision = "2e3d8c5ebc737af82d1f4c669f99bacc72b09d6b"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-dev-4bit"
|
||||
revision = "b96c650a7b2c57484e2df51d1fafcb4cb31b1060"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-dev-8bit"
|
||||
revision = "aff4dbb9efca28c2cb809b7814a92dc5de0533a4"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-dev"
|
||||
revision = "d5bf931a451025fd4e152be685d0e05f50324388"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-schnell-4bit"
|
||||
revision = "9eaa004ace32efb5b45b17f128d493ac614e8985"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-schnell-8bit"
|
||||
revision = "b829443fca09b0abcca3eb20821c8f54b307d119"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/FLUX.1-schnell"
|
||||
revision = "aedc677102a335e26774bea593b317d64c908a83"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/Qwen-Image-4bit"
|
||||
revision = "b38178287165b1b4ecc206a41f9a55f4423a21b5"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/Qwen-Image-8bit"
|
||||
revision = "617d15c93a317f4b5cebb766d4638775d002b380"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
|
||||
revision = "ac25e6566c3a94e52d988b017acde9d70945109a"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
|
||||
revision = "463547f285f8b6e5496347724a2b39a6d514ca76"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509"
|
||||
revision = "05027449c3ccee1b1c6be3ba85278ae683add9b5"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "exolabs/Qwen-Image"
|
||||
revision = "e7990ef5392a17dc917578b6f4e43aad9ae93e7a"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
revision = "main"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-8bit"
|
||||
revision = "cd6c63546a6d33a8cc75158dc60d1746787306ac"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-8bit"
|
||||
revision = "ca44c769a97034e91466f2a524b6ef2c28eb3c1f"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-bf16"
|
||||
revision = "1753b3269c9e3cb62ba3cbaf0ab6433d69784592"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-4.7-4bit"
|
||||
revision = "0e9f6c4babaef5d5fd04c9efc8770ef234b6d576"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-4.7-6bit"
|
||||
revision = "025456be149d69c6b2805914dcc0e4aa6307caf9"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
revision = "65b39750987b4230754dc0becc66e42b4c9da07b"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-4bit"
|
||||
revision = "1454cffb1a21737e162f508e5bc70be9def89276"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-5bit"
|
||||
revision = "4b35cbe614a5693f0d8978de8718efdbf06d5706"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-6bit"
|
||||
revision = "6a4b4e620a3a7c7759227d8905cf8293ab28bc54"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-8bit"
|
||||
revision = "b3a202c6df57f7297fb351486938952352dcd25a"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-5-8bit-MXFP8"
|
||||
revision = "aa833c40d178262d4ac8b92965807ef988e9340e"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-5-MXFP4-Q8"
|
||||
revision = "41c00f3f30615c0759475497e919c1b515b3cdc0"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/GLM-5"
|
||||
revision = "27d3ffdf48b063d00f0c3d49f9f8fab09609c275"
|
||||
n_layers = 78
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
revision = "91fb4f9fd1de100104925196d62b8ee06fd2ad60"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Kimi-K2-Thinking"
|
||||
revision = "035a0cdd221ae0dca6b03120e20704a251a7bc9b"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Kimi-K2.5"
|
||||
revision = "351021afd838c866ce1a7374fce51d615773d2a8"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
revision = "08231374eeacb049a0eade7922910865b8fce912"
|
||||
n_layers = 16
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
revision = "7f0dc925e0d0afb0322d96f9255cfddf2ba5636e"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
revision = "ff054899609078569493def2823f9acd2780c0c9"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
revision = "de2dfaf56839b7d0e834157d2401dee02726874d"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
revision = "c5bfd839cd4cda0e5a39a97e00218d9c56e468af"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
revision = "7772c93cf077b642f5503dd8d763a4176d7d406c"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
revision = "241a666dad6cb93c8ff213d39a7f34a36bf26db4"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
revision = "142d428004044c37c441272c91316251d9aecc58"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
revision = "f8311090f9ee47782b6f094984a20c856eb841d6"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-3bit"
|
||||
revision = "472cd920149fc1200e6ef2a2efc35db91cf44111"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-8bit"
|
||||
revision = "3d779130c25f54aa9198da1c845d844de7acc086"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-4bit"
|
||||
revision = "36fb6facb4697ac2e6c4e88b600cd8601fb62f08"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-6bit"
|
||||
revision = "6294b58e9eff340c3556dc8aa3ed688e9dc428f8"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/MiniMax-M2.5-8bit"
|
||||
revision = "26af8b335da2017182616067c9342940a1c1ae73"
|
||||
n_layers = 62
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-4bit"
|
||||
revision = "73e3e38d981303bc594367cd910ea6eb48349da8"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-8bit"
|
||||
revision = "11de96878523501bcaa86104e3c186de07ff9068"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
revision = "4dbf8a62338880825560dff3f58f2e9f0c56210f"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
revision = "97042893088decff8468f7729c1076dcad2f251b"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
revision = "d388dead1515f5e085ef7a0431dd8fadf0886c57"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
revision = "7d5b2e500d961076e3c16d6bf957b9c36783b0f5"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
revision = "ca8dbf41071f579fbe3260f20bbe1ab896f79031"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
|
||||
revision = "b4b2d06d678ac2819da4c41618a36a2dc8eeec03"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-4bit"
|
||||
revision = "7b9321eabb85ce79625cac3f61ea691e4ea984b5"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-5bit"
|
||||
revision = "1f3e27b1c376095ebf88a8037807c92784c25d66"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-6bit"
|
||||
revision = "9d12cc36cc6c386ffd04f7c8f0de6ccb29c5927e"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-8bit"
|
||||
revision = "6d3c664dc8539a711783391484fd6784c51fd8fa"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-bf16"
|
||||
revision = "83d523e8883faaea659705840ce3560472286d08"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
revision = "d8a069bfa8ae87d3d468412e1034acae19b5892b"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
revision = "fd52af0cc2a4a37b60904c4b0251255aa7d3dda2"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
revision = "9a2b46347bb170cb2924092175fa21554fe585a9"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
revision = "d093dbe8233828ca0cc420f75466133c542a1e96"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-4bit"
|
||||
revision = "caad23b0411c27d08aa3967822e61e5efcdd175c"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-6bit"
|
||||
revision = "78bbe35b6d9f0a4fa6425416aedea6021e4c40ec"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/Step-3.5-Flash-8Bit"
|
||||
revision = "0db40758dfe577e0ce383c3562226c2fbeac1d8e"
|
||||
n_layers = 45
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
revision = "81e5ac3ad0af6efb1298a8e8c7a10ed2990c137b"
|
||||
n_layers = 36
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
revision = "9f9d50e7b3418526519c2e21306d1c381e9181b2"
|
||||
n_layers = 24
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
revision = "8103891b028a8933068e47751bc2acc10bb59aa2"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# ruff: noqa: E501, F401
|
||||
|
||||
import builtins
|
||||
import enum
|
||||
import typing
|
||||
|
||||
@typing.final
|
||||
@@ -11,29 +10,6 @@ class AllQueuesFullError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdate:
|
||||
@property
|
||||
def update_type(self) -> ConnectionUpdateType:
|
||||
r"""
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> builtins.str:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@property
|
||||
def remote_ipv4(self) -> builtins.str:
|
||||
r"""
|
||||
Remote connection's IPv4 address.
|
||||
"""
|
||||
@property
|
||||
def remote_tcp_port(self) -> builtins.int:
|
||||
r"""
|
||||
Remote connection's TCP port.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Keypair:
|
||||
r"""
|
||||
@@ -58,21 +34,15 @@ class Keypair:
|
||||
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class MessageTooLargeError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> MessageTooLargeError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class NetworkingHandle:
|
||||
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
|
||||
async def connection_update_recv(self) -> ConnectionUpdate:
|
||||
r"""
|
||||
Receives the next `ConnectionUpdate` from networking.
|
||||
"""
|
||||
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
|
||||
r"""
|
||||
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
will sleep until a `ConnectionUpdate`s is sent.
|
||||
"""
|
||||
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Subscribe to a `GossipSub` topic.
|
||||
@@ -91,24 +61,7 @@ class NetworkingHandle:
|
||||
|
||||
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
"""
|
||||
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
|
||||
r"""
|
||||
Receives the next message from the `GossipSub` network.
|
||||
"""
|
||||
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
|
||||
r"""
|
||||
Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
will sleep until a message is sent.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class MessageTooLargeError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> MessageTooLargeError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
async def recv(self) -> PyFromSwarm: ...
|
||||
|
||||
@typing.final
|
||||
class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
@@ -116,11 +69,26 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
Connection or disconnection event discriminant type.
|
||||
"""
|
||||
Connected = ...
|
||||
Disconnected = ...
|
||||
class PyFromSwarm:
|
||||
@typing.final
|
||||
class Connection(PyFromSwarm):
|
||||
__match_args__ = ("peer_id", "connected",)
|
||||
@property
|
||||
def peer_id(self) -> builtins.str: ...
|
||||
@property
|
||||
def connected(self) -> builtins.bool: ...
|
||||
def __new__(cls, peer_id: builtins.str, connected: builtins.bool) -> PyFromSwarm.Connection: ...
|
||||
|
||||
@typing.final
|
||||
class Message(PyFromSwarm):
|
||||
__match_args__ = ("origin", "topic", "data",)
|
||||
@property
|
||||
def origin(self) -> builtins.str: ...
|
||||
@property
|
||||
def topic(self) -> builtins.str: ...
|
||||
@property
|
||||
def data(self) -> bytes: ...
|
||||
def __new__(cls, origin: builtins.str, topic: builtins.str, data: bytes) -> PyFromSwarm.Message: ...
|
||||
|
||||
...
|
||||
|
||||
|
||||
@@ -4,11 +4,12 @@ build-backend = "maturin"
|
||||
|
||||
[project]
|
||||
name = "exo_pyo3_bindings"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" }
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" },
|
||||
{ name = "Evan Quiney", email = "evanev7@gmail.com" }
|
||||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = []
|
||||
|
||||
@@ -155,6 +155,9 @@ pub(crate) mod ext {
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
pyo3_log::init();
|
||||
let mut builder = tokio::runtime::Builder::new_multi_thread();
|
||||
builder.enable_all();
|
||||
pyo3_async_runtimes::tokio::init(builder);
|
||||
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::exception::{
|
||||
PyAllQueuesFullError, PyMessageTooLargeError, PyNoPeersSubscribedToTopicError,
|
||||
};
|
||||
use crate::pyclass;
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use futures_lite::{Stream, StreamExt as _};
|
||||
use libp2p::gossipsub::PublishError;
|
||||
use networking::swarm::{FromSwarm, ToSwarm, create_swarm};
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use pyo3::{Bound, Py, PyAny, PyErr, PyResult, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{
|
||||
gen_methods_from_python, gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods,
|
||||
};
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
|
||||
mod exception {
|
||||
@@ -131,237 +129,45 @@ mod exception {
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection or disconnection event discriminant type.
|
||||
#[gen_stub_pyclass_enum]
|
||||
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum PyConnectionUpdateType {
|
||||
Connected = 0,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, name = "ConnectionUpdate")]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyConnectionUpdate {
|
||||
/// Whether this is a connection or disconnection event
|
||||
#[pyo3(get)]
|
||||
update_type: PyConnectionUpdateType,
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: String,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
remote_ipv4: String,
|
||||
|
||||
/// Remote connection's TCP port.
|
||||
#[pyo3(get)]
|
||||
remote_tcp_port: u16,
|
||||
}
|
||||
|
||||
enum ToTask {
|
||||
GossipsubSubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
||||
},
|
||||
GossipsubUnsubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<bool>,
|
||||
},
|
||||
GossipsubPublish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::enum_glob_use)]
|
||||
async fn networking_task(
|
||||
mut swarm: networking::swarm::Swarm,
|
||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = to_task_rx.recv() => {
|
||||
// handle closed channel
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming messages
|
||||
match message {
|
||||
GossipsubSubscribe { topic, result_tx } => {
|
||||
// try to subscribe
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubUnsubscribe { topic, result_tx } => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(result) {
|
||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubPublish { topic, data, result_tx } => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
||||
Err(exception::PyAllQueuesFullError::new_err())
|
||||
} else if let Err(PublishError::MessageTooLarge) = result {
|
||||
Err(exception::PyMessageTooLargeError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(pyresult) {
|
||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// architectural solution to this problem:
|
||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
||||
//
|
||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
||||
// then for actual communication it will dial those peers if need-be
|
||||
swarm_event = swarm.select_next_some() => {
|
||||
match swarm_event {
|
||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
||||
message: Message {
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => {
|
||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
||||
let message = (topic.into_string(), data);
|
||||
|
||||
// send incoming message to channel (or exit if connection closed)
|
||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
e => {
|
||||
log::info!("RUST: other event {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: networking task stopped");
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
||||
pub to_swarm: mpsc::Sender<ToSwarm>,
|
||||
pub swarm: Arc<Mutex<Pin<Box<dyn Stream<Item = FromSwarm> + Send>>>>,
|
||||
}
|
||||
|
||||
impl Drop for PyNetworkingHandle {
|
||||
fn drop(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
#[gen_stub_pyclass_complex_enum]
|
||||
#[pyclass]
|
||||
enum PyFromSwarm {
|
||||
Connection {
|
||||
peer_id: String,
|
||||
connected: bool,
|
||||
},
|
||||
Message {
|
||||
origin: String,
|
||||
topic: String,
|
||||
data: Py<PyBytes>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyNetworkingHandle {
|
||||
fn new(
|
||||
to_task_tx: mpsc::Sender<ToTask>,
|
||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_task_tx: Some(to_task_tx),
|
||||
connection_update_rx: Mutex::new(connection_update_rx),
|
||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
||||
impl From<FromSwarm> for PyFromSwarm {
|
||||
fn from(value: FromSwarm) -> Self {
|
||||
match value {
|
||||
FromSwarm::Discovered { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: true,
|
||||
},
|
||||
FromSwarm::Expired { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: false,
|
||||
},
|
||||
FromSwarm::Message { from, topic, data } => Self::Message {
|
||||
origin: from.to_base58(),
|
||||
topic: topic,
|
||||
data: data.pybytes(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
||||
self.to_task_tx
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
@@ -375,97 +181,36 @@ impl PyNetworkingHandle {
|
||||
|
||||
#[new]
|
||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channels
|
||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let swarm = get_runtime()
|
||||
.block_on(async { create_swarm(identity) })
|
||||
.pyerr()?;
|
||||
let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();
|
||||
let swarm = create_swarm(identity, from_client).pyerr()?.into_stream();
|
||||
|
||||
// spawn tokio task running the networking logic
|
||||
get_runtime().spawn(async move {
|
||||
networking_task(
|
||||
swarm,
|
||||
to_task_rx,
|
||||
connection_update_tx,
|
||||
gossipsub_message_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(Self::new(
|
||||
to_task_tx,
|
||||
connection_update_rx,
|
||||
gossipsub_message_rx,
|
||||
))
|
||||
Ok(Self {
|
||||
swarm: Arc::new(Mutex::new(swarm)),
|
||||
to_swarm,
|
||||
})
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
fn recv<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
|
||||
let swarm = Arc::clone(&self.swarm);
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
swarm
|
||||
.try_lock()
|
||||
.map_err(|_| PyRuntimeError::new_err("called recv twice concurrently"))?
|
||||
.next()
|
||||
.await
|
||||
.ok_or(PyErr::receiver_channel_closed())
|
||||
.map(PyFromSwarm::from)
|
||||
})
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
|
||||
// ---- Connection update receiver methods ----
|
||||
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
||||
// fn connection_update_is_empty(&self) -> bool {
|
||||
// self.connection_update_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
||||
// fn connection_update_len(&self) -> usize {
|
||||
// self.connection_update_rx.blocking_lock().len()
|
||||
// }
|
||||
|
||||
// ---- Gossipsub management methods ----
|
||||
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
@@ -475,10 +220,10 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubSubscribe {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Subscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -487,6 +232,7 @@ impl PyNetworkingHandle {
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
.pyerr()
|
||||
}
|
||||
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
@@ -496,10 +242,10 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubUnsubscribe {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Unsubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -518,11 +264,11 @@ impl PyNetworkingHandle {
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Publish {
|
||||
topic,
|
||||
data,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -531,64 +277,26 @@ impl PyNetworkingHandle {
|
||||
let _ = rx
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
.map_err(|e| match e {
|
||||
PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),
|
||||
PublishError::MessageTooLarge => PyMessageTooLargeError::new_err(),
|
||||
PublishError::NoPeersSubscribedToTopic => {
|
||||
PyNoPeersSubscribedToTopicError::new_err()
|
||||
}
|
||||
e => PyRuntimeError::new_err(e.to_string()),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Gossipsub message receiver methods ----
|
||||
|
||||
/// Receives the next message from the `GossipSub` network.
|
||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
||||
self.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
pyo3_stub_gen::inventory::submit! {
|
||||
gen_methods_from_python! {
|
||||
r#"
|
||||
class PyNetworkingHandle:
|
||||
async def recv() -> PyFromSwarm: ...
|
||||
"#
|
||||
}
|
||||
|
||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
Ok(self
|
||||
.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
.collect())
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next message from the `GossipSub` network.
|
||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
||||
// Ok(self
|
||||
// .gossipsub_message_rx
|
||||
// .blocking_lock()
|
||||
// .try_recv_py()?
|
||||
// .map(|(t, d)| (t, d.pybytes())))
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `GossipSub` message channel is empty.
|
||||
// fn gossipsub_is_empty(&self) -> bool {
|
||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `GossipSub` messages in the channel.
|
||||
// fn gossipsub_len(&self) -> usize {
|
||||
// self.gossipsub_message_rx.blocking_lock().len()
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
@@ -596,10 +304,8 @@ pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<exception::PyAllQueuesFullError>()?;
|
||||
m.add_class::<exception::PyMessageTooLargeError>()?;
|
||||
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyConnectionUpdate>()?;
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
m.add_class::<PyFromSwarm>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -21,9 +21,10 @@ extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
async-stream = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
@@ -35,3 +36,4 @@ log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use futures_lite::StreamExt;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
use libp2p::identity;
|
||||
use networking::swarm;
|
||||
use networking::swarm::{FromSwarm, ToSwarm};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::{io, io::AsyncBufReadExt as _};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
@@ -11,64 +13,69 @@ async fn main() {
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
let (to_swarm, from_client) = mpsc::channel(20);
|
||||
|
||||
// Configure swarm
|
||||
let mut swarm =
|
||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
||||
let mut swarm = swarm::create_swarm(identity::Keypair::generate_ed25519(), from_client)
|
||||
.expect("Swarm creation failed")
|
||||
.into_stream();
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("Subscribing to topic failed");
|
||||
let (tx, rx) = oneshot::channel();
|
||||
_ = to_swarm
|
||||
.send(ToSwarm::Subscribe {
|
||||
topic: "test-net".to_string(),
|
||||
result_sender: tx,
|
||||
})
|
||||
.await
|
||||
.expect("should send");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
rx.await
|
||||
.expect("tx not dropped")
|
||||
.expect("subscribe shouldn't fail");
|
||||
loop {
|
||||
if let Ok(Some(line)) = stdin.next_line().await {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
if let Err(e) = to_swarm
|
||||
.send(swarm::ToSwarm::Publish {
|
||||
topic: "test-net".to_string(),
|
||||
data: line.as_bytes().to_vec(),
|
||||
result_sender: tx,
|
||||
})
|
||||
.await
|
||||
{
|
||||
println!("Send error: {e:?}");
|
||||
return;
|
||||
};
|
||||
match rx.await {
|
||||
Ok(Err(e)) => println!("Publish error: {e:?}"),
|
||||
Err(e) => println!("Publish error: {e:?}"),
|
||||
Ok(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
// on gossipsub outgoing
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
// on gossipsub outgoing
|
||||
match swarm.next().await {
|
||||
// on gossipsub incoming
|
||||
Some(FromSwarm::Discovered { peer_id }) => {
|
||||
println!("\n\nconnected to {peer_id}\n\n")
|
||||
}
|
||||
event = swarm.next() => match event {
|
||||
// on gossipsub incoming
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
}))) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
discovery::Event::ConnectionClosed {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
Some(FromSwarm::Expired { peer_id }) => {
|
||||
println!("\n\ndisconnected from {peer_id}\n\n")
|
||||
}
|
||||
Some(FromSwarm::Message { from, topic, data }) => {
|
||||
println!("{topic}/{from}:\n{}", String::from_utf8_lossy(&data))
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use libp2p::{SwarmBuilder, identity};
|
||||
use std::pin::Pin;
|
||||
|
||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
use crate::{alias, discovery};
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use futures_lite::{Stream, StreamExt};
|
||||
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
@@ -15,8 +17,136 @@ pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||
|
||||
// Uses oneshot senders to emulate function calling apis while avoiding requiring unique ownership
|
||||
// of the Swarm.
|
||||
pub enum ToSwarm {
|
||||
Unsubscribe {
|
||||
topic: String,
|
||||
result_sender: oneshot::Sender<bool>,
|
||||
},
|
||||
Subscribe {
|
||||
topic: String,
|
||||
result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,
|
||||
},
|
||||
Publish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,
|
||||
},
|
||||
}
|
||||
pub enum FromSwarm {
|
||||
Message {
|
||||
from: PeerId,
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
},
|
||||
Discovered {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
Expired {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct Swarm {
|
||||
swarm: libp2p::Swarm<Behaviour>,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
}
|
||||
|
||||
impl Swarm {
|
||||
pub fn into_stream(self) -> Pin<Box<dyn Stream<Item = FromSwarm> + Send>> {
|
||||
let Swarm {
|
||||
mut swarm,
|
||||
mut from_client,
|
||||
} = self;
|
||||
let stream = async_stream::stream! {
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = from_client.recv() => {
|
||||
let Some(msg) = msg else { break };
|
||||
on_message(&mut swarm, msg);
|
||||
}
|
||||
event = swarm.next() => {
|
||||
let Some(event) = event else { break };
|
||||
if let Some(item) = filter_swarm_event(event) {
|
||||
yield item;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
Box::pin(stream)
|
||||
}
|
||||
}
|
||||
|
||||
fn on_message(swarm: &mut libp2p::Swarm<Behaviour>, message: ToSwarm) {
|
||||
match message {
|
||||
ToSwarm::Subscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&gossipsub::IdentTopic::new(topic));
|
||||
_ = result_sender.send(result);
|
||||
}
|
||||
ToSwarm::Unsubscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.unsubscribe(&gossipsub::IdentTopic::new(topic));
|
||||
_ = result_sender.send(result);
|
||||
}
|
||||
ToSwarm::Publish {
|
||||
topic,
|
||||
data,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.publish(gossipsub::IdentTopic::new(topic), data);
|
||||
_ = result_sender.send(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Option<FromSwarm> {
|
||||
match event {
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
message:
|
||||
gossipsub::Message {
|
||||
source: Some(peer_id),
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => Some(FromSwarm::Message {
|
||||
from: peer_id,
|
||||
topic: topic.into_string(),
|
||||
data,
|
||||
}),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
|
||||
discovery::Event::ConnectionEstablished { peer_id, .. },
|
||||
)) => Some(FromSwarm::Discovered { peer_id }),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(discovery::Event::ConnectionClosed {
|
||||
peer_id,
|
||||
..
|
||||
})) => Some(FromSwarm::Expired { peer_id }),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create and configure a swarm which listens to all ports on OS
|
||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
pub fn create_swarm(
|
||||
keypair: identity::Keypair,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
) -> alias::AnyResult<Swarm> {
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
@@ -25,7 +155,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
Ok(swarm)
|
||||
Ok(Swarm { swarm, from_client })
|
||||
}
|
||||
|
||||
mod transport {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from random import random
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
@@ -21,13 +20,9 @@ from exo.shared.types.commands import (
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
# TODO(evan): just for acks, should delete this ASAP
|
||||
GlobalForwarderEvent,
|
||||
LocalForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
@@ -38,40 +33,28 @@ from exo.shared.types.worker.downloads import (
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadCoordinator:
|
||||
node_id: NodeId
|
||||
session_id: SessionId
|
||||
shard_downloader: ShardDownloader
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[LocalForwarderEvent]
|
||||
|
||||
# ack stuff
|
||||
_global_event_receiver: Receiver[GlobalForwarderEvent]
|
||||
_out_for_delivery: dict[EventId, LocalForwarderEvent] = field(default_factory=dict)
|
||||
|
||||
event_sender: Sender[Event]
|
||||
offline: bool = False
|
||||
|
||||
_system_id: SystemId = field(default_factory=SystemId)
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
|
||||
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
|
||||
|
||||
# Per-model throttle for download progress events
|
||||
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
|
||||
def _model_dir(self, model_id: ModelId) -> str:
|
||||
@@ -123,10 +106,7 @@ class DownloadCoordinator:
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._clear_ofd)
|
||||
finally:
|
||||
for task in self.active_downloads.values():
|
||||
task.cancel()
|
||||
@@ -134,20 +114,6 @@ class DownloadCoordinator:
|
||||
def shutdown(self) -> None:
|
||||
self._tg.cancel_tasks()
|
||||
|
||||
# directly copied from worker
|
||||
async def _resend_out_for_delivery(self) -> None:
|
||||
# This can also be massively tightened, we should check events are at least a certain age before resending.
|
||||
# Exponential backoff would also certainly help here.
|
||||
while True:
|
||||
await anyio.sleep(1 + random())
|
||||
for event in self._out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
async def _clear_ofd(self) -> None:
|
||||
with self._global_event_receiver as events:
|
||||
async for event in events:
|
||||
self._out_for_delivery.pop(event.event.event_id, None)
|
||||
|
||||
async def _command_processor(self) -> None:
|
||||
with self.download_command_receiver as commands:
|
||||
async for cmd in commands:
|
||||
@@ -320,23 +286,6 @@ class DownloadCoordinator:
|
||||
)
|
||||
del self.download_status[model_id]
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
idx = 0
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
fe = LocalForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self._system_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
idx += 1
|
||||
logger.debug(
|
||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||
)
|
||||
await self.local_event_sender.send(fe)
|
||||
self._out_for_delivery[event.event_id] = fe
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -314,9 +314,13 @@ async def fetch_file_list_with_cache(
|
||||
_fetched_file_lists_this_session.add(cache_key)
|
||||
return file_list
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
"Ran into exception when fetching file list from HF."
|
||||
)
|
||||
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"No internet and no cached file list for {model_id} - using local file list"
|
||||
f"No cached file list for {model_id} - using local file list"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
@@ -692,7 +696,8 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
|
||||
# (iii) Tensor parallel requires all files.
|
||||
return ["*"]
|
||||
try:
|
||||
weight_map = await get_weight_map(str(shard.model_card.model_id))
|
||||
revision = shard.model_card.revision or "main"
|
||||
weight_map = await get_weight_map(str(shard.model_card.model_id), revision)
|
||||
return get_allow_patterns(weight_map, shard)
|
||||
except Exception:
|
||||
logger.error(f"Error getting weight map for {shard.model_card.model_id=}")
|
||||
@@ -726,7 +731,7 @@ async def download_shard(
|
||||
if not skip_download:
|
||||
logger.debug(f"Downloading {shard.model_card.model_id=}")
|
||||
|
||||
revision = "main"
|
||||
revision = shard.model_card.revision or "main"
|
||||
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
|
||||
"/", "--"
|
||||
)
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.download.coordinator import DownloadCoordinator
|
||||
from exo.download.shard_downloader import NoopShardDownloader
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
GlobalForwarderEvent,
|
||||
LocalForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadPending,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.utils.channels import channel
|
||||
|
||||
# Use the built‑in NoopShardDownloader directly – it already implements the required abstract interface.
|
||||
# No additional subclass is needed for this test.
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ack_behaviour():
|
||||
# Create channels (type Any for simplicity)
|
||||
_, command_receiver = channel[Any]()
|
||||
local_sender, _ = channel[Any]()
|
||||
global_sender, global_receiver = channel[Any]()
|
||||
|
||||
# Minimal identifiers
|
||||
node_id = NodeId()
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
# Create a dummy model card and shard metadata
|
||||
model_id = ModelId("test/model")
|
||||
model_card = ModelCard(
|
||||
model_id=model_id,
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
shard = PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
)
|
||||
|
||||
# Instantiate the coordinator with the dummy downloader
|
||||
coord = DownloadCoordinator(
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
shard_downloader=NoopShardDownloader(),
|
||||
download_command_receiver=command_receiver,
|
||||
local_event_sender=local_sender,
|
||||
_global_event_receiver=global_receiver,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
# Start the forwarding and ack‑clearing loops
|
||||
tg.start_soon(coord._forward_events) # pyright: ignore[reportPrivateUsage]
|
||||
tg.start_soon(coord._clear_ofd) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Send a pending download progress event via the internal event sender
|
||||
pending = DownloadPending(
|
||||
node_id=node_id,
|
||||
shard_metadata=shard,
|
||||
model_directory="/tmp/model",
|
||||
)
|
||||
await coord.event_sender.send(NodeDownloadProgress(download_progress=pending))
|
||||
# Allow the forwarder to process the event
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# There should be exactly one entry awaiting ACK
|
||||
assert len(coord._out_for_delivery) == 1 # pyright: ignore[reportPrivateUsage]
|
||||
# Retrieve the stored LocalForwarderEvent
|
||||
stored_fe: LocalForwarderEvent = next(iter(coord._out_for_delivery.values())) # pyright: ignore[reportPrivateUsage]
|
||||
# Simulate receiving a global ack for this event
|
||||
ack = GlobalForwarderEvent(
|
||||
origin_idx=0,
|
||||
origin=node_id,
|
||||
session=session_id,
|
||||
event=stored_fe.event,
|
||||
)
|
||||
await global_sender.send(ack)
|
||||
# Give the clear‑ofd task a moment to process the ack
|
||||
await anyio.sleep(0.1)
|
||||
# The out‑for‑delivery map should now be empty
|
||||
assert len(coord._out_for_delivery) == 0 # pyright: ignore[reportPrivateUsage]
|
||||
# Cancel background tasks
|
||||
tg.cancel_scope.cancel()
|
||||
@@ -15,6 +15,7 @@ from exo.download.coordinator import DownloadCoordinator
|
||||
from exo.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.master.api import API # TODO: should API be in master?
|
||||
from exo.master.main import Master
|
||||
from exo.routing.event_router import EventRouter
|
||||
from exo.routing.router import Router, get_node_id_keypair
|
||||
from exo.shared.constants import EXO_LOG
|
||||
from exo.shared.election import Election, ElectionResult
|
||||
@@ -29,6 +30,7 @@ from exo.worker.main import Worker
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
event_router: EventRouter
|
||||
download_coordinator: DownloadCoordinator | None
|
||||
worker: Worker | None
|
||||
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
|
||||
@@ -52,6 +54,12 @@ class Node:
|
||||
await router.register_topic(topics.ELECTION_MESSAGES)
|
||||
await router.register_topic(topics.CONNECTION_MESSAGES)
|
||||
await router.register_topic(topics.DOWNLOAD_COMMANDS)
|
||||
event_router = EventRouter(
|
||||
session_id,
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
external_outbound=router.sender(topics.LOCAL_EVENTS),
|
||||
external_inbound=router.receiver(topics.GLOBAL_EVENTS),
|
||||
)
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
@@ -59,13 +67,10 @@ class Node:
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(offline=args.offline),
|
||||
event_sender=event_router.sender(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
offline=args.offline,
|
||||
# TODO(evan): remove
|
||||
_global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
@@ -73,9 +78,8 @@ class Node:
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id,
|
||||
session_id,
|
||||
port=args.api_port,
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
event_receiver=event_router.receiver(),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
@@ -86,9 +90,8 @@ class Node:
|
||||
if not args.no_worker:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_receiver=event_router.receiver(),
|
||||
event_sender=event_router.sender(),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
)
|
||||
@@ -99,6 +102,7 @@ class Node:
|
||||
master = Master(
|
||||
node_id,
|
||||
session_id,
|
||||
event_sender=event_router.sender(),
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
@@ -121,6 +125,7 @@ class Node:
|
||||
|
||||
return cls(
|
||||
router,
|
||||
event_router,
|
||||
download_coordinator,
|
||||
worker,
|
||||
election,
|
||||
@@ -136,6 +141,7 @@ class Node:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.event_router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
tg.start_soon(self.download_coordinator.run)
|
||||
@@ -183,6 +189,7 @@ class Node:
|
||||
self.master = Master(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
event_sender=self.event_router.sender(),
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
@@ -206,21 +213,24 @@ class Node:
|
||||
)
|
||||
if result.is_new_master:
|
||||
await anyio.sleep(0)
|
||||
self.event_router.shutdown()
|
||||
self.event_router = EventRouter(
|
||||
result.session_id,
|
||||
self.router.sender(topics.COMMANDS),
|
||||
self.router.receiver(topics.GLOBAL_EVENTS),
|
||||
self.router.sender(topics.LOCAL_EVENTS),
|
||||
)
|
||||
self._tg.start_soon(self.event_router.run)
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(offline=self.offline),
|
||||
event_sender=self.event_router.sender(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
offline=self.offline,
|
||||
# TODO(evan): remove
|
||||
_global_event_receiver=self.router.receiver(
|
||||
topics.GLOBAL_EVENTS
|
||||
),
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
@@ -228,11 +238,8 @@ class Node:
|
||||
# TODO: add profiling etc to resource monitor
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
global_event_receiver=self.router.receiver(
|
||||
topics.GLOBAL_EVENTS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_receiver=self.event_router.receiver(),
|
||||
event_sender=self.event_router.sender(),
|
||||
command_sender=self.router.sender(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
@@ -240,7 +247,7 @@ class Node:
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
self.api.reset(result.session_id, result.won_clock)
|
||||
self.api.reset(result.won_clock, self.event_router.receiver())
|
||||
else:
|
||||
if self.api:
|
||||
self.api.unpause(result.won_clock)
|
||||
|
||||
@@ -140,11 +140,10 @@ from exo.shared.types.commands import (
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
GlobalForwarderEvent,
|
||||
IndexedEvent,
|
||||
TracesMerged,
|
||||
)
|
||||
@@ -172,7 +171,6 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
_API_EVENT_LOG_DIR = EXO_EVENT_LOG_DIR / "api"
|
||||
@@ -196,10 +194,9 @@ class API:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
*,
|
||||
port: int,
|
||||
global_event_receiver: Receiver[GlobalForwarderEvent],
|
||||
event_receiver: Receiver[IndexedEvent],
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
@@ -210,11 +207,9 @@ class API:
|
||||
self._system_id = SystemId()
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.event_receiver = event_receiver
|
||||
self.election_receiver = election_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
self.last_completed_election: int = 0
|
||||
self.port = port
|
||||
|
||||
@@ -254,17 +249,18 @@ class API:
|
||||
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
def reset(self, result_clock: int, event_receiver: Receiver[IndexedEvent]):
|
||||
logger.info("Resetting API State")
|
||||
self._event_log.close()
|
||||
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
|
||||
self.state = State()
|
||||
self._system_id = SystemId()
|
||||
self.session_id = new_session_id
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self._text_generation_queues = {}
|
||||
self._image_generation_queues = {}
|
||||
self.unpause(result_clock)
|
||||
self.event_receiver.close()
|
||||
self.event_receiver = event_receiver
|
||||
self._tg.start_soon(self._apply_state)
|
||||
|
||||
def unpause(self, result_clock: int):
|
||||
logger.info("Unpausing API")
|
||||
@@ -524,15 +520,15 @@ class API:
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
instance.sharding(),
|
||||
instance.instance_meta(),
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
sharding=instance.sharding(),
|
||||
instance_meta=instance.instance_meta(),
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
memory_delta_by_node=memory_delta_by_node or None,
|
||||
error=None,
|
||||
@@ -541,8 +537,8 @@ class API:
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
instance.sharding(),
|
||||
instance.instance_meta(),
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
)
|
||||
)
|
||||
@@ -1606,7 +1602,7 @@ class API:
|
||||
finally:
|
||||
self._event_log.close()
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
self.event_receiver.close()
|
||||
|
||||
async def run_api(self, ev: anyio.Event):
|
||||
cfg = Config()
|
||||
@@ -1623,38 +1619,31 @@ class API:
|
||||
)
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.session != self.session_id:
|
||||
continue
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
continue
|
||||
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
with self.event_receiver as events:
|
||||
async for i_event in events:
|
||||
self._event_log.append(i_event.event)
|
||||
self.state = apply(self.state, i_event)
|
||||
event = i_event.event
|
||||
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if queue := self._image_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._image_generation_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert not isinstance(event.chunk, ImageChunk)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if queue := self._image_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._image_generation_queues.pop(event.command_id, None)
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
assert not isinstance(event.chunk, ImageChunk)
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
|
||||
def _save_merged_trace(self, event: TracesMerged) -> None:
|
||||
traces = [
|
||||
|
||||
@@ -60,7 +60,7 @@ from exo.shared.types.tasks import (
|
||||
TextGeneration as TextGenerationTask,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
@@ -72,25 +72,21 @@ class Master:
|
||||
session_id: SessionId,
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
event_sender: Sender[Event],
|
||||
local_event_receiver: Receiver[LocalForwarderEvent],
|
||||
global_event_sender: Sender[GlobalForwarderEvent],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
self.node_id = node_id
|
||||
self.session_id = session_id
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
self.command_task_mapping: dict[CommandId, TaskId] = {}
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
self._loopback_event_sender: Sender[LocalForwarderEvent] = (
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self.event_sender = event_sender
|
||||
self._system_id = SystemId()
|
||||
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
@@ -104,15 +100,12 @@ class Master:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
finally:
|
||||
self._event_log.close()
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
@@ -409,22 +402,6 @@ class Master:
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _loopback_processor(self) -> None:
|
||||
# this would ideally not be necessary.
|
||||
# this is WAY less hacky than how I was working around this before
|
||||
local_index = 0
|
||||
with self._loopback_event_receiver as events:
|
||||
async for event in events:
|
||||
await self._loopback_event_sender.send(
|
||||
LocalForwarderEvent(
|
||||
origin=self._system_id,
|
||||
origin_idx=local_index,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
local_index += 1
|
||||
|
||||
# This function is re-entrant, take care!
|
||||
async def _send_event(self, event: IndexedEvent):
|
||||
# Convenience method since this line is ugly
|
||||
|
||||
@@ -17,6 +17,7 @@ from exo.shared.types.commands import (
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
GlobalForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
@@ -50,6 +51,22 @@ async def test_master():
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[LocalForwarderEvent]()
|
||||
fcds, _fcdr = channel[ForwarderDownloadCommand]()
|
||||
ev_send, ev_recv = channel[Event]()
|
||||
|
||||
async def mock_event_router():
|
||||
idx = 0
|
||||
sid = SystemId()
|
||||
with ev_recv as master_events:
|
||||
async for event in master_events:
|
||||
await local_event_sender.send(
|
||||
LocalForwarderEvent(
|
||||
origin=sid,
|
||||
origin_idx=idx,
|
||||
session=session_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
idx += 1
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
|
||||
@@ -67,6 +84,7 @@ async def test_master():
|
||||
master = Master(
|
||||
node_id,
|
||||
session_id,
|
||||
event_sender=ev_send,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
@@ -75,6 +93,7 @@ async def test_master():
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
tg.start_soon(mock_event_router)
|
||||
|
||||
# inject a NodeGatheredInfo event
|
||||
logger.info("inject a NodeGatheredInfo event")
|
||||
@@ -197,4 +216,5 @@ async def test_master():
|
||||
input=[InputMessage(role="user", content="Hello, how are you?")],
|
||||
)
|
||||
|
||||
ev_send.close()
|
||||
await master.shutdown()
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
|
||||
from exo_pyo3_bindings import PyFromSwarm
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
@@ -8,30 +6,10 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
"""Serialisable types for Connection Updates/Messages"""
|
||||
|
||||
|
||||
class ConnectionMessageType(Enum):
|
||||
Connected = 0
|
||||
Disconnected = 1
|
||||
|
||||
@staticmethod
|
||||
def from_update_type(update_type: ConnectionUpdateType):
|
||||
match update_type:
|
||||
case ConnectionUpdateType.Connected:
|
||||
return ConnectionMessageType.Connected
|
||||
case ConnectionUpdateType.Disconnected:
|
||||
return ConnectionMessageType.Disconnected
|
||||
|
||||
|
||||
class ConnectionMessage(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
connection_type: ConnectionMessageType
|
||||
remote_ipv4: str
|
||||
remote_tcp_port: int
|
||||
connected: bool
|
||||
|
||||
@classmethod
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(update.peer_id),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
)
|
||||
def from_update(cls, update: PyFromSwarm.Connection) -> "ConnectionMessage":
|
||||
return cls(node_id=NodeId(update.peer_id), connected=update.connected)
|
||||
|
||||
161
src/exo/routing/event_router.py
Normal file
161
src/exo/routing/event_router.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from dataclasses import dataclass, field
|
||||
from random import random
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, ClosedResourceError
|
||||
from anyio.abc import CancelScope
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import SessionId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
GlobalForwarderEvent,
|
||||
IndexedEvent,
|
||||
LocalForwarderEvent,
|
||||
)
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventRouter:
|
||||
session_id: SessionId
|
||||
command_sender: Sender[ForwarderCommand]
|
||||
external_inbound: Receiver[GlobalForwarderEvent]
|
||||
external_outbound: Sender[LocalForwarderEvent]
|
||||
_system_id: SystemId = field(init=False, default_factory=SystemId)
|
||||
internal_outbound: list[Sender[IndexedEvent]] = field(
|
||||
init=False, default_factory=list
|
||||
)
|
||||
event_buffer: OrderedBuffer[Event] = field(
|
||||
init=False, default_factory=OrderedBuffer
|
||||
)
|
||||
out_for_delivery: dict[EventId, tuple[float, LocalForwarderEvent]] = field(
|
||||
init=False, default_factory=dict
|
||||
)
|
||||
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
|
||||
|
||||
_nack_cancel_scope: CancelScope | None = field(init=False, default=None)
|
||||
_nack_attempts: int = field(init=False, default=0)
|
||||
_nack_base_seconds: float = field(init=False, default=0.5)
|
||||
_nack_cap_seconds: float = field(init=False, default=10.0)
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._run_ext_in)
|
||||
tg.start_soon(self._simple_retry)
|
||||
finally:
|
||||
self.external_outbound.close()
|
||||
for send in self.internal_outbound:
|
||||
send.close()
|
||||
|
||||
# can make this better in future
|
||||
async def _simple_retry(self):
|
||||
while True:
|
||||
await anyio.sleep(1 + random())
|
||||
# list here is a shallow clone for shared mutation
|
||||
for e_id, (time, event) in list(self.out_for_delivery.items()):
|
||||
if anyio.current_time() > time + 5:
|
||||
self.out_for_delivery[e_id] = (anyio.current_time(), event)
|
||||
await self.external_outbound.send(event)
|
||||
|
||||
def sender(self) -> Sender[Event]:
|
||||
send, recv = channel[Event]()
|
||||
if self._tg.is_running():
|
||||
self._tg.start_soon(self._ingest, SystemId(), recv)
|
||||
else:
|
||||
self._tg.queue(self._ingest, SystemId(), recv)
|
||||
return send
|
||||
|
||||
def receiver(self) -> Receiver[IndexedEvent]:
|
||||
send, recv = channel[IndexedEvent]()
|
||||
self.internal_outbound.append(send)
|
||||
return recv
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._tg.cancel_tasks()
|
||||
|
||||
async def _ingest(self, system_id: SystemId, recv: Receiver[Event]):
|
||||
idx = 0
|
||||
with recv as events:
|
||||
async for event in events:
|
||||
f_ev = LocalForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=system_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
idx += 1
|
||||
await self.external_outbound.send(f_ev)
|
||||
self.out_for_delivery[event.event_id] = (anyio.current_time(), f_ev)
|
||||
|
||||
async def _run_ext_in(self):
|
||||
buf = OrderedBuffer[Event]()
|
||||
with self.external_inbound as events:
|
||||
async for event in events:
|
||||
if event.session != self.session_id:
|
||||
continue
|
||||
if event.origin != self.session_id.master_node_id:
|
||||
continue
|
||||
|
||||
buf.ingest(event.origin_idx, event.event)
|
||||
event_id = event.event.event_id
|
||||
if event_id in self.out_for_delivery:
|
||||
self.out_for_delivery.pop(event_id)
|
||||
|
||||
drained = buf.drain_indexed()
|
||||
if drained:
|
||||
self._nack_attempts = 0
|
||||
if self._nack_cancel_scope:
|
||||
self._nack_cancel_scope.cancel()
|
||||
|
||||
if not drained and (
|
||||
self._nack_cancel_scope is None
|
||||
or self._nack_cancel_scope.cancel_called
|
||||
):
|
||||
# Request the next index.
|
||||
self._tg.start_soon(self._nack_request, buf.next_idx_to_release)
|
||||
continue
|
||||
|
||||
for idx, event in drained:
|
||||
to_clear = set[int]()
|
||||
for i, sender in enumerate(self.internal_outbound):
|
||||
try:
|
||||
await sender.send(IndexedEvent(idx=idx, event=event))
|
||||
except (ClosedResourceError, BrokenResourceError):
|
||||
to_clear.add(i)
|
||||
for i in sorted(to_clear, reverse=True):
|
||||
self.internal_outbound.pop(i)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
# This function is started whenever we receive an event that is out of sequence.
|
||||
# It is cancelled as soon as we receiver an event that is in sequence.
|
||||
|
||||
if since_idx < 0:
|
||||
logger.warning(f"Negative value encountered for nack request {since_idx=}")
|
||||
since_idx = 0
|
||||
|
||||
with CancelScope() as scope:
|
||||
self._nack_cancel_scope = scope
|
||||
delay: float = self._nack_base_seconds * (2.0**self._nack_attempts)
|
||||
delay = min(self._nack_cap_seconds, delay)
|
||||
self._nack_attempts += 1
|
||||
try:
|
||||
await anyio.sleep(delay)
|
||||
logger.info(
|
||||
f"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}"
|
||||
)
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self._system_id,
|
||||
command=RequestEventLog(since_idx=since_idx),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if self._nack_cancel_scope is scope:
|
||||
self._nack_cancel_scope = None
|
||||
@@ -17,6 +17,7 @@ from exo_pyo3_bindings import (
|
||||
MessageTooLargeError,
|
||||
NetworkingHandle,
|
||||
NoPeersSubscribedToTopicError,
|
||||
PyFromSwarm,
|
||||
)
|
||||
from filelock import FileLock
|
||||
from loguru import logger
|
||||
@@ -121,7 +122,8 @@ class Router:
|
||||
send = self.networking_receiver.clone_sender()
|
||||
router = TopicRouter[T](topic, send)
|
||||
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
|
||||
await self._networking_subscribe(str(topic.topic))
|
||||
if self._tg.is_running():
|
||||
await self._networking_subscribe(topic.topic)
|
||||
|
||||
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
|
||||
router = self.topic_routers.get(topic.topic, None)
|
||||
@@ -152,8 +154,10 @@ class Router:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# subscribe to pending topics
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_subscribe(topic)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
finally:
|
||||
@@ -176,41 +180,40 @@ class Router:
|
||||
async def _networking_recv(self):
|
||||
try:
|
||||
while True:
|
||||
topic, data = await self._net.gossipsub_recv()
|
||||
logger.trace(f"Received message on {topic} with payload {data}")
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(
|
||||
f"Received message on unknown or inactive topic {topic}"
|
||||
)
|
||||
continue
|
||||
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
from_swarm = await self._net.recv()
|
||||
logger.debug(from_swarm)
|
||||
match from_swarm:
|
||||
case PyFromSwarm.Message(origin, topic, data):
|
||||
logger.trace(
|
||||
f"Received message on {topic} from {origin} with payload {data}"
|
||||
)
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(
|
||||
f"Received message on unknown or inactive topic {topic}"
|
||||
)
|
||||
continue
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
case PyFromSwarm.Connection():
|
||||
message = ConnectionMessage.from_update(from_swarm)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(message)
|
||||
case _:
|
||||
logger.critical(
|
||||
"failed to exhaustively check FromSwarm messages - logic error"
|
||||
)
|
||||
except Exception as exception:
|
||||
logger.opt(exception=exception).error(
|
||||
"Gossipsub receive loop terminated unexpectedly"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _networking_recv_connection_messages(self):
|
||||
try:
|
||||
while True:
|
||||
update = await self._net.connection_update_recv()
|
||||
message = ConnectionMessage.from_update(update)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(message)
|
||||
except Exception as exception:
|
||||
logger.opt(exception=exception).error(
|
||||
"Connection update receive loop terminated unexpectedly"
|
||||
)
|
||||
raise
|
||||
|
||||
async def _networking_publish(self):
|
||||
with self.networking_receiver as networked_items:
|
||||
async for topic, data in networked_items:
|
||||
|
||||
@@ -5,7 +5,7 @@ import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import tomlkit
|
||||
from anyio import Path, open_file
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub import model_info, repo_info
|
||||
from loguru import logger
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
@@ -79,6 +79,7 @@ class ComponentInfo(CamelCaseModel):
|
||||
|
||||
class ModelCard(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
revision: str | None = None
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
@@ -127,12 +128,17 @@ class ModelCard(CamelCaseModel):
|
||||
async def fetch_from_hf(model_id: ModelId) -> "ModelCard":
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
# TODO: failure if files do not exist
|
||||
config_data = await fetch_config_data(model_id)
|
||||
# Fetch repo info to get the latest commit SHA
|
||||
repo = repo_info(model_id, repo_type="model")
|
||||
revision = repo.sha
|
||||
|
||||
config_data = await fetch_config_data(model_id, revision)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await fetch_safetensors_size(model_id)
|
||||
mem_size_bytes = await fetch_safetensors_size(model_id, revision)
|
||||
|
||||
mc = ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
revision=revision,
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
@@ -219,7 +225,9 @@ class ConfigData(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
async def fetch_config_data(model_id: ModelId) -> ConfigData:
|
||||
async def fetch_config_data(
|
||||
model_id: ModelId, revision: str | None = None
|
||||
) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
@@ -230,7 +238,7 @@ async def fetch_config_data(model_id: ModelId) -> ConfigData:
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
config_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
revision or "main",
|
||||
"config.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.debug(
|
||||
@@ -241,7 +249,9 @@ async def fetch_config_data(model_id: ModelId) -> ConfigData:
|
||||
return ConfigData.model_validate_json(await f.read())
|
||||
|
||||
|
||||
async def fetch_safetensors_size(model_id: ModelId) -> Memory:
|
||||
async def fetch_safetensors_size(
|
||||
model_id: ModelId, revision: str | None = None
|
||||
) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
@@ -253,7 +263,7 @@ async def fetch_safetensors_size(model_id: ModelId) -> Memory:
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_path = await download_file_with_retry(
|
||||
model_id,
|
||||
"main",
|
||||
revision or "main",
|
||||
"model.safetensors.index.json",
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: logger.debug(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from anyio import create_task_group, fail_after, move_on_after
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||
from exo.shared.types.common import NodeId, SessionId, SystemId
|
||||
@@ -327,14 +327,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Send any connection message object; we close quickly to cancel before result creation
|
||||
await cm_tx.send(
|
||||
ConnectionMessage(
|
||||
node_id=NodeId(),
|
||||
connection_type=ConnectionMessageType.Connected,
|
||||
remote_ipv4="",
|
||||
remote_tcp_port=0,
|
||||
)
|
||||
)
|
||||
await cm_tx.send(ConnectionMessage(node_id=NodeId(), connected=True))
|
||||
|
||||
# Expect a broadcast for the new round at clock=1
|
||||
while True:
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mlx import core as mx
|
||||
from mlx import nn as nn
|
||||
from mlx_lm.models.cache import (
|
||||
ArraysCache,
|
||||
CacheList,
|
||||
@@ -14,3 +16,16 @@ from mlx_lm.models.cache import (
|
||||
KVCacheType = Sequence[
|
||||
KVCache | RotatingKVCache | QuantizedKVCache | ArraysCache | CacheList
|
||||
]
|
||||
|
||||
|
||||
# Model is a wrapper function to fix the fact that mlx is not strongly typed in the same way that EXO is.
|
||||
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
|
||||
class Model(nn.Module):
|
||||
layers: list[nn.Module]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: KVCacheType | None,
|
||||
input_embeddings: mx.array | None = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
@@ -4,13 +4,7 @@ from pydantic import model_validator
|
||||
|
||||
from exo.shared.models.model_cards import ModelTask
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
Sharding,
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -30,40 +24,16 @@ class BaseInstance(TaggedModel):
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta: ...
|
||||
|
||||
def sharding(self) -> Sharding:
|
||||
if all(
|
||||
isinstance(sm, PipelineShardMetadata)
|
||||
for sm in self.shard_assignments.runner_to_shard.values()
|
||||
):
|
||||
return Sharding.Pipeline
|
||||
if all(
|
||||
isinstance(sm, TensorShardMetadata)
|
||||
for sm in self.shard_assignments.runner_to_shard.values()
|
||||
):
|
||||
return Sharding.Tensor
|
||||
raise ValueError("shard metadata malformed")
|
||||
|
||||
|
||||
class MlxRingInstance(BaseInstance):
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
ephemeral_port: int
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta:
|
||||
return InstanceMeta.MlxRing
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
jaccl_devices: list[list[str | None]]
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
@staticmethod
|
||||
def instance_meta() -> InstanceMeta:
|
||||
return InstanceMeta.MlxJaccl
|
||||
|
||||
|
||||
# TODO: Single node instance
|
||||
Instance = MlxRingInstance | MlxJacclInstance
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import KVCache
|
||||
|
||||
# These are wrapper functions to fix the fact that mlx is not strongly typed in the same way that EXO is.
|
||||
# For example - MLX has no guarantee of the interface that nn.Module will expose. But we need a guarantee that it has a __call__() function
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
layers: list[nn.Module]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: list[KVCache] | None,
|
||||
input_embeddings: mx.array | None = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
@@ -49,6 +49,21 @@ TimeoutCallback = Callable[[], None]
|
||||
LayerLoadedCallback = Callable[[int, int], None] # (layers_loaded, total_layers)
|
||||
|
||||
|
||||
_pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]] = []
|
||||
|
||||
|
||||
def flush_prefill_sends() -> None:
|
||||
for output, dst, group in _pending_prefill_sends:
|
||||
sent = mx.distributed.send(output, dst, group=group)
|
||||
mx.async_eval(sent)
|
||||
_pending_prefill_sends.clear()
|
||||
|
||||
|
||||
def clear_prefill_sends() -> None:
|
||||
# Discard pending sends (e.g. on cancellation).
|
||||
_pending_prefill_sends.clear()
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
@@ -150,6 +165,7 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
self.group = group
|
||||
self.original_layer_signature = signature(self.original_layer.__call__)
|
||||
self.is_prefill: bool = False
|
||||
self.queue_sends: bool = False
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
cache = self.original_layer_signature.bind_partial(
|
||||
@@ -163,9 +179,14 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
mx.eval(output)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if self.queue_sends:
|
||||
_pending_prefill_sends.append(
|
||||
(output, (self.r + 1) % self.s, self.group)
|
||||
)
|
||||
else:
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
|
||||
# doesn't have .keys directly; access via first sub-cache.
|
||||
@@ -190,6 +211,12 @@ def set_pipeline_prefill(model: nn.Module, is_prefill: bool) -> None:
|
||||
layer.is_prefill = is_prefill
|
||||
|
||||
|
||||
def set_pipeline_queue_sends(model: nn.Module, queue_sends: bool) -> None:
|
||||
for layer in model.layers: # type: ignore
|
||||
if isinstance(layer, PipelineLastLayer):
|
||||
layer.queue_sends = queue_sends
|
||||
|
||||
|
||||
def get_inner_model(model: nn.Module) -> nn.Module:
|
||||
inner = getattr(model, "model", None)
|
||||
if isinstance(inner, nn.Module):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user