Compare commits

..

5 Commits

Author SHA1 Message Date
Evan
4c9bc26c1a add system ids 2026-02-18 21:33:48 +00:00
Alex Cheema
6c322ebb72 feat: only show thinking toggle for models that support it (#1497)
## Summary
- Adds `thinking_toggle` capability to 26 model cards that support
toggling thinking mode on/off
- GPT-OSS models (20b, 120b) excluded — they always think and don't
support toggling
- Dashboard UI updated to check for `thinking_toggle` capability before
showing the toggle button

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 0 failed
- [x] Security review passed (no secrets, eval/exec, innerHTML, or dep
changes)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 17:05:00 +00:00
vskiwi
2ebe6216b4 feat: add explicit --offline mode for air-gapped clusters (#1525)
## Motivation

Closes #1510

There is currently no reliable way to run exo on an air-gapped or offline cluster where models are pre-staged on local disks. The two existing mechanisms — `--no-downloads` and `HF_HUB_OFFLINE=1` — each cover only a subset of the problem:

1. **`--no-downloads` blocks model loading**: When passed, `DownloadCoordinator` is not created. No `NodeDownloadProgress` events are ever emitted, so `_model_needs_download()` in `plan.py` perpetually returns `DownloadModel`, short-circuiting `_load_model()` and preventing the model from ever being loaded.

2. **`HF_HUB_OFFLINE=1` doesn't cover exo's aiohttp code**: exo's download pipeline primarily uses raw `aiohttp` for HTTP operations (file list fetching, file downloads, HEAD verification), not the `huggingface_hub` library. These calls will attempt connections and time out on air-gapped networks.

3. **`skip_internet` is not propagated to `download_file_with_retry()`**: Even when `internet_connection = False`, the `_download_file()` function still makes HTTP HEAD calls via `file_meta()` to verify local files and unconditionally attempts downloads for missing files.

## Changes

### `src/exo/main.py`
- Add `--offline` flag to `Args` with env var detection (`EXO_OFFLINE=1`, `HF_HUB_OFFLINE=1`)
- Pass `offline` to `DownloadCoordinator` at creation and re-creation (election loop)

### `src/exo/download/coordinator.py`
- Add `offline: bool = False` field
- In offline mode: set `internet_connection = False` immediately in `__post_init__`, skip `_test_internet_connection()` ping (avoids 3s timeout), skip `_check_internet_connection` periodic loop
- In `_start_download()`: if model is not fully available locally, emit `DownloadFailed` with clear message instead of starting a download task

### `src/exo/download/download_utils.py`
- Add `skip_internet: bool` parameter to `download_file_with_retry()` and `_download_file()`
- When `skip_internet=True` in `_download_file()`: return local file immediately without HTTP HEAD verification; raise `FileNotFoundError` for missing files
- Propagate `skip_internet` from `download_shard()` to `download_file_with_retry()`

### `src/exo/download/tests/test_offline_mode.py` (new)
- 8 tests covering `_download_file`, `download_file_with_retry`, and `fetch_file_list_with_cache` in offline mode

## Why It Works

Unlike `--no-downloads` which disables `DownloadCoordinator` entirely, `--offline` keeps the coordinator running in a restricted mode. The existing `_emit_existing_download_progress()` disk scanner still runs every 60 seconds, emitting `DownloadCompleted` events for pre-staged models. These events flow through the event-sourcing pipeline and populate `state.downloads`, which unblocks `_model_needs_download()` in `plan.py` — no changes to the planning logic required.

```
--offline flag
  → DownloadCoordinator (offline mode)
    → Skip 1.1.1.1 ping, internet_connection = False
    → _emit_existing_download_progress scans disk
      → Emits DownloadCompleted for pre-staged models
        → _model_needs_download sees DownloadCompleted
          → _load_model proceeds normally
```

## Test Plan

### Automated Testing
- `ruff check` — passes
- 8 new tests in `test_offline_mode.py` — all pass
- 11 existing download tests in `test_download_verification.py` — all pass (no regressions)

### Manual Testing
1. Pre-stage a model on disk (e.g., `~/.exo/models/mlx-community--Qwen3-0.6B-4bit/`)
2. Start exo with `--offline` (or `EXO_OFFLINE=1`)
3. Place an instance via API or dashboard
4. Verify: model loads into memory and inference works without any network calls

### Environment
- macOS (Apple Silicon), multi-node cluster with Thunderbolt interconnect
- Models pre-staged via rsync / NFS mount
2026-02-18 16:18:09 +00:00
ciaranbor
f54c80b121 Ciaran/image edit api (#1500)
## Motivation

- Image editing previously ignored input image dimensions, always
defaulting to 1024x1024
- Size dropdown was hidden in edit mode, giving users no control over
output dimensions
- Portrait/landscape presets used non-standard aspect ratios (1024x1365
/ 1365x1024)

## Changes

- Added "auto" size option that uses input image dimensions for edits,
defaults to 1024x1024 for generation
- Introduced ImageSize Literal type and normalize_image_size() validator
(replaces raw str size fields)
  - Updated portrait/landscape presets to standard 1024x1536 / 1536x1024
  - Made size selector visible in edit mode (previously hidden)
  - Default size changed from "1024x1024" to "auto"

## Why It Works

- "auto" reads actual input image dimensions via PIL at generation time,
so edits preserve the original aspect ratio
- Pydantic field_validator on both ImageGenerationTaskParams and
ImageEditsTaskParams normalizes None → "auto", keeping the API
backward-compatible

## Test Plan

### Manual Testing

- Verify image edits output at the input image's native resolution when
size is "auto"
- Verify size dropdown appears and works in both generate and edit modes
2026-02-18 16:05:39 +00:00
rltakashige
48b8f86395 Add support for GLM 5 (#1526)
## Motivation

Add GLM 5 support in favor of #1513 

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-02-18 14:04:06 +00:00
44 changed files with 499 additions and 189 deletions

View File

@@ -103,7 +103,7 @@
const modelSupportsThinking = $derived(() => {
if (!currentModel) return false;
const caps = modelCapabilities[currentModel] || [];
return caps.includes("thinking") && caps.includes("text");
return caps.includes("thinking_toggle") && caps.includes("text");
});
const isEditOnlyWithoutImage = $derived(

View File

@@ -59,13 +59,14 @@
}
const sizeOptions: ImageGenerationParams["size"][] = [
"auto",
"512x512",
"768x768",
"1024x1024",
"1024x768",
"768x1024",
"1024x1365",
"1365x1024",
"1024x1536",
"1536x1024",
];
const qualityOptions: ImageGenerationParams["quality"][] = [
@@ -176,92 +177,90 @@
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size (hidden in edit mode - output size comes from input image) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
{params.size.toUpperCase()}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size.toUpperCase()}</span>
</button>
{/each}
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
{/if}
</div>
<!-- Quality -->
<div class="flex items-center gap-1.5">
@@ -311,7 +310,7 @@
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto overflow-x-hidden min-w-max"
style="bottom: calc(100vh - {qualityDropdownPosition()
.top}px + 4px); left: {qualityDropdownPosition().left}px;"
>

View File

@@ -306,13 +306,14 @@ const IMAGE_PARAMS_STORAGE_KEY = "exo-image-generation-params";
export interface ImageGenerationParams {
// Basic params
size:
| "auto"
| "512x512"
| "768x768"
| "1024x1024"
| "1024x768"
| "768x1024"
| "1024x1365"
| "1365x1024";
| "1024x1536"
| "1536x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
@@ -336,7 +337,7 @@ export interface EditingImage {
}
const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
size: "1024x1024",
size: "auto",
quality: "medium",
outputFormat: "png",
numImages: 1,

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "4bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 405874409472

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "deepseek"
quantization = "8bit"
base_model = "DeepSeek V3.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 765577920512

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 122406567936

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "bf16"
base_model = "GLM 4.5 Air"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 229780750336

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 198556925568

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 286737579648

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 396963397248

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "4bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 19327352832

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "5bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 22548578304

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "6bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 26843545600

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "glm"
quantization = "8bit"
base_model = "GLM 4.7 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 34359738368

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 706522120192

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "kimi"
quantization = ""
base_model = "Kimi K2.5"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 662498705408

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "3bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 100086644736

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "minimax"
quantization = "8bit"
base_model = "MiniMax M2.1"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 242986745856

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 342884352

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 0.6B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 698351616

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 141733920768

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 235B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 268435456000

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 17612931072

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 30B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 33279705088

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "4bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 47080074240

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "qwen"
quantization = "8bit"
base_model = "Qwen3 Next 80B"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 88814387200

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 114572190076

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 159039627774

View File

@@ -6,7 +6,7 @@ tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
capabilities = ["text", "thinking", "thinking_toggle"]
[storage_size]
in_bytes = 209082699847

View File

@@ -1,7 +1,6 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
@@ -22,7 +21,7 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
ForwarderEvent,
@@ -46,7 +45,8 @@ class DownloadCoordinator:
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -62,6 +62,8 @@ class DownloadCoordinator:
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
@@ -107,13 +109,17 @@ class DownloadCoordinator:
self._last_progress_time[model_id] = current_time()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
logger.info(
f"Starting DownloadCoordinator{' (offline mode)' if self.offline else ''}"
)
if not self.offline:
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._check_internet_connection)
if not self.offline:
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
@@ -202,6 +208,20 @@ class DownloadCoordinator:
)
return
if self.offline:
logger.warning(
f"Offline mode: model {model_id} is not fully available locally, cannot download"
)
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=f"Model files not found locally in offline mode: {model_id}",
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(NodeDownloadProgress(download_progress=failed))
return
# Start actual download
self._start_download_task(shard, initial_progress)
@@ -274,15 +294,16 @@ class DownloadCoordinator:
del self.download_status[model_id]
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)

View File

@@ -448,12 +448,13 @@ async def download_file_with_retry(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
skip_internet: bool = False,
) -> Path:
n_attempts = 3
for attempt in range(n_attempts):
try:
return await _download_file(
model_id, revision, path, target_dir, on_progress
model_id, revision, path, target_dir, on_progress, skip_internet
)
except HuggingFaceAuthenticationError:
raise
@@ -487,10 +488,14 @@ async def _download_file(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
skip_internet: bool = False,
) -> Path:
target_path = target_dir / path
if await aios.path.exists(target_path):
if skip_internet:
return target_path
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
@@ -510,6 +515,11 @@ async def _download_file(
)
return target_path
if skip_internet:
raise FileNotFoundError(
f"File {path} not found locally and cannot download in offline mode"
)
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -814,6 +824,7 @@ async def download_shard(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
skip_internet=skip_internet,
)
if not skip_download:

View File

@@ -0,0 +1,230 @@
"""Tests for offline/air-gapped mode."""
from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import AsyncMock, patch
import aiofiles
import aiofiles.os as aios
import pytest
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
download_file_with_retry,
fetch_file_list_with_cache,
)
from exo.shared.types.common import ModelId
from exo.shared.types.worker.downloads import FileListEntry
@pytest.fixture
def model_id() -> ModelId:
return ModelId("test-org/test-model")
@pytest.fixture
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
yield models_dir
class TestDownloadFileOffline:
"""Tests for _download_file with skip_internet=True."""
async def test_returns_local_file_without_http_verification(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists locally, return it immediately
without making any HTTP calls (no file_meta verification)."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"model weights data")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_raises_file_not_found_for_missing_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file does NOT exist locally,
raise FileNotFoundError instead of attempting download."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError, match="offline mode"):
await _download_file(
model_id,
"main",
"missing_model.safetensors",
target_dir,
skip_internet=True,
)
async def test_returns_local_file_in_subdirectory(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""When skip_internet=True and file exists in a subdirectory,
return it without HTTP calls."""
target_dir = tmp_path / "downloads"
subdir = target_dir / "transformer"
await aios.makedirs(subdir, exist_ok=True)
local_file = subdir / "diffusion_pytorch_model.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"weights")
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await _download_file(
model_id,
"main",
"transformer/diffusion_pytorch_model.safetensors",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
class TestDownloadFileWithRetryOffline:
"""Tests for download_file_with_retry with skip_internet=True."""
async def test_propagates_skip_internet_to_download_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Verify skip_internet is passed through to _download_file."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
local_file = target_dir / "config.json"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b'{"model_type": "qwen2"}')
with patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
) as mock_file_meta:
result = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
skip_internet=True,
)
assert result == local_file
mock_file_meta.assert_not_called()
async def test_file_not_found_does_not_retry(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""FileNotFoundError from offline mode should not trigger retries."""
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
with pytest.raises(FileNotFoundError):
await download_file_with_retry(
model_id,
"main",
"nonexistent.safetensors",
target_dir,
skip_internet=True,
)
class TestFetchFileListOffline:
"""Tests for fetch_file_list_with_cache with skip_internet=True."""
async def test_uses_cached_file_list(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and cache file exists, use it without network."""
from pydantic import TypeAdapter
cache_dir = temp_models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
cached_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
FileListEntry(type="file", path="config.json", size=200),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_list).decode()
)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
assert result == cached_list
mock_fetch.assert_not_called()
async def test_falls_back_to_local_directory_scan(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and no cache but local files exist,
build file list from local directory."""
import json
model_dir = temp_models_dir / model_id.normalize()
await aios.makedirs(model_dir, exist_ok=True)
async with aiofiles.open(model_dir / "config.json", "w") as f:
await f.write('{"model_type": "qwen2"}')
index_data = {
"metadata": {},
"weight_map": {"model.layers.0.weight": "model.safetensors"},
}
async with aiofiles.open(model_dir / "model.safetensors.index.json", "w") as f:
await f.write(json.dumps(index_data))
async with aiofiles.open(model_dir / "model.safetensors", "wb") as f:
await f.write(b"x" * 500)
with patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
) as mock_fetch:
result = await fetch_file_list_with_cache(
model_id, "main", skip_internet=True
)
mock_fetch.assert_not_called()
paths = {entry.path for entry in result}
assert "config.json" in paths
assert "model.safetensors" in paths
async def test_raises_when_no_cache_and_no_local_files(
self, model_id: ModelId, temp_models_dir: Path
) -> None:
"""When skip_internet=True and neither cache nor local files exist,
raise FileNotFoundError."""
with pytest.raises(FileNotFoundError, match="No internet"):
await fetch_file_list_with_cache(model_id, "main", skip_internet=True)

View File

@@ -1,11 +1,10 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Iterator, Self
from typing import Self
import anyio
from anyio.abc import TaskGroup
@@ -38,11 +37,11 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
async def create(cls, args: "Args") -> "Self":
async def create(cls, args: "Args") -> Self:
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
@@ -56,9 +55,6 @@ class Node:
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
@@ -67,7 +63,7 @@ class Node:
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
)
else:
download_coordinator = None
@@ -93,7 +89,6 @@ class Node:
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
else:
worker = None
@@ -131,7 +126,7 @@ class Node:
master,
api,
node_id,
event_index_counter,
args.offline,
)
async def run(self):
@@ -210,7 +205,6 @@ class Node:
if result.is_new_master:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
@@ -221,7 +215,7 @@ class Node:
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
@@ -238,7 +232,6 @@ class Node:
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.api:
@@ -260,6 +253,9 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
if args.offline:
logger.info("Running in OFFLINE mode — no internet checks, local models only")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
@@ -282,6 +278,7 @@ class Args(CamelCaseModel):
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
offline: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -329,6 +326,11 @@ class Args(CamelCaseModel):
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
parser.add_argument(
"--offline",
action="store_true",
help="Run in offline/air-gapped mode: skip internet checks, use only pre-staged local models",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -85,6 +85,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
ImageListItem,
ImageListResponse,
ImageSize,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -100,6 +101,7 @@ from exo.shared.types.api import (
TraceRankStats,
TraceResponse,
TraceStatsResponse,
normalize_image_size,
)
from exo.shared.types.chunks import (
ErrorChunk,
@@ -129,7 +131,7 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -181,6 +183,7 @@ class API:
) -> None:
self.state = State()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self._system_id = SystemId()
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
@@ -231,6 +234,7 @@ class API:
self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {}
@@ -544,7 +548,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -751,9 +755,11 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(ModelId(payload.model))
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
update={
"model": await self._validate_image_model(ModelId(payload.model)),
"advanced_params": _ensure_seed(payload.advanced_params),
}
)
command = ImageGeneration(
@@ -887,7 +893,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -973,7 +979,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -1009,12 +1015,13 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.stream = False
payload.partial_images = 0
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
update={
"model": await self._validate_image_model(ModelId(payload.model)),
"stream": False,
"partial_images": 0,
"advanced_params": _ensure_seed(payload.advanced_params),
}
)
command = ImageGeneration(
@@ -1035,7 +1042,7 @@ class API:
prompt: str,
model: ModelId,
n: int,
size: str,
size: ImageSize,
response_format: Literal["url", "b64_json"],
input_fidelity: Literal["low", "high"],
stream: bool,
@@ -1105,7 +1112,7 @@ class API:
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
size: str | None = Form(None),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
stream: str = Form("false"),
@@ -1131,7 +1138,7 @@ class API:
prompt=prompt,
model=ModelId(model),
n=n,
size=size,
size=normalize_image_size(size),
response_format=response_format,
input_fidelity=input_fidelity,
stream=stream_bool,
@@ -1167,7 +1174,7 @@ class API:
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
size: str | None = Form(None),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
quality: Literal["high", "medium", "low"] = Form("medium"),
@@ -1187,7 +1194,7 @@ class API:
prompt=prompt,
model=ModelId(model),
n=n,
size=size,
size=normalize_image_size(size),
response_format=response_format,
input_fidelity=input_fidelity,
stream=False,
@@ -1403,7 +1410,7 @@ class API:
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.origin != self.session_id.master_node_id:
if f_event.session != self.session_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
for idx, event in self.event_buffer.drain_indexed():
@@ -1467,12 +1474,12 @@ class API:
while self.paused:
await self.paused_ev.wait()
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self.node_id, command=command)
ForwarderDownloadCommand(origin=self._system_id, command=command)
)
async def start_download(

View File

@@ -29,7 +29,7 @@ from exo.shared.types.commands import (
TestCommand,
TextGeneration,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
ForwarderEvent,
@@ -90,7 +90,8 @@ class Master:
self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._system_id = SystemId()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
@@ -288,7 +289,7 @@ class Master:
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id, command=cmd
origin=self._system_id, command=cmd
)
)
generated_events.extend(transition_events)
@@ -415,7 +416,7 @@ class Master:
async for event in events:
await self._loopback_event_sender.send(
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
origin=self._system_id,
origin_idx=local_index,
session=self.session_id,
event=event,
@@ -428,7 +429,7 @@ class Master:
# Convenience method since this line is ugly
await self.global_event_sender.send(
ForwarderEvent(
origin=self.node_id,
origin=self._system_id,
origin_idx=event.idx,
session=self.session_id,
event=event.event,

View File

@@ -15,7 +15,7 @@ from exo.shared.types.commands import (
PlaceInstance,
TextGeneration,
)
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
@@ -75,13 +75,12 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
origin=sender_node_id,
origin=SystemId("Worker"),
session=session_id,
event=(
NodeGatheredInfo(
@@ -108,7 +107,7 @@ async def test_master():
logger.info("inject a CreateInstance Command")
await command_sender.send(
ForwarderCommand(
origin=node_id,
origin=SystemId("API"),
command=(
PlaceInstance(
command_id=CommandId(),
@@ -133,7 +132,7 @@ async def test_master():
logger.info("inject a TextGeneration Command")
await command_sender.send(
ForwarderCommand(
origin=node_id,
origin=SystemId("API"),
command=(
TextGeneration(
command_id=CommandId(),

View File

@@ -4,7 +4,7 @@ from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.utils.channels import channel
# ======= #
@@ -384,7 +384,7 @@ async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
# Pump local commands so our commands_seen is high before the round starts
for _ in range(50):
await co_tx.send(
ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
ForwarderCommand(origin=SystemId("SOMEONE"), command=TestCommand())
)
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands

View File

@@ -1,9 +1,9 @@
import time
from collections.abc import Generator
from typing import Annotated, Any, Literal
from typing import Annotated, Any, Literal, get_args
from uuid import uuid4
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
@@ -262,6 +262,27 @@ class DeleteInstanceResponse(BaseModel):
instance_id: InstanceId
ImageSize = Literal[
"auto",
"512x512",
"768x768",
"1024x768",
"768x1024",
"1024x1024",
"1024x1536",
"1536x1024",
]
def normalize_image_size(v: object) -> ImageSize:
"""Shared validator for ImageSize fields: maps None → "auto" and rejects invalid values."""
if v is None:
return "auto"
if v not in get_args(ImageSize):
raise ValueError(f"Invalid size: {v!r}. Must be one of {get_args(ImageSize)}")
return v # pyright: ignore[reportReturnType]
class AdvancedImageParams(BaseModel):
seed: Annotated[int, Field(ge=0)] | None = None
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
@@ -281,7 +302,7 @@ class ImageGenerationTaskParams(BaseModel):
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
size: ImageSize = "auto"
stream: bool | None = False
style: str | None = "vivid"
user: str | None = None
@@ -289,6 +310,11 @@ class ImageGenerationTaskParams(BaseModel):
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
@field_validator("size", mode="before")
@classmethod
def normalize_size(cls, v: object) -> ImageSize:
return normalize_image_size(v)
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
bench: bool = True
@@ -305,13 +331,18 @@ class ImageEditsTaskParams(BaseModel):
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
size: ImageSize = "auto"
image_strength: float | None = 0.7
stream: bool = False
partial_images: int | None = 0
advanced_params: AdvancedImageParams | None = None
bench: bool = False
@field_validator("size", mode="before")
@classmethod
def normalize_size(cls, v: object) -> ImageSize:
return normalize_image_size(v)
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
if name == "image_data":

View File

@@ -6,7 +6,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, NodeId, SystemId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -100,10 +100,10 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: NodeId
origin: SystemId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: NodeId
origin: SystemId
command: DownloadCommand

View File

@@ -25,6 +25,10 @@ class NodeId(Id):
pass
class SystemId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -166,6 +166,6 @@ class ForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0)
origin: NodeId
origin: SystemId
session: SessionId
event: Event

View File

@@ -14,6 +14,7 @@ from exo.shared.types.api import (
ImageEditsTaskParams,
ImageGenerationStats,
ImageGenerationTaskParams,
ImageSize,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.runner_response import (
@@ -23,9 +24,9 @@ from exo.shared.types.worker.runner_response import (
from exo.worker.engines.image.distributed_model import DistributedImageModel
def parse_size(size_str: str | None) -> tuple[int, int]:
def parse_size(size_str: ImageSize) -> tuple[int, int]:
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
if not size_str:
if size_str == "auto":
return (1024, 1024)
try:
@@ -109,6 +110,9 @@ def generate_image(
# Decode base64 image data and save to temp file
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
if task.size == "auto":
with Image.open(image_path) as img:
width, height = img.size
for image_num in range(num_images):
# Increment seed for each image to ensure unique results

View File

@@ -57,7 +57,7 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
group: mx.distributed.Group | None,
group: mx.distributed.Group | None = None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.

View File

@@ -1,7 +1,6 @@
from collections import defaultdict
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, create_task_group, fail_after
@@ -17,7 +16,7 @@ from exo.shared.types.commands import (
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
EventId,
@@ -64,14 +63,12 @@ class Worker:
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_buffer = OrderedBuffer[Event]()
@@ -86,6 +83,8 @@ class Worker:
self._nack_base_seconds: float = 0.5
self._nack_cap_seconds: float = 10.0
self._system_id = SystemId()
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
@@ -132,7 +131,7 @@ class Worker:
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.origin != self.session_id.master_node_id:
if f_event.session != self.session_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
event_id = f_event.event.event_id
@@ -212,7 +211,7 @@ class Worker:
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
origin=self._system_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
@@ -312,7 +311,7 @@ class Worker:
)
await self.command_sender.send(
ForwarderCommand(
origin=self.node_id,
origin=self._system_id,
command=RequestEventLog(since_idx=since_idx),
)
)
@@ -339,15 +338,16 @@ class Worker:
return runner
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe