mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-14 09:00:07 -05:00
Compare commits
33 Commits
ciaran/ima
...
evan/glm47
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff8709e4c2 | ||
|
|
c6e42d4918 | ||
|
|
bdb43e1dbb | ||
|
|
e4a01e2b0e | ||
|
|
1200a7db64 | ||
|
|
47ceb54bc1 | ||
|
|
f8112fdf25 | ||
|
|
e388f59480 | ||
|
|
e5e74e1eef | ||
|
|
b968d6f0a0 | ||
|
|
3bfffd9b4f | ||
|
|
007eb80029 | ||
|
|
8d7b6789b3 | ||
|
|
3c5b7ea670 | ||
|
|
b74a610537 | ||
|
|
18c4e49f91 | ||
|
|
d85b5d3781 | ||
|
|
caafc48693 | ||
|
|
cca8c9984a | ||
|
|
d1e88def42 | ||
|
|
59e7594e34 | ||
|
|
c65320acd3 | ||
|
|
b9a78f6f3a | ||
|
|
8f7f0e893a | ||
|
|
4759b09d4c | ||
|
|
ca680185f3 | ||
|
|
383309e24e | ||
|
|
55463a9806 | ||
|
|
56af61fac9 | ||
|
|
f76d543d98 | ||
|
|
ea841aca37 | ||
|
|
077b1bc732 | ||
|
|
4963c33162 |
159
.github/benchmark-dashboard/README.md
vendored
159
.github/benchmark-dashboard/README.md
vendored
@@ -1,159 +0,0 @@
|
||||
# EXO Benchmark Dashboard
|
||||
|
||||
A fully self-contained, browser-based dashboard for tracking EXO benchmark performance over time.
|
||||
|
||||
## Features
|
||||
|
||||
- 📊 **Success Rate Tracking**: Monitor cluster reliability across commits
|
||||
- ⚡ **Response Time Analysis**: Track average request completion times
|
||||
- 🎯 **Throughput Metrics**: Tokens per second visualization
|
||||
- 📈 **Request Distribution**: Success/failure breakdown over time
|
||||
- 🔄 **Auto-Refresh**: Updates every 60 seconds
|
||||
- 📺 **TV-Ready**: Large, clear visualizations perfect for display
|
||||
- 🔐 **Secure**: Credentials stored in browser localStorage only
|
||||
- 🌐 **No Backend**: Directly accesses S3 from the browser
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Option 1: Direct File Access (Simplest)
|
||||
|
||||
Just open the HTML file directly in your browser:
|
||||
|
||||
```bash
|
||||
open .github/benchmark-dashboard/index.html
|
||||
```
|
||||
|
||||
Then click "Configure AWS Credentials" and enter your keys.
|
||||
|
||||
### Option 2: URL Parameters (For Quick Setup)
|
||||
|
||||
```bash
|
||||
# Serve with credentials in URL (they'll be moved to localStorage)
|
||||
open ".github/benchmark-dashboard/index.html?accessKey=YOUR_KEY&secretKey=YOUR_SECRET®ion=us-east-1"
|
||||
```
|
||||
|
||||
The credentials will be saved to localStorage and removed from the URL immediately.
|
||||
|
||||
### Option 3: Simple HTTP Server
|
||||
|
||||
```bash
|
||||
# From repo root
|
||||
python3 -m http.server 8080
|
||||
|
||||
# Then open: http://localhost:8080/.github/benchmark-dashboard/
|
||||
```
|
||||
|
||||
## AWS Credentials
|
||||
|
||||
The dashboard needs read-only access to the `exo-benchmark-results` S3 bucket.
|
||||
|
||||
### Required IAM Permissions
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::exo-benchmark-results",
|
||||
"arn:aws:s3:::exo-benchmark-results/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Security Notes
|
||||
|
||||
- ✅ Credentials stored in browser `localStorage` only
|
||||
- ✅ Never sent to any server (except AWS)
|
||||
- ✅ All S3 access happens client-side
|
||||
- ✅ Use read-only IAM credentials
|
||||
- ⚠️ Don't commit credentials to git
|
||||
- ⚠️ Use a dedicated read-only IAM user
|
||||
|
||||
## TV/Kiosk Mode
|
||||
|
||||
For permanent display on a TV:
|
||||
|
||||
### macOS
|
||||
```bash
|
||||
open -a "Google Chrome" --args --kiosk ".github/benchmark-dashboard/index.html"
|
||||
```
|
||||
|
||||
### Linux
|
||||
```bash
|
||||
chromium-browser --kiosk --app="file://$(pwd)/.github/benchmark-dashboard/index.html"
|
||||
```
|
||||
|
||||
### Auto-start on Boot
|
||||
|
||||
Create a simple startup script:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# /usr/local/bin/start-benchmark-dashboard.sh
|
||||
|
||||
cd /path/to/exo
|
||||
python3 -m http.server 8080 &
|
||||
sleep 2
|
||||
chromium-browser --kiosk http://localhost:8080/.github/benchmark-dashboard/
|
||||
```
|
||||
|
||||
## Data Displayed
|
||||
|
||||
### Summary Cards
|
||||
- **Latest Success Rate**: Most recent benchmark success percentage with trend
|
||||
- **Avg Response Time**: Latest average response time in ms with trend
|
||||
- **Total Benchmarks**: Count of all benchmarks run
|
||||
- **Active Configurations**: Number of unique benchmark configs
|
||||
|
||||
### Charts
|
||||
1. **Success Rate Over Time**: Line chart showing reliability trends
|
||||
2. **Average Response Time**: Performance over time (lower is better)
|
||||
3. **Throughput**: Tokens/second metric (higher is better)
|
||||
4. **Request Distribution**: Stacked bar chart of successes/failures
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Loads AWS SDK**: Uses AWS SDK for JavaScript (browser version)
|
||||
2. **Lists S3 Objects**: Fetches all files from `s3://exo-benchmark-results/bench/`
|
||||
3. **Downloads Results**: Fetches each JSON result file
|
||||
4. **Parses & Visualizes**: Uses Chart.js to create interactive charts
|
||||
5. **Auto-Refreshes**: Polls S3 every 60 seconds for new results
|
||||
|
||||
## Customization
|
||||
|
||||
To modify the dashboard:
|
||||
|
||||
1. Edit `index.html`
|
||||
2. Adjust `REFRESH_INTERVAL` for different polling frequency
|
||||
3. Modify chart colors/styles in the Chart.js configuration
|
||||
4. Add new metrics by extending the results parsing
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**"AWS credentials not configured"**
|
||||
- Click "Configure AWS Credentials" and enter your keys
|
||||
|
||||
**"Error loading benchmark data"**
|
||||
- Check AWS credentials are correct
|
||||
- Verify S3 bucket name is `exo-benchmark-results`
|
||||
- Ensure IAM user has read permissions
|
||||
- Check browser console for detailed errors
|
||||
|
||||
**"No benchmark results found"**
|
||||
- Wait for benchmark workflows to run
|
||||
- Verify results are being uploaded to S3
|
||||
- Check S3 bucket has files in `bench/` prefix
|
||||
|
||||
**Charts not updating**
|
||||
- Check browser console for errors
|
||||
- Verify network connectivity to S3
|
||||
- Try refreshing the page manually
|
||||
|
||||
1641
.github/benchmark-dashboard/index.html
vendored
1641
.github/benchmark-dashboard/index.html
vendored
File diff suppressed because it is too large
Load Diff
186
.github/configs/README.md
vendored
186
.github/configs/README.md
vendored
@@ -1,186 +0,0 @@
|
||||
# EXO Benchmark Configurations
|
||||
|
||||
This directory contains configuration files for the EXO staged benchmark system.
|
||||
|
||||
## Overview
|
||||
|
||||
The staged benchmark system allows you to run complex, multi-stage load tests against EXO clusters. Each stage can have different characteristics:
|
||||
|
||||
- **Prompt Length**: Number of tokens in the input prompt
|
||||
- **Generation Length**: Maximum tokens to generate in the response
|
||||
- **Time Between Requests**: Delay (in seconds) between firing consecutive requests
|
||||
- **Iterations**: Number of requests to send in this stage
|
||||
|
||||
Requests are **fire-and-forget** - they don't wait for the previous request to complete. This allows you to test overlapping request handling and measure success rates under load.
|
||||
|
||||
## Configuration Files
|
||||
|
||||
### `bench_simple.yaml`
|
||||
A minimal configuration that replicates the behavior of the original `bench.py` script:
|
||||
- Single stage with 1 iteration
|
||||
- Short prompt (~20 tokens)
|
||||
- Generates up to 100 tokens
|
||||
|
||||
This is useful for quick smoke tests.
|
||||
|
||||
### `bench_config.yaml`
|
||||
A comprehensive multi-stage benchmark with:
|
||||
1. **Warmup** (10 requests): Light load with short prompts
|
||||
2. **Medium Load** (20 requests): Moderate load with medium prompts
|
||||
3. **Stress Test** (30 requests): Heavy overlapping requests with long prompts
|
||||
4. **Cooldown** (5 requests): Light load to wind down
|
||||
|
||||
This tests the cluster's behavior under varying load patterns.
|
||||
|
||||
## Configuration Schema
|
||||
|
||||
```yaml
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 4
|
||||
|
||||
# Environment variables to set on each node (optional)
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 512
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 600
|
||||
|
||||
# Model instances to run concurrently
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
# Benchmark stages
|
||||
stages:
|
||||
- name: "stage_name" # Human-readable name for this stage
|
||||
prompt_length: 100 # Target prompt length in tokens
|
||||
generation_length: 200 # Max tokens to generate
|
||||
time_between_requests: 2.0 # Seconds between firing requests
|
||||
iterations: 10 # Number of requests in this stage
|
||||
```
|
||||
|
||||
## Running Benchmarks
|
||||
|
||||
### Via GitHub Actions
|
||||
|
||||
**Automatic (every commit):**
|
||||
- The **`bench`** workflow runs automatically on every push
|
||||
- Uses `bench_simple.yaml` as the default configuration
|
||||
- All settings (hardware plan, timeout, environment variables, models, stages) are defined in the config file
|
||||
|
||||
**Manual (on-demand):**
|
||||
1. Go to **Actions** → **bench** workflow
|
||||
2. Click **Run workflow**
|
||||
3. Configure:
|
||||
- **Config File**: Path to your YAML config (default: `.github/configs/bench_simple.yaml`)
|
||||
- `.github/configs/bench_simple.yaml` for quick tests
|
||||
- `.github/configs/bench_config.yaml` for complex multi-stage tests
|
||||
|
||||
All other settings (hardware plan, timeout, environment variables, models, stages) are read from the specified config file.
|
||||
|
||||
### Via Command Line
|
||||
|
||||
```bash
|
||||
# Start EXO on localhost:8000
|
||||
uv run exo --api-port 8000
|
||||
|
||||
# Run simple benchmark (1 stage, 1 iteration)
|
||||
python3 .github/scripts/bench.py \
|
||||
--api-port 8000 \
|
||||
--config .github/configs/bench_simple.yaml \
|
||||
--expected-nodes 1 \
|
||||
--is-primary true \
|
||||
--timeout-seconds 600
|
||||
|
||||
# Run complex staged benchmark (4 stages, multiple iterations)
|
||||
python3 .github/scripts/bench.py \
|
||||
--api-port 8000 \
|
||||
--config .github/configs/bench_config.yaml \
|
||||
--expected-nodes 1 \
|
||||
--is-primary true \
|
||||
--timeout-seconds 600
|
||||
```
|
||||
|
||||
## Output Metrics
|
||||
|
||||
For each stage, the benchmark reports:
|
||||
|
||||
- **Total Requests**: Number of requests fired
|
||||
- **Successful Requests**: Requests that completed successfully
|
||||
- **Failed Requests**: Requests that encountered errors
|
||||
- **Success Rate**: Percentage of successful requests
|
||||
- **Total Tokens**: Sum of all tokens generated across successful requests
|
||||
- **Avg Tokens/Request**: Average tokens per successful request
|
||||
- **Avg Time/Request**: Average completion time per successful request
|
||||
|
||||
A JSON summary is also printed for easy parsing and storage.
|
||||
|
||||
## Creating Custom Benchmarks
|
||||
|
||||
To create a custom benchmark:
|
||||
|
||||
1. Copy an existing config file (e.g., `bench_config.yaml`)
|
||||
2. Modify the stages to match your test scenario
|
||||
3. Save it in this directory with a descriptive name
|
||||
4. Run it using the workflow or command line
|
||||
|
||||
### Example: Sustained Load Test
|
||||
|
||||
```yaml
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 2
|
||||
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 1024
|
||||
|
||||
timeout_seconds: 600
|
||||
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
stages:
|
||||
- name: "sustained_load"
|
||||
prompt_length: 200
|
||||
generation_length: 150
|
||||
time_between_requests: 0.5 # Very fast - 2 requests/second
|
||||
iterations: 100 # Run for ~50 seconds
|
||||
```
|
||||
|
||||
### Example: Varying Prompt Sizes
|
||||
|
||||
```yaml
|
||||
hardware_plan:
|
||||
M4PRO_GPU16_24GB: 3
|
||||
|
||||
timeout_seconds: 900
|
||||
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
stages:
|
||||
- name: "tiny_prompts"
|
||||
prompt_length: 10
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
|
||||
- name: "medium_prompts"
|
||||
prompt_length: 200
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
|
||||
- name: "large_prompts"
|
||||
prompt_length: 1000
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
- **Overlapping Requests**: Set `time_between_requests` < expected completion time to test concurrent request handling
|
||||
- **Sequential Requests**: Set `time_between_requests` > expected completion time to ensure requests don't overlap
|
||||
- **Realistic Load**: Model real usage patterns by varying prompt/generation lengths across stages
|
||||
- **Success Rate**: A 100% success rate indicates the cluster handled the load well; lower rates suggest capacity limits
|
||||
|
||||
49
.github/configs/bench_config.yaml
vendored
49
.github/configs/bench_config.yaml
vendored
@@ -1,49 +0,0 @@
|
||||
# EXO Staged Benchmark Configuration
|
||||
# This configuration defines a multi-stage load test for EXO clusters
|
||||
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 4
|
||||
|
||||
# Environment variables to set on each node (optional)
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 512
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 600
|
||||
|
||||
# Multiple instances run concurrently on the cluster
|
||||
model_ids:
|
||||
- "mlx-community/Qwen3-0.6B-4bit"
|
||||
- "mlx-community/Qwen3-0.6B-4bit"
|
||||
|
||||
# Stages run sequentially, each with its own characteristics
|
||||
stages:
|
||||
# Stage 1: Light load with short prompts
|
||||
- name: "warmup"
|
||||
prompt_length: 50 # Number of tokens in prompt
|
||||
generation_length: 100 # Max tokens to generate
|
||||
time_between_requests: 5.0 # Seconds between firing requests
|
||||
iterations: 10 # Number of requests to send in this stage
|
||||
|
||||
# Stage 2: Medium load with medium prompts
|
||||
- name: "medium_load"
|
||||
prompt_length: 200
|
||||
generation_length: 150
|
||||
time_between_requests: 3.0
|
||||
iterations: 20
|
||||
|
||||
# Stage 3: Heavy load with long prompts - requests will overlap
|
||||
- name: "stress_test"
|
||||
prompt_length: 500
|
||||
generation_length: 200
|
||||
time_between_requests: 1.0 # Fast firing - will definitely overlap
|
||||
iterations: 30
|
||||
|
||||
# Stage 4: Cool down with simple prompts
|
||||
- name: "cooldown"
|
||||
prompt_length: 50
|
||||
generation_length: 50
|
||||
time_between_requests: 10.0
|
||||
iterations: 5
|
||||
|
||||
125
.github/configs/bench_simple.yaml
vendored
125
.github/configs/bench_simple.yaml
vendored
@@ -1,125 +0,0 @@
|
||||
# Simple single-shot benchmark
|
||||
# Tests 2 instances concurrently on 2 nodes
|
||||
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
puffin4: 1
|
||||
puffin8: 1
|
||||
|
||||
# Environment variables to set on each node
|
||||
environment:
|
||||
PLACEHOLDER: "placeholder"
|
||||
# OVERRIDE_MEMORY_MB: 50000
|
||||
MLX_METAL_FAST_SYNCH: 1
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 1800
|
||||
|
||||
# Model instances to run concurrently
|
||||
model_ids:
|
||||
# - "mlx-community/DeepSeek-V3.1-8bit"
|
||||
# - "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
- "mlx-community/Kimi-K2-Thinking"
|
||||
# - "mlx-community/Qwen3-235B-A22B-4bit"
|
||||
# - "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
# - "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
# Sharding strategy: "Pipeline" or "Tensor"
|
||||
sharding: "Tensor"
|
||||
|
||||
# Instance type: "MlxRing" or "MlxIbv"
|
||||
instance_meta: "MlxIbv"
|
||||
|
||||
# If true, run requests sequentially (no overlap); if false, fire-and-forget (default: false)
|
||||
no_overlap: true
|
||||
|
||||
# Benchmark stages
|
||||
# pp: 64, 256, 1024, 2048, 4096, 8192, 16384
|
||||
# g: 64, 512
|
||||
stages:
|
||||
# - name: "simple"
|
||||
# prompt_length: 512
|
||||
# generation_length: 10
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g64"
|
||||
# prompt_length: 64
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g64"
|
||||
# prompt_length: 64
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g512"
|
||||
# prompt_length: 64
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp256_g64"
|
||||
# prompt_length: 256
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
- name: "pp256_g64"
|
||||
prompt_length: 256
|
||||
generation_length: 64
|
||||
time_between_requests: 2.0
|
||||
iterations: 5
|
||||
# - name: "pp256_g512"
|
||||
# prompt_length: 256
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp1024_g64"
|
||||
# prompt_length: 1024
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp1024_g512"
|
||||
# prompt_length: 1024
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp2048_g64"
|
||||
# prompt_length: 2048
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp2048_g512"
|
||||
# prompt_length: 2048
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp4096_g64"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 4
|
||||
# - name: "pp4096_g512"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp8192_g64"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp8192_g512"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp16384_g64"
|
||||
# prompt_length: 16384
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp16384_g512"
|
||||
# prompt_length: 16384
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
1399
.github/scripts/bench.py
vendored
1399
.github/scripts/bench.py
vendored
File diff suppressed because it is too large
Load Diff
70
.github/scripts/build_matrix.py
vendored
70
.github/scripts/build_matrix.py
vendored
@@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class MatrixEntry(TypedDict):
|
||||
label: str
|
||||
index: int
|
||||
|
||||
|
||||
class MatrixInclude(TypedDict):
|
||||
label: str
|
||||
index: int
|
||||
is_primary: bool
|
||||
expected_nodes: int
|
||||
|
||||
|
||||
class Config(TypedDict):
|
||||
hardware_plan: dict[str, int]
|
||||
timeout_seconds: NotRequired[int]
|
||||
environment: NotRequired[dict[str, str]]
|
||||
|
||||
|
||||
# Read the config file
|
||||
config_file: str = os.environ["CONFIG_FILE"]
|
||||
with open(config_file, "r") as f:
|
||||
config: Config = cast(Config, yaml.safe_load(f))
|
||||
|
||||
# Extract hardware plan from config
|
||||
plan: dict[str, int] = config["hardware_plan"]
|
||||
if not plan:
|
||||
raise ValueError(f"No hardware_plan found in {config_file}")
|
||||
|
||||
# Build matrix entries
|
||||
entries: list[MatrixEntry] = []
|
||||
for label, count in plan.items():
|
||||
for idx in range(count):
|
||||
entries.append({"label": label, "index": idx})
|
||||
|
||||
total_nodes: int = len(entries)
|
||||
matrix: dict[str, list[MatrixInclude]] = {
|
||||
"include": [
|
||||
{
|
||||
"label": e["label"],
|
||||
"index": e["index"],
|
||||
"is_primary": (i == 0),
|
||||
"expected_nodes": total_nodes,
|
||||
}
|
||||
for i, e in enumerate(entries)
|
||||
]
|
||||
}
|
||||
|
||||
# Extract other config values
|
||||
timeout_seconds: int = config.get("timeout_seconds", 600)
|
||||
environment: dict[str, str] = config.get("environment", {})
|
||||
|
||||
# Output to GitHub Actions
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as f:
|
||||
f.write(f"matrix={json.dumps(matrix)}\n")
|
||||
f.write(f"config_file={config_file}\n")
|
||||
f.write(f"timeout_seconds={timeout_seconds}\n")
|
||||
f.write(f"environment={json.dumps(environment)}\n")
|
||||
|
||||
print(f"Matrix: {json.dumps(matrix)}")
|
||||
print(f"Config file: {config_file}")
|
||||
print(f"Timeout: {timeout_seconds}")
|
||||
print(f"Environment: {json.dumps(environment)}")
|
||||
156
.github/workflows/BENCH_USAGE.md
vendored
156
.github/workflows/BENCH_USAGE.md
vendored
@@ -1,156 +0,0 @@
|
||||
# Benchmark Workflow Usage
|
||||
|
||||
## Overview
|
||||
|
||||
The `bench_matrix.yml` workflow enables distributed benchmarking of models across multiple self-hosted macOS runners with different hardware configurations.
|
||||
|
||||
## Workflow Inputs
|
||||
|
||||
| Input | Description | Default | Required |
|
||||
|-------|-------------|---------|----------|
|
||||
| `model_id` | Model ID to benchmark | `mlx-community/Llama-3.2-1B-Instruct-4bit` | Yes |
|
||||
| `hardware_plan` | JSON mapping of runner labels to counts | `{"M4PRO_GPU16_24GB": 1}` | Yes |
|
||||
| `prompt` | Benchmark prompt text | `What is the capital of France?` | No |
|
||||
| `timeout_seconds` | Timeout for instance/runner readiness | `600` | No |
|
||||
|
||||
## Hardware Plan Format
|
||||
|
||||
The `hardware_plan` input is a JSON object mapping runner labels to the number of machines:
|
||||
|
||||
```json
|
||||
{
|
||||
"M4PRO_GPU16_24GB": 2,
|
||||
"M3ULTRA_GPU80_512GB": 1
|
||||
}
|
||||
```
|
||||
|
||||
This example would:
|
||||
- Start 2 runners with the `M4PRO_GPU16_24GB` label
|
||||
- Start 1 runner with the `M3ULTRA_GPU80_512GB` label
|
||||
- Total of 3 runners coordinating on a single distributed inference instance
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Planning Job** (`plan`)
|
||||
- Runs on `ubuntu-latest`
|
||||
- Parses the `hardware_plan` JSON
|
||||
- Generates a dynamic matrix with one entry per runner
|
||||
- Only the first runner (index 0) is marked as `is_primary`
|
||||
|
||||
2. **Benchmark Worker Jobs** (`bench_worker`)
|
||||
- Each job runs on a self-hosted macOS runner with the specified label
|
||||
- All runners start EXO in parallel
|
||||
- The primary runner creates the model instance
|
||||
- All runners wait for their assigned runner to be ready (Loaded/Running status)
|
||||
- The primary runner executes the benchmark and prints results
|
||||
- The primary runner deletes the instance
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Single Machine Benchmark
|
||||
|
||||
```yaml
|
||||
model_id: mlx-community/Llama-3.2-1B-Instruct-4bit
|
||||
hardware_plan: '{"M4PRO_GPU16_24GB": 1}'
|
||||
prompt: What is the capital of France?
|
||||
timeout_seconds: 600
|
||||
```
|
||||
|
||||
### Multi-Machine Distributed Benchmark
|
||||
|
||||
```yaml
|
||||
model_id: mlx-community/Llama-3.2-3B-Instruct-4bit
|
||||
hardware_plan: '{"M4PRO_GPU16_24GB": 2, "M3ULTRA_GPU80_512GB": 1}'
|
||||
prompt: Explain quantum computing in simple terms.
|
||||
timeout_seconds: 900
|
||||
```
|
||||
|
||||
## Benchmark Output
|
||||
|
||||
The primary runner outputs a JSON object with benchmark results:
|
||||
|
||||
```json
|
||||
{
|
||||
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
"instance_id": "abc-123-def",
|
||||
"tokens": 42,
|
||||
"elapsed_s": 2.451,
|
||||
"tps": 17.136
|
||||
}
|
||||
```
|
||||
|
||||
Where:
|
||||
- `tokens`: Number of chunks/tokens generated
|
||||
- `elapsed_s`: Total elapsed time in seconds
|
||||
- `tps`: Tokens per second (tokens / elapsed_s)
|
||||
|
||||
## Runner Requirements
|
||||
|
||||
Each self-hosted runner must:
|
||||
- Be labeled with appropriate hardware tags (e.g., `M4PRO_GPU16_24GB`)
|
||||
- Have the `self-hosted` and `macOS` labels
|
||||
- Have Nix installed with flakes enabled
|
||||
- Have network connectivity to other runners in the same job
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ GitHub Actions Workflow (bench_matrix.yml) │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌────────────────┐ │
|
||||
│ │ Plan Job │ │
|
||||
│ │ (ubuntu) │──┬─► Matrix: [{label, index, primary}] │
|
||||
│ └────────────────┘ │ │
|
||||
│ │ │
|
||||
│ ┌───────────────────▼──────────────────────────────────┐ │
|
||||
│ │ Bench Worker Jobs (Matrix) │ │
|
||||
│ ├──────────────────────────────────────────────────────┤ │
|
||||
│ │ │ │
|
||||
│ │ Runner 0 (Primary) Runner 1 Runner 2 │ │
|
||||
│ │ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ │
|
||||
│ │ │ Start EXO │ │ Start EXO │ │ Start EXO│ │ │
|
||||
│ │ │ Create Inst │ │ Wait... │ │ Wait... │ │ │
|
||||
│ │ │ Wait Ready │ │ Wait Ready │ │ Wait... │ │ │
|
||||
│ │ │ Run Bench │ │ (idle) │ │ (idle) │ │ │
|
||||
│ │ │ Print TPS │ │ │ │ │ │ │
|
||||
│ │ │ Delete Inst │ │ │ │ │ │ │
|
||||
│ │ └─────────────┘ └─────────────┘ └──────────┘ │ │
|
||||
│ └───────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### `scripts/bench.py`
|
||||
|
||||
A standalone Python script that:
|
||||
- Creates instance (primary only)
|
||||
- Polls `/state` endpoint until instance and all runners are ready
|
||||
- Executes chat completion with timing (primary only)
|
||||
- Parses SSE stream and counts tokens
|
||||
- Computes TPS metrics
|
||||
- Cleans up instance (primary only)
|
||||
|
||||
### Key Functions
|
||||
|
||||
- `wait_for_instance()`: Polls until instance with model_id appears
|
||||
- `wait_for_runners_ready()`: Polls until expected number of runners reach Loaded/Running status
|
||||
- `run_benchmark()`: Executes chat completion, measures time, counts tokens
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Instance never becomes ready
|
||||
- Check EXO logs in the workflow output
|
||||
- Verify model_id is valid and accessible
|
||||
- Increase `timeout_seconds`
|
||||
|
||||
### Runner mismatch
|
||||
- Ensure hardware_plan counts match available labeled runners
|
||||
- Check runner labels match exactly (case-sensitive)
|
||||
|
||||
### Network issues
|
||||
- Verify runners can communicate on the network
|
||||
- Check firewall rules between runner hosts
|
||||
|
||||
305
.github/workflows/bench.yml
vendored
305
.github/workflows/bench.yml
vendored
@@ -1,305 +0,0 @@
|
||||
name: bench
|
||||
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
plan:
|
||||
if: contains(github.event.head_commit.message, '/bench')
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.build.outputs.matrix }}
|
||||
config_file: ${{ steps.build.outputs.config_file }}
|
||||
timeout_seconds: ${{ steps.build.outputs.timeout_seconds }}
|
||||
environment: ${{ steps.build.outputs.environment }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Build matrix from config file
|
||||
id: build
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
CONFIG_FILE='.github/configs/bench_simple.yaml'
|
||||
export CONFIG_FILE
|
||||
echo "Config file: $CONFIG_FILE"
|
||||
python3 .github/scripts/build_matrix.py
|
||||
|
||||
bench_worker:
|
||||
needs: plan
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix: ${{ fromJSON(needs.plan.outputs.matrix) }}
|
||||
name: "bench on ${{ matrix.label }} [${{ matrix.index }}]"
|
||||
runs-on: [self-hosted, macOS, "${{ matrix.label }}"]
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
git config --local user.name "github-actions bot"
|
||||
shell: bash
|
||||
|
||||
# TODO: this is mega hacky and I'd like a simpler solution.
|
||||
- name: Setup Nix Environment
|
||||
run: |
|
||||
echo "Checking for nix installation..."
|
||||
|
||||
# Check if nix is already available
|
||||
if command -v nix >/dev/null 2>&1; then
|
||||
echo "Nix already in PATH"
|
||||
# Try sourcing profile scripts to set up environment properly
|
||||
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
echo "Sourcing multi-user nix-daemon profile script"
|
||||
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
elif [ -f "$HOME/.nix-profile/etc/profile.d/nix.sh" ]; then
|
||||
echo "Sourcing single-user nix profile script"
|
||||
source "$HOME/.nix-profile/etc/profile.d/nix.sh"
|
||||
elif [ -f /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh ]; then
|
||||
echo "Sourcing per-user nix profile script"
|
||||
source /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh
|
||||
elif [ -f /etc/profile.d/nix.sh ]; then
|
||||
echo "Sourcing system-wide nix profile script"
|
||||
source /etc/profile.d/nix.sh
|
||||
# Fallback: manually add nix to PATH if binary exists
|
||||
elif [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
echo "Found nix binary, manually adding to PATH"
|
||||
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
elif [ -f "$HOME/.nix-profile/bin/nix" ]; then
|
||||
echo "Found nix binary in user profile, manually adding to PATH"
|
||||
export PATH="$HOME/.nix-profile/bin:$PATH"
|
||||
else
|
||||
echo "Nix not found. Debugging info:"
|
||||
echo "USER: $USER"
|
||||
echo "HOME: $HOME"
|
||||
echo "Current PATH: $PATH"
|
||||
echo ""
|
||||
echo "Checking common Nix locations:"
|
||||
echo " /nix/var/nix/profiles/default/bin/nix:"
|
||||
ls -la /nix/var/nix/profiles/default/bin/nix 2>/dev/null || echo " Not found"
|
||||
echo " /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh:"
|
||||
ls -la /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh 2>/dev/null || echo " Not found"
|
||||
echo " ~/.nix-profile/etc/profile.d/nix.sh:"
|
||||
ls -la "$HOME/.nix-profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||
echo " /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh:"
|
||||
ls -la "/nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||
echo ""
|
||||
echo "/nix directory structure:"
|
||||
ls -la /nix 2>/dev/null || echo " /nix directory not found"
|
||||
echo ""
|
||||
echo "/nix/var:"
|
||||
ls -la /nix/var 2>/dev/null || echo " /nix/var not found"
|
||||
echo ""
|
||||
echo "/nix/store:"
|
||||
ls -la /nix/store 2>/dev/null | head -20 || echo " /nix/store not found"
|
||||
echo ""
|
||||
echo "GitHub Actions runner is running as user '$USER'."
|
||||
echo "If Nix is installed for a different user, either:"
|
||||
echo " 1. Install Nix for user '$USER' (multi-user install recommended)"
|
||||
echo " 2. Configure the runner service to run as the user with Nix installed"
|
||||
echo " 3. Ensure Nix is installed system-wide with proper daemon setup"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify nix is available and persist to GITHUB_ENV
|
||||
if command -v nix >/dev/null 2>&1; then
|
||||
echo "✓ Nix is available"
|
||||
nix --version
|
||||
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
if [ -n "$NIX_PATH" ]; then
|
||||
echo "NIX_PATH=$NIX_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
else
|
||||
echo "ERROR: Failed to set up Nix"
|
||||
echo "PATH after setup attempt: $PATH"
|
||||
exit 1
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Setup EXO_HOME and API_PORT
|
||||
run: |
|
||||
EXO_HOME=$(mktemp -d -t exo-e2e-XXXXXXXX)
|
||||
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||
EXO_MODELS_DIR="$HOME/.exo/models"
|
||||
EXO_LIBP2P_NAMESPACE="bench-${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
|
||||
echo "EXO_HOME=$EXO_HOME" >> "$GITHUB_ENV"
|
||||
echo "API_PORT=$API_PORT" >> "$GITHUB_ENV"
|
||||
echo "EXO_MODELS_DIR=$EXO_MODELS_DIR" >> "$GITHUB_ENV"
|
||||
echo "EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE" >> "$GITHUB_ENV"
|
||||
echo "Created EXO_HOME: $EXO_HOME"
|
||||
echo "Generated API_PORT: $API_PORT"
|
||||
echo "Using models from: $EXO_MODELS_DIR"
|
||||
echo "Using libp2p namespace: $EXO_LIBP2P_NAMESPACE"
|
||||
shell: bash
|
||||
|
||||
- name: Configure local MLX if available
|
||||
run: |
|
||||
echo "=== DEBUG: Checking for local MLX configuration ==="
|
||||
MODIFIED=false
|
||||
|
||||
echo "Checking for /Users/Shared/mlx directory..."
|
||||
if [ -d "/Users/Shared/mlx" ]; then
|
||||
echo "✓ Found /Users/Shared/mlx"
|
||||
ls -la /Users/Shared/mlx | head -5
|
||||
echo "Enabling local mlx path in pyproject.toml"
|
||||
sed -i.bak 's|^# mlx = { path = "/Users/Shared/mlx", editable=true }$|mlx = { path = "/Users/Shared/mlx", editable=true }|' pyproject.toml
|
||||
MODIFIED=true
|
||||
else
|
||||
echo "✗ /Users/Shared/mlx not found, will use PyPI version"
|
||||
fi
|
||||
|
||||
echo "Checking for /Users/Shared/mlx-lm directory..."
|
||||
if [ -d "/Users/Shared/mlx-lm" ]; then
|
||||
echo "✓ Found /Users/Shared/mlx-lm"
|
||||
ls -la /Users/Shared/mlx-lm | head -5
|
||||
echo "Enabling local mlx-lm path in pyproject.toml"
|
||||
sed -i.bak 's|^# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }$|mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }|' pyproject.toml
|
||||
MODIFIED=true
|
||||
else
|
||||
echo "✗ /Users/Shared/mlx-lm not found, will use PyPI version"
|
||||
fi
|
||||
|
||||
if [ "$MODIFIED" = true ]; then
|
||||
echo "=== Modified pyproject.toml [tool.uv.sources] section: ==="
|
||||
sed -n '/\[tool\.uv\.sources\]/,/^\[/{/^\[tool\.uv\.sources\]/p; /^\[/!p;}' pyproject.toml
|
||||
echo "=== Regenerating uv.lock with local MLX paths... ==="
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command uv lock --upgrade-package mlx --upgrade-package mlx-lm
|
||||
echo "✓ Lock file regenerated"
|
||||
else
|
||||
echo "⚠ No local MLX directories found, using PyPI packages"
|
||||
fi
|
||||
echo "=== DEBUG: Local MLX configuration complete ==="
|
||||
shell: bash
|
||||
|
||||
- name: Sync dependencies
|
||||
run: |
|
||||
if [ -d "/Users/Shared/test" ]; then
|
||||
pushd /Users/Shared/test
|
||||
uv sync --reinstall
|
||||
popd
|
||||
fi
|
||||
echo "Running just sync to ensure clean dependencies..."
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync
|
||||
shell: bash
|
||||
|
||||
- name: Start EXO and run bench script
|
||||
shell: bash
|
||||
env:
|
||||
IS_PRIMARY: ${{ matrix.is_primary }}
|
||||
EXPECTED_NODES: ${{ matrix.expected_nodes }}
|
||||
HARDWARE_LABEL: ${{ matrix.label }}
|
||||
CONFIG_FILE: ${{ needs.plan.outputs.config_file }}
|
||||
TIMEOUT_SECONDS: ${{ needs.plan.outputs.timeout_seconds }}
|
||||
ENVIRONMENT_JSON: ${{ needs.plan.outputs.environment }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Parse environment variables from config
|
||||
ENV_VARS=""
|
||||
if [ -n "$ENVIRONMENT_JSON" ] && [ "$ENVIRONMENT_JSON" != "{}" ]; then
|
||||
ENV_VARS=$(echo "$ENVIRONMENT_JSON" | python3 -c "import sys, json; env = json.load(sys.stdin); print(' '.join([f'{k}={v}' for k, v in env.items()]))")
|
||||
fi
|
||||
|
||||
echo "Starting EXO with API_PORT=${API_PORT} EXO_HOME=${EXO_HOME} EXO_LIBP2P_NAMESPACE=${EXO_LIBP2P_NAMESPACE}"
|
||||
echo "Environment variables from config: $ENV_VARS"
|
||||
LOG_FILE=/tmp/exo.log
|
||||
: > "$LOG_FILE"
|
||||
|
||||
MASTER_FLAG=""
|
||||
if [ "$IS_PRIMARY" = "true" ]; then
|
||||
MASTER_FLAG="-m"
|
||||
fi
|
||||
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||
"EXO_HOME=$EXO_HOME EXO_MODELS_DIR=$EXO_MODELS_DIR EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE $ENV_VARS PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run exo $MASTER_FLAG --api-port $API_PORT" \
|
||||
>> "$LOG_FILE" 2>&1 &
|
||||
|
||||
EXO_PID=$!
|
||||
echo "Started EXO in background with PID: $EXO_PID"
|
||||
echo "Log file: $LOG_FILE"
|
||||
|
||||
cleanup() {
|
||||
echo '=== EXO log (tail) ==='
|
||||
tail -n 300 "$LOG_FILE" || true
|
||||
if ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||
echo "Killing EXO (PID $EXO_PID)"
|
||||
kill "$EXO_PID" || true
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
for i in $(seq 1 60); do
|
||||
if curl -s "http://localhost:${API_PORT}/state" >/dev/null 2>&1; then
|
||||
echo "EXO API ready"
|
||||
break
|
||||
fi
|
||||
if ! ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||
echo "EXO terminated early"; sed -n '1,200p' "$LOG_FILE" || true; exit 1
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
RESULTS_FILE="/tmp/bench_results_${GITHUB_RUN_ID}_${GITHUB_RUN_ATTEMPT}_$(date +%s).json"
|
||||
echo "Results will be saved to: $RESULTS_FILE"
|
||||
echo "RESULTS_FILE=$RESULTS_FILE" >> "$GITHUB_ENV"
|
||||
|
||||
echo "Running bench script with config: $CONFIG_FILE, timeout: $TIMEOUT_SECONDS"
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||
"PYTHONUNBUFFERED=1 uv run --no-project --with pyyaml --with pydantic python .github/scripts/bench.py \
|
||||
--api-port $API_PORT \
|
||||
--config $CONFIG_FILE \
|
||||
--expected-nodes ${EXPECTED_NODES} \
|
||||
--is-primary ${IS_PRIMARY} \
|
||||
--timeout-seconds ${TIMEOUT_SECONDS} \
|
||||
--output $RESULTS_FILE \
|
||||
--git-commit ${GITHUB_SHA} \
|
||||
--hardware-labels ${HARDWARE_LABEL}"
|
||||
|
||||
- name: Install AWS CLI
|
||||
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||
run: |
|
||||
if ! command -v aws &> /dev/null; then
|
||||
echo "AWS CLI not found, installing..."
|
||||
brew install awscli
|
||||
else
|
||||
echo "AWS CLI already installed"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Upload results to S3
|
||||
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.S3_BENCHMARKS_AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
run: |
|
||||
echo "Checking for results file: $RESULTS_FILE"
|
||||
echo "Is primary: ${{ matrix.is_primary }}"
|
||||
|
||||
if [ -f "$RESULTS_FILE" ]; then
|
||||
TIMESTAMP=$(date -u +%Y/%m/%d/%H%M%S)
|
||||
S3_KEY="bench/${TIMESTAMP}_${GITHUB_SHA:0:8}_${GITHUB_RUN_ID}.json"
|
||||
echo "Uploading results to s3://exo-benchmark-results/$S3_KEY"
|
||||
|
||||
aws s3 cp "$RESULTS_FILE" "s3://exo-benchmark-results/$S3_KEY" \
|
||||
--content-type application/json \
|
||||
--metadata "commit=${GITHUB_SHA},run_id=${GITHUB_RUN_ID},branch=${GITHUB_REF_NAME}"
|
||||
|
||||
echo "Results uploaded successfully"
|
||||
echo "View at: https://exo-benchmark-results.s3.amazonaws.com/$S3_KEY"
|
||||
else
|
||||
echo "Results file not found at: $RESULTS_FILE"
|
||||
echo "Skipping upload"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Cleanup EXO_HOME
|
||||
run: |
|
||||
echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||
rm -rf "$EXO_HOME"
|
||||
shell: bash
|
||||
if: always()
|
||||
35
.github/workflows/build-app.yml
vendored
35
.github/workflows/build-app.yml
vendored
@@ -1,6 +1,7 @@
|
||||
name: Build EXO macOS DMG
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
@@ -18,6 +19,7 @@ jobs:
|
||||
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
|
||||
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
|
||||
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
|
||||
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT: ${{ secrets.EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT }}
|
||||
AWS_REGION: ${{ secrets.AWS_REGION }}
|
||||
EXO_BUILD_NUMBER: ${{ github.run_number }}
|
||||
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
|
||||
@@ -34,7 +36,7 @@ jobs:
|
||||
|
||||
- name: Derive release version from tag
|
||||
run: |
|
||||
if [[ "$GITHUB_REF_NAME" == "test-app" ]]; then
|
||||
if [[ "$GITHUB_REF_NAME" == "test-app" || "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
VERSION="0.0.0-alpha.0"
|
||||
echo "IS_ALPHA=true" >> $GITHUB_ENV
|
||||
else
|
||||
@@ -47,6 +49,32 @@ jobs:
|
||||
fi
|
||||
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
|
||||
|
||||
- name: Compute build version from semver
|
||||
run: |
|
||||
VERSION="$RELEASE_VERSION"
|
||||
# Extract major.minor.patch (strip prerelease suffix)
|
||||
BASE_VERSION="${VERSION%%-*}"
|
||||
MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
|
||||
MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
|
||||
PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)
|
||||
|
||||
# Extract prerelease number (e.g., "alpha.2" -> 2, or 999 for releases)
|
||||
if [[ "$VERSION" == *-* ]]; then
|
||||
PRERELEASE_PART="${VERSION#*-}"
|
||||
PRERELEASE_NUM="${PRERELEASE_PART##*.}"
|
||||
# Default to 0 if not a number
|
||||
if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
|
||||
PRERELEASE_NUM=0
|
||||
fi
|
||||
else
|
||||
PRERELEASE_NUM=999
|
||||
fi
|
||||
|
||||
# Compute: PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR)
|
||||
BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
|
||||
echo "EXO_BUILD_VERSION=$BUILD_VERSION" >> $GITHUB_ENV
|
||||
echo "Computed build version: $BUILD_VERSION from $VERSION"
|
||||
|
||||
- name: Ensure tag commit is on main
|
||||
if: github.ref_type == 'tag'
|
||||
run: |
|
||||
@@ -162,11 +190,12 @@ jobs:
|
||||
-configuration Release \
|
||||
-derivedDataPath build \
|
||||
MARKETING_VERSION="$RELEASE_VERSION" \
|
||||
CURRENT_PROJECT_VERSION="$EXO_BUILD_NUMBER" \
|
||||
CURRENT_PROJECT_VERSION="$EXO_BUILD_VERSION" \
|
||||
EXO_BUILD_TAG="$RELEASE_VERSION" \
|
||||
EXO_BUILD_COMMIT="$GITHUB_SHA" \
|
||||
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
|
||||
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
|
||||
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT="$EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT" \
|
||||
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
|
||||
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
|
||||
mkdir -p ../../output
|
||||
@@ -294,5 +323,5 @@ jobs:
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
|
||||
if [[ "$IS_ALPHA" != "true" ]]; then
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
fi
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
|
||||
117
.github/workflows/pipeline.yml
vendored
117
.github/workflows/pipeline.yml
vendored
@@ -20,6 +20,12 @@ jobs:
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
@@ -88,9 +94,19 @@ jobs:
|
||||
|
||||
- uses: ./.github/actions/typecheck
|
||||
|
||||
nix-flake-check:
|
||||
name: Check Nix flake
|
||||
runs-on: ubuntu-latest
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
runs-on: ${{ matrix.runner }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- runner: macos-26
|
||||
system: aarch64-darwin
|
||||
- runner: ubuntu-latest
|
||||
system: x86_64-linux
|
||||
- runner: ubuntu-24.04-arm
|
||||
system: aarch64-linux
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -101,83 +117,20 @@ jobs:
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- name: Run nix flake check
|
||||
run: |
|
||||
nix flake check
|
||||
shell: bash
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
# ci:
|
||||
# needs: typecheck
|
||||
# runs-on: ubuntu-latest
|
||||
# permissions:
|
||||
# contents: read
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# steps:
|
||||
# - name: Checkout repository
|
||||
# uses: actions/checkout@v4
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
# token: ${{ secrets.GITHUB_TOKEN }}
|
||||
# lfs: true
|
||||
#
|
||||
# - name: Configure git user
|
||||
# run: |
|
||||
# git config --local user.email "github-actions@users.noreply.github.com"
|
||||
# git config --local user.name "github-actions bot"
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Pull LFS files
|
||||
# run: |
|
||||
# echo "Pulling Git LFS files..."
|
||||
# git lfs pull
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Setup EXO_HOME and API_PORT
|
||||
# run: |
|
||||
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
|
||||
# # Generate random port (macOS compatible method)
|
||||
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
|
||||
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
|
||||
# echo "Created EXO_HOME: $EXO_HOME"
|
||||
# echo "Generated API_PORT: $API_PORT"
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Setup Nix Environment
|
||||
# run: |
|
||||
# echo "Checking for nix installation..."
|
||||
#
|
||||
# # Check if nix binary exists directly
|
||||
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
# echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
# nix --version
|
||||
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
# echo "Found nix profile script, sourcing..."
|
||||
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
# nix --version
|
||||
# elif command -v nix >/dev/null 2>&1; then
|
||||
# echo "Nix already in PATH"
|
||||
# nix --version
|
||||
# else
|
||||
# echo "Nix not found. Debugging info:"
|
||||
# echo "Contents of /nix/var/nix/profiles/default/:"
|
||||
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||
# exit 1
|
||||
# fi
|
||||
# shell: bash
|
||||
#
|
||||
# - uses: ./.github/actions/lint-check
|
||||
#
|
||||
# - uses: ./.github/actions/unit-test
|
||||
#
|
||||
# - name: Cleanup EXO_HOME
|
||||
# run: |
|
||||
# echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||
# rm -rf "$EXO_HOME"
|
||||
# shell: bash
|
||||
# if: always()
|
||||
- name: Build all Nix outputs
|
||||
run: |
|
||||
nix flake show --json | jq -r '
|
||||
[
|
||||
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
|
||||
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
|
||||
] | .[]
|
||||
' | xargs nix build
|
||||
|
||||
- name: Run nix flake check
|
||||
run: nix flake check
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -16,6 +16,7 @@ digest.txt
|
||||
*.xcuserdatad/
|
||||
**/.DS_Store
|
||||
app/EXO/build/
|
||||
dist/
|
||||
|
||||
|
||||
# rust
|
||||
|
||||
156
.mlx_typings/mlx_lm/models/deepseek_v3.pyi
Normal file
156
.mlx_typings/mlx_lm/models/deepseek_v3.pyi
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Type stubs for mlx_lm.models.deepseek_v3"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
moe_intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
n_shared_experts: Optional[int]
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
kv_lora_rank: int
|
||||
q_lora_rank: Optional[int]
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
topk_method: str
|
||||
scoring_func: str
|
||||
norm_topk_prob: bool
|
||||
n_group: int
|
||||
topk_group: int
|
||||
num_experts_per_tok: int
|
||||
moe_layer_freq: int
|
||||
first_k_dense_replace: int
|
||||
max_position_embeddings: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Any]]
|
||||
attention_bias: bool
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
num_heads: int
|
||||
max_position_embeddings: int
|
||||
rope_theta: float
|
||||
q_lora_rank: Optional[int]
|
||||
qk_rope_head_dim: int
|
||||
kv_lora_rank: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
q_head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
q_a_proj: nn.Linear
|
||||
q_a_layernorm: nn.RMSNorm
|
||||
q_b_proj: nn.Linear
|
||||
kv_a_proj_with_mqa: nn.Linear
|
||||
kv_a_layernorm: nn.RMSNorm
|
||||
kv_b_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
rope: Any
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
gate_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
hidden_size: Optional[int] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
config: ModelArgs
|
||||
top_k: int
|
||||
norm_topk_prob: bool
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
n_group: int
|
||||
topk_group: int
|
||||
weight: mx.array
|
||||
e_score_correction_bias: mx.array
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
config: ModelArgs
|
||||
num_experts_per_tok: int
|
||||
switch_mlp: SwitchGLU
|
||||
gate: MoEGate
|
||||
shared_experts: DeepseekV3MLP
|
||||
sharding_group: Optional[mx.distributed.Group]
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
self_attn: DeepseekV3Attention
|
||||
mlp: DeepseekV3MLP | DeepseekV3MoE
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
vocab_size: int
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[DeepseekV3DecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
model_type: str
|
||||
model: DeepseekV3Model
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
@property
|
||||
def layers(self) -> list[DeepseekV3DecoderLayer]: ...
|
||||
@@ -57,6 +57,11 @@ class SwiGLU(nn.Module):
|
||||
def __call__(self, x, gate): ...
|
||||
|
||||
class SwitchGLU(nn.Module):
|
||||
gate_proj: SwitchLinear
|
||||
up_proj: SwitchLinear
|
||||
down_proj: SwitchLinear
|
||||
activation: SwiGLU
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
|
||||
@@ -4,6 +4,7 @@ This type stub file was generated by pyright.
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
@@ -103,37 +104,55 @@ class TokenizerWrapper:
|
||||
Accessing any attribute other than the ``detokenizer`` is forwarded to the
|
||||
huggingface tokenizer.
|
||||
"""
|
||||
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
|
||||
def add_eos_token(self, token: str): # -> None:
|
||||
...
|
||||
@property
|
||||
def has_thinking(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def think_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def think_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def has_tool_calling(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def tool_call_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def tool_call_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def detokenizer(self): # -> NaiveStreamingDetokenizer:
|
||||
"""
|
||||
Get a stateful streaming detokenizer.
|
||||
"""
|
||||
|
||||
def __getattr__(self, attr): # -> set[Any] | Any:
|
||||
...
|
||||
def __setattr__(self, attr, value): # -> None:
|
||||
...
|
||||
_tokenizer: PreTrainedTokenizerFast
|
||||
eos_token_id: int | None
|
||||
eos_token: str | None
|
||||
bos_token_id: int | None
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
all_special_tokens: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
detokenizer_class: Any = ...,
|
||||
eos_token_ids: list[int] | None = ...,
|
||||
chat_template: Any = ...,
|
||||
tool_parser: Any = ...,
|
||||
tool_call_start: str | None = ...,
|
||||
tool_call_end: str | None = ...,
|
||||
) -> None: ...
|
||||
def encode(self, text: str, **kwargs: Any) -> list[int]: ...
|
||||
def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tokenize: bool = False,
|
||||
add_generation_prompt: bool = False,
|
||||
tools: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> str: ...
|
||||
def get_vocab(self) -> dict[str, int]: ...
|
||||
def add_eos_token(self, token: str) -> None: ...
|
||||
@property
|
||||
def has_thinking(self) -> bool: ...
|
||||
@property
|
||||
def think_start(self) -> str | None: ...
|
||||
@property
|
||||
def think_end(self) -> str | None: ...
|
||||
@property
|
||||
def has_tool_calling(self) -> bool: ...
|
||||
@property
|
||||
def tool_call_start(self) -> str | None: ...
|
||||
@property
|
||||
def tool_call_end(self) -> str | None: ...
|
||||
@property
|
||||
def detokenizer(self) -> NaiveStreamingDetokenizer:
|
||||
"""Get a stateful streaming detokenizer."""
|
||||
|
||||
def __getattr__(self, attr: str) -> Any: ...
|
||||
def __setattr__(self, attr: str, value: Any) -> None: ...
|
||||
|
||||
class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
|
||||
@@ -146,18 +165,11 @@ class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
def batch_decode(self, *args, **kwargs): # -> list[str]:
|
||||
...
|
||||
|
||||
def load_tokenizer(
|
||||
def load(
|
||||
model_path: Path,
|
||||
tokenizer_config_extra=...,
|
||||
return_tokenizer=...,
|
||||
eos_token_ids=...,
|
||||
) -> (
|
||||
TokenizerWrapper
|
||||
| type[SPMStreamingDetokenizer]
|
||||
| partial[SPMStreamingDetokenizer]
|
||||
| type[BPEStreamingDetokenizer]
|
||||
| type[NaiveStreamingDetokenizer]
|
||||
):
|
||||
tokenizer_config_extra: dict[str, Any] | None = None,
|
||||
eos_token_ids: list[int] | int | None = None,
|
||||
) -> TokenizerWrapper:
|
||||
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||
detokenizer to use.
|
||||
|
||||
@@ -165,4 +177,7 @@ def load_tokenizer(
|
||||
a Hugging Face repo ID.
|
||||
"""
|
||||
|
||||
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
|
||||
# Alias for backward compatibility
|
||||
load_tokenizer = load
|
||||
|
||||
def no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...
|
||||
|
||||
3
.prettierrc
Normal file
3
.prettierrc
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"useTabs": true
|
||||
}
|
||||
6
.swift-format
Normal file
6
.swift-format
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"version": 1,
|
||||
"indentation": {
|
||||
"spaces": 4
|
||||
}
|
||||
}
|
||||
96
AGENTS.md
Normal file
96
AGENTS.md
Normal file
@@ -0,0 +1,96 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to AI coding agents when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
|
||||
|
||||
## Build & Run Commands
|
||||
|
||||
```bash
|
||||
# Build the dashboard (required before running exo)
|
||||
cd dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo (starts both master and worker with API at http://localhost:52415)
|
||||
uv run exo
|
||||
|
||||
# Run with verbose logging
|
||||
uv run exo -v # or -vv for more verbose
|
||||
|
||||
# Run tests (excludes slow tests by default)
|
||||
uv run pytest
|
||||
|
||||
# Run all tests including slow tests
|
||||
uv run pytest -m ""
|
||||
|
||||
# Run a specific test file
|
||||
uv run pytest src/exo/shared/tests/test_election.py
|
||||
|
||||
# Run a specific test function
|
||||
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
|
||||
|
||||
# Type checking (strict mode)
|
||||
uv run basedpyright
|
||||
|
||||
# Linting
|
||||
uv run ruff check
|
||||
|
||||
# Format code (using nix)
|
||||
nix fmt
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
A single exo `Node` (src/exo/main.py) runs multiple components:
|
||||
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
|
||||
- **Worker**: Handles inference tasks, downloads models, manages runner processes
|
||||
- **Master**: Coordinates cluster state, places model instances across nodes
|
||||
- **Election**: Bully algorithm for master election
|
||||
- **API**: FastAPI server for OpenAI-compatible chat completions
|
||||
|
||||
### Message Flow
|
||||
Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
|
||||
- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers
|
||||
- `LOCAL_EVENTS`: Workers send events to master for indexing
|
||||
- `COMMANDS`: Workers/API send commands to master
|
||||
- `ELECTION_MESSAGES`: Election protocol messages
|
||||
- `CONNECTION_MESSAGES`: libp2p connection updates
|
||||
|
||||
### Event Sourcing
|
||||
The system uses event sourcing for state management:
|
||||
- `State` (src/exo/shared/types/state.py): Immutable state object
|
||||
- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state
|
||||
- Master indexes events and broadcasts; workers apply indexed events
|
||||
|
||||
### Key Type Hierarchy
|
||||
- `src/exo/shared/types/`: Pydantic models for all shared types
|
||||
- `events.py`: Event types (discriminated union)
|
||||
- `commands.py`: Command types
|
||||
- `tasks.py`: Task types for worker execution
|
||||
- `state.py`: Cluster state model
|
||||
|
||||
### Rust Components
|
||||
Rust code in `rust/` provides:
|
||||
- `networking`: libp2p networking (gossipsub, peer discovery)
|
||||
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
|
||||
- `system_custodian`: System-level operations
|
||||
|
||||
### Dashboard
|
||||
Svelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.
|
||||
|
||||
## Code Style Requirements
|
||||
|
||||
From .cursorrules:
|
||||
- Strict, exhaustive typing - never bypass the type-checker
|
||||
- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives
|
||||
- Pydantic models with `frozen=True` and `strict=True`
|
||||
- Pure functions with injectable effect handlers for side-effects
|
||||
- Descriptive names - no abbreviations or 3-letter acronyms
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
41
MISSED_THINGS.md
Normal file
41
MISSED_THINGS.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# Missed things
|
||||
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
|
||||
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
25
README.md
25
README.md
@@ -8,7 +8,7 @@
|
||||
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/72NsF6ux" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://x.com/exolabs" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/twitter/follow/exolabs?style=social" alt="X"></a>
|
||||
<a href="https://www.apache.org/licenses/LICENSE-2.0.html" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/License-Apache2.0-blue.svg" alt="License: Apache-2.0"></a>
|
||||
</p>
|
||||
@@ -166,6 +166,24 @@ Download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-
|
||||
|
||||
The app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.
|
||||
|
||||
#### Uninstalling the macOS App
|
||||
|
||||
The recommended way to uninstall is through the app itself: click the menu bar icon → Advanced → Uninstall. This cleanly removes all system components.
|
||||
|
||||
If you've already deleted the app, you can run the standalone uninstaller script:
|
||||
|
||||
```bash
|
||||
sudo ./app/EXO/uninstall-exo.sh
|
||||
```
|
||||
|
||||
This removes:
|
||||
- Network setup LaunchDaemon
|
||||
- Network configuration script
|
||||
- Log files
|
||||
- The "exo" network location
|
||||
|
||||
**Note:** You'll need to manually remove EXO from Login Items in System Settings → General → Login Items.
|
||||
|
||||
---
|
||||
|
||||
### Enabling RDMA on macOS
|
||||
@@ -287,7 +305,10 @@ curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
|
||||
- List all models: `curl http://localhost:52415/models`
|
||||
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
|
||||
|
||||
For further details, see API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
|
||||
For further details, see:
|
||||
|
||||
- API basic documentation in [docs/api.md](docs/api.md).
|
||||
- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -12,20 +12,25 @@ struct ContentView: View {
|
||||
@EnvironmentObject private var controller: ExoProcessController
|
||||
@EnvironmentObject private var stateService: ClusterStateService
|
||||
@EnvironmentObject private var networkStatusService: NetworkStatusService
|
||||
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@EnvironmentObject private var updater: SparkleUpdater
|
||||
@State private var focusedNode: NodeViewModel?
|
||||
@State private var deletingInstanceIDs: Set<String> = []
|
||||
@State private var showAllNodes = false
|
||||
@State private var showAllInstances = false
|
||||
@State private var showAdvanced = false
|
||||
@State private var showDebugInfo = false
|
||||
@State private var bugReportInFlight = false
|
||||
@State private var bugReportMessage: String?
|
||||
@State private var showAdvancedOptions = false
|
||||
@State private var uninstallInProgress = false
|
||||
@State private var pendingNamespace: String = ""
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
statusSection
|
||||
if shouldShowLocalNetworkWarning {
|
||||
localNetworkWarningBanner
|
||||
}
|
||||
if shouldShowClusterDetails {
|
||||
Divider()
|
||||
overviewSection
|
||||
@@ -40,6 +45,7 @@ struct ContentView: View {
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowClusterDetails)
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowInstances)
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowLocalNetworkWarning)
|
||||
.padding()
|
||||
.frame(width: 340)
|
||||
.onAppear {
|
||||
@@ -49,9 +55,62 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
private var localNetworkWarningBanner: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack(spacing: 6) {
|
||||
Image(systemName: "exclamationmark.triangle.fill")
|
||||
.foregroundColor(.orange)
|
||||
Text("Local Network Access Issue")
|
||||
.font(.caption)
|
||||
.fontWeight(.semibold)
|
||||
}
|
||||
Text(
|
||||
"Device discovery won't work. To fix:\n1. Quit EXO\n2. Open System Settings → Privacy & Security → Local Network\n3. Toggle EXO off, then back on\n4. Relaunch EXO"
|
||||
)
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
Button {
|
||||
openLocalNetworkSettings()
|
||||
} label: {
|
||||
Text("Open Settings")
|
||||
.font(.caption2)
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
}
|
||||
.padding(8)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.fill(Color.orange.opacity(0.1))
|
||||
)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.stroke(Color.orange.opacity(0.3), lineWidth: 1)
|
||||
)
|
||||
}
|
||||
|
||||
private func openLocalNetworkSettings() {
|
||||
// Open Privacy & Security settings - Local Network section
|
||||
if let url = URL(
|
||||
string: "x-apple.systempreferences:com.apple.preference.security?Privacy_LocalNetwork")
|
||||
{
|
||||
NSWorkspace.shared.open(url)
|
||||
}
|
||||
}
|
||||
|
||||
private var topologySection: some View {
|
||||
Group {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(
|
||||
localNodeId: stateService.localNodeId), !topology.nodes.isEmpty
|
||||
{
|
||||
TopologyMiniView(topology: topology)
|
||||
}
|
||||
}
|
||||
@@ -85,8 +144,10 @@ struct ContentView: View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
HStack {
|
||||
VStack(alignment: .leading) {
|
||||
Text("\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB")
|
||||
.font(.headline)
|
||||
Text(
|
||||
"\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB"
|
||||
)
|
||||
.font(.headline)
|
||||
Text("Memory")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
@@ -195,13 +256,7 @@ struct ContentView: View {
|
||||
Divider()
|
||||
.padding(.vertical, 4)
|
||||
}
|
||||
controlButton(title: "Check for Updates") {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
.padding(.bottom, 8)
|
||||
advancedOptionsSection
|
||||
.padding(.bottom, 8)
|
||||
debugSection
|
||||
advancedSection
|
||||
.padding(.bottom, 8)
|
||||
controlButton(title: "Quit", tint: .secondary) {
|
||||
controller.stop()
|
||||
@@ -210,7 +265,57 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void) -> some View {
|
||||
private var advancedSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
Text("Advanced")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
Spacer()
|
||||
collapseButton(isExpanded: $showAdvanced)
|
||||
}
|
||||
.animation(nil, value: showAdvanced)
|
||||
if showAdvanced {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Cluster Namespace")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
HStack {
|
||||
TextField("optional", text: $pendingNamespace)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingNamespace = controller.customNamespace
|
||||
}
|
||||
Button("Save & Restart") {
|
||||
controller.customNamespace = pendingNamespace
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingNamespace == controller.customNamespace)
|
||||
}
|
||||
}
|
||||
HoverButton(title: "Check for Updates", small: true) {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
debugSection
|
||||
HoverButton(title: "Uninstall", tint: .red, small: true) {
|
||||
showUninstallConfirmationAlert()
|
||||
}
|
||||
.disabled(uninstallInProgress)
|
||||
}
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.25), value: showAdvanced)
|
||||
}
|
||||
|
||||
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void)
|
||||
-> some View
|
||||
{
|
||||
HoverButton(title: title, tint: tint, trailingSystemImage: nil, action: action)
|
||||
}
|
||||
|
||||
@@ -241,9 +346,12 @@ struct ContentView: View {
|
||||
Button {
|
||||
isExpanded.wrappedValue.toggle()
|
||||
} label: {
|
||||
Label(isExpanded.wrappedValue ? "Hide" : "Show All", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
Label(
|
||||
isExpanded.wrappedValue ? "Hide" : "Show All",
|
||||
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
|
||||
)
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.font(.caption2)
|
||||
@@ -331,57 +439,16 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private var advancedOptionsSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
Text("Advanced Options")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
Spacer()
|
||||
collapseButton(isExpanded: $showAdvancedOptions)
|
||||
}
|
||||
.animation(nil, value: showAdvancedOptions)
|
||||
if showAdvancedOptions {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Cluster Namespace")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
HStack {
|
||||
TextField("optional", text: $pendingNamespace)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingNamespace = controller.customNamespace
|
||||
}
|
||||
Button("Save & Restart") {
|
||||
controller.customNamespace = pendingNamespace
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingNamespace == controller.customNamespace)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.25), value: showAdvancedOptions)
|
||||
}
|
||||
|
||||
private var debugSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
Text("Debug Info")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
Spacer()
|
||||
collapseButton(isExpanded: $showDebugInfo)
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
HoverButton(
|
||||
title: "Debug Info",
|
||||
tint: .primary,
|
||||
trailingSystemImage: showDebugInfo ? "chevron.up" : "chevron.down",
|
||||
small: true
|
||||
) {
|
||||
showDebugInfo.toggle()
|
||||
}
|
||||
.animation(nil, value: showDebugInfo)
|
||||
if showDebugInfo {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Version: \(buildTag)")
|
||||
@@ -394,15 +461,63 @@ struct ContentView: View {
|
||||
.font(.caption2)
|
||||
.foregroundColor(thunderboltStatusColor)
|
||||
interfaceIpList
|
||||
rdmaStatusView
|
||||
sendBugReportButton
|
||||
.padding(.top, 6)
|
||||
}
|
||||
.padding(.leading, 8)
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.25), value: showDebugInfo)
|
||||
}
|
||||
|
||||
private var rdmaStatusView: some View {
|
||||
let rdma = networkStatusService.status.rdmaStatus
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
Text("RDMA: \(rdmaStatusText(rdma))")
|
||||
.font(.caption2)
|
||||
.foregroundColor(rdmaStatusColor(rdma))
|
||||
if !rdma.devices.isEmpty {
|
||||
Text(" Devices: \(rdma.devices.joined(separator: ", "))")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
if !rdma.activePorts.isEmpty {
|
||||
Text(" Active Ports:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(rdma.activePorts, id: \.device) { port in
|
||||
Text(" \(port.device) port \(port.port): \(port.state)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.green)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func rdmaStatusText(_ rdma: RDMAStatus) -> String {
|
||||
switch rdma.rdmaCtlEnabled {
|
||||
case .some(true):
|
||||
return "Enabled"
|
||||
case .some(false):
|
||||
return "Disabled"
|
||||
case nil:
|
||||
return rdma.devices.isEmpty ? "Not Available" : "Available"
|
||||
}
|
||||
}
|
||||
|
||||
private func rdmaStatusColor(_ rdma: RDMAStatus) -> Color {
|
||||
switch rdma.rdmaCtlEnabled {
|
||||
case .some(true):
|
||||
return .green
|
||||
case .some(false):
|
||||
return .orange
|
||||
case nil:
|
||||
return rdma.devices.isEmpty ? .secondary : .green
|
||||
}
|
||||
}
|
||||
|
||||
private var sendBugReportButton: some View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Button {
|
||||
@@ -492,6 +607,88 @@ struct ContentView: View {
|
||||
bugReportInFlight = false
|
||||
}
|
||||
|
||||
private func showUninstallConfirmationAlert() {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Uninstall EXO"
|
||||
alert.informativeText = """
|
||||
This will remove EXO and all its system components:
|
||||
|
||||
• Network configuration daemon
|
||||
• Launch at login registration
|
||||
• EXO network location
|
||||
|
||||
The app will be moved to Trash.
|
||||
"""
|
||||
alert.alertStyle = .warning
|
||||
alert.addButton(withTitle: "Uninstall")
|
||||
alert.addButton(withTitle: "Cancel")
|
||||
|
||||
// Style the Uninstall button as destructive
|
||||
if let uninstallButton = alert.buttons.first {
|
||||
uninstallButton.hasDestructiveAction = true
|
||||
}
|
||||
|
||||
let response = alert.runModal()
|
||||
if response == .alertFirstButtonReturn {
|
||||
performUninstall()
|
||||
}
|
||||
}
|
||||
|
||||
private func performUninstall() {
|
||||
uninstallInProgress = true
|
||||
|
||||
// Stop EXO process first
|
||||
controller.cancelPendingLaunch()
|
||||
controller.stop()
|
||||
stateService.stopPolling()
|
||||
|
||||
// Run the privileged uninstall on a background thread
|
||||
// Using .utility QoS to avoid priority inversion with NSAppleScript's subprocess
|
||||
DispatchQueue.global(qos: .utility).async {
|
||||
do {
|
||||
// Remove network setup daemon and components (requires admin privileges)
|
||||
try NetworkSetupHelper.uninstall()
|
||||
|
||||
DispatchQueue.main.async {
|
||||
// Unregister from launch at login
|
||||
LaunchAtLoginHelper.disable()
|
||||
|
||||
// Move app to trash
|
||||
self.moveAppToTrash()
|
||||
|
||||
// Quit the app
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {
|
||||
NSApplication.shared.terminate(nil)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
DispatchQueue.main.async {
|
||||
self.showErrorAlert(message: error.localizedDescription)
|
||||
self.uninstallInProgress = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func showErrorAlert(message: String) {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Uninstall Failed"
|
||||
alert.informativeText = message
|
||||
alert.alertStyle = .critical
|
||||
alert.addButton(withTitle: "OK")
|
||||
alert.runModal()
|
||||
}
|
||||
|
||||
private func moveAppToTrash() {
|
||||
guard let appURL = Bundle.main.bundleURL as URL? else { return }
|
||||
do {
|
||||
try FileManager.default.trashItem(at: appURL, resultingItemURL: nil)
|
||||
} catch {
|
||||
// If we can't trash the app, that's OK - user can do it manually
|
||||
// The important system components have already been cleaned up
|
||||
}
|
||||
}
|
||||
|
||||
private var buildTag: String {
|
||||
Bundle.main.infoDictionary?["EXOBuildTag"] as? String ?? "unknown"
|
||||
}
|
||||
@@ -505,14 +702,27 @@ private struct HoverButton: View {
|
||||
let title: String
|
||||
let tint: Color
|
||||
let trailingSystemImage: String?
|
||||
let small: Bool
|
||||
let action: () -> Void
|
||||
|
||||
init(
|
||||
title: String, tint: Color = .primary, trailingSystemImage: String? = nil,
|
||||
small: Bool = false, action: @escaping () -> Void
|
||||
) {
|
||||
self.title = title
|
||||
self.tint = tint
|
||||
self.trailingSystemImage = trailingSystemImage
|
||||
self.small = small
|
||||
self.action = action
|
||||
}
|
||||
|
||||
@State private var isHovering = false
|
||||
|
||||
var body: some View {
|
||||
Button(action: action) {
|
||||
HStack {
|
||||
Text(title)
|
||||
.font(small ? .caption : nil)
|
||||
Spacer()
|
||||
if let systemName = trailingSystemImage {
|
||||
Image(systemName: systemName)
|
||||
@@ -520,8 +730,8 @@ private struct HoverButton: View {
|
||||
}
|
||||
}
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
.padding(.vertical, 6)
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, small ? 4 : 6)
|
||||
.padding(.horizontal, small ? 6 : 8)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 6)
|
||||
.fill(
|
||||
@@ -536,4 +746,3 @@ private struct HoverButton: View {
|
||||
.onHover { isHovering = $0 }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
import AppKit
|
||||
import CoreImage
|
||||
import CoreImage.CIFilterBuiltins
|
||||
import ServiceManagement
|
||||
import Sparkle
|
||||
import SwiftUI
|
||||
import ServiceManagement
|
||||
import UserNotifications
|
||||
import os.log
|
||||
|
||||
@@ -19,6 +19,7 @@ struct EXOApp: App {
|
||||
@StateObject private var controller: ExoProcessController
|
||||
@StateObject private var stateService: ClusterStateService
|
||||
@StateObject private var networkStatusService: NetworkStatusService
|
||||
@StateObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@StateObject private var updater: SparkleUpdater
|
||||
private let terminationObserver: TerminationObserver
|
||||
private let ciContext = CIContext(options: nil)
|
||||
@@ -37,9 +38,13 @@ struct EXOApp: App {
|
||||
_stateService = StateObject(wrappedValue: service)
|
||||
let networkStatus = NetworkStatusService()
|
||||
_networkStatusService = StateObject(wrappedValue: networkStatus)
|
||||
let localNetwork = LocalNetworkChecker()
|
||||
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
|
||||
_updater = StateObject(wrappedValue: updater)
|
||||
enableLaunchAtLoginIfNeeded()
|
||||
NetworkSetupHelper.ensureLaunchDaemonInstalled()
|
||||
// Check local network access BEFORE launching exo
|
||||
localNetwork.check()
|
||||
controller.scheduleLaunch(after: 15)
|
||||
service.startPolling()
|
||||
networkStatus.startPolling()
|
||||
@@ -51,6 +56,7 @@ struct EXOApp: App {
|
||||
.environmentObject(controller)
|
||||
.environmentObject(stateService)
|
||||
.environmentObject(networkStatusService)
|
||||
.environmentObject(localNetworkChecker)
|
||||
.environmentObject(updater)
|
||||
} label: {
|
||||
menuBarIcon
|
||||
@@ -107,7 +113,7 @@ struct EXOApp: App {
|
||||
filter.contrast = 0.9
|
||||
|
||||
guard let output = filter.outputImage,
|
||||
let rendered = ciContext.createCGImage(output, from: output.extent)
|
||||
let rendered = ciContext.createCGImage(output, from: output.extent)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
@@ -120,7 +126,26 @@ struct EXOApp: App {
|
||||
do {
|
||||
try SMAppService.mainApp.register()
|
||||
} catch {
|
||||
Logger().error("Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
Logger().error(
|
||||
"Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for managing EXO's launch-at-login registration
|
||||
enum LaunchAtLoginHelper {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LaunchAtLogin")
|
||||
|
||||
/// Unregisters EXO from launching at login
|
||||
static func disable() {
|
||||
guard SMAppService.mainApp.status == .enabled else { return }
|
||||
do {
|
||||
try SMAppService.mainApp.unregister()
|
||||
logger.info("Unregistered EXO from launch at login")
|
||||
} catch {
|
||||
logger.error(
|
||||
"Failed to unregister EXO from launch at login: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,7 +170,7 @@ final class SparkleUpdater: NSObject, ObservableObject {
|
||||
center.requestAuthorization(options: [.alert, .sound]) { _, _ in }
|
||||
controller.updater.automaticallyChecksForUpdates = true
|
||||
controller.updater.automaticallyDownloadsUpdates = false
|
||||
controller.updater.updateCheckInterval = 900 // 15 minutes
|
||||
controller.updater.updateCheckInterval = 900 // 15 minutes
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 5) { [weak controller] in
|
||||
controller?.updater.checkForUpdatesInBackground()
|
||||
}
|
||||
@@ -212,7 +237,8 @@ private final class ExoNotificationDelegate: NSObject, UNUserNotificationCenterD
|
||||
func userNotificationCenter(
|
||||
_ center: UNUserNotificationCenter,
|
||||
willPresent notification: UNNotification,
|
||||
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) -> Void
|
||||
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) ->
|
||||
Void
|
||||
) {
|
||||
completionHandler([.banner, .list, .sound])
|
||||
}
|
||||
|
||||
@@ -31,7 +31,8 @@ final class ExoProcessController: ObservableObject {
|
||||
@Published private(set) var launchCountdownSeconds: Int?
|
||||
@Published var customNamespace: String = {
|
||||
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
|
||||
}() {
|
||||
}()
|
||||
{
|
||||
didSet {
|
||||
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
|
||||
}
|
||||
@@ -221,7 +222,9 @@ final class ExoProcessController: ObservableObject {
|
||||
if let tag = Bundle.main.infoDictionary?["EXOBuildTag"] as? String, !tag.isEmpty {
|
||||
return tag
|
||||
}
|
||||
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String, !short.isEmpty {
|
||||
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String,
|
||||
!short.isEmpty
|
||||
{
|
||||
return short
|
||||
}
|
||||
return "dev"
|
||||
|
||||
@@ -8,5 +8,15 @@
|
||||
<string>$(EXO_BUILD_TAG)</string>
|
||||
<key>EXOBuildCommit</key>
|
||||
<string>$(EXO_BUILD_COMMIT)</string>
|
||||
<key>EXOBugReportPresignedUrlEndpoint</key>
|
||||
<string>$(EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT)</string>
|
||||
<key>NSLocalNetworkUsageDescription</key>
|
||||
<string>EXO needs local network access to discover and connect to other devices in your cluster for distributed AI inference.</string>
|
||||
<key>NSBonjourServices</key>
|
||||
<array>
|
||||
<string>_p2p._tcp</string>
|
||||
<string>_p2p._udp</string>
|
||||
<string>_libp2p._udp</string>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
||||
|
||||
@@ -16,10 +16,13 @@ struct ClusterState: Decodable {
|
||||
self.instances = rawInstances.mapValues(\.instance)
|
||||
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
|
||||
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
|
||||
let rawTasks = try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
let rawTasks =
|
||||
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
self.tasks = rawTasks.compactMapValues(\.task)
|
||||
self.topology = try container.decodeIfPresent(Topology.self, forKey: .topology)
|
||||
let rawDownloads = try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads) ?? [:]
|
||||
let rawDownloads =
|
||||
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
|
||||
?? [:]
|
||||
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
|
||||
}
|
||||
|
||||
@@ -41,7 +44,8 @@ private struct TaggedInstance: Decodable {
|
||||
let payloads = try container.decode([String: ClusterInstancePayload].self)
|
||||
guard let entry = payloads.first else {
|
||||
throw DecodingError.dataCorrupted(
|
||||
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
|
||||
DecodingError.Context(
|
||||
codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
|
||||
)
|
||||
}
|
||||
self.instance = ClusterInstance(
|
||||
@@ -77,7 +81,8 @@ struct RunnerStatusSummary: Decodable {
|
||||
let payloads = try container.decode([String: RunnerStatusDetail].self)
|
||||
guard let entry = payloads.first else {
|
||||
throw DecodingError.dataCorrupted(
|
||||
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
|
||||
DecodingError.Context(
|
||||
codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
|
||||
)
|
||||
}
|
||||
self.status = entry.key
|
||||
@@ -257,7 +262,9 @@ struct ChatCompletionTaskParameters: Decodable, Equatable {
|
||||
|
||||
func promptPreview() -> String? {
|
||||
guard let messages else { return nil }
|
||||
if let userMessage = messages.last(where: { $0.role?.lowercased() == "user" && ($0.content?.isEmpty == false) }) {
|
||||
if let userMessage = messages.last(where: {
|
||||
$0.role?.lowercased() == "user" && ($0.content?.isEmpty == false)
|
||||
}) {
|
||||
return userMessage.content
|
||||
}
|
||||
return messages.last?.content
|
||||
@@ -365,5 +372,3 @@ extension ClusterState {
|
||||
|
||||
func availableModels() -> [ModelOption] { [] }
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import CryptoKit
|
||||
import Foundation
|
||||
|
||||
struct BugReportOutcome: Equatable {
|
||||
@@ -7,17 +6,17 @@ struct BugReportOutcome: Equatable {
|
||||
}
|
||||
|
||||
enum BugReportError: LocalizedError {
|
||||
case missingCredentials
|
||||
case invalidEndpoint
|
||||
case presignedUrlFailed(String)
|
||||
case uploadFailed(String)
|
||||
case collectFailed(String)
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .missingCredentials:
|
||||
return "Bug report upload credentials are not set."
|
||||
case .invalidEndpoint:
|
||||
return "Bug report endpoint is invalid."
|
||||
case .presignedUrlFailed(let message):
|
||||
return "Failed to get presigned URLs: \(message)"
|
||||
case .uploadFailed(let message):
|
||||
return "Bug report upload failed: \(message)"
|
||||
case .collectFailed(let message):
|
||||
@@ -27,11 +26,13 @@ enum BugReportError: LocalizedError {
|
||||
}
|
||||
|
||||
struct BugReportService {
|
||||
struct AWSConfig {
|
||||
let accessKey: String
|
||||
let secretKey: String
|
||||
let region: String
|
||||
let bucket: String
|
||||
private struct PresignedUrlsRequest: Codable {
|
||||
let keys: [String]
|
||||
}
|
||||
|
||||
private struct PresignedUrlsResponse: Codable {
|
||||
let urls: [String: String]
|
||||
let expiresIn: Int?
|
||||
}
|
||||
|
||||
func sendReport(
|
||||
@@ -39,9 +40,9 @@ struct BugReportService {
|
||||
now: Date = Date(),
|
||||
isManual: Bool = false
|
||||
) async throws -> BugReportOutcome {
|
||||
let credentials = try loadCredentials()
|
||||
let timestamp = ISO8601DateFormatter().string(from: now)
|
||||
let prefix = "reports/\(timestamp)/"
|
||||
let timestamp = Self.runTimestampString(now)
|
||||
let dayPrefix = Self.dayPrefixString(now)
|
||||
let prefix = "reports/\(dayPrefix)/\(timestamp)/"
|
||||
|
||||
let logData = readLog()
|
||||
let ifconfigText = try await captureIfconfig()
|
||||
@@ -66,28 +67,82 @@ struct BugReportService {
|
||||
("\(prefix)exo.log", logData),
|
||||
("\(prefix)state.json", stateData),
|
||||
("\(prefix)events.json", eventsData),
|
||||
("\(prefix)report.json", reportJSON)
|
||||
("\(prefix)report.json", reportJSON),
|
||||
]
|
||||
|
||||
let uploader = try S3Uploader(config: credentials)
|
||||
for item in uploads {
|
||||
guard let data = item.data else { continue }
|
||||
try await uploader.upload(
|
||||
objectPath: item.path,
|
||||
body: data
|
||||
)
|
||||
let uploadItems: [(key: String, body: Data)] = uploads.compactMap { item in
|
||||
guard let body = item.data else { return nil }
|
||||
return (key: item.path, body: body)
|
||||
}
|
||||
|
||||
return BugReportOutcome(success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
|
||||
guard !uploadItems.isEmpty else {
|
||||
return BugReportOutcome(success: false, message: "No data to upload")
|
||||
}
|
||||
|
||||
let presignedUrls = try await fetchPresignedUploadUrls(keys: uploadItems.map(\.key))
|
||||
for item in uploadItems {
|
||||
guard let urlString = presignedUrls[item.key], let url = URL(string: urlString) else {
|
||||
throw BugReportError.uploadFailed("Missing presigned URL for \(item.key)")
|
||||
}
|
||||
try await uploadToPresignedUrl(url: url, body: item.body)
|
||||
}
|
||||
|
||||
return BugReportOutcome(
|
||||
success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
|
||||
}
|
||||
|
||||
private func loadCredentials() throws -> AWSConfig {
|
||||
return AWSConfig(
|
||||
accessKey: "AKIAYEKP5EMXTOBYDGHX",
|
||||
secretKey: "Ep5gIlUZ1o8ssTLQwmyy34yPGfTPEYQ4evE8NdPE",
|
||||
region: "us-east-1",
|
||||
bucket: "exo-bug-reports"
|
||||
)
|
||||
private static func dayPrefixString(_ date: Date) -> String {
|
||||
var calendar = Calendar(identifier: .gregorian)
|
||||
calendar.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
|
||||
let components = calendar.dateComponents([.year, .month, .day], from: date)
|
||||
let year = components.year ?? 0
|
||||
let month = components.month ?? 0
|
||||
let day = components.day ?? 0
|
||||
return String(format: "%04d/%02d/%02d", year, month, day)
|
||||
}
|
||||
|
||||
private static func runTimestampString(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.locale = Locale(identifier: "en_US_POSIX")
|
||||
formatter.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
|
||||
formatter.dateFormat = "yyyy-MM-dd'T'HHmmss.SSS'Z'"
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private func fetchPresignedUploadUrls(keys: [String], bundle: Bundle = .main) async throws
|
||||
-> [String: String]
|
||||
{
|
||||
guard
|
||||
let endpointString = bundle.infoDictionary?["EXOBugReportPresignedUrlEndpoint"]
|
||||
as? String
|
||||
else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
let trimmedEndpointString = endpointString.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
guard !trimmedEndpointString.isEmpty, let endpoint = URL(string: trimmedEndpointString)
|
||||
else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
|
||||
var request = URLRequest(url: endpoint)
|
||||
request.httpMethod = "POST"
|
||||
request.timeoutInterval = 10
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
|
||||
let encoder = JSONEncoder()
|
||||
request.httpBody = try encoder.encode(PresignedUrlsRequest(keys: keys))
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse else {
|
||||
throw BugReportError.presignedUrlFailed("Non-HTTP response")
|
||||
}
|
||||
guard (200..<300).contains(http.statusCode) else {
|
||||
throw BugReportError.presignedUrlFailed("HTTP status \(http.statusCode)")
|
||||
}
|
||||
|
||||
let decoder = JSONDecoder()
|
||||
let decoded = try decoder.decode(PresignedUrlsResponse.self, from: data)
|
||||
return decoded.urls
|
||||
}
|
||||
|
||||
private func readLog() -> Data? {
|
||||
@@ -100,7 +155,8 @@ struct BugReportService {
|
||||
private func captureIfconfig() async throws -> String {
|
||||
let result = runCommand(["/sbin/ifconfig"])
|
||||
guard result.exitCode == 0 else {
|
||||
throw BugReportError.collectFailed(result.error.isEmpty ? "ifconfig failed" : result.error)
|
||||
throw BugReportError.collectFailed(
|
||||
result.error.isEmpty ? "ifconfig failed" : result.error)
|
||||
}
|
||||
return result.output
|
||||
}
|
||||
@@ -108,12 +164,23 @@ struct BugReportService {
|
||||
private func readDebugInfo() -> DebugInfo {
|
||||
DebugInfo(
|
||||
thunderboltBridgeDisabled: readThunderboltBridgeDisabled(),
|
||||
interfaces: readInterfaces()
|
||||
interfaces: readInterfaces(),
|
||||
rdma: readRDMADebugInfo()
|
||||
)
|
||||
}
|
||||
|
||||
private func readRDMADebugInfo() -> DebugInfo.RDMADebugInfo {
|
||||
DebugInfo.RDMADebugInfo(
|
||||
rdmaCtlStatus: safeRunCommand(["/usr/bin/rdma_ctl", "status"]),
|
||||
ibvDevices: safeRunCommand(["/usr/bin/ibv_devices"]),
|
||||
ibvDevinfo: safeRunCommand(["/usr/bin/ibv_devinfo"])
|
||||
)
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeDisabled() -> Bool? {
|
||||
let result = runCommand(["/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
|
||||
let result = runCommand([
|
||||
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
|
||||
])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased()
|
||||
if output.contains("enabled") {
|
||||
@@ -156,7 +223,8 @@ struct BugReportService {
|
||||
request.timeoutInterval = 5
|
||||
do {
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
@@ -165,6 +233,36 @@ struct BugReportService {
|
||||
}
|
||||
}
|
||||
|
||||
private func uploadToPresignedUrl(url: URL, body: Data) async throws {
|
||||
let maxAttempts = 2
|
||||
var lastError: Error?
|
||||
|
||||
for attempt in 1...maxAttempts {
|
||||
do {
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "PUT"
|
||||
request.httpBody = body
|
||||
request.timeoutInterval = 30
|
||||
|
||||
let (_, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse else {
|
||||
throw BugReportError.uploadFailed("Non-HTTP response")
|
||||
}
|
||||
guard (200..<300).contains(http.statusCode) else {
|
||||
throw BugReportError.uploadFailed("HTTP status \(http.statusCode)")
|
||||
}
|
||||
return
|
||||
} catch {
|
||||
lastError = error
|
||||
if attempt < maxAttempts {
|
||||
try await Task.sleep(nanoseconds: 400_000_000)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw BugReportError.uploadFailed(lastError?.localizedDescription ?? "Unknown error")
|
||||
}
|
||||
|
||||
private func makeReportJson(
|
||||
timestamp: String,
|
||||
hostName: String,
|
||||
@@ -182,7 +280,7 @@ struct BugReportService {
|
||||
"system": system,
|
||||
"exo_version": exo.version as Any,
|
||||
"exo_commit": exo.commit as Any,
|
||||
"report_type": isManual ? "manual" : "automated"
|
||||
"report_type": isManual ? "manual" : "automated",
|
||||
]
|
||||
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
|
||||
}
|
||||
@@ -213,10 +311,13 @@ struct BugReportService {
|
||||
let user = safeRunCommand(["/usr/bin/whoami"])
|
||||
let consoleUser = safeRunCommand(["/usr/bin/stat", "-f%Su", "/dev/console"])
|
||||
let uptime = safeRunCommand(["/usr/bin/uptime"])
|
||||
let diskRoot = safeRunCommand(["/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'"])
|
||||
let diskRoot = safeRunCommand([
|
||||
"/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'",
|
||||
])
|
||||
|
||||
let interfacesList = safeRunCommand(["/usr/sbin/ipconfig", "getiflist"])
|
||||
let interfacesAndIPs = interfacesList?
|
||||
let interfacesAndIPs =
|
||||
interfacesList?
|
||||
.split(whereSeparator: { $0 == " " || $0 == "\n" })
|
||||
.compactMap { iface -> [String: Any]? in
|
||||
let name = String(iface)
|
||||
@@ -227,7 +328,8 @@ struct BugReportService {
|
||||
} ?? []
|
||||
|
||||
let wifiSSID: String?
|
||||
let airportPath = "/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
|
||||
let airportPath =
|
||||
"/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
|
||||
if FileManager.default.isExecutableFile(atPath: airportPath) {
|
||||
wifiSSID = safeRunCommand([airportPath, "-I"]).flatMap(parseWifiSSID)
|
||||
} else {
|
||||
@@ -255,7 +357,7 @@ struct BugReportService {
|
||||
"disk_root": diskRoot as Any,
|
||||
"interfaces_and_ips": interfacesAndIPs,
|
||||
"ipconfig_getiflist": interfacesList as Any,
|
||||
"wifi_ssid": wifiSSID as Any
|
||||
"wifi_ssid": wifiSSID as Any,
|
||||
]
|
||||
}
|
||||
|
||||
@@ -313,7 +415,8 @@ struct BugReportService {
|
||||
for line in airportOutput.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("SSID:") {
|
||||
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(in: .whitespaces)
|
||||
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -350,6 +453,7 @@ struct BugReportService {
|
||||
private struct DebugInfo {
|
||||
let thunderboltBridgeDisabled: Bool?
|
||||
let interfaces: [InterfaceStatus]
|
||||
let rdma: RDMADebugInfo
|
||||
|
||||
struct InterfaceStatus {
|
||||
let name: String
|
||||
@@ -358,7 +462,21 @@ private struct DebugInfo {
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"name": name,
|
||||
"ip": ip as Any
|
||||
"ip": ip as Any,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
struct RDMADebugInfo {
|
||||
let rdmaCtlStatus: String?
|
||||
let ibvDevices: String?
|
||||
let ibvDevinfo: String?
|
||||
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"rdma_ctl_status": rdmaCtlStatus as Any,
|
||||
"ibv_devices": ibvDevices as Any,
|
||||
"ibv_devinfo": ibvDevinfo as Any,
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -366,7 +484,8 @@ private struct DebugInfo {
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"thunderbolt_bridge_disabled": thunderboltBridgeDisabled as Any,
|
||||
"interfaces": interfaces.map { $0.toDictionary() }
|
||||
"interfaces": interfaces.map { $0.toDictionary() },
|
||||
"rdma": rdma.toDictionary(),
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -376,163 +495,3 @@ private struct CommandResult {
|
||||
let output: String
|
||||
let error: String
|
||||
}
|
||||
|
||||
private struct S3Uploader {
|
||||
let config: BugReportService.AWSConfig
|
||||
|
||||
init(config: BugReportService.AWSConfig) throws {
|
||||
self.config = config
|
||||
}
|
||||
|
||||
func upload(objectPath: String, body: Data) async throws {
|
||||
let host = "\(config.bucket).s3.amazonaws.com"
|
||||
guard let url = URL(string: "https://\(host)/\(objectPath)") else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
|
||||
let now = Date()
|
||||
let amzDate = awsTimestamp(now)
|
||||
let dateStamp = dateStamp(now)
|
||||
let payloadHash = sha256Hex(body)
|
||||
|
||||
let headers = [
|
||||
"host": host,
|
||||
"x-amz-content-sha256": payloadHash,
|
||||
"x-amz-date": amzDate
|
||||
]
|
||||
|
||||
let canonicalRequest = buildCanonicalRequest(
|
||||
method: "PUT",
|
||||
url: url,
|
||||
headers: headers,
|
||||
payloadHash: payloadHash
|
||||
)
|
||||
|
||||
let stringToSign = buildStringToSign(
|
||||
amzDate: amzDate,
|
||||
dateStamp: dateStamp,
|
||||
canonicalRequestHash: sha256Hex(canonicalRequest.data(using: .utf8) ?? Data())
|
||||
)
|
||||
|
||||
let signingKey = deriveKey(secret: config.secretKey, dateStamp: dateStamp, region: config.region, service: "s3")
|
||||
let signature = hmacHex(key: signingKey, data: Data(stringToSign.utf8))
|
||||
|
||||
let signedHeaders = "host;x-amz-content-sha256;x-amz-date"
|
||||
let authorization = """
|
||||
AWS4-HMAC-SHA256 Credential=\(config.accessKey)/\(dateStamp)/\(config.region)/s3/aws4_request, SignedHeaders=\(signedHeaders), Signature=\(signature)
|
||||
"""
|
||||
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "PUT"
|
||||
request.httpBody = body
|
||||
request.setValue(headers["x-amz-content-sha256"], forHTTPHeaderField: "x-amz-content-sha256")
|
||||
request.setValue(headers["x-amz-date"], forHTTPHeaderField: "x-amz-date")
|
||||
request.setValue(host, forHTTPHeaderField: "Host")
|
||||
request.setValue(authorization, forHTTPHeaderField: "Authorization")
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
|
||||
let statusText = (response as? HTTPURLResponse)?.statusCode ?? -1
|
||||
_ = data // ignore response body for UX
|
||||
throw BugReportError.uploadFailed("HTTP status \(statusText)")
|
||||
}
|
||||
}
|
||||
|
||||
private func buildCanonicalRequest(
|
||||
method: String,
|
||||
url: URL,
|
||||
headers: [String: String],
|
||||
payloadHash: String
|
||||
) -> String {
|
||||
let canonicalURI = encodePath(url.path)
|
||||
let canonicalQuery = url.query ?? ""
|
||||
let sortedHeaders = headers.sorted { $0.key < $1.key }
|
||||
let canonicalHeaders = sortedHeaders
|
||||
.map { "\($0.key.lowercased()):\($0.value)\n" }
|
||||
.joined()
|
||||
let signedHeaders = sortedHeaders.map { $0.key.lowercased() }.joined(separator: ";")
|
||||
|
||||
return [
|
||||
method,
|
||||
canonicalURI,
|
||||
canonicalQuery,
|
||||
canonicalHeaders,
|
||||
signedHeaders,
|
||||
payloadHash
|
||||
].joined(separator: "\n")
|
||||
}
|
||||
|
||||
private func encodePath(_ path: String) -> String {
|
||||
return path
|
||||
.split(separator: "/")
|
||||
.map { segment in
|
||||
segment.addingPercentEncoding(withAllowedCharacters: Self.rfc3986) ?? String(segment)
|
||||
}
|
||||
.joined(separator: "/")
|
||||
.prependSlashIfNeeded()
|
||||
}
|
||||
|
||||
private func buildStringToSign(
|
||||
amzDate: String,
|
||||
dateStamp: String,
|
||||
canonicalRequestHash: String
|
||||
) -> String {
|
||||
"""
|
||||
AWS4-HMAC-SHA256
|
||||
\(amzDate)
|
||||
\(dateStamp)/\(config.region)/s3/aws4_request
|
||||
\(canonicalRequestHash)
|
||||
"""
|
||||
}
|
||||
|
||||
private func deriveKey(secret: String, dateStamp: String, region: String, service: String) -> Data {
|
||||
let kDate = hmac(key: Data(("AWS4" + secret).utf8), data: Data(dateStamp.utf8))
|
||||
let kRegion = hmac(key: kDate, data: Data(region.utf8))
|
||||
let kService = hmac(key: kRegion, data: Data(service.utf8))
|
||||
return hmac(key: kService, data: Data("aws4_request".utf8))
|
||||
}
|
||||
|
||||
private func hmac(key: Data, data: Data) -> Data {
|
||||
let keySym = SymmetricKey(data: key)
|
||||
let mac = HMAC<SHA256>.authenticationCode(for: data, using: keySym)
|
||||
return Data(mac)
|
||||
}
|
||||
|
||||
private func hmacHex(key: Data, data: Data) -> String {
|
||||
hmac(key: key, data: data).map { String(format: "%02x", $0) }.joined()
|
||||
}
|
||||
|
||||
private func sha256Hex(_ data: Data) -> String {
|
||||
let digest = SHA256.hash(data: data)
|
||||
return digest.compactMap { String(format: "%02x", $0) }.joined()
|
||||
}
|
||||
|
||||
private func awsTimestamp(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
|
||||
formatter.timeZone = TimeZone(abbreviation: "UTC")
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private func dateStamp(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.dateFormat = "yyyyMMdd"
|
||||
formatter.timeZone = TimeZone(abbreviation: "UTC")
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private static let rfc3986: CharacterSet = {
|
||||
var set = CharacterSet.alphanumerics
|
||||
set.insert(charactersIn: "-._~")
|
||||
return set
|
||||
}()
|
||||
}
|
||||
|
||||
private extension String {
|
||||
func prependSlashIfNeeded() -> String {
|
||||
if hasPrefix("/") {
|
||||
return self
|
||||
}
|
||||
return "/" + self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,7 +57,9 @@ final class ClusterStateService: ObservableObject {
|
||||
var request = URLRequest(url: url)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (data, response) = try await session.data(for: request)
|
||||
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
return
|
||||
}
|
||||
if let nodeId = try? decoder.decode(String.self, from: data) {
|
||||
@@ -113,7 +115,9 @@ final class ClusterStateService: ObservableObject {
|
||||
}
|
||||
}
|
||||
|
||||
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int) async {
|
||||
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int)
|
||||
async
|
||||
{
|
||||
do {
|
||||
var request = URLRequest(url: baseURL.appendingPathComponent("instance"))
|
||||
request.httpMethod = "POST"
|
||||
@@ -122,7 +126,7 @@ final class ClusterStateService: ObservableObject {
|
||||
"model_id": modelId,
|
||||
"sharding": sharding,
|
||||
"instance_meta": instanceMeta,
|
||||
"min_nodes": minNodes
|
||||
"min_nodes": minNodes,
|
||||
]
|
||||
request.httpBody = try JSONSerialization.data(withJSONObject: payload, options: [])
|
||||
let (_, response) = try await session.data(for: request)
|
||||
@@ -143,7 +147,9 @@ final class ClusterStateService: ObservableObject {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("models")
|
||||
let (data, response) = try await session.data(from: url)
|
||||
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
throw URLError(.badServerResponse)
|
||||
}
|
||||
let list = try decoder.decode(ModelListResponse.self, from: data)
|
||||
|
||||
150
app/EXO/EXO/Services/LocalNetworkChecker.swift
Normal file
150
app/EXO/EXO/Services/LocalNetworkChecker.swift
Normal file
@@ -0,0 +1,150 @@
|
||||
import Foundation
|
||||
import Network
|
||||
import os.log
|
||||
|
||||
/// Checks if the app's local network permission is actually functional.
|
||||
///
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service detects this by creating a UDP
|
||||
/// connection to the mDNS multicast address (224.0.0.251:5353).
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
case unknown
|
||||
case checking
|
||||
case working
|
||||
case notWorking(reason: String)
|
||||
|
||||
var isHealthy: Bool {
|
||||
if case .working = self { return true }
|
||||
return false
|
||||
}
|
||||
|
||||
var displayText: String {
|
||||
switch self {
|
||||
case .unknown:
|
||||
return "Unknown"
|
||||
case .checking:
|
||||
return "Checking..."
|
||||
case .working:
|
||||
return "Working"
|
||||
case .notWorking(let reason):
|
||||
return reason
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
@Published private(set) var lastConnectionState: String = "none"
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
/// Checks if local network access is working.
|
||||
func check() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
lastConnectionState = "connecting"
|
||||
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
let result = await self.performCheck()
|
||||
self.status = result
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
|
||||
private func performCheck() async -> Status {
|
||||
Self.logger.info("Checking local network access via UDP multicast")
|
||||
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
|
||||
// mDNS multicast address - same as libp2p uses for peer discovery
|
||||
let host = NWEndpoint.Host("224.0.0.251")
|
||||
let port = NWEndpoint.Port(integerLiteral: 5353)
|
||||
|
||||
let params = NWParameters.udp
|
||||
params.allowLocalEndpointReuse = true
|
||||
|
||||
let conn = NWConnection(host: host, port: port, using: params)
|
||||
connection = conn
|
||||
|
||||
return await withCheckedContinuation { continuation in
|
||||
var hasResumed = false
|
||||
let lock = NSLock()
|
||||
|
||||
let resumeOnce: (Status) -> Void = { status in
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
guard !hasResumed else { return }
|
||||
hasResumed = true
|
||||
continuation.resume(returning: status)
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { [weak self] state in
|
||||
let stateStr: String
|
||||
switch state {
|
||||
case .setup: stateStr = "setup"
|
||||
case .preparing: stateStr = "preparing"
|
||||
case .ready: stateStr = "ready"
|
||||
case .waiting(let e): stateStr = "waiting(\(e))"
|
||||
case .failed(let e): stateStr = "failed(\(e))"
|
||||
case .cancelled: stateStr = "cancelled"
|
||||
@unknown default: stateStr = "unknown"
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
self?.lastConnectionState = stateStr
|
||||
}
|
||||
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
case .waiting(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
|
||||
resumeOnce(.notWorking(reason: "Connection blocked"))
|
||||
}
|
||||
case .failed(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|
||||
|| errorStr.contains("permission") || errorStr.contains("denied")
|
||||
{
|
||||
resumeOnce(.notWorking(reason: "Permission denied"))
|
||||
} else {
|
||||
resumeOnce(.notWorking(reason: "Failed: \(error.localizedDescription)"))
|
||||
}
|
||||
case .cancelled, .setup, .preparing:
|
||||
break
|
||||
@unknown default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
conn.start(queue: .main)
|
||||
|
||||
Task {
|
||||
try? await Task.sleep(nanoseconds: 3_000_000_000)
|
||||
let state = conn.state
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
case .waiting, .preparing, .setup:
|
||||
resumeOnce(.notWorking(reason: "Timeout (may be blocked)"))
|
||||
default:
|
||||
resumeOnce(.notWorking(reason: "Timeout"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stop() {
|
||||
checkTask?.cancel()
|
||||
checkTask = nil
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
}
|
||||
}
|
||||
@@ -5,64 +5,66 @@ import os.log
|
||||
enum NetworkSetupHelper {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
|
||||
private static let daemonLabel = "io.exo.networksetup"
|
||||
private static let scriptDestination = "/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let scriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1791
|
||||
|
||||
private static let setupScript = """
|
||||
#!/usr/bin/env bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
|
||||
static func ensureLaunchDaemonInstalled() {
|
||||
Task.detached {
|
||||
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
|
||||
Task.detached(priority: .utility) {
|
||||
do {
|
||||
if daemonAlreadyInstalled() {
|
||||
return
|
||||
@@ -70,11 +72,70 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
try await installLaunchDaemon()
|
||||
logger.info("Network setup launch daemon installed and started")
|
||||
} catch {
|
||||
logger.error("Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)")
|
||||
logger.error(
|
||||
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes all EXO network setup components from the system.
|
||||
/// This includes the LaunchDaemon, scripts, logs, and network location.
|
||||
/// Requires admin privileges.
|
||||
static func uninstall() throws {
|
||||
let uninstallScript = makeUninstallScript()
|
||||
try runShellAsAdmin(uninstallScript)
|
||||
logger.info("EXO network setup components removed successfully")
|
||||
}
|
||||
|
||||
/// Checks if there are any EXO network components installed that need cleanup
|
||||
static func hasInstalledComponents() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
return scriptExists || plistExists
|
||||
}
|
||||
|
||||
private static func makeUninstallScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LOG_OUT="/var/log/\(daemonLabel).log"
|
||||
LOG_ERR="/var/log/\(daemonLabel).err.log"
|
||||
|
||||
# Unload the LaunchDaemon if running
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
|
||||
# Remove LaunchDaemon plist
|
||||
rm -f "$PLIST_DEST"
|
||||
|
||||
# Remove the script and parent directory if empty
|
||||
rm -f "$SCRIPT_DEST"
|
||||
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
|
||||
|
||||
# Remove log files
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
|
||||
# Switch back to Automatic network location
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
networksetup -listlocations | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
} || true
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
} || true
|
||||
|
||||
echo "EXO network components removed successfully"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func daemonAlreadyInstalled() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
@@ -82,7 +143,8 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
guard scriptExists, plistExists else { return false }
|
||||
guard
|
||||
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
|
||||
let plist = try? PropertyListSerialization.propertyList(from: data, options: [], format: nil) as? [String: Any]
|
||||
let plist = try? PropertyListSerialization.propertyList(
|
||||
from: data, options: [], format: nil) as? [String: Any]
|
||||
else {
|
||||
return false
|
||||
}
|
||||
@@ -92,7 +154,9 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
else {
|
||||
return false
|
||||
}
|
||||
if let programArgs = plist["ProgramArguments"] as? [String], programArgs.contains(scriptDestination) == false {
|
||||
if let programArgs = plist["ProgramArguments"] as? [String],
|
||||
programArgs.contains(scriptDestination) == false
|
||||
{
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -105,58 +169,59 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
|
||||
private static func makeInstallerScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func runShellAsAdmin(_ script: String) throws {
|
||||
let escapedScript = script
|
||||
let escapedScript =
|
||||
script
|
||||
.replacingOccurrences(of: "\\", with: "\\\\")
|
||||
.replacingOccurrences(of: "\"", with: "\\\"")
|
||||
|
||||
let appleScriptSource = """
|
||||
do shell script "\(escapedScript)" with administrator privileges
|
||||
"""
|
||||
do shell script "\(escapedScript)" with administrator privileges
|
||||
"""
|
||||
|
||||
guard let appleScript = NSAppleScript(source: appleScriptSource) else {
|
||||
throw NetworkSetupError.scriptCreationFailed
|
||||
|
||||
@@ -35,14 +35,34 @@ struct NetworkStatus: Equatable {
|
||||
let thunderboltBridgeState: ThunderboltState?
|
||||
let bridgeInactive: Bool?
|
||||
let interfaceStatuses: [InterfaceIpStatus]
|
||||
let rdmaStatus: RDMAStatus
|
||||
|
||||
static let empty = NetworkStatus(
|
||||
thunderboltBridgeState: nil,
|
||||
bridgeInactive: nil,
|
||||
interfaceStatuses: []
|
||||
interfaceStatuses: [],
|
||||
rdmaStatus: .empty
|
||||
)
|
||||
}
|
||||
|
||||
struct RDMAStatus: Equatable {
|
||||
let rdmaCtlEnabled: Bool?
|
||||
let devices: [String]
|
||||
let activePorts: [RDMAPort]
|
||||
|
||||
var isAvailable: Bool {
|
||||
rdmaCtlEnabled == true || !devices.isEmpty
|
||||
}
|
||||
|
||||
static let empty = RDMAStatus(rdmaCtlEnabled: nil, devices: [], activePorts: [])
|
||||
}
|
||||
|
||||
struct RDMAPort: Equatable {
|
||||
let device: String
|
||||
let port: String
|
||||
let state: String
|
||||
}
|
||||
|
||||
struct InterfaceIpStatus: Equatable {
|
||||
let interfaceName: String
|
||||
let ipAddress: String?
|
||||
@@ -59,10 +79,79 @@ private struct NetworkStatusFetcher {
|
||||
NetworkStatus(
|
||||
thunderboltBridgeState: readThunderboltBridgeState(),
|
||||
bridgeInactive: readBridgeInactive(),
|
||||
interfaceStatuses: readInterfaceStatuses()
|
||||
interfaceStatuses: readInterfaceStatuses(),
|
||||
rdmaStatus: readRDMAStatus()
|
||||
)
|
||||
}
|
||||
|
||||
private func readRDMAStatus() -> RDMAStatus {
|
||||
let rdmaCtlEnabled = readRDMACtlEnabled()
|
||||
let devices = readRDMADevices()
|
||||
let activePorts = readRDMAActivePorts()
|
||||
return RDMAStatus(
|
||||
rdmaCtlEnabled: rdmaCtlEnabled, devices: devices, activePorts: activePorts)
|
||||
}
|
||||
|
||||
private func readRDMACtlEnabled() -> Bool? {
|
||||
let result = runCommand(["rdma_ctl", "status"])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
if output.contains("enabled") {
|
||||
return true
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
private func readRDMADevices() -> [String] {
|
||||
let result = runCommand(["ibv_devices"])
|
||||
guard result.exitCode == 0 else { return [] }
|
||||
var devices: [String] = []
|
||||
for line in result.output.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("---") || trimmed.lowercased().hasPrefix("device")
|
||||
|| trimmed.isEmpty
|
||||
{
|
||||
continue
|
||||
}
|
||||
let parts = trimmed.split(separator: " ", maxSplits: 1)
|
||||
if let deviceName = parts.first {
|
||||
devices.append(String(deviceName))
|
||||
}
|
||||
}
|
||||
return devices
|
||||
}
|
||||
|
||||
private func readRDMAActivePorts() -> [RDMAPort] {
|
||||
let result = runCommand(["ibv_devinfo"])
|
||||
guard result.exitCode == 0 else { return [] }
|
||||
var ports: [RDMAPort] = []
|
||||
var currentDevice: String?
|
||||
var currentPort: String?
|
||||
|
||||
for line in result.output.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("hca_id:") {
|
||||
currentDevice = trimmed.replacingOccurrences(of: "hca_id:", with: "")
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
} else if trimmed.hasPrefix("port:") {
|
||||
currentPort = trimmed.replacingOccurrences(of: "port:", with: "")
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
} else if trimmed.hasPrefix("state:") {
|
||||
let state = trimmed.replacingOccurrences(of: "state:", with: "").trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
if let device = currentDevice, let port = currentPort {
|
||||
if state.lowercased().contains("active") {
|
||||
ports.append(RDMAPort(device: device, port: port, state: state))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ports
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeState() -> ThunderboltState? {
|
||||
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
|
||||
guard result.exitCode == 0 else {
|
||||
@@ -85,10 +174,11 @@ private struct NetworkStatusFetcher {
|
||||
private func readBridgeInactive() -> Bool? {
|
||||
let result = runCommand(["ifconfig", "bridge0"])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
guard let statusLine = result.output
|
||||
.components(separatedBy: .newlines)
|
||||
.first(where: { $0.contains("status:") })?
|
||||
.lowercased()
|
||||
guard
|
||||
let statusLine = result.output
|
||||
.components(separatedBy: .newlines)
|
||||
.first(where: { $0.contains("status:") })?
|
||||
.lowercased()
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
@@ -171,4 +261,3 @@ private struct NetworkStatusFetcher {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ struct InstanceViewModel: Identifiable, Equatable {
|
||||
case waiting
|
||||
case failed
|
||||
case idle
|
||||
case unknown
|
||||
case preparing
|
||||
|
||||
var label: String {
|
||||
switch self {
|
||||
@@ -68,7 +68,7 @@ struct InstanceViewModel: Identifiable, Equatable {
|
||||
case .waiting: return "Waiting"
|
||||
case .failed: return "Failed"
|
||||
case .idle: return "Idle"
|
||||
case .unknown: return "Unknown"
|
||||
case .preparing: return "Preparing"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -107,10 +107,13 @@ extension ClusterState {
|
||||
let nodeToRunner = instance.shardAssignments.nodeToRunner
|
||||
let nodeIds = Array(nodeToRunner.keys)
|
||||
let runnerIds = Array(nodeToRunner.values)
|
||||
let nodeNames = nodeIds.compactMap { nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0 }
|
||||
let nodeNames = nodeIds.compactMap {
|
||||
nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0
|
||||
}
|
||||
let statuses = runnerIds.compactMap { runners[$0]?.status.lowercased() }
|
||||
let downloadProgress = aggregateDownloadProgress(for: nodeIds)
|
||||
let state = InstanceViewModel.State(statuses: statuses, hasActiveDownload: downloadProgress != nil)
|
||||
let state = InstanceViewModel.State(
|
||||
statuses: statuses, hasActiveDownload: downloadProgress != nil)
|
||||
let chatTasks = (chatTasksByInstance[entry.key] ?? [])
|
||||
.sorted(by: { $0.sortPriority < $1.sortPriority })
|
||||
.map { InstanceTaskViewModel(task: $0) }
|
||||
@@ -165,8 +168,8 @@ extension ClusterState {
|
||||
}
|
||||
}
|
||||
|
||||
private extension InstanceViewModel.State {
|
||||
init(statuses: [String], hasActiveDownload: Bool = false) {
|
||||
extension InstanceViewModel.State {
|
||||
fileprivate init(statuses: [String], hasActiveDownload: Bool = false) {
|
||||
if statuses.contains(where: { $0.contains("failed") }) {
|
||||
self = .failed
|
||||
} else if hasActiveDownload || statuses.contains(where: { $0.contains("downloading") }) {
|
||||
@@ -182,7 +185,7 @@ private extension InstanceViewModel.State {
|
||||
} else if statuses.isEmpty {
|
||||
self = .idle
|
||||
} else {
|
||||
self = .unknown
|
||||
self = .preparing
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -243,4 +246,3 @@ extension InstanceTaskViewModel {
|
||||
self.parameters = task.parameters
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +87,9 @@ struct TopologyViewModel {
|
||||
extension ClusterState {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
|
||||
let allNodes = nodeViewModels().filter {
|
||||
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
|
||||
}
|
||||
guard !allNodes.isEmpty else { return nil }
|
||||
|
||||
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
|
||||
@@ -106,18 +108,24 @@ extension ClusterState {
|
||||
}
|
||||
|
||||
// Rotate so the local node (from /node_id API) is first
|
||||
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
|
||||
if let localId = localNodeId,
|
||||
let index = orderedNodes.firstIndex(where: { $0.id == localId })
|
||||
{
|
||||
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
|
||||
}
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
|
||||
return TopologyEdgeViewModel(sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
|
||||
} ?? []
|
||||
let edgesArray: [TopologyEdgeViewModel] =
|
||||
topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId),
|
||||
nodeIds.contains(connection.sendBackNodeId)
|
||||
else { return nil }
|
||||
return TopologyEdgeViewModel(
|
||||
sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
|
||||
} ?? []
|
||||
let edges = Set(edgesArray)
|
||||
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
|
||||
return TopologyViewModel(
|
||||
nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ struct InstanceRowView: View {
|
||||
if let progress = instance.downloadProgress {
|
||||
downloadStatusView(progress: progress)
|
||||
} else {
|
||||
statusChip(label: instance.state.label.uppercased(), color: statusColor)
|
||||
}
|
||||
statusChip(label: instance.state.label.uppercased(), color: statusColor)
|
||||
}
|
||||
}
|
||||
if let progress = instance.downloadProgress {
|
||||
GeometryReader { geometry in
|
||||
@@ -83,7 +83,7 @@ struct InstanceRowView: View {
|
||||
case .ready: return .teal
|
||||
case .waiting, .idle: return .gray
|
||||
case .failed: return .red
|
||||
case .unknown: return .secondary
|
||||
case .preparing: return .secondary
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,7 +97,8 @@ struct InstanceRowView: View {
|
||||
.font(.caption)
|
||||
.fontWeight(.semibold)
|
||||
if let subtitle = task.subtitle,
|
||||
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame {
|
||||
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame
|
||||
{
|
||||
Text(subtitle)
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
@@ -234,9 +235,12 @@ struct InstanceRowView: View {
|
||||
Button {
|
||||
isExpanded.wrappedValue.toggle()
|
||||
} label: {
|
||||
Label(isExpanded.wrappedValue ? "Hide" : "Show", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
Label(
|
||||
isExpanded.wrappedValue ? "Hide" : "Show",
|
||||
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
|
||||
)
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.font(.caption2)
|
||||
@@ -311,7 +315,9 @@ struct InstanceRowView: View {
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private func detailRow(icon: String? = nil, title: String, value: String, tint: Color = .secondary) -> some View {
|
||||
private func detailRow(
|
||||
icon: String? = nil, title: String, value: String, tint: Color = .secondary
|
||||
) -> some View {
|
||||
HStack(alignment: .firstTextBaseline, spacing: 6) {
|
||||
if let icon {
|
||||
Image(systemName: icon)
|
||||
@@ -329,4 +335,3 @@ struct InstanceRowView: View {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,4 +32,3 @@ struct NodeDetailView: View {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,4 +28,3 @@ struct NodeRowView: View {
|
||||
.padding(.vertical, 4)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -76,30 +76,33 @@ struct TopologyMiniView: View {
|
||||
|
||||
private func connectionLines(in size: CGSize) -> some View {
|
||||
let positions = positionedNodes(in: size)
|
||||
let positionById = Dictionary(uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
|
||||
let positionById = Dictionary(
|
||||
uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
|
||||
return Canvas { context, _ in
|
||||
guard !topology.edges.isEmpty else { return }
|
||||
let nodeRadius: CGFloat = 32
|
||||
let arrowLength: CGFloat = 10
|
||||
let arrowSpread: CGFloat = .pi / 7
|
||||
for edge in topology.edges {
|
||||
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId] else { continue }
|
||||
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId]
|
||||
else { continue }
|
||||
let dx = end.x - start.x
|
||||
let dy = end.y - start.y
|
||||
let distance = max(CGFloat(hypot(dx, dy)), 1)
|
||||
let ux = dx / distance
|
||||
let uy = dy / distance
|
||||
let adjustedStart = CGPoint(x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
|
||||
let adjustedStart = CGPoint(
|
||||
x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
|
||||
let adjustedEnd = CGPoint(x: end.x - ux * nodeRadius, y: end.y - uy * nodeRadius)
|
||||
|
||||
var linePath = Path()
|
||||
linePath.move(to: adjustedStart)
|
||||
linePath.addLine(to: adjustedEnd)
|
||||
context.stroke(
|
||||
context.stroke(
|
||||
linePath,
|
||||
with: .color(.secondary.opacity(0.3)),
|
||||
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
|
||||
)
|
||||
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
|
||||
)
|
||||
|
||||
let angle = atan2(uy, ux)
|
||||
let tip = adjustedEnd
|
||||
@@ -168,5 +171,3 @@ private struct NodeGlyphView: View {
|
||||
.frame(width: 95)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
//
|
||||
|
||||
import Testing
|
||||
|
||||
@testable import EXO
|
||||
|
||||
struct EXOTests {
|
||||
|
||||
154
app/EXO/uninstall-exo.sh
Executable file
154
app/EXO/uninstall-exo.sh
Executable file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# EXO Uninstaller Script
|
||||
#
|
||||
# This script removes all EXO system components that persist after deleting the app.
|
||||
# Run with: sudo ./uninstall-exo.sh
|
||||
#
|
||||
# Components removed:
|
||||
# - LaunchDaemon: /Library/LaunchDaemons/io.exo.networksetup.plist
|
||||
# - Network script: /Library/Application Support/EXO/
|
||||
# - Log files: /var/log/io.exo.networksetup.*
|
||||
# - Network location: "exo"
|
||||
# - Launch at login registration
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="io.exo.networksetup"
|
||||
SCRIPT_DEST="/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
PLIST_DEST="/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
LOG_OUT="/var/log/${LABEL}.log"
|
||||
LOG_ERR="/var/log/${LABEL}.err.log"
|
||||
APP_BUNDLE_ID="io.exo.EXO"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
echo_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
echo_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if running as root
|
||||
if [[ $EUID -ne 0 ]]; then
|
||||
echo_error "This script must be run as root (use sudo)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " EXO Uninstaller"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
# Unload the LaunchDaemon if running
|
||||
echo_info "Stopping network setup daemon..."
|
||||
if launchctl list | grep -q "$LABEL"; then
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
echo_info "Daemon stopped"
|
||||
else
|
||||
echo_warn "Daemon was not running"
|
||||
fi
|
||||
|
||||
# Remove LaunchDaemon plist
|
||||
if [[ -f "$PLIST_DEST" ]]; then
|
||||
rm -f "$PLIST_DEST"
|
||||
echo_info "Removed LaunchDaemon plist"
|
||||
else
|
||||
echo_warn "LaunchDaemon plist not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove the script and parent directory
|
||||
if [[ -f "$SCRIPT_DEST" ]]; then
|
||||
rm -f "$SCRIPT_DEST"
|
||||
echo_info "Removed network setup script"
|
||||
else
|
||||
echo_warn "Network setup script not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove EXO directory if empty
|
||||
if [[ -d "/Library/Application Support/EXO" ]]; then
|
||||
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
|
||||
echo_info "Removed EXO support directory" || \
|
||||
echo_warn "EXO support directory not empty, leaving in place"
|
||||
fi
|
||||
|
||||
# Remove log files
|
||||
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
echo_info "Removed log files"
|
||||
else
|
||||
echo_warn "Log files not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Switch back to Automatic network location
|
||||
echo_info "Restoring network configuration..."
|
||||
if networksetup -listlocations | grep -q "^Automatic$"; then
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
echo_info "Switched to Automatic network location"
|
||||
else
|
||||
echo_warn "Automatic network location not found"
|
||||
fi
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
if networksetup -listlocations | grep -q "^exo$"; then
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
echo_info "Deleted 'exo' network location"
|
||||
else
|
||||
echo_warn "'exo' network location not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
echo_info "Re-enabled Thunderbolt Bridge"
|
||||
fi
|
||||
|
||||
# Note about launch at login registration
|
||||
# SMAppService-based login items cannot be removed from a shell script.
|
||||
# They can only be unregistered from within the app itself or manually via System Settings.
|
||||
echo_warn "Launch at login must be removed manually:"
|
||||
echo_warn " System Settings → General → Login Items → Remove EXO"
|
||||
|
||||
# Check if EXO.app exists in common locations
|
||||
APP_FOUND=false
|
||||
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
|
||||
if [[ -d "$app_path" ]]; then
|
||||
if [[ "$APP_FOUND" == false ]]; then
|
||||
echo ""
|
||||
APP_FOUND=true
|
||||
fi
|
||||
echo_warn "EXO.app found at: $app_path"
|
||||
echo_warn "You may want to move it to Trash manually."
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo_info "EXO uninstall complete!"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
echo "The following have been removed:"
|
||||
echo " • Network setup LaunchDaemon"
|
||||
echo " • Network configuration script"
|
||||
echo " • Log files"
|
||||
echo " • 'exo' network location"
|
||||
echo ""
|
||||
echo "Your network has been restored to use the 'Automatic' location."
|
||||
echo "Thunderbolt Bridge has been re-enabled (if present)."
|
||||
echo ""
|
||||
echo "Manual step required:"
|
||||
echo " Remove EXO from Login Items in System Settings → General → Login Items"
|
||||
echo ""
|
||||
|
||||
526
bench/exo_bench.py
Normal file
526
bench/exo_bench.py
Normal file
@@ -0,0 +1,526 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def format_peak_memory(b: float) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if b < 1024.0:
|
||||
return f"{b:.2f}{unit}"
|
||||
b /= 1024.0
|
||||
raise ValueError("You're using petabytes of memory. Something went wrong...")
|
||||
|
||||
|
||||
def parse_int_list(values: list[str]) -> list[int]:
|
||||
items: list[int] = []
|
||||
for v in values:
|
||||
for part in v.split(","):
|
||||
part = part.strip()
|
||||
if part:
|
||||
items.append(int(part))
|
||||
|
||||
seen: set[int] = set()
|
||||
out: list[int] = []
|
||||
for x in items:
|
||||
if x not in seen:
|
||||
out.append(x)
|
||||
seen.add(x)
|
||||
return out
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if m.get("id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
full_id = str(m.get("hugging_face_id") or m["id"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def run_one_completion(
|
||||
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
content, pp_tokens = prompt_sizer.build(pp_hint)
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"max_tokens": tg,
|
||||
}
|
||||
|
||||
t0 = time.perf_counter()
|
||||
out = client.post_bench_chat_completions(payload)
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
stats = out.get("generation_stats")
|
||||
|
||||
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
|
||||
|
||||
return {
|
||||
"elapsed_s": elapsed,
|
||||
"output_text_preview": preview,
|
||||
"stats": stats,
|
||||
}, pp_tokens
|
||||
|
||||
|
||||
class PromptSizer:
|
||||
def __init__(self, tokenizer: Any, atom: str = "a "):
|
||||
self.tokenizer = tokenizer
|
||||
self.atom = atom
|
||||
self.count_fn = PromptSizer._make_counter(tokenizer)
|
||||
self.base_tokens = self.count_fn("")
|
||||
|
||||
@staticmethod
|
||||
def _make_counter(tokenizer: Any) -> Callable[[str], int]:
|
||||
def count_fn(user_content: str) -> int:
|
||||
messages = [{"role": "user", "content": user_content}]
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
|
||||
def build(self, target_prompt_tokens: int) -> tuple[str, int]:
|
||||
target = int(target_prompt_tokens)
|
||||
if target < self.base_tokens:
|
||||
raise RuntimeError(
|
||||
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
|
||||
)
|
||||
|
||||
content = ""
|
||||
tok = self.count_fn(content)
|
||||
|
||||
while tok < target:
|
||||
content += self.atom
|
||||
tok = self.count_fn(content)
|
||||
|
||||
if tok != target:
|
||||
raise RuntimeError(
|
||||
f"Overshot: got {tok} tokens (target {target}). "
|
||||
f"Pick a different atom (try ' a' or '\\n' or '0 ')."
|
||||
)
|
||||
|
||||
return content, tok
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
description="Benchmark exo model throughput across placement previews.",
|
||||
)
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--pp",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Prompt-size hints (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--tg",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Generation lengths (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Pipeline jaccl is often pointless, skip by default",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--warmup",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
default="bench/results.json",
|
||||
help="Write raw per-run results JSON to this path.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dry-run", action="store_true", help="List selected placements and exit."
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
tg_list = parse_int_list(args.tg)
|
||||
if not pp_list or not tg_list:
|
||||
logger.error("pp and tg lists must be non-empty")
|
||||
return 2
|
||||
if args.repeat <= 0:
|
||||
logger.error("--repeat must be >= 1")
|
||||
return 2
|
||||
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": short_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
full_model_id,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("[exo-bench] tokenizer load failed")
|
||||
|
||||
try:
|
||||
prompt_sizer = PromptSizer(tokenizer)
|
||||
logger.debug(f"[exo-bench] loaded tokenizer: {full_model_id} for prompt sizer")
|
||||
except Exception:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if 0 < n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
|
||||
selected.sort(
|
||||
key=lambda p: (
|
||||
str(p.get("instance_meta", "")),
|
||||
str(p.get("sharding", "")),
|
||||
-nodes_used_in_instance(p["instance"]),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
|
||||
logger.info(f"placements: {len(selected)}")
|
||||
for p in selected:
|
||||
logger.info(
|
||||
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
|
||||
)
|
||||
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
|
||||
sharding = str(preview["sharding"])
|
||||
instance_meta = str(preview["instance_meta"])
|
||||
n_nodes = nodes_used_in_instance(instance)
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info(
|
||||
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
for i in range(args.warmup):
|
||||
run_one_completion(
|
||||
client, full_model_id, pp_list[0], tg_list[0], prompt_sizer
|
||||
)
|
||||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||||
|
||||
for pp in pp_list:
|
||||
if (
|
||||
pp * n_nodes > 2048
|
||||
and "ring" in instance_meta.lower()
|
||||
and "tensor" in sharding.lower()
|
||||
):
|
||||
model_card = MODEL_CARDS[short_id]
|
||||
if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
logger.info(
|
||||
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
)
|
||||
continue
|
||||
for tg in tg_list:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
time.sleep(3)
|
||||
try:
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client, full_model_id, pp, tg, prompt_sizer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
row.update(
|
||||
{
|
||||
"model_short_id": short_id,
|
||||
"model_id": full_model_id,
|
||||
"placement_sharding": sharding,
|
||||
"placement_instance_meta": instance_meta,
|
||||
"placement_nodes": n_nodes,
|
||||
"instance_id": instance_id,
|
||||
"pp_tokens": actual_pp_tokens,
|
||||
"tg": tg,
|
||||
"repeat_index": r,
|
||||
}
|
||||
)
|
||||
runs.append(row)
|
||||
all_rows.append(row)
|
||||
|
||||
if runs:
|
||||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||||
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
f"peak_memory={format_peak_memory(peak)}\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
finally:
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
logger.debug(f"Deleted instance {instance_id}")
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
if args.json_out:
|
||||
with open(args.json_out, "w", encoding="utf-8") as f:
|
||||
json.dump(all_rows, f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"\nWrote results JSON: {args.json_out}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
1
dashboard/src/app.d.ts
vendored
1
dashboard/src/app.d.ts
vendored
@@ -11,4 +11,3 @@ declare global {
|
||||
}
|
||||
|
||||
export {};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { isLoading, sendMessage, generateImage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
|
||||
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
|
||||
import ChatAttachments from './ChatAttachments.svelte';
|
||||
import type { ChatUploadedFile } from '$lib/types/files';
|
||||
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
|
||||
@@ -10,7 +10,6 @@
|
||||
showHelperText?: boolean;
|
||||
autofocus?: boolean;
|
||||
showModelSelector?: boolean;
|
||||
modelTasks?: Record<string, string[]>;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -18,8 +17,7 @@
|
||||
placeholder = 'Ask anything',
|
||||
showHelperText = false,
|
||||
autofocus = true,
|
||||
showModelSelector = false,
|
||||
modelTasks = {}
|
||||
showModelSelector = false
|
||||
}: Props = $props();
|
||||
|
||||
let message = $state('');
|
||||
@@ -50,29 +48,13 @@
|
||||
// Accept all supported file types
|
||||
const acceptString = getAcceptString(['image', 'text', 'pdf']);
|
||||
|
||||
// Check if a model supports image generation
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const tasks = modelTasks[modelId] || [];
|
||||
return tasks.includes('TextToImage') || tasks.includes('ImageToImage');
|
||||
}
|
||||
|
||||
// Check if the currently selected model supports image generation
|
||||
const isImageModel = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
return modelSupportsImageGeneration(currentModel);
|
||||
});
|
||||
|
||||
// Extract available models from running instances
|
||||
const availableModels = $derived(() => {
|
||||
const models: Array<{id: string, label: string, isImageModel: boolean}> = [];
|
||||
const models: Array<{id: string, label: string}> = [];
|
||||
for (const [, instance] of Object.entries(instanceData)) {
|
||||
const modelId = getInstanceModelId(instance);
|
||||
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
|
||||
models.push({
|
||||
id: modelId,
|
||||
label: modelId.split('/').pop() || modelId,
|
||||
isImageModel: modelSupportsImageGeneration(modelId)
|
||||
});
|
||||
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
|
||||
}
|
||||
}
|
||||
return models;
|
||||
@@ -178,12 +160,7 @@
|
||||
uploadedFiles = [];
|
||||
resetTextareaHeight();
|
||||
|
||||
// Use image generation for image models
|
||||
if (isImageModel() && content) {
|
||||
generateImage(content);
|
||||
} else {
|
||||
sendMessage(content, files);
|
||||
}
|
||||
sendMessage(content, files);
|
||||
|
||||
// Refocus the textarea after sending
|
||||
setTimeout(() => textareaRef?.focus(), 10);
|
||||
@@ -320,14 +297,7 @@
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
{#if model.isImageModel}
|
||||
<svg class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate flex-1">{model.label}</span>
|
||||
<span class="truncate">{model.label}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
@@ -387,7 +357,7 @@
|
||||
onkeydown={handleKeydown}
|
||||
oninput={handleInput}
|
||||
onpaste={handlePaste}
|
||||
placeholder={isImageModel() ? 'Describe the image you want to generate...' : placeholder}
|
||||
{placeholder}
|
||||
disabled={loading}
|
||||
rows={1}
|
||||
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
|
||||
@@ -401,23 +371,14 @@
|
||||
{!canSend || loading
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={isImageModel() ? "Generate image" : "Send message"}
|
||||
aria-label="Send message"
|
||||
>
|
||||
{#if loading}
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
|
||||
<span class="hidden sm:inline">{isImageModel() ? 'GENERATING' : 'PROCESSING'}</span>
|
||||
<span class="hidden sm:inline">PROCESSING</span>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
|
||||
@@ -365,58 +365,10 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Generated Images -->
|
||||
{#if message.attachments?.some(a => a.type === 'generated-image')}
|
||||
<div class="mb-3">
|
||||
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
|
||||
<div class="relative group/img inline-block">
|
||||
<img
|
||||
src={attachment.preview}
|
||||
alt=""
|
||||
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
|
||||
/>
|
||||
<!-- Download button overlay -->
|
||||
<button
|
||||
type="button"
|
||||
class="absolute top-2 right-2 p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow opacity-0 group-hover/img:opacity-100 transition-opacity hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
|
||||
onclick={() => {
|
||||
if (attachment.preview) {
|
||||
const link = document.createElement('a');
|
||||
link.href = attachment.preview;
|
||||
link.download = `generated-image-${Date.now()}.png`;
|
||||
link.click();
|
||||
}
|
||||
}}
|
||||
title="Download image"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="text-xs text-foreground">
|
||||
{#if message.content === 'Generating image...'}
|
||||
<div class="flex items-center gap-3 text-exo-yellow">
|
||||
<div class="relative">
|
||||
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
|
||||
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
</div>
|
||||
<span class="font-mono tracking-wider uppercase text-sm">Generating image...</span>
|
||||
</div>
|
||||
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
export { default as TopologyGraph } from './TopologyGraph.svelte';
|
||||
export { default as ChatForm } from './ChatForm.svelte';
|
||||
export { default as ChatMessages } from './ChatMessages.svelte';
|
||||
export { default as ChatAttachments } from './ChatAttachments.svelte';
|
||||
export { default as ChatSidebar } from './ChatSidebar.svelte';
|
||||
export { default as ModelCard } from './ModelCard.svelte';
|
||||
export { default as MarkdownContent } from './MarkdownContent.svelte';
|
||||
|
||||
export { default as TopologyGraph } from "./TopologyGraph.svelte";
|
||||
export { default as ChatForm } from "./ChatForm.svelte";
|
||||
export { default as ChatMessages } from "./ChatMessages.svelte";
|
||||
export { default as ChatAttachments } from "./ChatAttachments.svelte";
|
||||
export { default as ChatSidebar } from "./ChatSidebar.svelte";
|
||||
export { default as ModelCard } from "./ModelCard.svelte";
|
||||
export { default as MarkdownContent } from "./MarkdownContent.svelte";
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,55 +13,124 @@ export interface ChatUploadedFile {
|
||||
}
|
||||
|
||||
export interface ChatAttachment {
|
||||
type: 'image' | 'text' | 'pdf' | 'audio';
|
||||
type: "image" | "text" | "pdf" | "audio";
|
||||
name: string;
|
||||
content?: string;
|
||||
base64Url?: string;
|
||||
mimeType?: string;
|
||||
}
|
||||
|
||||
export type FileCategory = 'image' | 'text' | 'pdf' | 'audio' | 'unknown';
|
||||
export type FileCategory = "image" | "text" | "pdf" | "audio" | "unknown";
|
||||
|
||||
export const IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.svg'];
|
||||
export const IMAGE_MIME_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/svg+xml'];
|
||||
export const IMAGE_EXTENSIONS = [
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".gif",
|
||||
".webp",
|
||||
".svg",
|
||||
];
|
||||
export const IMAGE_MIME_TYPES = [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/svg+xml",
|
||||
];
|
||||
|
||||
export const TEXT_EXTENSIONS = [
|
||||
'.txt', '.md', '.json', '.xml', '.yaml', '.yml', '.csv', '.log',
|
||||
'.js', '.ts', '.jsx', '.tsx', '.py', '.java', '.cpp', '.c', '.h',
|
||||
'.css', '.html', '.htm', '.sql', '.sh', '.bat', '.rs', '.go',
|
||||
'.rb', '.php', '.swift', '.kt', '.scala', '.r', '.dart', '.vue', '.svelte'
|
||||
".txt",
|
||||
".md",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".csv",
|
||||
".log",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".py",
|
||||
".java",
|
||||
".cpp",
|
||||
".c",
|
||||
".h",
|
||||
".css",
|
||||
".html",
|
||||
".htm",
|
||||
".sql",
|
||||
".sh",
|
||||
".bat",
|
||||
".rs",
|
||||
".go",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".r",
|
||||
".dart",
|
||||
".vue",
|
||||
".svelte",
|
||||
];
|
||||
export const TEXT_MIME_TYPES = [
|
||||
'text/plain', 'text/markdown', 'text/csv', 'text/html', 'text/css',
|
||||
'application/json', 'application/xml', 'text/xml', 'application/javascript',
|
||||
'text/javascript', 'application/typescript'
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"text/html",
|
||||
"text/css",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/javascript",
|
||||
"text/javascript",
|
||||
"application/typescript",
|
||||
];
|
||||
|
||||
export const PDF_EXTENSIONS = ['.pdf'];
|
||||
export const PDF_MIME_TYPES = ['application/pdf'];
|
||||
export const PDF_EXTENSIONS = [".pdf"];
|
||||
export const PDF_MIME_TYPES = ["application/pdf"];
|
||||
|
||||
export const AUDIO_EXTENSIONS = ['.mp3', '.wav', '.ogg', '.m4a'];
|
||||
export const AUDIO_MIME_TYPES = ['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/mp4'];
|
||||
export const AUDIO_EXTENSIONS = [".mp3", ".wav", ".ogg", ".m4a"];
|
||||
export const AUDIO_MIME_TYPES = [
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/ogg",
|
||||
"audio/mp4",
|
||||
];
|
||||
|
||||
/**
|
||||
* Get file category based on MIME type and extension
|
||||
*/
|
||||
export function getFileCategory(mimeType: string, fileName: string): FileCategory {
|
||||
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf('.'));
|
||||
|
||||
if (IMAGE_MIME_TYPES.includes(mimeType) || IMAGE_EXTENSIONS.includes(extension)) {
|
||||
return 'image';
|
||||
export function getFileCategory(
|
||||
mimeType: string,
|
||||
fileName: string,
|
||||
): FileCategory {
|
||||
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf("."));
|
||||
|
||||
if (
|
||||
IMAGE_MIME_TYPES.includes(mimeType) ||
|
||||
IMAGE_EXTENSIONS.includes(extension)
|
||||
) {
|
||||
return "image";
|
||||
}
|
||||
if (PDF_MIME_TYPES.includes(mimeType) || PDF_EXTENSIONS.includes(extension)) {
|
||||
return 'pdf';
|
||||
return "pdf";
|
||||
}
|
||||
if (AUDIO_MIME_TYPES.includes(mimeType) || AUDIO_EXTENSIONS.includes(extension)) {
|
||||
return 'audio';
|
||||
if (
|
||||
AUDIO_MIME_TYPES.includes(mimeType) ||
|
||||
AUDIO_EXTENSIONS.includes(extension)
|
||||
) {
|
||||
return "audio";
|
||||
}
|
||||
if (TEXT_MIME_TYPES.includes(mimeType) || TEXT_EXTENSIONS.includes(extension) || mimeType.startsWith('text/')) {
|
||||
return 'text';
|
||||
if (
|
||||
TEXT_MIME_TYPES.includes(mimeType) ||
|
||||
TEXT_EXTENSIONS.includes(extension) ||
|
||||
mimeType.startsWith("text/")
|
||||
) {
|
||||
return "text";
|
||||
}
|
||||
return 'unknown';
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -69,36 +138,36 @@ export function getFileCategory(mimeType: string, fileName: string): FileCategor
|
||||
*/
|
||||
export function getAcceptString(categories: FileCategory[]): string {
|
||||
const accepts: string[] = [];
|
||||
|
||||
|
||||
for (const category of categories) {
|
||||
switch (category) {
|
||||
case 'image':
|
||||
case "image":
|
||||
accepts.push(...IMAGE_EXTENSIONS, ...IMAGE_MIME_TYPES);
|
||||
break;
|
||||
case 'text':
|
||||
case "text":
|
||||
accepts.push(...TEXT_EXTENSIONS, ...TEXT_MIME_TYPES);
|
||||
break;
|
||||
case 'pdf':
|
||||
case "pdf":
|
||||
accepts.push(...PDF_EXTENSIONS, ...PDF_MIME_TYPES);
|
||||
break;
|
||||
case 'audio':
|
||||
case "audio":
|
||||
accepts.push(...AUDIO_EXTENSIONS, ...AUDIO_MIME_TYPES);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return accepts.join(',');
|
||||
|
||||
return accepts.join(",");
|
||||
}
|
||||
|
||||
/**
|
||||
* Format file size for display
|
||||
*/
|
||||
export function formatFileSize(bytes: number): string {
|
||||
if (bytes === 0) return '0 B';
|
||||
if (bytes === 0) return "0 B";
|
||||
const k = 1024;
|
||||
const sizes = ['B', 'KB', 'MB', 'GB'];
|
||||
const sizes = ["B", "KB", "MB", "GB"];
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + " " + sizes[i];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -128,42 +197,44 @@ export function readFileAsText(file: File): Promise<string> {
|
||||
/**
|
||||
* Process uploaded files into ChatUploadedFile format
|
||||
*/
|
||||
export async function processUploadedFiles(files: File[]): Promise<ChatUploadedFile[]> {
|
||||
export async function processUploadedFiles(
|
||||
files: File[],
|
||||
): Promise<ChatUploadedFile[]> {
|
||||
const results: ChatUploadedFile[] = [];
|
||||
|
||||
|
||||
for (const file of files) {
|
||||
const id = Date.now().toString() + Math.random().toString(36).substring(2, 9);
|
||||
const id =
|
||||
Date.now().toString() + Math.random().toString(36).substring(2, 9);
|
||||
const category = getFileCategory(file.type, file.name);
|
||||
|
||||
|
||||
const base: ChatUploadedFile = {
|
||||
id,
|
||||
name: file.name,
|
||||
size: file.size,
|
||||
type: file.type,
|
||||
file
|
||||
file,
|
||||
};
|
||||
|
||||
|
||||
try {
|
||||
if (category === 'image') {
|
||||
if (category === "image") {
|
||||
const preview = await readFileAsDataURL(file);
|
||||
results.push({ ...base, preview });
|
||||
} else if (category === 'text' || category === 'unknown') {
|
||||
} else if (category === "text" || category === "unknown") {
|
||||
const textContent = await readFileAsText(file);
|
||||
results.push({ ...base, textContent });
|
||||
} else if (category === 'pdf') {
|
||||
} else if (category === "pdf") {
|
||||
results.push(base);
|
||||
} else if (category === 'audio') {
|
||||
} else if (category === "audio") {
|
||||
const preview = await readFileAsDataURL(file);
|
||||
results.push({ ...base, preview });
|
||||
} else {
|
||||
results.push(base);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error processing file:', file.name, error);
|
||||
console.error("Error processing file:", file.name, error);
|
||||
results.push(base);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
@@ -47,30 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let mounted = $state(false);
|
||||
|
||||
// Instance launch state
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
|
||||
|
||||
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
|
||||
const modelTasks = $derived(() => {
|
||||
const tasks: Record<string, string[]> = {};
|
||||
for (const model of models) {
|
||||
if (model.tasks && model.tasks.length > 0) {
|
||||
// Map by short ID
|
||||
tasks[model.id] = model.tasks;
|
||||
// Also map by hugging_face_id from the API response
|
||||
if (model.hugging_face_id) {
|
||||
tasks[model.hugging_face_id] = model.tasks;
|
||||
}
|
||||
}
|
||||
}
|
||||
return tasks;
|
||||
});
|
||||
|
||||
// Helper to check if a model supports image generation
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
|
||||
if (!model?.tasks) return false;
|
||||
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
|
||||
}
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
@@ -616,7 +593,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
// Unwrap the instance
|
||||
const [instanceTag, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== 'object') {
|
||||
return { isDownloading: false, progress: null, statusText: 'UNKNOWN', perNode: [] };
|
||||
return { isDownloading: false, progress: null, statusText: 'PREPARING', perNode: [] };
|
||||
}
|
||||
|
||||
const inst = instance as { shardAssignments?: { nodeToRunner?: Record<string, string>; runnerToShard?: Record<string, unknown>; modelId?: string } };
|
||||
@@ -729,7 +706,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
function deriveInstanceStatus(instanceWrapped: unknown): { statusText: string; statusClass: string } {
|
||||
const [, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== 'object') {
|
||||
return { statusText: 'UNKNOWN', statusClass: 'inactive' };
|
||||
return { statusText: 'PREPARING', statusClass: 'inactive' };
|
||||
}
|
||||
|
||||
const inst = instance as { shardAssignments?: { runnerToShard?: Record<string, unknown> } };
|
||||
@@ -758,7 +735,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const has = (s: string) => statuses.includes(s);
|
||||
|
||||
if (statuses.length === 0) return { statusText: 'UNKNOWN', statusClass: 'inactive' };
|
||||
if (statuses.length === 0) return { statusText: 'PREPARING', statusClass: 'inactive' };
|
||||
if (has('Failed')) return { statusText: 'FAILED', statusClass: 'failed' };
|
||||
if (has('Shutdown')) return { statusText: 'SHUTDOWN', statusClass: 'inactive' };
|
||||
if (has('Loading')) return { statusText: 'LOADING', statusClass: 'starting' };
|
||||
@@ -1273,7 +1250,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
placeholder="Ask anything"
|
||||
showHelperText={false}
|
||||
showModelSelector={true}
|
||||
modelTasks={modelTasks()}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -1291,9 +1267,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
<div
|
||||
bind:this={instancesContainerRef}
|
||||
class="max-h-72 space-y-3 overflow-y-auto"
|
||||
class="max-h-72 xl:max-h-96 space-y-3 overflow-y-auto overflow-x-hidden py-px"
|
||||
>
|
||||
{#each Object.entries(instanceData) as [id, instance]}
|
||||
{@const downloadInfo = getInstanceDownloadStatus(id, instance)}
|
||||
@@ -1495,18 +1471,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{@const foundModel = models.find(m => m.id === selectedModelId)}
|
||||
{#if foundModel}
|
||||
{@const sizeGB = getModelSizeGB(foundModel)}
|
||||
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
|
||||
<span class="flex items-center justify-between gap-2 w-full pr-4">
|
||||
<span class="flex items-center gap-2 text-exo-light-gray truncate">
|
||||
{#if isImageModel}
|
||||
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{foundModel.name || foundModel.id}</span>
|
||||
</span>
|
||||
<span class="flex items-center justify-between gap-2 w-full pr-4">
|
||||
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
|
||||
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
|
||||
</span>
|
||||
{:else}
|
||||
@@ -1551,7 +1517,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
) as model}
|
||||
{@const sizeGB = getModelSizeGB(model)}
|
||||
{@const modelCanFit = hasEnoughMemory(model)}
|
||||
{@const isImageModel = modelSupportsImageGeneration(model.id)}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
@@ -1571,16 +1536,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
: 'text-white/30 cursor-default'
|
||||
}"
|
||||
>
|
||||
<span class="flex items-center gap-2 truncate flex-1">
|
||||
{#if isImageModel}
|
||||
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
</span>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
|
||||
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
|
||||
</span>
|
||||
@@ -1777,7 +1733,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -1817,7 +1773,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<h3 class="text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
|
||||
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
|
||||
</div>
|
||||
<div class="space-y-3 max-h-72 overflow-y-auto pr-1">
|
||||
<div class="space-y-3 max-h-72 xl:max-h-96 overflow-y-auto overflow-x-hidden py-px pr-1">
|
||||
{#each Object.entries(instanceData) as [id, instance]}
|
||||
{@const downloadInfo = getInstanceDownloadStatus(id, instance)}
|
||||
{@const statusText = downloadInfo.statusText}
|
||||
|
||||
@@ -199,7 +199,13 @@
|
||||
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
|
||||
?? (downloadPayload as Record<string, unknown>).downloadProgress
|
||||
?? {};
|
||||
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
|
||||
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
|
||||
const totalBytes = getBytes(
|
||||
(downloadPayload as Record<string, unknown>).total_bytes
|
||||
?? (downloadPayload as Record<string, unknown>).totalBytes
|
||||
?? (rawProgress as Record<string, unknown>).total_bytes
|
||||
?? (rawProgress as Record<string, unknown>).totalBytes
|
||||
);
|
||||
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
|
||||
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
|
||||
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
|
||||
@@ -332,8 +338,13 @@
|
||||
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
|
||||
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
|
||||
</div>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
|
||||
<div>
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
|
||||
</div>
|
||||
<div class="text-exo-light-gray normal-case tracking-normal">
|
||||
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -385,7 +396,7 @@
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
{#if model.status !== 'completed'}
|
||||
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
|
||||
{/if}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import tailwindcss from '@tailwindcss/vite';
|
||||
import { sveltekit } from '@sveltejs/kit/vite';
|
||||
import { defineConfig } from 'vite';
|
||||
import tailwindcss from "@tailwindcss/vite";
|
||||
import { sveltekit } from "@sveltejs/kit/vite";
|
||||
import { defineConfig } from "vite";
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [tailwindcss(), sveltekit()],
|
||||
server: {
|
||||
proxy: {
|
||||
'/v1': 'http://localhost:52415',
|
||||
'/state': 'http://localhost:52415',
|
||||
'/models': 'http://localhost:52415',
|
||||
'/instance': 'http://localhost:52415'
|
||||
}
|
||||
}
|
||||
"/v1": "http://localhost:52415",
|
||||
"/state": "http://localhost:52415",
|
||||
"/models": "http://localhost:52415",
|
||||
"/instance": "http://localhost:52415",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
212
docs/api.md
Normal file
212
docs/api.md
Normal file
@@ -0,0 +1,212 @@
|
||||
# EXO API – Technical Reference
|
||||
|
||||
This document describes the REST API exposed by the **EXO ** service, as implemented in:
|
||||
|
||||
`src/exo/master/api.py`
|
||||
|
||||
The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using an OpenAI-compatible interface.
|
||||
|
||||
Base URL example:
|
||||
|
||||
```
|
||||
http://localhost:52415
|
||||
```
|
||||
|
||||
## 1. General / Meta Endpoints
|
||||
|
||||
### Get Master Node ID
|
||||
|
||||
**GET** `/node_id`
|
||||
|
||||
Returns the identifier of the current master node.
|
||||
|
||||
**Response (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"node_id": "node-1234"
|
||||
}
|
||||
```
|
||||
|
||||
### Get Cluster State
|
||||
|
||||
**GET** `/state`
|
||||
|
||||
Returns the current state of the cluster, including nodes and active instances.
|
||||
|
||||
**Response:**
|
||||
JSON object describing topology, nodes, and instances.
|
||||
|
||||
### Get Events
|
||||
|
||||
**GET** `/events`
|
||||
|
||||
Returns the list of internal events recorded by the master (mainly for debugging and observability).
|
||||
|
||||
**Response:**
|
||||
Array of event objects.
|
||||
|
||||
## 2. Model Instance Management
|
||||
|
||||
### Create Instance
|
||||
|
||||
**POST** `/instance`
|
||||
|
||||
Creates a new model instance in the cluster.
|
||||
|
||||
**Request body (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"instance": {
|
||||
"model_id": "llama-3.2-1b",
|
||||
"placement": { }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
JSON description of the created instance.
|
||||
|
||||
### Delete Instance
|
||||
|
||||
**DELETE** `/instance/{instance_id}`
|
||||
|
||||
Deletes an existing instance by ID.
|
||||
|
||||
**Path parameters:**
|
||||
|
||||
* `instance_id`: string, ID of the instance to delete
|
||||
|
||||
**Response:**
|
||||
Status / confirmation JSON.
|
||||
|
||||
### Get Instance
|
||||
|
||||
**GET** `/instance/{instance_id}`
|
||||
|
||||
Returns details of a specific instance.
|
||||
|
||||
**Path parameters:**
|
||||
|
||||
* `instance_id`: string
|
||||
|
||||
**Response:**
|
||||
JSON description of the instance.
|
||||
|
||||
### Preview Placements
|
||||
|
||||
**GET** `/instance/previews?model_id=...`
|
||||
|
||||
Returns possible placement previews for a given model.
|
||||
|
||||
**Query parameters:**
|
||||
|
||||
* `model_id`: string, required
|
||||
|
||||
**Response:**
|
||||
Array of placement preview objects.
|
||||
|
||||
### Compute Placement
|
||||
|
||||
**GET** `/instance/placement`
|
||||
|
||||
Computes a placement for a potential instance without creating it.
|
||||
|
||||
**Query parameters (typical):**
|
||||
|
||||
* `model_id`: string
|
||||
* `sharding`: string or config
|
||||
* `instance_meta`: JSON-encoded metadata
|
||||
* `min_nodes`: integer
|
||||
|
||||
**Response:**
|
||||
JSON object describing the proposed placement / instance configuration.
|
||||
|
||||
### Place Instance (Dry Operation)
|
||||
|
||||
**POST** `/place_instance`
|
||||
|
||||
Performs a placement operation for an instance (planning step), without necessarily creating it.
|
||||
|
||||
**Request body:**
|
||||
JSON describing the instance to be placed.
|
||||
|
||||
**Response:**
|
||||
Placement result.
|
||||
|
||||
## 3. Models
|
||||
|
||||
### List Models
|
||||
|
||||
**GET** `/models`
|
||||
**GET** `/v1/models` (alias)
|
||||
|
||||
Returns the list of available models and their metadata.
|
||||
|
||||
**Response:**
|
||||
Array of model descriptors.
|
||||
|
||||
## 4. Inference / Chat Completions
|
||||
|
||||
### OpenAI-Compatible Chat Completions
|
||||
|
||||
**POST** `/v1/chat/completions`
|
||||
|
||||
Executes a chat completion request using an OpenAI-compatible schema. Supports streaming and non-streaming modes.
|
||||
|
||||
**Request body (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama-3.2-1b",
|
||||
"messages": [
|
||||
{ "role": "system", "content": "You are a helpful assistant." },
|
||||
{ "role": "user", "content": "Hello" }
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
OpenAI-compatible chat completion response.
|
||||
|
||||
### Benchmarked Chat Completions
|
||||
|
||||
**POST** `/bench/chat/completions`
|
||||
|
||||
Same as `/v1/chat/completions`, but also returns performance and generation statistics.
|
||||
|
||||
**Request body:**
|
||||
Same schema as `/v1/chat/completions`.
|
||||
|
||||
**Response:**
|
||||
Chat completion plus benchmarking metrics.
|
||||
|
||||
## 5. Complete Endpoint Summary
|
||||
|
||||
```
|
||||
GET /node_id
|
||||
GET /state
|
||||
GET /events
|
||||
|
||||
POST /instance
|
||||
GET /instance/{instance_id}
|
||||
DELETE /instance/{instance_id}
|
||||
|
||||
GET /instance/previews
|
||||
GET /instance/placement
|
||||
POST /place_instance
|
||||
|
||||
GET /models
|
||||
GET /v1/models
|
||||
|
||||
POST /v1/chat/completions
|
||||
POST /bench/chat/completions
|
||||
```
|
||||
|
||||
## 6. Notes
|
||||
|
||||
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
|
||||
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
|
||||
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.
|
||||
74
flake.lock
generated
74
flake.lock
generated
@@ -8,11 +8,11 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761893049,
|
||||
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
|
||||
"lastModified": 1768287139,
|
||||
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -21,25 +21,43 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils": {
|
||||
"flake-parts": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
"nixpkgs-lib": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"lastModified": 1768135262,
|
||||
"narHash": "sha256-PVvu7OqHBGWN16zSi6tEmPwwHQ4rLPU9Plvs8/1TUBY=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"rev": "80daad04eddbbf5a4d883996a73f3f542fa437ac",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1768127708,
|
||||
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-swift": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
@@ -50,27 +68,28 @@
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"fenix": "fenix",
|
||||
"flake-utils": "flake-utils",
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1761849405,
|
||||
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
|
||||
"lastModified": 1768224240,
|
||||
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"rev": "725349602e525df37f377701e001fe8aab807878",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -80,21 +99,6 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"treefmt-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
@@ -102,11 +106,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1762938485,
|
||||
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
|
||||
"lastModified": 1768158989,
|
||||
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
|
||||
"owner": "numtide",
|
||||
"repo": "treefmt-nix",
|
||||
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
|
||||
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
198
flake.nix
198
flake.nix
@@ -3,118 +3,136 @@
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
# Provides Rust dev-env integration:
|
||||
|
||||
flake-parts = {
|
||||
url = "github:hercules-ci/flake-parts";
|
||||
inputs.nixpkgs-lib.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
fenix = {
|
||||
url = "github:nix-community/fenix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
# Provides formatting infrastructure:
|
||||
|
||||
treefmt-nix = {
|
||||
url = "github:numtide/treefmt-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
|
||||
};
|
||||
|
||||
# TODO: figure out caching story
|
||||
# nixConfig = {
|
||||
# # nix community cachix
|
||||
# extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=";
|
||||
# extra-substituters = "https://nix-community.cachix.org";
|
||||
# };
|
||||
nixConfig = {
|
||||
extra-trusted-public-keys = "exo.cachix.org-1:okq7hl624TBeAR3kV+g39dUFSiaZgLRkLsFBCuJ2NZI=";
|
||||
extra-substituters = "https://exo.cachix.org";
|
||||
};
|
||||
|
||||
outputs =
|
||||
inputs:
|
||||
let
|
||||
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
|
||||
systems = [
|
||||
"x86_64-linux"
|
||||
"aarch64-darwin"
|
||||
"aarch64-linux"
|
||||
];
|
||||
fenixToolchain = system: inputs.fenix.packages.${system}.complete;
|
||||
in
|
||||
inputs.flake-utils.lib.eachSystem systems (
|
||||
system:
|
||||
let
|
||||
pkgs = import inputs.nixpkgs {
|
||||
inherit system;
|
||||
overlays = [ inputs.fenix.overlays.default ];
|
||||
};
|
||||
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
|
||||
projectRootFile = "flake.nix";
|
||||
programs.ruff-format.enable = true;
|
||||
programs.ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
programs.rustfmt.enable = true;
|
||||
programs.rustfmt.package = (fenixToolchain system).rustfmt;
|
||||
programs.nixpkgs-fmt.enable = true;
|
||||
};
|
||||
in
|
||||
{
|
||||
formatter = treefmtEval.config.build.wrapper;
|
||||
checks.formatting = treefmtEval.config.build.check inputs.self;
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
|
||||
devShells.default = pkgs.mkShell {
|
||||
packages =
|
||||
with pkgs;
|
||||
[
|
||||
# PYTHON
|
||||
python313
|
||||
uv
|
||||
ruff
|
||||
basedpyright
|
||||
imports = [
|
||||
inputs.treefmt-nix.flakeModule
|
||||
];
|
||||
|
||||
# RUST
|
||||
((fenixToolchain system).withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
])
|
||||
rustup # Just here to make RustRover happy
|
||||
perSystem =
|
||||
{ config, inputs', pkgs, lib, system, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
treefmt = {
|
||||
projectRootFile = "flake.nix";
|
||||
programs = {
|
||||
nixpkgs-fmt.enable = true;
|
||||
ruff-format = {
|
||||
enable = true;
|
||||
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
};
|
||||
rustfmt = {
|
||||
enable = true;
|
||||
package = fenixToolchain.rustfmt;
|
||||
};
|
||||
prettier = {
|
||||
enable = true;
|
||||
includes = [ "*.ts" ];
|
||||
};
|
||||
swift-format = {
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
|
||||
# SVELTE
|
||||
nodejs
|
||||
|
||||
# MISC
|
||||
just
|
||||
jq
|
||||
]
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
|
||||
# IFCONFIG
|
||||
unixtools.ifconfig
|
||||
|
||||
# Build dependencies for Linux
|
||||
pkg-config
|
||||
openssl
|
||||
])
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
|
||||
# MACMON
|
||||
macmon
|
||||
]);
|
||||
|
||||
shellHook = ''
|
||||
# PYTHON
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
|
||||
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
|
||||
# Build environment for Linux
|
||||
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
''}
|
||||
echo
|
||||
echo "🍎🍎 Run 'just <recipe>' to get started"
|
||||
just --list
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
packages =
|
||||
[
|
||||
# FORMATTING
|
||||
config.treefmt.build.wrapper
|
||||
|
||||
# PYTHON
|
||||
python313
|
||||
uv
|
||||
ruff
|
||||
basedpyright
|
||||
|
||||
# RUST
|
||||
(fenixToolchain.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
])
|
||||
rustup # Just here to make RustRover happy
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
|
||||
# SVELTE
|
||||
nodejs
|
||||
|
||||
# MISC
|
||||
just
|
||||
jq
|
||||
]
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
|
||||
# IFCONFIG
|
||||
unixtools.ifconfig
|
||||
|
||||
# Build dependencies for Linux
|
||||
pkg-config
|
||||
openssl
|
||||
])
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
|
||||
# MACMON
|
||||
macmon
|
||||
]);
|
||||
|
||||
shellHook = ''
|
||||
# PYTHON
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
|
||||
${lib.optionalString pkgs.stdenv.isLinux ''
|
||||
# Build environment for Linux
|
||||
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
''}
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -8,35 +8,21 @@ dependencies = [
|
||||
"aiofiles>=24.1.0",
|
||||
"aiohttp>=3.12.14",
|
||||
"types-aiofiles>=24.1.0.20250708",
|
||||
"typeguard>=4.4.4",
|
||||
"pydantic>=2.11.7",
|
||||
"base58>=2.1.1",
|
||||
"cryptography>=45.0.5",
|
||||
"fastapi>=0.116.1",
|
||||
"filelock>=3.18.0",
|
||||
"aiosqlite>=0.21.0",
|
||||
"networkx>=3.5",
|
||||
"protobuf>=6.32.0",
|
||||
"rich>=14.1.0",
|
||||
"rustworkx>=0.17.1",
|
||||
"sqlmodel>=0.0.24",
|
||||
"sqlalchemy[asyncio]>=2.0.43",
|
||||
"greenlet>=3.2.4",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"psutil>=7.0.0",
|
||||
"loguru>=0.7.3",
|
||||
"textual>=5.3.0",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"bidict>=0.23.1",
|
||||
"mlx>=0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm>=0.28.3",
|
||||
"mlx==0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux>=0.12.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -47,6 +33,7 @@ exo = "exo.main:main"
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"basedpyright>=1.29.0",
|
||||
"pyinstaller>=6.17.0",
|
||||
"pytest>=8.4.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
@@ -84,7 +71,7 @@ build-backend = "uv_build"
|
||||
###
|
||||
|
||||
[tool.basedpyright]
|
||||
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"]
|
||||
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src", "bench"]
|
||||
typeCheckingMode = "strict"
|
||||
failOnWarnings = true
|
||||
|
||||
@@ -112,6 +99,7 @@ root = "src"
|
||||
|
||||
# supported platforms for this project
|
||||
[tool.uv]
|
||||
prerelease = "allow"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import resource
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
@@ -195,6 +196,8 @@ class Node:
|
||||
|
||||
def main():
|
||||
args = Args.parse()
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
# TODO: Refactor the current verbosity system
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Literal, cast
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -24,12 +22,13 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
@@ -37,10 +36,7 @@ from exo.shared.types.api import (
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
FinishReason,
|
||||
ImageData,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationResponse,
|
||||
ImageGenerationTaskParams,
|
||||
GenerationStats,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -48,17 +44,14 @@ from exo.shared.types.api import (
|
||||
PlacementPreviewResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
@@ -94,23 +87,12 @@ def chunk_to_response(
|
||||
)
|
||||
|
||||
|
||||
def get_model_card(model_id: str) -> ModelCard | None:
|
||||
async def resolve_model_meta(model_id: str) -> ModelMetadata:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
|
||||
for _, model_card in MODEL_CARDS.items():
|
||||
if model_id == model_card.model_id:
|
||||
return model_card
|
||||
|
||||
|
||||
async def resolve_model_meta(model_id: str) -> ModelMetadata:
|
||||
model_card = get_model_card(model_id)
|
||||
|
||||
if model_card is not None:
|
||||
return model_card.metadata
|
||||
|
||||
return await get_model_meta(model_id)
|
||||
else:
|
||||
return await get_model_meta(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
@@ -154,7 +136,6 @@ class API:
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
|
||||
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
@@ -163,7 +144,6 @@ class API:
|
||||
self.session_id = new_session_id
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self._chat_completion_queues = {}
|
||||
self._image_generation_queues = {}
|
||||
self.unpause(result_clock)
|
||||
|
||||
def unpause(self, result_clock: int):
|
||||
@@ -195,10 +175,7 @@ class API:
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/v1/images/generations", response_model=None)(
|
||||
self.image_generations
|
||||
)
|
||||
self.app.post("/v1/images/edits", response_model=None)(self.image_edits)
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -517,6 +494,45 @@ class API:
|
||||
],
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
resp = BenchChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=combined_text
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
generation_stats=stats,
|
||||
)
|
||||
return resp
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
@@ -552,15 +568,11 @@ class API:
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
async def image_generations(
|
||||
self, payload: ImageGenerationTaskParams
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image generation requests.
|
||||
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
@@ -572,304 +584,16 @@ class API:
|
||||
status_code=404, detail=f"No instance found for model {payload.model}"
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
)
|
||||
payload.stream = False
|
||||
|
||||
command = ChatCompletion(request_params=payload)
|
||||
await self._send(command)
|
||||
|
||||
# Check if streaming is requested
|
||||
if payload.stream and payload.partial_images and payload.partial_images > 0:
|
||||
return StreamingResponse(
|
||||
self._generate_image_stream(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming: collect all image chunks
|
||||
return await self._collect_image_generation(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
response = await self._collect_chat_completion_with_stats(
|
||||
command.command_id,
|
||||
parse_gpt_oss,
|
||||
)
|
||||
|
||||
async def _generate_image_stream(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate SSE stream of partial and final images."""
|
||||
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
|
||||
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
|
||||
image_total_chunks: dict[tuple[int, bool], int] = {}
|
||||
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
key = (chunk.image_index, chunk.is_partial)
|
||||
|
||||
if key not in image_chunks:
|
||||
image_chunks[key] = {}
|
||||
image_total_chunks[key] = chunk.total_chunks
|
||||
image_metadata[key] = (
|
||||
chunk.partial_index,
|
||||
chunk.total_partials,
|
||||
)
|
||||
|
||||
image_chunks[key][chunk.chunk_index] = chunk.data
|
||||
|
||||
# Check if this image is complete
|
||||
if len(image_chunks[key]) == image_total_chunks[key]:
|
||||
full_data = "".join(
|
||||
image_chunks[key][i] for i in range(len(image_chunks[key]))
|
||||
)
|
||||
|
||||
partial_idx, total_partials = image_metadata[key]
|
||||
|
||||
if chunk.is_partial:
|
||||
# Yield partial image event
|
||||
event_data = {
|
||||
"type": "partial",
|
||||
"partial_index": partial_idx,
|
||||
"total_partials": total_partials,
|
||||
"data": {
|
||||
"b64_json": full_data
|
||||
if response_format == "b64_json"
|
||||
else None,
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
else:
|
||||
# Final image
|
||||
event_data = {
|
||||
"type": "final",
|
||||
"image_index": chunk.image_index,
|
||||
"data": {
|
||||
"b64_json": full_data
|
||||
if response_format == "b64_json"
|
||||
else None,
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
# Clean up completed image chunks
|
||||
del image_chunks[key]
|
||||
del image_total_chunks[key]
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._image_generation_queues:
|
||||
del self._image_generation_queues[command_id]
|
||||
|
||||
async def _collect_image_generation(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> ImageGenerationResponse:
|
||||
"""Collect all image chunks (non-streaming) and return a single response."""
|
||||
# Track chunks per image: {image_index: {chunk_index: data}}
|
||||
# Only track non-partial (final) images
|
||||
image_chunks: dict[int, dict[int, str]] = {}
|
||||
image_total_chunks: dict[int, int] = {}
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
|
||||
while images_complete < num_images:
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
# Skip partial images in non-streaming mode
|
||||
if chunk.is_partial:
|
||||
continue
|
||||
|
||||
if chunk.image_index not in image_chunks:
|
||||
image_chunks[chunk.image_index] = {}
|
||||
image_total_chunks[chunk.image_index] = chunk.total_chunks
|
||||
|
||||
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
|
||||
|
||||
# Check if this image is complete
|
||||
if (
|
||||
len(image_chunks[chunk.image_index])
|
||||
== image_total_chunks[chunk.image_index]
|
||||
):
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
break
|
||||
|
||||
# Reassemble images in order
|
||||
images: list[ImageData] = []
|
||||
for image_idx in range(num_images):
|
||||
chunks_dict = image_chunks[image_idx]
|
||||
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
|
||||
images.append(
|
||||
ImageData(
|
||||
b64_json=full_data if response_format == "b64_json" else None,
|
||||
url=None, # URL format not implemented yet
|
||||
)
|
||||
)
|
||||
|
||||
return ImageGenerationResponse(data=images)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._image_generation_queues:
|
||||
del self._image_generation_queues[command_id]
|
||||
|
||||
async def image_edits(
|
||||
self,
|
||||
image: UploadFile = File(...),
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: bool = Form(False),
|
||||
partial_images: int = Form(0),
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image editing requests (img2img)."""
|
||||
model_meta = await resolve_model_meta(model)
|
||||
resolved_model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(resolved_model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {resolved_model}"
|
||||
)
|
||||
|
||||
# Read and base64 encode the uploaded image
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
# Map input_fidelity to image_strength
|
||||
image_strength = 0.7 if input_fidelity == "high" else 0.3
|
||||
|
||||
# Split image into chunks to stay under gossipsub message size limit
|
||||
data_chunks = [
|
||||
image_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
|
||||
# Create command first to get command_id
|
||||
command = ImageEdits(
|
||||
request_params=ImageEditsInternalParams(
|
||||
image_data="", # Empty - will be assembled at worker from chunks
|
||||
total_input_chunks=total_chunks,
|
||||
prompt=prompt,
|
||||
model=resolved_model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
image_strength=image_strength,
|
||||
stream=stream,
|
||||
partial_images=partial_images,
|
||||
),
|
||||
)
|
||||
|
||||
# Send input chunks BEFORE the command
|
||||
logger.info(
|
||||
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(data_chunks):
|
||||
await self._send(
|
||||
SendInputChunk(
|
||||
chunk=InputImageChunk(
|
||||
idx=chunk_index,
|
||||
model=resolved_model,
|
||||
command_id=command.command_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Now send the main command
|
||||
await self._send(command)
|
||||
|
||||
num_images = n
|
||||
|
||||
# Check if streaming is requested
|
||||
if stream and partial_images and partial_images > 0:
|
||||
return StreamingResponse(
|
||||
self._generate_image_stream(
|
||||
command_id=command.command_id,
|
||||
num_images=num_images,
|
||||
response_format=response_format,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Track chunks per image: {image_index: {chunk_index: data}}
|
||||
image_chunks: dict[int, dict[int, str]] = {}
|
||||
image_total_chunks: dict[int, int] = {}
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command.command_id], recv = channel[
|
||||
ImageChunk
|
||||
]()
|
||||
|
||||
while images_complete < num_images:
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
if chunk.image_index not in image_chunks:
|
||||
image_chunks[chunk.image_index] = {}
|
||||
image_total_chunks[chunk.image_index] = chunk.total_chunks
|
||||
|
||||
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
|
||||
|
||||
if (
|
||||
len(image_chunks[chunk.image_index])
|
||||
== image_total_chunks[chunk.image_index]
|
||||
):
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
break
|
||||
|
||||
images: list[ImageData] = []
|
||||
for image_idx in range(num_images):
|
||||
chunks_dict = image_chunks[image_idx]
|
||||
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
|
||||
images.append(
|
||||
ImageData(
|
||||
b64_json=full_data if response_format == "b64_json" else None,
|
||||
url=None, # URL format not implemented yet
|
||||
)
|
||||
)
|
||||
|
||||
return ImageGenerationResponse(data=images)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
# Send TaskFinished command
|
||||
await self._send(TaskFinished(finished_command_id=command.command_id))
|
||||
del self._image_generation_queues[command.command_id]
|
||||
return response
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
@@ -893,7 +617,6 @@ class API:
|
||||
tags=card.tags,
|
||||
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
|
||||
supports_tensor=card.metadata.supports_tensor,
|
||||
tasks=[task.value for task in card.tasks],
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
]
|
||||
@@ -931,17 +654,14 @@ class API:
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if event.command_id in self._chat_completion_queues:
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
elif event.command_id in self._image_generation_queues:
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
await self._image_generation_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
if (
|
||||
isinstance(event, ChunkGenerated)
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -2,7 +2,6 @@ from datetime import datetime, timedelta, timezone
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi.routing import request_response
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
@@ -12,17 +11,13 @@ from exo.master.placement import (
|
||||
place_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -31,7 +26,6 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
@@ -41,12 +35,6 @@ from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion as ChatCompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageEdits as ImageEditsTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageGeneration as ImageGenerationTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
@@ -111,7 +99,6 @@ class Master:
|
||||
async for forwarder_command in commands:
|
||||
try:
|
||||
logger.info(f"Executing command: {forwarder_command.command}")
|
||||
|
||||
generated_events: list[Event] = []
|
||||
command = forwarder_command.command
|
||||
match command:
|
||||
@@ -159,92 +146,6 @@ class Master:
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageGeneration():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageGenerationTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageEdits():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageEditsTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
@@ -272,13 +173,6 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
generated_events.append(
|
||||
InputChunkReceived(
|
||||
command_id=chunk.command_id,
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
|
||||
@@ -9,7 +9,6 @@ from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
NodeCreated,
|
||||
@@ -41,8 +40,8 @@ def event_apply(event: Event, state: State) -> State:
|
||||
"""Apply an event to state."""
|
||||
match event:
|
||||
case (
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
|
||||
): # Pass-through events that don't modify state
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged()
|
||||
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
|
||||
return state
|
||||
case InstanceCreated():
|
||||
return apply_instance_created(event, state)
|
||||
|
||||
@@ -44,5 +44,3 @@ LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
||||
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
|
||||
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
|
||||
LIBP2P_COMMANDS_TOPIC = "commands"
|
||||
|
||||
EXO_MAX_CHUNK_SIZE = 512 * 1024
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ComponentInfo, ModelId, ModelMetadata, ModelTask
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ class ModelCard(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
name: str
|
||||
description: str
|
||||
tasks: list[ModelTask]
|
||||
tags: list[str]
|
||||
metadata: ModelMetadata
|
||||
|
||||
@@ -46,7 +45,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
name="DeepSeek V3.1 (4-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
@@ -62,7 +60,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
name="DeepSeek V3.1 (8-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
@@ -85,6 +82,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3.2-4bit": ModelCard(
|
||||
@@ -99,6 +97,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
# deepseek r1
|
||||
@@ -136,7 +135,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
name="Kimi K2 Instruct (4-bit)",
|
||||
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
@@ -152,7 +150,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
name="Kimi K2 Thinking (4-bit)",
|
||||
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
@@ -169,7 +166,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
name="Llama 3.1 8B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
@@ -185,7 +181,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
name="Llama 3.1 8B (8-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
@@ -201,7 +196,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
name="Llama 3.1 8B (BF16)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
@@ -217,7 +211,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
name="Llama 3.1 70B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
@@ -234,7 +227,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
name="Llama 3.2 1B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
@@ -250,7 +242,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
name="Llama 3.2 3B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
@@ -266,7 +257,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
name="Llama 3.2 3B (8-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
@@ -283,7 +273,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
name="Llama 3.3 70B (4-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
@@ -299,7 +288,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
name="Llama 3.3 70B (8-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
@@ -315,7 +303,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
name="Llama 3.3 70B (FP16)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
@@ -332,7 +319,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
name="Qwen3 0.6B (4-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
@@ -348,7 +334,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
name="Qwen3 0.6B (8-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
@@ -364,7 +349,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
name="Qwen3 30B A3B (4-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
@@ -380,7 +364,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
name="Qwen3 30B A3B (8-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
@@ -396,7 +379,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
name="Qwen3 80B A3B (4-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
@@ -412,7 +394,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
name="Qwen3 80B A3B (8-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
@@ -428,7 +409,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
name="Qwen3 80B A3B Thinking (4-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
@@ -444,7 +424,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
name="Qwen3 80B A3B Thinking (8-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
@@ -460,7 +439,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
name="Qwen3 235B A22B (4-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
@@ -476,7 +454,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
name="Qwen3 235B A22B (8-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
@@ -492,7 +469,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
@@ -508,7 +484,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
@@ -525,7 +500,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
@@ -541,7 +515,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
@@ -558,7 +531,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
description="""GLM 4.5 Air 8bit""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
@@ -574,7 +546,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
name="GLM 4.5 Air bf16",
|
||||
description="""GLM 4.5 Air bf16""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
@@ -585,6 +556,81 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
short_id="glm-4.7-4bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
name="GLM 4.7 4bit",
|
||||
description="GLM 4.7 4bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
pretty_name="GLM 4.7 4bit",
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
short_id="glm-4.7-6bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
name="GLM 4.7 6bit",
|
||||
description="GLM 4.7 6bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
pretty_name="GLM 4.7 6bit",
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
short_id="glm-4.7-8bit-gs32",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
name="GLM 4.7 8bit (gs32)",
|
||||
description="GLM 4.7 8bit (gs32)",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
pretty_name="GLM 4.7 8bit (gs32)",
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
short_id="minimax-m2.1-8bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
name="MiniMax M2.1 8bit",
|
||||
description="MiniMax M2.1 8bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
pretty_name="MiniMax M2.1 8bit",
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
short_id="minimax-m2.1-3bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
name="MiniMax M2.1 3bit",
|
||||
description="MiniMax M2.1 3bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
pretty_name="MiniMax M2.1 3bit",
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
|
||||
# short_id="devstral-2-123b-instruct-2512-8bit",
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
@@ -600,188 +646,4 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
"flux1-schnell": ModelCard(
|
||||
short_id="flux1-schnell",
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
name="FLUX.1 [schnell]",
|
||||
description="""FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
pretty_name="FLUX.1 [schnell]",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(23782357120), # + 9524621312),
|
||||
n_layers=57, # sharded layers
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
short_id="flux1-dev",
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
name="FLUX.1 [dev]",
|
||||
description="""FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
pretty_name="FLUX.1 [dev]",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57, # sharded layers
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
short_id="qwen-image",
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
name="Qwen Image",
|
||||
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
pretty_name="Qwen Image",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
short_id="qwen-image-edit-2509",
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
name="Qwen Image Edit 2509",
|
||||
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
pretty_name="Qwen Image Edit 2509",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ from exo.shared.apply import apply_node_download_progress
|
||||
from exo.shared.tests.conftest import get_pipeline_shard_metadata
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import NodeDownloadProgress
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
|
||||
@@ -13,6 +14,7 @@ def test_apply_node_download_progress():
|
||||
event = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
|
||||
new_state = apply_node_download_progress(
|
||||
@@ -28,10 +30,12 @@ def test_apply_two_node_download_progress():
|
||||
event1 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
event2 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard2,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
state = State(downloads={NodeId("node-1"): [event1]})
|
||||
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
@@ -29,7 +28,6 @@ class ModelListModel(BaseModel):
|
||||
tags: list[str] = Field(default=[])
|
||||
storage_size_megabytes: int = Field(default=0)
|
||||
supports_tensor: bool = Field(default=False)
|
||||
tasks: list[str] = Field(default=[])
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
@@ -54,6 +52,10 @@ class ChatCompletionMessage(BaseModel):
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class BenchChatCompletionMessage(ChatCompletionMessage):
|
||||
pass
|
||||
|
||||
|
||||
class TopLogprobItem(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
@@ -116,6 +118,18 @@ class ChatCompletionResponse(BaseModel):
|
||||
service_tier: str | None = None
|
||||
|
||||
|
||||
class GenerationStats(BaseModel):
|
||||
prompt_tps: float
|
||||
generation_tps: float
|
||||
prompt_tokens: int
|
||||
generation_tokens: int
|
||||
peak_memory_usage: Memory
|
||||
|
||||
|
||||
class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
@@ -138,6 +152,10 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
user: str | None = None
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
pass
|
||||
|
||||
|
||||
class PlaceInstanceParams(BaseModel):
|
||||
model_id: str
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
@@ -184,75 +202,3 @@ class DeleteInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class ImageGenerationTaskParams(BaseModel):
|
||||
prompt: str
|
||||
# background: str | None = None
|
||||
model: str
|
||||
# moderation: str | None = None
|
||||
n: int | None = 1
|
||||
# output_compression: int | None = None
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_images: int | None = 0
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
stream: bool | None = False
|
||||
# style: str | None = "vivid"
|
||||
# user: str | None = None
|
||||
|
||||
|
||||
class ImageEditsTaskParams(BaseModel):
|
||||
image: UploadFile
|
||||
prompt: str
|
||||
input_fidelity: float = 0.7
|
||||
model: str
|
||||
n: int | None = 1
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
# user: str | None = None
|
||||
|
||||
|
||||
class ImageEditsInternalParams(BaseModel):
|
||||
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
|
||||
|
||||
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
|
||||
total_input_chunks: int = 0
|
||||
prompt: str
|
||||
model: str
|
||||
n: int | None = 1
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
image_strength: float = 0.7
|
||||
stream: bool = False
|
||||
partial_images: int | None = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class ImageData(BaseModel):
|
||||
b64_json: str | None = None
|
||||
url: str | None = None
|
||||
revised_prompt: str | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "b64_json" and value is not None:
|
||||
yield name, f"<{len(value)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
data: list[ImageData]
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .common import CommandId
|
||||
from .models import ModelId
|
||||
|
||||
|
||||
@@ -23,37 +21,11 @@ class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
data: str
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
image_index: int
|
||||
is_partial: bool = False
|
||||
partial_index: int | None = None
|
||||
total_partials: int | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "data":
|
||||
yield name, f"<{len(self.data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class InputImageChunk(BaseChunk):
|
||||
command_id: CommandId
|
||||
data: str
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "data":
|
||||
yield name, f"<{len(self.data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
data: bytes
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
@@ -25,14 +20,6 @@ class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
request_params: ImageGenerationTaskParams
|
||||
|
||||
|
||||
class ImageEdits(BaseCommand):
|
||||
request_params: ImageEditsInternalParams
|
||||
|
||||
|
||||
class PlaceInstance(BaseCommand):
|
||||
model_meta: ModelMetadata
|
||||
sharding: Sharding
|
||||
@@ -52,12 +39,6 @@ class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
|
||||
class SendInputChunk(BaseCommand):
|
||||
"""Command to send an input image chunk (converted to event by master)."""
|
||||
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class RequestEventLog(BaseCommand):
|
||||
since_idx: int
|
||||
|
||||
@@ -66,13 +47,10 @@ Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
| ChatCompletion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from datetime import datetime
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
@@ -106,11 +106,6 @@ class ChunkGenerated(BaseEvent):
|
||||
chunk: GenerationChunk
|
||||
|
||||
|
||||
class InputChunkReceived(BaseEvent):
|
||||
command_id: CommandId
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
edge: Connection
|
||||
|
||||
@@ -136,7 +131,6 @@ Event = (
|
||||
| NodeMemoryMeasured
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.types.common import Id
|
||||
@@ -11,21 +9,6 @@ class ModelId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class ModelTask(str, Enum):
|
||||
TextGeneration = "TextGeneration"
|
||||
TextToImage = "TextToImage"
|
||||
ImageToImage = "ImageToImage"
|
||||
|
||||
|
||||
class ComponentInfo(CamelCaseModel):
|
||||
component_name: str
|
||||
component_path: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt | None
|
||||
can_shard: bool
|
||||
safetensors_index_filename: str | None
|
||||
|
||||
|
||||
class ModelMetadata(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
pretty_name: str
|
||||
@@ -33,4 +16,3 @@ class ModelMetadata(CamelCaseModel):
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
components: list[ComponentInfo] | None = None
|
||||
|
||||
@@ -2,11 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.common import CommandId, Id
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
@@ -60,22 +56,6 @@ class ChatCompletion(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageEdits(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageEditsInternalParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
@@ -87,7 +67,5 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
|
||||
class DownloadCompleted(BaseDownloadProgress):
|
||||
pass
|
||||
total_bytes: Memory
|
||||
|
||||
|
||||
class DownloadFailed(BaseDownloadProgress):
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -18,32 +15,7 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
token: int
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} bytes>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class PartialImageResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_index: int
|
||||
total_partials: int
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} bytes>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -9,7 +9,6 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal
|
||||
from urllib.parse import urljoin
|
||||
from huggingface_hub._snapshot_download import snapshot_download
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
@@ -442,31 +441,12 @@ def calculate_repo_progress(
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
|
||||
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
index_files_dir = snapshot_download(
|
||||
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
|
||||
index_file = await download_file_with_retry(
|
||||
repo_id, revision, "model.safetensors.index.json", target_dir
|
||||
)
|
||||
|
||||
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
|
||||
|
||||
weight_map: dict[str, str] = {}
|
||||
|
||||
for index_file in index_files:
|
||||
relative_dir = index_file.parent.relative_to(index_files_dir)
|
||||
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
if relative_dir != Path("."):
|
||||
prefixed_weight_map = {
|
||||
f"{relative_dir}/{key}": str(relative_dir / value)
|
||||
for key, value in index_data.weight_map.items()
|
||||
}
|
||||
weight_map = weight_map | prefixed_weight_map
|
||||
else:
|
||||
weight_map = weight_map | index_data.weight_map
|
||||
|
||||
return weight_map
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
return index_data.weight_map
|
||||
|
||||
|
||||
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
|
||||
@@ -571,6 +551,8 @@ async def download_shard(
|
||||
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
# Update: <- This does not seem to be the case. Yay?
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
str(shard.model_meta.model_id), revision, recursive=True
|
||||
)
|
||||
|
||||
@@ -100,68 +100,26 @@ def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"tiktoken.model",
|
||||
"*/spiece.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jinja",
|
||||
]
|
||||
)
|
||||
shard_specific_patterns: set[str] = set()
|
||||
|
||||
if shard.model_meta.components is not None:
|
||||
shardable_component = next(
|
||||
(c for c in shard.model_meta.components if c.can_shard), None
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num <= shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
layer_independent_files = set(
|
||||
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
|
||||
)
|
||||
|
||||
if weight_map and shardable_component:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
# Strip component prefix from tensor name (added by weight map namespacing)
|
||||
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
|
||||
if "/" in tensor_name:
|
||||
_, tensor_name_no_prefix = tensor_name.split("/", 1)
|
||||
else:
|
||||
tensor_name_no_prefix = tensor_name
|
||||
|
||||
# Determine which component this file belongs to from filename
|
||||
component_path = Path(filename).parts[0] if "/" in filename else None
|
||||
|
||||
if component_path == shardable_component.component_path.rstrip("/"):
|
||||
layer_num = extract_layer_num(tensor_name_no_prefix)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num < shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
|
||||
if shard.is_first_layer or shard.is_last_layer:
|
||||
shard_specific_patterns.add(filename)
|
||||
else:
|
||||
shard_specific_patterns.add(filename)
|
||||
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
|
||||
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
|
||||
for component in shard.model_meta.components:
|
||||
if not component.can_shard and component.safetensors_index_filename is None:
|
||||
component_pattern = f"{component.component_path.rstrip('/')}/*"
|
||||
shard_specific_patterns.add(component_pattern)
|
||||
shard_specific_patterns.update(layer_independent_files)
|
||||
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
|
||||
else:
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num < shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
layer_independent_files = set(
|
||||
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
|
||||
)
|
||||
shard_specific_patterns.update(layer_independent_files)
|
||||
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
|
||||
return list(default_patterns | shard_specific_patterns)
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from exo.worker.engines.image.base import ImageGenerator
|
||||
from exo.worker.engines.image.distributed_model import initialize_image_model
|
||||
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
|
||||
|
||||
__all__ = [
|
||||
"ImageGenerator",
|
||||
"generate_image",
|
||||
"initialize_image_model",
|
||||
"warmup_image_generator",
|
||||
]
|
||||
@@ -1,50 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ImageGenerator(Protocol):
|
||||
@property
|
||||
def rank(self) -> int: ...
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
quality: Literal["low", "medium", "high"],
|
||||
seed: int,
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
"""Generate an image from a text prompt, or edit an existing image.
|
||||
|
||||
For distributed inference, only the last stage returns images.
|
||||
Other stages yield nothing after participating in the pipeline.
|
||||
|
||||
When partial_images > 0, yields intermediate images during diffusion
|
||||
as tuples of (image, partial_index, total_partials), then yields
|
||||
the final image.
|
||||
|
||||
When partial_images = 0 (default), only yields the final image.
|
||||
|
||||
Args:
|
||||
prompt: Text description of the image to generate
|
||||
height: Image height in pixels
|
||||
width: Image width in pixels
|
||||
quality: Generation quality level
|
||||
seed: Random seed for reproducibility
|
||||
image_path: Optional path to input image for image editing
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
|
||||
Yields:
|
||||
Intermediate images as (Image, partial_index, total_partials) tuples
|
||||
Final PIL Image (last stage) or nothing (other stages)
|
||||
"""
|
||||
...
|
||||
@@ -1,74 +0,0 @@
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BlockType(Enum):
|
||||
JOINT = "joint" # Separate image/text streams
|
||||
SINGLE = "single" # Concatenated streams
|
||||
|
||||
|
||||
class TransformerBlockConfig(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
block_type: BlockType
|
||||
count: int
|
||||
has_separate_text_output: bool # True for joint blocks that output text separately
|
||||
|
||||
|
||||
class ImageModelConfig(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
# Model identification
|
||||
model_family: str # "flux", "fibo", "qwen"
|
||||
model_variant: str # "schnell", "dev", etc.
|
||||
|
||||
# Architecture parameters
|
||||
hidden_dim: int
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
|
||||
# Block configuration - ordered sequence of block types
|
||||
block_configs: tuple[TransformerBlockConfig, ...]
|
||||
|
||||
# Tokenization parameters
|
||||
patch_size: int # 2 for Flux/Qwen
|
||||
vae_scale_factor: int # 8 for Flux, 16 for others
|
||||
|
||||
# Inference parameters
|
||||
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
|
||||
num_sync_steps_factor: float # Fraction of steps for sync phase
|
||||
|
||||
# Feature flags
|
||||
uses_attention_mask: bool # True for Fibo
|
||||
|
||||
# CFG (Classifier-Free Guidance) parameters
|
||||
guidance_scale: float | None = None # None or <= 1.0 disables CFG
|
||||
|
||||
@property
|
||||
def total_blocks(self) -> int:
|
||||
"""Total number of transformer blocks."""
|
||||
return sum(bc.count for bc in self.block_configs)
|
||||
|
||||
@property
|
||||
def joint_block_count(self) -> int:
|
||||
"""Number of joint transformer blocks."""
|
||||
return sum(
|
||||
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
|
||||
)
|
||||
|
||||
@property
|
||||
def single_block_count(self) -> int:
|
||||
"""Number of single transformer blocks."""
|
||||
return sum(
|
||||
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
|
||||
)
|
||||
|
||||
def get_steps_for_quality(self, quality: str) -> int:
|
||||
"""Get inference steps for a quality level."""
|
||||
return self.default_steps[quality]
|
||||
|
||||
def get_num_sync_steps(self, quality: str) -> int:
|
||||
"""Get number of synchronous steps based on quality."""
|
||||
return ceil(self.default_steps[quality] * self.num_sync_steps_factor)
|
||||
@@ -1,228 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.config import Config
|
||||
from PIL import Image
|
||||
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models import (
|
||||
create_adapter_for_model,
|
||||
get_config_for_model,
|
||||
)
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline import DiffusionRunner
|
||||
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
class DistributedImageModel:
|
||||
__slots__ = (
|
||||
"_config",
|
||||
"_adapter",
|
||||
"_group",
|
||||
"_shard_metadata",
|
||||
"_runner",
|
||||
)
|
||||
|
||||
_config: ImageModelConfig
|
||||
_adapter: BaseModelAdapter
|
||||
_group: Optional[mx.distributed.Group]
|
||||
_shard_metadata: PipelineShardMetadata
|
||||
_runner: DiffusionRunner
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
# Get model config and create adapter (adapter owns the model)
|
||||
config = get_config_for_model(model_id)
|
||||
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
|
||||
|
||||
if group is not None:
|
||||
adapter.slice_transformer_blocks(
|
||||
start_layer=shard_metadata.start_layer,
|
||||
end_layer=shard_metadata.end_layer,
|
||||
total_joint_blocks=config.joint_block_count,
|
||||
total_single_blocks=config.single_block_count,
|
||||
)
|
||||
|
||||
# Create diffusion runner (handles both single-node and distributed modes)
|
||||
num_sync_steps = config.get_num_sync_steps("medium") if group else 0
|
||||
runner = DiffusionRunner(
|
||||
config=config,
|
||||
adapter=adapter,
|
||||
group=group,
|
||||
shard_metadata=shard_metadata,
|
||||
num_sync_steps=num_sync_steps,
|
||||
)
|
||||
|
||||
if group is not None:
|
||||
logger.info("Initialized distributed diffusion runner")
|
||||
|
||||
mx.eval(adapter.model.parameters())
|
||||
|
||||
# TODO(ciaran): Do we need this?
|
||||
mx.eval(adapter.model)
|
||||
|
||||
# Synchronize processes before generation to avoid timeout
|
||||
mx_barrier(group)
|
||||
logger.info(f"Transformer sharded for rank {group.rank()}")
|
||||
else:
|
||||
logger.info("Single-node initialization")
|
||||
|
||||
object.__setattr__(self, "_config", config)
|
||||
object.__setattr__(self, "_adapter", adapter)
|
||||
object.__setattr__(self, "_group", group)
|
||||
object.__setattr__(self, "_shard_metadata", shard_metadata)
|
||||
object.__setattr__(self, "_runner", runner)
|
||||
|
||||
@classmethod
|
||||
def from_bound_instance(
|
||||
cls, bound_instance: BoundInstance
|
||||
) -> "DistributedImageModel":
|
||||
model_id = bound_instance.bound_shard.model_meta.model_id
|
||||
model_path = build_model_path(model_id)
|
||||
|
||||
shard_metadata = bound_instance.bound_shard
|
||||
if not isinstance(shard_metadata, PipelineShardMetadata):
|
||||
raise ValueError("Expected PipelineShardMetadata for image generation")
|
||||
|
||||
is_distributed = (
|
||||
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
|
||||
)
|
||||
|
||||
if is_distributed:
|
||||
logger.info("Starting distributed init for image model")
|
||||
group = mlx_distributed_init(bound_instance)
|
||||
else:
|
||||
group = None
|
||||
|
||||
return cls(
|
||||
model_id=model_id,
|
||||
local_path=model_path,
|
||||
shard_metadata=shard_metadata,
|
||||
group=group,
|
||||
)
|
||||
|
||||
@property
|
||||
def model(self) -> Any:
|
||||
"""Return the underlying mflux model via the adapter."""
|
||||
return self._adapter.model
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def adapter(self) -> BaseModelAdapter:
|
||||
return self._adapter
|
||||
|
||||
@property
|
||||
def group(self) -> Optional[mx.distributed.Group]:
|
||||
return self._group
|
||||
|
||||
@property
|
||||
def shard_metadata(self) -> PipelineShardMetadata:
|
||||
return self._shard_metadata
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return self._shard_metadata.device_rank
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return self._shard_metadata.world_size
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self._shard_metadata.device_rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self._shard_metadata.device_rank == self._shard_metadata.world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return self._shard_metadata.world_size > 1
|
||||
|
||||
@property
|
||||
def runner(self) -> DiffusionRunner:
|
||||
return self._runner
|
||||
|
||||
# Delegate attribute access to the underlying model via the adapter.
|
||||
# Guarded with TYPE_CHECKING to prevent type checker complaints
|
||||
# while still providing full delegation at runtime.
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._adapter.model, name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in (
|
||||
"_config",
|
||||
"_adapter",
|
||||
"_group",
|
||||
"_shard_metadata",
|
||||
"_runner",
|
||||
):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
setattr(self._adapter.model, name, value)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
quality: Literal["low", "medium", "high"] = "medium",
|
||||
seed: int = 2,
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
# Determine number of inference steps based on quality
|
||||
steps = self._config.get_steps_for_quality(quality)
|
||||
|
||||
# For edit mode: compute dimensions from input image
|
||||
# This also stores image_paths in the adapter for encode_prompt()
|
||||
if image_path is not None:
|
||||
computed_dims = self._adapter.set_image_dimensions(image_path)
|
||||
if computed_dims is not None:
|
||||
# Override user-provided dimensions with computed ones
|
||||
width, height = computed_dims
|
||||
|
||||
config = Config(
|
||||
num_inference_steps=steps,
|
||||
height=height,
|
||||
width=width,
|
||||
image_path=image_path,
|
||||
)
|
||||
|
||||
# Generate images via the runner
|
||||
for result in self._runner.generate_image(
|
||||
settings=config,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
partial_images=partial_images,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (GeneratedImage, partial_index, total_partials)
|
||||
generated_image, partial_idx, total_partials = result
|
||||
yield (generated_image.image, partial_idx, total_partials)
|
||||
else:
|
||||
# Final image: GeneratedImage
|
||||
logger.info("generated image")
|
||||
yield result.image
|
||||
|
||||
|
||||
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
|
||||
"""Initialize DistributedImageModel from a BoundInstance."""
|
||||
return DistributedImageModel.from_bound_instance(bound_instance)
|
||||
@@ -1,120 +0,0 @@
|
||||
import base64
|
||||
import io
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from exo.shared.types.api import ImageEditsInternalParams, ImageGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
)
|
||||
from exo.worker.engines.image.base import ImageGenerator
|
||||
|
||||
|
||||
def parse_size(size_str: str | None) -> tuple[int, int]:
|
||||
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
|
||||
if not size_str or size_str == "auto":
|
||||
size_str = "1024x1024"
|
||||
|
||||
try:
|
||||
parts = size_str.split("x")
|
||||
if len(parts) == 2:
|
||||
width, height = int(parts[0]), int(parts[1])
|
||||
return (width, height)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
# Default fallback
|
||||
return (1024, 1024)
|
||||
|
||||
|
||||
def warmup_image_generator(model: ImageGenerator) -> Image.Image | None:
|
||||
"""Warmup the image generator with a small image."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a small dummy image for warmup (needed for edit models)
|
||||
dummy_image = Image.new("RGB", (256, 256), color=(128, 128, 128))
|
||||
dummy_path = Path(tmpdir) / "warmup.png"
|
||||
dummy_image.save(dummy_path)
|
||||
|
||||
for result in model.generate(
|
||||
prompt="Warmup",
|
||||
height=256,
|
||||
width=256,
|
||||
quality="low",
|
||||
seed=2,
|
||||
image_path=dummy_path,
|
||||
):
|
||||
if not isinstance(result, tuple):
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def generate_image(
|
||||
model: ImageGenerator,
|
||||
task: ImageGenerationTaskParams | ImageEditsInternalParams,
|
||||
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
|
||||
"""Generate image(s), optionally yielding partial results.
|
||||
|
||||
When partial_images > 0 or stream=True, yields PartialImageResponse for
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0)
|
||||
ImageGenerationResponse for the final complete image
|
||||
"""
|
||||
width, height = parse_size(task.size)
|
||||
quality: Literal["low", "medium", "high"] = task.quality or "medium"
|
||||
seed = 2 # TODO(ciaran): Randomise when not testing anymore
|
||||
|
||||
# Handle streaming params for both generation and edit tasks
|
||||
partial_images = task.partial_images or (3 if task.stream else 0)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
if isinstance(task, ImageEditsInternalParams):
|
||||
# Decode base64 image data and save to temp file
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
|
||||
# Iterate over generator results
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
)
|
||||
else:
|
||||
# Final image
|
||||
image = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
)
|
||||
@@ -1,84 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.flux import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
FluxModelAdapter,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen import (
|
||||
QWEN_IMAGE_CONFIG,
|
||||
QWEN_IMAGE_EDIT_CONFIG,
|
||||
QwenEditModelAdapter,
|
||||
QwenModelAdapter,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.adapter import ModelAdapter
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
# Type alias for adapter factory functions
|
||||
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
|
||||
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter]
|
||||
|
||||
# Registry maps model_family string to adapter factory
|
||||
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
|
||||
"flux": FluxModelAdapter,
|
||||
"qwen-edit": QwenEditModelAdapter,
|
||||
"qwen": QwenModelAdapter,
|
||||
}
|
||||
|
||||
# Config registry: maps model ID patterns to configs
|
||||
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
|
||||
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
|
||||
"flux.1-dev": FLUX_DEV_CONFIG,
|
||||
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
|
||||
"qwen-image": QWEN_IMAGE_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_config_for_model(model_id: str) -> ImageModelConfig:
|
||||
"""Get configuration for a model ID.
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
|
||||
|
||||
Returns:
|
||||
The model configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If no configuration found for model ID
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
|
||||
for pattern, config in _CONFIG_REGISTRY.items():
|
||||
if pattern in model_id_lower:
|
||||
return config
|
||||
|
||||
raise ValueError(f"No configuration found for model: {model_id}")
|
||||
|
||||
|
||||
def create_adapter_for_model(
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
) -> ModelAdapter:
|
||||
"""Create a model adapter for the given configuration.
|
||||
|
||||
Args:
|
||||
config: The model configuration
|
||||
model_id: The model identifier
|
||||
local_path: Path to the model weights
|
||||
quantize: Optional quantization bits
|
||||
|
||||
Returns:
|
||||
A ModelAdapter instance
|
||||
|
||||
Raises:
|
||||
ValueError: If no adapter found for model family
|
||||
"""
|
||||
factory = _ADAPTER_REGISTRY.get(config.model_family)
|
||||
if factory is None:
|
||||
raise ValueError(f"No adapter found for model family: {config.model_family}")
|
||||
return factory(config, model_id, local_path, quantize)
|
||||
@@ -1,103 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
|
||||
from mflux.utils.array_util import ArrayUtil
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
|
||||
|
||||
class BaseModelAdapter(ABC):
|
||||
"""Base class for model adapters with shared utilities.
|
||||
|
||||
Provides common implementations for latent creation and decoding.
|
||||
Subclasses implement model-specific prompt encoding and noise computation.
|
||||
"""
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
|
||||
"""Create initial latents. Uses model-specific latent creator."""
|
||||
return LatentCreator.create_for_txt2img_or_img2img(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
img2img=Img2Img(
|
||||
vae=self.model.vae,
|
||||
latent_creator=self._get_latent_creator(),
|
||||
sigmas=runtime_config.scheduler.sigmas,
|
||||
init_time_step=runtime_config.init_time_step,
|
||||
image_path=runtime_config.image_path,
|
||||
),
|
||||
)
|
||||
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
) -> Any:
|
||||
"""Decode latents to image. Shared implementation."""
|
||||
latents = ArrayUtil.unpack_latents(
|
||||
latents=latents,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
decoded = self.model.vae.decode(latents)
|
||||
return ImageUtil.to_image(
|
||||
decoded_latents=decoded,
|
||||
config=runtime_config,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
quantization=self.model.bits,
|
||||
lora_paths=self.model.lora_paths,
|
||||
lora_scales=self.model.lora_scales,
|
||||
image_path=runtime_config.image_path,
|
||||
image_strength=runtime_config.image_strength,
|
||||
generation_time=0,
|
||||
)
|
||||
|
||||
# Abstract methods - subclasses must implement
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> Any:
|
||||
"""Return the underlying mflux model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_latent_creator(self) -> type:
|
||||
"""Return the latent creator class for this model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
):
|
||||
"""Remove transformer blocks outside the assigned range.
|
||||
|
||||
This should be called BEFORE mx.eval() to avoid loading unused weights
|
||||
in distributed mode.
|
||||
|
||||
Args:
|
||||
start_layer: First layer index (inclusive) assigned to this node
|
||||
end_layer: Last layer index (exclusive) assigned to this node
|
||||
total_joint_blocks: Total number of joint blocks in the model
|
||||
total_single_blocks: Total number of single blocks in the model
|
||||
"""
|
||||
...
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
|
||||
"""Default implementation: no dimension computation needed.
|
||||
|
||||
Override in edit adapters to compute dimensions from input image.
|
||||
|
||||
Returns:
|
||||
None (use user-specified dimensions)
|
||||
"""
|
||||
return None
|
||||
@@ -1,11 +0,0 @@
|
||||
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
|
||||
from exo.worker.engines.image.models.flux.config import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FluxModelAdapter",
|
||||
"FLUX_DEV_CONFIG",
|
||||
"FLUX_SCHNELL_CONFIG",
|
||||
]
|
||||
@@ -1,680 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.model_config import ModelConfig
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
|
||||
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
|
||||
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
|
||||
AttentionUtils,
|
||||
)
|
||||
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
|
||||
JointTransformerBlock,
|
||||
)
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
from mflux.models.flux.variants.txt2img.flux import Flux1
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class FluxPromptData:
|
||||
"""Container for Flux prompt encoding results."""
|
||||
|
||||
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._pooled_prompt_embeds = pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
"""Flux does not use CFG."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
"""Flux does not use CFG."""
|
||||
return None
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Flux has no extra forward kwargs."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Flux does not use conditioning latents."""
|
||||
return None
|
||||
|
||||
|
||||
class FluxModelAdapter(BaseModelAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = Flux1(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
local_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> Flux1:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> Transformer:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.x_embedder.weight.shape[0]
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return FluxLatentCreator
|
||||
|
||||
def encode_prompt(self, prompt: str) -> FluxPromptData:
|
||||
"""Encode prompt into FluxPromptData."""
|
||||
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_cache=self._model.prompt_cache,
|
||||
t5_tokenizer=self._model.t5_tokenizer,
|
||||
clip_tokenizer=self._model.clip_tokenizer,
|
||||
t5_text_encoder=self._model.t5_text_encoder,
|
||||
clip_text_encoder=self._model.clip_text_encoder,
|
||||
)
|
||||
return FluxPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
return False
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
raise NotImplementedError("Flux does not use classifier-free guidance")
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
embedded_hidden = self._transformer.x_embedder(hidden_states)
|
||||
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None, # Ignored by Flux
|
||||
) -> mx.array:
|
||||
if pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"pooled_prompt_embeds is required for Flux text embeddings"
|
||||
)
|
||||
|
||||
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
|
||||
return Transformer.compute_text_embeddings(
|
||||
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
|
||||
)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> mx.array:
|
||||
kontext_image_ids = kwargs.get("kontext_image_ids")
|
||||
return Transformer.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
self._transformer.pos_embed,
|
||||
runtime_config,
|
||||
kontext_image_ids,
|
||||
)
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any, # mx.array for Flux
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_joint_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
)
|
||||
else:
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_joint_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
)
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_single_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
)
|
||||
else:
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_single_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
)
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
|
||||
return self._transformer.proj_out(hidden_states)
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
return cast(
|
||||
list[JointBlockInterface], list(self._transformer.transformer_blocks)
|
||||
)
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
return cast(
|
||||
list[SingleBlockInterface],
|
||||
list(self._transformer.single_transformer_blocks),
|
||||
)
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
) -> None:
|
||||
if end_layer <= total_joint_blocks:
|
||||
# All assigned are joint blocks
|
||||
joint_start, joint_end = start_layer, end_layer
|
||||
single_start, single_end = 0, 0
|
||||
elif start_layer >= total_joint_blocks:
|
||||
# All assigned are single blocks
|
||||
joint_start, joint_end = 0, 0
|
||||
single_start = start_layer - total_joint_blocks
|
||||
single_end = end_layer - total_joint_blocks
|
||||
else:
|
||||
# Spans both joint and single
|
||||
joint_start, joint_end = start_layer, total_joint_blocks
|
||||
single_start = 0
|
||||
single_end = end_layer - total_joint_blocks
|
||||
|
||||
all_joint = list(self._transformer.transformer_blocks)
|
||||
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
|
||||
|
||||
all_single = list(self._transformer.single_transformer_blocks)
|
||||
self._transformer.single_transformer_blocks = all_single[
|
||||
single_start:single_end
|
||||
]
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
def _apply_joint_block_caching(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
num_img_tokens = hidden_states.shape[1]
|
||||
batch_size = hidden_states.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# 1. Compute norms
|
||||
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
|
||||
block.norm1_context(
|
||||
hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V for full image
|
||||
img_query, img_key, img_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Compute Q, K, V for text
|
||||
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_encoder,
|
||||
to_q=attn.add_q_proj,
|
||||
to_k=attn.add_k_proj,
|
||||
to_v=attn.add_v_proj,
|
||||
norm_q=attn.norm_added_q,
|
||||
norm_k=attn.norm_added_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 4. Concatenate Q, K, V: [text, image]
|
||||
query = mx.concatenate([txt_query, img_query], axis=2)
|
||||
key = mx.concatenate([txt_key, img_key], axis=2)
|
||||
value = mx.concatenate([txt_value, img_value], axis=2)
|
||||
|
||||
# 5. Apply RoPE
|
||||
query, key = AttentionUtils.apply_rope(
|
||||
xq=query, xk=key, freqs_cis=rotary_embeddings
|
||||
)
|
||||
|
||||
# 6. Store IMAGE K/V in cache for async pipeline
|
||||
if kv_cache is not None:
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=0,
|
||||
patch_end=num_img_tokens,
|
||||
key=key[:, :, text_seq_len:, :],
|
||||
value=value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 7. Compute full attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 8. Extract and project outputs
|
||||
context_attn_output = attn_output[:, :text_seq_len, :]
|
||||
attn_output = attn_output[:, text_seq_len:, :]
|
||||
|
||||
attn_output = attn.to_out[0](attn_output)
|
||||
context_attn_output = attn.to_add_out(context_attn_output)
|
||||
|
||||
# 9. Apply norm and feed forward
|
||||
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=hidden_states,
|
||||
attn_output=attn_output,
|
||||
gate_mlp=gate_mlp,
|
||||
gate_msa=gate_msa,
|
||||
scale_mlp=scale_mlp,
|
||||
shift_mlp=shift_mlp,
|
||||
norm_layer=block.norm2,
|
||||
ff_layer=block.ff,
|
||||
)
|
||||
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=encoder_hidden_states,
|
||||
attn_output=context_attn_output,
|
||||
gate_mlp=c_gate_mlp,
|
||||
gate_msa=c_gate_msa,
|
||||
scale_mlp=c_scale_mlp,
|
||||
shift_mlp=c_shift_mlp,
|
||||
norm_layer=block.norm2_context,
|
||||
ff_layer=block.ff_context,
|
||||
)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
def _apply_joint_block_patched(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
patch_hidden: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# 1. Compute norms
|
||||
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
|
||||
hidden_states=patch_hidden,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
|
||||
block.norm1_context(
|
||||
hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V for image patch
|
||||
img_query, img_key, img_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Compute Q, K, V for text
|
||||
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_encoder,
|
||||
to_q=attn.add_q_proj,
|
||||
to_k=attn.add_k_proj,
|
||||
to_v=attn.add_v_proj,
|
||||
norm_q=attn.norm_added_q,
|
||||
norm_k=attn.norm_added_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 4. Concatenate Q, K, V for patch: [text, patch]
|
||||
query = mx.concatenate([txt_query, img_query], axis=2)
|
||||
patch_key = mx.concatenate([txt_key, img_key], axis=2)
|
||||
patch_value = mx.concatenate([txt_value, img_value], axis=2)
|
||||
|
||||
# 5. Extract RoPE for [text + current_patch]
|
||||
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
|
||||
patch_img_rope = rotary_embeddings[
|
||||
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
|
||||
]
|
||||
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
|
||||
|
||||
# 6. Apply RoPE
|
||||
query, patch_key = AttentionUtils.apply_rope(
|
||||
xq=query, xk=patch_key, freqs_cis=patch_rope
|
||||
)
|
||||
|
||||
# 7. Update cache with this patch's IMAGE K/V
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=patch_key[:, :, text_seq_len:, :],
|
||||
value=patch_value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 8. Get full K, V from cache
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=patch_key[:, :, :text_seq_len, :],
|
||||
text_value=patch_value[:, :, :text_seq_len, :],
|
||||
)
|
||||
|
||||
# 9. Compute attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=full_key,
|
||||
value=full_value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 10. Extract and project outputs
|
||||
context_attn_output = attn_output[:, :text_seq_len, :]
|
||||
hidden_attn_output = attn_output[:, text_seq_len:, :]
|
||||
|
||||
hidden_attn_output = attn.to_out[0](hidden_attn_output)
|
||||
context_attn_output = attn.to_add_out(context_attn_output)
|
||||
|
||||
# 11. Apply norm and feed forward
|
||||
patch_hidden = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=patch_hidden,
|
||||
attn_output=hidden_attn_output,
|
||||
gate_mlp=gate_mlp,
|
||||
gate_msa=gate_msa,
|
||||
scale_mlp=scale_mlp,
|
||||
shift_mlp=shift_mlp,
|
||||
norm_layer=block.norm2,
|
||||
ff_layer=block.ff,
|
||||
)
|
||||
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=encoder_hidden_states,
|
||||
attn_output=context_attn_output,
|
||||
gate_mlp=c_gate_mlp,
|
||||
gate_msa=c_gate_msa,
|
||||
scale_mlp=c_scale_mlp,
|
||||
shift_mlp=c_shift_mlp,
|
||||
norm_layer=block.norm2_context,
|
||||
ff_layer=block.ff_context,
|
||||
)
|
||||
|
||||
return encoder_hidden_states, patch_hidden
|
||||
|
||||
def _apply_single_block_caching(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
) -> mx.array:
|
||||
total_seq_len = hidden_states.shape[1]
|
||||
num_img_tokens = total_seq_len - text_seq_len
|
||||
batch_size = hidden_states.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# Residual connection
|
||||
residual = hidden_states
|
||||
|
||||
# 1. Compute norm
|
||||
norm_hidden, gate = block.norm(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V
|
||||
query, key, value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Apply RoPE
|
||||
query, key = AttentionUtils.apply_rope(
|
||||
xq=query, xk=key, freqs_cis=rotary_embeddings
|
||||
)
|
||||
|
||||
# 4. Store IMAGE K/V in cache
|
||||
if kv_cache is not None:
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=0,
|
||||
patch_end=num_img_tokens,
|
||||
key=key[:, :, text_seq_len:, :],
|
||||
value=value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 5. Compute attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 6. Apply feed forward and projection
|
||||
hidden_states = block._apply_feed_forward_and_projection(
|
||||
norm_hidden_states=norm_hidden,
|
||||
attn_output=attn_output,
|
||||
gate=gate,
|
||||
)
|
||||
|
||||
return residual + hidden_states
|
||||
|
||||
def _apply_single_block_patched(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
patch_hidden: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
) -> mx.array:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# Residual connection
|
||||
residual = patch_hidden
|
||||
|
||||
# 1. Compute norm
|
||||
norm_hidden, gate = block.norm(
|
||||
hidden_states=patch_hidden,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V
|
||||
query, key, value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Extract RoPE for [text + current_patch]
|
||||
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
|
||||
patch_img_rope = rotary_embeddings[
|
||||
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
|
||||
]
|
||||
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
|
||||
|
||||
# 4. Apply RoPE
|
||||
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=patch_rope)
|
||||
|
||||
# 5. Update cache with this patch's IMAGE K/V
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=key[:, :, text_seq_len:, :],
|
||||
value=value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 6. Get full K, V from cache
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=key[:, :, :text_seq_len, :],
|
||||
text_value=value[:, :, :text_seq_len, :],
|
||||
)
|
||||
|
||||
# 7. Compute attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=full_key,
|
||||
value=full_value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 8. Apply feed forward and projection
|
||||
hidden_states = block._apply_feed_forward_and_projection(
|
||||
norm_hidden_states=norm_hidden,
|
||||
attn_output=attn_output,
|
||||
gate=gate,
|
||||
)
|
||||
|
||||
return residual + hidden_states
|
||||
@@ -1,48 +0,0 @@
|
||||
from exo.worker.engines.image.config import (
|
||||
BlockType,
|
||||
ImageModelConfig,
|
||||
TransformerBlockConfig,
|
||||
)
|
||||
|
||||
FLUX_SCHNELL_CONFIG = ImageModelConfig(
|
||||
model_family="flux",
|
||||
model_variant="schnell",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=8,
|
||||
default_steps={"low": 1, "medium": 2, "high": 4},
|
||||
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
|
||||
uses_attention_mask=False,
|
||||
)
|
||||
|
||||
|
||||
FLUX_DEV_CONFIG = ImageModelConfig(
|
||||
model_family="flux",
|
||||
model_variant="dev",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=8,
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
|
||||
uses_attention_mask=False,
|
||||
)
|
||||
@@ -1,13 +0,0 @@
|
||||
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
|
||||
from exo.worker.engines.image.models.qwen.config import (
|
||||
QWEN_IMAGE_CONFIG,
|
||||
QWEN_IMAGE_EDIT_CONFIG,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter
|
||||
|
||||
__all__ = [
|
||||
"QwenModelAdapter",
|
||||
"QwenEditModelAdapter",
|
||||
"QWEN_IMAGE_CONFIG",
|
||||
"QWEN_IMAGE_EDIT_CONFIG",
|
||||
]
|
||||
@@ -1,519 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.model_config import ModelConfig
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
||||
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
|
||||
QwenPromptEncoder,
|
||||
)
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
|
||||
QwenTransformerBlock,
|
||||
)
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class QwenPromptData:
|
||||
"""Container for Qwen prompt encoding results.
|
||||
|
||||
Implements PromptData protocol with additional Qwen-specific attributes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
prompt_mask: mx.array,
|
||||
negative_prompt_embeds: mx.array,
|
||||
negative_prompt_mask: mx.array,
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self.prompt_mask = prompt_mask
|
||||
self._negative_prompt_embeds = negative_prompt_embeds
|
||||
self.negative_prompt_mask = negative_prompt_mask
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from encoder."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
|
||||
return self._prompt_embeds # Use prompt_embeds as placeholder
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array:
|
||||
"""Negative prompt embeddings for CFG."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder - Qwen doesn't use pooled embeds."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Return encoder_hidden_states_mask for the appropriate prompt."""
|
||||
if positive:
|
||||
return {"encoder_hidden_states_mask": self.prompt_mask}
|
||||
else:
|
||||
return {"encoder_hidden_states_mask": self.negative_prompt_mask}
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Standard Qwen does not use conditioning latents."""
|
||||
return None
|
||||
|
||||
|
||||
class QwenModelAdapter(BaseModelAdapter):
|
||||
"""Adapter for Qwen-Image model.
|
||||
|
||||
Key differences from Flux:
|
||||
- Single text encoder (vs dual T5+CLIP)
|
||||
- 60 joint-style blocks, no single blocks
|
||||
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
|
||||
- Norm-preserving CFG with negative prompts
|
||||
- Uses attention mask for variable-length text
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = QwenImage(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
local_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> QwenImage:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> QwenTransformer:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.inner_dim
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return QwenLatentCreator
|
||||
|
||||
def encode_prompt(self, prompt: str) -> QwenPromptData:
|
||||
"""Encode prompt into QwenPromptData.
|
||||
|
||||
Qwen uses classifier-free guidance with explicit negative prompts.
|
||||
Returns a QwenPromptData container with all 4 tensors.
|
||||
"""
|
||||
# TODO(ciaran): empty string as default negative prompt
|
||||
negative_prompt = ""
|
||||
|
||||
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
|
||||
QwenPromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_cache=self._model.prompt_cache,
|
||||
qwen_tokenizer=self._model.qwen_tokenizer,
|
||||
qwen_text_encoder=self._model.text_encoder,
|
||||
)
|
||||
)
|
||||
|
||||
return QwenPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_mask=prompt_mask,
|
||||
negative_prompt_embeds=neg_embeds,
|
||||
negative_prompt_mask=neg_mask,
|
||||
)
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
gs = self._config.guidance_scale
|
||||
return gs is not None and gs > 1.0
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
return self._model.compute_guided_noise(
|
||||
noise=noise_positive,
|
||||
noise_negative=noise_negative,
|
||||
guidance=guidance_scale,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute image and text embeddings."""
|
||||
# Image embedding
|
||||
embedded_hidden = self._transformer.img_in(hidden_states)
|
||||
# Text embedding: first normalize, then project
|
||||
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
|
||||
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings.
|
||||
|
||||
For Qwen, the time_text_embed only uses hidden_states for:
|
||||
- batch_size (shape[0])
|
||||
- dtype
|
||||
|
||||
This allows us to pass any tensor (latents, prompt_embeds) as a fallback
|
||||
when embedded hidden_states are not yet available.
|
||||
"""
|
||||
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
|
||||
# (which for Qwen is the same as prompt_embeds)
|
||||
ref_tensor = (
|
||||
hidden_states if hidden_states is not None else pooled_prompt_embeds
|
||||
)
|
||||
if ref_tensor is None:
|
||||
raise ValueError(
|
||||
"Either hidden_states or pooled_prompt_embeds is required "
|
||||
"for Qwen text embeddings"
|
||||
)
|
||||
|
||||
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
|
||||
batch_size = ref_tensor.shape[0]
|
||||
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
|
||||
return self._transformer.time_text_embed(timestep, ref_tensor)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Compute 3D rotary embeddings for Qwen.
|
||||
|
||||
Qwen uses video-aware 3D RoPE with separate embeddings for image and text.
|
||||
|
||||
Returns:
|
||||
tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]:
|
||||
((img_cos, img_sin), (txt_cos, txt_sin))
|
||||
"""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
cond_image_grid = kwargs.get("cond_image_grid")
|
||||
|
||||
if encoder_hidden_states_mask is None:
|
||||
raise ValueError(
|
||||
"encoder_hidden_states_mask is required for Qwen RoPE computation"
|
||||
)
|
||||
|
||||
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
pos_embed=self._transformer.pos_embed,
|
||||
config=runtime_config,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any, # tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]] for Qwen
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply Qwen joint block.
|
||||
|
||||
For caching mode, we run the full block and optionally populate the KV cache.
|
||||
For patched mode, we use the cached KV values (not yet implemented).
|
||||
"""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
block_idx = kwargs.get("block_idx")
|
||||
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_joint_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
else:
|
||||
# mode == BlockWrapperMode.PATCHED
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_joint_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Qwen has no single blocks."""
|
||||
raise NotImplementedError("Qwen does not have single blocks")
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply final normalization and projection."""
|
||||
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
|
||||
return self._transformer.proj_out(hidden_states)
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
"""Return all 60 transformer blocks."""
|
||||
return cast(
|
||||
list[JointBlockInterface], list(self._transformer.transformer_blocks)
|
||||
)
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
"""Qwen has no single blocks."""
|
||||
return []
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
) -> None:
|
||||
all_blocks = list(self._transformer.transformer_blocks)
|
||||
assigned_blocks = all_blocks[start_layer:end_layer]
|
||||
self._transformer.transformer_blocks = assigned_blocks
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
"""Merge image and text streams.
|
||||
|
||||
For Qwen, this is called before final projection.
|
||||
The streams remain separate through all blocks.
|
||||
"""
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
def _apply_joint_block_caching(
|
||||
self,
|
||||
block: Any, # QwenTransformerBlock
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply joint block in caching mode (full attention, optionally populate cache).
|
||||
|
||||
Delegates to the QwenTransformerBlock's forward pass.
|
||||
"""
|
||||
# Call the block directly - it handles all the modulation and attention internally
|
||||
return block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_emb=rotary_embeddings,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def _apply_joint_block_patched(
|
||||
self,
|
||||
block: Any, # QwenTransformerBlock
|
||||
patch_hidden: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dim
|
||||
|
||||
# 1. Compute modulation parameters
|
||||
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
|
||||
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
|
||||
|
||||
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
|
||||
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
|
||||
|
||||
# 2. Apply normalization and modulation
|
||||
img_normed = block.img_norm1(patch_hidden)
|
||||
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
|
||||
|
||||
txt_normed = block.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
|
||||
|
||||
# 3. Compute Q, K, V for image patch
|
||||
img_query = attn.to_q(img_modulated)
|
||||
img_key = attn.to_k(img_modulated)
|
||||
img_value = attn.to_v(img_modulated)
|
||||
|
||||
# 4. Compute Q, K, V for text
|
||||
txt_query = attn.add_q_proj(txt_modulated)
|
||||
txt_key = attn.add_k_proj(txt_modulated)
|
||||
txt_value = attn.add_v_proj(txt_modulated)
|
||||
|
||||
# 5. Reshape to [B, S, H, D]
|
||||
patch_len = patch_hidden.shape[1]
|
||||
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
|
||||
|
||||
txt_query = mx.reshape(
|
||||
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
|
||||
txt_value = mx.reshape(
|
||||
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
|
||||
# 6. Apply RMSNorm to Q, K
|
||||
img_query = attn.norm_q(img_query)
|
||||
img_key = attn.norm_k(img_key)
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
# 7. Extract RoPE for patch: slice image RoPE, keep full text RoPE
|
||||
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
|
||||
patch_img_cos = img_cos[patch_start:patch_end]
|
||||
patch_img_sin = img_sin[patch_start:patch_end]
|
||||
|
||||
# 8. Apply RoPE to Q, K
|
||||
img_query = QwenAttention._apply_rope_qwen(
|
||||
img_query, patch_img_cos, patch_img_sin
|
||||
)
|
||||
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
|
||||
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
|
||||
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
|
||||
|
||||
# 9. Transpose to [B, H, S, D] for cache operations
|
||||
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
|
||||
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
|
||||
|
||||
# 10. Update cache with this patch's IMAGE K/V
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=img_key_bhsd,
|
||||
value=img_value_bhsd,
|
||||
)
|
||||
|
||||
# 11. Get full K, V from cache (text + full image)
|
||||
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
|
||||
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=txt_key_bhsd,
|
||||
text_value=txt_value_bhsd,
|
||||
)
|
||||
|
||||
# 12. Build query: [text, patch]
|
||||
joint_query = mx.concatenate([txt_query, img_query], axis=1)
|
||||
|
||||
# 13. Build attention mask for [text + patch] query attending to [text + full_image] KV
|
||||
mask = QwenAttention._convert_mask_for_qwen(
|
||||
mask=encoder_hidden_states_mask,
|
||||
joint_seq_len=full_key.shape[2], # text + full_image
|
||||
txt_seq_len=text_seq_len,
|
||||
)
|
||||
|
||||
# 14. Compute attention
|
||||
hidden_states = attn._compute_attention_qwen(
|
||||
query=joint_query,
|
||||
key=mx.transpose(full_key, (0, 2, 1, 3)), # Back to [B, S, H, D]
|
||||
value=mx.transpose(full_value, (0, 2, 1, 3)),
|
||||
mask=mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
# 15. Extract text and image attention outputs
|
||||
txt_attn_output = hidden_states[:, :text_seq_len, :]
|
||||
img_attn_output = hidden_states[:, text_seq_len:, :]
|
||||
|
||||
# 16. Project outputs
|
||||
img_attn_output = attn.attn_to_out[0](img_attn_output)
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
|
||||
# 17. Apply residual + gate for attention
|
||||
patch_hidden = patch_hidden + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
|
||||
# 18. Apply feed-forward for image
|
||||
img_normed2 = block.img_norm2(patch_hidden)
|
||||
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
|
||||
img_normed2, img_mod2
|
||||
)
|
||||
img_mlp_output = block.img_ff(img_modulated2)
|
||||
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
|
||||
|
||||
# 19. Apply feed-forward for text
|
||||
txt_normed2 = block.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
|
||||
txt_normed2, txt_mod2
|
||||
)
|
||||
txt_mlp_output = block.txt_ff(txt_modulated2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
||||
|
||||
return encoder_hidden_states, patch_hidden
|
||||
@@ -1,49 +0,0 @@
|
||||
from exo.worker.engines.image.config import (
|
||||
BlockType,
|
||||
ImageModelConfig,
|
||||
TransformerBlockConfig,
|
||||
)
|
||||
|
||||
# Qwen-Image has 60 joint-style blocks (no single blocks)
|
||||
# Architecture: 24 heads * 128 dim = 3072 hidden dim
|
||||
# VAE uses scale factor of 16 (vs Flux's 8)
|
||||
QWEN_IMAGE_CONFIG = ImageModelConfig(
|
||||
model_family="qwen",
|
||||
model_variant="image",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
|
||||
),
|
||||
# Qwen has no single blocks - all blocks process image and text separately
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=16,
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
|
||||
uses_attention_mask=True, # Qwen uses encoder_hidden_states_mask
|
||||
guidance_scale=None, # Set to None or < 1.0 to disable CFG
|
||||
)
|
||||
|
||||
# Qwen-Image-Edit uses the same architecture but different processing pipeline
|
||||
# Uses vision-language encoding and conditioning latents
|
||||
QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
|
||||
model_family="qwen-edit",
|
||||
model_variant="image-edit",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
|
||||
),
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=16,
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125,
|
||||
uses_attention_mask=True,
|
||||
guidance_scale=None,
|
||||
)
|
||||
@@ -1,671 +0,0 @@
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
|
||||
QwenTransformerBlock,
|
||||
)
|
||||
from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
|
||||
from mflux.models.qwen.variants.edit.utils.qwen_edit_util import QwenEditUtil
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class QwenEditPromptData:
|
||||
"""Container for Qwen edit prompt encoding results.
|
||||
|
||||
Includes vision-language encoded embeddings and edit-specific conditioning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
prompt_mask: mx.array,
|
||||
negative_prompt_embeds: mx.array,
|
||||
negative_prompt_mask: mx.array,
|
||||
conditioning_latents: mx.array,
|
||||
qwen_image_ids: mx.array,
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self.prompt_mask = prompt_mask
|
||||
self._negative_prompt_embeds = negative_prompt_embeds
|
||||
self.negative_prompt_mask = negative_prompt_mask
|
||||
self._conditioning_latents = conditioning_latents
|
||||
self._qwen_image_ids = qwen_image_ids
|
||||
self._cond_image_grid = cond_image_grid
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from vision-language encoder."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array:
|
||||
"""Negative prompt embeddings for CFG."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder - Qwen doesn't use pooled embeds."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array:
|
||||
"""Static image conditioning latents to concatenate with generated latents."""
|
||||
return self._conditioning_latents
|
||||
|
||||
@property
|
||||
def qwen_image_ids(self) -> mx.array:
|
||||
"""Spatial position IDs for conditioning images."""
|
||||
return self._qwen_image_ids
|
||||
|
||||
@property
|
||||
def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:
|
||||
"""Conditioning image grid dimensions."""
|
||||
return self._cond_image_grid
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Return encoder_hidden_states_mask and edit-specific params."""
|
||||
if positive:
|
||||
return {
|
||||
"encoder_hidden_states_mask": self.prompt_mask,
|
||||
"qwen_image_ids": self._qwen_image_ids,
|
||||
"cond_image_grid": self._cond_image_grid,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"encoder_hidden_states_mask": self.negative_prompt_mask,
|
||||
"qwen_image_ids": self._qwen_image_ids,
|
||||
"cond_image_grid": self._cond_image_grid,
|
||||
}
|
||||
|
||||
@property
|
||||
def is_edit_mode(self) -> bool:
|
||||
"""Indicates this is edit mode with conditioning latents."""
|
||||
return True
|
||||
|
||||
|
||||
class QwenEditModelAdapter(BaseModelAdapter):
|
||||
"""Adapter for Qwen-Image-Edit model.
|
||||
|
||||
Key differences from standard QwenModelAdapter:
|
||||
- Uses QwenImageEdit model with vision-language components
|
||||
- Encodes prompts WITH input images via VL tokenizer/encoder
|
||||
- Creates conditioning latents from input images
|
||||
- Supports image editing with concatenated latents during diffusion
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = QwenImageEdit(
|
||||
quantize=quantize,
|
||||
local_path=str(local_path),
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
# Store dimensions and image paths (set via set_image_dimensions)
|
||||
self._vl_width: int | None = None
|
||||
self._vl_height: int | None = None
|
||||
self._vae_width: int | None = None
|
||||
self._vae_height: int | None = None
|
||||
self._image_paths: list[str] | None = None
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> QwenImageEdit:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> QwenTransformer:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.inner_dim
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return QwenLatentCreator
|
||||
|
||||
def _compute_dimensions_from_image(
|
||||
self, image_path: Path
|
||||
) -> tuple[int, int, int, int, int, int]:
|
||||
"""Compute VL and VAE dimensions from input image.
|
||||
|
||||
Returns:
|
||||
(vl_width, vl_height, vae_width, vae_height, output_width, output_height)
|
||||
"""
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
|
||||
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
|
||||
image_size = pil_image.size
|
||||
|
||||
# Vision-language dimensions (384x384 target area)
|
||||
condition_image_size = 384 * 384
|
||||
condition_ratio = image_size[0] / image_size[1]
|
||||
vl_width = math.sqrt(condition_image_size * condition_ratio)
|
||||
vl_height = vl_width / condition_ratio
|
||||
vl_width = round(vl_width / 32) * 32
|
||||
vl_height = round(vl_height / 32) * 32
|
||||
|
||||
# VAE dimensions (1024x1024 target area)
|
||||
vae_image_size = 1024 * 1024
|
||||
vae_ratio = image_size[0] / image_size[1]
|
||||
vae_width = math.sqrt(vae_image_size * vae_ratio)
|
||||
vae_height = vae_width / vae_ratio
|
||||
vae_width = round(vae_width / 32) * 32
|
||||
vae_height = round(vae_height / 32) * 32
|
||||
|
||||
# Output dimensions from input image aspect ratio
|
||||
target_area = 1024 * 1024
|
||||
ratio = image_size[0] / image_size[1]
|
||||
output_width = math.sqrt(target_area * ratio)
|
||||
output_height = output_width / ratio
|
||||
output_width = round(output_width / 32) * 32
|
||||
output_height = round(output_height / 32) * 32
|
||||
|
||||
# Ensure multiple of 16 for VAE
|
||||
vae_scale_factor = 8
|
||||
multiple_of = vae_scale_factor * 2
|
||||
output_width = output_width // multiple_of * multiple_of
|
||||
output_height = output_height // multiple_of * multiple_of
|
||||
|
||||
return (
|
||||
int(vl_width),
|
||||
int(vl_height),
|
||||
int(vae_width),
|
||||
int(vae_height),
|
||||
int(output_width),
|
||||
int(output_height),
|
||||
)
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
|
||||
"""Create initial noise latents (pure noise for edit mode)."""
|
||||
return QwenLatentCreator.create_noise(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
|
||||
def encode_prompt(self, prompt: str) -> QwenEditPromptData:
|
||||
"""Encode prompt with input images using vision-language encoder.
|
||||
|
||||
Uses stored image_paths from set_image_dimensions() for VL encoding.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt for editing
|
||||
|
||||
Returns:
|
||||
QwenEditPromptData with VL embeddings and conditioning latents
|
||||
"""
|
||||
# Ensure image_paths and dimensions were set via set_image_dimensions()
|
||||
if self._image_paths is None:
|
||||
raise RuntimeError(
|
||||
"set_image_dimensions() must be called before encode_prompt() "
|
||||
"for QwenEditModelAdapter"
|
||||
)
|
||||
|
||||
negative_prompt = ""
|
||||
image_paths = self._image_paths
|
||||
|
||||
# Use stored dimensions (computed from input image)
|
||||
vl_width = self._vl_width
|
||||
vl_height = self._vl_height
|
||||
vae_width = self._vae_width
|
||||
vae_height = self._vae_height
|
||||
|
||||
# Encode prompts with images via vision-language components
|
||||
tokenizer = self._model.qwen_vl_tokenizer
|
||||
pos_input_ids, pos_attention_mask, pos_pixel_values, pos_image_grid_thw = (
|
||||
tokenizer.tokenize_with_image(
|
||||
prompt, image_paths, vl_width=vl_width, vl_height=vl_height
|
||||
)
|
||||
)
|
||||
|
||||
pos_hidden_states = self._model.qwen_vl_encoder(
|
||||
input_ids=pos_input_ids,
|
||||
attention_mask=pos_attention_mask,
|
||||
pixel_values=pos_pixel_values,
|
||||
image_grid_thw=pos_image_grid_thw,
|
||||
)
|
||||
mx.eval(pos_hidden_states[0])
|
||||
mx.eval(pos_hidden_states[1])
|
||||
|
||||
# Encode negative prompt with images
|
||||
neg_input_ids, neg_attention_mask, neg_pixel_values, neg_image_grid_thw = (
|
||||
tokenizer.tokenize_with_image(
|
||||
negative_prompt, image_paths, vl_width=vl_width, vl_height=vl_height
|
||||
)
|
||||
)
|
||||
|
||||
neg_hidden_states = self._model.qwen_vl_encoder(
|
||||
input_ids=neg_input_ids,
|
||||
attention_mask=neg_attention_mask,
|
||||
pixel_values=neg_pixel_values,
|
||||
image_grid_thw=neg_image_grid_thw,
|
||||
)
|
||||
mx.eval(neg_hidden_states[0])
|
||||
mx.eval(neg_hidden_states[1])
|
||||
|
||||
# Create conditioning latents from input images
|
||||
# Ensure dimensions are set (should have been set via set_image_dimensions)
|
||||
assert vl_width is not None and vl_height is not None
|
||||
assert vae_width is not None and vae_height is not None
|
||||
|
||||
(
|
||||
conditioning_latents,
|
||||
qwen_image_ids,
|
||||
cond_h_patches,
|
||||
cond_w_patches,
|
||||
num_images,
|
||||
) = QwenEditUtil.create_image_conditioning_latents(
|
||||
vae=self._model.vae,
|
||||
height=vae_height,
|
||||
width=vae_width,
|
||||
image_paths=image_paths,
|
||||
vl_width=vl_width,
|
||||
vl_height=vl_height,
|
||||
)
|
||||
|
||||
# Build cond_image_grid
|
||||
if num_images > 1:
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [
|
||||
(1, cond_h_patches, cond_w_patches) for _ in range(num_images)
|
||||
]
|
||||
else:
|
||||
cond_image_grid = (1, cond_h_patches, cond_w_patches)
|
||||
|
||||
return QwenEditPromptData(
|
||||
prompt_embeds=pos_hidden_states[0].astype(mx.float16),
|
||||
prompt_mask=pos_hidden_states[1].astype(mx.float16),
|
||||
negative_prompt_embeds=neg_hidden_states[0].astype(mx.float16),
|
||||
negative_prompt_mask=neg_hidden_states[1].astype(mx.float16),
|
||||
conditioning_latents=conditioning_latents,
|
||||
qwen_image_ids=qwen_image_ids,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
|
||||
"""Compute and store dimensions from input image.
|
||||
|
||||
Also stores image_paths for use in encode_prompt().
|
||||
|
||||
Returns:
|
||||
(output_width, output_height) for runtime config
|
||||
"""
|
||||
vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(
|
||||
image_path
|
||||
)
|
||||
self._vl_width = vl_w
|
||||
self._vl_height = vl_h
|
||||
self._vae_width = vae_w
|
||||
self._vae_height = vae_h
|
||||
self._image_paths = [str(image_path)]
|
||||
return out_w, out_h
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
gs = self._config.guidance_scale
|
||||
return gs is not None and gs > 1.0
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
|
||||
return QwenImage.compute_guided_noise(
|
||||
noise=noise_positive,
|
||||
noise_negative=noise_negative,
|
||||
guidance=guidance_scale,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute image and text embeddings."""
|
||||
embedded_hidden = self._transformer.img_in(hidden_states)
|
||||
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
|
||||
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings."""
|
||||
ref_tensor = (
|
||||
hidden_states if hidden_states is not None else pooled_prompt_embeds
|
||||
)
|
||||
if ref_tensor is None:
|
||||
raise ValueError(
|
||||
"Either hidden_states or pooled_prompt_embeds is required "
|
||||
"for Qwen text embeddings"
|
||||
)
|
||||
|
||||
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
|
||||
batch_size = ref_tensor.shape[0]
|
||||
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
|
||||
return self._transformer.time_text_embed(timestep, ref_tensor)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Compute 3D rotary embeddings for Qwen edit."""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
cond_image_grid = kwargs.get("cond_image_grid")
|
||||
|
||||
if encoder_hidden_states_mask is None:
|
||||
raise ValueError(
|
||||
"encoder_hidden_states_mask is required for Qwen RoPE computation"
|
||||
)
|
||||
|
||||
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
pos_embed=self._transformer.pos_embed,
|
||||
config=runtime_config,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply Qwen joint block."""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
block_idx = kwargs.get("block_idx")
|
||||
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_joint_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
else:
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_joint_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Qwen has no single blocks."""
|
||||
raise NotImplementedError("Qwen does not have single blocks")
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply final normalization and projection."""
|
||||
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
|
||||
return self._transformer.proj_out(hidden_states)
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
"""Return all 60 transformer blocks."""
|
||||
return cast(
|
||||
list[JointBlockInterface], list(self._transformer.transformer_blocks)
|
||||
)
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
"""Qwen has no single blocks."""
|
||||
return []
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
) -> None:
|
||||
all_blocks = list(self._transformer.transformer_blocks)
|
||||
assigned_blocks = all_blocks[start_layer:end_layer]
|
||||
self._transformer.transformer_blocks = assigned_blocks
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
"""Merge image and text streams."""
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
def _apply_joint_block_caching(
|
||||
self,
|
||||
block: Any,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply joint block in caching mode."""
|
||||
return block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_emb=rotary_embeddings,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def _apply_joint_block_patched(
|
||||
self,
|
||||
block: Any,
|
||||
patch_hidden: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dim
|
||||
|
||||
# Modulation parameters
|
||||
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
|
||||
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
|
||||
|
||||
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
|
||||
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
|
||||
|
||||
# Normalization and modulation
|
||||
img_normed = block.img_norm1(patch_hidden)
|
||||
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
|
||||
|
||||
txt_normed = block.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
|
||||
|
||||
# Q, K, V for image patch
|
||||
img_query = attn.to_q(img_modulated)
|
||||
img_key = attn.to_k(img_modulated)
|
||||
img_value = attn.to_v(img_modulated)
|
||||
|
||||
# Q, K, V for text
|
||||
txt_query = attn.add_q_proj(txt_modulated)
|
||||
txt_key = attn.add_k_proj(txt_modulated)
|
||||
txt_value = attn.add_v_proj(txt_modulated)
|
||||
|
||||
# Reshape to [B, S, H, D]
|
||||
patch_len = patch_hidden.shape[1]
|
||||
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
|
||||
|
||||
txt_query = mx.reshape(
|
||||
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
|
||||
txt_value = mx.reshape(
|
||||
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
|
||||
# RMSNorm to Q, K
|
||||
img_query = attn.norm_q(img_query)
|
||||
img_key = attn.norm_k(img_key)
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
# Extract RoPE for patch
|
||||
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
|
||||
patch_img_cos = img_cos[patch_start:patch_end]
|
||||
patch_img_sin = img_sin[patch_start:patch_end]
|
||||
|
||||
# Apply RoPE
|
||||
img_query = QwenAttention._apply_rope_qwen(
|
||||
img_query, patch_img_cos, patch_img_sin
|
||||
)
|
||||
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
|
||||
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
|
||||
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
|
||||
|
||||
# Transpose to [B, H, S, D]
|
||||
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
|
||||
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
|
||||
|
||||
# Update cache
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=img_key_bhsd,
|
||||
value=img_value_bhsd,
|
||||
)
|
||||
|
||||
# Get full K, V from cache
|
||||
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
|
||||
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=txt_key_bhsd,
|
||||
text_value=txt_value_bhsd,
|
||||
)
|
||||
|
||||
# Build query
|
||||
joint_query = mx.concatenate([txt_query, img_query], axis=1)
|
||||
|
||||
# Build attention mask
|
||||
mask = QwenAttention._convert_mask_for_qwen(
|
||||
mask=encoder_hidden_states_mask,
|
||||
joint_seq_len=full_key.shape[2],
|
||||
txt_seq_len=text_seq_len,
|
||||
)
|
||||
|
||||
# Compute attention
|
||||
hidden_states = attn._compute_attention_qwen(
|
||||
query=joint_query,
|
||||
key=mx.transpose(full_key, (0, 2, 1, 3)),
|
||||
value=mx.transpose(full_value, (0, 2, 1, 3)),
|
||||
mask=mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
# Extract outputs
|
||||
txt_attn_output = hidden_states[:, :text_seq_len, :]
|
||||
img_attn_output = hidden_states[:, text_seq_len:, :]
|
||||
|
||||
# Project
|
||||
img_attn_output = attn.attn_to_out[0](img_attn_output)
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
|
||||
# Residual + gate
|
||||
patch_hidden = patch_hidden + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
|
||||
# Feed-forward for image
|
||||
img_normed2 = block.img_norm2(patch_hidden)
|
||||
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
|
||||
img_normed2, img_mod2
|
||||
)
|
||||
img_mlp_output = block.img_ff(img_modulated2)
|
||||
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
|
||||
|
||||
# Feed-forward for text
|
||||
txt_normed2 = block.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
|
||||
txt_normed2, txt_mod2
|
||||
)
|
||||
txt_mlp_output = block.txt_ff(txt_modulated2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
||||
|
||||
return encoder_hidden_states, patch_hidden
|
||||
@@ -1,23 +0,0 @@
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
ModelAdapter,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
|
||||
|
||||
__all__ = [
|
||||
"BlockWrapperMode",
|
||||
"DiffusionRunner",
|
||||
"ImagePatchKVCache",
|
||||
"JointBlockInterface",
|
||||
"JointBlockWrapper",
|
||||
"ModelAdapter",
|
||||
"SingleBlockInterface",
|
||||
"SingleBlockWrapper",
|
||||
]
|
||||
@@ -1,402 +0,0 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class AttentionInterface(Protocol):
|
||||
num_heads: int
|
||||
head_dimension: int
|
||||
to_q: Any
|
||||
to_k: Any
|
||||
to_v: Any
|
||||
norm_q: Any
|
||||
norm_k: Any
|
||||
to_out: list[Any]
|
||||
|
||||
|
||||
class JointAttentionInterface(AttentionInterface, Protocol):
|
||||
add_q_proj: Any
|
||||
add_k_proj: Any
|
||||
add_v_proj: Any
|
||||
norm_added_q: Any
|
||||
norm_added_k: Any
|
||||
to_add_out: Any
|
||||
|
||||
|
||||
class JointBlockInterface(Protocol):
|
||||
attn: JointAttentionInterface
|
||||
norm1: Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
|
||||
norm1_context: (
|
||||
Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
|
||||
)
|
||||
norm2: Any
|
||||
norm2_context: Any
|
||||
ff: Any
|
||||
ff_context: Any
|
||||
|
||||
|
||||
class SingleBlockInterface(Protocol):
|
||||
attn: AttentionInterface
|
||||
norm: Any # Callable module: (hidden_states, text_embeddings) -> tuple[2 arrays]
|
||||
|
||||
def _apply_feed_forward_and_projection(
|
||||
self, norm_hidden_states: mx.array, attn_output: mx.array, gate: mx.array
|
||||
) -> mx.array:
|
||||
"""Apply feed forward network and projection."""
|
||||
...
|
||||
|
||||
|
||||
class BlockWrapperMode(Enum):
|
||||
CACHING = "caching" # Sync mode: compute full attention, populate cache
|
||||
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
|
||||
|
||||
|
||||
class PromptData(Protocol):
|
||||
"""Protocol for encoded prompt data.
|
||||
|
||||
All adapters must return prompt data that conforms to this protocol.
|
||||
Model-specific prompt data classes can add additional attributes
|
||||
(e.g., attention masks for Qwen).
|
||||
"""
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from encoder."""
|
||||
...
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Pooled text embeddings (for Flux) or placeholder (for Qwen)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
"""Negative prompt embeddings for CFG (None if not using CFG)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
"""Negative pooled embeddings for CFG (None if not using CFG)."""
|
||||
...
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Return model-specific kwargs for forward pass.
|
||||
|
||||
Args:
|
||||
positive: If True, return kwargs for positive prompt pass.
|
||||
If False, return kwargs for negative prompt pass.
|
||||
|
||||
Returns:
|
||||
Dict of extra kwargs (e.g., {"encoder_hidden_states_mask": ...} for Qwen)
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Conditioning latents for edit mode.
|
||||
|
||||
Returns:
|
||||
Conditioning latents array for image editing, None for standard generation.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModelAdapter(Protocol):
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
"""Return the model configuration."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model(self) -> Any:
|
||||
"""Return the underlying mflux model instance (e.g., Flux1, Fibo, Qwen)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def transformer(self) -> Any:
|
||||
"""Return the transformer component of the model."""
|
||||
...
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
"""Return the hidden dimension of the transformer."""
|
||||
...
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute x_embedder and context_embedder outputs.
|
||||
|
||||
Args:
|
||||
hidden_states: Input latent states
|
||||
prompt_embeds: Text embeddings from encoder
|
||||
|
||||
Returns:
|
||||
Tuple of (embedded_hidden_states, embedded_encoder_states)
|
||||
"""
|
||||
...
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings for conditioning.
|
||||
|
||||
Args:
|
||||
t: Current timestep
|
||||
runtime_config: Runtime configuration
|
||||
pooled_prompt_embeds: Pooled text embeddings (used by Flux)
|
||||
hidden_states: Image hidden states
|
||||
|
||||
Returns:
|
||||
Text embeddings tensor
|
||||
"""
|
||||
...
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Compute rotary position embeddings.
|
||||
|
||||
Args:
|
||||
prompt_embeds: Text embeddings
|
||||
runtime_config: Runtime configuration
|
||||
**kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask for Qwen)
|
||||
|
||||
Returns:
|
||||
Flux: mx.array
|
||||
Qwen: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any, # Format varies: mx.array (Flux) or nested tuple (Qwen)
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: "BlockWrapperMode",
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply a joint transformer block.
|
||||
|
||||
Args:
|
||||
block: The joint transformer block
|
||||
hidden_states: Image hidden states
|
||||
encoder_hidden_states: Text hidden states
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings (format varies by model)
|
||||
kv_cache: KV cache (None if not using cache)
|
||||
mode: CACHING or PATCHED mode
|
||||
text_seq_len: Text sequence length
|
||||
patch_start: Start index for patched mode
|
||||
patch_end: End index for patched mode
|
||||
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
|
||||
block_idx for Qwen)
|
||||
|
||||
Returns:
|
||||
Tuple of (encoder_hidden_states, hidden_states)
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: "BlockWrapperMode",
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Apply a single transformer block.
|
||||
|
||||
Args:
|
||||
block: The single transformer block
|
||||
hidden_states: Concatenated [text + image] hidden states
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
kv_cache: KV cache (None if not using cache)
|
||||
mode: CACHING or PATCHED mode
|
||||
text_seq_len: Text sequence length
|
||||
patch_start: Start index for patched mode
|
||||
patch_end: End index for patched mode
|
||||
|
||||
Returns:
|
||||
Output hidden states
|
||||
"""
|
||||
...
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply final norm and projection.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states (image only, text already removed)
|
||||
text_embeddings: Conditioning embeddings
|
||||
|
||||
Returns:
|
||||
Projected output
|
||||
"""
|
||||
...
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
"""Get the list of joint transformer blocks from the model."""
|
||||
...
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
"""Get the list of single transformer blocks from the model."""
|
||||
...
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
):
|
||||
"""Remove transformer blocks outside the assigned range.
|
||||
|
||||
This should be called BEFORE mx.eval() to avoid loading unused weights
|
||||
in distributed mode.
|
||||
|
||||
Args:
|
||||
start_layer: First layer index (inclusive) assigned to this node
|
||||
end_layer: Last layer index (exclusive) assigned to this node
|
||||
total_joint_blocks: Total number of joint blocks in the model
|
||||
total_single_blocks: Total number of single blocks in the model
|
||||
"""
|
||||
...
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
"""Merge image and text streams for transition to single blocks.
|
||||
|
||||
This is called at the transition point from joint blocks (which process
|
||||
image and text separately) to single blocks (which process them
|
||||
together). Override to customize the merge strategy.
|
||||
|
||||
Args:
|
||||
hidden_states: Image hidden states
|
||||
encoder_hidden_states: Text hidden states
|
||||
|
||||
Returns:
|
||||
Merged hidden states (default: concatenate [text, image])
|
||||
"""
|
||||
...
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
|
||||
"""Create initial noise latents for generation.
|
||||
|
||||
Args:
|
||||
seed: Random seed
|
||||
runtime_config: Runtime configuration with dimensions
|
||||
|
||||
Returns:
|
||||
Initial latent tensor
|
||||
"""
|
||||
...
|
||||
|
||||
def encode_prompt(self, prompt: str) -> PromptData:
|
||||
"""Encode prompt into model-specific prompt data.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
|
||||
Returns:
|
||||
PromptData containing embeddings (and model-specific extras)
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
"""Whether this model uses classifier-free guidance.
|
||||
|
||||
Returns:
|
||||
True if model requires two forward passes with guidance (e.g., Qwen)
|
||||
False if model uses a single forward pass (e.g., Flux)
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
"""Apply classifier-free guidance to combine positive/negative predictions.
|
||||
|
||||
Only called when needs_cfg is True.
|
||||
|
||||
Args:
|
||||
noise_positive: Noise prediction from positive prompt
|
||||
noise_negative: Noise prediction from negative prompt
|
||||
guidance_scale: Guidance strength
|
||||
|
||||
Returns:
|
||||
Guided noise prediction
|
||||
"""
|
||||
...
|
||||
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
) -> Any:
|
||||
"""Decode latents to final image.
|
||||
|
||||
Args:
|
||||
latents: Final denoised latents
|
||||
runtime_config: Runtime configuration
|
||||
seed: Random seed (for metadata)
|
||||
prompt: Text prompt (for metadata)
|
||||
|
||||
Returns:
|
||||
GeneratedImage result
|
||||
"""
|
||||
...
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
|
||||
"""Compute and store dimensions from input image for edit mode.
|
||||
|
||||
For edit adapters: computes dimensions from input image aspect ratio,
|
||||
stores image paths internally for encode_prompt(), returns (width, height).
|
||||
|
||||
For standard adapters: returns None (use user-specified dimensions).
|
||||
|
||||
Args:
|
||||
image_path: Path to the input image
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height) if dimensions were computed, None otherwise.
|
||||
"""
|
||||
...
|
||||
@@ -1,146 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
ModelAdapter,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class JointBlockWrapper:
|
||||
"""Unified wrapper for joint transformer blocks.
|
||||
|
||||
Handles both CACHING (sync) and PATCHED (async) modes by delegating
|
||||
to the model adapter for model-specific attention computation.
|
||||
|
||||
The wrapper is created once at initialization and reused across calls.
|
||||
Mode and KV cache are passed at call time to support switching between
|
||||
sync and async pipelines.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
adapter: ModelAdapter,
|
||||
):
|
||||
"""Initialize the joint block wrapper.
|
||||
|
||||
Args:
|
||||
block: The joint transformer block to wrap
|
||||
adapter: Model adapter for model-specific operations
|
||||
"""
|
||||
self.block = block
|
||||
self.adapter = adapter
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
text_seq_len: int,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply the joint block.
|
||||
|
||||
Args:
|
||||
hidden_states: Image hidden states (full or patch depending on mode)
|
||||
encoder_hidden_states: Text hidden states
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
text_seq_len: Text sequence length
|
||||
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
|
||||
mode: CACHING (populate cache) or PATCHED (use cached K/V)
|
||||
patch_start: Start index for patched mode (required if mode=PATCHED)
|
||||
patch_end: End index for patched mode (required if mode=PATCHED)
|
||||
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
|
||||
block_idx for Qwen)
|
||||
|
||||
Returns:
|
||||
Tuple of (encoder_hidden_states, hidden_states)
|
||||
"""
|
||||
return self.adapter.apply_joint_block(
|
||||
block=self.block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
mode=mode,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class SingleBlockWrapper:
|
||||
"""Unified wrapper for single transformer blocks.
|
||||
|
||||
Handles both CACHING (sync) and PATCHED (async) modes by delegating
|
||||
to the model adapter for model-specific attention computation.
|
||||
|
||||
The wrapper is created once at initialization and reused across calls.
|
||||
Mode and KV cache are passed at call time to support switching between
|
||||
sync and async pipelines.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
adapter: ModelAdapter,
|
||||
):
|
||||
"""Initialize the single block wrapper.
|
||||
|
||||
Args:
|
||||
block: The single transformer block to wrap
|
||||
adapter: Model adapter for model-specific operations
|
||||
"""
|
||||
self.block = block
|
||||
self.adapter = adapter
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
text_seq_len: int,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Apply the single block.
|
||||
|
||||
Args:
|
||||
hidden_states: [text + image] hidden states (full or patch depending on mode)
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
text_seq_len: Text sequence length
|
||||
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
|
||||
mode: CACHING (populate cache) or PATCHED (use cached K/V)
|
||||
patch_start: Start index for patched mode (required if mode=PATCHED)
|
||||
patch_end: End index for patched mode (required if mode=PATCHED)
|
||||
|
||||
Returns:
|
||||
Output hidden states
|
||||
"""
|
||||
return self.adapter.apply_single_block(
|
||||
block=self.block,
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
mode=mode,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
)
|
||||
@@ -1,72 +0,0 @@
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class ImagePatchKVCache:
|
||||
"""KV cache that stores only IMAGE K/V with patch-level updates.
|
||||
|
||||
Only caches image K/V since:
|
||||
- Text K/V is always computed fresh (same for all patches)
|
||||
- Only image portion needs stale/fresh cache management across patches
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
image_seq_len: int,
|
||||
head_dim: int,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.num_heads = num_heads
|
||||
self.image_seq_len = image_seq_len
|
||||
self.head_dim = head_dim
|
||||
self._dtype = dtype
|
||||
|
||||
self.key_cache = mx.zeros(
|
||||
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
|
||||
)
|
||||
self.value_cache = mx.zeros(
|
||||
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
|
||||
)
|
||||
|
||||
def update_image_patch(
|
||||
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
|
||||
) -> None:
|
||||
"""Update cache with fresh K/V for an image patch slice.
|
||||
|
||||
Args:
|
||||
patch_start: Start token index within image portion (0-indexed)
|
||||
patch_end: End token index within image portion
|
||||
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
|
||||
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
|
||||
"""
|
||||
self.key_cache[:, :, patch_start:patch_end, :] = key
|
||||
self.value_cache[:, :, patch_start:patch_end, :] = value
|
||||
|
||||
def get_full_kv(
|
||||
self, text_key: mx.array, text_value: mx.array
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
|
||||
|
||||
Args:
|
||||
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
|
||||
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
|
||||
|
||||
Returns:
|
||||
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
|
||||
"""
|
||||
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
|
||||
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
|
||||
return full_key, full_value
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset cache to zeros."""
|
||||
self.key_cache = mx.zeros(
|
||||
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
|
||||
dtype=self._dtype,
|
||||
)
|
||||
self.value_cache = mx.zeros(
|
||||
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
|
||||
dtype=self._dtype,
|
||||
)
|
||||
@@ -1,975 +0,0 @@
|
||||
from math import ceil
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.callbacks.callbacks import Callbacks
|
||||
from mflux.config.config import Config
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.utils.exceptions import StopImageGenerationException
|
||||
from tqdm import tqdm
|
||||
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
ModelAdapter,
|
||||
PromptData,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
def calculate_patch_heights(latent_height: int, num_patches: int):
|
||||
patch_height = ceil(latent_height / num_patches)
|
||||
|
||||
actual_num_patches = ceil(latent_height / patch_height)
|
||||
patch_heights = [patch_height] * (actual_num_patches - 1)
|
||||
|
||||
last_height = latent_height - patch_height * (actual_num_patches - 1)
|
||||
patch_heights.append(last_height)
|
||||
|
||||
return patch_heights, actual_num_patches
|
||||
|
||||
|
||||
def calculate_token_indices(patch_heights: list[int], latent_width: int):
|
||||
tokens_per_row = latent_width
|
||||
|
||||
token_ranges = []
|
||||
cumulative_height = 0
|
||||
|
||||
for h in patch_heights:
|
||||
start_token = tokens_per_row * cumulative_height
|
||||
end_token = tokens_per_row * (cumulative_height + h)
|
||||
|
||||
token_ranges.append((start_token, end_token))
|
||||
cumulative_height += h
|
||||
|
||||
return token_ranges
|
||||
|
||||
|
||||
class DiffusionRunner:
|
||||
"""Orchestrates the diffusion loop for image generation.
|
||||
|
||||
This class owns the entire diffusion process, handling both single-node
|
||||
and distributed (PipeFusion) modes.
|
||||
|
||||
In distributed mode, it implements PipeFusion with:
|
||||
- Sync pipeline for initial timesteps (full image, all devices in lockstep)
|
||||
- Async pipeline for later timesteps (patches processed independently)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
adapter: ModelAdapter,
|
||||
group: Optional[mx.distributed.Group],
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
num_sync_steps: int = 1,
|
||||
num_patches: Optional[int] = None,
|
||||
):
|
||||
"""Initialize the diffusion runner.
|
||||
|
||||
Args:
|
||||
config: Model configuration (architecture, block counts, etc.)
|
||||
adapter: Model adapter for model-specific operations
|
||||
group: MLX distributed group (None for single-node mode)
|
||||
shard_metadata: Pipeline shard metadata with layer assignments
|
||||
num_sync_steps: Number of synchronous timesteps before async mode
|
||||
num_patches: Number of patches for async mode (defaults to world_size)
|
||||
"""
|
||||
self.config = config
|
||||
self.adapter = adapter
|
||||
self.group = group
|
||||
|
||||
# Handle single-node vs distributed mode
|
||||
if group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.next_rank = 0
|
||||
self.prev_rank = 0
|
||||
self.start_layer = 0
|
||||
self.end_layer = config.total_blocks
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.next_rank = (self.rank + 1) % self.world_size
|
||||
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
self.num_sync_steps = num_sync_steps
|
||||
self.num_patches = num_patches if num_patches else max(1, self.world_size)
|
||||
|
||||
# Persistent KV caches (initialized on first async timestep, reused across timesteps)
|
||||
self.joint_kv_caches: list[ImagePatchKVCache] | None = None
|
||||
self.single_kv_caches: list[ImagePatchKVCache] | None = None
|
||||
|
||||
# Get block counts from config (model-agnostic)
|
||||
self.total_joint = config.joint_block_count
|
||||
self.total_single = config.single_block_count
|
||||
self.total_layers = config.total_blocks
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
def _compute_assigned_blocks(self) -> None:
|
||||
"""Determine which joint/single blocks this stage owns."""
|
||||
start = self.start_layer
|
||||
end = self.end_layer
|
||||
|
||||
if end <= self.total_joint:
|
||||
# All assigned blocks are joint blocks
|
||||
self.joint_start = start
|
||||
self.joint_end = end
|
||||
self.single_start = 0
|
||||
self.single_end = 0
|
||||
elif start >= self.total_joint:
|
||||
# All assigned blocks are single blocks
|
||||
self.joint_start = 0
|
||||
self.joint_end = 0
|
||||
self.single_start = start - self.total_joint
|
||||
self.single_end = end - self.total_joint
|
||||
else:
|
||||
# Stage spans joint→single transition
|
||||
self.joint_start = start
|
||||
self.joint_end = self.total_joint
|
||||
self.single_start = 0
|
||||
self.single_end = end - self.total_joint
|
||||
|
||||
self.has_joint_blocks = self.joint_end > self.joint_start
|
||||
self.has_single_blocks = self.single_end > self.single_start
|
||||
|
||||
self.owns_concat_stage = self.has_joint_blocks and (
|
||||
self.has_single_blocks or self.end_layer == self.total_joint
|
||||
)
|
||||
|
||||
joint_blocks = self.adapter.get_joint_blocks()
|
||||
single_blocks = self.adapter.get_single_blocks()
|
||||
|
||||
# Wrap blocks at initialization (reused across all calls)
|
||||
self.joint_block_wrappers = [
|
||||
JointBlockWrapper(block=block, adapter=self.adapter)
|
||||
for block in joint_blocks
|
||||
]
|
||||
self.single_block_wrappers = [
|
||||
SingleBlockWrapper(block=block, adapter=self.adapter)
|
||||
for block in single_blocks
|
||||
]
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self.rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self.rank == self.world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return self.group is not None
|
||||
|
||||
def _calculate_capture_steps(
|
||||
self,
|
||||
partial_images: int,
|
||||
init_time_step: int,
|
||||
num_inference_steps: int,
|
||||
) -> set[int]:
|
||||
"""Calculate which timesteps should produce partial images.
|
||||
|
||||
Evenly spaces `partial_images` captures across the diffusion loop.
|
||||
Does NOT include the final timestep (that's the complete image).
|
||||
|
||||
Args:
|
||||
partial_images: Number of partial images to capture
|
||||
init_time_step: Starting timestep (for img2img this may not be 0)
|
||||
num_inference_steps: Total inference steps
|
||||
|
||||
Returns:
|
||||
Set of timestep indices to capture
|
||||
"""
|
||||
if partial_images <= 0:
|
||||
return set()
|
||||
|
||||
total_steps = num_inference_steps - init_time_step
|
||||
if total_steps <= 1:
|
||||
return set()
|
||||
|
||||
if partial_images >= total_steps - 1:
|
||||
# Capture every step except final
|
||||
return set(range(init_time_step, num_inference_steps - 1))
|
||||
|
||||
# Evenly space partial captures
|
||||
step_interval = total_steps / (partial_images + 1)
|
||||
capture_steps: set[int] = set()
|
||||
for i in range(1, partial_images + 1):
|
||||
step_idx = int(init_time_step + i * step_interval)
|
||||
# Ensure we don't capture the final step
|
||||
if step_idx < num_inference_steps - 1:
|
||||
capture_steps.add(step_idx)
|
||||
|
||||
return capture_steps
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
settings: Config,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
partial_images: int = 0,
|
||||
):
|
||||
"""Primary entry point for image generation.
|
||||
|
||||
Orchestrates the full generation flow:
|
||||
1. Create runtime config
|
||||
2. Create initial latents
|
||||
3. Encode prompt
|
||||
4. Run diffusion loop (yielding partials if requested)
|
||||
5. Decode to image
|
||||
|
||||
When partial_images > 0, yields (GeneratedImage, partial_index, total_partials)
|
||||
tuples for intermediate images, then yields the final GeneratedImage.
|
||||
|
||||
Args:
|
||||
settings: Generation config (steps, height, width)
|
||||
prompt: Text prompt
|
||||
seed: Random seed
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
|
||||
Yields:
|
||||
Partial images as (GeneratedImage, partial_index, total_partials) tuples
|
||||
Final GeneratedImage
|
||||
"""
|
||||
runtime_config = RuntimeConfig(settings, self.adapter.model.model_config)
|
||||
latents = self.adapter.create_latents(seed, runtime_config)
|
||||
prompt_data = self.adapter.encode_prompt(prompt)
|
||||
|
||||
# Calculate which steps to capture
|
||||
capture_steps = self._calculate_capture_steps(
|
||||
partial_images=partial_images,
|
||||
init_time_step=runtime_config.init_time_step,
|
||||
num_inference_steps=runtime_config.num_inference_steps,
|
||||
)
|
||||
|
||||
# Run diffusion loop - may yield partial latents
|
||||
diffusion_gen = self._run_diffusion_loop(
|
||||
latents=latents,
|
||||
prompt_data=prompt_data,
|
||||
runtime_config=runtime_config,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
capture_steps=capture_steps,
|
||||
)
|
||||
|
||||
# Process partial yields and get final latents
|
||||
partial_index = 0
|
||||
total_partials = len(capture_steps)
|
||||
|
||||
if capture_steps:
|
||||
# Generator mode - iterate to get partials and final latents
|
||||
try:
|
||||
while True:
|
||||
partial_latents, _step = next(diffusion_gen)
|
||||
if self.is_last_stage:
|
||||
partial_image = self.adapter.decode_latents(
|
||||
partial_latents, runtime_config, seed, prompt
|
||||
)
|
||||
yield (partial_image, partial_index, total_partials)
|
||||
partial_index += 1
|
||||
except StopIteration as e:
|
||||
latents = e.value
|
||||
else:
|
||||
# No partials - just consume generator to get final latents
|
||||
try:
|
||||
while True:
|
||||
next(diffusion_gen)
|
||||
except StopIteration as e:
|
||||
latents = e.value
|
||||
|
||||
# Yield final image (only on last stage)
|
||||
if self.is_last_stage:
|
||||
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt)
|
||||
|
||||
def _run_diffusion_loop(
|
||||
self,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
runtime_config: RuntimeConfig,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
capture_steps: set[int] | None = None,
|
||||
):
|
||||
"""Execute the diffusion loop, optionally yielding at capture steps.
|
||||
|
||||
When capture_steps is provided and non-empty, this becomes a generator
|
||||
that yields (latents, step_index) tuples at the specified timesteps.
|
||||
Only the last stage yields (others have incomplete latents).
|
||||
|
||||
Args:
|
||||
latents: Initial noise latents
|
||||
prompt_data: Encoded prompt data
|
||||
runtime_config: RuntimeConfig with scheduler, steps, dimensions
|
||||
seed: Random seed (for callbacks)
|
||||
prompt: Text prompt (for callbacks)
|
||||
capture_steps: Set of timestep indices to capture (None = no captures)
|
||||
|
||||
Yields:
|
||||
(latents, step_index) tuples at capture steps (last stage only)
|
||||
|
||||
Returns:
|
||||
Final denoised latents ready for VAE decoding
|
||||
"""
|
||||
if capture_steps is None:
|
||||
capture_steps = set()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
|
||||
# Call subscribers for beginning of loop
|
||||
Callbacks.before_loop(
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
)
|
||||
|
||||
for t in time_steps:
|
||||
try:
|
||||
latents = self._diffusion_step(
|
||||
t=t,
|
||||
config=runtime_config,
|
||||
latents=latents,
|
||||
prompt_data=prompt_data,
|
||||
)
|
||||
|
||||
# Call subscribers in-loop
|
||||
Callbacks.in_loop(
|
||||
t=t,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
|
||||
# Yield partial latents at capture steps (only on last stage)
|
||||
if t in capture_steps and self.is_last_stage:
|
||||
yield (latents, t)
|
||||
|
||||
except KeyboardInterrupt: # noqa: PERF203
|
||||
Callbacks.interruption(
|
||||
t=t,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
raise StopImageGenerationException(
|
||||
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
|
||||
) from None
|
||||
|
||||
# Call subscribers after loop
|
||||
Callbacks.after_loop(
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def _forward_pass(
|
||||
self,
|
||||
latents: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
kwargs: dict[str, Any],
|
||||
) -> mx.array:
|
||||
"""Run a single forward pass through the transformer.
|
||||
|
||||
This is the internal method called by adapters via compute_step_noise.
|
||||
Returns noise prediction without applying scheduler step.
|
||||
|
||||
For edit mode, concatenates conditioning latents with generated latents
|
||||
before the transformer, and extracts only the generated portion after.
|
||||
|
||||
Args:
|
||||
latents: Input latents (already scaled by caller)
|
||||
prompt_embeds: Text embeddings
|
||||
pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)
|
||||
kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask, t)
|
||||
|
||||
Returns:
|
||||
Noise prediction tensor
|
||||
"""
|
||||
t = kwargs.get("t", 0)
|
||||
config = kwargs.get("config")
|
||||
if config is None:
|
||||
raise ValueError("config must be provided in kwargs")
|
||||
scaled_latents = config.scheduler.scale_model_input(latents, t)
|
||||
|
||||
# For edit mode: concatenate with conditioning latents
|
||||
conditioning_latents = kwargs.get("conditioning_latents")
|
||||
original_latent_tokens = scaled_latents.shape[1]
|
||||
if conditioning_latents is not None:
|
||||
scaled_latents = mx.concatenate(
|
||||
[scaled_latents, conditioning_latents], axis=1
|
||||
)
|
||||
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
scaled_latents, prompt_embeds
|
||||
)
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds, hidden_states=hidden_states
|
||||
)
|
||||
rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds, config, **kwargs
|
||||
)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
|
||||
# Run through all joint blocks
|
||||
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=None,
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
block_idx=block_idx,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Merge streams
|
||||
if self.joint_block_wrappers:
|
||||
hidden_states = self.adapter.merge_streams(
|
||||
hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
# Run through single blocks
|
||||
for wrapper in self.single_block_wrappers:
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=None,
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
)
|
||||
|
||||
# Extract image portion and project
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
# For edit mode: extract only the generated portion (exclude conditioning latents)
|
||||
if conditioning_latents is not None:
|
||||
hidden_states = hidden_states[:, :original_latent_tokens, ...]
|
||||
|
||||
return self.adapter.final_projection(hidden_states, text_embeddings)
|
||||
|
||||
def _diffusion_step(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
"""Execute a single diffusion step.
|
||||
|
||||
Routes to single-node, sync pipeline, or async pipeline based on
|
||||
configuration and current timestep.
|
||||
"""
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif t < config.init_time_step + self.num_sync_steps:
|
||||
return self._sync_pipeline(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
else:
|
||||
return self._async_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
|
||||
def _single_node_step(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
"""Execute a single diffusion step on a single node (no distribution)."""
|
||||
base_kwargs = {"t": t, "config": config}
|
||||
|
||||
# For edit mode: include conditioning latents
|
||||
if prompt_data.conditioning_latents is not None:
|
||||
base_kwargs["conditioning_latents"] = prompt_data.conditioning_latents
|
||||
|
||||
if self.adapter.needs_cfg:
|
||||
# Two forward passes + guidance for CFG models (e.g., Qwen)
|
||||
pos_kwargs = {
|
||||
**base_kwargs,
|
||||
**prompt_data.get_extra_forward_kwargs(positive=True),
|
||||
}
|
||||
noise_pos = self._forward_pass(
|
||||
latents,
|
||||
prompt_data.prompt_embeds,
|
||||
prompt_data.pooled_prompt_embeds,
|
||||
pos_kwargs,
|
||||
)
|
||||
|
||||
neg_kwargs = {
|
||||
**base_kwargs,
|
||||
**prompt_data.get_extra_forward_kwargs(positive=False),
|
||||
}
|
||||
noise_neg = self._forward_pass(
|
||||
latents,
|
||||
prompt_data.negative_prompt_embeds,
|
||||
prompt_data.negative_pooled_prompt_embeds,
|
||||
neg_kwargs,
|
||||
)
|
||||
|
||||
assert self.config.guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale=self.config.guidance_scale
|
||||
)
|
||||
else:
|
||||
# Single forward pass for non-CFG models (e.g., Flux)
|
||||
kwargs = {**base_kwargs, **prompt_data.get_extra_forward_kwargs()}
|
||||
noise = self._forward_pass(
|
||||
latents,
|
||||
prompt_data.prompt_embeds,
|
||||
prompt_data.pooled_prompt_embeds,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
return config.scheduler.step(model_output=noise, timestep=t, sample=latents)
|
||||
|
||||
def _initialize_kv_caches(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_img_tokens: int,
|
||||
dtype: mx.Dtype,
|
||||
) -> None:
|
||||
"""Initialize KV caches for both sync and async pipelines.
|
||||
|
||||
Note: Caches only store IMAGE K/V, not text K/V. Text K/V is always
|
||||
computed fresh and doesn't need caching (it's the same for all patches).
|
||||
"""
|
||||
self.joint_kv_caches = [
|
||||
ImagePatchKVCache(
|
||||
batch_size=batch_size,
|
||||
num_heads=self.config.num_heads,
|
||||
image_seq_len=num_img_tokens,
|
||||
head_dim=self.config.head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(len(self.joint_block_wrappers))
|
||||
]
|
||||
self.single_kv_caches = [
|
||||
ImagePatchKVCache(
|
||||
batch_size=batch_size,
|
||||
num_heads=self.config.num_heads,
|
||||
image_seq_len=num_img_tokens,
|
||||
head_dim=self.config.head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(len(self.single_block_wrappers))
|
||||
]
|
||||
|
||||
def _create_patches(
|
||||
self,
|
||||
latents: mx.array,
|
||||
config: RuntimeConfig,
|
||||
) -> tuple[list[mx.array], list[tuple[int, int]]]:
|
||||
"""Split latents into patches for async pipeline."""
|
||||
# Use 16 to match FluxLatentCreator.create_noise formula
|
||||
latent_height = config.height // 16
|
||||
latent_width = config.width // 16
|
||||
|
||||
patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)
|
||||
token_indices = calculate_token_indices(patch_heights, latent_width)
|
||||
|
||||
# Split latents into patches
|
||||
patch_latents = [latents[:, start:end, :] for start, end in token_indices]
|
||||
|
||||
return patch_latents, token_indices
|
||||
|
||||
def _sync_pipeline(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
hidden_states: mx.array,
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
|
||||
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
|
||||
extra_kwargs = prompt_data.get_extra_forward_kwargs()
|
||||
|
||||
hidden_states = config.scheduler.scale_model_input(hidden_states, t)
|
||||
|
||||
# For edit mode: handle conditioning latents
|
||||
# All stages need to know the total token count for correct recv templates
|
||||
conditioning_latents = prompt_data.conditioning_latents
|
||||
original_latent_tokens = hidden_states.shape[1]
|
||||
if conditioning_latents is not None:
|
||||
num_img_tokens = original_latent_tokens + conditioning_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
# First stage: concatenate conditioning latents before embedding
|
||||
if self.is_first_stage and conditioning_latents is not None:
|
||||
hidden_states = mx.concatenate(
|
||||
[hidden_states, conditioning_latents], axis=1
|
||||
)
|
||||
|
||||
# === PHASE 1: Embeddings ===
|
||||
if self.is_first_stage:
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
hidden_states, prompt_embeds
|
||||
)
|
||||
|
||||
# All stages need these for their blocks
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
# === Initialize KV caches to populate during sync for async warmstart ===
|
||||
batch_size = prev_latents.shape[0]
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
hidden_dim = self.adapter.hidden_dim
|
||||
|
||||
if t == config.init_time_step:
|
||||
self._initialize_kv_caches(
|
||||
batch_size=batch_size,
|
||||
num_img_tokens=num_img_tokens,
|
||||
dtype=prev_latents.dtype,
|
||||
)
|
||||
|
||||
# === PHASE 2: Joint Blocks with Communication and Caching ===
|
||||
if self.has_joint_blocks:
|
||||
# Receive from previous stage (if not first stage)
|
||||
if not self.is_first_stage:
|
||||
recv_template = mx.zeros(
|
||||
(batch_size, num_img_tokens, hidden_dim), dtype=prev_latents.dtype
|
||||
)
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
recv_template, self.prev_rank, group=self.group
|
||||
)
|
||||
enc_template = mx.zeros(
|
||||
(batch_size, text_seq_len, hidden_dim), dtype=prev_latents.dtype
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv_like(
|
||||
enc_template, self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
# Run assigned joint blocks with caching mode
|
||||
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.joint_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
# === PHASE 3: Joint→Single Transition ===
|
||||
if self.owns_concat_stage:
|
||||
# Merge encoder and hidden states using adapter hook
|
||||
concatenated = self.adapter.merge_streams(
|
||||
hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
# Keep locally: either for single blocks or final projection
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
# Send concatenated state to next stage (which has single blocks)
|
||||
mx.eval(
|
||||
mx.distributed.send(concatenated, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
# Send joint block outputs to next stage (which has more joint blocks)
|
||||
mx.eval(
|
||||
mx.distributed.send(hidden_states, self.next_rank, group=self.group),
|
||||
mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
),
|
||||
)
|
||||
|
||||
# === PHASE 4: Single Blocks with Communication and Caching ===
|
||||
if self.has_single_blocks:
|
||||
# Receive from previous stage if we didn't do concatenation
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
recv_template = mx.zeros(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype=prev_latents.dtype,
|
||||
)
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
recv_template, self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
# Run assigned single blocks with caching mode
|
||||
for block_idx, wrapper in enumerate(self.single_block_wrappers):
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.single_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
)
|
||||
|
||||
# Send to next stage if not last
|
||||
if not self.is_last_stage:
|
||||
mx.eval(
|
||||
mx.distributed.send(hidden_states, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
# === PHASE 5: Last Stage - Final Projection + Scheduler ===
|
||||
# Extract image portion (remove text embeddings prefix)
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
# For edit mode: extract only the generated portion (exclude conditioning latents)
|
||||
if conditioning_latents is not None:
|
||||
hidden_states = hidden_states[:, :original_latent_tokens, ...]
|
||||
|
||||
if self.is_last_stage:
|
||||
hidden_states = self.adapter.final_projection(
|
||||
hidden_states, text_embeddings
|
||||
)
|
||||
|
||||
hidden_states = config.scheduler.step(
|
||||
model_output=hidden_states,
|
||||
timestep=t,
|
||||
sample=prev_latents,
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
mx.eval(mx.distributed.send(hidden_states, 0, group=self.group))
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
|
||||
mx.eval(hidden_states)
|
||||
|
||||
else:
|
||||
# For shape correctness
|
||||
hidden_states = prev_latents
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _async_pipeline_step(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
|
||||
patch_latents = self._async_pipeline(
|
||||
t,
|
||||
config,
|
||||
patch_latents,
|
||||
token_indices,
|
||||
prompt_data,
|
||||
kontext_image_ids,
|
||||
)
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
def _async_pipeline(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
patch_latents: list[mx.array],
|
||||
token_indices: list[tuple[int, int]],
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> list[mx.array]:
|
||||
"""Execute async pipeline for all patches."""
|
||||
assert self.joint_kv_caches is not None
|
||||
assert self.single_kv_caches is not None
|
||||
|
||||
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
|
||||
extra_kwargs = prompt_data.get_extra_forward_kwargs()
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
batch_size = patch_latents[0].shape[0]
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
hidden_dim = self.adapter.hidden_dim
|
||||
|
||||
for patch_idx, patch in enumerate(patch_latents):
|
||||
patch_prev = patch
|
||||
|
||||
start_token, end_token = token_indices[patch_idx]
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if (
|
||||
not self.is_first_stage
|
||||
or t != config.init_time_step + self.num_sync_steps
|
||||
):
|
||||
if self.is_first_stage:
|
||||
# First stage receives latent-space from last stage (scheduler output)
|
||||
recv_template = patch
|
||||
else:
|
||||
# Other stages receive hidden-space from previous stage
|
||||
patch_len = patch.shape[1]
|
||||
recv_template = mx.zeros(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
dtype=patch.dtype,
|
||||
)
|
||||
patch = mx.distributed.recv_like(
|
||||
recv_template, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
patch_latents[patch_idx] = patch
|
||||
|
||||
if not self.is_first_stage and patch_idx == 0:
|
||||
enc_template = mx.zeros(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype=patch_latents[0].dtype,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv_like(
|
||||
enc_template, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
patch, prompt_embeds
|
||||
)
|
||||
|
||||
# Run assigned joint blocks with patched mode
|
||||
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
|
||||
encoder_hidden_states, patch = wrapper(
|
||||
hidden_states=patch,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.joint_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.PATCHED,
|
||||
patch_start=start_token,
|
||||
patch_end=end_token,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
# Keep locally: either for single blocks or final projection
|
||||
patch = patch_concat
|
||||
else:
|
||||
mx.eval(
|
||||
mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
mx.eval(mx.distributed.send(patch, self.next_rank, group=self.group))
|
||||
|
||||
if patch_idx == 0:
|
||||
mx.eval(
|
||||
mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
recv_template = mx.zeros(
|
||||
[
|
||||
batch_size,
|
||||
text_seq_len + patch_latents[patch_idx].shape[1],
|
||||
hidden_dim,
|
||||
],
|
||||
dtype=patch_latents[0].dtype,
|
||||
)
|
||||
|
||||
patch = mx.distributed.recv_like(
|
||||
recv_template, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
patch_latents[patch_idx] = patch
|
||||
|
||||
# Run assigned single blocks with patched mode
|
||||
for block_idx, wrapper in enumerate(self.single_block_wrappers):
|
||||
patch = wrapper(
|
||||
hidden_states=patch,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.single_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.PATCHED,
|
||||
patch_start=start_token,
|
||||
patch_end=end_token,
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
mx.eval(
|
||||
mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
patch_img_only = patch[:, text_seq_len:, :]
|
||||
|
||||
patch_img_only = self.adapter.final_projection(
|
||||
patch_img_only, text_embeddings
|
||||
)
|
||||
|
||||
patch = config.scheduler.step(
|
||||
model_output=patch_img_only,
|
||||
timestep=t,
|
||||
sample=patch_prev,
|
||||
)
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
mx.eval(
|
||||
mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
patch_latents[patch_idx] = patch
|
||||
|
||||
return patch_latents
|
||||
@@ -10,18 +10,23 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.cache import (
|
||||
_BaseCache, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
)
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
|
||||
class _LayerCallable(Protocol):
|
||||
@@ -91,8 +96,6 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
|
||||
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
@@ -100,10 +103,8 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
# This change happened upstream - check out mlx github somewhere??
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
# TODO(ciaran): This is overkill
|
||||
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
|
||||
return output
|
||||
|
||||
@@ -133,24 +134,6 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
|
||||
return layers
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(model, DeepseekV3Model) and hasattr(
|
||||
inner_model_instance, "num_layers"
|
||||
):
|
||||
inner_model_instance.start_idx = 0
|
||||
inner_model_instance.end_idx = len(layers)
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
inner_model_instance.h = layers
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
def pipeline_auto_parallel(
|
||||
model: nn.Module,
|
||||
group: mx.distributed.Group,
|
||||
@@ -166,8 +149,7 @@ def pipeline_auto_parallel(
|
||||
"""
|
||||
inner_model_instance: nn.Module = _inner_model(model)
|
||||
|
||||
# Handle both model.layers and model.h cases
|
||||
layers: list[_LayerCallable] = _get_layers(inner_model_instance)
|
||||
layers = _get_layers(inner_model_instance)
|
||||
|
||||
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
@@ -181,6 +163,17 @@ def pipeline_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, GptOssMoeModel):
|
||||
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
|
||||
start_layer:end_layer
|
||||
]
|
||||
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
)
|
||||
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
)
|
||||
|
||||
_set_layers(model, layers)
|
||||
|
||||
assert isinstance(layers, list), (
|
||||
@@ -205,18 +198,44 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
segments: int = 1
|
||||
|
||||
def _all_to_sharded(path: str, weight: mx.array):
|
||||
if path.endswith("bias"):
|
||||
logger.info(f"Sharding bias for {path} - all to sharded")
|
||||
return weight.ndim - 1, segments
|
||||
return max(weight.ndim - 2, 0), segments
|
||||
|
||||
all_to_sharded_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding="all-to-sharded",
|
||||
group=group,
|
||||
)
|
||||
sharded_to_all_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding="sharded-to-all",
|
||||
sharding=_all_to_sharded, # type: ignore
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(model, LlamaModel):
|
||||
n = group.size()
|
||||
|
||||
def _sharded_to_all(path: str, weight: mx.array):
|
||||
if path.endswith("bias"):
|
||||
logger.info(f"Sharding bias for {path} - sharded to all")
|
||||
weight /= n
|
||||
return None
|
||||
return -1, segments
|
||||
|
||||
sharded_to_all_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding=_sharded_to_all, # type: ignore
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return model
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -224,7 +243,8 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, DeepseekV3Model):
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -232,7 +252,7 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, Qwen3MoeModel):
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -240,6 +260,15 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GptOssModel):
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -285,13 +314,38 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
|
||||
) and hasattr(inner_model_instance, "num_layers"):
|
||||
logger.info(
|
||||
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
)
|
||||
inner_model_instance.start_idx = 0
|
||||
inner_model_instance.end_idx = len(layers)
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
elif isinstance(model, Qwen3MoeModel):
|
||||
logger.info(
|
||||
f"Setting num_hidden_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
)
|
||||
inner_model_instance.num_hidden_layers = len(layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
inner_model_instance.h = layers
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
for layer in model.layers:
|
||||
# Shard the self attention
|
||||
if layer.self_attn.q_lora_rank is None: # pyright: ignore[reportUnnecessaryComparison]
|
||||
# Unfortunately, q_lora_rank can be None despite typing hints.
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
@@ -306,7 +360,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, DeepseekV3MLP):
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
@@ -354,7 +408,9 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
if isinstance(
|
||||
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
|
||||
):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
@@ -382,3 +438,50 @@ class ShardedQwenMoE(CustomMlxLayer):
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
layer.self_attn.num_key_value_groups = (
|
||||
layer.self_attn.num_attention_heads
|
||||
// layer.self_attn.num_key_value_heads
|
||||
)
|
||||
|
||||
layer.self_attn.sinks = layer.self_attn.sinks[
|
||||
layer.self_attn.num_attention_heads
|
||||
* self.group.rank() : layer.self_attn.num_attention_heads
|
||||
* (self.group.rank() + 1)
|
||||
]
|
||||
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.shared.types.api import ChatCompletionMessage, FinishReason
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
@@ -41,7 +48,6 @@ def maybe_quantize_kv_cache(
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -64,6 +70,9 @@ def warmup_inference(
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Use a default sampler for warmup
|
||||
sampler = make_sampler(temp=0.7)
|
||||
|
||||
logger.info("Generating warmup tokens")
|
||||
for _r in stream_generate(
|
||||
model=model,
|
||||
@@ -72,7 +81,7 @@ def warmup_inference(
|
||||
max_tokens=50,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=65536,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
@@ -80,20 +89,47 @@ def warmup_inference(
|
||||
tokens_generated += 1
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
|
||||
return tokens_generated
|
||||
|
||||
|
||||
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
|
||||
token_ids = [int(t) for t in token_ids]
|
||||
|
||||
def proc(_history: mx.array, logits: mx.array) -> mx.array:
|
||||
for tid in token_ids:
|
||||
logits[..., tid] = -1e9
|
||||
return logits
|
||||
|
||||
return proc
|
||||
|
||||
|
||||
def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
eos: list[int] | None = getattr(tokenizer, "eos_token_ids", None)
|
||||
if eos is None:
|
||||
return []
|
||||
return eos
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
task: ChatCompletionTaskParams,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
|
||||
|
||||
# Currently we support chat-completion tasks only.
|
||||
logger.info(f"task_params: {task}")
|
||||
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
|
||||
prompt = apply_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
chat_task_data=task,
|
||||
@@ -101,6 +137,17 @@ def mlx_generate(
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
eos_ids = eos_ids_from_tokenizer(tokenizer)
|
||||
logits_processors = [ban_token_ids(eos_ids)]
|
||||
|
||||
sampler = make_sampler(
|
||||
temp=task.temperature if task.temperature is not None else 0.7,
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
@@ -108,26 +155,40 @@ def mlx_generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
prefill_step_size=65536,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(out.text)
|
||||
if out.finish_reason is not None and out.finish_reason not in get_args(
|
||||
FinishReason
|
||||
):
|
||||
# We don't throw here as this failure case is really not all that bad
|
||||
# Just log the error and move on
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
|
||||
if out.finish_reason not in get_args(FinishReason):
|
||||
# We don't throw here as this failure case is really not all that bad
|
||||
# Just log the error and move on
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -1,13 +1,25 @@
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast
|
||||
from typing import Any, cast
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
try:
|
||||
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
|
||||
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
|
||||
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
pass # transformers < 5.0 or bytes_to_unicode not available
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
@@ -19,7 +31,7 @@ from exo.worker.engines.mlx.constants import (
|
||||
try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
except ImportError:
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer
|
||||
import contextlib
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -176,11 +188,7 @@ def initialize_mlx(
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
# TODO: pass temperature
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
|
||||
logger.info("Created a sampler")
|
||||
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
@@ -201,7 +209,7 @@ def load_mlx_items(
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
return cast(Model, model), tokenizer, sampler
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
@@ -257,26 +265,70 @@ def shard_and_load(
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata):
|
||||
# TODO: Let's move away from this custom logic to mlx_lm.load()
|
||||
if "kimi-k2" in shard_metadata.model_meta.model_id.lower():
|
||||
eos_token_ids = [163586]
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
|
||||
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
|
||||
|
||||
elif "glm" in shard_metadata.model_meta.model_id.lower():
|
||||
eos_token_ids = [151336, 151329, 151338]
|
||||
|
||||
else:
|
||||
eos_token_ids = None
|
||||
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
||||
"""
|
||||
Get the EOS token IDs for a model based on its ID.
|
||||
|
||||
tokenizer = cast(
|
||||
TokenizerWrapper,
|
||||
load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
|
||||
eos_token_ids=eos_token_ids,
|
||||
),
|
||||
Some models require explicit EOS token configuration that isn't in their
|
||||
tokenizer config. This function returns the known EOS token IDs for such models.
|
||||
|
||||
Args:
|
||||
model_id: The HuggingFace model ID
|
||||
|
||||
Returns:
|
||||
List of EOS token IDs, or None if the model uses standard tokenizer config
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
elif "glm" in model_id_lower:
|
||||
return [151336, 151329, 151338]
|
||||
return None
|
||||
|
||||
|
||||
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
|
||||
"""
|
||||
Load tokenizer for a model given its ID and local path.
|
||||
|
||||
This is the core tokenizer loading logic, handling special cases for different
|
||||
model families (Kimi, GLM, etc.) and transformers 5.x compatibility.
|
||||
|
||||
Args:
|
||||
model_id: The HuggingFace model ID (e.g., "moonshotai/Kimi-K2-Instruct")
|
||||
model_path: Local path where the model/tokenizer files are stored
|
||||
|
||||
Returns:
|
||||
TokenizerWrapper instance configured for the model
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
|
||||
if "kimi-k2" in model_id_lower:
|
||||
sys.path.insert(0, str(model_path))
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
||||
|
||||
# Patch encode to use internal tiktoken model directly
|
||||
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
|
||||
def _patched_encode(text: str, **_kwargs: object) -> list[int]:
|
||||
# Pass allowed_special="all" to handle special tokens like <|im_user|>
|
||||
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
|
||||
hf_tokenizer.encode = _patched_encode
|
||||
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
|
||||
|
||||
tokenizer = load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
assert isinstance(tokenizer, TokenizerWrapper)
|
||||
|
||||
return tokenizer
|
||||
|
||||
@@ -289,15 +341,15 @@ def apply_chat_template(
|
||||
messages = chat_task_data.messages
|
||||
|
||||
formatted_messages: list[dict[str, Any]] = []
|
||||
for _, message in enumerate(messages):
|
||||
for message in messages:
|
||||
if isinstance(message.content, ChatCompletionMessageText):
|
||||
message.content = message.content.text
|
||||
if isinstance(message.content, list):
|
||||
if len(message.content) != 1:
|
||||
logger.warning("Received malformed prompt")
|
||||
if len(message.content) == 0:
|
||||
logger.warning("Received prompt with no content, skipping")
|
||||
continue
|
||||
|
||||
message.content = message.content[0].text
|
||||
message.content = "\n".join(c.text for c in message.content).strip()
|
||||
if message.content is None and message.thinking is None:
|
||||
continue
|
||||
|
||||
@@ -306,13 +358,14 @@ def apply_chat_template(
|
||||
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
|
||||
)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template( # type: ignore
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
|
||||
return prompt # type: ignore
|
||||
return prompt
|
||||
|
||||
|
||||
class NullKVCache(KVCache):
|
||||
@@ -397,3 +450,13 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
)
|
||||
mx.set_wired_limit(max_rec_size)
|
||||
logger.info(f"Wired limit set to {max_rec_size}.")
|
||||
|
||||
|
||||
def mlx_cleanup(
|
||||
model: Model | None, tokenizer: TokenizerWrapper | None, group: Group | None
|
||||
) -> None:
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
@@ -8,15 +8,13 @@ from loguru import logger
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.api import ImageEditsInternalParams
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
NodeDownloadProgress,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
@@ -32,7 +30,6 @@ from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
Shutdown,
|
||||
Task,
|
||||
TaskStatus,
|
||||
@@ -98,10 +95,6 @@ class Worker:
|
||||
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
|
||||
# Buffer for input image chunks (for image editing)
|
||||
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
|
||||
self.input_chunk_counts: dict[CommandId, int] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
@@ -180,17 +173,6 @@ class Worker:
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
if cmd_id not in self.input_chunk_buffer:
|
||||
self.input_chunk_buffer[cmd_id] = {}
|
||||
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
|
||||
|
||||
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
|
||||
event.chunk.data
|
||||
)
|
||||
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
await anyio.sleep(0.1)
|
||||
@@ -203,8 +185,6 @@ class Worker:
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
self.state.tasks,
|
||||
self.input_chunk_buffer,
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
continue
|
||||
@@ -237,7 +217,9 @@ class Worker:
|
||||
)
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
@@ -266,42 +248,6 @@ class Worker:
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
chunks = self.input_chunk_buffer.get(cmd_id, {})
|
||||
assembled = "".join(chunks[i] for i in range(len(chunks)))
|
||||
logger.info(
|
||||
f"Assembled input image from {len(chunks)} chunks, "
|
||||
f"total size: {len(assembled)} bytes"
|
||||
)
|
||||
# Create modified task with assembled image data
|
||||
modified_task = ImageEdits(
|
||||
task_id=task.task_id,
|
||||
command_id=task.command_id,
|
||||
instance_id=task.instance_id,
|
||||
task_status=task.task_status,
|
||||
task_params=ImageEditsInternalParams(
|
||||
image_data=assembled,
|
||||
total_input_chunks=task.task_params.total_input_chunks,
|
||||
prompt=task.task_params.prompt,
|
||||
model=task.task_params.model,
|
||||
n=task.task_params.n,
|
||||
quality=task.task_params.quality,
|
||||
output_format=task.task_params.output_format,
|
||||
response_format=task.task_params.response_format,
|
||||
size=task.task_params.size,
|
||||
image_strength=task.task_params.image_strength,
|
||||
),
|
||||
)
|
||||
# Cleanup buffers
|
||||
if cmd_id in self.input_chunk_buffer:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
@@ -420,7 +366,11 @@ class Worker:
|
||||
nonlocal self
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
|
||||
status = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
# Footgun!
|
||||
self.event_sender.send_nowait(
|
||||
@@ -513,7 +463,9 @@ class Worker:
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
|
||||
@@ -2,15 +2,13 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -51,8 +49,6 @@ def plan(
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
tasks: Mapping[TaskId, Task],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_counts: Mapping[CommandId, int] | None = None,
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
@@ -62,7 +58,7 @@ def plan(
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
)
|
||||
|
||||
|
||||
@@ -266,28 +262,14 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
# TODO(ciaran): do this better!
|
||||
if (
|
||||
not isinstance(task, ChatCompletion)
|
||||
and not isinstance(task, ImageGeneration)
|
||||
and not isinstance(task, ImageEdits)
|
||||
):
|
||||
if not isinstance(task, ChatCompletion):
|
||||
continue
|
||||
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
|
||||
continue
|
||||
|
||||
# For ImageEdits tasks, verify all input chunks have been received
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
for runner in runners.values():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
@@ -6,7 +6,7 @@ from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
|
||||
logger: "loguru.Logger" = loguru.logger
|
||||
|
||||
@@ -31,6 +31,8 @@ def entrypoint(
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
f"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}"
|
||||
@@ -42,8 +44,10 @@ def entrypoint(
|
||||
)
|
||||
)
|
||||
finally:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
try:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
finally:
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import base64
|
||||
import time
|
||||
|
||||
from exo.master.api import get_model_card
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import ImageChunk, TokenChunk
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -12,12 +11,9 @@ from exo.shared.types.events import (
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.models import ModelTask
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -27,8 +23,6 @@ from exo.shared.types.tasks import (
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
@@ -44,14 +38,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
from exo.worker.engines.image import (
|
||||
ImageGenerator,
|
||||
generate_image,
|
||||
initialize_image_model,
|
||||
warmup_image_generator,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
@@ -71,368 +58,153 @@ def main(
|
||||
bound_instance.bound_runner_id,
|
||||
bound_instance.bound_shard,
|
||||
)
|
||||
try:
|
||||
logger.info("hello from the runner")
|
||||
if getattr(shard_metadata, "immediate_exception", False):
|
||||
raise Exception("Fake exception - runner failed to spin up.")
|
||||
if timeout := getattr(shard_metadata, "should_timeout", 0):
|
||||
time.sleep(timeout)
|
||||
logger.info("hello from the runner")
|
||||
if getattr(shard_metadata, "immediate_exception", False):
|
||||
raise Exception("Fake exception - runner failed to spin up.")
|
||||
if timeout := getattr(shard_metadata, "should_timeout", 0):
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
setup_start_time = time.time()
|
||||
|
||||
model = None
|
||||
tokenizer = None
|
||||
sampler = None
|
||||
group = None
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
|
||||
model_card = get_model_card(shard_metadata.model_meta.model_id)
|
||||
assert model_card
|
||||
model_tasks = model_card.tasks
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected)
|
||||
and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(ciaran): switch
|
||||
if ModelTask.TextGeneration in model_tasks:
|
||||
model, tokenizer, sampler = load_mlx_items(
|
||||
bound_instance, group
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in model_tasks
|
||||
or ModelTask.ImageToImage in model_tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {model_card.tasks}"
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in model_tasks:
|
||||
assert model and isinstance(model, Model)
|
||||
assert tokenizer
|
||||
assert sampler
|
||||
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in model_tasks
|
||||
or ModelTask.ImageToImage in model_tasks
|
||||
):
|
||||
assert isinstance(model, ImageGenerator)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(
|
||||
f"warmed up by generating {image.size} image"
|
||||
)
|
||||
else:
|
||||
logger.info("warmup completed (non-primary node)")
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert model and isinstance(model, Model)
|
||||
assert tokenizer
|
||||
assert sampler
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
task=task_params,
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
),
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, ImageGenerator)
|
||||
logger.info(
|
||||
f"received image generation request: {str(task)[:500]}"
|
||||
)
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=model,
|
||||
task=task_params,
|
||||
):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
encoded_data = base64.b64encode(
|
||||
response.image_data
|
||||
).decode("utf-8")
|
||||
# Split into chunks to stay under gossipsub 1MB limit
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(
|
||||
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
|
||||
)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}: {len(encoded_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(
|
||||
data_chunks
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=response.partial_index,
|
||||
is_partial=True,
|
||||
partial_index=response.partial_index,
|
||||
total_partials=response.total_partials,
|
||||
),
|
||||
)
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
encoded_data = base64.b64encode(
|
||||
response.image_data
|
||||
).decode("utf-8")
|
||||
# Split into chunks to stay under gossipsub 1MB limit
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(
|
||||
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
|
||||
)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
logger.info(
|
||||
f"sending final ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(
|
||||
data_chunks
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=image_index,
|
||||
is_partial=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, ImageGenerator)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=model,
|
||||
task=task_params,
|
||||
):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case ImageGenerationResponse():
|
||||
encoded_data = base64.b64encode(
|
||||
response.image_data
|
||||
).decode("utf-8")
|
||||
# Split into chunks to stay under gossipsub 1MB limit
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(
|
||||
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
|
||||
)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
logger.info(
|
||||
f"sending ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(
|
||||
data_chunks
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=image_index,
|
||||
is_partial=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
break
|
||||
except ClosedResourceError:
|
||||
logger.warning("runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
f"Runner {runner_id} crashed with critical exception {e}"
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(error_message=str(e)),
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
)
|
||||
finally:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
break
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
|
||||
386
src/exo/worker/tests/unittests/test_mlx/test_tokenizers.py
Normal file
386
src/exo/worker/tests/unittests/test_mlx/test_tokenizers.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Unit tests for tokenizer loading and functionality across all supported models.
|
||||
|
||||
This test downloads only tokenizer-related files (not full model weights) to verify
|
||||
that tokenizers can be loaded and used correctly for encoding/decoding.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
|
||||
from exo.worker.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
get_eos_token_ids_for_model,
|
||||
load_tokenizer_for_model_id,
|
||||
)
|
||||
|
||||
# Files needed for tokenizer functionality
|
||||
TOKENIZER_FILE_PATTERNS = [
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
"vocab.json",
|
||||
"vocab.txt",
|
||||
"merges.txt",
|
||||
"tiktoken.model",
|
||||
"added_tokens.json",
|
||||
"tokenizer.model",
|
||||
"tokenization_*.py", # Custom tokenizer implementations
|
||||
]
|
||||
|
||||
|
||||
def is_tokenizer_file(filename: str) -> bool:
|
||||
"""Check if a file is needed for tokenizer functionality."""
|
||||
for pattern in TOKENIZER_FILE_PATTERNS:
|
||||
if "*" in pattern:
|
||||
prefix = pattern.split("*")[0]
|
||||
suffix = pattern.split("*")[1]
|
||||
if filename.startswith(prefix) and filename.endswith(suffix):
|
||||
return True
|
||||
elif filename == pattern:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def download_tokenizer_files(model_id: str) -> Path:
|
||||
"""Download only the tokenizer-related files for a model."""
|
||||
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
|
||||
|
||||
tokenizer_files = [f for f in file_list if is_tokenizer_file(f.path)]
|
||||
|
||||
if not tokenizer_files:
|
||||
pytest.skip(f"No tokenizer files found for {model_id}")
|
||||
|
||||
for file_entry in tokenizer_files:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
await download_file_with_retry(
|
||||
model_id, "main", file_entry.path, target_dir
|
||||
)
|
||||
|
||||
return target_dir
|
||||
|
||||
|
||||
# Get a sample of models to test (one per family to keep tests fast)
|
||||
def get_test_models() -> list[tuple[str, ModelCard]]:
|
||||
"""Get a representative sample of models to test."""
|
||||
# Pick one model from each family to test
|
||||
families: dict[str, tuple[str, ModelCard]] = {}
|
||||
for short_id, card in MODEL_CARDS.items():
|
||||
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
|
||||
parts = short_id.split("-")
|
||||
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
|
||||
|
||||
if family not in families:
|
||||
families[family] = (short_id, card)
|
||||
|
||||
return list(families.values())
|
||||
|
||||
|
||||
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
"""Create event loop for async tests."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode and decode text correctly."""
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
# Download tokenizer files
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
# Verify required files exist
|
||||
has_tokenizer = (
|
||||
(model_path / "tokenizer.json").exists()
|
||||
or (model_path / "tokenizer_config.json").exists()
|
||||
or (model_path / "tiktoken.model").exists()
|
||||
or (model_path / "tokenizer.model").exists()
|
||||
)
|
||||
if not has_tokenizer:
|
||||
pytest.skip(f"Required tokenizer files not found for {model_id}")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
|
||||
# Test basic encoding
|
||||
test_text = "Hello, world!"
|
||||
encoded = tokenizer.encode(test_text)
|
||||
assert isinstance(encoded, list), f"encode() should return a list for {model_id}"
|
||||
assert len(encoded) > 0, f"encode() should return non-empty list for {model_id}"
|
||||
assert all(isinstance(t, int) for t in encoded), (
|
||||
f"All tokens should be integers for {model_id}"
|
||||
)
|
||||
|
||||
# Test decoding
|
||||
decoded = tokenizer.decode(encoded)
|
||||
assert isinstance(decoded, str), f"decode() should return a string for {model_id}"
|
||||
assert test_text in decoded or decoded.strip() == test_text.strip(), (
|
||||
f"decode(encode(x)) should preserve text for {model_id}: got {decoded!r}"
|
||||
)
|
||||
|
||||
# Test with longer text
|
||||
long_text = "The quick brown fox jumps over the lazy dog. " * 10
|
||||
long_encoded = tokenizer.encode(long_text)
|
||||
assert len(long_encoded) > len(encoded), (
|
||||
f"Longer text should produce more tokens for {model_id}"
|
||||
)
|
||||
|
||||
# Test empty string
|
||||
empty_encoded = tokenizer.encode("")
|
||||
assert isinstance(empty_encoded, list), (
|
||||
f"encode('') should return a list for {model_id}"
|
||||
)
|
||||
|
||||
# Test special characters
|
||||
special_text = 'Hello!\n\tWorld? <test> & "quotes"'
|
||||
special_encoded = tokenizer.encode(special_text)
|
||||
assert len(special_encoded) > 0, f"Special chars should encode for {model_id}"
|
||||
|
||||
# Test unicode
|
||||
unicode_text = "Hello 世界 🌍"
|
||||
unicode_encoded = tokenizer.encode(unicode_text)
|
||||
assert len(unicode_encoded) > 0, f"Unicode should encode for {model_id}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_has_required_attributes(
|
||||
short_id: str, model_card: ModelCard
|
||||
) -> None:
|
||||
"""Test that tokenizer has required attributes for inference."""
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
has_tokenizer = (
|
||||
(model_path / "tokenizer.json").exists()
|
||||
or (model_path / "tokenizer_config.json").exists()
|
||||
or (model_path / "tiktoken.model").exists()
|
||||
or (model_path / "tokenizer.model").exists()
|
||||
)
|
||||
if not has_tokenizer:
|
||||
pytest.skip(f"Required tokenizer files not found for {model_id}")
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Check for vocabulary size
|
||||
empty_vocab: dict[str, int] = {}
|
||||
vocab_size: int = getattr(tokenizer, "vocab_size", None) or len(
|
||||
getattr(tokenizer, "get_vocab", lambda: empty_vocab)()
|
||||
)
|
||||
assert vocab_size > 0, f"Tokenizer should have vocab_size > 0 for {model_id}"
|
||||
|
||||
# Check for EOS token (either from tokenizer or explicitly provided)
|
||||
has_eos = (
|
||||
eos_token_ids is not None
|
||||
or getattr(tokenizer, "eos_token_id", None) is not None
|
||||
or getattr(tokenizer, "eos_token", None) is not None
|
||||
)
|
||||
assert has_eos, f"Tokenizer should have EOS token for {model_id}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode text containing special tokens.
|
||||
|
||||
This is critical because the actual inference path uses prompts with
|
||||
special tokens from chat templates. If special tokens aren't handled
|
||||
correctly, encoding will fail.
|
||||
"""
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
has_tokenizer = (
|
||||
(model_path / "tokenizer.json").exists()
|
||||
or (model_path / "tokenizer_config.json").exists()
|
||||
or (model_path / "tiktoken.model").exists()
|
||||
or (model_path / "tokenizer.model").exists()
|
||||
)
|
||||
assert has_tokenizer, f"Required tokenizer files not found for {model_id}"
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
|
||||
# Get special tokens from the tokenizer
|
||||
special_tokens: list[str] = []
|
||||
|
||||
# Try to get special tokens from various sources
|
||||
if hasattr(tokenizer, "all_special_tokens"):
|
||||
special_tokens.extend(tokenizer.all_special_tokens)
|
||||
elif hasattr(tokenizer, "_tokenizer") and hasattr(
|
||||
tokenizer._tokenizer,
|
||||
"all_special_tokens",
|
||||
):
|
||||
special_tokens.extend(tokenizer._tokenizer.all_special_tokens)
|
||||
|
||||
# Also check for common special token attributes
|
||||
for attr in [
|
||||
"bos_token",
|
||||
"eos_token",
|
||||
"pad_token",
|
||||
"unk_token",
|
||||
"sep_token",
|
||||
"cls_token",
|
||||
]:
|
||||
token = getattr(tokenizer, attr, None)
|
||||
if token is None and hasattr(tokenizer, "_tokenizer"):
|
||||
token = getattr(tokenizer._tokenizer, attr, None)
|
||||
if token and isinstance(token, str) and token not in special_tokens:
|
||||
special_tokens.append(token)
|
||||
|
||||
# If we found special tokens, test encoding text that contains them
|
||||
if special_tokens:
|
||||
# Create text with special tokens interspersed
|
||||
test_with_special = f"{special_tokens[0]}Hello world"
|
||||
if len(special_tokens) > 1:
|
||||
test_with_special += f"{special_tokens[1]}"
|
||||
|
||||
encoded = tokenizer.encode(test_with_special)
|
||||
assert isinstance(encoded, list), (
|
||||
f"encode() with special tokens should return list for {model_id}"
|
||||
)
|
||||
assert len(encoded) > 0, (
|
||||
f"encode() with special tokens should return non-empty list for {model_id}"
|
||||
)
|
||||
assert all(isinstance(t, int) for t in encoded), (
|
||||
f"All tokens should be integers for {model_id}"
|
||||
)
|
||||
|
||||
# Verify we can decode
|
||||
decoded = tokenizer.decode(encoded)
|
||||
assert isinstance(decoded, str), f"decode() should return string for {model_id}"
|
||||
|
||||
# Test with angle-bracket tokens (common format for special tokens)
|
||||
# These should not raise errors even if they're not actual special tokens
|
||||
angle_bracket_text = "<|test|>Hello<|end|>"
|
||||
encoded = tokenizer.encode(angle_bracket_text)
|
||||
assert isinstance(encoded, list), (
|
||||
f"encode() with angle brackets should return list for {model_id}"
|
||||
)
|
||||
assert len(encoded) > 0, (
|
||||
f"encode() with angle brackets should be non-empty for {model_id}"
|
||||
)
|
||||
|
||||
|
||||
# Specifically test Kimi tokenizer since it has special handling
|
||||
@pytest.mark.asyncio
|
||||
async def test_kimi_tokenizer_specifically():
|
||||
"""Test Kimi tokenizer with its specific patches and quirks."""
|
||||
kimi_models = [
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "kimi" in short_id.lower()
|
||||
]
|
||||
|
||||
if not kimi_models:
|
||||
pytest.skip("No Kimi models found in MODEL_CARDS")
|
||||
|
||||
_, model_card = kimi_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
# Ensure the custom tokenizer file exists
|
||||
if not (model_path / "tokenization_kimi.py").exists():
|
||||
pytest.skip("tokenization_kimi.py not found")
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Test encode/decode cycle
|
||||
test_text = "Hello, world!"
|
||||
encoded = tokenizer.encode(test_text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
|
||||
assert len(encoded) > 0, "Kimi tokenizer should encode text"
|
||||
assert isinstance(decoded, str), "Kimi tokenizer should decode to string"
|
||||
|
||||
# Test that the patched encode works (returns list of ints)
|
||||
assert all(isinstance(t, int) for t in encoded), "Tokens should be integers"
|
||||
|
||||
# Test encoding text with special tokens (like from chat templates)
|
||||
# This is critical - the warmup inference uses prompts with special tokens
|
||||
special_token_text = "<|im_user|>user<|im_middle|>Hello<|im_end|><|im_assistant|>"
|
||||
special_encoded = tokenizer.encode(special_token_text)
|
||||
assert len(special_encoded) > 0, "Kimi tokenizer should handle special tokens"
|
||||
assert all(isinstance(t, int) for t in special_encoded), (
|
||||
"Special token encoding should return integers"
|
||||
)
|
||||
|
||||
# Verify EOS token is set
|
||||
assert eos_token_ids == [163586], "Kimi EOS token should be [163586]"
|
||||
|
||||
|
||||
# Test GLM tokenizer since it also has special handling
|
||||
@pytest.mark.asyncio
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
glm_models = [
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "glm" in short_id.lower()
|
||||
]
|
||||
|
||||
if not glm_models:
|
||||
pytest.skip("No GLM models found in MODEL_CARDS")
|
||||
|
||||
_, model_card = glm_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
has_tokenizer = (model_path / "tokenizer.json").exists() or (
|
||||
model_path / "tokenizer_config.json"
|
||||
).exists()
|
||||
if not has_tokenizer:
|
||||
pytest.skip("GLM tokenizer files not found")
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Test encode/decode
|
||||
test_text = "Hello, world!"
|
||||
encoded = tokenizer.encode(test_text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
|
||||
assert len(encoded) > 0, "GLM tokenizer should encode text"
|
||||
assert isinstance(decoded, str), "GLM tokenizer should decode to string"
|
||||
|
||||
# Verify EOS tokens
|
||||
assert eos_token_ids == [
|
||||
151336,
|
||||
151329,
|
||||
151338,
|
||||
], "GLM EOS tokens should be correct"
|
||||
@@ -1,5 +1,6 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
@@ -94,13 +95,23 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view has completed downloads for both nodes
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_B: [DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
@@ -140,7 +151,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
|
||||
@@ -192,10 +205,16 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
@@ -212,9 +231,15 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
assert result is None
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1, 1)))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ async def check_reachability(
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.HTTPException:
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user