Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
24418e3788 fix: import ResponsesStreamEvent and DRY up SSE formatting in responses adapter
ResponsesStreamEvent was defined in openai_responses.py but never imported
or used anywhere. Import it in the responses adapter and add a _format_sse()
helper that uses the event's .type field, replacing 13 hardcoded SSE format
strings with a single typed function.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 09:14:46 -08:00
39 changed files with 765 additions and 681 deletions

View File

@@ -1,46 +0,0 @@
"""Type stubs for mlx_lm.models.glm_moe_dsa"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
from .base import BaseModelArgs
from .deepseek_v32 import Model as DSV32Model
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
index_head_dim: int
index_n_heads: int
index_topk: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: int
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_parameters: Dict[str, Any]
attention_bias: bool
rope_scaling: Dict[str, Any] | None
rope_theta: float | None
class Model(DSV32Model):
def __init__(self, config: ModelArgs) -> None: ...

123
Cargo.lock generated
View File

@@ -141,6 +141,12 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
[[package]]
name = "arrayvec"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "asn1-rs"
version = "0.7.1"
@@ -298,6 +304,19 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba"
[[package]]
name = "bigdecimal"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "560f42649de9fa436b73517378a147ec21f6c997a546581df4b4b31677828934"
dependencies = [
"autocfg",
"libm",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "bimap"
version = "0.6.3"
@@ -497,6 +516,15 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f421161cb492475f1661ddc9815a745a1c894592070661180fdec3d4872e9c3"
[[package]]
name = "convert_case"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9"
dependencies = [
"unicode-segmentation",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -718,6 +746,29 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "derive_more"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b"
dependencies = [
"convert_case",
"proc-macro2",
"quote",
"rustc_version",
"syn 2.0.111",
"unicode-xid",
]
[[package]]
name = "digest"
version = "0.10.7"
@@ -888,17 +939,22 @@ name = "exo_pyo3_bindings"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"env_logger",
"extend",
"futures",
"impl-trait-for-tuples",
"libp2p",
"log",
"networking",
"once_cell",
"pin-project",
"pyo3",
"pyo3-async-runtimes",
"pyo3-log",
"pyo3-stub-gen",
"thiserror 2.0.17",
"thread_local",
"tokio",
"util",
]
@@ -1584,6 +1640,17 @@ dependencies = [
"xmltree",
]
[[package]]
name = "impl-trait-for-tuples"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "indexmap"
version = "2.12.1"
@@ -1762,6 +1829,12 @@ version = "0.2.178"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
name = "libm"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de"
[[package]]
name = "libp2p"
version = "0.56.0"
@@ -2751,13 +2824,16 @@ name = "networking"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"libp2p",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
@@ -2842,6 +2918,17 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@@ -3192,14 +3279,28 @@ version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab53c047fcd1a1d2a8820fe84f05d6be69e9526be40cb03b73f86b6b03e6d87d"
dependencies = [
"bigdecimal",
"either",
"hashbrown 0.16.1",
"indexmap",
"indoc",
"inventory",
"libc",
"lock_api",
"memoffset",
"num-bigint",
"num-complex",
"num-rational",
"num-traits",
"once_cell",
"ordered-float",
"parking_lot",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"rust_decimal",
"smallvec",
"unindent",
]
@@ -3640,6 +3741,16 @@ dependencies = [
"tokio",
]
[[package]]
name = "rust_decimal"
version = "1.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35affe401787a9bd846712274d97654355d21b2a2c092a3139aabe31e9022282"
dependencies = [
"arrayvec",
"num-traits",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
@@ -4504,12 +4615,24 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-segmentation"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_names2"
version = "1.3.0"

View File

@@ -26,21 +26,49 @@ opt-level = 3
networking = { path = "rust/networking" }
util = { path = "rust/util" }
# Proc-macro authoring tools
syn = "2.0"
quote = "1.0"
proc-macro2 = "1.0"
darling = "0.20"
# Macro dependecies
extend = "1.2"
delegate = "0.13"
impl-trait-for-tuples = "0.2"
clap = "4.5"
derive_more = { version = "2.0.1", features = ["display"] }
pin-project = "1"
# Utility dependencies
itertools = "0.14"
thiserror = "2"
internment = "0.8"
recursion = "0.5"
regex = "1.11"
once_cell = "1.21"
thread_local = "1.1"
bon = "3.4"
generativity = "1.1"
anyhow = "1.0"
keccak-const = "0.2"
# Functional generics/lenses frameworks
frunk_core = "0.4"
frunk = "0.4"
frunk_utils = "0.2"
frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-util = "0.3"
futures-timer = "3.0"
# Data structures
either = "1.15"
ordered-float = "5.0"
ahash = "0.8"
# Tracing/logging
log = "0.4"

View File

@@ -72,23 +72,16 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
```bash
nix run .#exo
```
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
- [node](https://github.com/nodejs/node) (for building the dashboard)
```bash
brew install uv macmon node
```

View File

@@ -59,14 +59,13 @@
}
const sizeOptions: ImageGenerationParams["size"][] = [
"auto",
"512x512",
"768x768",
"1024x1024",
"1024x768",
"768x1024",
"1024x1536",
"1536x1024",
"1024x1365",
"1365x1024",
];
const qualityOptions: ImageGenerationParams["quality"][] = [
@@ -177,90 +176,92 @@
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
<!-- Size (hidden in edit mode - output size comes from input image) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
{params.size.toUpperCase()}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size.toUpperCase()}</span>
</button>
{/each}
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{/if}
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
<!-- Quality -->
<div class="flex items-center gap-1.5">
@@ -310,7 +311,7 @@
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {qualityDropdownPosition()
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
>

View File

@@ -306,14 +306,13 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
export interface ImageGenerationParams {
// Basic params
size:
| "auto"
| "512x512"
| "768x768"
| "1024x1024"
| "1024x768"
| "768x1024"
| "1024x1536"
| "1536x1024";
| "1024x1365"
| "1365x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
@@ -337,7 +336,7 @@ export interface EditingImage {
}
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
size: "auto",
size: "1024x1024",
quality: "medium",
outputFormat: "png",
numImages: 1,

View File

@@ -115,7 +115,7 @@
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{

View File

@@ -41,16 +41,16 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.7.dev20260218+14841977"; in
version = let v = "0.30.6"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "1484197707f35186ad3bd614357c7c47fdf86ebc";
hash = "sha256-FupCMoK/SF/ldfKuvMSAKECcOP8c+ANgkQlPZttDsLk=";
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
};
patches = [

View File

@@ -6,6 +6,8 @@ readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"aiofiles>=24.1.0",
"aiohttp>=3.12.14",
"types-aiofiles>=24.1.0.20250708",
"pydantic>=2.11.7",
"fastapi>=0.116.1",
"filelock>=3.18.0",
@@ -15,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx; sys_platform == 'darwin'",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.7",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -62,7 +64,6 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }

View File

@@ -58,21 +58,6 @@
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
mlx = ignoreMissing prev.mlx;
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [
final.nvidia-cublas
final.nvidia-cuda-nvrtc
final.nvidia-cudnn-cu13
final.nvidia-nccl-cu13
];
preFixup = ''
addAutoPatchelfSearchPath ${final.nvidia-cublas}
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
'';
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
torch = ignoreMissing prev.torch;
triton = ignoreMissing prev.triton;
}
@@ -89,25 +74,14 @@
linuxOverlay
]
);
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
"lib/python3.13/site-packages/mlx*"
"lib/python3.13/site-packages/nvidia*"
];
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
)).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
);
mkPythonScript = name: path: pkgs.writeShellApplication {
inherit name;

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5-8bit-MXFP8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 790517400864

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5-MXFP4-Q8"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "MXFP4-Q8"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 405478939008

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/GLM-5"
n_layers = 78
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM-5"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 1487822475264

View File

@@ -25,17 +25,17 @@ workspace = true
networking = { workspace = true }
# interop
pyo3 = { version = "0.27.2", features = [
# "abi3-py313", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.13
pyo3 = { version = "0.27.1", features = [
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
"nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
# "multiple-pymethods", # allows multiple #[pymethods] sections per class
"multiple-pymethods", # allows multiple #[pymethods] sections per class
# integrations with other libraries
# "arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
# "ordered-float", "rust_decimal", "smallvec",
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
"ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.17.2" }
@@ -45,6 +45,8 @@ pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
pin-project = { workspace = true }
# async runtime
@@ -52,11 +54,24 @@ tokio = { workspace = true, features = ["full", "tracing"] }
futures = { workspace = true }
# utility dependencies
once_cell = "1.21.3"
thread_local = "1.1.9"
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
# Tracing
#tracing = "0.1"
#tracing-subscriber = "0.3"
#console-subscriber = "0.1.5"
#tracing-log = "0.2.0"
log = { workspace = true }
env_logger = "0.11"
# Networking
libp2p = { workspace = true, features = ["full"] }

View File

@@ -6,7 +6,7 @@ use pyo3::marker::Ungil;
use pyo3::prelude::*;
use std::{
future::Future,
pin::Pin,
pin::{Pin, pin},
task::{Context, Poll},
};
@@ -33,6 +33,8 @@ where
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let waker = cx.waker();
Python::attach(|py| py.detach(|| self.project().0.poll(&mut Context::from_waker(waker))))
Python::with_gil(|py| {
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
})
}
}

View File

@@ -0,0 +1,240 @@
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
//!
//! Pattern examples include:
//! - Async task handles: with GC-integrated cleanup
//! - Sync/async callbacks from python: with propper eventloop handling
//!
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
//! - Store mutable fields in tokio's `Mutex<T>`
//! - For async code: take `&self` and `.lock().await`
//! - For sync code: take `&mut self` and `.get_mut()`
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
use futures::FutureExt as _;
use futures::future::BoxFuture;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::{
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
fn needs_tokio_runtime() {
tokio::runtime::Handle::current();
}
type SyncCallback = Box<dyn Fn() + Send + Sync>;
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
enum AsyncTaskMessage {
SyncCallback(SyncCallback),
AsyncCallback(AsyncCallback),
}
async fn async_task(
sender: mpsc::UnboundedSender<()>,
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
) {
log::info!("RUST: async task started");
// task state
let mut interval = tokio::time::interval(Duration::from_secs(1));
let mut sync_cbs: Vec<SyncCallback> = vec![];
let mut async_cbs: Vec<AsyncCallback> = vec![];
loop {
tokio::select! {
// handle incoming messages from task-handle
message = receiver.recv() => {
// handle closed channel by exiting
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming event
match message {
AsyncTaskMessage::SyncCallback(cb) => {
sync_cbs.push(cb);
}
AsyncTaskMessage::AsyncCallback(cb) => {
async_cbs.push(cb);
}
}
}
// handle all other events
_ = interval.tick() => {
log::info!("RUST: async task tick");
// call back all sync callbacks
for cb in &sync_cbs {
cb();
}
// call back all async callbacks
for cb in &async_cbs {
cb().await;
}
// send event on unbounded channel
sender.send(()).expect("handle receiver cannot be closed/dropped");
}
}
}
log::info!("RUST: async task stopped");
}
// #[gen_stub_pyclass]
#[pyclass(name = "AsyncTaskHandle")]
#[derive(Debug)]
struct PyAsyncTaskHandle {
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
receiver: mpsc::UnboundedReceiver<()>,
}
#[allow(clippy::expect_used)]
impl PyAsyncTaskHandle {
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
self.sender
.as_mut()
.expect("The sender should only be None after de-initialization.")
}
const fn new(
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
receiver: mpsc::UnboundedReceiver<()>,
) -> Self {
Self {
sender: Some(sender),
receiver,
}
}
}
// #[gen_stub_pymethods]
#[pymethods]
impl PyAsyncTaskHandle {
#[new]
fn py_new(py: Python<'_>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channel TOWARDS our task
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
// create communication channel FROM our task
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
// perform necessary setup within tokio context - or it crashes
let () = get_runtime().block_on(async { needs_tokio_runtime() });
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
_ = get_runtime().spawn_with_scope(py, async move {
async_task(t_sender, t_receiver).await;
});
Ok(Self::new(h_sender, h_receiver))
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_sync_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], None]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
})))
.pyerr()?;
Ok(())
}
/// NOTE: exceptions in callbacks are silently ignored until end of execution
fn add_async_callback(
&self,
// #[gen_stub(override_type(
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
// imports=("collections.abc")
// ))]
callback: Py<PyAny>,
) -> PyResult<()> {
// blocking call to async method -> can do non-blocking if needed
self.sender()
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
let c = Python::with_gil(|py| callback.clone_ref(py));
async move {
if let Some(f) = Python::with_gil(|py| {
let coroutine = c.call0(py).write_unraisable_with(py)?;
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
.write_unraisable_with(py)
}) {
_ = f.await.write_unraisable();
}
}
.boxed()
})))
.pyerr()?;
Ok(())
}
async fn receive_unit(&mut self) -> PyResult<()> {
self.receiver
.recv()
.await
.ok_or(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
))
}
fn drain_units(&mut self) -> PyResult<i32> {
let mut cnt = 0;
loop {
match self.receiver.try_recv() {
Err(TryRecvError::Disconnected) => {
return Err(PyErr::new::<PyRuntimeError, _>(
"cannot receive unit on closed channel",
));
}
Err(TryRecvError::Empty) => return Ok(cnt),
Ok(()) => {
cnt += 1;
continue;
}
}
}
}
// #[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
}
// #[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
}
}
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAsyncTaskHandle>()?;
Ok(())
}

View File

@@ -17,6 +17,7 @@
extern crate core;
mod allow_threading;
mod examples;
pub(crate) mod networking;
pub(crate) mod pylibp2p;
@@ -24,6 +25,7 @@ use crate::networking::networking_submodule;
use crate::pylibp2p::ident::ident_submodule;
use crate::pylibp2p::multiaddr::multiaddr_submodule;
use pyo3::prelude::PyModule;
use pyo3::prelude::*;
use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer;
@@ -34,10 +36,14 @@ pub(crate) mod r#const {
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {
use std::error::Error;
use std::marker::Tuple;
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
Fn<Args, Output = Output> + Send + 'static;
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
pub type AnyResult<T> = Result<T, AnyError>;
}
/// Namespace for crate-wide extension traits/methods
@@ -45,6 +51,7 @@ pub(crate) mod ext {
use crate::allow_threading::AllowThreads;
use extend::ext;
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
use pyo3::marker::Ungil;
use pyo3::types::PyBytes;
use pyo3::{Py, PyErr, PyResult, Python};
use tokio::runtime::Runtime;
@@ -55,7 +62,7 @@ pub(crate) mod ext {
#[ext(pub, name = ByteArrayExt)]
impl [u8] {
fn pybytes(&self) -> Py<PyBytes> {
Python::attach(|py| PyBytes::new(py, self).unbind())
Python::with_gil(|py| PyBytes::new(py, self).unbind())
}
}
@@ -91,7 +98,7 @@ pub(crate) mod ext {
#[ext(pub, name = PyResultExt)]
impl<T> PyResult<T> {
fn write_unraisable(self) -> Option<T> {
Python::attach(|py| self.write_unraisable_with(py))
Python::with_gil(|py| self.write_unraisable_with(py))
}
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
@@ -168,6 +175,24 @@ pub(crate) mod ext {
}
}
pub(crate) mod private {
use std::marker::Sized;
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
#[repr(transparent)]
pub(crate) struct ClonePy<T>(pub Py<T>);
impl<T> Clone for ClonePy<T> {
fn clone(&self) -> Self {
Python::with_gil(|py| Self(self.0.clone_ref(py)))
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.

View File

@@ -11,9 +11,9 @@ use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt a
use crate::pyclass;
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use libp2p::{gossipsub, mdns};
use networking::discovery;
use networking::swarm::create_swarm;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
@@ -25,7 +25,7 @@ use tokio::sync::{Mutex, mpsc, oneshot};
mod exception {
use pyo3::types::PyTuple;
use pyo3::{exceptions::PyException, prelude::*};
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
use pyo3_stub_gen::derive::*;
#[gen_stub_pyclass]
@@ -155,6 +155,7 @@ async fn networking_task(
) {
use SwarmEvent::*;
use ToTask::*;
use mdns::Event::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
@@ -484,7 +485,7 @@ impl PyNetworkingHandle {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
topic,

View File

@@ -19,6 +19,8 @@ either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
@@ -27,6 +29,11 @@ futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
@@ -34,4 +41,4 @@ keccak-const = { workspace = true }
log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
libp2p = { workspace = true, features = ["full"] }

View File

@@ -24,8 +24,8 @@ use libp2p::{
swarm::{NetworkBehaviour, SwarmEvent},
tcp, yamux,
};
use std::error::Error;
use std::time::Duration;
use std::{error::Error, hash::Hash};
use tokio::{io, io::AsyncBufReadExt, select};
use tracing_subscriber::EnvFilter;

View File

@@ -1,4 +1,5 @@
use crate::ext::MultiaddrExt;
use crate::keep_alive;
use delegate::delegate;
use either::Either;
use futures::FutureExt;

View File

@@ -0,0 +1,44 @@
use delegate::delegate;
use libp2p::swarm::handler::ConnectionEvent;
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
use std::task::{Context, Poll};
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
/// the connection alive.
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnectionHandler(dummy::ConnectionHandler);
impl ConnectionHandler {
pub fn new() -> Self {
ConnectionHandler(dummy::ConnectionHandler)
}
}
impl handler::ConnectionHandler for ConnectionHandler {
// delegate types and implementation mostly to dummy handler
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
type InboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
type OutboundProtocol =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
type OutboundOpenInfo =
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
delegate! {
to self.0 {
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
}
}
// specifically override this to force connection to stay alive
fn connection_keep_alive(&self) -> bool {
true
}
}

View File

@@ -3,7 +3,19 @@
//! this is here as a placeholder documentation
//!
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
// #![feature(stmt_expr_attributes)]
// #![feature(unboxed_closures)]
// #![feature(assert_matches)]
// #![feature(async_fn_in_dyn_trait)]
// #![feature(async_for_loop)]
// #![feature(auto_traits)]
// #![feature(negative_impls)]
pub mod discovery;
pub mod keep_alive;
pub mod swarm;
/// Namespace for all the type/trait aliases used by this crate.
@@ -42,3 +54,11 @@ pub(crate) mod ext {
}
}
}
pub(crate) mod private {
#![allow(dead_code)]
/// Sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}

View File

@@ -47,7 +47,6 @@ class DownloadCoordinator:
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
offline: bool = False
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -63,8 +62,6 @@ class DownloadCoordinator:
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
@@ -110,17 +107,13 @@ class DownloadCoordinator:
self._last_progress_time[model_id] = current_time()
async def run(self) -> None:
logger.info(
f"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}"
)
if not self.offline:
self._test_internet_connection()
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
if not self.offline:
tg.start_soon(self._check_internet_connection)
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
@@ -209,20 +202,6 @@ class DownloadCoordinator:
)
return
if self.offline:
logger.warning(
f"Offline mode: model {model_id} is not fully available locally, cannot download"
)
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=f"Model files not found locally in offline mode: {model_id}",
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(NodeDownloadProgress(download_progress=failed))
return
# Start actual download
self._start_download_task(shard, initial_progress)

View File

@@ -8,13 +8,13 @@ import traceback
from collections.abc import Awaitable
from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal, cast
from typing import Callable, Literal
from urllib.parse import urljoin
import aiofiles
import aiofiles.os as aios
import aiohttp
import certifi
import httpx
from huggingface_hub import (
snapshot_download, # pyright: ignore[reportUnknownVariableType]
)
@@ -330,17 +330,17 @@ async def _fetch_file_list(
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.get(url, headers=headers) as response,
):
response = await session.get(url, headers=headers)
if response.status_code in [401, 403]:
msg = await _build_auth_error_message(response.status_code, model_id)
if response.status in [401, 403]:
msg = await _build_auth_error_message(response.status, model_id)
raise HuggingFaceAuthenticationError(msg)
elif response.status_code == 429:
elif response.status == 429:
raise HuggingFaceRateLimitError(
f"Couldn't download {model_id} because of HuggingFace rate limit."
)
elif response.status_code == 200:
data_json = response.text
elif response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
for item in data:
@@ -353,7 +353,7 @@ async def _fetch_file_list(
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status_code}")
raise Exception(f"Failed to fetch file list: {response.status}")
async def get_download_headers() -> dict[str, str]:
@@ -361,29 +361,34 @@ async def get_download_headers() -> dict[str, str]:
def create_http_session(
auto_decompress: bool = False,
timeout_profile: Literal["short", "long"] = "long",
) -> httpx.AsyncClient:
) -> aiohttp.ClientSession:
if timeout_profile == "short":
total_timeout = 30
connect_timeout = 10
read_timeout = 30
sock_read_timeout = 30
sock_connect_timeout = 10
else:
total_timeout = 1800
connect_timeout = 60
read_timeout = 60
sock_read_timeout = 60
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
)
connector = aiohttp.TCPConnector(ssl=ssl_context)
# default here is to load env vars
return httpx.AsyncClient(
verify=ssl_context,
timeout=httpx.Timeout(
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,
read=read_timeout,
write=total_timeout,
pool=total_timeout,
sock_read=sock_read_timeout,
sock_connect=sock_connect_timeout,
),
)
@@ -410,28 +415,26 @@ async def file_meta(
headers = await get_download_headers()
async with (
create_http_session(timeout_profile="short") as session,
session.stream("HEAD", url, headers=headers) as r,
session.head(url, headers=headers) as r,
):
if r.status_code == 307:
if r.status == 307:
# On redirect, only trust Hugging Face's x-linked-* headers.
x_linked_size = cast(str | None, r.headers.get("x-linked-size"))
x_linked_etag = cast(str | None, r.headers.get("x-linked-etag"))
x_linked_size = r.headers.get("x-linked-size")
x_linked_etag = r.headers.get("x-linked-etag")
if x_linked_size and x_linked_etag:
content_length = int(x_linked_size)
etag = trim_etag(x_linked_etag)
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = cast(str | None, r.headers.get("location"))
redirected_location = r.headers.get("location")
return await file_meta(model_id, revision, path, redirected_location)
if r.status_code in [401, 403]:
msg = await _build_auth_error_message(r.status_code, model_id)
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
raise HuggingFaceAuthenticationError(msg)
content_length = cast(
str | None,
r.headers.get("x-linked-size") or r.headers.get("content-length"),
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
content_length = 0 if content_length is None else int(content_length)
etag = cast(str | None, r.headers.get("x-linked-etag") or r.headers.get("etag"))
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
etag = trim_etag(etag)
@@ -445,13 +448,12 @@ async def download_file_with_retry(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
skip_internet: bool = False,
) -> Path:
n_attempts = 3
for attempt in range(n_attempts):
try:
return await _download_file(
model_id, revision, path, target_dir, on_progress, skip_internet
model_id, revision, path, target_dir, on_progress
)
except HuggingFaceAuthenticationError:
raise
@@ -485,14 +487,10 @@ async def _download_file(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
skip_internet: bool = False,
) -> Path:
target_path = target_dir / path
if await aios.path.exists(target_path):
if skip_internet:
return target_path
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
@@ -512,11 +510,6 @@ async def _download_file(
)
return target_path
if skip_internet:
raise FileNotFoundError(
f"File {path} not found locally and cannot download in offline mode"
)
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -534,20 +527,20 @@ async def _download_file(
n_read = resume_byte_pos or 0
async with (
create_http_session(timeout_profile="long") as session,
session.stream("GET", url, headers=headers, follow_redirects=True) as r,
session.get(url, headers=headers) as r,
):
if r.status_code == 404:
if r.status == 404:
raise FileNotFoundError(f"File not found: {url}")
if r.status_code in [401, 403]:
msg = await _build_auth_error_message(r.status_code, model_id)
if r.status in [401, 403]:
msg = await _build_auth_error_message(r.status, model_id)
raise HuggingFaceAuthenticationError(msg)
assert r.status_code in [200, 206], (
f"Failed to download {path} from {url}: {r.status_code}"
assert r.status in [200, 206], (
f"Failed to download {path} from {url}: {r.status}"
)
async with aiofiles.open(
partial_path, "ab" if resume_byte_pos else "wb"
) as f:
async for chunk in r.aiter_bytes(8 * 1024 * 1024):
while chunk := await r.content.read(8 * 1024 * 1024):
n_read = n_read + (await f.write(chunk))
on_progress(n_read, length, False)
@@ -821,7 +814,6 @@ async def download_shard(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
skip_internet=skip_internet,
)
if not skip_download:

View File

@@ -189,7 +189,6 @@ class ResumableShardDownloader(ShardDownloader):
try:
yield await task
except Exception as e:
task.cancel()
logger.warning(f"Error downloading shard: {type(e).__name__}")
async def get_shard_download_status_for_shard(

View File

@@ -1,230 +0,0 @@
"""Tests for offline/air-gapped mode."""
from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import AsyncMock, patch
import aiofiles
import aiofiles.os as aios
import pytest
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
download_file_with_retry,
fetch_file_list_with_cache,
)
from exo.shared.types.common import ModelId
from exo.shared.types.worker.downloads import FileListEntry
@pytest.fixture
def model_id() -> ModelId:
return ModelId("test-org/test-model")
@pytest.fixture
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
yield models_dir
class TestDownloadFileOffline:
"""Tests for _download_file with skip_internet=True."""
async def test_returns_local_file_without_http_verification(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists locally, return it immediately
without making any HTTP calls (no file_meta verification)."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"model weights data")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_raises_file_not_found_for_missing_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file does NOT exist locally,
raise FileNotFoundError instead of attempting download."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError, match="offline mode"):
await _download_file(
model_id,
"main",
"missing_model.safetensors",
target_dir,
skip_internet=True,
)
async def test_returns_local_file_in_subdirectory(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists in a subdirectory,
return it without HTTP calls."""
target_dir = tmp_path / "downloads"
subdir = target_dir / "transformer"
await aios.makedirs(subdir, exist_ok=True)
local_file = subdir / "diffusion_pytorch_model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"weights")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"transformer/diffusion_pytorch_model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
class TestDownloadFileWithRetryOffline:
"""Tests for download_file_with_retry with skip_internet=True."""
async def test_propagates_skip_internet_to_download_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Verify skip_internet is passed through to _download_file."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "config.json"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b'{"model_type": "qwen2"}')
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_file_not_found_does_not_retry(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""FileNotFoundError from offline mode should not trigger retries."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError):
await download_file_with_retry(
model_id,
"main",
"nonexistent.safetensors",
target_dir,
skip_internet=True,
)
class TestFetchFileListOffline:
"""Tests for fetch_file_list_with_cache with skip_internet=True."""
async def test_uses_cached_file_list(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and cache file exists, use it without network."""
from pydantic import TypeAdapter
cache_dir = temp_models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
cached_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
FileListEntry(type="file", path="config.json", size=200),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
assert result == cached_list
mock_fetch.assert_not_called()
async def test_falls_back_to_local_directory_scan(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and no cache but local files exist,
build file list from local directory."""
import json
model_dir = temp_models_dir / model_id.normalize()
await aios.makedirs(model_dir, exist_ok=True)
async with aiofiles.open(model_dir / "config.json", "w") as f:
await f.write('{"model_type": "qwen2"}')
index_data = {
"metadata": {},
"weight_map": {"model.layers.0.weight": "model.safetensors"},
}
async with aiofiles.open(model_dir / "model.safetensors.index.json", "w") as f:
await f.write(json.dumps(index_data))
async with aiofiles.open(model_dir / "model.safetensors", "wb") as f:
await f.write(b"x" * 500)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
mock_fetch.assert_not_called()
paths = {entry.path for entry in result}
assert "config.json" in paths
assert "model.safetensors" in paths
async def test_raises_when_no_cache_and_no_local_files(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and neither cache nor local files exist,
raise FileNotFoundError."""
with pytest.raises(FileNotFoundError, match="No internet"):
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)

View File

@@ -39,7 +39,6 @@ class Node:
node_id: NodeId
event_index_counter: Iterator[int]
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
@@ -69,7 +68,6 @@ class Node:
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
)
else:
download_coordinator = None
@@ -134,7 +132,6 @@ class Node:
api,
node_id,
event_index_counter,
args.offline,
)
async def run(self):
@@ -225,7 +222,6 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
@@ -264,9 +260,6 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
@@ -289,7 +282,6 @@ class Args(CamelCaseModel):
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -337,11 +329,6 @@ class Args(CamelCaseModel):
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
parser.add_argument(
"--offline",
action="store_true",
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -26,6 +26,7 @@ from exo.shared.types.openai_responses import (
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
@@ -33,6 +34,11 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str):
@@ -207,13 +213,13 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
yield _format_sse(created_event)
# response.in_progress
in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
yield _format_sse(in_progress_event)
# response.output_item.added
initial_item = ResponseMessageItem(
@@ -224,7 +230,7 @@ async def generate_responses_stream(
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
yield _format_sse(item_added)
# response.content_part.added
initial_part = ResponseOutputText(text="")
@@ -235,7 +241,7 @@ async def generate_responses_stream(
content_index=0,
part=initial_part,
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
yield _format_sse(part_added)
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
@@ -266,7 +272,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_item,
)
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
yield _format_sse(fc_added)
# response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -275,7 +281,7 @@ async def generate_responses_stream(
output_index=next_output_index,
delta=tool.arguments,
)
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
yield _format_sse(args_delta)
# response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -285,7 +291,7 @@ async def generate_responses_stream(
name=tool.name,
arguments=tool.arguments,
)
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
yield _format_sse(args_done)
# response.output_item.done
fc_done_item = ResponseFunctionCallItem(
@@ -300,7 +306,7 @@ async def generate_responses_stream(
output_index=next_output_index,
item=fc_done_item,
)
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
yield _format_sse(fc_item_done)
function_call_items.append(fc_done_item)
next_output_index += 1
@@ -316,7 +322,7 @@ async def generate_responses_stream(
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
yield _format_sse(delta_event)
# response.output_text.done
text_done = ResponseTextDoneEvent(
@@ -326,7 +332,7 @@ async def generate_responses_stream(
content_index=0,
text=accumulated_text,
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
yield _format_sse(text_done)
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
@@ -337,7 +343,7 @@ async def generate_responses_stream(
content_index=0,
part=final_part,
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
yield _format_sse(part_done)
# response.output_item.done
final_message_item = ResponseMessageItem(
@@ -348,7 +354,7 @@ async def generate_responses_stream(
item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
yield _format_sse(item_done)
# Create usage from usage data if available
usage = None
@@ -373,4 +379,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response
)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
yield _format_sse(completed_event)

View File

@@ -85,7 +85,6 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
ImageListItem,
ImageListResponse,
ImageSize,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -101,7 +100,6 @@ from exo.shared.types.api import (
TraceRankStats,
TraceResponse,
TraceStatsResponse,
normalize_image_size,
)
from exo.shared.types.chunks import (
ErrorChunk,
@@ -753,11 +751,9 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(ModelId(payload.model))
payload = payload.model_copy(
update={
"model": await self._validate_image_model(ModelId(payload.model)),
"advanced_params": _ensure_seed(payload.advanced_params),
}
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
@@ -1013,13 +1009,12 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.stream = False
payload.partial_images = 0
payload = payload.model_copy(
update={
"model": await self._validate_image_model(ModelId(payload.model)),
"stream": False,
"partial_images": 0,
"advanced_params": _ensure_seed(payload.advanced_params),
}
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
@@ -1040,7 +1035,7 @@ class API:
prompt: str,
model: ModelId,
n: int,
size: ImageSize,
size: str,
response_format: Literal["url", "b64_json"],
input_fidelity: Literal["low", "high"],
stream: bool,
@@ -1110,7 +1105,7 @@ class API:
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str | None = Form(None),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
stream: str = Form("false"),
@@ -1136,7 +1131,7 @@ class API:
prompt=prompt,
model=ModelId(model),
n=n,
size=normalize_image_size(size),
size=size,
response_format=response_format,
input_fidelity=input_fidelity,
stream=stream_bool,
@@ -1172,7 +1167,7 @@ class API:
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str | None = Form(None),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
quality: Literal["high", "medium", "low"] = Form("medium"),
@@ -1192,7 +1187,7 @@ class API:
prompt=prompt,
model=ModelId(model),
n=n,
size=normalize_image_size(size),
size=size,
response_format=response_format,
input_fidelity=input_fidelity,
stream=False,

View File

@@ -44,8 +44,7 @@ async def _refresh_card_cache():
async for toml_file in path.rglob("*.toml"):
try:
card = await ModelCard.load_from_path(toml_file)
if card.model_id not in _card_cache:
_card_cache[card.model_id] = card
_card_cache[card.model_id] = card
except (ValidationError, TOMLKitError):
pass
@@ -183,7 +182,6 @@ class ConfigData(BaseModel):
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["GlmMoeDsaForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],

View File

@@ -1,9 +1,9 @@
import time
from collections.abc import Generator
from typing import Annotated, Any, Literal, get_args
from typing import Annotated, Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
@@ -262,27 +262,6 @@ class DeleteInstanceResponse(BaseModel):
instance_id: InstanceId
ImageSize = Literal[
"auto",
"512x512",
"768x768",
"1024x768",
"768x1024",
"1024x1024",
"1024x1536",
"1536x1024",
]
def normalize_image_size(v: object) -> ImageSize:
"""Shared validator for ImageSize fields: maps None → "auto" and rejects invalid values."""
if v is None:
return "auto"
if v not in get_args(ImageSize):
raise ValueError(f"Invalid size: {v!r}. Must be one of {get_args(ImageSize)}")
return v # pyright: ignore[reportReturnType]
class AdvancedImageParams(BaseModel):
seed: Annotated[int, Field(ge=0)] | None = None
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
@@ -302,7 +281,7 @@ class ImageGenerationTaskParams(BaseModel):
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: ImageSize = "auto"
size: str | None = "1024x1024"
stream: bool | None = False
style: str | None = "vivid"
user: str | None = None
@@ -310,11 +289,6 @@ class ImageGenerationTaskParams(BaseModel):
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
@field_validator("size", mode="before")
@classmethod
def normalize_size(cls, v: object) -> ImageSize:
return normalize_image_size(v)
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
bench: bool = True
@@ -331,18 +305,13 @@ class ImageEditsTaskParams(BaseModel):
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: ImageSize = "auto"
size: str | None = "1024x1024"
image_strength: float | None = 0.7
stream: bool = False
partial_images: int | None = 0
advanced_params: AdvancedImageParams | None = None
bench: bool = False
@field_validator("size", mode="before")
@classmethod
def normalize_size(cls, v: object) -> ImageSize:
return normalize_image_size(v)
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
if name == "image_data":

View File

@@ -1,8 +1,10 @@
import time
from collections.abc import Hashable
from typing import Generic, TypeVar
K = TypeVar("K")
class KeyedBackoff[K: Hashable]:
class KeyedBackoff(Generic[K]):
"""Tracks exponential backoff state per key."""
def __init__(self, base: float = 0.5, cap: float = 10.0):

View File

@@ -14,7 +14,6 @@ from exo.shared.types.api import (
ImageEditsTaskParams,
ImageGenerationStats,
ImageGenerationTaskParams,
ImageSize,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.runner_response import (
@@ -24,9 +23,9 @@ from exo.shared.types.worker.runner_response import (
from exo.worker.engines.image.distributed_model import DistributedImageModel
def parse_size(size_str: ImageSize) -> tuple[int, int]:
def parse_size(size_str: str | None) -> tuple[int, int]:
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
if size_str == "auto":
if not size_str:
return (1024, 1024)
try:
@@ -110,9 +109,6 @@ def generate_image(
# Decode base64 image data and save to temp file
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
if task.size == "auto":
with Image.open(image_path) as img:
width, height = img.size
for image_num in range(num_images):
# Increment seed for each image to ensure unique results

View File

@@ -163,14 +163,11 @@ class PipelineLastLayer(CustomMlxLayer):
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
# CacheList (used by MLA models like DeepSeekV32, GLM MoE DSA)
# doesn't have .keys directly; access via first sub-cache.
_cache = cache[0] if hasattr(cache, "caches") else cache # type: ignore
_cache.keys = mx.depends(_cache.keys, output) # type: ignore
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
if self.is_prefill:
mx.eval(output)
if cache is not None:
mx.eval(_cache.keys) # type: ignore
mx.eval(cache.keys) # type: ignore
if not self.is_prefill:
output = mx.distributed.all_gather(output, group=self.group)[
@@ -310,9 +307,7 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None:
last = cache[-1] # type: ignore
dep_cache = last[0] if hasattr(last, "caches") else last # type: ignore
dep_cache.keys = mx.depends(dep_cache.keys, logits) # type: ignore
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
return logits
@@ -338,9 +333,7 @@ def patch_tensor_model[T](model: T) -> T:
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
last = cache[-1] # pyright: ignore[reportAny]
dep_cache = last[0] if hasattr(last, "caches") else last # pyright: ignore[reportAny]
dep_cache.keys = mx.depends(dep_cache.keys, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
@@ -554,12 +547,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
@@ -590,18 +581,12 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
# Shard the MoE.
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
else:
if getattr(layer.mlp, "shared_experts", None) is not None:
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.mlp.shared_experts.down_proj
)
self.all_to_sharded_linear_in_place(
layer.mlp.shared_experts.up_proj
)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
@@ -794,7 +779,8 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
# Shard the MoE.
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.gate_proj
)
@@ -907,7 +893,8 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE.
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
if isinstance(layer.mlp, (Qwen3MoeSparseMoeBlock, Qwen3NextSparseMoeBlock)):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)

View File

@@ -57,7 +57,6 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
group: mx.distributed.Group | None = None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
@@ -87,9 +86,6 @@ def prefill(
set_pipeline_prefill(model, is_prefill=True)
mx_barrier(group)
logger.info("Starting prefill")
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
@@ -309,9 +305,16 @@ def mlx_generate(
)
max_stop_len = max((len(s) for s in stop_sequences), default=0)
mx_barrier(group)
logger.info("Starting prefill")
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
model, tokenizer, sampler, prompt_tokens[:-1], caches, group
model,
tokenizer,
sampler,
prompt_tokens[:-1],
caches,
)
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None
@@ -328,7 +331,6 @@ def mlx_generate(
think_start = tokenizer.think_start
think_end = tokenizer.think_end
logger.info("Starting decode")
mx_barrier(group)
for completion_tokens, out in enumerate(

View File

@@ -285,12 +285,10 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm-5" in model_id_lower or "glm-4.7" in model_id_lower:
# For GLM-5 and GLM-4.7
elif "glm-4.7-flash" in model_id_lower:
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
# For GLM-4.5 and older
return [151336, 151329, 151338]
return None

View File

@@ -191,7 +191,7 @@ class RunnerSupervisor:
logger.info("Checking runner's status")
if self.runner_process.is_alive():
logger.info("Runner was found to be alive, attempting to join process")
await to_thread.run_sync(self.runner_process.join, 5)
await to_thread.run_sync(self.runner_process.join, 1)
rc = self.runner_process.exitcode
logger.info(f"RunnerSupervisor exited with exit code {rc}")
if rc == 0:

52
uv.lock generated
View File

@@ -367,6 +367,7 @@ version = "0.3.0"
source = { editable = "." }
dependencies = [
{ name = "aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "aiohttp", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "exo-pyo3-bindings", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "fastapi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -376,8 +377,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -388,6 +389,7 @@ dependencies = [
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "zstandard", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
@@ -404,6 +406,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "aiofiles", specifier = ">=24.1.0" },
{ name = "aiohttp", specifier = ">=3.12.14" },
{ name = "anyio", specifier = "==4.11.0" },
{ name = "exo-pyo3-bindings", editable = "rust/exo_pyo3_bindings" },
{ name = "fastapi", specifier = ">=0.116.1" },
@@ -413,9 +416,9 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.5" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
{ name = "mlx-lm", specifier = "==0.30.7" },
{ name = "mlx-lm", specifier = "==0.30.6" },
{ name = "msgspec", specifier = ">=0.19.0" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "pillow", specifier = ">=11.0,<12.0" },
@@ -425,6 +428,7 @@ requires-dist = [
{ name = "rustworkx", specifier = ">=0.17.1" },
{ name = "tiktoken", specifier = ">=0.12.0" },
{ name = "tomlkit", specifier = ">=0.14.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
{ name = "zstandard", specifier = ">=0.23.0" },
]
@@ -1016,8 +1020,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1044,12 +1048,18 @@ wheels = [
name = "mlx"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"sys_platform == 'linux'",
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/ae/5b/e460e144a34d5529e010056cccf50b538d56ed001473bc6b246018fd58cb/mlx-0.30.6-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ed86f8bffc174c2f259ca589ea25464c96cf69d1bb457074a2bf2ef53737e54f", size = 573515, upload-time = "2026-02-06T03:45:23.405Z" },
{ url = "https://files.pythonhosted.org/packages/60/25/69833fefb9a3fef30b56792b1bcd022496c4fea83e45411d289b77ef7546/mlx-0.30.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:c52294958269e20f300639a17c1900ca8fc737d859ddda737f9811e94bd040e5", size = 573516, upload-time = "2026-02-06T03:45:24.618Z" },
{ url = "https://files.pythonhosted.org/packages/9c/6a/7e7fbeebc5cb51b6a5eba96b263a6298707bcbdc059f4b0b73e088bc3dea/mlx-0.30.6-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:b5b6636f7c49a4d86d8ec82643b972f45a144a7a9f3a967b27b2e6e22cf71e6a", size = 573592, upload-time = "2026-02-06T03:45:25.928Z" },
{ url = "https://files.pythonhosted.org/packages/93/06/280f6f2ba80520a7109730425eda0d966658793aa0d02d8be8d351f75253/mlx-0.30.6-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:67e6c9e30a9faeacc209917ef5523177cf9b086914b6b5d83ff886e4294b727d", size = 622011, upload-time = "2026-02-06T03:45:28.165Z" },
{ url = "https://files.pythonhosted.org/packages/fe/35/f872afbee9c079cc69924d9e9c46f5663adb7da58cba3511db082dd307c1/mlx-0.30.6-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:47db8b16fcb6f6c5a47c0bdb24ed377b41237017ac93aa6cb6aa206c9bdf82e4", size = 663650, upload-time = "2026-02-06T03:45:30.315Z" },
{ url = "https://files.pythonhosted.org/packages/60/23/361dc7a5797634e4d7e9bdd6564c6b28f9b1246672632def2f91bf066b18/mlx-0.30.6-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:78804a89dcff4a838f7c2da72392fe87a523e95122a3c840e53df019122aad45", size = 575028, upload-time = "2026-02-06T03:45:31.549Z" },
{ url = "https://files.pythonhosted.org/packages/a8/69/1854484d414171586814dfbe8def95f75c4ea2c7341ba13ba8ee675f7c62/mlx-0.30.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:ec13584ab069665cc7ad34a05494d9291cd623aef6ae96be48875fc87cfc25d6", size = 575026, upload-time = "2026-02-06T03:45:33.072Z" },
{ url = "https://files.pythonhosted.org/packages/6b/b8/3adbc441924209a7e4c568308b2a0b54bd09aee6a68db5bae85304791e54/mlx-0.30.6-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:b2c5e8a090a753ef99a1380a4d059c983083f36198864f6df9faaf1223d083df", size = 575041, upload-time = "2026-02-06T03:45:34.814Z" },
{ url = "https://files.pythonhosted.org/packages/3f/54/9d9e06804fb2088202a2cdf60458e00b221f71420bea285720b60f9e82b5/mlx-0.30.6-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:9ceddede4af0de31d1f6b3099f70e5469d60cd7c546975dedbdbeab3519cab3f", size = 624002, upload-time = "2026-02-06T03:45:36Z" },
{ url = "https://files.pythonhosted.org/packages/42/92/3140a15a50cb1f9267a6552171e1dfa577861de53e093124bc43707f2a0e/mlx-0.30.6-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:4a6ffd2d16728cf95f63a1b555d7c2eaeea686a0e6b73228bd265411cb5d77a4", size = 663569, upload-time = "2026-02-06T03:45:37.242Z" },
]
@@ -1062,14 +1072,6 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.7.dev20260218+14841977"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.6"
@@ -1096,20 +1098,30 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.7"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260218+14841977", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#1484197707f35186ad3bd614357c7c47fdf86ebc" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/66/0d/56542e2ae13ec6f542d3977d7cff89a205d4f6c5122e0ce23f33265f61c9/mlx_lm-0.30.7.tar.gz", hash = "sha256:e5f31ac58d9f2381f28e1ba639ff903e64f7cff1bdc245c0bc97f72264be329c", size = 275764, upload-time = "2026-02-12T18:41:11.86Z" }
sdist = { url = "https://files.pythonhosted.org/packages/76/cb/815deddc8699b1f694d7e1f9cbed52934c03a8b49432c8add72932bb2f0b/mlx_lm-0.30.6.tar.gz", hash = "sha256:807e042d7040268f1b19190b7eaefd8b2efbff5590a65460974ad4225b91dda1", size = 271733, upload-time = "2026-02-04T21:27:45.741Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/17/a41c798a3d9cbdc47f39c6db5bba4c2cd199203ead26bf911cb03b644070/mlx_lm-0.30.7-py3-none-any.whl", hash = "sha256:17442a4bf01c4c2d3bca1e647712fe44f19890c3f1eadc8589d389e57b44b9bf", size = 386591, upload-time = "2026-02-12T18:41:10.236Z" },
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
]
[[package]]
name = "mlx-metal"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/85/44406b521f920248fad621334d4dc15e77660a494edf890e7cbee33bf38d/mlx_metal-0.30.6-py3-none-macosx_14_0_arm64.whl", hash = "sha256:ea6d0c973def9a5b4f652cc77036237db3f88c9d0af63701d76b5fddde99b820", size = 38437818, upload-time = "2026-02-06T03:44:56.19Z" },
{ url = "https://files.pythonhosted.org/packages/d0/cb/10a516995f7d0c154b0d7e633c54b51e96977a86a355105b6474cfcbe0d0/mlx_metal-0.30.6-py3-none-macosx_15_0_arm64.whl", hash = "sha256:0f8cb94634d07e06a372d6ad9a090f38a18bab1ff19a140aede60eacf707bb94", size = 38433701, upload-time = "2026-02-06T03:44:59.678Z" },
{ url = "https://files.pythonhosted.org/packages/4c/7d/70cb272f7373c334709f210ed8420511fc9d64d05a7a646c0b3b94c29c04/mlx_metal-0.30.6-py3-none-macosx_26_0_arm64.whl", hash = "sha256:d761ae26304f2c4b454eeea7f612a56919d9e5e57dbb1dc0788f8e34aa6f41c2", size = 47718448, upload-time = "2026-02-06T03:45:03.133Z" },
]
[[package]]