mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-24 22:12:39 -05:00
Compare commits
30 Commits
ciaran/ima
...
linux-cpu-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d2e828aba | ||
|
|
b5319d6b03 | ||
|
|
b988e08d69 | ||
|
|
9bf5979f8a | ||
|
|
91944383d3 | ||
|
|
dcc6872724 | ||
|
|
dccc2709c5 | ||
|
|
20d1246600 | ||
|
|
81bad9e01a | ||
|
|
7ff67d0a28 | ||
|
|
c888b13d3f | ||
|
|
1f80705b56 | ||
|
|
b349330404 | ||
|
|
812ce47194 | ||
|
|
643c6b8d28 | ||
|
|
4754f56bd4 | ||
|
|
66d01369b4 | ||
|
|
d20d9e5fc8 | ||
|
|
e67282282c | ||
|
|
54daa9e2db | ||
|
|
06125d1503 | ||
|
|
505e756872 | ||
|
|
4cd3db0f6e | ||
|
|
8b137a1e64 | ||
|
|
4176c7ec25 | ||
|
|
dbce607911 | ||
|
|
9949b93517 | ||
|
|
f4feeff077 | ||
|
|
f529884344 | ||
|
|
df4c6ce24e |
159
.github/benchmark-dashboard/README.md
vendored
Normal file
159
.github/benchmark-dashboard/README.md
vendored
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
# EXO Benchmark Dashboard
|
||||||
|
|
||||||
|
A fully self-contained, browser-based dashboard for tracking EXO benchmark performance over time.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- 📊 **Success Rate Tracking**: Monitor cluster reliability across commits
|
||||||
|
- ⚡ **Response Time Analysis**: Track average request completion times
|
||||||
|
- 🎯 **Throughput Metrics**: Tokens per second visualization
|
||||||
|
- 📈 **Request Distribution**: Success/failure breakdown over time
|
||||||
|
- 🔄 **Auto-Refresh**: Updates every 60 seconds
|
||||||
|
- 📺 **TV-Ready**: Large, clear visualizations perfect for display
|
||||||
|
- 🔐 **Secure**: Credentials stored in browser localStorage only
|
||||||
|
- 🌐 **No Backend**: Directly accesses S3 from the browser
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Option 1: Direct File Access (Simplest)
|
||||||
|
|
||||||
|
Just open the HTML file directly in your browser:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
open .github/benchmark-dashboard/index.html
|
||||||
|
```
|
||||||
|
|
||||||
|
Then click "Configure AWS Credentials" and enter your keys.
|
||||||
|
|
||||||
|
### Option 2: URL Parameters (For Quick Setup)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Serve with credentials in URL (they'll be moved to localStorage)
|
||||||
|
open ".github/benchmark-dashboard/index.html?accessKey=YOUR_KEY&secretKey=YOUR_SECRET®ion=us-east-1"
|
||||||
|
```
|
||||||
|
|
||||||
|
The credentials will be saved to localStorage and removed from the URL immediately.
|
||||||
|
|
||||||
|
### Option 3: Simple HTTP Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# From repo root
|
||||||
|
python3 -m http.server 8080
|
||||||
|
|
||||||
|
# Then open: http://localhost:8080/.github/benchmark-dashboard/
|
||||||
|
```
|
||||||
|
|
||||||
|
## AWS Credentials
|
||||||
|
|
||||||
|
The dashboard needs read-only access to the `exo-benchmark-results` S3 bucket.
|
||||||
|
|
||||||
|
### Required IAM Permissions
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Version": "2012-10-17",
|
||||||
|
"Statement": [
|
||||||
|
{
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Action": [
|
||||||
|
"s3:GetObject",
|
||||||
|
"s3:ListBucket"
|
||||||
|
],
|
||||||
|
"Resource": [
|
||||||
|
"arn:aws:s3:::exo-benchmark-results",
|
||||||
|
"arn:aws:s3:::exo-benchmark-results/*"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Security Notes
|
||||||
|
|
||||||
|
- ✅ Credentials stored in browser `localStorage` only
|
||||||
|
- ✅ Never sent to any server (except AWS)
|
||||||
|
- ✅ All S3 access happens client-side
|
||||||
|
- ✅ Use read-only IAM credentials
|
||||||
|
- ⚠️ Don't commit credentials to git
|
||||||
|
- ⚠️ Use a dedicated read-only IAM user
|
||||||
|
|
||||||
|
## TV/Kiosk Mode
|
||||||
|
|
||||||
|
For permanent display on a TV:
|
||||||
|
|
||||||
|
### macOS
|
||||||
|
```bash
|
||||||
|
open -a "Google Chrome" --args --kiosk ".github/benchmark-dashboard/index.html"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Linux
|
||||||
|
```bash
|
||||||
|
chromium-browser --kiosk --app="file://$(pwd)/.github/benchmark-dashboard/index.html"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Auto-start on Boot
|
||||||
|
|
||||||
|
Create a simple startup script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#!/bin/bash
|
||||||
|
# /usr/local/bin/start-benchmark-dashboard.sh
|
||||||
|
|
||||||
|
cd /path/to/exo
|
||||||
|
python3 -m http.server 8080 &
|
||||||
|
sleep 2
|
||||||
|
chromium-browser --kiosk http://localhost:8080/.github/benchmark-dashboard/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Data Displayed
|
||||||
|
|
||||||
|
### Summary Cards
|
||||||
|
- **Latest Success Rate**: Most recent benchmark success percentage with trend
|
||||||
|
- **Avg Response Time**: Latest average response time in ms with trend
|
||||||
|
- **Total Benchmarks**: Count of all benchmarks run
|
||||||
|
- **Active Configurations**: Number of unique benchmark configs
|
||||||
|
|
||||||
|
### Charts
|
||||||
|
1. **Success Rate Over Time**: Line chart showing reliability trends
|
||||||
|
2. **Average Response Time**: Performance over time (lower is better)
|
||||||
|
3. **Throughput**: Tokens/second metric (higher is better)
|
||||||
|
4. **Request Distribution**: Stacked bar chart of successes/failures
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
1. **Loads AWS SDK**: Uses AWS SDK for JavaScript (browser version)
|
||||||
|
2. **Lists S3 Objects**: Fetches all files from `s3://exo-benchmark-results/bench/`
|
||||||
|
3. **Downloads Results**: Fetches each JSON result file
|
||||||
|
4. **Parses & Visualizes**: Uses Chart.js to create interactive charts
|
||||||
|
5. **Auto-Refreshes**: Polls S3 every 60 seconds for new results
|
||||||
|
|
||||||
|
## Customization
|
||||||
|
|
||||||
|
To modify the dashboard:
|
||||||
|
|
||||||
|
1. Edit `index.html`
|
||||||
|
2. Adjust `REFRESH_INTERVAL` for different polling frequency
|
||||||
|
3. Modify chart colors/styles in the Chart.js configuration
|
||||||
|
4. Add new metrics by extending the results parsing
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
**"AWS credentials not configured"**
|
||||||
|
- Click "Configure AWS Credentials" and enter your keys
|
||||||
|
|
||||||
|
**"Error loading benchmark data"**
|
||||||
|
- Check AWS credentials are correct
|
||||||
|
- Verify S3 bucket name is `exo-benchmark-results`
|
||||||
|
- Ensure IAM user has read permissions
|
||||||
|
- Check browser console for detailed errors
|
||||||
|
|
||||||
|
**"No benchmark results found"**
|
||||||
|
- Wait for benchmark workflows to run
|
||||||
|
- Verify results are being uploaded to S3
|
||||||
|
- Check S3 bucket has files in `bench/` prefix
|
||||||
|
|
||||||
|
**Charts not updating**
|
||||||
|
- Check browser console for errors
|
||||||
|
- Verify network connectivity to S3
|
||||||
|
- Try refreshing the page manually
|
||||||
|
|
||||||
1641
.github/benchmark-dashboard/index.html
vendored
Normal file
1641
.github/benchmark-dashboard/index.html
vendored
Normal file
File diff suppressed because it is too large
Load Diff
186
.github/configs/README.md
vendored
Normal file
186
.github/configs/README.md
vendored
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
# EXO Benchmark Configurations
|
||||||
|
|
||||||
|
This directory contains configuration files for the EXO staged benchmark system.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The staged benchmark system allows you to run complex, multi-stage load tests against EXO clusters. Each stage can have different characteristics:
|
||||||
|
|
||||||
|
- **Prompt Length**: Number of tokens in the input prompt
|
||||||
|
- **Generation Length**: Maximum tokens to generate in the response
|
||||||
|
- **Time Between Requests**: Delay (in seconds) between firing consecutive requests
|
||||||
|
- **Iterations**: Number of requests to send in this stage
|
||||||
|
|
||||||
|
Requests are **fire-and-forget** - they don't wait for the previous request to complete. This allows you to test overlapping request handling and measure success rates under load.
|
||||||
|
|
||||||
|
## Configuration Files
|
||||||
|
|
||||||
|
### `bench_simple.yaml`
|
||||||
|
A minimal configuration that replicates the behavior of the original `bench.py` script:
|
||||||
|
- Single stage with 1 iteration
|
||||||
|
- Short prompt (~20 tokens)
|
||||||
|
- Generates up to 100 tokens
|
||||||
|
|
||||||
|
This is useful for quick smoke tests.
|
||||||
|
|
||||||
|
### `bench_config.yaml`
|
||||||
|
A comprehensive multi-stage benchmark with:
|
||||||
|
1. **Warmup** (10 requests): Light load with short prompts
|
||||||
|
2. **Medium Load** (20 requests): Moderate load with medium prompts
|
||||||
|
3. **Stress Test** (30 requests): Heavy overlapping requests with long prompts
|
||||||
|
4. **Cooldown** (5 requests): Light load to wind down
|
||||||
|
|
||||||
|
This tests the cluster's behavior under varying load patterns.
|
||||||
|
|
||||||
|
## Configuration Schema
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Hardware configuration - maps runner labels to instance counts
|
||||||
|
hardware_plan:
|
||||||
|
M3ULTRA_GPU80_512GB: 4
|
||||||
|
|
||||||
|
# Environment variables to set on each node (optional)
|
||||||
|
environment:
|
||||||
|
OVERRIDE_MEMORY_MB: 512
|
||||||
|
|
||||||
|
# Timeout for instance and runner readiness (seconds)
|
||||||
|
timeout_seconds: 600
|
||||||
|
|
||||||
|
# Model instances to run concurrently
|
||||||
|
model_ids:
|
||||||
|
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||||
|
|
||||||
|
# Benchmark stages
|
||||||
|
stages:
|
||||||
|
- name: "stage_name" # Human-readable name for this stage
|
||||||
|
prompt_length: 100 # Target prompt length in tokens
|
||||||
|
generation_length: 200 # Max tokens to generate
|
||||||
|
time_between_requests: 2.0 # Seconds between firing requests
|
||||||
|
iterations: 10 # Number of requests in this stage
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running Benchmarks
|
||||||
|
|
||||||
|
### Via GitHub Actions
|
||||||
|
|
||||||
|
**Automatic (every commit):**
|
||||||
|
- The **`bench`** workflow runs automatically on every push
|
||||||
|
- Uses `bench_simple.yaml` as the default configuration
|
||||||
|
- All settings (hardware plan, timeout, environment variables, models, stages) are defined in the config file
|
||||||
|
|
||||||
|
**Manual (on-demand):**
|
||||||
|
1. Go to **Actions** → **bench** workflow
|
||||||
|
2. Click **Run workflow**
|
||||||
|
3. Configure:
|
||||||
|
- **Config File**: Path to your YAML config (default: `.github/configs/bench_simple.yaml`)
|
||||||
|
- `.github/configs/bench_simple.yaml` for quick tests
|
||||||
|
- `.github/configs/bench_config.yaml` for complex multi-stage tests
|
||||||
|
|
||||||
|
All other settings (hardware plan, timeout, environment variables, models, stages) are read from the specified config file.
|
||||||
|
|
||||||
|
### Via Command Line
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start EXO on localhost:8000
|
||||||
|
uv run exo --api-port 8000
|
||||||
|
|
||||||
|
# Run simple benchmark (1 stage, 1 iteration)
|
||||||
|
python3 .github/scripts/bench.py \
|
||||||
|
--api-port 8000 \
|
||||||
|
--config .github/configs/bench_simple.yaml \
|
||||||
|
--expected-nodes 1 \
|
||||||
|
--is-primary true \
|
||||||
|
--timeout-seconds 600
|
||||||
|
|
||||||
|
# Run complex staged benchmark (4 stages, multiple iterations)
|
||||||
|
python3 .github/scripts/bench.py \
|
||||||
|
--api-port 8000 \
|
||||||
|
--config .github/configs/bench_config.yaml \
|
||||||
|
--expected-nodes 1 \
|
||||||
|
--is-primary true \
|
||||||
|
--timeout-seconds 600
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output Metrics
|
||||||
|
|
||||||
|
For each stage, the benchmark reports:
|
||||||
|
|
||||||
|
- **Total Requests**: Number of requests fired
|
||||||
|
- **Successful Requests**: Requests that completed successfully
|
||||||
|
- **Failed Requests**: Requests that encountered errors
|
||||||
|
- **Success Rate**: Percentage of successful requests
|
||||||
|
- **Total Tokens**: Sum of all tokens generated across successful requests
|
||||||
|
- **Avg Tokens/Request**: Average tokens per successful request
|
||||||
|
- **Avg Time/Request**: Average completion time per successful request
|
||||||
|
|
||||||
|
A JSON summary is also printed for easy parsing and storage.
|
||||||
|
|
||||||
|
## Creating Custom Benchmarks
|
||||||
|
|
||||||
|
To create a custom benchmark:
|
||||||
|
|
||||||
|
1. Copy an existing config file (e.g., `bench_config.yaml`)
|
||||||
|
2. Modify the stages to match your test scenario
|
||||||
|
3. Save it in this directory with a descriptive name
|
||||||
|
4. Run it using the workflow or command line
|
||||||
|
|
||||||
|
### Example: Sustained Load Test
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
hardware_plan:
|
||||||
|
M3ULTRA_GPU80_512GB: 2
|
||||||
|
|
||||||
|
environment:
|
||||||
|
OVERRIDE_MEMORY_MB: 1024
|
||||||
|
|
||||||
|
timeout_seconds: 600
|
||||||
|
|
||||||
|
model_ids:
|
||||||
|
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||||
|
|
||||||
|
stages:
|
||||||
|
- name: "sustained_load"
|
||||||
|
prompt_length: 200
|
||||||
|
generation_length: 150
|
||||||
|
time_between_requests: 0.5 # Very fast - 2 requests/second
|
||||||
|
iterations: 100 # Run for ~50 seconds
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example: Varying Prompt Sizes
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
hardware_plan:
|
||||||
|
M4PRO_GPU16_24GB: 3
|
||||||
|
|
||||||
|
timeout_seconds: 900
|
||||||
|
|
||||||
|
model_ids:
|
||||||
|
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||||
|
|
||||||
|
stages:
|
||||||
|
- name: "tiny_prompts"
|
||||||
|
prompt_length: 10
|
||||||
|
generation_length: 100
|
||||||
|
time_between_requests: 1.0
|
||||||
|
iterations: 10
|
||||||
|
|
||||||
|
- name: "medium_prompts"
|
||||||
|
prompt_length: 200
|
||||||
|
generation_length: 100
|
||||||
|
time_between_requests: 1.0
|
||||||
|
iterations: 10
|
||||||
|
|
||||||
|
- name: "large_prompts"
|
||||||
|
prompt_length: 1000
|
||||||
|
generation_length: 100
|
||||||
|
time_between_requests: 1.0
|
||||||
|
iterations: 10
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tips
|
||||||
|
|
||||||
|
- **Overlapping Requests**: Set `time_between_requests` < expected completion time to test concurrent request handling
|
||||||
|
- **Sequential Requests**: Set `time_between_requests` > expected completion time to ensure requests don't overlap
|
||||||
|
- **Realistic Load**: Model real usage patterns by varying prompt/generation lengths across stages
|
||||||
|
- **Success Rate**: A 100% success rate indicates the cluster handled the load well; lower rates suggest capacity limits
|
||||||
|
|
||||||
49
.github/configs/bench_config.yaml
vendored
Normal file
49
.github/configs/bench_config.yaml
vendored
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# EXO Staged Benchmark Configuration
|
||||||
|
# This configuration defines a multi-stage load test for EXO clusters
|
||||||
|
|
||||||
|
# Hardware configuration - maps runner labels to instance counts
|
||||||
|
hardware_plan:
|
||||||
|
M3ULTRA_GPU80_512GB: 4
|
||||||
|
|
||||||
|
# Environment variables to set on each node (optional)
|
||||||
|
environment:
|
||||||
|
OVERRIDE_MEMORY_MB: 512
|
||||||
|
|
||||||
|
# Timeout for instance and runner readiness (seconds)
|
||||||
|
timeout_seconds: 600
|
||||||
|
|
||||||
|
# Multiple instances run concurrently on the cluster
|
||||||
|
model_ids:
|
||||||
|
- "mlx-community/Qwen3-0.6B-4bit"
|
||||||
|
- "mlx-community/Qwen3-0.6B-4bit"
|
||||||
|
|
||||||
|
# Stages run sequentially, each with its own characteristics
|
||||||
|
stages:
|
||||||
|
# Stage 1: Light load with short prompts
|
||||||
|
- name: "warmup"
|
||||||
|
prompt_length: 50 # Number of tokens in prompt
|
||||||
|
generation_length: 100 # Max tokens to generate
|
||||||
|
time_between_requests: 5.0 # Seconds between firing requests
|
||||||
|
iterations: 10 # Number of requests to send in this stage
|
||||||
|
|
||||||
|
# Stage 2: Medium load with medium prompts
|
||||||
|
- name: "medium_load"
|
||||||
|
prompt_length: 200
|
||||||
|
generation_length: 150
|
||||||
|
time_between_requests: 3.0
|
||||||
|
iterations: 20
|
||||||
|
|
||||||
|
# Stage 3: Heavy load with long prompts - requests will overlap
|
||||||
|
- name: "stress_test"
|
||||||
|
prompt_length: 500
|
||||||
|
generation_length: 200
|
||||||
|
time_between_requests: 1.0 # Fast firing - will definitely overlap
|
||||||
|
iterations: 30
|
||||||
|
|
||||||
|
# Stage 4: Cool down with simple prompts
|
||||||
|
- name: "cooldown"
|
||||||
|
prompt_length: 50
|
||||||
|
generation_length: 50
|
||||||
|
time_between_requests: 10.0
|
||||||
|
iterations: 5
|
||||||
|
|
||||||
125
.github/configs/bench_simple.yaml
vendored
Normal file
125
.github/configs/bench_simple.yaml
vendored
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# Simple single-shot benchmark
|
||||||
|
# Tests 2 instances concurrently on 2 nodes
|
||||||
|
|
||||||
|
# Hardware configuration - maps runner labels to instance counts
|
||||||
|
hardware_plan:
|
||||||
|
puffin4: 1
|
||||||
|
puffin8: 1
|
||||||
|
|
||||||
|
# Environment variables to set on each node
|
||||||
|
environment:
|
||||||
|
PLACEHOLDER: "placeholder"
|
||||||
|
# OVERRIDE_MEMORY_MB: 50000
|
||||||
|
MLX_METAL_FAST_SYNCH: 1
|
||||||
|
|
||||||
|
# Timeout for instance and runner readiness (seconds)
|
||||||
|
timeout_seconds: 1800
|
||||||
|
|
||||||
|
# Model instances to run concurrently
|
||||||
|
model_ids:
|
||||||
|
# - "mlx-community/DeepSeek-V3.1-8bit"
|
||||||
|
# - "mlx-community/Kimi-K2-Instruct-4bit"
|
||||||
|
- "mlx-community/Kimi-K2-Thinking"
|
||||||
|
# - "mlx-community/Qwen3-235B-A22B-4bit"
|
||||||
|
# - "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||||
|
# - "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||||
|
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||||
|
|
||||||
|
# Sharding strategy: "Pipeline" or "Tensor"
|
||||||
|
sharding: "Tensor"
|
||||||
|
|
||||||
|
# Instance type: "MlxRing" or "MlxIbv"
|
||||||
|
instance_meta: "MlxIbv"
|
||||||
|
|
||||||
|
# If true, run requests sequentially (no overlap); if false, fire-and-forget (default: false)
|
||||||
|
no_overlap: true
|
||||||
|
|
||||||
|
# Benchmark stages
|
||||||
|
# pp: 64, 256, 1024, 2048, 4096, 8192, 16384
|
||||||
|
# g: 64, 512
|
||||||
|
stages:
|
||||||
|
# - name: "simple"
|
||||||
|
# prompt_length: 512
|
||||||
|
# generation_length: 10
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 5
|
||||||
|
# - name: "pp64_g64"
|
||||||
|
# prompt_length: 64
|
||||||
|
# generation_length: 64
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 5
|
||||||
|
# - name: "pp64_g64"
|
||||||
|
# prompt_length: 64
|
||||||
|
# generation_length: 64
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 5
|
||||||
|
# - name: "pp64_g512"
|
||||||
|
# prompt_length: 64
|
||||||
|
# generation_length: 512
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 10
|
||||||
|
# - name: "pp256_g64"
|
||||||
|
# prompt_length: 256
|
||||||
|
# generation_length: 64
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 5
|
||||||
|
- name: "pp256_g64"
|
||||||
|
prompt_length: 256
|
||||||
|
generation_length: 64
|
||||||
|
time_between_requests: 2.0
|
||||||
|
iterations: 5
|
||||||
|
# - name: "pp256_g512"
|
||||||
|
# prompt_length: 256
|
||||||
|
# generation_length: 512
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 10
|
||||||
|
# - name: "pp1024_g64"
|
||||||
|
# prompt_length: 1024
|
||||||
|
# generation_length: 64
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 5
|
||||||
|
# - name: "pp1024_g512"
|
||||||
|
# prompt_length: 1024
|
||||||
|
# generation_length: 512
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 10
|
||||||
|
# - name: "pp2048_g64"
|
||||||
|
# prompt_length: 2048
|
||||||
|
# generation_length: 64
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 5
|
||||||
|
# - name: "pp2048_g512"
|
||||||
|
# prompt_length: 2048
|
||||||
|
# generation_length: 512
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 10
|
||||||
|
# - name: "pp4096_g64"
|
||||||
|
# prompt_length: 4096
|
||||||
|
# generation_length: 64
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 4
|
||||||
|
# - name: "pp4096_g512"
|
||||||
|
# prompt_length: 4096
|
||||||
|
# generation_length: 512
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 10
|
||||||
|
# - name: "pp8192_g64"
|
||||||
|
# prompt_length: 8192
|
||||||
|
# generation_length: 64
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 5
|
||||||
|
# - name: "pp8192_g512"
|
||||||
|
# prompt_length: 8192
|
||||||
|
# generation_length: 512
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 5
|
||||||
|
# - name: "pp16384_g64"
|
||||||
|
# prompt_length: 16384
|
||||||
|
# generation_length: 64
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 10
|
||||||
|
# - name: "pp16384_g512"
|
||||||
|
# prompt_length: 16384
|
||||||
|
# generation_length: 512
|
||||||
|
# time_between_requests: 2.0
|
||||||
|
# iterations: 10
|
||||||
1399
.github/scripts/bench.py
vendored
Normal file
1399
.github/scripts/bench.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
70
.github/scripts/build_matrix.py
vendored
Normal file
70
.github/scripts/build_matrix.py
vendored
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import NotRequired, TypedDict, cast
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixEntry(TypedDict):
|
||||||
|
label: str
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
class MatrixInclude(TypedDict):
|
||||||
|
label: str
|
||||||
|
index: int
|
||||||
|
is_primary: bool
|
||||||
|
expected_nodes: int
|
||||||
|
|
||||||
|
|
||||||
|
class Config(TypedDict):
|
||||||
|
hardware_plan: dict[str, int]
|
||||||
|
timeout_seconds: NotRequired[int]
|
||||||
|
environment: NotRequired[dict[str, str]]
|
||||||
|
|
||||||
|
|
||||||
|
# Read the config file
|
||||||
|
config_file: str = os.environ["CONFIG_FILE"]
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config: Config = cast(Config, yaml.safe_load(f))
|
||||||
|
|
||||||
|
# Extract hardware plan from config
|
||||||
|
plan: dict[str, int] = config["hardware_plan"]
|
||||||
|
if not plan:
|
||||||
|
raise ValueError(f"No hardware_plan found in {config_file}")
|
||||||
|
|
||||||
|
# Build matrix entries
|
||||||
|
entries: list[MatrixEntry] = []
|
||||||
|
for label, count in plan.items():
|
||||||
|
for idx in range(count):
|
||||||
|
entries.append({"label": label, "index": idx})
|
||||||
|
|
||||||
|
total_nodes: int = len(entries)
|
||||||
|
matrix: dict[str, list[MatrixInclude]] = {
|
||||||
|
"include": [
|
||||||
|
{
|
||||||
|
"label": e["label"],
|
||||||
|
"index": e["index"],
|
||||||
|
"is_primary": (i == 0),
|
||||||
|
"expected_nodes": total_nodes,
|
||||||
|
}
|
||||||
|
for i, e in enumerate(entries)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract other config values
|
||||||
|
timeout_seconds: int = config.get("timeout_seconds", 600)
|
||||||
|
environment: dict[str, str] = config.get("environment", {})
|
||||||
|
|
||||||
|
# Output to GitHub Actions
|
||||||
|
with open(os.environ["GITHUB_OUTPUT"], "a") as f:
|
||||||
|
f.write(f"matrix={json.dumps(matrix)}\n")
|
||||||
|
f.write(f"config_file={config_file}\n")
|
||||||
|
f.write(f"timeout_seconds={timeout_seconds}\n")
|
||||||
|
f.write(f"environment={json.dumps(environment)}\n")
|
||||||
|
|
||||||
|
print(f"Matrix: {json.dumps(matrix)}")
|
||||||
|
print(f"Config file: {config_file}")
|
||||||
|
print(f"Timeout: {timeout_seconds}")
|
||||||
|
print(f"Environment: {json.dumps(environment)}")
|
||||||
156
.github/workflows/BENCH_USAGE.md
vendored
Normal file
156
.github/workflows/BENCH_USAGE.md
vendored
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
# Benchmark Workflow Usage
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `bench_matrix.yml` workflow enables distributed benchmarking of models across multiple self-hosted macOS runners with different hardware configurations.
|
||||||
|
|
||||||
|
## Workflow Inputs
|
||||||
|
|
||||||
|
| Input | Description | Default | Required |
|
||||||
|
|-------|-------------|---------|----------|
|
||||||
|
| `model_id` | Model ID to benchmark | `mlx-community/Llama-3.2-1B-Instruct-4bit` | Yes |
|
||||||
|
| `hardware_plan` | JSON mapping of runner labels to counts | `{"M4PRO_GPU16_24GB": 1}` | Yes |
|
||||||
|
| `prompt` | Benchmark prompt text | `What is the capital of France?` | No |
|
||||||
|
| `timeout_seconds` | Timeout for instance/runner readiness | `600` | No |
|
||||||
|
|
||||||
|
## Hardware Plan Format
|
||||||
|
|
||||||
|
The `hardware_plan` input is a JSON object mapping runner labels to the number of machines:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"M4PRO_GPU16_24GB": 2,
|
||||||
|
"M3ULTRA_GPU80_512GB": 1
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This example would:
|
||||||
|
- Start 2 runners with the `M4PRO_GPU16_24GB` label
|
||||||
|
- Start 1 runner with the `M3ULTRA_GPU80_512GB` label
|
||||||
|
- Total of 3 runners coordinating on a single distributed inference instance
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
1. **Planning Job** (`plan`)
|
||||||
|
- Runs on `ubuntu-latest`
|
||||||
|
- Parses the `hardware_plan` JSON
|
||||||
|
- Generates a dynamic matrix with one entry per runner
|
||||||
|
- Only the first runner (index 0) is marked as `is_primary`
|
||||||
|
|
||||||
|
2. **Benchmark Worker Jobs** (`bench_worker`)
|
||||||
|
- Each job runs on a self-hosted macOS runner with the specified label
|
||||||
|
- All runners start EXO in parallel
|
||||||
|
- The primary runner creates the model instance
|
||||||
|
- All runners wait for their assigned runner to be ready (Loaded/Running status)
|
||||||
|
- The primary runner executes the benchmark and prints results
|
||||||
|
- The primary runner deletes the instance
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
### Single Machine Benchmark
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_id: mlx-community/Llama-3.2-1B-Instruct-4bit
|
||||||
|
hardware_plan: '{"M4PRO_GPU16_24GB": 1}'
|
||||||
|
prompt: What is the capital of France?
|
||||||
|
timeout_seconds: 600
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Machine Distributed Benchmark
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_id: mlx-community/Llama-3.2-3B-Instruct-4bit
|
||||||
|
hardware_plan: '{"M4PRO_GPU16_24GB": 2, "M3ULTRA_GPU80_512GB": 1}'
|
||||||
|
prompt: Explain quantum computing in simple terms.
|
||||||
|
timeout_seconds: 900
|
||||||
|
```
|
||||||
|
|
||||||
|
## Benchmark Output
|
||||||
|
|
||||||
|
The primary runner outputs a JSON object with benchmark results:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||||
|
"instance_id": "abc-123-def",
|
||||||
|
"tokens": 42,
|
||||||
|
"elapsed_s": 2.451,
|
||||||
|
"tps": 17.136
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Where:
|
||||||
|
- `tokens`: Number of chunks/tokens generated
|
||||||
|
- `elapsed_s`: Total elapsed time in seconds
|
||||||
|
- `tps`: Tokens per second (tokens / elapsed_s)
|
||||||
|
|
||||||
|
## Runner Requirements
|
||||||
|
|
||||||
|
Each self-hosted runner must:
|
||||||
|
- Be labeled with appropriate hardware tags (e.g., `M4PRO_GPU16_24GB`)
|
||||||
|
- Have the `self-hosted` and `macOS` labels
|
||||||
|
- Have Nix installed with flakes enabled
|
||||||
|
- Have network connectivity to other runners in the same job
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ GitHub Actions Workflow (bench_matrix.yml) │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ ┌────────────────┐ │
|
||||||
|
│ │ Plan Job │ │
|
||||||
|
│ │ (ubuntu) │──┬─► Matrix: [{label, index, primary}] │
|
||||||
|
│ └────────────────┘ │ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌───────────────────▼──────────────────────────────────┐ │
|
||||||
|
│ │ Bench Worker Jobs (Matrix) │ │
|
||||||
|
│ ├──────────────────────────────────────────────────────┤ │
|
||||||
|
│ │ │ │
|
||||||
|
│ │ Runner 0 (Primary) Runner 1 Runner 2 │ │
|
||||||
|
│ │ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ │
|
||||||
|
│ │ │ Start EXO │ │ Start EXO │ │ Start EXO│ │ │
|
||||||
|
│ │ │ Create Inst │ │ Wait... │ │ Wait... │ │ │
|
||||||
|
│ │ │ Wait Ready │ │ Wait Ready │ │ Wait... │ │ │
|
||||||
|
│ │ │ Run Bench │ │ (idle) │ │ (idle) │ │ │
|
||||||
|
│ │ │ Print TPS │ │ │ │ │ │ │
|
||||||
|
│ │ │ Delete Inst │ │ │ │ │ │ │
|
||||||
|
│ │ └─────────────┘ └─────────────┘ └──────────┘ │ │
|
||||||
|
│ └───────────────────────────────────────────────────────┘ │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
### `scripts/bench.py`
|
||||||
|
|
||||||
|
A standalone Python script that:
|
||||||
|
- Creates instance (primary only)
|
||||||
|
- Polls `/state` endpoint until instance and all runners are ready
|
||||||
|
- Executes chat completion with timing (primary only)
|
||||||
|
- Parses SSE stream and counts tokens
|
||||||
|
- Computes TPS metrics
|
||||||
|
- Cleans up instance (primary only)
|
||||||
|
|
||||||
|
### Key Functions
|
||||||
|
|
||||||
|
- `wait_for_instance()`: Polls until instance with model_id appears
|
||||||
|
- `wait_for_runners_ready()`: Polls until expected number of runners reach Loaded/Running status
|
||||||
|
- `run_benchmark()`: Executes chat completion, measures time, counts tokens
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Instance never becomes ready
|
||||||
|
- Check EXO logs in the workflow output
|
||||||
|
- Verify model_id is valid and accessible
|
||||||
|
- Increase `timeout_seconds`
|
||||||
|
|
||||||
|
### Runner mismatch
|
||||||
|
- Ensure hardware_plan counts match available labeled runners
|
||||||
|
- Check runner labels match exactly (case-sensitive)
|
||||||
|
|
||||||
|
### Network issues
|
||||||
|
- Verify runners can communicate on the network
|
||||||
|
- Check firewall rules between runner hosts
|
||||||
|
|
||||||
305
.github/workflows/bench.yml
vendored
Normal file
305
.github/workflows/bench.yml
vendored
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
name: bench
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
plan:
|
||||||
|
if: contains(github.event.head_commit.message, '/bench')
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
matrix: ${{ steps.build.outputs.matrix }}
|
||||||
|
config_file: ${{ steps.build.outputs.config_file }}
|
||||||
|
timeout_seconds: ${{ steps.build.outputs.timeout_seconds }}
|
||||||
|
environment: ${{ steps.build.outputs.environment }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Build matrix from config file
|
||||||
|
id: build
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
CONFIG_FILE='.github/configs/bench_simple.yaml'
|
||||||
|
export CONFIG_FILE
|
||||||
|
echo "Config file: $CONFIG_FILE"
|
||||||
|
python3 .github/scripts/build_matrix.py
|
||||||
|
|
||||||
|
bench_worker:
|
||||||
|
needs: plan
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix: ${{ fromJSON(needs.plan.outputs.matrix) }}
|
||||||
|
name: "bench on ${{ matrix.label }} [${{ matrix.index }}]"
|
||||||
|
runs-on: [self-hosted, macOS, "${{ matrix.label }}"]
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
lfs: false
|
||||||
|
|
||||||
|
- name: Configure git user
|
||||||
|
run: |
|
||||||
|
git config --local user.email "github-actions@users.noreply.github.com"
|
||||||
|
git config --local user.name "github-actions bot"
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
# TODO: this is mega hacky and I'd like a simpler solution.
|
||||||
|
- name: Setup Nix Environment
|
||||||
|
run: |
|
||||||
|
echo "Checking for nix installation..."
|
||||||
|
|
||||||
|
# Check if nix is already available
|
||||||
|
if command -v nix >/dev/null 2>&1; then
|
||||||
|
echo "Nix already in PATH"
|
||||||
|
# Try sourcing profile scripts to set up environment properly
|
||||||
|
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||||
|
echo "Sourcing multi-user nix-daemon profile script"
|
||||||
|
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||||
|
elif [ -f "$HOME/.nix-profile/etc/profile.d/nix.sh" ]; then
|
||||||
|
echo "Sourcing single-user nix profile script"
|
||||||
|
source "$HOME/.nix-profile/etc/profile.d/nix.sh"
|
||||||
|
elif [ -f /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh ]; then
|
||||||
|
echo "Sourcing per-user nix profile script"
|
||||||
|
source /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh
|
||||||
|
elif [ -f /etc/profile.d/nix.sh ]; then
|
||||||
|
echo "Sourcing system-wide nix profile script"
|
||||||
|
source /etc/profile.d/nix.sh
|
||||||
|
# Fallback: manually add nix to PATH if binary exists
|
||||||
|
elif [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||||
|
echo "Found nix binary, manually adding to PATH"
|
||||||
|
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||||
|
elif [ -f "$HOME/.nix-profile/bin/nix" ]; then
|
||||||
|
echo "Found nix binary in user profile, manually adding to PATH"
|
||||||
|
export PATH="$HOME/.nix-profile/bin:$PATH"
|
||||||
|
else
|
||||||
|
echo "Nix not found. Debugging info:"
|
||||||
|
echo "USER: $USER"
|
||||||
|
echo "HOME: $HOME"
|
||||||
|
echo "Current PATH: $PATH"
|
||||||
|
echo ""
|
||||||
|
echo "Checking common Nix locations:"
|
||||||
|
echo " /nix/var/nix/profiles/default/bin/nix:"
|
||||||
|
ls -la /nix/var/nix/profiles/default/bin/nix 2>/dev/null || echo " Not found"
|
||||||
|
echo " /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh:"
|
||||||
|
ls -la /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh 2>/dev/null || echo " Not found"
|
||||||
|
echo " ~/.nix-profile/etc/profile.d/nix.sh:"
|
||||||
|
ls -la "$HOME/.nix-profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||||
|
echo " /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh:"
|
||||||
|
ls -la "/nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||||
|
echo ""
|
||||||
|
echo "/nix directory structure:"
|
||||||
|
ls -la /nix 2>/dev/null || echo " /nix directory not found"
|
||||||
|
echo ""
|
||||||
|
echo "/nix/var:"
|
||||||
|
ls -la /nix/var 2>/dev/null || echo " /nix/var not found"
|
||||||
|
echo ""
|
||||||
|
echo "/nix/store:"
|
||||||
|
ls -la /nix/store 2>/dev/null | head -20 || echo " /nix/store not found"
|
||||||
|
echo ""
|
||||||
|
echo "GitHub Actions runner is running as user '$USER'."
|
||||||
|
echo "If Nix is installed for a different user, either:"
|
||||||
|
echo " 1. Install Nix for user '$USER' (multi-user install recommended)"
|
||||||
|
echo " 2. Configure the runner service to run as the user with Nix installed"
|
||||||
|
echo " 3. Ensure Nix is installed system-wide with proper daemon setup"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Verify nix is available and persist to GITHUB_ENV
|
||||||
|
if command -v nix >/dev/null 2>&1; then
|
||||||
|
echo "✓ Nix is available"
|
||||||
|
nix --version
|
||||||
|
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||||
|
if [ -n "$NIX_PATH" ]; then
|
||||||
|
echo "NIX_PATH=$NIX_PATH" >> $GITHUB_ENV
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "ERROR: Failed to set up Nix"
|
||||||
|
echo "PATH after setup attempt: $PATH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
- name: Setup EXO_HOME and API_PORT
|
||||||
|
run: |
|
||||||
|
EXO_HOME=$(mktemp -d -t exo-e2e-XXXXXXXX)
|
||||||
|
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||||
|
EXO_MODELS_DIR="$HOME/.exo/models"
|
||||||
|
EXO_LIBP2P_NAMESPACE="bench-${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
|
||||||
|
echo "EXO_HOME=$EXO_HOME" >> "$GITHUB_ENV"
|
||||||
|
echo "API_PORT=$API_PORT" >> "$GITHUB_ENV"
|
||||||
|
echo "EXO_MODELS_DIR=$EXO_MODELS_DIR" >> "$GITHUB_ENV"
|
||||||
|
echo "EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE" >> "$GITHUB_ENV"
|
||||||
|
echo "Created EXO_HOME: $EXO_HOME"
|
||||||
|
echo "Generated API_PORT: $API_PORT"
|
||||||
|
echo "Using models from: $EXO_MODELS_DIR"
|
||||||
|
echo "Using libp2p namespace: $EXO_LIBP2P_NAMESPACE"
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
- name: Configure local MLX if available
|
||||||
|
run: |
|
||||||
|
echo "=== DEBUG: Checking for local MLX configuration ==="
|
||||||
|
MODIFIED=false
|
||||||
|
|
||||||
|
echo "Checking for /Users/Shared/mlx directory..."
|
||||||
|
if [ -d "/Users/Shared/mlx" ]; then
|
||||||
|
echo "✓ Found /Users/Shared/mlx"
|
||||||
|
ls -la /Users/Shared/mlx | head -5
|
||||||
|
echo "Enabling local mlx path in pyproject.toml"
|
||||||
|
sed -i.bak 's|^# mlx = { path = "/Users/Shared/mlx", editable=true }$|mlx = { path = "/Users/Shared/mlx", editable=true }|' pyproject.toml
|
||||||
|
MODIFIED=true
|
||||||
|
else
|
||||||
|
echo "✗ /Users/Shared/mlx not found, will use PyPI version"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Checking for /Users/Shared/mlx-lm directory..."
|
||||||
|
if [ -d "/Users/Shared/mlx-lm" ]; then
|
||||||
|
echo "✓ Found /Users/Shared/mlx-lm"
|
||||||
|
ls -la /Users/Shared/mlx-lm | head -5
|
||||||
|
echo "Enabling local mlx-lm path in pyproject.toml"
|
||||||
|
sed -i.bak 's|^# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }$|mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }|' pyproject.toml
|
||||||
|
MODIFIED=true
|
||||||
|
else
|
||||||
|
echo "✗ /Users/Shared/mlx-lm not found, will use PyPI version"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$MODIFIED" = true ]; then
|
||||||
|
echo "=== Modified pyproject.toml [tool.uv.sources] section: ==="
|
||||||
|
sed -n '/\[tool\.uv\.sources\]/,/^\[/{/^\[tool\.uv\.sources\]/p; /^\[/!p;}' pyproject.toml
|
||||||
|
echo "=== Regenerating uv.lock with local MLX paths... ==="
|
||||||
|
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command uv lock --upgrade-package mlx --upgrade-package mlx-lm
|
||||||
|
echo "✓ Lock file regenerated"
|
||||||
|
else
|
||||||
|
echo "⚠ No local MLX directories found, using PyPI packages"
|
||||||
|
fi
|
||||||
|
echo "=== DEBUG: Local MLX configuration complete ==="
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
- name: Sync dependencies
|
||||||
|
run: |
|
||||||
|
if [ -d "/Users/Shared/test" ]; then
|
||||||
|
pushd /Users/Shared/test
|
||||||
|
uv sync --reinstall
|
||||||
|
popd
|
||||||
|
fi
|
||||||
|
echo "Running just sync to ensure clean dependencies..."
|
||||||
|
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
- name: Start EXO and run bench script
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
IS_PRIMARY: ${{ matrix.is_primary }}
|
||||||
|
EXPECTED_NODES: ${{ matrix.expected_nodes }}
|
||||||
|
HARDWARE_LABEL: ${{ matrix.label }}
|
||||||
|
CONFIG_FILE: ${{ needs.plan.outputs.config_file }}
|
||||||
|
TIMEOUT_SECONDS: ${{ needs.plan.outputs.timeout_seconds }}
|
||||||
|
ENVIRONMENT_JSON: ${{ needs.plan.outputs.environment }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Parse environment variables from config
|
||||||
|
ENV_VARS=""
|
||||||
|
if [ -n "$ENVIRONMENT_JSON" ] && [ "$ENVIRONMENT_JSON" != "{}" ]; then
|
||||||
|
ENV_VARS=$(echo "$ENVIRONMENT_JSON" | python3 -c "import sys, json; env = json.load(sys.stdin); print(' '.join([f'{k}={v}' for k, v in env.items()]))")
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Starting EXO with API_PORT=${API_PORT} EXO_HOME=${EXO_HOME} EXO_LIBP2P_NAMESPACE=${EXO_LIBP2P_NAMESPACE}"
|
||||||
|
echo "Environment variables from config: $ENV_VARS"
|
||||||
|
LOG_FILE=/tmp/exo.log
|
||||||
|
: > "$LOG_FILE"
|
||||||
|
|
||||||
|
MASTER_FLAG=""
|
||||||
|
if [ "$IS_PRIMARY" = "true" ]; then
|
||||||
|
MASTER_FLAG="-m"
|
||||||
|
fi
|
||||||
|
|
||||||
|
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||||
|
"EXO_HOME=$EXO_HOME EXO_MODELS_DIR=$EXO_MODELS_DIR EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE $ENV_VARS PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run exo $MASTER_FLAG --api-port $API_PORT" \
|
||||||
|
>> "$LOG_FILE" 2>&1 &
|
||||||
|
|
||||||
|
EXO_PID=$!
|
||||||
|
echo "Started EXO in background with PID: $EXO_PID"
|
||||||
|
echo "Log file: $LOG_FILE"
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
echo '=== EXO log (tail) ==='
|
||||||
|
tail -n 300 "$LOG_FILE" || true
|
||||||
|
if ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||||
|
echo "Killing EXO (PID $EXO_PID)"
|
||||||
|
kill "$EXO_PID" || true
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
for i in $(seq 1 60); do
|
||||||
|
if curl -s "http://localhost:${API_PORT}/state" >/dev/null 2>&1; then
|
||||||
|
echo "EXO API ready"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
if ! ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||||
|
echo "EXO terminated early"; sed -n '1,200p' "$LOG_FILE" || true; exit 1
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
|
||||||
|
RESULTS_FILE="/tmp/bench_results_${GITHUB_RUN_ID}_${GITHUB_RUN_ATTEMPT}_$(date +%s).json"
|
||||||
|
echo "Results will be saved to: $RESULTS_FILE"
|
||||||
|
echo "RESULTS_FILE=$RESULTS_FILE" >> "$GITHUB_ENV"
|
||||||
|
|
||||||
|
echo "Running bench script with config: $CONFIG_FILE, timeout: $TIMEOUT_SECONDS"
|
||||||
|
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||||
|
"PYTHONUNBUFFERED=1 uv run --no-project --with pyyaml --with pydantic python .github/scripts/bench.py \
|
||||||
|
--api-port $API_PORT \
|
||||||
|
--config $CONFIG_FILE \
|
||||||
|
--expected-nodes ${EXPECTED_NODES} \
|
||||||
|
--is-primary ${IS_PRIMARY} \
|
||||||
|
--timeout-seconds ${TIMEOUT_SECONDS} \
|
||||||
|
--output $RESULTS_FILE \
|
||||||
|
--git-commit ${GITHUB_SHA} \
|
||||||
|
--hardware-labels ${HARDWARE_LABEL}"
|
||||||
|
|
||||||
|
- name: Install AWS CLI
|
||||||
|
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||||
|
run: |
|
||||||
|
if ! command -v aws &> /dev/null; then
|
||||||
|
echo "AWS CLI not found, installing..."
|
||||||
|
brew install awscli
|
||||||
|
else
|
||||||
|
echo "AWS CLI already installed"
|
||||||
|
fi
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
- name: Upload results to S3
|
||||||
|
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||||
|
env:
|
||||||
|
AWS_ACCESS_KEY_ID: ${{ secrets.S3_BENCHMARKS_AWS_ACCESS_KEY_ID }}
|
||||||
|
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
|
||||||
|
AWS_DEFAULT_REGION: us-east-1
|
||||||
|
run: |
|
||||||
|
echo "Checking for results file: $RESULTS_FILE"
|
||||||
|
echo "Is primary: ${{ matrix.is_primary }}"
|
||||||
|
|
||||||
|
if [ -f "$RESULTS_FILE" ]; then
|
||||||
|
TIMESTAMP=$(date -u +%Y/%m/%d/%H%M%S)
|
||||||
|
S3_KEY="bench/${TIMESTAMP}_${GITHUB_SHA:0:8}_${GITHUB_RUN_ID}.json"
|
||||||
|
echo "Uploading results to s3://exo-benchmark-results/$S3_KEY"
|
||||||
|
|
||||||
|
aws s3 cp "$RESULTS_FILE" "s3://exo-benchmark-results/$S3_KEY" \
|
||||||
|
--content-type application/json \
|
||||||
|
--metadata "commit=${GITHUB_SHA},run_id=${GITHUB_RUN_ID},branch=${GITHUB_REF_NAME}"
|
||||||
|
|
||||||
|
echo "Results uploaded successfully"
|
||||||
|
echo "View at: https://exo-benchmark-results.s3.amazonaws.com/$S3_KEY"
|
||||||
|
else
|
||||||
|
echo "Results file not found at: $RESULTS_FILE"
|
||||||
|
echo "Skipping upload"
|
||||||
|
fi
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
- name: Cleanup EXO_HOME
|
||||||
|
run: |
|
||||||
|
echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||||
|
rm -rf "$EXO_HOME"
|
||||||
|
shell: bash
|
||||||
|
if: always()
|
||||||
158
.github/workflows/build-app.yml
vendored
158
.github/workflows/build-app.yml
vendored
@@ -1,18 +1,6 @@
|
|||||||
name: Build EXO macOS DMG
|
name: Build EXO macOS DMG
|
||||||
|
|
||||||
# Release workflow:
|
|
||||||
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
|
|
||||||
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
|
|
||||||
# 3. This workflow builds, signs, and notarizes the DMG
|
|
||||||
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
|
|
||||||
# 5. DMG and appcast.xml are uploaded to S3
|
|
||||||
# 6. The draft GitHub Release is published with the DMG attached
|
|
||||||
#
|
|
||||||
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
|
|
||||||
# If no draft exists, a release is auto-created with generated notes.
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- "v*"
|
- "v*"
|
||||||
@@ -22,17 +10,14 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
build-macos-app:
|
build-macos-app:
|
||||||
runs-on: "macos-26"
|
runs-on: "macos-26"
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
env:
|
env:
|
||||||
SPARKLE_VERSION: 2.9.0-beta.1
|
SPARKLE_VERSION: 2.8.1
|
||||||
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
|
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
|
||||||
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
|
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
|
||||||
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
|
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
|
||||||
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
|
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
|
||||||
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
|
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
|
||||||
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
|
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
|
||||||
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT: ${{ secrets.EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT }}
|
|
||||||
AWS_REGION: ${{ secrets.AWS_REGION }}
|
AWS_REGION: ${{ secrets.AWS_REGION }}
|
||||||
EXO_BUILD_NUMBER: ${{ github.run_number }}
|
EXO_BUILD_NUMBER: ${{ github.run_number }}
|
||||||
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
|
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
|
||||||
@@ -49,7 +34,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Derive release version from tag
|
- name: Derive release version from tag
|
||||||
run: |
|
run: |
|
||||||
if [[ "$GITHUB_REF_NAME" == "test-app" || "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
if [[ "$GITHUB_REF_NAME" == "test-app" ]]; then
|
||||||
VERSION="0.0.0-alpha.0"
|
VERSION="0.0.0-alpha.0"
|
||||||
echo "IS_ALPHA=true" >> $GITHUB_ENV
|
echo "IS_ALPHA=true" >> $GITHUB_ENV
|
||||||
else
|
else
|
||||||
@@ -62,32 +47,6 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
|
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Compute build version from semver
|
|
||||||
run: |
|
|
||||||
VERSION="$RELEASE_VERSION"
|
|
||||||
# Extract major.minor.patch (strip prerelease suffix)
|
|
||||||
BASE_VERSION="${VERSION%%-*}"
|
|
||||||
MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
|
|
||||||
MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
|
|
||||||
PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)
|
|
||||||
|
|
||||||
# Extract prerelease number (e.g., "alpha.2" -> 2, or 999 for releases)
|
|
||||||
if [[ "$VERSION" == *-* ]]; then
|
|
||||||
PRERELEASE_PART="${VERSION#*-}"
|
|
||||||
PRERELEASE_NUM="${PRERELEASE_PART##*.}"
|
|
||||||
# Default to 0 if not a number
|
|
||||||
if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
|
|
||||||
PRERELEASE_NUM=0
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
PRERELEASE_NUM=999
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Compute: PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR)
|
|
||||||
BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
|
|
||||||
echo "EXO_BUILD_VERSION=$BUILD_VERSION" >> $GITHUB_ENV
|
|
||||||
echo "Computed build version: $BUILD_VERSION from $VERSION"
|
|
||||||
|
|
||||||
- name: Ensure tag commit is on main
|
- name: Ensure tag commit is on main
|
||||||
if: github.ref_type == 'tag'
|
if: github.ref_type == 'tag'
|
||||||
run: |
|
run: |
|
||||||
@@ -100,52 +59,6 @@ jobs:
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Fetch and validate release notes
|
|
||||||
if: github.ref_type == 'tag'
|
|
||||||
env:
|
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
run: |
|
|
||||||
# Find draft release by name using gh release list (more reliable with default token)
|
|
||||||
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
|
|
||||||
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
|
|
||||||
|
|
||||||
if [[ -z "$DRAFT_EXISTS" ]]; then
|
|
||||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
|
||||||
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
|
|
||||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
|
|
||||||
echo "Please create a draft release with release notes before pushing the tag."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Fetch full release details via API to get body and ID
|
|
||||||
echo "Found draft release, fetching details..."
|
|
||||||
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
|
|
||||||
|
|
||||||
# Extract release notes
|
|
||||||
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
|
|
||||||
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
|
|
||||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
|
||||||
echo "Draft release has no notes (optional for alphas)"
|
|
||||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
echo "ERROR: Draft release exists but has no release notes"
|
|
||||||
echo "Please add release notes to the draft release before pushing the tag."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Save release ID for later publishing
|
|
||||||
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
|
|
||||||
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
|
|
||||||
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
|
|
||||||
echo "$NOTES" > /tmp/release_notes.md
|
|
||||||
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -172,22 +85,11 @@ jobs:
|
|||||||
uv python install
|
uv python install
|
||||||
uv sync --locked
|
uv sync --locked
|
||||||
|
|
||||||
- name: Install Nix
|
|
||||||
uses: cachix/install-nix-action@v31
|
|
||||||
with:
|
|
||||||
nix_path: nixpkgs=channel:nixos-unstable
|
|
||||||
|
|
||||||
- name: Configure Cachix
|
|
||||||
uses: cachix/cachix-action@v14
|
|
||||||
with:
|
|
||||||
name: exo
|
|
||||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
|
||||||
|
|
||||||
- name: Build dashboard
|
- name: Build dashboard
|
||||||
run: |
|
run: |
|
||||||
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
|
cd dashboard
|
||||||
mkdir -p dashboard/build
|
npm ci
|
||||||
cp -r "$DASHBOARD_OUT"/* dashboard/build/
|
npm run build
|
||||||
|
|
||||||
- name: Install Sparkle CLI
|
- name: Install Sparkle CLI
|
||||||
run: |
|
run: |
|
||||||
@@ -260,12 +162,11 @@ jobs:
|
|||||||
-configuration Release \
|
-configuration Release \
|
||||||
-derivedDataPath build \
|
-derivedDataPath build \
|
||||||
MARKETING_VERSION="$RELEASE_VERSION" \
|
MARKETING_VERSION="$RELEASE_VERSION" \
|
||||||
CURRENT_PROJECT_VERSION="$EXO_BUILD_VERSION" \
|
CURRENT_PROJECT_VERSION="$EXO_BUILD_NUMBER" \
|
||||||
EXO_BUILD_TAG="$RELEASE_VERSION" \
|
EXO_BUILD_TAG="$RELEASE_VERSION" \
|
||||||
EXO_BUILD_COMMIT="$GITHUB_SHA" \
|
EXO_BUILD_COMMIT="$GITHUB_SHA" \
|
||||||
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
|
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
|
||||||
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
|
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
|
||||||
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT="$EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT" \
|
|
||||||
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
|
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
|
||||||
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
|
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
|
||||||
mkdir -p ../../output
|
mkdir -p ../../output
|
||||||
@@ -363,28 +264,6 @@ jobs:
|
|||||||
$CHANNEL_FLAG \
|
$CHANNEL_FLAG \
|
||||||
.
|
.
|
||||||
|
|
||||||
- name: Inject release notes into appcast
|
|
||||||
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
|
|
||||||
env:
|
|
||||||
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
|
|
||||||
run: |
|
|
||||||
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
|
|
||||||
export NOTES=$(cat "$RELEASE_NOTES_FILE")
|
|
||||||
|
|
||||||
# Insert description after the enclosure tag for this version
|
|
||||||
awk '
|
|
||||||
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
|
|
||||||
print
|
|
||||||
print " <description sparkle:format=\"markdown\"><![CDATA["
|
|
||||||
print ENVIRON["NOTES"]
|
|
||||||
print " ]]></description>"
|
|
||||||
next
|
|
||||||
}
|
|
||||||
{ print }
|
|
||||||
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
|
|
||||||
|
|
||||||
echo "Injected markdown release notes for version $RELEASE_VERSION"
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Upload artifacts
|
# Upload artifacts
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -415,28 +294,5 @@ jobs:
|
|||||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
|
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
|
||||||
if [[ "$IS_ALPHA" != "true" ]]; then
|
if [[ "$IS_ALPHA" != "true" ]]; then
|
||||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
||||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Publish GitHub Release
|
|
||||||
if: github.ref_type == 'tag'
|
|
||||||
env:
|
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
run: |
|
|
||||||
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
|
|
||||||
|
|
||||||
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
|
|
||||||
# Update the draft release with the tag and upload DMG
|
|
||||||
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
|
|
||||||
-f tag_name="$GITHUB_REF_NAME" \
|
|
||||||
-F draft=false
|
|
||||||
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
|
|
||||||
echo "Published release $GITHUB_REF_NAME with DMG attached"
|
|
||||||
else
|
|
||||||
# Alpha without draft release - create one with auto-generated notes
|
|
||||||
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
|
|
||||||
--title "$GITHUB_REF_NAME" \
|
|
||||||
--generate-notes \
|
|
||||||
--prerelease
|
|
||||||
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
|
|
||||||
fi
|
fi
|
||||||
|
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||||
|
|||||||
117
.github/workflows/pipeline.yml
vendored
117
.github/workflows/pipeline.yml
vendored
@@ -20,12 +20,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
nix_path: nixpkgs=channel:nixos-unstable
|
nix_path: nixpkgs=channel:nixos-unstable
|
||||||
|
|
||||||
- uses: cachix/cachix-action@v14
|
|
||||||
name: Configure Cachix
|
|
||||||
with:
|
|
||||||
name: exo
|
|
||||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
|
||||||
|
|
||||||
- name: Configure git user
|
- name: Configure git user
|
||||||
run: |
|
run: |
|
||||||
git config --local user.email "github-actions@users.noreply.github.com"
|
git config --local user.email "github-actions@users.noreply.github.com"
|
||||||
@@ -94,19 +88,9 @@ jobs:
|
|||||||
|
|
||||||
- uses: ./.github/actions/typecheck
|
- uses: ./.github/actions/typecheck
|
||||||
|
|
||||||
nix:
|
nix-flake-check:
|
||||||
name: Build and check (${{ matrix.system }})
|
name: Check Nix flake
|
||||||
runs-on: ${{ matrix.runner }}
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- runner: macos-26
|
|
||||||
system: aarch64-darwin
|
|
||||||
- runner: ubuntu-latest
|
|
||||||
system: x86_64-linux
|
|
||||||
- runner: ubuntu-24.04-arm
|
|
||||||
system: aarch64-linux
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -117,20 +101,83 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
nix_path: nixpkgs=channel:nixos-unstable
|
nix_path: nixpkgs=channel:nixos-unstable
|
||||||
|
|
||||||
- uses: cachix/cachix-action@v14
|
|
||||||
name: Configure Cachix
|
|
||||||
with:
|
|
||||||
name: exo
|
|
||||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
|
||||||
|
|
||||||
- name: Build all Nix outputs
|
|
||||||
run: |
|
|
||||||
nix flake show --json | jq -r '
|
|
||||||
[
|
|
||||||
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
|
|
||||||
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
|
|
||||||
] | .[]
|
|
||||||
' | xargs nix build
|
|
||||||
|
|
||||||
- name: Run nix flake check
|
- name: Run nix flake check
|
||||||
run: nix flake check
|
run: |
|
||||||
|
nix flake check
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
# ci:
|
||||||
|
# needs: typecheck
|
||||||
|
# runs-on: ubuntu-latest
|
||||||
|
# permissions:
|
||||||
|
# contents: read
|
||||||
|
# env:
|
||||||
|
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
# steps:
|
||||||
|
# - name: Checkout repository
|
||||||
|
# uses: actions/checkout@v4
|
||||||
|
# with:
|
||||||
|
# fetch-depth: 0
|
||||||
|
# token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
# lfs: true
|
||||||
|
#
|
||||||
|
# - name: Configure git user
|
||||||
|
# run: |
|
||||||
|
# git config --local user.email "github-actions@users.noreply.github.com"
|
||||||
|
# git config --local user.name "github-actions bot"
|
||||||
|
# shell: bash
|
||||||
|
#
|
||||||
|
# - name: Pull LFS files
|
||||||
|
# run: |
|
||||||
|
# echo "Pulling Git LFS files..."
|
||||||
|
# git lfs pull
|
||||||
|
# shell: bash
|
||||||
|
#
|
||||||
|
# - name: Setup EXO_HOME and API_PORT
|
||||||
|
# run: |
|
||||||
|
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
|
||||||
|
# # Generate random port (macOS compatible method)
|
||||||
|
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||||
|
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
|
||||||
|
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
|
||||||
|
# echo "Created EXO_HOME: $EXO_HOME"
|
||||||
|
# echo "Generated API_PORT: $API_PORT"
|
||||||
|
# shell: bash
|
||||||
|
#
|
||||||
|
# - name: Setup Nix Environment
|
||||||
|
# run: |
|
||||||
|
# echo "Checking for nix installation..."
|
||||||
|
#
|
||||||
|
# # Check if nix binary exists directly
|
||||||
|
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||||
|
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||||
|
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||||
|
# echo "PATH=$PATH" >> $GITHUB_ENV
|
||||||
|
# nix --version
|
||||||
|
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||||
|
# echo "Found nix profile script, sourcing..."
|
||||||
|
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||||
|
# nix --version
|
||||||
|
# elif command -v nix >/dev/null 2>&1; then
|
||||||
|
# echo "Nix already in PATH"
|
||||||
|
# nix --version
|
||||||
|
# else
|
||||||
|
# echo "Nix not found. Debugging info:"
|
||||||
|
# echo "Contents of /nix/var/nix/profiles/default/:"
|
||||||
|
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||||
|
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||||
|
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||||
|
# exit 1
|
||||||
|
# fi
|
||||||
|
# shell: bash
|
||||||
|
#
|
||||||
|
# - uses: ./.github/actions/lint-check
|
||||||
|
#
|
||||||
|
# - uses: ./.github/actions/unit-test
|
||||||
|
#
|
||||||
|
# - name: Cleanup EXO_HOME
|
||||||
|
# run: |
|
||||||
|
# echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||||
|
# rm -rf "$EXO_HOME"
|
||||||
|
# shell: bash
|
||||||
|
# if: always()
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -7,8 +7,6 @@ digest.txt
|
|||||||
# nix
|
# nix
|
||||||
.direnv/
|
.direnv/
|
||||||
|
|
||||||
# IDEA (PyCharm)
|
|
||||||
.idea
|
|
||||||
|
|
||||||
# xcode / macos
|
# xcode / macos
|
||||||
*.xcuserstate
|
*.xcuserstate
|
||||||
@@ -16,7 +14,6 @@ digest.txt
|
|||||||
*.xcuserdatad/
|
*.xcuserdatad/
|
||||||
**/.DS_Store
|
**/.DS_Store
|
||||||
app/EXO/build/
|
app/EXO/build/
|
||||||
dist/
|
|
||||||
|
|
||||||
|
|
||||||
# rust
|
# rust
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
if "TOKENIZERS_PARALLELISM" not in os.environ: ...
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import PIL.Image
|
|
||||||
import tqdm
|
|
||||||
from typing import Protocol
|
|
||||||
from mflux.models.common.config.config import Config
|
|
||||||
|
|
||||||
class BeforeLoopCallback(Protocol):
|
|
||||||
def call_before_loop(
|
|
||||||
self,
|
|
||||||
seed: int,
|
|
||||||
prompt: str,
|
|
||||||
latents: mx.array,
|
|
||||||
config: Config,
|
|
||||||
canny_image: PIL.Image.Image | None = ...,
|
|
||||||
depth_image: PIL.Image.Image | None = ...,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
class InLoopCallback(Protocol):
|
|
||||||
def call_in_loop(
|
|
||||||
self,
|
|
||||||
t: int,
|
|
||||||
seed: int,
|
|
||||||
prompt: str,
|
|
||||||
latents: mx.array,
|
|
||||||
config: Config,
|
|
||||||
time_steps: tqdm,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
class AfterLoopCallback(Protocol):
|
|
||||||
def call_after_loop(
|
|
||||||
self, seed: int, prompt: str, latents: mx.array, config: Config
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
class InterruptCallback(Protocol):
|
|
||||||
def call_interrupt(
|
|
||||||
self,
|
|
||||||
t: int,
|
|
||||||
seed: int,
|
|
||||||
prompt: str,
|
|
||||||
latents: mx.array,
|
|
||||||
config: Config,
|
|
||||||
time_steps: tqdm,
|
|
||||||
) -> None: ...
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from mflux.callbacks.callback import (
|
|
||||||
AfterLoopCallback,
|
|
||||||
BeforeLoopCallback,
|
|
||||||
InLoopCallback,
|
|
||||||
InterruptCallback,
|
|
||||||
)
|
|
||||||
from mflux.callbacks.generation_context import GenerationContext
|
|
||||||
from mflux.models.common.config.config import Config
|
|
||||||
|
|
||||||
if TYPE_CHECKING: ...
|
|
||||||
|
|
||||||
class CallbackRegistry:
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def register(self, callback) -> None: ...
|
|
||||||
def start(self, seed: int, prompt: str, config: Config) -> GenerationContext: ...
|
|
||||||
def before_loop_callbacks(self) -> list[BeforeLoopCallback]: ...
|
|
||||||
def in_loop_callbacks(self) -> list[InLoopCallback]: ...
|
|
||||||
def after_loop_callbacks(self) -> list[AfterLoopCallback]: ...
|
|
||||||
def interrupt_callbacks(self) -> list[InterruptCallback]: ...
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import PIL.Image
|
|
||||||
import tqdm
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from mflux.callbacks.callback_registry import CallbackRegistry
|
|
||||||
from mflux.models.common.config.config import Config
|
|
||||||
|
|
||||||
if TYPE_CHECKING: ...
|
|
||||||
|
|
||||||
class GenerationContext:
|
|
||||||
def __init__(
|
|
||||||
self, registry: CallbackRegistry, seed: int, prompt: str, config: Config
|
|
||||||
) -> None: ...
|
|
||||||
def before_loop(
|
|
||||||
self,
|
|
||||||
latents: mx.array,
|
|
||||||
*,
|
|
||||||
canny_image: PIL.Image.Image | None = ...,
|
|
||||||
depth_image: PIL.Image.Image | None = ...,
|
|
||||||
) -> None: ...
|
|
||||||
def in_loop(self, t: int, latents: mx.array, time_steps: tqdm = ...) -> None: ...
|
|
||||||
def after_loop(self, latents: mx.array) -> None: ...
|
|
||||||
def interruption(
|
|
||||||
self, t: int, latents: mx.array, time_steps: tqdm = ...
|
|
||||||
) -> None: ...
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
BATTERY_PERCENTAGE_STOP_LIMIT = ...
|
|
||||||
CONTROLNET_STRENGTH = ...
|
|
||||||
DEFAULT_DEV_FILL_GUIDANCE = ...
|
|
||||||
DEFAULT_DEPTH_GUIDANCE = ...
|
|
||||||
DIMENSION_STEP_PIXELS = ...
|
|
||||||
GUIDANCE_SCALE = ...
|
|
||||||
GUIDANCE_SCALE_KONTEXT = ...
|
|
||||||
IMAGE_STRENGTH = ...
|
|
||||||
MODEL_CHOICES = ...
|
|
||||||
MODEL_INFERENCE_STEPS = ...
|
|
||||||
QUANTIZE_CHOICES = ...
|
|
||||||
if os.environ.get("MFLUX_CACHE_DIR"):
|
|
||||||
MFLUX_CACHE_DIR = ...
|
|
||||||
else:
|
|
||||||
MFLUX_CACHE_DIR = ...
|
|
||||||
MFLUX_LORA_CACHE_DIR = ...
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mflux.models.common.config.config import Config
|
|
||||||
from mflux.models.common.config.model_config import ModelConfig
|
|
||||||
|
|
||||||
__all__ = ["Config", "ModelConfig"]
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from tqdm import tqdm
|
|
||||||
from mflux.models.common.config.model_config import ModelConfig
|
|
||||||
|
|
||||||
logger = ...
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
num_inference_steps: int = ...,
|
|
||||||
height: int = ...,
|
|
||||||
width: int = ...,
|
|
||||||
guidance: float = ...,
|
|
||||||
image_path: Path | str | None = ...,
|
|
||||||
image_strength: float | None = ...,
|
|
||||||
depth_image_path: Path | str | None = ...,
|
|
||||||
redux_image_paths: list[Path | str] | None = ...,
|
|
||||||
redux_image_strengths: list[float] | None = ...,
|
|
||||||
masked_image_path: Path | str | None = ...,
|
|
||||||
controlnet_strength: float | None = ...,
|
|
||||||
scheduler: str = ...,
|
|
||||||
) -> None: ...
|
|
||||||
@property
|
|
||||||
def height(self) -> int: ...
|
|
||||||
@property
|
|
||||||
def width(self) -> int: ...
|
|
||||||
@width.setter
|
|
||||||
def width(self, value): # -> None:
|
|
||||||
...
|
|
||||||
@property
|
|
||||||
def image_seq_len(self) -> int: ...
|
|
||||||
@property
|
|
||||||
def guidance(self) -> float: ...
|
|
||||||
@property
|
|
||||||
def num_inference_steps(self) -> int: ...
|
|
||||||
@property
|
|
||||||
def precision(self) -> mx.Dtype: ...
|
|
||||||
@property
|
|
||||||
def num_train_steps(self) -> int: ...
|
|
||||||
@property
|
|
||||||
def image_path(self) -> Path | None: ...
|
|
||||||
@property
|
|
||||||
def image_strength(self) -> float | None: ...
|
|
||||||
@property
|
|
||||||
def depth_image_path(self) -> Path | None: ...
|
|
||||||
@property
|
|
||||||
def redux_image_paths(self) -> list[Path] | None: ...
|
|
||||||
@property
|
|
||||||
def redux_image_strengths(self) -> list[float] | None: ...
|
|
||||||
@property
|
|
||||||
def masked_image_path(self) -> Path | None: ...
|
|
||||||
@property
|
|
||||||
def init_time_step(self) -> int: ...
|
|
||||||
@property
|
|
||||||
def time_steps(self) -> tqdm: ...
|
|
||||||
@property
|
|
||||||
def controlnet_strength(self) -> float | None: ...
|
|
||||||
@property
|
|
||||||
def scheduler(self) -> Any: ...
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
class ModelConfig:
|
|
||||||
precision: mx.Dtype = ...
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
priority: int,
|
|
||||||
aliases: list[str],
|
|
||||||
model_name: str,
|
|
||||||
base_model: str | None,
|
|
||||||
controlnet_model: str | None,
|
|
||||||
custom_transformer_model: str | None,
|
|
||||||
num_train_steps: int | None,
|
|
||||||
max_sequence_length: int | None,
|
|
||||||
supports_guidance: bool | None,
|
|
||||||
requires_sigma_shift: bool | None,
|
|
||||||
transformer_overrides: dict | None = ...,
|
|
||||||
) -> None: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def dev() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def schnell() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def dev_kontext() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def dev_fill() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def dev_redux() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def dev_depth() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def dev_controlnet_canny() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def schnell_controlnet_canny() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def dev_controlnet_upscaler() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def dev_fill_catvton() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def krea_dev() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def flux2_klein_4b() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def flux2_klein_9b() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def qwen_image() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def qwen_image_edit() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def fibo() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def z_image_turbo() -> ModelConfig: ...
|
|
||||||
@staticmethod
|
|
||||||
@lru_cache
|
|
||||||
def seedvr2_3b() -> ModelConfig: ...
|
|
||||||
def x_embedder_input_dim(self) -> int: ...
|
|
||||||
def is_canny(self) -> bool: ...
|
|
||||||
@staticmethod
|
|
||||||
def from_name(
|
|
||||||
model_name: str, base_model: Literal["dev", "schnell", "krea-dev"] | None = ...
|
|
||||||
) -> ModelConfig: ...
|
|
||||||
|
|
||||||
AVAILABLE_MODELS = ...
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, TypeAlias
|
|
||||||
from mlx import nn
|
|
||||||
from mflux.models.common.vae.tiling_config import TilingConfig
|
|
||||||
from mflux.models.fibo.latent_creator.fibo_latent_creator import FiboLatentCreator
|
|
||||||
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
|
|
||||||
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
|
||||||
from mflux.models.z_image.latent_creator.z_image_latent_creator import (
|
|
||||||
ZImageLatentCreator,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
LatentCreatorType: TypeAlias = type[
|
|
||||||
FiboLatentCreator | FluxLatentCreator | QwenLatentCreator | ZImageLatentCreator
|
|
||||||
]
|
|
||||||
|
|
||||||
class Img2Img:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vae: nn.Module,
|
|
||||||
latent_creator: LatentCreatorType,
|
|
||||||
sigmas: mx.array,
|
|
||||||
init_time_step: int,
|
|
||||||
image_path: str | Path | None,
|
|
||||||
tiling_config: TilingConfig | None = ...,
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
class LatentCreator:
|
|
||||||
@staticmethod
|
|
||||||
def create_for_txt2img_or_img2img(
|
|
||||||
seed: int, height: int, width: int, img2img: Img2Img
|
|
||||||
) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def encode_image(
|
|
||||||
vae: nn.Module,
|
|
||||||
image_path: str | Path,
|
|
||||||
height: int,
|
|
||||||
width: int,
|
|
||||||
tiling_config: TilingConfig | None = ...,
|
|
||||||
) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def add_noise_by_interpolation(
|
|
||||||
clean: mx.array, noise: mx.array, sigma: float
|
|
||||||
) -> mx.array: ...
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mlx import nn
|
|
||||||
from mflux.models.common.lora.layer.linear_lora_layer import LoRALinear
|
|
||||||
|
|
||||||
class FusedLoRALinear(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, base_linear: nn.Linear | nn.QuantizedLinear, loras: list[LoRALinear]
|
|
||||||
) -> None: ...
|
|
||||||
def __call__(self, x): # -> array:
|
|
||||||
...
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mlx import nn
|
|
||||||
|
|
||||||
class LoRALinear(nn.Module):
|
|
||||||
@staticmethod
|
|
||||||
def from_linear(
|
|
||||||
linear: nn.Linear | nn.QuantizedLinear, r: int = ..., scale: float = ...
|
|
||||||
): # -> LoRALinear:
|
|
||||||
...
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_dims: int,
|
|
||||||
output_dims: int,
|
|
||||||
r: int = ...,
|
|
||||||
scale: float = ...,
|
|
||||||
bias: bool = ...,
|
|
||||||
) -> None: ...
|
|
||||||
def __call__(self, x): # -> array:
|
|
||||||
...
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
from collections.abc import Callable
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from mflux.models.common.lora.mapping.lora_mapping import LoRATarget
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PatternMatch:
|
|
||||||
source_pattern: str
|
|
||||||
target_path: str
|
|
||||||
matrix_name: str
|
|
||||||
transpose: bool
|
|
||||||
transform: Callable[[mx.array], mx.array] | None = ...
|
|
||||||
|
|
||||||
class LoRALoader:
|
|
||||||
@staticmethod
|
|
||||||
def load_and_apply_lora(
|
|
||||||
lora_mapping: list[LoRATarget],
|
|
||||||
transformer: nn.Module,
|
|
||||||
lora_paths: list[str] | None = ...,
|
|
||||||
lora_scales: list[float] | None = ...,
|
|
||||||
) -> tuple[list[str], list[float]]: ...
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from collections.abc import Callable
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Protocol
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoRATarget:
|
|
||||||
model_path: str
|
|
||||||
possible_up_patterns: List[str]
|
|
||||||
possible_down_patterns: List[str]
|
|
||||||
possible_alpha_patterns: List[str] = ...
|
|
||||||
up_transform: Callable[[mx.array], mx.array] | None = ...
|
|
||||||
down_transform: Callable[[mx.array], mx.array] | None = ...
|
|
||||||
|
|
||||||
class LoRAMapping(Protocol):
|
|
||||||
@staticmethod
|
|
||||||
def get_mapping() -> List[LoRATarget]: ...
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class LoRASaver:
|
|
||||||
@staticmethod
|
|
||||||
def bake_and_strip_lora(module: nn.Module) -> nn.Module: ...
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
class LoraTransforms:
|
|
||||||
@staticmethod
|
|
||||||
def split_q_up(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_k_up(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_v_up(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_q_down(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_k_down(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_v_down(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_single_q_up(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_single_k_up(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_single_v_up(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_single_mlp_up(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_single_q_down(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_single_k_down(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_single_v_down(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def split_single_mlp_down(tensor: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mflux.models.common.resolution.config_resolution import ConfigResolution
|
|
||||||
from mflux.models.common.resolution.lora_resolution import LoraResolution
|
|
||||||
from mflux.models.common.resolution.path_resolution import PathResolution
|
|
||||||
from mflux.models.common.resolution.quantization_resolution import (
|
|
||||||
QuantizationResolution,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ConfigResolution",
|
|
||||||
"LoraResolution",
|
|
||||||
"PathResolution",
|
|
||||||
"QuantizationResolution",
|
|
||||||
]
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import NamedTuple
|
|
||||||
|
|
||||||
class QuantizationAction(Enum):
|
|
||||||
NONE = ...
|
|
||||||
STORED = ...
|
|
||||||
REQUESTED = ...
|
|
||||||
|
|
||||||
class PathAction(Enum):
|
|
||||||
LOCAL = ...
|
|
||||||
HUGGINGFACE_CACHED = ...
|
|
||||||
HUGGINGFACE = ...
|
|
||||||
ERROR = ...
|
|
||||||
|
|
||||||
class LoraAction(Enum):
|
|
||||||
LOCAL = ...
|
|
||||||
REGISTRY = ...
|
|
||||||
HUGGINGFACE_COLLECTION_CACHED = ...
|
|
||||||
HUGGINGFACE_COLLECTION = ...
|
|
||||||
HUGGINGFACE_REPO_CACHED = ...
|
|
||||||
HUGGINGFACE_REPO = ...
|
|
||||||
ERROR = ...
|
|
||||||
|
|
||||||
class ConfigAction(Enum):
|
|
||||||
EXACT_MATCH = ...
|
|
||||||
EXPLICIT_BASE = ...
|
|
||||||
INFER_SUBSTRING = ...
|
|
||||||
ERROR = ...
|
|
||||||
|
|
||||||
class Rule(NamedTuple):
|
|
||||||
priority: int
|
|
||||||
name: str
|
|
||||||
check: str
|
|
||||||
action: QuantizationAction | PathAction | LoraAction | ConfigAction
|
|
||||||
...
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from mflux.models.common.config.model_config import ModelConfig
|
|
||||||
|
|
||||||
if TYPE_CHECKING: ...
|
|
||||||
logger = ...
|
|
||||||
|
|
||||||
class ConfigResolution:
|
|
||||||
RULES = ...
|
|
||||||
@staticmethod
|
|
||||||
def resolve(model_name: str, base_model: str | None = ...) -> ModelConfig: ...
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
logger = ...
|
|
||||||
|
|
||||||
class LoraResolution:
|
|
||||||
RULES = ...
|
|
||||||
_registry: dict[str, Path] = ...
|
|
||||||
@staticmethod
|
|
||||||
def resolve(path: str) -> str: ...
|
|
||||||
@staticmethod
|
|
||||||
def resolve_paths(paths: list[str] | None) -> list[str]: ...
|
|
||||||
@staticmethod
|
|
||||||
def resolve_scales(scales: list[float] | None, num_paths: int) -> list[float]: ...
|
|
||||||
@staticmethod
|
|
||||||
def get_registry() -> dict[str, Path]: ...
|
|
||||||
@staticmethod
|
|
||||||
def discover_files(library_paths: list[Path]) -> dict[str, Path]: ...
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
logger = ...
|
|
||||||
|
|
||||||
class PathResolution:
|
|
||||||
RULES = ...
|
|
||||||
@staticmethod
|
|
||||||
def resolve(path: str | None, patterns: list[str] | None = ...) -> Path | None: ...
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
logger = ...
|
|
||||||
|
|
||||||
class QuantizationResolution:
|
|
||||||
RULES = ...
|
|
||||||
@staticmethod
|
|
||||||
def resolve(
|
|
||||||
stored: int | None, requested: int | None
|
|
||||||
) -> tuple[int | None, str | None]: ...
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .flow_match_euler_discrete_scheduler import FlowMatchEulerDiscreteScheduler
|
|
||||||
from .linear_scheduler import LinearScheduler
|
|
||||||
from .seedvr2_euler_scheduler import SeedVR2EulerScheduler
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"LinearScheduler",
|
|
||||||
"FlowMatchEulerDiscreteScheduler",
|
|
||||||
"SeedVR2EulerScheduler",
|
|
||||||
]
|
|
||||||
|
|
||||||
class SchedulerModuleNotFound(ValueError): ...
|
|
||||||
class SchedulerClassNotFound(ValueError): ...
|
|
||||||
class InvalidSchedulerType(TypeError): ...
|
|
||||||
|
|
||||||
SCHEDULER_REGISTRY = ...
|
|
||||||
|
|
||||||
def register_contrib(scheduler_object, scheduler_name=...): # -> None:
|
|
||||||
...
|
|
||||||
def try_import_external_scheduler(
|
|
||||||
scheduler_object_path: str,
|
|
||||||
): # -> type[BaseScheduler]:
|
|
||||||
...
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
class BaseScheduler(ABC):
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def sigmas(self) -> mx.array: ...
|
|
||||||
@abstractmethod
|
|
||||||
def step(
|
|
||||||
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
|
|
||||||
) -> mx.array: ...
|
|
||||||
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from mflux.models.common.config.config import Config
|
|
||||||
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
|
|
||||||
|
|
||||||
if TYPE_CHECKING: ...
|
|
||||||
|
|
||||||
class FlowMatchEulerDiscreteScheduler(BaseScheduler):
|
|
||||||
def __init__(self, config: Config) -> None: ...
|
|
||||||
@property
|
|
||||||
def sigmas(self) -> mx.array: ...
|
|
||||||
@property
|
|
||||||
def timesteps(self) -> mx.array: ...
|
|
||||||
def set_image_seq_len(self, image_seq_len: int) -> None: ...
|
|
||||||
@staticmethod
|
|
||||||
def get_timesteps_and_sigmas(
|
|
||||||
image_seq_len: int, num_inference_steps: int, num_train_timesteps: int = ...
|
|
||||||
) -> tuple[mx.array, mx.array]: ...
|
|
||||||
def step(
|
|
||||||
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
|
|
||||||
) -> mx.array: ...
|
|
||||||
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from mflux.models.common.config.config import Config
|
|
||||||
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
|
|
||||||
|
|
||||||
if TYPE_CHECKING: ...
|
|
||||||
|
|
||||||
class LinearScheduler(BaseScheduler):
|
|
||||||
def __init__(self, config: Config) -> None: ...
|
|
||||||
@property
|
|
||||||
def sigmas(self) -> mx.array: ...
|
|
||||||
@property
|
|
||||||
def timesteps(self) -> mx.array: ...
|
|
||||||
def step(
|
|
||||||
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
|
|
||||||
) -> mx.array: ...
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from mflux.models.common.config.config import Config
|
|
||||||
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
|
|
||||||
|
|
||||||
if TYPE_CHECKING: ...
|
|
||||||
|
|
||||||
class SeedVR2EulerScheduler(BaseScheduler):
|
|
||||||
def __init__(self, config: Config) -> None: ...
|
|
||||||
@property
|
|
||||||
def timesteps(self) -> mx.array: ...
|
|
||||||
@property
|
|
||||||
def sigmas(self) -> mx.array: ...
|
|
||||||
def step(
|
|
||||||
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
|
|
||||||
) -> mx.array: ...
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mflux.models.common.tokenizer.tokenizer import (
|
|
||||||
BaseTokenizer,
|
|
||||||
LanguageTokenizer,
|
|
||||||
Tokenizer,
|
|
||||||
VisionLanguageTokenizer,
|
|
||||||
)
|
|
||||||
from mflux.models.common.tokenizer.tokenizer_loader import TokenizerLoader
|
|
||||||
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
__all__ = [
|
|
||||||
"Tokenizer",
|
|
||||||
"BaseTokenizer",
|
|
||||||
"LanguageTokenizer",
|
|
||||||
"VisionLanguageTokenizer",
|
|
||||||
"TokenizerLoader",
|
|
||||||
"TokenizerOutput",
|
|
||||||
]
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Protocol, runtime_checkable
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import PreTrainedTokenizer
|
|
||||||
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class Tokenizer(Protocol):
|
|
||||||
tokenizer: PreTrainedTokenizer
|
|
||||||
def tokenize(
|
|
||||||
self,
|
|
||||||
prompt: str | list[str],
|
|
||||||
images: list[Image.Image] | None = ...,
|
|
||||||
max_length: int | None = ...,
|
|
||||||
**kwargs,
|
|
||||||
) -> TokenizerOutput: ...
|
|
||||||
|
|
||||||
class BaseTokenizer(ABC):
|
|
||||||
def __init__(
|
|
||||||
self, tokenizer: PreTrainedTokenizer, max_length: int = ...
|
|
||||||
) -> None: ...
|
|
||||||
@abstractmethod
|
|
||||||
def tokenize(
|
|
||||||
self,
|
|
||||||
prompt: str | list[str],
|
|
||||||
images: list[Image.Image] | None = ...,
|
|
||||||
max_length: int | None = ...,
|
|
||||||
**kwargs,
|
|
||||||
) -> TokenizerOutput: ...
|
|
||||||
|
|
||||||
class LanguageTokenizer(BaseTokenizer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
max_length: int = ...,
|
|
||||||
padding: str = ...,
|
|
||||||
return_attention_mask: bool = ...,
|
|
||||||
template: str | None = ...,
|
|
||||||
use_chat_template: bool = ...,
|
|
||||||
chat_template_kwargs: dict | None = ...,
|
|
||||||
add_special_tokens: bool = ...,
|
|
||||||
) -> None: ...
|
|
||||||
def tokenize(
|
|
||||||
self,
|
|
||||||
prompt: str | list[str],
|
|
||||||
images: list[Image.Image] | None = ...,
|
|
||||||
max_length: int | None = ...,
|
|
||||||
**kwargs,
|
|
||||||
) -> TokenizerOutput: ...
|
|
||||||
|
|
||||||
class VisionLanguageTokenizer(BaseTokenizer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
processor,
|
|
||||||
max_length: int = ...,
|
|
||||||
template: str | None = ...,
|
|
||||||
image_token: str = ...,
|
|
||||||
) -> None: ...
|
|
||||||
def tokenize(
|
|
||||||
self,
|
|
||||||
prompt: str | list[str],
|
|
||||||
images: list[Image.Image] | None = ...,
|
|
||||||
max_length: int | None = ...,
|
|
||||||
**kwargs,
|
|
||||||
) -> TokenizerOutput: ...
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
|
|
||||||
from mflux.models.common.weights.loading.weight_definition import TokenizerDefinition
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
if TYPE_CHECKING: ...
|
|
||||||
|
|
||||||
class TokenizerLoader:
|
|
||||||
@staticmethod
|
|
||||||
def load(definition: TokenizerDefinition, model_path: str) -> BaseTokenizer: ...
|
|
||||||
@staticmethod
|
|
||||||
def load_all(
|
|
||||||
definitions: list[TokenizerDefinition],
|
|
||||||
model_path: str,
|
|
||||||
max_length_overrides: dict[str, int] | None = ...,
|
|
||||||
) -> dict[str, BaseTokenizer]: ...
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TokenizerOutput:
|
|
||||||
input_ids: mx.array
|
|
||||||
attention_mask: mx.array
|
|
||||||
pixel_values: mx.array | None = ...
|
|
||||||
image_grid_thw: mx.array | None = ...
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mflux.models.common.vae.tiling_config import TilingConfig
|
|
||||||
from mflux.models.common.vae.vae_tiler import VAETiler
|
|
||||||
|
|
||||||
__all__ = ["TilingConfig", "VAETiler"]
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
|
||||||
class TilingConfig:
|
|
||||||
vae_decode_tiles_per_dim: int | None = ...
|
|
||||||
vae_decode_overlap: int = ...
|
|
||||||
vae_encode_tiled: bool = ...
|
|
||||||
vae_encode_tile_size: int = ...
|
|
||||||
vae_encode_tile_overlap: int = ...
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
class VAETiler:
|
|
||||||
@staticmethod
|
|
||||||
def encode_image_tiled(
|
|
||||||
*,
|
|
||||||
image: mx.array,
|
|
||||||
encode_fn: Callable[[mx.array], mx.array],
|
|
||||||
latent_channels: int,
|
|
||||||
tile_size: tuple[int, int] = ...,
|
|
||||||
tile_overlap: tuple[int, int] = ...,
|
|
||||||
spatial_scale: int = ...,
|
|
||||||
) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def decode_image_tiled(
|
|
||||||
*,
|
|
||||||
latent: mx.array,
|
|
||||||
decode_fn: Callable[[mx.array], mx.array],
|
|
||||||
tile_size: tuple[int, int] = ...,
|
|
||||||
tile_overlap: tuple[int, int] = ...,
|
|
||||||
spatial_scale: int = ...,
|
|
||||||
) -> mx.array: ...
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx import nn
|
|
||||||
from mflux.models.common.vae.tiling_config import TilingConfig
|
|
||||||
|
|
||||||
class VAEUtil:
|
|
||||||
@staticmethod
|
|
||||||
def encode(
|
|
||||||
vae: nn.Module, image: mx.array, tiling_config: TilingConfig | None = ...
|
|
||||||
) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def decode(
|
|
||||||
vae: nn.Module, latent: mx.array, tiling_config: TilingConfig | None = ...
|
|
||||||
) -> mx.array: ...
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights, MetaData
|
|
||||||
from mflux.models.common.weights.loading.weight_applier import WeightApplier
|
|
||||||
from mflux.models.common.weights.loading.weight_definition import ComponentDefinition
|
|
||||||
from mflux.models.common.weights.loading.weight_loader import WeightLoader
|
|
||||||
from mflux.models.common.weights.saving.model_saver import ModelSaver
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ComponentDefinition",
|
|
||||||
"LoadedWeights",
|
|
||||||
"MetaData",
|
|
||||||
"ModelSaver",
|
|
||||||
"WeightApplier",
|
|
||||||
"WeightLoader",
|
|
||||||
]
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MetaData:
|
|
||||||
quantization_level: int | None = ...
|
|
||||||
mflux_version: str | None = ...
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoadedWeights:
|
|
||||||
components: dict[str, dict]
|
|
||||||
meta_data: MetaData
|
|
||||||
def __getattr__(self, name: str) -> dict | None: ...
|
|
||||||
def num_transformer_blocks(self, component_name: str = ...) -> int: ...
|
|
||||||
def num_single_transformer_blocks(self, component_name: str = ...) -> int: ...
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.nn as nn
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
|
|
||||||
from mflux.models.common.weights.loading.weight_definition import (
|
|
||||||
ComponentDefinition,
|
|
||||||
WeightDefinitionType,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING: ...
|
|
||||||
|
|
||||||
class WeightApplier:
|
|
||||||
@staticmethod
|
|
||||||
def apply_and_quantize_single(
|
|
||||||
weights: LoadedWeights,
|
|
||||||
model: nn.Module,
|
|
||||||
component: ComponentDefinition,
|
|
||||||
quantize_arg: int | None,
|
|
||||||
quantization_predicate=...,
|
|
||||||
) -> int | None: ...
|
|
||||||
@staticmethod
|
|
||||||
def apply_and_quantize(
|
|
||||||
weights: LoadedWeights,
|
|
||||||
models: dict[str, nn.Module],
|
|
||||||
quantize_arg: int | None,
|
|
||||||
weight_definition: WeightDefinitionType,
|
|
||||||
) -> int | None: ...
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable, List, TYPE_CHECKING, TypeAlias
|
|
||||||
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
|
|
||||||
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
|
|
||||||
from mflux.models.depth_pro.weights.depth_pro_weight_definition import (
|
|
||||||
DepthProWeightDefinition,
|
|
||||||
)
|
|
||||||
from mflux.models.fibo.weights.fibo_weight_definition import FIBOWeightDefinition
|
|
||||||
from mflux.models.fibo_vlm.weights.fibo_vlm_weight_definition import (
|
|
||||||
FIBOVLMWeightDefinition,
|
|
||||||
)
|
|
||||||
from mflux.models.flux.weights.flux_weight_definition import FluxWeightDefinition
|
|
||||||
from mflux.models.qwen.weights.qwen_weight_definition import QwenWeightDefinition
|
|
||||||
from mflux.models.seedvr2.weights.seedvr2_weight_definition import (
|
|
||||||
SeedVR2WeightDefinition,
|
|
||||||
)
|
|
||||||
from mflux.models.z_image.weights.z_image_weight_definition import (
|
|
||||||
ZImageWeightDefinition,
|
|
||||||
)
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
WeightDefinitionType: TypeAlias = type[
|
|
||||||
FluxWeightDefinition
|
|
||||||
| FIBOWeightDefinition
|
|
||||||
| FIBOVLMWeightDefinition
|
|
||||||
| QwenWeightDefinition
|
|
||||||
| ZImageWeightDefinition
|
|
||||||
| SeedVR2WeightDefinition
|
|
||||||
| DepthProWeightDefinition
|
|
||||||
]
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ComponentDefinition:
|
|
||||||
name: str
|
|
||||||
hf_subdir: str
|
|
||||||
mapping_getter: Callable[[], List[WeightTarget]] | None = ...
|
|
||||||
model_attr: str | None = ...
|
|
||||||
num_blocks: int | None = ...
|
|
||||||
num_layers: int | None = ...
|
|
||||||
loading_mode: str = ...
|
|
||||||
precision: mx.Dtype | None = ...
|
|
||||||
skip_quantization: bool = ...
|
|
||||||
bulk_transform: Callable[[mx.array], mx.array] | None = ...
|
|
||||||
weight_subkey: str | None = ...
|
|
||||||
download_url: str | None = ...
|
|
||||||
weight_prefix_filters: List[str] | None = ...
|
|
||||||
weight_files: List[str] | None = ...
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TokenizerDefinition:
|
|
||||||
name: str
|
|
||||||
hf_subdir: str
|
|
||||||
tokenizer_class: str = ...
|
|
||||||
fallback_subdirs: List[str] | None = ...
|
|
||||||
download_patterns: List[str] | None = ...
|
|
||||||
encoder_class: type[BaseTokenizer] | None = ...
|
|
||||||
max_length: int = ...
|
|
||||||
padding: str = ...
|
|
||||||
template: str | None = ...
|
|
||||||
use_chat_template: bool = ...
|
|
||||||
chat_template_kwargs: dict | None = ...
|
|
||||||
add_special_tokens: bool = ...
|
|
||||||
processor_class: type | None = ...
|
|
||||||
image_token: str = ...
|
|
||||||
chat_template: str | None = ...
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
|
|
||||||
from mflux.models.common.weights.loading.weight_definition import (
|
|
||||||
ComponentDefinition,
|
|
||||||
WeightDefinitionType,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING: ...
|
|
||||||
logger = ...
|
|
||||||
|
|
||||||
class WeightLoader:
|
|
||||||
@staticmethod
|
|
||||||
def load_single(
|
|
||||||
component: ComponentDefinition, repo_id: str, file_pattern: str = ...
|
|
||||||
) -> LoadedWeights: ...
|
|
||||||
@staticmethod
|
|
||||||
def load(
|
|
||||||
weight_definition: WeightDefinitionType, model_path: str | None = ...
|
|
||||||
) -> LoadedWeights: ...
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
|
|
||||||
|
|
||||||
class WeightMapper:
|
|
||||||
@staticmethod
|
|
||||||
def apply_mapping(
|
|
||||||
hf_weights: Dict[str, mx.array],
|
|
||||||
mapping: List[WeightTarget],
|
|
||||||
num_blocks: Optional[int] = ...,
|
|
||||||
num_layers: Optional[int] = ...,
|
|
||||||
) -> Dict: ...
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable, List, Optional, Protocol
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WeightTarget:
|
|
||||||
to_pattern: str
|
|
||||||
from_pattern: List[str]
|
|
||||||
transform: Optional[Callable[[mx.array], mx.array]] = ...
|
|
||||||
required: bool = ...
|
|
||||||
max_blocks: Optional[int] = ...
|
|
||||||
|
|
||||||
class WeightMapping(Protocol):
|
|
||||||
@staticmethod
|
|
||||||
def get_mapping() -> List[WeightTarget]: ...
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
class WeightTransforms:
|
|
||||||
@staticmethod
|
|
||||||
def reshape_gamma_to_1d(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def transpose_patch_embed(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def transpose_conv3d_weight(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def transpose_conv2d_weight(tensor: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def transpose_conv_transpose2d_weight(tensor: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, TYPE_CHECKING
|
|
||||||
from mflux.models.common.weights.loading.weight_definition import WeightDefinitionType
|
|
||||||
|
|
||||||
if TYPE_CHECKING: ...
|
|
||||||
|
|
||||||
class ModelSaver:
|
|
||||||
@staticmethod
|
|
||||||
def save_model(
|
|
||||||
model: Any, bits: int, base_path: str, weight_definition: WeightDefinitionType
|
|
||||||
) -> None: ...
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mflux.models.depth_pro.model.depth_pro_model import DepthProModel
|
|
||||||
|
|
||||||
class DepthProInitializer:
|
|
||||||
@staticmethod
|
|
||||||
def init(model: DepthProModel, quantize: int | None = ...) -> None: ...
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class FeatureFusionBlock2d(nn.Module):
|
|
||||||
def __init__(self, num_features: int, deconv: bool = ...) -> None: ...
|
|
||||||
def __call__(self, x0: mx.array, x1: mx.array | None = ...) -> mx.array: ...
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class MultiresConvDecoder(nn.Module):
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
x0_latent: mx.array,
|
|
||||||
x1_latent: mx.array,
|
|
||||||
x0_features: mx.array,
|
|
||||||
x1_features: mx.array,
|
|
||||||
x_global_features: mx.array,
|
|
||||||
) -> mx.array: ...
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
|
||||||
def __init__(self, num_features: int) -> None: ...
|
|
||||||
def __call__(self, x: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DepthResult:
|
|
||||||
depth_image: Image.Image
|
|
||||||
depth_array: mx.array
|
|
||||||
min_depth: float
|
|
||||||
max_depth: float
|
|
||||||
...
|
|
||||||
|
|
||||||
class DepthPro:
|
|
||||||
def __init__(self, quantize: int | None = ...) -> None: ...
|
|
||||||
def create_depth_map(self, image_path: str | Path) -> DepthResult: ...
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class DepthProModel(nn.Module):
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def __call__(
|
|
||||||
self, x0: mx.array, x1: mx.array, x2: mx.array
|
|
||||||
) -> tuple[mx.array, mx.array]: ...
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class DepthProUtil:
|
|
||||||
@staticmethod
|
|
||||||
def split(x: mx.array, overlap_ratio: float = ...) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def interpolate(x: mx.array, size=..., scale_factor=...): # -> array:
|
|
||||||
...
|
|
||||||
@staticmethod
|
|
||||||
def apply_conv(x: mx.array, conv_module: nn.Module) -> mx.array: ...
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx import nn
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, dim: int = ..., head_dim: int = ..., num_heads: int = ...
|
|
||||||
) -> None: ...
|
|
||||||
def __call__(self, x: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class DinoVisionTransformer(nn.Module):
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array, mx.array]: ...
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class LayerScale(nn.Module):
|
|
||||||
def __init__(self, dims: int, init_values: float = ...) -> None: ...
|
|
||||||
def __call__(self, x: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def __call__(self, x: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def __call__(self, x: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class TransformerBlock(nn.Module):
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def __call__(self, x: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class DepthProEncoder(nn.Module):
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def __call__(
|
|
||||||
self, x0: mx.array, x1: mx.array, x2: mx.array
|
|
||||||
) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class UpSampleBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_in: int = ...,
|
|
||||||
dim_int: int = ...,
|
|
||||||
dim_out: int = ...,
|
|
||||||
upsample_layers: int = ...,
|
|
||||||
) -> None: ...
|
|
||||||
def __call__(self, x: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn as nn
|
|
||||||
|
|
||||||
class FOVHead(nn.Module):
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def __call__(self, x: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
from mflux.models.common.weights.loading.weight_definition import (
|
|
||||||
ComponentDefinition,
|
|
||||||
TokenizerDefinition,
|
|
||||||
)
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class DepthProWeightDefinition:
|
|
||||||
@staticmethod
|
|
||||||
def get_components() -> List[ComponentDefinition]: ...
|
|
||||||
@staticmethod
|
|
||||||
def get_tokenizers() -> List[TokenizerDefinition]: ...
|
|
||||||
@staticmethod
|
|
||||||
def get_download_patterns() -> List[str]: ...
|
|
||||||
@staticmethod
|
|
||||||
def quantization_predicate(path: str, module) -> bool: ...
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
from mflux.models.common.weights.mapping.weight_mapping import (
|
|
||||||
WeightMapping,
|
|
||||||
WeightTarget,
|
|
||||||
)
|
|
||||||
|
|
||||||
class DepthProWeightMapping(WeightMapping):
|
|
||||||
@staticmethod
|
|
||||||
def get_mapping() -> List[WeightTarget]: ...
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
class FiboLatentCreator:
|
|
||||||
@staticmethod
|
|
||||||
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def pack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
from mflux.models.common.weights.loading.weight_definition import (
|
|
||||||
ComponentDefinition,
|
|
||||||
TokenizerDefinition,
|
|
||||||
)
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class FIBOWeightDefinition:
|
|
||||||
@staticmethod
|
|
||||||
def get_components() -> List[ComponentDefinition]: ...
|
|
||||||
@staticmethod
|
|
||||||
def get_tokenizers() -> List[TokenizerDefinition]: ...
|
|
||||||
@staticmethod
|
|
||||||
def get_download_patterns() -> List[str]: ...
|
|
||||||
@staticmethod
|
|
||||||
def quantization_predicate(path: str, module) -> bool: ...
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
from mflux.models.common.weights.mapping.weight_mapping import (
|
|
||||||
WeightMapping,
|
|
||||||
WeightTarget,
|
|
||||||
)
|
|
||||||
|
|
||||||
class FIBOWeightMapping(WeightMapping):
|
|
||||||
@staticmethod
|
|
||||||
def get_transformer_mapping() -> List[WeightTarget]: ...
|
|
||||||
@staticmethod
|
|
||||||
def get_text_encoder_mapping() -> List[WeightTarget]: ...
|
|
||||||
@staticmethod
|
|
||||||
def get_vae_mapping() -> List[WeightTarget]: ...
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mflux.models.qwen.tokenizer.qwen_image_processor import QwenImageProcessor
|
|
||||||
|
|
||||||
class Qwen2VLImageProcessor(QwenImageProcessor):
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional, Union
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
class Qwen2VLProcessor:
|
|
||||||
def __init__(self, tokenizer) -> None: ...
|
|
||||||
def apply_chat_template(
|
|
||||||
self,
|
|
||||||
messages,
|
|
||||||
tokenize: bool = ...,
|
|
||||||
add_generation_prompt: bool = ...,
|
|
||||||
return_tensors: Optional[str] = ...,
|
|
||||||
return_dict: bool = ...,
|
|
||||||
**kwargs,
|
|
||||||
): # -> dict[Any, Any]:
|
|
||||||
...
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
text: Optional[Union[str, list[str]]] = ...,
|
|
||||||
images: Optional[Union[Image.Image, list[Image.Image]]] = ...,
|
|
||||||
padding: bool = ...,
|
|
||||||
return_tensors: Optional[str] = ...,
|
|
||||||
**kwargs,
|
|
||||||
): # -> dict[Any, Any]:
|
|
||||||
...
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
from mflux.models.common.weights.loading.weight_definition import (
|
|
||||||
ComponentDefinition,
|
|
||||||
TokenizerDefinition,
|
|
||||||
)
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
QWEN2VL_CHAT_TEMPLATE = ...
|
|
||||||
|
|
||||||
class FIBOVLMWeightDefinition:
|
|
||||||
@staticmethod
|
|
||||||
def get_components() -> List[ComponentDefinition]: ...
|
|
||||||
@staticmethod
|
|
||||||
def get_tokenizers() -> List[TokenizerDefinition]: ...
|
|
||||||
@staticmethod
|
|
||||||
def get_download_patterns() -> List[str]: ...
|
|
||||||
@staticmethod
|
|
||||||
def quantization_predicate(path: str, module) -> bool: ...
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
from mflux.models.common.weights.mapping.weight_mapping import (
|
|
||||||
WeightMapping,
|
|
||||||
WeightTarget,
|
|
||||||
)
|
|
||||||
|
|
||||||
class FIBOVLMWeightMapping(WeightMapping):
|
|
||||||
@staticmethod
|
|
||||||
def get_vlm_decoder_mapping(num_layers: int = ...) -> List[WeightTarget]: ...
|
|
||||||
@staticmethod
|
|
||||||
def get_vlm_visual_mapping(depth: int = ...) -> List[WeightTarget]: ...
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from mflux.models.common.config import ModelConfig
|
|
||||||
|
|
||||||
class FluxInitializer:
|
|
||||||
@staticmethod
|
|
||||||
def init(
|
|
||||||
model,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
quantize: int | None,
|
|
||||||
model_path: str | None = ...,
|
|
||||||
lora_paths: list[str] | None = ...,
|
|
||||||
lora_scales: list[float] | None = ...,
|
|
||||||
custom_transformer=...,
|
|
||||||
) -> None: ...
|
|
||||||
@staticmethod
|
|
||||||
def init_depth(
|
|
||||||
model,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
quantize: int | None,
|
|
||||||
model_path: str | None = ...,
|
|
||||||
lora_paths: list[str] | None = ...,
|
|
||||||
lora_scales: list[float] | None = ...,
|
|
||||||
) -> None: ...
|
|
||||||
@staticmethod
|
|
||||||
def init_redux(
|
|
||||||
model,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
quantize: int | None,
|
|
||||||
model_path: str | None = ...,
|
|
||||||
lora_paths: list[str] | None = ...,
|
|
||||||
lora_scales: list[float] | None = ...,
|
|
||||||
) -> None: ...
|
|
||||||
@staticmethod
|
|
||||||
def init_controlnet(
|
|
||||||
model,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
quantize: int | None,
|
|
||||||
model_path: str | None = ...,
|
|
||||||
lora_paths: list[str] | None = ...,
|
|
||||||
lora_scales: list[float] | None = ...,
|
|
||||||
) -> None: ...
|
|
||||||
@staticmethod
|
|
||||||
def init_concept(
|
|
||||||
model,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
quantize: int | None,
|
|
||||||
model_path: str | None = ...,
|
|
||||||
lora_paths: list[str] | None = ...,
|
|
||||||
lora_scales: list[float] | None = ...,
|
|
||||||
) -> None: ...
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class FluxLatentCreator:
|
|
||||||
@staticmethod
|
|
||||||
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def pack_latents(
|
|
||||||
latents: mx.array, height: int, width: int, num_channels_latents: int = ...
|
|
||||||
) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx import nn
|
|
||||||
|
|
||||||
class CLIPEmbeddings(nn.Module):
|
|
||||||
def __init__(self, dims: int) -> None: ...
|
|
||||||
def __call__(self, tokens: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx import nn
|
|
||||||
|
|
||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class CLIPEncoder(nn.Module):
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def __call__(self, tokens: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx import nn
|
|
||||||
|
|
||||||
class CLIPEncoderLayer(nn.Module):
|
|
||||||
def __init__(self, layer: int) -> None: ...
|
|
||||||
def __call__(
|
|
||||||
self, hidden_states: mx.array, causal_attention_mask: mx.array
|
|
||||||
) -> mx.array: ...
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx import nn
|
|
||||||
|
|
||||||
class CLIPMLP(nn.Module):
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def __call__(self, hidden_states: mx.array) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def quick_gelu(input_array: mx.array) -> mx.array: ...
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx import nn
|
|
||||||
|
|
||||||
class CLIPSdpaAttention(nn.Module):
|
|
||||||
head_dimension = ...
|
|
||||||
batch_size = ...
|
|
||||||
num_heads = ...
|
|
||||||
def __init__(self) -> None: ...
|
|
||||||
def __call__(
|
|
||||||
self, hidden_states: mx.array, causal_attention_mask: mx.array
|
|
||||||
) -> mx.array: ...
|
|
||||||
@staticmethod
|
|
||||||
def reshape_and_transpose(x, batch_size, num_heads, head_dim): # -> array:
|
|
||||||
...
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx import nn
|
|
||||||
|
|
||||||
class CLIPTextModel(nn.Module):
|
|
||||||
def __init__(self, dims: int, num_encoder_layers: int) -> None: ...
|
|
||||||
def __call__(self, tokens: mx.array) -> tuple[mx.array, mx.array]: ...
|
|
||||||
@staticmethod
|
|
||||||
def create_causal_attention_mask(input_shape: tuple) -> mx.array: ...
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
This type stub file was generated by pyright.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx import nn
|
|
||||||
|
|
||||||
class EncoderCLIP(nn.Module):
|
|
||||||
def __init__(self, num_encoder_layers: int) -> None: ...
|
|
||||||
def __call__(
|
|
||||||
self, tokens: mx.array, causal_attention_mask: mx.array
|
|
||||||
) -> mx.array: ...
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user