mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-07 13:39:19 -05:00
Compare commits
2 Commits
v1.0.61-al
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
077b1bc732 | ||
|
|
4963c33162 |
159
.github/benchmark-dashboard/README.md
vendored
159
.github/benchmark-dashboard/README.md
vendored
@@ -1,159 +0,0 @@
|
||||
# EXO Benchmark Dashboard
|
||||
|
||||
A fully self-contained, browser-based dashboard for tracking EXO benchmark performance over time.
|
||||
|
||||
## Features
|
||||
|
||||
- 📊 **Success Rate Tracking**: Monitor cluster reliability across commits
|
||||
- ⚡ **Response Time Analysis**: Track average request completion times
|
||||
- 🎯 **Throughput Metrics**: Tokens per second visualization
|
||||
- 📈 **Request Distribution**: Success/failure breakdown over time
|
||||
- 🔄 **Auto-Refresh**: Updates every 60 seconds
|
||||
- 📺 **TV-Ready**: Large, clear visualizations perfect for display
|
||||
- 🔐 **Secure**: Credentials stored in browser localStorage only
|
||||
- 🌐 **No Backend**: Directly accesses S3 from the browser
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Option 1: Direct File Access (Simplest)
|
||||
|
||||
Just open the HTML file directly in your browser:
|
||||
|
||||
```bash
|
||||
open .github/benchmark-dashboard/index.html
|
||||
```
|
||||
|
||||
Then click "Configure AWS Credentials" and enter your keys.
|
||||
|
||||
### Option 2: URL Parameters (For Quick Setup)
|
||||
|
||||
```bash
|
||||
# Serve with credentials in URL (they'll be moved to localStorage)
|
||||
open ".github/benchmark-dashboard/index.html?accessKey=YOUR_KEY&secretKey=YOUR_SECRET®ion=us-east-1"
|
||||
```
|
||||
|
||||
The credentials will be saved to localStorage and removed from the URL immediately.
|
||||
|
||||
### Option 3: Simple HTTP Server
|
||||
|
||||
```bash
|
||||
# From repo root
|
||||
python3 -m http.server 8080
|
||||
|
||||
# Then open: http://localhost:8080/.github/benchmark-dashboard/
|
||||
```
|
||||
|
||||
## AWS Credentials
|
||||
|
||||
The dashboard needs read-only access to the `exo-benchmark-results` S3 bucket.
|
||||
|
||||
### Required IAM Permissions
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::exo-benchmark-results",
|
||||
"arn:aws:s3:::exo-benchmark-results/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Security Notes
|
||||
|
||||
- ✅ Credentials stored in browser `localStorage` only
|
||||
- ✅ Never sent to any server (except AWS)
|
||||
- ✅ All S3 access happens client-side
|
||||
- ✅ Use read-only IAM credentials
|
||||
- ⚠️ Don't commit credentials to git
|
||||
- ⚠️ Use a dedicated read-only IAM user
|
||||
|
||||
## TV/Kiosk Mode
|
||||
|
||||
For permanent display on a TV:
|
||||
|
||||
### macOS
|
||||
```bash
|
||||
open -a "Google Chrome" --args --kiosk ".github/benchmark-dashboard/index.html"
|
||||
```
|
||||
|
||||
### Linux
|
||||
```bash
|
||||
chromium-browser --kiosk --app="file://$(pwd)/.github/benchmark-dashboard/index.html"
|
||||
```
|
||||
|
||||
### Auto-start on Boot
|
||||
|
||||
Create a simple startup script:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# /usr/local/bin/start-benchmark-dashboard.sh
|
||||
|
||||
cd /path/to/exo
|
||||
python3 -m http.server 8080 &
|
||||
sleep 2
|
||||
chromium-browser --kiosk http://localhost:8080/.github/benchmark-dashboard/
|
||||
```
|
||||
|
||||
## Data Displayed
|
||||
|
||||
### Summary Cards
|
||||
- **Latest Success Rate**: Most recent benchmark success percentage with trend
|
||||
- **Avg Response Time**: Latest average response time in ms with trend
|
||||
- **Total Benchmarks**: Count of all benchmarks run
|
||||
- **Active Configurations**: Number of unique benchmark configs
|
||||
|
||||
### Charts
|
||||
1. **Success Rate Over Time**: Line chart showing reliability trends
|
||||
2. **Average Response Time**: Performance over time (lower is better)
|
||||
3. **Throughput**: Tokens/second metric (higher is better)
|
||||
4. **Request Distribution**: Stacked bar chart of successes/failures
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Loads AWS SDK**: Uses AWS SDK for JavaScript (browser version)
|
||||
2. **Lists S3 Objects**: Fetches all files from `s3://exo-benchmark-results/bench/`
|
||||
3. **Downloads Results**: Fetches each JSON result file
|
||||
4. **Parses & Visualizes**: Uses Chart.js to create interactive charts
|
||||
5. **Auto-Refreshes**: Polls S3 every 60 seconds for new results
|
||||
|
||||
## Customization
|
||||
|
||||
To modify the dashboard:
|
||||
|
||||
1. Edit `index.html`
|
||||
2. Adjust `REFRESH_INTERVAL` for different polling frequency
|
||||
3. Modify chart colors/styles in the Chart.js configuration
|
||||
4. Add new metrics by extending the results parsing
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**"AWS credentials not configured"**
|
||||
- Click "Configure AWS Credentials" and enter your keys
|
||||
|
||||
**"Error loading benchmark data"**
|
||||
- Check AWS credentials are correct
|
||||
- Verify S3 bucket name is `exo-benchmark-results`
|
||||
- Ensure IAM user has read permissions
|
||||
- Check browser console for detailed errors
|
||||
|
||||
**"No benchmark results found"**
|
||||
- Wait for benchmark workflows to run
|
||||
- Verify results are being uploaded to S3
|
||||
- Check S3 bucket has files in `bench/` prefix
|
||||
|
||||
**Charts not updating**
|
||||
- Check browser console for errors
|
||||
- Verify network connectivity to S3
|
||||
- Try refreshing the page manually
|
||||
|
||||
1641
.github/benchmark-dashboard/index.html
vendored
1641
.github/benchmark-dashboard/index.html
vendored
File diff suppressed because it is too large
Load Diff
186
.github/configs/README.md
vendored
186
.github/configs/README.md
vendored
@@ -1,186 +0,0 @@
|
||||
# EXO Benchmark Configurations
|
||||
|
||||
This directory contains configuration files for the EXO staged benchmark system.
|
||||
|
||||
## Overview
|
||||
|
||||
The staged benchmark system allows you to run complex, multi-stage load tests against EXO clusters. Each stage can have different characteristics:
|
||||
|
||||
- **Prompt Length**: Number of tokens in the input prompt
|
||||
- **Generation Length**: Maximum tokens to generate in the response
|
||||
- **Time Between Requests**: Delay (in seconds) between firing consecutive requests
|
||||
- **Iterations**: Number of requests to send in this stage
|
||||
|
||||
Requests are **fire-and-forget** - they don't wait for the previous request to complete. This allows you to test overlapping request handling and measure success rates under load.
|
||||
|
||||
## Configuration Files
|
||||
|
||||
### `bench_simple.yaml`
|
||||
A minimal configuration that replicates the behavior of the original `bench.py` script:
|
||||
- Single stage with 1 iteration
|
||||
- Short prompt (~20 tokens)
|
||||
- Generates up to 100 tokens
|
||||
|
||||
This is useful for quick smoke tests.
|
||||
|
||||
### `bench_config.yaml`
|
||||
A comprehensive multi-stage benchmark with:
|
||||
1. **Warmup** (10 requests): Light load with short prompts
|
||||
2. **Medium Load** (20 requests): Moderate load with medium prompts
|
||||
3. **Stress Test** (30 requests): Heavy overlapping requests with long prompts
|
||||
4. **Cooldown** (5 requests): Light load to wind down
|
||||
|
||||
This tests the cluster's behavior under varying load patterns.
|
||||
|
||||
## Configuration Schema
|
||||
|
||||
```yaml
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 4
|
||||
|
||||
# Environment variables to set on each node (optional)
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 512
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 600
|
||||
|
||||
# Model instances to run concurrently
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
# Benchmark stages
|
||||
stages:
|
||||
- name: "stage_name" # Human-readable name for this stage
|
||||
prompt_length: 100 # Target prompt length in tokens
|
||||
generation_length: 200 # Max tokens to generate
|
||||
time_between_requests: 2.0 # Seconds between firing requests
|
||||
iterations: 10 # Number of requests in this stage
|
||||
```
|
||||
|
||||
## Running Benchmarks
|
||||
|
||||
### Via GitHub Actions
|
||||
|
||||
**Automatic (every commit):**
|
||||
- The **`bench`** workflow runs automatically on every push
|
||||
- Uses `bench_simple.yaml` as the default configuration
|
||||
- All settings (hardware plan, timeout, environment variables, models, stages) are defined in the config file
|
||||
|
||||
**Manual (on-demand):**
|
||||
1. Go to **Actions** → **bench** workflow
|
||||
2. Click **Run workflow**
|
||||
3. Configure:
|
||||
- **Config File**: Path to your YAML config (default: `.github/configs/bench_simple.yaml`)
|
||||
- `.github/configs/bench_simple.yaml` for quick tests
|
||||
- `.github/configs/bench_config.yaml` for complex multi-stage tests
|
||||
|
||||
All other settings (hardware plan, timeout, environment variables, models, stages) are read from the specified config file.
|
||||
|
||||
### Via Command Line
|
||||
|
||||
```bash
|
||||
# Start EXO on localhost:8000
|
||||
uv run exo --api-port 8000
|
||||
|
||||
# Run simple benchmark (1 stage, 1 iteration)
|
||||
python3 .github/scripts/bench.py \
|
||||
--api-port 8000 \
|
||||
--config .github/configs/bench_simple.yaml \
|
||||
--expected-nodes 1 \
|
||||
--is-primary true \
|
||||
--timeout-seconds 600
|
||||
|
||||
# Run complex staged benchmark (4 stages, multiple iterations)
|
||||
python3 .github/scripts/bench.py \
|
||||
--api-port 8000 \
|
||||
--config .github/configs/bench_config.yaml \
|
||||
--expected-nodes 1 \
|
||||
--is-primary true \
|
||||
--timeout-seconds 600
|
||||
```
|
||||
|
||||
## Output Metrics
|
||||
|
||||
For each stage, the benchmark reports:
|
||||
|
||||
- **Total Requests**: Number of requests fired
|
||||
- **Successful Requests**: Requests that completed successfully
|
||||
- **Failed Requests**: Requests that encountered errors
|
||||
- **Success Rate**: Percentage of successful requests
|
||||
- **Total Tokens**: Sum of all tokens generated across successful requests
|
||||
- **Avg Tokens/Request**: Average tokens per successful request
|
||||
- **Avg Time/Request**: Average completion time per successful request
|
||||
|
||||
A JSON summary is also printed for easy parsing and storage.
|
||||
|
||||
## Creating Custom Benchmarks
|
||||
|
||||
To create a custom benchmark:
|
||||
|
||||
1. Copy an existing config file (e.g., `bench_config.yaml`)
|
||||
2. Modify the stages to match your test scenario
|
||||
3. Save it in this directory with a descriptive name
|
||||
4. Run it using the workflow or command line
|
||||
|
||||
### Example: Sustained Load Test
|
||||
|
||||
```yaml
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 2
|
||||
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 1024
|
||||
|
||||
timeout_seconds: 600
|
||||
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
stages:
|
||||
- name: "sustained_load"
|
||||
prompt_length: 200
|
||||
generation_length: 150
|
||||
time_between_requests: 0.5 # Very fast - 2 requests/second
|
||||
iterations: 100 # Run for ~50 seconds
|
||||
```
|
||||
|
||||
### Example: Varying Prompt Sizes
|
||||
|
||||
```yaml
|
||||
hardware_plan:
|
||||
M4PRO_GPU16_24GB: 3
|
||||
|
||||
timeout_seconds: 900
|
||||
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
stages:
|
||||
- name: "tiny_prompts"
|
||||
prompt_length: 10
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
|
||||
- name: "medium_prompts"
|
||||
prompt_length: 200
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
|
||||
- name: "large_prompts"
|
||||
prompt_length: 1000
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
- **Overlapping Requests**: Set `time_between_requests` < expected completion time to test concurrent request handling
|
||||
- **Sequential Requests**: Set `time_between_requests` > expected completion time to ensure requests don't overlap
|
||||
- **Realistic Load**: Model real usage patterns by varying prompt/generation lengths across stages
|
||||
- **Success Rate**: A 100% success rate indicates the cluster handled the load well; lower rates suggest capacity limits
|
||||
|
||||
49
.github/configs/bench_config.yaml
vendored
49
.github/configs/bench_config.yaml
vendored
@@ -1,49 +0,0 @@
|
||||
# EXO Staged Benchmark Configuration
|
||||
# This configuration defines a multi-stage load test for EXO clusters
|
||||
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 4
|
||||
|
||||
# Environment variables to set on each node (optional)
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 512
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 600
|
||||
|
||||
# Multiple instances run concurrently on the cluster
|
||||
model_ids:
|
||||
- "mlx-community/Qwen3-0.6B-4bit"
|
||||
- "mlx-community/Qwen3-0.6B-4bit"
|
||||
|
||||
# Stages run sequentially, each with its own characteristics
|
||||
stages:
|
||||
# Stage 1: Light load with short prompts
|
||||
- name: "warmup"
|
||||
prompt_length: 50 # Number of tokens in prompt
|
||||
generation_length: 100 # Max tokens to generate
|
||||
time_between_requests: 5.0 # Seconds between firing requests
|
||||
iterations: 10 # Number of requests to send in this stage
|
||||
|
||||
# Stage 2: Medium load with medium prompts
|
||||
- name: "medium_load"
|
||||
prompt_length: 200
|
||||
generation_length: 150
|
||||
time_between_requests: 3.0
|
||||
iterations: 20
|
||||
|
||||
# Stage 3: Heavy load with long prompts - requests will overlap
|
||||
- name: "stress_test"
|
||||
prompt_length: 500
|
||||
generation_length: 200
|
||||
time_between_requests: 1.0 # Fast firing - will definitely overlap
|
||||
iterations: 30
|
||||
|
||||
# Stage 4: Cool down with simple prompts
|
||||
- name: "cooldown"
|
||||
prompt_length: 50
|
||||
generation_length: 50
|
||||
time_between_requests: 10.0
|
||||
iterations: 5
|
||||
|
||||
125
.github/configs/bench_simple.yaml
vendored
125
.github/configs/bench_simple.yaml
vendored
@@ -1,125 +0,0 @@
|
||||
# Simple single-shot benchmark
|
||||
# Tests 2 instances concurrently on 2 nodes
|
||||
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
puffin4: 1
|
||||
puffin8: 1
|
||||
|
||||
# Environment variables to set on each node
|
||||
environment:
|
||||
PLACEHOLDER: "placeholder"
|
||||
# OVERRIDE_MEMORY_MB: 50000
|
||||
MLX_METAL_FAST_SYNCH: 1
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 1800
|
||||
|
||||
# Model instances to run concurrently
|
||||
model_ids:
|
||||
# - "mlx-community/DeepSeek-V3.1-8bit"
|
||||
# - "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
- "mlx-community/Kimi-K2-Thinking"
|
||||
# - "mlx-community/Qwen3-235B-A22B-4bit"
|
||||
# - "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
# - "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
# Sharding strategy: "Pipeline" or "Tensor"
|
||||
sharding: "Tensor"
|
||||
|
||||
# Instance type: "MlxRing" or "MlxIbv"
|
||||
instance_meta: "MlxIbv"
|
||||
|
||||
# If true, run requests sequentially (no overlap); if false, fire-and-forget (default: false)
|
||||
no_overlap: true
|
||||
|
||||
# Benchmark stages
|
||||
# pp: 64, 256, 1024, 2048, 4096, 8192, 16384
|
||||
# g: 64, 512
|
||||
stages:
|
||||
# - name: "simple"
|
||||
# prompt_length: 512
|
||||
# generation_length: 10
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g64"
|
||||
# prompt_length: 64
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g64"
|
||||
# prompt_length: 64
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g512"
|
||||
# prompt_length: 64
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp256_g64"
|
||||
# prompt_length: 256
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
- name: "pp256_g64"
|
||||
prompt_length: 256
|
||||
generation_length: 64
|
||||
time_between_requests: 2.0
|
||||
iterations: 5
|
||||
# - name: "pp256_g512"
|
||||
# prompt_length: 256
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp1024_g64"
|
||||
# prompt_length: 1024
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp1024_g512"
|
||||
# prompt_length: 1024
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp2048_g64"
|
||||
# prompt_length: 2048
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp2048_g512"
|
||||
# prompt_length: 2048
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp4096_g64"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 4
|
||||
# - name: "pp4096_g512"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp8192_g64"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp8192_g512"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp16384_g64"
|
||||
# prompt_length: 16384
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp16384_g512"
|
||||
# prompt_length: 16384
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
1399
.github/scripts/bench.py
vendored
1399
.github/scripts/bench.py
vendored
File diff suppressed because it is too large
Load Diff
70
.github/scripts/build_matrix.py
vendored
70
.github/scripts/build_matrix.py
vendored
@@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class MatrixEntry(TypedDict):
|
||||
label: str
|
||||
index: int
|
||||
|
||||
|
||||
class MatrixInclude(TypedDict):
|
||||
label: str
|
||||
index: int
|
||||
is_primary: bool
|
||||
expected_nodes: int
|
||||
|
||||
|
||||
class Config(TypedDict):
|
||||
hardware_plan: dict[str, int]
|
||||
timeout_seconds: NotRequired[int]
|
||||
environment: NotRequired[dict[str, str]]
|
||||
|
||||
|
||||
# Read the config file
|
||||
config_file: str = os.environ["CONFIG_FILE"]
|
||||
with open(config_file, "r") as f:
|
||||
config: Config = cast(Config, yaml.safe_load(f))
|
||||
|
||||
# Extract hardware plan from config
|
||||
plan: dict[str, int] = config["hardware_plan"]
|
||||
if not plan:
|
||||
raise ValueError(f"No hardware_plan found in {config_file}")
|
||||
|
||||
# Build matrix entries
|
||||
entries: list[MatrixEntry] = []
|
||||
for label, count in plan.items():
|
||||
for idx in range(count):
|
||||
entries.append({"label": label, "index": idx})
|
||||
|
||||
total_nodes: int = len(entries)
|
||||
matrix: dict[str, list[MatrixInclude]] = {
|
||||
"include": [
|
||||
{
|
||||
"label": e["label"],
|
||||
"index": e["index"],
|
||||
"is_primary": (i == 0),
|
||||
"expected_nodes": total_nodes,
|
||||
}
|
||||
for i, e in enumerate(entries)
|
||||
]
|
||||
}
|
||||
|
||||
# Extract other config values
|
||||
timeout_seconds: int = config.get("timeout_seconds", 600)
|
||||
environment: dict[str, str] = config.get("environment", {})
|
||||
|
||||
# Output to GitHub Actions
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as f:
|
||||
f.write(f"matrix={json.dumps(matrix)}\n")
|
||||
f.write(f"config_file={config_file}\n")
|
||||
f.write(f"timeout_seconds={timeout_seconds}\n")
|
||||
f.write(f"environment={json.dumps(environment)}\n")
|
||||
|
||||
print(f"Matrix: {json.dumps(matrix)}")
|
||||
print(f"Config file: {config_file}")
|
||||
print(f"Timeout: {timeout_seconds}")
|
||||
print(f"Environment: {json.dumps(environment)}")
|
||||
156
.github/workflows/BENCH_USAGE.md
vendored
156
.github/workflows/BENCH_USAGE.md
vendored
@@ -1,156 +0,0 @@
|
||||
# Benchmark Workflow Usage
|
||||
|
||||
## Overview
|
||||
|
||||
The `bench_matrix.yml` workflow enables distributed benchmarking of models across multiple self-hosted macOS runners with different hardware configurations.
|
||||
|
||||
## Workflow Inputs
|
||||
|
||||
| Input | Description | Default | Required |
|
||||
|-------|-------------|---------|----------|
|
||||
| `model_id` | Model ID to benchmark | `mlx-community/Llama-3.2-1B-Instruct-4bit` | Yes |
|
||||
| `hardware_plan` | JSON mapping of runner labels to counts | `{"M4PRO_GPU16_24GB": 1}` | Yes |
|
||||
| `prompt` | Benchmark prompt text | `What is the capital of France?` | No |
|
||||
| `timeout_seconds` | Timeout for instance/runner readiness | `600` | No |
|
||||
|
||||
## Hardware Plan Format
|
||||
|
||||
The `hardware_plan` input is a JSON object mapping runner labels to the number of machines:
|
||||
|
||||
```json
|
||||
{
|
||||
"M4PRO_GPU16_24GB": 2,
|
||||
"M3ULTRA_GPU80_512GB": 1
|
||||
}
|
||||
```
|
||||
|
||||
This example would:
|
||||
- Start 2 runners with the `M4PRO_GPU16_24GB` label
|
||||
- Start 1 runner with the `M3ULTRA_GPU80_512GB` label
|
||||
- Total of 3 runners coordinating on a single distributed inference instance
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Planning Job** (`plan`)
|
||||
- Runs on `ubuntu-latest`
|
||||
- Parses the `hardware_plan` JSON
|
||||
- Generates a dynamic matrix with one entry per runner
|
||||
- Only the first runner (index 0) is marked as `is_primary`
|
||||
|
||||
2. **Benchmark Worker Jobs** (`bench_worker`)
|
||||
- Each job runs on a self-hosted macOS runner with the specified label
|
||||
- All runners start EXO in parallel
|
||||
- The primary runner creates the model instance
|
||||
- All runners wait for their assigned runner to be ready (Loaded/Running status)
|
||||
- The primary runner executes the benchmark and prints results
|
||||
- The primary runner deletes the instance
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Single Machine Benchmark
|
||||
|
||||
```yaml
|
||||
model_id: mlx-community/Llama-3.2-1B-Instruct-4bit
|
||||
hardware_plan: '{"M4PRO_GPU16_24GB": 1}'
|
||||
prompt: What is the capital of France?
|
||||
timeout_seconds: 600
|
||||
```
|
||||
|
||||
### Multi-Machine Distributed Benchmark
|
||||
|
||||
```yaml
|
||||
model_id: mlx-community/Llama-3.2-3B-Instruct-4bit
|
||||
hardware_plan: '{"M4PRO_GPU16_24GB": 2, "M3ULTRA_GPU80_512GB": 1}'
|
||||
prompt: Explain quantum computing in simple terms.
|
||||
timeout_seconds: 900
|
||||
```
|
||||
|
||||
## Benchmark Output
|
||||
|
||||
The primary runner outputs a JSON object with benchmark results:
|
||||
|
||||
```json
|
||||
{
|
||||
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
"instance_id": "abc-123-def",
|
||||
"tokens": 42,
|
||||
"elapsed_s": 2.451,
|
||||
"tps": 17.136
|
||||
}
|
||||
```
|
||||
|
||||
Where:
|
||||
- `tokens`: Number of chunks/tokens generated
|
||||
- `elapsed_s`: Total elapsed time in seconds
|
||||
- `tps`: Tokens per second (tokens / elapsed_s)
|
||||
|
||||
## Runner Requirements
|
||||
|
||||
Each self-hosted runner must:
|
||||
- Be labeled with appropriate hardware tags (e.g., `M4PRO_GPU16_24GB`)
|
||||
- Have the `self-hosted` and `macOS` labels
|
||||
- Have Nix installed with flakes enabled
|
||||
- Have network connectivity to other runners in the same job
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ GitHub Actions Workflow (bench_matrix.yml) │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌────────────────┐ │
|
||||
│ │ Plan Job │ │
|
||||
│ │ (ubuntu) │──┬─► Matrix: [{label, index, primary}] │
|
||||
│ └────────────────┘ │ │
|
||||
│ │ │
|
||||
│ ┌───────────────────▼──────────────────────────────────┐ │
|
||||
│ │ Bench Worker Jobs (Matrix) │ │
|
||||
│ ├──────────────────────────────────────────────────────┤ │
|
||||
│ │ │ │
|
||||
│ │ Runner 0 (Primary) Runner 1 Runner 2 │ │
|
||||
│ │ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ │
|
||||
│ │ │ Start EXO │ │ Start EXO │ │ Start EXO│ │ │
|
||||
│ │ │ Create Inst │ │ Wait... │ │ Wait... │ │ │
|
||||
│ │ │ Wait Ready │ │ Wait Ready │ │ Wait... │ │ │
|
||||
│ │ │ Run Bench │ │ (idle) │ │ (idle) │ │ │
|
||||
│ │ │ Print TPS │ │ │ │ │ │ │
|
||||
│ │ │ Delete Inst │ │ │ │ │ │ │
|
||||
│ │ └─────────────┘ └─────────────┘ └──────────┘ │ │
|
||||
│ └───────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### `scripts/bench.py`
|
||||
|
||||
A standalone Python script that:
|
||||
- Creates instance (primary only)
|
||||
- Polls `/state` endpoint until instance and all runners are ready
|
||||
- Executes chat completion with timing (primary only)
|
||||
- Parses SSE stream and counts tokens
|
||||
- Computes TPS metrics
|
||||
- Cleans up instance (primary only)
|
||||
|
||||
### Key Functions
|
||||
|
||||
- `wait_for_instance()`: Polls until instance with model_id appears
|
||||
- `wait_for_runners_ready()`: Polls until expected number of runners reach Loaded/Running status
|
||||
- `run_benchmark()`: Executes chat completion, measures time, counts tokens
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Instance never becomes ready
|
||||
- Check EXO logs in the workflow output
|
||||
- Verify model_id is valid and accessible
|
||||
- Increase `timeout_seconds`
|
||||
|
||||
### Runner mismatch
|
||||
- Ensure hardware_plan counts match available labeled runners
|
||||
- Check runner labels match exactly (case-sensitive)
|
||||
|
||||
### Network issues
|
||||
- Verify runners can communicate on the network
|
||||
- Check firewall rules between runner hosts
|
||||
|
||||
305
.github/workflows/bench.yml
vendored
305
.github/workflows/bench.yml
vendored
@@ -1,305 +0,0 @@
|
||||
name: bench
|
||||
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
plan:
|
||||
if: contains(github.event.head_commit.message, '/bench')
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.build.outputs.matrix }}
|
||||
config_file: ${{ steps.build.outputs.config_file }}
|
||||
timeout_seconds: ${{ steps.build.outputs.timeout_seconds }}
|
||||
environment: ${{ steps.build.outputs.environment }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Build matrix from config file
|
||||
id: build
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
CONFIG_FILE='.github/configs/bench_simple.yaml'
|
||||
export CONFIG_FILE
|
||||
echo "Config file: $CONFIG_FILE"
|
||||
python3 .github/scripts/build_matrix.py
|
||||
|
||||
bench_worker:
|
||||
needs: plan
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix: ${{ fromJSON(needs.plan.outputs.matrix) }}
|
||||
name: "bench on ${{ matrix.label }} [${{ matrix.index }}]"
|
||||
runs-on: [self-hosted, macOS, "${{ matrix.label }}"]
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
git config --local user.name "github-actions bot"
|
||||
shell: bash
|
||||
|
||||
# TODO: this is mega hacky and I'd like a simpler solution.
|
||||
- name: Setup Nix Environment
|
||||
run: |
|
||||
echo "Checking for nix installation..."
|
||||
|
||||
# Check if nix is already available
|
||||
if command -v nix >/dev/null 2>&1; then
|
||||
echo "Nix already in PATH"
|
||||
# Try sourcing profile scripts to set up environment properly
|
||||
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
echo "Sourcing multi-user nix-daemon profile script"
|
||||
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
elif [ -f "$HOME/.nix-profile/etc/profile.d/nix.sh" ]; then
|
||||
echo "Sourcing single-user nix profile script"
|
||||
source "$HOME/.nix-profile/etc/profile.d/nix.sh"
|
||||
elif [ -f /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh ]; then
|
||||
echo "Sourcing per-user nix profile script"
|
||||
source /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh
|
||||
elif [ -f /etc/profile.d/nix.sh ]; then
|
||||
echo "Sourcing system-wide nix profile script"
|
||||
source /etc/profile.d/nix.sh
|
||||
# Fallback: manually add nix to PATH if binary exists
|
||||
elif [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
echo "Found nix binary, manually adding to PATH"
|
||||
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
elif [ -f "$HOME/.nix-profile/bin/nix" ]; then
|
||||
echo "Found nix binary in user profile, manually adding to PATH"
|
||||
export PATH="$HOME/.nix-profile/bin:$PATH"
|
||||
else
|
||||
echo "Nix not found. Debugging info:"
|
||||
echo "USER: $USER"
|
||||
echo "HOME: $HOME"
|
||||
echo "Current PATH: $PATH"
|
||||
echo ""
|
||||
echo "Checking common Nix locations:"
|
||||
echo " /nix/var/nix/profiles/default/bin/nix:"
|
||||
ls -la /nix/var/nix/profiles/default/bin/nix 2>/dev/null || echo " Not found"
|
||||
echo " /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh:"
|
||||
ls -la /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh 2>/dev/null || echo " Not found"
|
||||
echo " ~/.nix-profile/etc/profile.d/nix.sh:"
|
||||
ls -la "$HOME/.nix-profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||
echo " /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh:"
|
||||
ls -la "/nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||
echo ""
|
||||
echo "/nix directory structure:"
|
||||
ls -la /nix 2>/dev/null || echo " /nix directory not found"
|
||||
echo ""
|
||||
echo "/nix/var:"
|
||||
ls -la /nix/var 2>/dev/null || echo " /nix/var not found"
|
||||
echo ""
|
||||
echo "/nix/store:"
|
||||
ls -la /nix/store 2>/dev/null | head -20 || echo " /nix/store not found"
|
||||
echo ""
|
||||
echo "GitHub Actions runner is running as user '$USER'."
|
||||
echo "If Nix is installed for a different user, either:"
|
||||
echo " 1. Install Nix for user '$USER' (multi-user install recommended)"
|
||||
echo " 2. Configure the runner service to run as the user with Nix installed"
|
||||
echo " 3. Ensure Nix is installed system-wide with proper daemon setup"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify nix is available and persist to GITHUB_ENV
|
||||
if command -v nix >/dev/null 2>&1; then
|
||||
echo "✓ Nix is available"
|
||||
nix --version
|
||||
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
if [ -n "$NIX_PATH" ]; then
|
||||
echo "NIX_PATH=$NIX_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
else
|
||||
echo "ERROR: Failed to set up Nix"
|
||||
echo "PATH after setup attempt: $PATH"
|
||||
exit 1
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Setup EXO_HOME and API_PORT
|
||||
run: |
|
||||
EXO_HOME=$(mktemp -d -t exo-e2e-XXXXXXXX)
|
||||
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||
EXO_MODELS_DIR="$HOME/.exo/models"
|
||||
EXO_LIBP2P_NAMESPACE="bench-${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
|
||||
echo "EXO_HOME=$EXO_HOME" >> "$GITHUB_ENV"
|
||||
echo "API_PORT=$API_PORT" >> "$GITHUB_ENV"
|
||||
echo "EXO_MODELS_DIR=$EXO_MODELS_DIR" >> "$GITHUB_ENV"
|
||||
echo "EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE" >> "$GITHUB_ENV"
|
||||
echo "Created EXO_HOME: $EXO_HOME"
|
||||
echo "Generated API_PORT: $API_PORT"
|
||||
echo "Using models from: $EXO_MODELS_DIR"
|
||||
echo "Using libp2p namespace: $EXO_LIBP2P_NAMESPACE"
|
||||
shell: bash
|
||||
|
||||
- name: Configure local MLX if available
|
||||
run: |
|
||||
echo "=== DEBUG: Checking for local MLX configuration ==="
|
||||
MODIFIED=false
|
||||
|
||||
echo "Checking for /Users/Shared/mlx directory..."
|
||||
if [ -d "/Users/Shared/mlx" ]; then
|
||||
echo "✓ Found /Users/Shared/mlx"
|
||||
ls -la /Users/Shared/mlx | head -5
|
||||
echo "Enabling local mlx path in pyproject.toml"
|
||||
sed -i.bak 's|^# mlx = { path = "/Users/Shared/mlx", editable=true }$|mlx = { path = "/Users/Shared/mlx", editable=true }|' pyproject.toml
|
||||
MODIFIED=true
|
||||
else
|
||||
echo "✗ /Users/Shared/mlx not found, will use PyPI version"
|
||||
fi
|
||||
|
||||
echo "Checking for /Users/Shared/mlx-lm directory..."
|
||||
if [ -d "/Users/Shared/mlx-lm" ]; then
|
||||
echo "✓ Found /Users/Shared/mlx-lm"
|
||||
ls -la /Users/Shared/mlx-lm | head -5
|
||||
echo "Enabling local mlx-lm path in pyproject.toml"
|
||||
sed -i.bak 's|^# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }$|mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }|' pyproject.toml
|
||||
MODIFIED=true
|
||||
else
|
||||
echo "✗ /Users/Shared/mlx-lm not found, will use PyPI version"
|
||||
fi
|
||||
|
||||
if [ "$MODIFIED" = true ]; then
|
||||
echo "=== Modified pyproject.toml [tool.uv.sources] section: ==="
|
||||
sed -n '/\[tool\.uv\.sources\]/,/^\[/{/^\[tool\.uv\.sources\]/p; /^\[/!p;}' pyproject.toml
|
||||
echo "=== Regenerating uv.lock with local MLX paths... ==="
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command uv lock --upgrade-package mlx --upgrade-package mlx-lm
|
||||
echo "✓ Lock file regenerated"
|
||||
else
|
||||
echo "⚠ No local MLX directories found, using PyPI packages"
|
||||
fi
|
||||
echo "=== DEBUG: Local MLX configuration complete ==="
|
||||
shell: bash
|
||||
|
||||
- name: Sync dependencies
|
||||
run: |
|
||||
if [ -d "/Users/Shared/test" ]; then
|
||||
pushd /Users/Shared/test
|
||||
uv sync --reinstall
|
||||
popd
|
||||
fi
|
||||
echo "Running just sync to ensure clean dependencies..."
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync
|
||||
shell: bash
|
||||
|
||||
- name: Start EXO and run bench script
|
||||
shell: bash
|
||||
env:
|
||||
IS_PRIMARY: ${{ matrix.is_primary }}
|
||||
EXPECTED_NODES: ${{ matrix.expected_nodes }}
|
||||
HARDWARE_LABEL: ${{ matrix.label }}
|
||||
CONFIG_FILE: ${{ needs.plan.outputs.config_file }}
|
||||
TIMEOUT_SECONDS: ${{ needs.plan.outputs.timeout_seconds }}
|
||||
ENVIRONMENT_JSON: ${{ needs.plan.outputs.environment }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Parse environment variables from config
|
||||
ENV_VARS=""
|
||||
if [ -n "$ENVIRONMENT_JSON" ] && [ "$ENVIRONMENT_JSON" != "{}" ]; then
|
||||
ENV_VARS=$(echo "$ENVIRONMENT_JSON" | python3 -c "import sys, json; env = json.load(sys.stdin); print(' '.join([f'{k}={v}' for k, v in env.items()]))")
|
||||
fi
|
||||
|
||||
echo "Starting EXO with API_PORT=${API_PORT} EXO_HOME=${EXO_HOME} EXO_LIBP2P_NAMESPACE=${EXO_LIBP2P_NAMESPACE}"
|
||||
echo "Environment variables from config: $ENV_VARS"
|
||||
LOG_FILE=/tmp/exo.log
|
||||
: > "$LOG_FILE"
|
||||
|
||||
MASTER_FLAG=""
|
||||
if [ "$IS_PRIMARY" = "true" ]; then
|
||||
MASTER_FLAG="-m"
|
||||
fi
|
||||
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||
"EXO_HOME=$EXO_HOME EXO_MODELS_DIR=$EXO_MODELS_DIR EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE $ENV_VARS PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run exo $MASTER_FLAG --api-port $API_PORT" \
|
||||
>> "$LOG_FILE" 2>&1 &
|
||||
|
||||
EXO_PID=$!
|
||||
echo "Started EXO in background with PID: $EXO_PID"
|
||||
echo "Log file: $LOG_FILE"
|
||||
|
||||
cleanup() {
|
||||
echo '=== EXO log (tail) ==='
|
||||
tail -n 300 "$LOG_FILE" || true
|
||||
if ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||
echo "Killing EXO (PID $EXO_PID)"
|
||||
kill "$EXO_PID" || true
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
for i in $(seq 1 60); do
|
||||
if curl -s "http://localhost:${API_PORT}/state" >/dev/null 2>&1; then
|
||||
echo "EXO API ready"
|
||||
break
|
||||
fi
|
||||
if ! ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||
echo "EXO terminated early"; sed -n '1,200p' "$LOG_FILE" || true; exit 1
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
RESULTS_FILE="/tmp/bench_results_${GITHUB_RUN_ID}_${GITHUB_RUN_ATTEMPT}_$(date +%s).json"
|
||||
echo "Results will be saved to: $RESULTS_FILE"
|
||||
echo "RESULTS_FILE=$RESULTS_FILE" >> "$GITHUB_ENV"
|
||||
|
||||
echo "Running bench script with config: $CONFIG_FILE, timeout: $TIMEOUT_SECONDS"
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||
"PYTHONUNBUFFERED=1 uv run --no-project --with pyyaml --with pydantic python .github/scripts/bench.py \
|
||||
--api-port $API_PORT \
|
||||
--config $CONFIG_FILE \
|
||||
--expected-nodes ${EXPECTED_NODES} \
|
||||
--is-primary ${IS_PRIMARY} \
|
||||
--timeout-seconds ${TIMEOUT_SECONDS} \
|
||||
--output $RESULTS_FILE \
|
||||
--git-commit ${GITHUB_SHA} \
|
||||
--hardware-labels ${HARDWARE_LABEL}"
|
||||
|
||||
- name: Install AWS CLI
|
||||
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||
run: |
|
||||
if ! command -v aws &> /dev/null; then
|
||||
echo "AWS CLI not found, installing..."
|
||||
brew install awscli
|
||||
else
|
||||
echo "AWS CLI already installed"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Upload results to S3
|
||||
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.S3_BENCHMARKS_AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
run: |
|
||||
echo "Checking for results file: $RESULTS_FILE"
|
||||
echo "Is primary: ${{ matrix.is_primary }}"
|
||||
|
||||
if [ -f "$RESULTS_FILE" ]; then
|
||||
TIMESTAMP=$(date -u +%Y/%m/%d/%H%M%S)
|
||||
S3_KEY="bench/${TIMESTAMP}_${GITHUB_SHA:0:8}_${GITHUB_RUN_ID}.json"
|
||||
echo "Uploading results to s3://exo-benchmark-results/$S3_KEY"
|
||||
|
||||
aws s3 cp "$RESULTS_FILE" "s3://exo-benchmark-results/$S3_KEY" \
|
||||
--content-type application/json \
|
||||
--metadata "commit=${GITHUB_SHA},run_id=${GITHUB_RUN_ID},branch=${GITHUB_REF_NAME}"
|
||||
|
||||
echo "Results uploaded successfully"
|
||||
echo "View at: https://exo-benchmark-results.s3.amazonaws.com/$S3_KEY"
|
||||
else
|
||||
echo "Results file not found at: $RESULTS_FILE"
|
||||
echo "Skipping upload"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Cleanup EXO_HOME
|
||||
run: |
|
||||
echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||
rm -rf "$EXO_HOME"
|
||||
shell: bash
|
||||
if: always()
|
||||
@@ -8,7 +8,7 @@
|
||||
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/72NsF6ux" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://x.com/exolabs" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/twitter/follow/exolabs?style=social" alt="X"></a>
|
||||
<a href="https://www.apache.org/licenses/LICENSE-2.0.html" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/License-Apache2.0-blue.svg" alt="License: Apache-2.0"></a>
|
||||
</p>
|
||||
|
||||
526
bench/exo_bench.py
Normal file
526
bench/exo_bench.py
Normal file
@@ -0,0 +1,526 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def format_peak_memory(b: float) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if b < 1024.0:
|
||||
return f"{b:.2f}{unit}"
|
||||
b /= 1024.0
|
||||
raise ValueError("You're using petabytes of memory. Something went wrong...")
|
||||
|
||||
|
||||
def parse_int_list(values: list[str]) -> list[int]:
|
||||
items: list[int] = []
|
||||
for v in values:
|
||||
for part in v.split(","):
|
||||
part = part.strip()
|
||||
if part:
|
||||
items.append(int(part))
|
||||
|
||||
seen: set[int] = set()
|
||||
out: list[int] = []
|
||||
for x in items:
|
||||
if x not in seen:
|
||||
out.append(x)
|
||||
seen.add(x)
|
||||
return out
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if m.get("id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
full_id = str(m.get("hugging_face_id") or m["id"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def run_one_completion(
|
||||
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
content, pp_tokens = prompt_sizer.build(pp_hint)
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"max_tokens": tg,
|
||||
}
|
||||
|
||||
t0 = time.perf_counter()
|
||||
out = client.post_bench_chat_completions(payload)
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
stats = out.get("generation_stats")
|
||||
|
||||
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
|
||||
|
||||
return {
|
||||
"elapsed_s": elapsed,
|
||||
"output_text_preview": preview,
|
||||
"stats": stats,
|
||||
}, pp_tokens
|
||||
|
||||
|
||||
class PromptSizer:
|
||||
def __init__(self, tokenizer: Any, atom: str = "a "):
|
||||
self.tokenizer = tokenizer
|
||||
self.atom = atom
|
||||
self.count_fn = PromptSizer._make_counter(tokenizer)
|
||||
self.base_tokens = self.count_fn("")
|
||||
|
||||
@staticmethod
|
||||
def _make_counter(tokenizer: Any) -> Callable[[str], int]:
|
||||
def count_fn(user_content: str) -> int:
|
||||
messages = [{"role": "user", "content": user_content}]
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
|
||||
def build(self, target_prompt_tokens: int) -> tuple[str, int]:
|
||||
target = int(target_prompt_tokens)
|
||||
if target < self.base_tokens:
|
||||
raise RuntimeError(
|
||||
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
|
||||
)
|
||||
|
||||
content = ""
|
||||
tok = self.count_fn(content)
|
||||
|
||||
while tok < target:
|
||||
content += self.atom
|
||||
tok = self.count_fn(content)
|
||||
|
||||
if tok != target:
|
||||
raise RuntimeError(
|
||||
f"Overshot: got {tok} tokens (target {target}). "
|
||||
f"Pick a different atom (try ' a' or '\\n' or '0 ')."
|
||||
)
|
||||
|
||||
return content, tok
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
description="Benchmark exo model throughput across placement previews.",
|
||||
)
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--pp",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Prompt-size hints (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--tg",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Generation lengths (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Pipeline jaccl is often pointless, skip by default",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--warmup",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
default="bench/results.json",
|
||||
help="Write raw per-run results JSON to this path.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dry-run", action="store_true", help="List selected placements and exit."
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
tg_list = parse_int_list(args.tg)
|
||||
if not pp_list or not tg_list:
|
||||
logger.error("pp and tg lists must be non-empty")
|
||||
return 2
|
||||
if args.repeat <= 0:
|
||||
logger.error("--repeat must be >= 1")
|
||||
return 2
|
||||
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": short_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
full_model_id,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("[exo-bench] tokenizer load failed")
|
||||
|
||||
try:
|
||||
prompt_sizer = PromptSizer(tokenizer)
|
||||
logger.debug(f"[exo-bench] loaded tokenizer: {full_model_id} for prompt sizer")
|
||||
except Exception:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if 0 < n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
|
||||
selected.sort(
|
||||
key=lambda p: (
|
||||
str(p.get("instance_meta", "")),
|
||||
str(p.get("sharding", "")),
|
||||
-nodes_used_in_instance(p["instance"]),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
|
||||
logger.info(f"placements: {len(selected)}")
|
||||
for p in selected:
|
||||
logger.info(
|
||||
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
|
||||
)
|
||||
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
|
||||
sharding = str(preview["sharding"])
|
||||
instance_meta = str(preview["instance_meta"])
|
||||
n_nodes = nodes_used_in_instance(instance)
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info(
|
||||
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
for i in range(args.warmup):
|
||||
run_one_completion(
|
||||
client, full_model_id, pp_list[0], tg_list[0], prompt_sizer
|
||||
)
|
||||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||||
|
||||
for pp in pp_list:
|
||||
if (
|
||||
pp * n_nodes > 2048
|
||||
and "ring" in instance_meta.lower()
|
||||
and "tensor" in sharding.lower()
|
||||
):
|
||||
model_card = MODEL_CARDS[short_id]
|
||||
if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
logger.info(
|
||||
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
)
|
||||
continue
|
||||
for tg in tg_list:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
time.sleep(3)
|
||||
try:
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client, full_model_id, pp, tg, prompt_sizer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
row.update(
|
||||
{
|
||||
"model_short_id": short_id,
|
||||
"model_id": full_model_id,
|
||||
"placement_sharding": sharding,
|
||||
"placement_instance_meta": instance_meta,
|
||||
"placement_nodes": n_nodes,
|
||||
"instance_id": instance_id,
|
||||
"pp_tokens": actual_pp_tokens,
|
||||
"tg": tg,
|
||||
"repeat_index": r,
|
||||
}
|
||||
)
|
||||
runs.append(row)
|
||||
all_rows.append(row)
|
||||
|
||||
if runs:
|
||||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||||
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
f"peak_memory={format_peak_memory(peak)}\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
finally:
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
logger.debug(f"Deleted instance {instance_id}")
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
if args.json_out:
|
||||
with open(args.json_out, "w", encoding="utf-8") as f:
|
||||
json.dump(all_rows, f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"\nWrote results JSON: {args.json_out}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -82,7 +82,7 @@ build-backend = "uv_build"
|
||||
###
|
||||
|
||||
[tool.basedpyright]
|
||||
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"]
|
||||
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src", "bench"]
|
||||
typeCheckingMode = "strict"
|
||||
failOnWarnings = true
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
@@ -34,6 +36,7 @@ from exo.shared.types.api import (
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -172,6 +175,7 @@ class API:
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -490,6 +494,45 @@ class API:
|
||||
],
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
resp = BenchChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=combined_text
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
generation_stats=stats,
|
||||
)
|
||||
return resp
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
@@ -525,6 +568,33 @@ class API:
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(payload.model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {payload.model}"
|
||||
)
|
||||
|
||||
payload.stream = False
|
||||
|
||||
command = ChatCompletion(request_params=payload)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_chat_completion_with_stats(
|
||||
command.command_id,
|
||||
parse_gpt_oss,
|
||||
)
|
||||
return response
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
@@ -51,6 +52,10 @@ class ChatCompletionMessage(BaseModel):
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class BenchChatCompletionMessage(ChatCompletionMessage):
|
||||
pass
|
||||
|
||||
|
||||
class TopLogprobItem(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
@@ -113,6 +118,18 @@ class ChatCompletionResponse(BaseModel):
|
||||
service_tier: str | None = None
|
||||
|
||||
|
||||
class GenerationStats(BaseModel):
|
||||
prompt_tps: float
|
||||
generation_tps: float
|
||||
prompt_tokens: int
|
||||
generation_tokens: int
|
||||
peak_memory_usage: Memory
|
||||
|
||||
|
||||
class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
@@ -135,6 +152,10 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
user: str | None = None
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
pass
|
||||
|
||||
|
||||
class PlaceInstanceParams(BaseModel):
|
||||
model_id: str
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -20,6 +21,7 @@ class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
token: int
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -6,7 +6,13 @@ from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.shared.types.api import ChatCompletionMessage, FinishReason
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
@@ -72,7 +78,7 @@ def warmup_inference(
|
||||
max_tokens=50,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=65536,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
@@ -80,17 +86,42 @@ def warmup_inference(
|
||||
tokens_generated += 1
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
|
||||
return tokens_generated
|
||||
|
||||
|
||||
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
|
||||
token_ids = [int(t) for t in token_ids]
|
||||
|
||||
def proc(_history: mx.array, logits: mx.array) -> mx.array:
|
||||
for tid in token_ids:
|
||||
logits[..., tid] = -1e9
|
||||
return logits
|
||||
|
||||
return proc
|
||||
|
||||
|
||||
def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
eos: list[int] | None = getattr(tokenizer, "eos_token_ids", None)
|
||||
if eos is None:
|
||||
return []
|
||||
return eos
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
task: ChatCompletionTaskParams,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
|
||||
|
||||
# Currently we support chat-completion tasks only.
|
||||
logger.info(f"task_params: {task}")
|
||||
|
||||
@@ -101,6 +132,12 @@ def mlx_generate(
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
eos_ids = eos_ids_from_tokenizer(tokenizer)
|
||||
logits_processors = [ban_token_ids(eos_ids)]
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
@@ -108,26 +145,40 @@ def mlx_generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
prefill_step_size=65536,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(out.text)
|
||||
if out.finish_reason is not None and out.finish_reason not in get_args(
|
||||
FinishReason
|
||||
):
|
||||
# We don't throw here as this failure case is really not all that bad
|
||||
# Just log the error and move on
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
|
||||
if out.finish_reason not in get_args(FinishReason):
|
||||
# We don't throw here as this failure case is really not all that bad
|
||||
# Just log the error and move on
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -397,3 +397,13 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
)
|
||||
mx.set_wired_limit(max_rec_size)
|
||||
logger.info(f"Wired limit set to {max_rec_size}.")
|
||||
|
||||
|
||||
def mlx_cleanup(
|
||||
model: Model | None, tokenizer: TokenizerWrapper | None, group: Group | None
|
||||
) -> None:
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
@@ -41,6 +41,7 @@ from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_infer
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_cleanup,
|
||||
mlx_force_oom,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -179,6 +180,7 @@ def main(
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -190,6 +192,7 @@ def main(
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
mlx_cleanup(model, tokenizer, group)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
networksetup -listallnetworkservices | grep -q '^Thunderbolt Bridge$' \
|
||||
&& echo "Disabling bridge in networksetup" \
|
||||
&& networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
|
||||
networksetup -listallnetworkservices | grep -q '^\*Thunderbolt Bridge$' \
|
||||
&& echo "Bridge disabled in networksetup"
|
||||
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && echo "Removing bridge members in ifconfig" && {
|
||||
ifconfig bridge0 | \
|
||||
awk '/member/ {print $2}' | \
|
||||
xargs -n1 sudo ifconfig bridge0 deletem
|
||||
}
|
||||
ifconfig bridge0 | grep -q 'status: active' && sudo ifconfig bridge0 down
|
||||
ifconfig bridge0 | grep -q 'status: inactive' && echo "Bridge disabled in ifconfig"
|
||||
}
|
||||
|
||||
for iface in $(seq 2 7); do
|
||||
sudo ipconfig set "en$iface" dhcp && echo "enabled dhcp on en$iface" || echo "failed to enable dhcp on en$iface"
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user