mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-09 06:31:18 -05:00
Compare commits
13 Commits
ciaran/ima
...
test-app
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7dc0907df7 | ||
|
|
c65320acd3 | ||
|
|
b9a78f6f3a | ||
|
|
8f7f0e893a | ||
|
|
4759b09d4c | ||
|
|
ca680185f3 | ||
|
|
383309e24e | ||
|
|
55463a9806 | ||
|
|
56af61fac9 | ||
|
|
f76d543d98 | ||
|
|
ea841aca37 | ||
|
|
077b1bc732 | ||
|
|
4963c33162 |
159
.github/benchmark-dashboard/README.md
vendored
159
.github/benchmark-dashboard/README.md
vendored
@@ -1,159 +0,0 @@
|
||||
# EXO Benchmark Dashboard
|
||||
|
||||
A fully self-contained, browser-based dashboard for tracking EXO benchmark performance over time.
|
||||
|
||||
## Features
|
||||
|
||||
- 📊 **Success Rate Tracking**: Monitor cluster reliability across commits
|
||||
- ⚡ **Response Time Analysis**: Track average request completion times
|
||||
- 🎯 **Throughput Metrics**: Tokens per second visualization
|
||||
- 📈 **Request Distribution**: Success/failure breakdown over time
|
||||
- 🔄 **Auto-Refresh**: Updates every 60 seconds
|
||||
- 📺 **TV-Ready**: Large, clear visualizations perfect for display
|
||||
- 🔐 **Secure**: Credentials stored in browser localStorage only
|
||||
- 🌐 **No Backend**: Directly accesses S3 from the browser
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Option 1: Direct File Access (Simplest)
|
||||
|
||||
Just open the HTML file directly in your browser:
|
||||
|
||||
```bash
|
||||
open .github/benchmark-dashboard/index.html
|
||||
```
|
||||
|
||||
Then click "Configure AWS Credentials" and enter your keys.
|
||||
|
||||
### Option 2: URL Parameters (For Quick Setup)
|
||||
|
||||
```bash
|
||||
# Serve with credentials in URL (they'll be moved to localStorage)
|
||||
open ".github/benchmark-dashboard/index.html?accessKey=YOUR_KEY&secretKey=YOUR_SECRET®ion=us-east-1"
|
||||
```
|
||||
|
||||
The credentials will be saved to localStorage and removed from the URL immediately.
|
||||
|
||||
### Option 3: Simple HTTP Server
|
||||
|
||||
```bash
|
||||
# From repo root
|
||||
python3 -m http.server 8080
|
||||
|
||||
# Then open: http://localhost:8080/.github/benchmark-dashboard/
|
||||
```
|
||||
|
||||
## AWS Credentials
|
||||
|
||||
The dashboard needs read-only access to the `exo-benchmark-results` S3 bucket.
|
||||
|
||||
### Required IAM Permissions
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::exo-benchmark-results",
|
||||
"arn:aws:s3:::exo-benchmark-results/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Security Notes
|
||||
|
||||
- ✅ Credentials stored in browser `localStorage` only
|
||||
- ✅ Never sent to any server (except AWS)
|
||||
- ✅ All S3 access happens client-side
|
||||
- ✅ Use read-only IAM credentials
|
||||
- ⚠️ Don't commit credentials to git
|
||||
- ⚠️ Use a dedicated read-only IAM user
|
||||
|
||||
## TV/Kiosk Mode
|
||||
|
||||
For permanent display on a TV:
|
||||
|
||||
### macOS
|
||||
```bash
|
||||
open -a "Google Chrome" --args --kiosk ".github/benchmark-dashboard/index.html"
|
||||
```
|
||||
|
||||
### Linux
|
||||
```bash
|
||||
chromium-browser --kiosk --app="file://$(pwd)/.github/benchmark-dashboard/index.html"
|
||||
```
|
||||
|
||||
### Auto-start on Boot
|
||||
|
||||
Create a simple startup script:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# /usr/local/bin/start-benchmark-dashboard.sh
|
||||
|
||||
cd /path/to/exo
|
||||
python3 -m http.server 8080 &
|
||||
sleep 2
|
||||
chromium-browser --kiosk http://localhost:8080/.github/benchmark-dashboard/
|
||||
```
|
||||
|
||||
## Data Displayed
|
||||
|
||||
### Summary Cards
|
||||
- **Latest Success Rate**: Most recent benchmark success percentage with trend
|
||||
- **Avg Response Time**: Latest average response time in ms with trend
|
||||
- **Total Benchmarks**: Count of all benchmarks run
|
||||
- **Active Configurations**: Number of unique benchmark configs
|
||||
|
||||
### Charts
|
||||
1. **Success Rate Over Time**: Line chart showing reliability trends
|
||||
2. **Average Response Time**: Performance over time (lower is better)
|
||||
3. **Throughput**: Tokens/second metric (higher is better)
|
||||
4. **Request Distribution**: Stacked bar chart of successes/failures
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Loads AWS SDK**: Uses AWS SDK for JavaScript (browser version)
|
||||
2. **Lists S3 Objects**: Fetches all files from `s3://exo-benchmark-results/bench/`
|
||||
3. **Downloads Results**: Fetches each JSON result file
|
||||
4. **Parses & Visualizes**: Uses Chart.js to create interactive charts
|
||||
5. **Auto-Refreshes**: Polls S3 every 60 seconds for new results
|
||||
|
||||
## Customization
|
||||
|
||||
To modify the dashboard:
|
||||
|
||||
1. Edit `index.html`
|
||||
2. Adjust `REFRESH_INTERVAL` for different polling frequency
|
||||
3. Modify chart colors/styles in the Chart.js configuration
|
||||
4. Add new metrics by extending the results parsing
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**"AWS credentials not configured"**
|
||||
- Click "Configure AWS Credentials" and enter your keys
|
||||
|
||||
**"Error loading benchmark data"**
|
||||
- Check AWS credentials are correct
|
||||
- Verify S3 bucket name is `exo-benchmark-results`
|
||||
- Ensure IAM user has read permissions
|
||||
- Check browser console for detailed errors
|
||||
|
||||
**"No benchmark results found"**
|
||||
- Wait for benchmark workflows to run
|
||||
- Verify results are being uploaded to S3
|
||||
- Check S3 bucket has files in `bench/` prefix
|
||||
|
||||
**Charts not updating**
|
||||
- Check browser console for errors
|
||||
- Verify network connectivity to S3
|
||||
- Try refreshing the page manually
|
||||
|
||||
1641
.github/benchmark-dashboard/index.html
vendored
1641
.github/benchmark-dashboard/index.html
vendored
File diff suppressed because it is too large
Load Diff
186
.github/configs/README.md
vendored
186
.github/configs/README.md
vendored
@@ -1,186 +0,0 @@
|
||||
# EXO Benchmark Configurations
|
||||
|
||||
This directory contains configuration files for the EXO staged benchmark system.
|
||||
|
||||
## Overview
|
||||
|
||||
The staged benchmark system allows you to run complex, multi-stage load tests against EXO clusters. Each stage can have different characteristics:
|
||||
|
||||
- **Prompt Length**: Number of tokens in the input prompt
|
||||
- **Generation Length**: Maximum tokens to generate in the response
|
||||
- **Time Between Requests**: Delay (in seconds) between firing consecutive requests
|
||||
- **Iterations**: Number of requests to send in this stage
|
||||
|
||||
Requests are **fire-and-forget** - they don't wait for the previous request to complete. This allows you to test overlapping request handling and measure success rates under load.
|
||||
|
||||
## Configuration Files
|
||||
|
||||
### `bench_simple.yaml`
|
||||
A minimal configuration that replicates the behavior of the original `bench.py` script:
|
||||
- Single stage with 1 iteration
|
||||
- Short prompt (~20 tokens)
|
||||
- Generates up to 100 tokens
|
||||
|
||||
This is useful for quick smoke tests.
|
||||
|
||||
### `bench_config.yaml`
|
||||
A comprehensive multi-stage benchmark with:
|
||||
1. **Warmup** (10 requests): Light load with short prompts
|
||||
2. **Medium Load** (20 requests): Moderate load with medium prompts
|
||||
3. **Stress Test** (30 requests): Heavy overlapping requests with long prompts
|
||||
4. **Cooldown** (5 requests): Light load to wind down
|
||||
|
||||
This tests the cluster's behavior under varying load patterns.
|
||||
|
||||
## Configuration Schema
|
||||
|
||||
```yaml
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 4
|
||||
|
||||
# Environment variables to set on each node (optional)
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 512
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 600
|
||||
|
||||
# Model instances to run concurrently
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
# Benchmark stages
|
||||
stages:
|
||||
- name: "stage_name" # Human-readable name for this stage
|
||||
prompt_length: 100 # Target prompt length in tokens
|
||||
generation_length: 200 # Max tokens to generate
|
||||
time_between_requests: 2.0 # Seconds between firing requests
|
||||
iterations: 10 # Number of requests in this stage
|
||||
```
|
||||
|
||||
## Running Benchmarks
|
||||
|
||||
### Via GitHub Actions
|
||||
|
||||
**Automatic (every commit):**
|
||||
- The **`bench`** workflow runs automatically on every push
|
||||
- Uses `bench_simple.yaml` as the default configuration
|
||||
- All settings (hardware plan, timeout, environment variables, models, stages) are defined in the config file
|
||||
|
||||
**Manual (on-demand):**
|
||||
1. Go to **Actions** → **bench** workflow
|
||||
2. Click **Run workflow**
|
||||
3. Configure:
|
||||
- **Config File**: Path to your YAML config (default: `.github/configs/bench_simple.yaml`)
|
||||
- `.github/configs/bench_simple.yaml` for quick tests
|
||||
- `.github/configs/bench_config.yaml` for complex multi-stage tests
|
||||
|
||||
All other settings (hardware plan, timeout, environment variables, models, stages) are read from the specified config file.
|
||||
|
||||
### Via Command Line
|
||||
|
||||
```bash
|
||||
# Start EXO on localhost:8000
|
||||
uv run exo --api-port 8000
|
||||
|
||||
# Run simple benchmark (1 stage, 1 iteration)
|
||||
python3 .github/scripts/bench.py \
|
||||
--api-port 8000 \
|
||||
--config .github/configs/bench_simple.yaml \
|
||||
--expected-nodes 1 \
|
||||
--is-primary true \
|
||||
--timeout-seconds 600
|
||||
|
||||
# Run complex staged benchmark (4 stages, multiple iterations)
|
||||
python3 .github/scripts/bench.py \
|
||||
--api-port 8000 \
|
||||
--config .github/configs/bench_config.yaml \
|
||||
--expected-nodes 1 \
|
||||
--is-primary true \
|
||||
--timeout-seconds 600
|
||||
```
|
||||
|
||||
## Output Metrics
|
||||
|
||||
For each stage, the benchmark reports:
|
||||
|
||||
- **Total Requests**: Number of requests fired
|
||||
- **Successful Requests**: Requests that completed successfully
|
||||
- **Failed Requests**: Requests that encountered errors
|
||||
- **Success Rate**: Percentage of successful requests
|
||||
- **Total Tokens**: Sum of all tokens generated across successful requests
|
||||
- **Avg Tokens/Request**: Average tokens per successful request
|
||||
- **Avg Time/Request**: Average completion time per successful request
|
||||
|
||||
A JSON summary is also printed for easy parsing and storage.
|
||||
|
||||
## Creating Custom Benchmarks
|
||||
|
||||
To create a custom benchmark:
|
||||
|
||||
1. Copy an existing config file (e.g., `bench_config.yaml`)
|
||||
2. Modify the stages to match your test scenario
|
||||
3. Save it in this directory with a descriptive name
|
||||
4. Run it using the workflow or command line
|
||||
|
||||
### Example: Sustained Load Test
|
||||
|
||||
```yaml
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 2
|
||||
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 1024
|
||||
|
||||
timeout_seconds: 600
|
||||
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
stages:
|
||||
- name: "sustained_load"
|
||||
prompt_length: 200
|
||||
generation_length: 150
|
||||
time_between_requests: 0.5 # Very fast - 2 requests/second
|
||||
iterations: 100 # Run for ~50 seconds
|
||||
```
|
||||
|
||||
### Example: Varying Prompt Sizes
|
||||
|
||||
```yaml
|
||||
hardware_plan:
|
||||
M4PRO_GPU16_24GB: 3
|
||||
|
||||
timeout_seconds: 900
|
||||
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
stages:
|
||||
- name: "tiny_prompts"
|
||||
prompt_length: 10
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
|
||||
- name: "medium_prompts"
|
||||
prompt_length: 200
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
|
||||
- name: "large_prompts"
|
||||
prompt_length: 1000
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
- **Overlapping Requests**: Set `time_between_requests` < expected completion time to test concurrent request handling
|
||||
- **Sequential Requests**: Set `time_between_requests` > expected completion time to ensure requests don't overlap
|
||||
- **Realistic Load**: Model real usage patterns by varying prompt/generation lengths across stages
|
||||
- **Success Rate**: A 100% success rate indicates the cluster handled the load well; lower rates suggest capacity limits
|
||||
|
||||
49
.github/configs/bench_config.yaml
vendored
49
.github/configs/bench_config.yaml
vendored
@@ -1,49 +0,0 @@
|
||||
# EXO Staged Benchmark Configuration
|
||||
# This configuration defines a multi-stage load test for EXO clusters
|
||||
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 4
|
||||
|
||||
# Environment variables to set on each node (optional)
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 512
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 600
|
||||
|
||||
# Multiple instances run concurrently on the cluster
|
||||
model_ids:
|
||||
- "mlx-community/Qwen3-0.6B-4bit"
|
||||
- "mlx-community/Qwen3-0.6B-4bit"
|
||||
|
||||
# Stages run sequentially, each with its own characteristics
|
||||
stages:
|
||||
# Stage 1: Light load with short prompts
|
||||
- name: "warmup"
|
||||
prompt_length: 50 # Number of tokens in prompt
|
||||
generation_length: 100 # Max tokens to generate
|
||||
time_between_requests: 5.0 # Seconds between firing requests
|
||||
iterations: 10 # Number of requests to send in this stage
|
||||
|
||||
# Stage 2: Medium load with medium prompts
|
||||
- name: "medium_load"
|
||||
prompt_length: 200
|
||||
generation_length: 150
|
||||
time_between_requests: 3.0
|
||||
iterations: 20
|
||||
|
||||
# Stage 3: Heavy load with long prompts - requests will overlap
|
||||
- name: "stress_test"
|
||||
prompt_length: 500
|
||||
generation_length: 200
|
||||
time_between_requests: 1.0 # Fast firing - will definitely overlap
|
||||
iterations: 30
|
||||
|
||||
# Stage 4: Cool down with simple prompts
|
||||
- name: "cooldown"
|
||||
prompt_length: 50
|
||||
generation_length: 50
|
||||
time_between_requests: 10.0
|
||||
iterations: 5
|
||||
|
||||
125
.github/configs/bench_simple.yaml
vendored
125
.github/configs/bench_simple.yaml
vendored
@@ -1,125 +0,0 @@
|
||||
# Simple single-shot benchmark
|
||||
# Tests 2 instances concurrently on 2 nodes
|
||||
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
puffin4: 1
|
||||
puffin8: 1
|
||||
|
||||
# Environment variables to set on each node
|
||||
environment:
|
||||
PLACEHOLDER: "placeholder"
|
||||
# OVERRIDE_MEMORY_MB: 50000
|
||||
MLX_METAL_FAST_SYNCH: 1
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 1800
|
||||
|
||||
# Model instances to run concurrently
|
||||
model_ids:
|
||||
# - "mlx-community/DeepSeek-V3.1-8bit"
|
||||
# - "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
- "mlx-community/Kimi-K2-Thinking"
|
||||
# - "mlx-community/Qwen3-235B-A22B-4bit"
|
||||
# - "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
# - "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
# Sharding strategy: "Pipeline" or "Tensor"
|
||||
sharding: "Tensor"
|
||||
|
||||
# Instance type: "MlxRing" or "MlxIbv"
|
||||
instance_meta: "MlxIbv"
|
||||
|
||||
# If true, run requests sequentially (no overlap); if false, fire-and-forget (default: false)
|
||||
no_overlap: true
|
||||
|
||||
# Benchmark stages
|
||||
# pp: 64, 256, 1024, 2048, 4096, 8192, 16384
|
||||
# g: 64, 512
|
||||
stages:
|
||||
# - name: "simple"
|
||||
# prompt_length: 512
|
||||
# generation_length: 10
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g64"
|
||||
# prompt_length: 64
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g64"
|
||||
# prompt_length: 64
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g512"
|
||||
# prompt_length: 64
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp256_g64"
|
||||
# prompt_length: 256
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
- name: "pp256_g64"
|
||||
prompt_length: 256
|
||||
generation_length: 64
|
||||
time_between_requests: 2.0
|
||||
iterations: 5
|
||||
# - name: "pp256_g512"
|
||||
# prompt_length: 256
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp1024_g64"
|
||||
# prompt_length: 1024
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp1024_g512"
|
||||
# prompt_length: 1024
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp2048_g64"
|
||||
# prompt_length: 2048
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp2048_g512"
|
||||
# prompt_length: 2048
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp4096_g64"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 4
|
||||
# - name: "pp4096_g512"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp8192_g64"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp8192_g512"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp16384_g64"
|
||||
# prompt_length: 16384
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp16384_g512"
|
||||
# prompt_length: 16384
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
1399
.github/scripts/bench.py
vendored
1399
.github/scripts/bench.py
vendored
File diff suppressed because it is too large
Load Diff
70
.github/scripts/build_matrix.py
vendored
70
.github/scripts/build_matrix.py
vendored
@@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class MatrixEntry(TypedDict):
|
||||
label: str
|
||||
index: int
|
||||
|
||||
|
||||
class MatrixInclude(TypedDict):
|
||||
label: str
|
||||
index: int
|
||||
is_primary: bool
|
||||
expected_nodes: int
|
||||
|
||||
|
||||
class Config(TypedDict):
|
||||
hardware_plan: dict[str, int]
|
||||
timeout_seconds: NotRequired[int]
|
||||
environment: NotRequired[dict[str, str]]
|
||||
|
||||
|
||||
# Read the config file
|
||||
config_file: str = os.environ["CONFIG_FILE"]
|
||||
with open(config_file, "r") as f:
|
||||
config: Config = cast(Config, yaml.safe_load(f))
|
||||
|
||||
# Extract hardware plan from config
|
||||
plan: dict[str, int] = config["hardware_plan"]
|
||||
if not plan:
|
||||
raise ValueError(f"No hardware_plan found in {config_file}")
|
||||
|
||||
# Build matrix entries
|
||||
entries: list[MatrixEntry] = []
|
||||
for label, count in plan.items():
|
||||
for idx in range(count):
|
||||
entries.append({"label": label, "index": idx})
|
||||
|
||||
total_nodes: int = len(entries)
|
||||
matrix: dict[str, list[MatrixInclude]] = {
|
||||
"include": [
|
||||
{
|
||||
"label": e["label"],
|
||||
"index": e["index"],
|
||||
"is_primary": (i == 0),
|
||||
"expected_nodes": total_nodes,
|
||||
}
|
||||
for i, e in enumerate(entries)
|
||||
]
|
||||
}
|
||||
|
||||
# Extract other config values
|
||||
timeout_seconds: int = config.get("timeout_seconds", 600)
|
||||
environment: dict[str, str] = config.get("environment", {})
|
||||
|
||||
# Output to GitHub Actions
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as f:
|
||||
f.write(f"matrix={json.dumps(matrix)}\n")
|
||||
f.write(f"config_file={config_file}\n")
|
||||
f.write(f"timeout_seconds={timeout_seconds}\n")
|
||||
f.write(f"environment={json.dumps(environment)}\n")
|
||||
|
||||
print(f"Matrix: {json.dumps(matrix)}")
|
||||
print(f"Config file: {config_file}")
|
||||
print(f"Timeout: {timeout_seconds}")
|
||||
print(f"Environment: {json.dumps(environment)}")
|
||||
156
.github/workflows/BENCH_USAGE.md
vendored
156
.github/workflows/BENCH_USAGE.md
vendored
@@ -1,156 +0,0 @@
|
||||
# Benchmark Workflow Usage
|
||||
|
||||
## Overview
|
||||
|
||||
The `bench_matrix.yml` workflow enables distributed benchmarking of models across multiple self-hosted macOS runners with different hardware configurations.
|
||||
|
||||
## Workflow Inputs
|
||||
|
||||
| Input | Description | Default | Required |
|
||||
|-------|-------------|---------|----------|
|
||||
| `model_id` | Model ID to benchmark | `mlx-community/Llama-3.2-1B-Instruct-4bit` | Yes |
|
||||
| `hardware_plan` | JSON mapping of runner labels to counts | `{"M4PRO_GPU16_24GB": 1}` | Yes |
|
||||
| `prompt` | Benchmark prompt text | `What is the capital of France?` | No |
|
||||
| `timeout_seconds` | Timeout for instance/runner readiness | `600` | No |
|
||||
|
||||
## Hardware Plan Format
|
||||
|
||||
The `hardware_plan` input is a JSON object mapping runner labels to the number of machines:
|
||||
|
||||
```json
|
||||
{
|
||||
"M4PRO_GPU16_24GB": 2,
|
||||
"M3ULTRA_GPU80_512GB": 1
|
||||
}
|
||||
```
|
||||
|
||||
This example would:
|
||||
- Start 2 runners with the `M4PRO_GPU16_24GB` label
|
||||
- Start 1 runner with the `M3ULTRA_GPU80_512GB` label
|
||||
- Total of 3 runners coordinating on a single distributed inference instance
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Planning Job** (`plan`)
|
||||
- Runs on `ubuntu-latest`
|
||||
- Parses the `hardware_plan` JSON
|
||||
- Generates a dynamic matrix with one entry per runner
|
||||
- Only the first runner (index 0) is marked as `is_primary`
|
||||
|
||||
2. **Benchmark Worker Jobs** (`bench_worker`)
|
||||
- Each job runs on a self-hosted macOS runner with the specified label
|
||||
- All runners start EXO in parallel
|
||||
- The primary runner creates the model instance
|
||||
- All runners wait for their assigned runner to be ready (Loaded/Running status)
|
||||
- The primary runner executes the benchmark and prints results
|
||||
- The primary runner deletes the instance
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Single Machine Benchmark
|
||||
|
||||
```yaml
|
||||
model_id: mlx-community/Llama-3.2-1B-Instruct-4bit
|
||||
hardware_plan: '{"M4PRO_GPU16_24GB": 1}'
|
||||
prompt: What is the capital of France?
|
||||
timeout_seconds: 600
|
||||
```
|
||||
|
||||
### Multi-Machine Distributed Benchmark
|
||||
|
||||
```yaml
|
||||
model_id: mlx-community/Llama-3.2-3B-Instruct-4bit
|
||||
hardware_plan: '{"M4PRO_GPU16_24GB": 2, "M3ULTRA_GPU80_512GB": 1}'
|
||||
prompt: Explain quantum computing in simple terms.
|
||||
timeout_seconds: 900
|
||||
```
|
||||
|
||||
## Benchmark Output
|
||||
|
||||
The primary runner outputs a JSON object with benchmark results:
|
||||
|
||||
```json
|
||||
{
|
||||
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
"instance_id": "abc-123-def",
|
||||
"tokens": 42,
|
||||
"elapsed_s": 2.451,
|
||||
"tps": 17.136
|
||||
}
|
||||
```
|
||||
|
||||
Where:
|
||||
- `tokens`: Number of chunks/tokens generated
|
||||
- `elapsed_s`: Total elapsed time in seconds
|
||||
- `tps`: Tokens per second (tokens / elapsed_s)
|
||||
|
||||
## Runner Requirements
|
||||
|
||||
Each self-hosted runner must:
|
||||
- Be labeled with appropriate hardware tags (e.g., `M4PRO_GPU16_24GB`)
|
||||
- Have the `self-hosted` and `macOS` labels
|
||||
- Have Nix installed with flakes enabled
|
||||
- Have network connectivity to other runners in the same job
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ GitHub Actions Workflow (bench_matrix.yml) │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌────────────────┐ │
|
||||
│ │ Plan Job │ │
|
||||
│ │ (ubuntu) │──┬─► Matrix: [{label, index, primary}] │
|
||||
│ └────────────────┘ │ │
|
||||
│ │ │
|
||||
│ ┌───────────────────▼──────────────────────────────────┐ │
|
||||
│ │ Bench Worker Jobs (Matrix) │ │
|
||||
│ ├──────────────────────────────────────────────────────┤ │
|
||||
│ │ │ │
|
||||
│ │ Runner 0 (Primary) Runner 1 Runner 2 │ │
|
||||
│ │ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ │
|
||||
│ │ │ Start EXO │ │ Start EXO │ │ Start EXO│ │ │
|
||||
│ │ │ Create Inst │ │ Wait... │ │ Wait... │ │ │
|
||||
│ │ │ Wait Ready │ │ Wait Ready │ │ Wait... │ │ │
|
||||
│ │ │ Run Bench │ │ (idle) │ │ (idle) │ │ │
|
||||
│ │ │ Print TPS │ │ │ │ │ │ │
|
||||
│ │ │ Delete Inst │ │ │ │ │ │ │
|
||||
│ │ └─────────────┘ └─────────────┘ └──────────┘ │ │
|
||||
│ └───────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### `scripts/bench.py`
|
||||
|
||||
A standalone Python script that:
|
||||
- Creates instance (primary only)
|
||||
- Polls `/state` endpoint until instance and all runners are ready
|
||||
- Executes chat completion with timing (primary only)
|
||||
- Parses SSE stream and counts tokens
|
||||
- Computes TPS metrics
|
||||
- Cleans up instance (primary only)
|
||||
|
||||
### Key Functions
|
||||
|
||||
- `wait_for_instance()`: Polls until instance with model_id appears
|
||||
- `wait_for_runners_ready()`: Polls until expected number of runners reach Loaded/Running status
|
||||
- `run_benchmark()`: Executes chat completion, measures time, counts tokens
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Instance never becomes ready
|
||||
- Check EXO logs in the workflow output
|
||||
- Verify model_id is valid and accessible
|
||||
- Increase `timeout_seconds`
|
||||
|
||||
### Runner mismatch
|
||||
- Ensure hardware_plan counts match available labeled runners
|
||||
- Check runner labels match exactly (case-sensitive)
|
||||
|
||||
### Network issues
|
||||
- Verify runners can communicate on the network
|
||||
- Check firewall rules between runner hosts
|
||||
|
||||
305
.github/workflows/bench.yml
vendored
305
.github/workflows/bench.yml
vendored
@@ -1,305 +0,0 @@
|
||||
name: bench
|
||||
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
plan:
|
||||
if: contains(github.event.head_commit.message, '/bench')
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.build.outputs.matrix }}
|
||||
config_file: ${{ steps.build.outputs.config_file }}
|
||||
timeout_seconds: ${{ steps.build.outputs.timeout_seconds }}
|
||||
environment: ${{ steps.build.outputs.environment }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Build matrix from config file
|
||||
id: build
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
CONFIG_FILE='.github/configs/bench_simple.yaml'
|
||||
export CONFIG_FILE
|
||||
echo "Config file: $CONFIG_FILE"
|
||||
python3 .github/scripts/build_matrix.py
|
||||
|
||||
bench_worker:
|
||||
needs: plan
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix: ${{ fromJSON(needs.plan.outputs.matrix) }}
|
||||
name: "bench on ${{ matrix.label }} [${{ matrix.index }}]"
|
||||
runs-on: [self-hosted, macOS, "${{ matrix.label }}"]
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
git config --local user.name "github-actions bot"
|
||||
shell: bash
|
||||
|
||||
# TODO: this is mega hacky and I'd like a simpler solution.
|
||||
- name: Setup Nix Environment
|
||||
run: |
|
||||
echo "Checking for nix installation..."
|
||||
|
||||
# Check if nix is already available
|
||||
if command -v nix >/dev/null 2>&1; then
|
||||
echo "Nix already in PATH"
|
||||
# Try sourcing profile scripts to set up environment properly
|
||||
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
echo "Sourcing multi-user nix-daemon profile script"
|
||||
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
elif [ -f "$HOME/.nix-profile/etc/profile.d/nix.sh" ]; then
|
||||
echo "Sourcing single-user nix profile script"
|
||||
source "$HOME/.nix-profile/etc/profile.d/nix.sh"
|
||||
elif [ -f /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh ]; then
|
||||
echo "Sourcing per-user nix profile script"
|
||||
source /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh
|
||||
elif [ -f /etc/profile.d/nix.sh ]; then
|
||||
echo "Sourcing system-wide nix profile script"
|
||||
source /etc/profile.d/nix.sh
|
||||
# Fallback: manually add nix to PATH if binary exists
|
||||
elif [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
echo "Found nix binary, manually adding to PATH"
|
||||
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
elif [ -f "$HOME/.nix-profile/bin/nix" ]; then
|
||||
echo "Found nix binary in user profile, manually adding to PATH"
|
||||
export PATH="$HOME/.nix-profile/bin:$PATH"
|
||||
else
|
||||
echo "Nix not found. Debugging info:"
|
||||
echo "USER: $USER"
|
||||
echo "HOME: $HOME"
|
||||
echo "Current PATH: $PATH"
|
||||
echo ""
|
||||
echo "Checking common Nix locations:"
|
||||
echo " /nix/var/nix/profiles/default/bin/nix:"
|
||||
ls -la /nix/var/nix/profiles/default/bin/nix 2>/dev/null || echo " Not found"
|
||||
echo " /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh:"
|
||||
ls -la /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh 2>/dev/null || echo " Not found"
|
||||
echo " ~/.nix-profile/etc/profile.d/nix.sh:"
|
||||
ls -la "$HOME/.nix-profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||
echo " /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh:"
|
||||
ls -la "/nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||
echo ""
|
||||
echo "/nix directory structure:"
|
||||
ls -la /nix 2>/dev/null || echo " /nix directory not found"
|
||||
echo ""
|
||||
echo "/nix/var:"
|
||||
ls -la /nix/var 2>/dev/null || echo " /nix/var not found"
|
||||
echo ""
|
||||
echo "/nix/store:"
|
||||
ls -la /nix/store 2>/dev/null | head -20 || echo " /nix/store not found"
|
||||
echo ""
|
||||
echo "GitHub Actions runner is running as user '$USER'."
|
||||
echo "If Nix is installed for a different user, either:"
|
||||
echo " 1. Install Nix for user '$USER' (multi-user install recommended)"
|
||||
echo " 2. Configure the runner service to run as the user with Nix installed"
|
||||
echo " 3. Ensure Nix is installed system-wide with proper daemon setup"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify nix is available and persist to GITHUB_ENV
|
||||
if command -v nix >/dev/null 2>&1; then
|
||||
echo "✓ Nix is available"
|
||||
nix --version
|
||||
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
if [ -n "$NIX_PATH" ]; then
|
||||
echo "NIX_PATH=$NIX_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
else
|
||||
echo "ERROR: Failed to set up Nix"
|
||||
echo "PATH after setup attempt: $PATH"
|
||||
exit 1
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Setup EXO_HOME and API_PORT
|
||||
run: |
|
||||
EXO_HOME=$(mktemp -d -t exo-e2e-XXXXXXXX)
|
||||
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||
EXO_MODELS_DIR="$HOME/.exo/models"
|
||||
EXO_LIBP2P_NAMESPACE="bench-${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
|
||||
echo "EXO_HOME=$EXO_HOME" >> "$GITHUB_ENV"
|
||||
echo "API_PORT=$API_PORT" >> "$GITHUB_ENV"
|
||||
echo "EXO_MODELS_DIR=$EXO_MODELS_DIR" >> "$GITHUB_ENV"
|
||||
echo "EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE" >> "$GITHUB_ENV"
|
||||
echo "Created EXO_HOME: $EXO_HOME"
|
||||
echo "Generated API_PORT: $API_PORT"
|
||||
echo "Using models from: $EXO_MODELS_DIR"
|
||||
echo "Using libp2p namespace: $EXO_LIBP2P_NAMESPACE"
|
||||
shell: bash
|
||||
|
||||
- name: Configure local MLX if available
|
||||
run: |
|
||||
echo "=== DEBUG: Checking for local MLX configuration ==="
|
||||
MODIFIED=false
|
||||
|
||||
echo "Checking for /Users/Shared/mlx directory..."
|
||||
if [ -d "/Users/Shared/mlx" ]; then
|
||||
echo "✓ Found /Users/Shared/mlx"
|
||||
ls -la /Users/Shared/mlx | head -5
|
||||
echo "Enabling local mlx path in pyproject.toml"
|
||||
sed -i.bak 's|^# mlx = { path = "/Users/Shared/mlx", editable=true }$|mlx = { path = "/Users/Shared/mlx", editable=true }|' pyproject.toml
|
||||
MODIFIED=true
|
||||
else
|
||||
echo "✗ /Users/Shared/mlx not found, will use PyPI version"
|
||||
fi
|
||||
|
||||
echo "Checking for /Users/Shared/mlx-lm directory..."
|
||||
if [ -d "/Users/Shared/mlx-lm" ]; then
|
||||
echo "✓ Found /Users/Shared/mlx-lm"
|
||||
ls -la /Users/Shared/mlx-lm | head -5
|
||||
echo "Enabling local mlx-lm path in pyproject.toml"
|
||||
sed -i.bak 's|^# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }$|mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }|' pyproject.toml
|
||||
MODIFIED=true
|
||||
else
|
||||
echo "✗ /Users/Shared/mlx-lm not found, will use PyPI version"
|
||||
fi
|
||||
|
||||
if [ "$MODIFIED" = true ]; then
|
||||
echo "=== Modified pyproject.toml [tool.uv.sources] section: ==="
|
||||
sed -n '/\[tool\.uv\.sources\]/,/^\[/{/^\[tool\.uv\.sources\]/p; /^\[/!p;}' pyproject.toml
|
||||
echo "=== Regenerating uv.lock with local MLX paths... ==="
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command uv lock --upgrade-package mlx --upgrade-package mlx-lm
|
||||
echo "✓ Lock file regenerated"
|
||||
else
|
||||
echo "⚠ No local MLX directories found, using PyPI packages"
|
||||
fi
|
||||
echo "=== DEBUG: Local MLX configuration complete ==="
|
||||
shell: bash
|
||||
|
||||
- name: Sync dependencies
|
||||
run: |
|
||||
if [ -d "/Users/Shared/test" ]; then
|
||||
pushd /Users/Shared/test
|
||||
uv sync --reinstall
|
||||
popd
|
||||
fi
|
||||
echo "Running just sync to ensure clean dependencies..."
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync
|
||||
shell: bash
|
||||
|
||||
- name: Start EXO and run bench script
|
||||
shell: bash
|
||||
env:
|
||||
IS_PRIMARY: ${{ matrix.is_primary }}
|
||||
EXPECTED_NODES: ${{ matrix.expected_nodes }}
|
||||
HARDWARE_LABEL: ${{ matrix.label }}
|
||||
CONFIG_FILE: ${{ needs.plan.outputs.config_file }}
|
||||
TIMEOUT_SECONDS: ${{ needs.plan.outputs.timeout_seconds }}
|
||||
ENVIRONMENT_JSON: ${{ needs.plan.outputs.environment }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Parse environment variables from config
|
||||
ENV_VARS=""
|
||||
if [ -n "$ENVIRONMENT_JSON" ] && [ "$ENVIRONMENT_JSON" != "{}" ]; then
|
||||
ENV_VARS=$(echo "$ENVIRONMENT_JSON" | python3 -c "import sys, json; env = json.load(sys.stdin); print(' '.join([f'{k}={v}' for k, v in env.items()]))")
|
||||
fi
|
||||
|
||||
echo "Starting EXO with API_PORT=${API_PORT} EXO_HOME=${EXO_HOME} EXO_LIBP2P_NAMESPACE=${EXO_LIBP2P_NAMESPACE}"
|
||||
echo "Environment variables from config: $ENV_VARS"
|
||||
LOG_FILE=/tmp/exo.log
|
||||
: > "$LOG_FILE"
|
||||
|
||||
MASTER_FLAG=""
|
||||
if [ "$IS_PRIMARY" = "true" ]; then
|
||||
MASTER_FLAG="-m"
|
||||
fi
|
||||
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||
"EXO_HOME=$EXO_HOME EXO_MODELS_DIR=$EXO_MODELS_DIR EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE $ENV_VARS PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run exo $MASTER_FLAG --api-port $API_PORT" \
|
||||
>> "$LOG_FILE" 2>&1 &
|
||||
|
||||
EXO_PID=$!
|
||||
echo "Started EXO in background with PID: $EXO_PID"
|
||||
echo "Log file: $LOG_FILE"
|
||||
|
||||
cleanup() {
|
||||
echo '=== EXO log (tail) ==='
|
||||
tail -n 300 "$LOG_FILE" || true
|
||||
if ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||
echo "Killing EXO (PID $EXO_PID)"
|
||||
kill "$EXO_PID" || true
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
for i in $(seq 1 60); do
|
||||
if curl -s "http://localhost:${API_PORT}/state" >/dev/null 2>&1; then
|
||||
echo "EXO API ready"
|
||||
break
|
||||
fi
|
||||
if ! ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||
echo "EXO terminated early"; sed -n '1,200p' "$LOG_FILE" || true; exit 1
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
RESULTS_FILE="/tmp/bench_results_${GITHUB_RUN_ID}_${GITHUB_RUN_ATTEMPT}_$(date +%s).json"
|
||||
echo "Results will be saved to: $RESULTS_FILE"
|
||||
echo "RESULTS_FILE=$RESULTS_FILE" >> "$GITHUB_ENV"
|
||||
|
||||
echo "Running bench script with config: $CONFIG_FILE, timeout: $TIMEOUT_SECONDS"
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||
"PYTHONUNBUFFERED=1 uv run --no-project --with pyyaml --with pydantic python .github/scripts/bench.py \
|
||||
--api-port $API_PORT \
|
||||
--config $CONFIG_FILE \
|
||||
--expected-nodes ${EXPECTED_NODES} \
|
||||
--is-primary ${IS_PRIMARY} \
|
||||
--timeout-seconds ${TIMEOUT_SECONDS} \
|
||||
--output $RESULTS_FILE \
|
||||
--git-commit ${GITHUB_SHA} \
|
||||
--hardware-labels ${HARDWARE_LABEL}"
|
||||
|
||||
- name: Install AWS CLI
|
||||
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||
run: |
|
||||
if ! command -v aws &> /dev/null; then
|
||||
echo "AWS CLI not found, installing..."
|
||||
brew install awscli
|
||||
else
|
||||
echo "AWS CLI already installed"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Upload results to S3
|
||||
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.S3_BENCHMARKS_AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
run: |
|
||||
echo "Checking for results file: $RESULTS_FILE"
|
||||
echo "Is primary: ${{ matrix.is_primary }}"
|
||||
|
||||
if [ -f "$RESULTS_FILE" ]; then
|
||||
TIMESTAMP=$(date -u +%Y/%m/%d/%H%M%S)
|
||||
S3_KEY="bench/${TIMESTAMP}_${GITHUB_SHA:0:8}_${GITHUB_RUN_ID}.json"
|
||||
echo "Uploading results to s3://exo-benchmark-results/$S3_KEY"
|
||||
|
||||
aws s3 cp "$RESULTS_FILE" "s3://exo-benchmark-results/$S3_KEY" \
|
||||
--content-type application/json \
|
||||
--metadata "commit=${GITHUB_SHA},run_id=${GITHUB_RUN_ID},branch=${GITHUB_REF_NAME}"
|
||||
|
||||
echo "Results uploaded successfully"
|
||||
echo "View at: https://exo-benchmark-results.s3.amazonaws.com/$S3_KEY"
|
||||
else
|
||||
echo "Results file not found at: $RESULTS_FILE"
|
||||
echo "Skipping upload"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Cleanup EXO_HOME
|
||||
run: |
|
||||
echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||
rm -rf "$EXO_HOME"
|
||||
shell: bash
|
||||
if: always()
|
||||
32
.github/workflows/build-app.yml
vendored
32
.github/workflows/build-app.yml
vendored
@@ -18,6 +18,7 @@ jobs:
|
||||
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
|
||||
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
|
||||
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
|
||||
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT: ${{ secrets.EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT }}
|
||||
AWS_REGION: ${{ secrets.AWS_REGION }}
|
||||
EXO_BUILD_NUMBER: ${{ github.run_number }}
|
||||
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
|
||||
@@ -47,6 +48,32 @@ jobs:
|
||||
fi
|
||||
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
|
||||
|
||||
- name: Compute build version from semver
|
||||
run: |
|
||||
VERSION="$RELEASE_VERSION"
|
||||
# Extract major.minor.patch (strip prerelease suffix)
|
||||
BASE_VERSION="${VERSION%%-*}"
|
||||
MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
|
||||
MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
|
||||
PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)
|
||||
|
||||
# Extract prerelease number (e.g., "alpha.2" -> 2, or 999 for releases)
|
||||
if [[ "$VERSION" == *-* ]]; then
|
||||
PRERELEASE_PART="${VERSION#*-}"
|
||||
PRERELEASE_NUM="${PRERELEASE_PART##*.}"
|
||||
# Default to 0 if not a number
|
||||
if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
|
||||
PRERELEASE_NUM=0
|
||||
fi
|
||||
else
|
||||
PRERELEASE_NUM=999
|
||||
fi
|
||||
|
||||
# Compute: PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR)
|
||||
BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
|
||||
echo "EXO_BUILD_VERSION=$BUILD_VERSION" >> $GITHUB_ENV
|
||||
echo "Computed build version: $BUILD_VERSION from $VERSION"
|
||||
|
||||
- name: Ensure tag commit is on main
|
||||
if: github.ref_type == 'tag'
|
||||
run: |
|
||||
@@ -162,11 +189,12 @@ jobs:
|
||||
-configuration Release \
|
||||
-derivedDataPath build \
|
||||
MARKETING_VERSION="$RELEASE_VERSION" \
|
||||
CURRENT_PROJECT_VERSION="$EXO_BUILD_NUMBER" \
|
||||
CURRENT_PROJECT_VERSION="$EXO_BUILD_VERSION" \
|
||||
EXO_BUILD_TAG="$RELEASE_VERSION" \
|
||||
EXO_BUILD_COMMIT="$GITHUB_SHA" \
|
||||
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
|
||||
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
|
||||
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT="$EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT" \
|
||||
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
|
||||
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
|
||||
mkdir -p ../../output
|
||||
@@ -294,5 +322,5 @@ jobs:
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
|
||||
if [[ "$IS_ALPHA" != "true" ]]; then
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
fi
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
|
||||
3
.prettierrc
Normal file
3
.prettierrc
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"useTabs": true
|
||||
}
|
||||
6
.swift-format
Normal file
6
.swift-format
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"version": 1,
|
||||
"indentation": {
|
||||
"spaces": 4
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,7 @@
|
||||
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/72NsF6ux" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://x.com/exolabs" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/twitter/follow/exolabs?style=social" alt="X"></a>
|
||||
<a href="https://www.apache.org/licenses/LICENSE-2.0.html" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/License-Apache2.0-blue.svg" alt="License: Apache-2.0"></a>
|
||||
</p>
|
||||
|
||||
@@ -12,6 +12,7 @@ struct ContentView: View {
|
||||
@EnvironmentObject private var controller: ExoProcessController
|
||||
@EnvironmentObject private var stateService: ClusterStateService
|
||||
@EnvironmentObject private var networkStatusService: NetworkStatusService
|
||||
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@EnvironmentObject private var updater: SparkleUpdater
|
||||
@State private var focusedNode: NodeViewModel?
|
||||
@State private var deletingInstanceIDs: Set<String> = []
|
||||
@@ -26,6 +27,9 @@ struct ContentView: View {
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
statusSection
|
||||
if shouldShowLocalNetworkWarning {
|
||||
localNetworkWarningBanner
|
||||
}
|
||||
if shouldShowClusterDetails {
|
||||
Divider()
|
||||
overviewSection
|
||||
@@ -40,6 +44,7 @@ struct ContentView: View {
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowClusterDetails)
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowInstances)
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowLocalNetworkWarning)
|
||||
.padding()
|
||||
.frame(width: 340)
|
||||
.onAppear {
|
||||
@@ -49,9 +54,62 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
private var localNetworkWarningBanner: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack(spacing: 6) {
|
||||
Image(systemName: "exclamationmark.triangle.fill")
|
||||
.foregroundColor(.orange)
|
||||
Text("Local Network Access Issue")
|
||||
.font(.caption)
|
||||
.fontWeight(.semibold)
|
||||
}
|
||||
Text(
|
||||
"Device discovery won't work. To fix:\n1. Quit EXO\n2. Open System Settings → Privacy & Security → Local Network\n3. Toggle EXO off, then back on\n4. Relaunch EXO"
|
||||
)
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
Button {
|
||||
openLocalNetworkSettings()
|
||||
} label: {
|
||||
Text("Open Settings")
|
||||
.font(.caption2)
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
}
|
||||
.padding(8)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.fill(Color.orange.opacity(0.1))
|
||||
)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.stroke(Color.orange.opacity(0.3), lineWidth: 1)
|
||||
)
|
||||
}
|
||||
|
||||
private func openLocalNetworkSettings() {
|
||||
// Open Privacy & Security settings - Local Network section
|
||||
if let url = URL(
|
||||
string: "x-apple.systempreferences:com.apple.preference.security?Privacy_LocalNetwork")
|
||||
{
|
||||
NSWorkspace.shared.open(url)
|
||||
}
|
||||
}
|
||||
|
||||
private var topologySection: some View {
|
||||
Group {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(
|
||||
localNodeId: stateService.localNodeId), !topology.nodes.isEmpty
|
||||
{
|
||||
TopologyMiniView(topology: topology)
|
||||
}
|
||||
}
|
||||
@@ -85,8 +143,10 @@ struct ContentView: View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
HStack {
|
||||
VStack(alignment: .leading) {
|
||||
Text("\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB")
|
||||
.font(.headline)
|
||||
Text(
|
||||
"\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB"
|
||||
)
|
||||
.font(.headline)
|
||||
Text("Memory")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
@@ -210,7 +270,9 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void) -> some View {
|
||||
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void)
|
||||
-> some View
|
||||
{
|
||||
HoverButton(title: title, tint: tint, trailingSystemImage: nil, action: action)
|
||||
}
|
||||
|
||||
@@ -241,9 +303,12 @@ struct ContentView: View {
|
||||
Button {
|
||||
isExpanded.wrappedValue.toggle()
|
||||
} label: {
|
||||
Label(isExpanded.wrappedValue ? "Hide" : "Show All", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
Label(
|
||||
isExpanded.wrappedValue ? "Hide" : "Show All",
|
||||
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
|
||||
)
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.font(.caption2)
|
||||
@@ -394,6 +459,7 @@ struct ContentView: View {
|
||||
.font(.caption2)
|
||||
.foregroundColor(thunderboltStatusColor)
|
||||
interfaceIpList
|
||||
rdmaStatusView
|
||||
sendBugReportButton
|
||||
.padding(.top, 6)
|
||||
}
|
||||
@@ -403,6 +469,52 @@ struct ContentView: View {
|
||||
.animation(.easeInOut(duration: 0.25), value: showDebugInfo)
|
||||
}
|
||||
|
||||
private var rdmaStatusView: some View {
|
||||
let rdma = networkStatusService.status.rdmaStatus
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
Text("RDMA: \(rdmaStatusText(rdma))")
|
||||
.font(.caption2)
|
||||
.foregroundColor(rdmaStatusColor(rdma))
|
||||
if !rdma.devices.isEmpty {
|
||||
Text(" Devices: \(rdma.devices.joined(separator: ", "))")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
if !rdma.activePorts.isEmpty {
|
||||
Text(" Active Ports:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(rdma.activePorts, id: \.device) { port in
|
||||
Text(" \(port.device) port \(port.port): \(port.state)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.green)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func rdmaStatusText(_ rdma: RDMAStatus) -> String {
|
||||
switch rdma.rdmaCtlEnabled {
|
||||
case .some(true):
|
||||
return "Enabled"
|
||||
case .some(false):
|
||||
return "Disabled"
|
||||
case nil:
|
||||
return rdma.devices.isEmpty ? "Not Available" : "Available"
|
||||
}
|
||||
}
|
||||
|
||||
private func rdmaStatusColor(_ rdma: RDMAStatus) -> Color {
|
||||
switch rdma.rdmaCtlEnabled {
|
||||
case .some(true):
|
||||
return .green
|
||||
case .some(false):
|
||||
return .orange
|
||||
case nil:
|
||||
return rdma.devices.isEmpty ? .secondary : .green
|
||||
}
|
||||
}
|
||||
|
||||
private var sendBugReportButton: some View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Button {
|
||||
@@ -536,4 +648,3 @@ private struct HoverButton: View {
|
||||
.onHover { isHovering = $0 }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
import AppKit
|
||||
import CoreImage
|
||||
import CoreImage.CIFilterBuiltins
|
||||
import ServiceManagement
|
||||
import Sparkle
|
||||
import SwiftUI
|
||||
import ServiceManagement
|
||||
import UserNotifications
|
||||
import os.log
|
||||
|
||||
@@ -19,6 +19,7 @@ struct EXOApp: App {
|
||||
@StateObject private var controller: ExoProcessController
|
||||
@StateObject private var stateService: ClusterStateService
|
||||
@StateObject private var networkStatusService: NetworkStatusService
|
||||
@StateObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@StateObject private var updater: SparkleUpdater
|
||||
private let terminationObserver: TerminationObserver
|
||||
private let ciContext = CIContext(options: nil)
|
||||
@@ -37,9 +38,13 @@ struct EXOApp: App {
|
||||
_stateService = StateObject(wrappedValue: service)
|
||||
let networkStatus = NetworkStatusService()
|
||||
_networkStatusService = StateObject(wrappedValue: networkStatus)
|
||||
let localNetwork = LocalNetworkChecker()
|
||||
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
|
||||
_updater = StateObject(wrappedValue: updater)
|
||||
enableLaunchAtLoginIfNeeded()
|
||||
NetworkSetupHelper.ensureLaunchDaemonInstalled()
|
||||
// Check local network access BEFORE launching exo
|
||||
localNetwork.check()
|
||||
controller.scheduleLaunch(after: 15)
|
||||
service.startPolling()
|
||||
networkStatus.startPolling()
|
||||
@@ -51,6 +56,7 @@ struct EXOApp: App {
|
||||
.environmentObject(controller)
|
||||
.environmentObject(stateService)
|
||||
.environmentObject(networkStatusService)
|
||||
.environmentObject(localNetworkChecker)
|
||||
.environmentObject(updater)
|
||||
} label: {
|
||||
menuBarIcon
|
||||
@@ -107,7 +113,7 @@ struct EXOApp: App {
|
||||
filter.contrast = 0.9
|
||||
|
||||
guard let output = filter.outputImage,
|
||||
let rendered = ciContext.createCGImage(output, from: output.extent)
|
||||
let rendered = ciContext.createCGImage(output, from: output.extent)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
@@ -120,7 +126,8 @@ struct EXOApp: App {
|
||||
do {
|
||||
try SMAppService.mainApp.register()
|
||||
} catch {
|
||||
Logger().error("Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
Logger().error(
|
||||
"Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,7 +152,7 @@ final class SparkleUpdater: NSObject, ObservableObject {
|
||||
center.requestAuthorization(options: [.alert, .sound]) { _, _ in }
|
||||
controller.updater.automaticallyChecksForUpdates = true
|
||||
controller.updater.automaticallyDownloadsUpdates = false
|
||||
controller.updater.updateCheckInterval = 900 // 15 minutes
|
||||
controller.updater.updateCheckInterval = 900 // 15 minutes
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 5) { [weak controller] in
|
||||
controller?.updater.checkForUpdatesInBackground()
|
||||
}
|
||||
@@ -212,7 +219,8 @@ private final class ExoNotificationDelegate: NSObject, UNUserNotificationCenterD
|
||||
func userNotificationCenter(
|
||||
_ center: UNUserNotificationCenter,
|
||||
willPresent notification: UNNotification,
|
||||
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) -> Void
|
||||
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) ->
|
||||
Void
|
||||
) {
|
||||
completionHandler([.banner, .list, .sound])
|
||||
}
|
||||
|
||||
@@ -31,7 +31,8 @@ final class ExoProcessController: ObservableObject {
|
||||
@Published private(set) var launchCountdownSeconds: Int?
|
||||
@Published var customNamespace: String = {
|
||||
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
|
||||
}() {
|
||||
}()
|
||||
{
|
||||
didSet {
|
||||
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
|
||||
}
|
||||
@@ -221,7 +222,9 @@ final class ExoProcessController: ObservableObject {
|
||||
if let tag = Bundle.main.infoDictionary?["EXOBuildTag"] as? String, !tag.isEmpty {
|
||||
return tag
|
||||
}
|
||||
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String, !short.isEmpty {
|
||||
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String,
|
||||
!short.isEmpty
|
||||
{
|
||||
return short
|
||||
}
|
||||
return "dev"
|
||||
|
||||
@@ -8,5 +8,15 @@
|
||||
<string>$(EXO_BUILD_TAG)</string>
|
||||
<key>EXOBuildCommit</key>
|
||||
<string>$(EXO_BUILD_COMMIT)</string>
|
||||
<key>EXOBugReportPresignedUrlEndpoint</key>
|
||||
<string>$(EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT)</string>
|
||||
<key>NSLocalNetworkUsageDescription</key>
|
||||
<string>EXO needs local network access to discover and connect to other devices in your cluster for distributed AI inference.</string>
|
||||
<key>NSBonjourServices</key>
|
||||
<array>
|
||||
<string>_p2p._tcp</string>
|
||||
<string>_p2p._udp</string>
|
||||
<string>_libp2p._udp</string>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
||||
|
||||
@@ -16,10 +16,13 @@ struct ClusterState: Decodable {
|
||||
self.instances = rawInstances.mapValues(\.instance)
|
||||
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
|
||||
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
|
||||
let rawTasks = try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
let rawTasks =
|
||||
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
self.tasks = rawTasks.compactMapValues(\.task)
|
||||
self.topology = try container.decodeIfPresent(Topology.self, forKey: .topology)
|
||||
let rawDownloads = try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads) ?? [:]
|
||||
let rawDownloads =
|
||||
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
|
||||
?? [:]
|
||||
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
|
||||
}
|
||||
|
||||
@@ -41,7 +44,8 @@ private struct TaggedInstance: Decodable {
|
||||
let payloads = try container.decode([String: ClusterInstancePayload].self)
|
||||
guard let entry = payloads.first else {
|
||||
throw DecodingError.dataCorrupted(
|
||||
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
|
||||
DecodingError.Context(
|
||||
codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
|
||||
)
|
||||
}
|
||||
self.instance = ClusterInstance(
|
||||
@@ -77,7 +81,8 @@ struct RunnerStatusSummary: Decodable {
|
||||
let payloads = try container.decode([String: RunnerStatusDetail].self)
|
||||
guard let entry = payloads.first else {
|
||||
throw DecodingError.dataCorrupted(
|
||||
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
|
||||
DecodingError.Context(
|
||||
codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
|
||||
)
|
||||
}
|
||||
self.status = entry.key
|
||||
@@ -257,7 +262,9 @@ struct ChatCompletionTaskParameters: Decodable, Equatable {
|
||||
|
||||
func promptPreview() -> String? {
|
||||
guard let messages else { return nil }
|
||||
if let userMessage = messages.last(where: { $0.role?.lowercased() == "user" && ($0.content?.isEmpty == false) }) {
|
||||
if let userMessage = messages.last(where: {
|
||||
$0.role?.lowercased() == "user" && ($0.content?.isEmpty == false)
|
||||
}) {
|
||||
return userMessage.content
|
||||
}
|
||||
return messages.last?.content
|
||||
@@ -365,5 +372,3 @@ extension ClusterState {
|
||||
|
||||
func availableModels() -> [ModelOption] { [] }
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import CryptoKit
|
||||
import Foundation
|
||||
|
||||
struct BugReportOutcome: Equatable {
|
||||
@@ -7,17 +6,17 @@ struct BugReportOutcome: Equatable {
|
||||
}
|
||||
|
||||
enum BugReportError: LocalizedError {
|
||||
case missingCredentials
|
||||
case invalidEndpoint
|
||||
case presignedUrlFailed(String)
|
||||
case uploadFailed(String)
|
||||
case collectFailed(String)
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .missingCredentials:
|
||||
return "Bug report upload credentials are not set."
|
||||
case .invalidEndpoint:
|
||||
return "Bug report endpoint is invalid."
|
||||
case .presignedUrlFailed(let message):
|
||||
return "Failed to get presigned URLs: \(message)"
|
||||
case .uploadFailed(let message):
|
||||
return "Bug report upload failed: \(message)"
|
||||
case .collectFailed(let message):
|
||||
@@ -27,11 +26,13 @@ enum BugReportError: LocalizedError {
|
||||
}
|
||||
|
||||
struct BugReportService {
|
||||
struct AWSConfig {
|
||||
let accessKey: String
|
||||
let secretKey: String
|
||||
let region: String
|
||||
let bucket: String
|
||||
private struct PresignedUrlsRequest: Codable {
|
||||
let keys: [String]
|
||||
}
|
||||
|
||||
private struct PresignedUrlsResponse: Codable {
|
||||
let urls: [String: String]
|
||||
let expiresIn: Int?
|
||||
}
|
||||
|
||||
func sendReport(
|
||||
@@ -39,9 +40,9 @@ struct BugReportService {
|
||||
now: Date = Date(),
|
||||
isManual: Bool = false
|
||||
) async throws -> BugReportOutcome {
|
||||
let credentials = try loadCredentials()
|
||||
let timestamp = ISO8601DateFormatter().string(from: now)
|
||||
let prefix = "reports/\(timestamp)/"
|
||||
let timestamp = Self.runTimestampString(now)
|
||||
let dayPrefix = Self.dayPrefixString(now)
|
||||
let prefix = "reports/\(dayPrefix)/\(timestamp)/"
|
||||
|
||||
let logData = readLog()
|
||||
let ifconfigText = try await captureIfconfig()
|
||||
@@ -66,28 +67,82 @@ struct BugReportService {
|
||||
("\(prefix)exo.log", logData),
|
||||
("\(prefix)state.json", stateData),
|
||||
("\(prefix)events.json", eventsData),
|
||||
("\(prefix)report.json", reportJSON)
|
||||
("\(prefix)report.json", reportJSON),
|
||||
]
|
||||
|
||||
let uploader = try S3Uploader(config: credentials)
|
||||
for item in uploads {
|
||||
guard let data = item.data else { continue }
|
||||
try await uploader.upload(
|
||||
objectPath: item.path,
|
||||
body: data
|
||||
)
|
||||
let uploadItems: [(key: String, body: Data)] = uploads.compactMap { item in
|
||||
guard let body = item.data else { return nil }
|
||||
return (key: item.path, body: body)
|
||||
}
|
||||
|
||||
return BugReportOutcome(success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
|
||||
guard !uploadItems.isEmpty else {
|
||||
return BugReportOutcome(success: false, message: "No data to upload")
|
||||
}
|
||||
|
||||
let presignedUrls = try await fetchPresignedUploadUrls(keys: uploadItems.map(\.key))
|
||||
for item in uploadItems {
|
||||
guard let urlString = presignedUrls[item.key], let url = URL(string: urlString) else {
|
||||
throw BugReportError.uploadFailed("Missing presigned URL for \(item.key)")
|
||||
}
|
||||
try await uploadToPresignedUrl(url: url, body: item.body)
|
||||
}
|
||||
|
||||
return BugReportOutcome(
|
||||
success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
|
||||
}
|
||||
|
||||
private func loadCredentials() throws -> AWSConfig {
|
||||
return AWSConfig(
|
||||
accessKey: "AKIAYEKP5EMXTOBYDGHX",
|
||||
secretKey: "Ep5gIlUZ1o8ssTLQwmyy34yPGfTPEYQ4evE8NdPE",
|
||||
region: "us-east-1",
|
||||
bucket: "exo-bug-reports"
|
||||
)
|
||||
private static func dayPrefixString(_ date: Date) -> String {
|
||||
var calendar = Calendar(identifier: .gregorian)
|
||||
calendar.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
|
||||
let components = calendar.dateComponents([.year, .month, .day], from: date)
|
||||
let year = components.year ?? 0
|
||||
let month = components.month ?? 0
|
||||
let day = components.day ?? 0
|
||||
return String(format: "%04d/%02d/%02d", year, month, day)
|
||||
}
|
||||
|
||||
private static func runTimestampString(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.locale = Locale(identifier: "en_US_POSIX")
|
||||
formatter.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
|
||||
formatter.dateFormat = "yyyy-MM-dd'T'HHmmss.SSS'Z'"
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private func fetchPresignedUploadUrls(keys: [String], bundle: Bundle = .main) async throws
|
||||
-> [String: String]
|
||||
{
|
||||
guard
|
||||
let endpointString = bundle.infoDictionary?["EXOBugReportPresignedUrlEndpoint"]
|
||||
as? String
|
||||
else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
let trimmedEndpointString = endpointString.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
guard !trimmedEndpointString.isEmpty, let endpoint = URL(string: trimmedEndpointString)
|
||||
else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
|
||||
var request = URLRequest(url: endpoint)
|
||||
request.httpMethod = "POST"
|
||||
request.timeoutInterval = 10
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
|
||||
let encoder = JSONEncoder()
|
||||
request.httpBody = try encoder.encode(PresignedUrlsRequest(keys: keys))
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse else {
|
||||
throw BugReportError.presignedUrlFailed("Non-HTTP response")
|
||||
}
|
||||
guard (200..<300).contains(http.statusCode) else {
|
||||
throw BugReportError.presignedUrlFailed("HTTP status \(http.statusCode)")
|
||||
}
|
||||
|
||||
let decoder = JSONDecoder()
|
||||
let decoded = try decoder.decode(PresignedUrlsResponse.self, from: data)
|
||||
return decoded.urls
|
||||
}
|
||||
|
||||
private func readLog() -> Data? {
|
||||
@@ -100,7 +155,8 @@ struct BugReportService {
|
||||
private func captureIfconfig() async throws -> String {
|
||||
let result = runCommand(["/sbin/ifconfig"])
|
||||
guard result.exitCode == 0 else {
|
||||
throw BugReportError.collectFailed(result.error.isEmpty ? "ifconfig failed" : result.error)
|
||||
throw BugReportError.collectFailed(
|
||||
result.error.isEmpty ? "ifconfig failed" : result.error)
|
||||
}
|
||||
return result.output
|
||||
}
|
||||
@@ -108,12 +164,23 @@ struct BugReportService {
|
||||
private func readDebugInfo() -> DebugInfo {
|
||||
DebugInfo(
|
||||
thunderboltBridgeDisabled: readThunderboltBridgeDisabled(),
|
||||
interfaces: readInterfaces()
|
||||
interfaces: readInterfaces(),
|
||||
rdma: readRDMADebugInfo()
|
||||
)
|
||||
}
|
||||
|
||||
private func readRDMADebugInfo() -> DebugInfo.RDMADebugInfo {
|
||||
DebugInfo.RDMADebugInfo(
|
||||
rdmaCtlStatus: safeRunCommand(["/usr/bin/rdma_ctl", "status"]),
|
||||
ibvDevices: safeRunCommand(["/usr/bin/ibv_devices"]),
|
||||
ibvDevinfo: safeRunCommand(["/usr/bin/ibv_devinfo"])
|
||||
)
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeDisabled() -> Bool? {
|
||||
let result = runCommand(["/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
|
||||
let result = runCommand([
|
||||
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
|
||||
])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased()
|
||||
if output.contains("enabled") {
|
||||
@@ -156,7 +223,8 @@ struct BugReportService {
|
||||
request.timeoutInterval = 5
|
||||
do {
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
@@ -165,6 +233,36 @@ struct BugReportService {
|
||||
}
|
||||
}
|
||||
|
||||
private func uploadToPresignedUrl(url: URL, body: Data) async throws {
|
||||
let maxAttempts = 2
|
||||
var lastError: Error?
|
||||
|
||||
for attempt in 1...maxAttempts {
|
||||
do {
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "PUT"
|
||||
request.httpBody = body
|
||||
request.timeoutInterval = 30
|
||||
|
||||
let (_, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse else {
|
||||
throw BugReportError.uploadFailed("Non-HTTP response")
|
||||
}
|
||||
guard (200..<300).contains(http.statusCode) else {
|
||||
throw BugReportError.uploadFailed("HTTP status \(http.statusCode)")
|
||||
}
|
||||
return
|
||||
} catch {
|
||||
lastError = error
|
||||
if attempt < maxAttempts {
|
||||
try await Task.sleep(nanoseconds: 400_000_000)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw BugReportError.uploadFailed(lastError?.localizedDescription ?? "Unknown error")
|
||||
}
|
||||
|
||||
private func makeReportJson(
|
||||
timestamp: String,
|
||||
hostName: String,
|
||||
@@ -182,7 +280,7 @@ struct BugReportService {
|
||||
"system": system,
|
||||
"exo_version": exo.version as Any,
|
||||
"exo_commit": exo.commit as Any,
|
||||
"report_type": isManual ? "manual" : "automated"
|
||||
"report_type": isManual ? "manual" : "automated",
|
||||
]
|
||||
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
|
||||
}
|
||||
@@ -213,10 +311,13 @@ struct BugReportService {
|
||||
let user = safeRunCommand(["/usr/bin/whoami"])
|
||||
let consoleUser = safeRunCommand(["/usr/bin/stat", "-f%Su", "/dev/console"])
|
||||
let uptime = safeRunCommand(["/usr/bin/uptime"])
|
||||
let diskRoot = safeRunCommand(["/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'"])
|
||||
let diskRoot = safeRunCommand([
|
||||
"/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'",
|
||||
])
|
||||
|
||||
let interfacesList = safeRunCommand(["/usr/sbin/ipconfig", "getiflist"])
|
||||
let interfacesAndIPs = interfacesList?
|
||||
let interfacesAndIPs =
|
||||
interfacesList?
|
||||
.split(whereSeparator: { $0 == " " || $0 == "\n" })
|
||||
.compactMap { iface -> [String: Any]? in
|
||||
let name = String(iface)
|
||||
@@ -227,7 +328,8 @@ struct BugReportService {
|
||||
} ?? []
|
||||
|
||||
let wifiSSID: String?
|
||||
let airportPath = "/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
|
||||
let airportPath =
|
||||
"/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
|
||||
if FileManager.default.isExecutableFile(atPath: airportPath) {
|
||||
wifiSSID = safeRunCommand([airportPath, "-I"]).flatMap(parseWifiSSID)
|
||||
} else {
|
||||
@@ -255,7 +357,7 @@ struct BugReportService {
|
||||
"disk_root": diskRoot as Any,
|
||||
"interfaces_and_ips": interfacesAndIPs,
|
||||
"ipconfig_getiflist": interfacesList as Any,
|
||||
"wifi_ssid": wifiSSID as Any
|
||||
"wifi_ssid": wifiSSID as Any,
|
||||
]
|
||||
}
|
||||
|
||||
@@ -313,7 +415,8 @@ struct BugReportService {
|
||||
for line in airportOutput.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("SSID:") {
|
||||
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(in: .whitespaces)
|
||||
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -350,6 +453,7 @@ struct BugReportService {
|
||||
private struct DebugInfo {
|
||||
let thunderboltBridgeDisabled: Bool?
|
||||
let interfaces: [InterfaceStatus]
|
||||
let rdma: RDMADebugInfo
|
||||
|
||||
struct InterfaceStatus {
|
||||
let name: String
|
||||
@@ -358,7 +462,21 @@ private struct DebugInfo {
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"name": name,
|
||||
"ip": ip as Any
|
||||
"ip": ip as Any,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
struct RDMADebugInfo {
|
||||
let rdmaCtlStatus: String?
|
||||
let ibvDevices: String?
|
||||
let ibvDevinfo: String?
|
||||
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"rdma_ctl_status": rdmaCtlStatus as Any,
|
||||
"ibv_devices": ibvDevices as Any,
|
||||
"ibv_devinfo": ibvDevinfo as Any,
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -366,7 +484,8 @@ private struct DebugInfo {
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"thunderbolt_bridge_disabled": thunderboltBridgeDisabled as Any,
|
||||
"interfaces": interfaces.map { $0.toDictionary() }
|
||||
"interfaces": interfaces.map { $0.toDictionary() },
|
||||
"rdma": rdma.toDictionary(),
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -376,163 +495,3 @@ private struct CommandResult {
|
||||
let output: String
|
||||
let error: String
|
||||
}
|
||||
|
||||
private struct S3Uploader {
|
||||
let config: BugReportService.AWSConfig
|
||||
|
||||
init(config: BugReportService.AWSConfig) throws {
|
||||
self.config = config
|
||||
}
|
||||
|
||||
func upload(objectPath: String, body: Data) async throws {
|
||||
let host = "\(config.bucket).s3.amazonaws.com"
|
||||
guard let url = URL(string: "https://\(host)/\(objectPath)") else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
|
||||
let now = Date()
|
||||
let amzDate = awsTimestamp(now)
|
||||
let dateStamp = dateStamp(now)
|
||||
let payloadHash = sha256Hex(body)
|
||||
|
||||
let headers = [
|
||||
"host": host,
|
||||
"x-amz-content-sha256": payloadHash,
|
||||
"x-amz-date": amzDate
|
||||
]
|
||||
|
||||
let canonicalRequest = buildCanonicalRequest(
|
||||
method: "PUT",
|
||||
url: url,
|
||||
headers: headers,
|
||||
payloadHash: payloadHash
|
||||
)
|
||||
|
||||
let stringToSign = buildStringToSign(
|
||||
amzDate: amzDate,
|
||||
dateStamp: dateStamp,
|
||||
canonicalRequestHash: sha256Hex(canonicalRequest.data(using: .utf8) ?? Data())
|
||||
)
|
||||
|
||||
let signingKey = deriveKey(secret: config.secretKey, dateStamp: dateStamp, region: config.region, service: "s3")
|
||||
let signature = hmacHex(key: signingKey, data: Data(stringToSign.utf8))
|
||||
|
||||
let signedHeaders = "host;x-amz-content-sha256;x-amz-date"
|
||||
let authorization = """
|
||||
AWS4-HMAC-SHA256 Credential=\(config.accessKey)/\(dateStamp)/\(config.region)/s3/aws4_request, SignedHeaders=\(signedHeaders), Signature=\(signature)
|
||||
"""
|
||||
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "PUT"
|
||||
request.httpBody = body
|
||||
request.setValue(headers["x-amz-content-sha256"], forHTTPHeaderField: "x-amz-content-sha256")
|
||||
request.setValue(headers["x-amz-date"], forHTTPHeaderField: "x-amz-date")
|
||||
request.setValue(host, forHTTPHeaderField: "Host")
|
||||
request.setValue(authorization, forHTTPHeaderField: "Authorization")
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
|
||||
let statusText = (response as? HTTPURLResponse)?.statusCode ?? -1
|
||||
_ = data // ignore response body for UX
|
||||
throw BugReportError.uploadFailed("HTTP status \(statusText)")
|
||||
}
|
||||
}
|
||||
|
||||
private func buildCanonicalRequest(
|
||||
method: String,
|
||||
url: URL,
|
||||
headers: [String: String],
|
||||
payloadHash: String
|
||||
) -> String {
|
||||
let canonicalURI = encodePath(url.path)
|
||||
let canonicalQuery = url.query ?? ""
|
||||
let sortedHeaders = headers.sorted { $0.key < $1.key }
|
||||
let canonicalHeaders = sortedHeaders
|
||||
.map { "\($0.key.lowercased()):\($0.value)\n" }
|
||||
.joined()
|
||||
let signedHeaders = sortedHeaders.map { $0.key.lowercased() }.joined(separator: ";")
|
||||
|
||||
return [
|
||||
method,
|
||||
canonicalURI,
|
||||
canonicalQuery,
|
||||
canonicalHeaders,
|
||||
signedHeaders,
|
||||
payloadHash
|
||||
].joined(separator: "\n")
|
||||
}
|
||||
|
||||
private func encodePath(_ path: String) -> String {
|
||||
return path
|
||||
.split(separator: "/")
|
||||
.map { segment in
|
||||
segment.addingPercentEncoding(withAllowedCharacters: Self.rfc3986) ?? String(segment)
|
||||
}
|
||||
.joined(separator: "/")
|
||||
.prependSlashIfNeeded()
|
||||
}
|
||||
|
||||
private func buildStringToSign(
|
||||
amzDate: String,
|
||||
dateStamp: String,
|
||||
canonicalRequestHash: String
|
||||
) -> String {
|
||||
"""
|
||||
AWS4-HMAC-SHA256
|
||||
\(amzDate)
|
||||
\(dateStamp)/\(config.region)/s3/aws4_request
|
||||
\(canonicalRequestHash)
|
||||
"""
|
||||
}
|
||||
|
||||
private func deriveKey(secret: String, dateStamp: String, region: String, service: String) -> Data {
|
||||
let kDate = hmac(key: Data(("AWS4" + secret).utf8), data: Data(dateStamp.utf8))
|
||||
let kRegion = hmac(key: kDate, data: Data(region.utf8))
|
||||
let kService = hmac(key: kRegion, data: Data(service.utf8))
|
||||
return hmac(key: kService, data: Data("aws4_request".utf8))
|
||||
}
|
||||
|
||||
private func hmac(key: Data, data: Data) -> Data {
|
||||
let keySym = SymmetricKey(data: key)
|
||||
let mac = HMAC<SHA256>.authenticationCode(for: data, using: keySym)
|
||||
return Data(mac)
|
||||
}
|
||||
|
||||
private func hmacHex(key: Data, data: Data) -> String {
|
||||
hmac(key: key, data: data).map { String(format: "%02x", $0) }.joined()
|
||||
}
|
||||
|
||||
private func sha256Hex(_ data: Data) -> String {
|
||||
let digest = SHA256.hash(data: data)
|
||||
return digest.compactMap { String(format: "%02x", $0) }.joined()
|
||||
}
|
||||
|
||||
private func awsTimestamp(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
|
||||
formatter.timeZone = TimeZone(abbreviation: "UTC")
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private func dateStamp(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.dateFormat = "yyyyMMdd"
|
||||
formatter.timeZone = TimeZone(abbreviation: "UTC")
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private static let rfc3986: CharacterSet = {
|
||||
var set = CharacterSet.alphanumerics
|
||||
set.insert(charactersIn: "-._~")
|
||||
return set
|
||||
}()
|
||||
}
|
||||
|
||||
private extension String {
|
||||
func prependSlashIfNeeded() -> String {
|
||||
if hasPrefix("/") {
|
||||
return self
|
||||
}
|
||||
return "/" + self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,7 +57,9 @@ final class ClusterStateService: ObservableObject {
|
||||
var request = URLRequest(url: url)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (data, response) = try await session.data(for: request)
|
||||
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
return
|
||||
}
|
||||
if let nodeId = try? decoder.decode(String.self, from: data) {
|
||||
@@ -113,7 +115,9 @@ final class ClusterStateService: ObservableObject {
|
||||
}
|
||||
}
|
||||
|
||||
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int) async {
|
||||
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int)
|
||||
async
|
||||
{
|
||||
do {
|
||||
var request = URLRequest(url: baseURL.appendingPathComponent("instance"))
|
||||
request.httpMethod = "POST"
|
||||
@@ -122,7 +126,7 @@ final class ClusterStateService: ObservableObject {
|
||||
"model_id": modelId,
|
||||
"sharding": sharding,
|
||||
"instance_meta": instanceMeta,
|
||||
"min_nodes": minNodes
|
||||
"min_nodes": minNodes,
|
||||
]
|
||||
request.httpBody = try JSONSerialization.data(withJSONObject: payload, options: [])
|
||||
let (_, response) = try await session.data(for: request)
|
||||
@@ -143,7 +147,9 @@ final class ClusterStateService: ObservableObject {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("models")
|
||||
let (data, response) = try await session.data(from: url)
|
||||
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
throw URLError(.badServerResponse)
|
||||
}
|
||||
let list = try decoder.decode(ModelListResponse.self, from: data)
|
||||
|
||||
150
app/EXO/EXO/Services/LocalNetworkChecker.swift
Normal file
150
app/EXO/EXO/Services/LocalNetworkChecker.swift
Normal file
@@ -0,0 +1,150 @@
|
||||
import Foundation
|
||||
import Network
|
||||
import os.log
|
||||
|
||||
/// Checks if the app's local network permission is actually functional.
|
||||
///
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service detects this by creating a UDP
|
||||
/// connection to the mDNS multicast address (224.0.0.251:5353).
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
case unknown
|
||||
case checking
|
||||
case working
|
||||
case notWorking(reason: String)
|
||||
|
||||
var isHealthy: Bool {
|
||||
if case .working = self { return true }
|
||||
return false
|
||||
}
|
||||
|
||||
var displayText: String {
|
||||
switch self {
|
||||
case .unknown:
|
||||
return "Unknown"
|
||||
case .checking:
|
||||
return "Checking..."
|
||||
case .working:
|
||||
return "Working"
|
||||
case .notWorking(let reason):
|
||||
return reason
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
@Published private(set) var lastConnectionState: String = "none"
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
/// Checks if local network access is working.
|
||||
func check() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
lastConnectionState = "connecting"
|
||||
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
let result = await self.performCheck()
|
||||
self.status = result
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
|
||||
private func performCheck() async -> Status {
|
||||
Self.logger.info("Checking local network access via UDP multicast")
|
||||
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
|
||||
// mDNS multicast address - same as libp2p uses for peer discovery
|
||||
let host = NWEndpoint.Host("224.0.0.251")
|
||||
let port = NWEndpoint.Port(integerLiteral: 5353)
|
||||
|
||||
let params = NWParameters.udp
|
||||
params.allowLocalEndpointReuse = true
|
||||
|
||||
let conn = NWConnection(host: host, port: port, using: params)
|
||||
connection = conn
|
||||
|
||||
return await withCheckedContinuation { continuation in
|
||||
var hasResumed = false
|
||||
let lock = NSLock()
|
||||
|
||||
let resumeOnce: (Status) -> Void = { status in
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
guard !hasResumed else { return }
|
||||
hasResumed = true
|
||||
continuation.resume(returning: status)
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { [weak self] state in
|
||||
let stateStr: String
|
||||
switch state {
|
||||
case .setup: stateStr = "setup"
|
||||
case .preparing: stateStr = "preparing"
|
||||
case .ready: stateStr = "ready"
|
||||
case .waiting(let e): stateStr = "waiting(\(e))"
|
||||
case .failed(let e): stateStr = "failed(\(e))"
|
||||
case .cancelled: stateStr = "cancelled"
|
||||
@unknown default: stateStr = "unknown"
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
self?.lastConnectionState = stateStr
|
||||
}
|
||||
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
case .waiting(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
|
||||
resumeOnce(.notWorking(reason: "Connection blocked"))
|
||||
}
|
||||
case .failed(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|
||||
|| errorStr.contains("permission") || errorStr.contains("denied")
|
||||
{
|
||||
resumeOnce(.notWorking(reason: "Permission denied"))
|
||||
} else {
|
||||
resumeOnce(.notWorking(reason: "Failed: \(error.localizedDescription)"))
|
||||
}
|
||||
case .cancelled, .setup, .preparing:
|
||||
break
|
||||
@unknown default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
conn.start(queue: .main)
|
||||
|
||||
Task {
|
||||
try? await Task.sleep(nanoseconds: 3_000_000_000)
|
||||
let state = conn.state
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
case .waiting, .preparing, .setup:
|
||||
resumeOnce(.notWorking(reason: "Timeout (may be blocked)"))
|
||||
default:
|
||||
resumeOnce(.notWorking(reason: "Timeout"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stop() {
|
||||
checkTask?.cancel()
|
||||
checkTask = nil
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
}
|
||||
}
|
||||
@@ -5,61 +5,62 @@ import os.log
|
||||
enum NetworkSetupHelper {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
|
||||
private static let daemonLabel = "io.exo.networksetup"
|
||||
private static let scriptDestination = "/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let scriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1791
|
||||
|
||||
private static let setupScript = """
|
||||
#!/usr/bin/env bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
|
||||
static func ensureLaunchDaemonInstalled() {
|
||||
Task.detached {
|
||||
@@ -70,7 +71,9 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
try await installLaunchDaemon()
|
||||
logger.info("Network setup launch daemon installed and started")
|
||||
} catch {
|
||||
logger.error("Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)")
|
||||
logger.error(
|
||||
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -82,7 +85,8 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
guard scriptExists, plistExists else { return false }
|
||||
guard
|
||||
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
|
||||
let plist = try? PropertyListSerialization.propertyList(from: data, options: [], format: nil) as? [String: Any]
|
||||
let plist = try? PropertyListSerialization.propertyList(
|
||||
from: data, options: [], format: nil) as? [String: Any]
|
||||
else {
|
||||
return false
|
||||
}
|
||||
@@ -92,7 +96,9 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
else {
|
||||
return false
|
||||
}
|
||||
if let programArgs = plist["ProgramArguments"] as? [String], programArgs.contains(scriptDestination) == false {
|
||||
if let programArgs = plist["ProgramArguments"] as? [String],
|
||||
programArgs.contains(scriptDestination) == false
|
||||
{
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -105,58 +111,59 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
|
||||
private static func makeInstallerScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func runShellAsAdmin(_ script: String) throws {
|
||||
let escapedScript = script
|
||||
let escapedScript =
|
||||
script
|
||||
.replacingOccurrences(of: "\\", with: "\\\\")
|
||||
.replacingOccurrences(of: "\"", with: "\\\"")
|
||||
|
||||
let appleScriptSource = """
|
||||
do shell script "\(escapedScript)" with administrator privileges
|
||||
"""
|
||||
do shell script "\(escapedScript)" with administrator privileges
|
||||
"""
|
||||
|
||||
guard let appleScript = NSAppleScript(source: appleScriptSource) else {
|
||||
throw NetworkSetupError.scriptCreationFailed
|
||||
|
||||
@@ -35,14 +35,34 @@ struct NetworkStatus: Equatable {
|
||||
let thunderboltBridgeState: ThunderboltState?
|
||||
let bridgeInactive: Bool?
|
||||
let interfaceStatuses: [InterfaceIpStatus]
|
||||
let rdmaStatus: RDMAStatus
|
||||
|
||||
static let empty = NetworkStatus(
|
||||
thunderboltBridgeState: nil,
|
||||
bridgeInactive: nil,
|
||||
interfaceStatuses: []
|
||||
interfaceStatuses: [],
|
||||
rdmaStatus: .empty
|
||||
)
|
||||
}
|
||||
|
||||
struct RDMAStatus: Equatable {
|
||||
let rdmaCtlEnabled: Bool?
|
||||
let devices: [String]
|
||||
let activePorts: [RDMAPort]
|
||||
|
||||
var isAvailable: Bool {
|
||||
rdmaCtlEnabled == true || !devices.isEmpty
|
||||
}
|
||||
|
||||
static let empty = RDMAStatus(rdmaCtlEnabled: nil, devices: [], activePorts: [])
|
||||
}
|
||||
|
||||
struct RDMAPort: Equatable {
|
||||
let device: String
|
||||
let port: String
|
||||
let state: String
|
||||
}
|
||||
|
||||
struct InterfaceIpStatus: Equatable {
|
||||
let interfaceName: String
|
||||
let ipAddress: String?
|
||||
@@ -59,10 +79,79 @@ private struct NetworkStatusFetcher {
|
||||
NetworkStatus(
|
||||
thunderboltBridgeState: readThunderboltBridgeState(),
|
||||
bridgeInactive: readBridgeInactive(),
|
||||
interfaceStatuses: readInterfaceStatuses()
|
||||
interfaceStatuses: readInterfaceStatuses(),
|
||||
rdmaStatus: readRDMAStatus()
|
||||
)
|
||||
}
|
||||
|
||||
private func readRDMAStatus() -> RDMAStatus {
|
||||
let rdmaCtlEnabled = readRDMACtlEnabled()
|
||||
let devices = readRDMADevices()
|
||||
let activePorts = readRDMAActivePorts()
|
||||
return RDMAStatus(
|
||||
rdmaCtlEnabled: rdmaCtlEnabled, devices: devices, activePorts: activePorts)
|
||||
}
|
||||
|
||||
private func readRDMACtlEnabled() -> Bool? {
|
||||
let result = runCommand(["rdma_ctl", "status"])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
if output.contains("enabled") {
|
||||
return true
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
private func readRDMADevices() -> [String] {
|
||||
let result = runCommand(["ibv_devices"])
|
||||
guard result.exitCode == 0 else { return [] }
|
||||
var devices: [String] = []
|
||||
for line in result.output.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("---") || trimmed.lowercased().hasPrefix("device")
|
||||
|| trimmed.isEmpty
|
||||
{
|
||||
continue
|
||||
}
|
||||
let parts = trimmed.split(separator: " ", maxSplits: 1)
|
||||
if let deviceName = parts.first {
|
||||
devices.append(String(deviceName))
|
||||
}
|
||||
}
|
||||
return devices
|
||||
}
|
||||
|
||||
private func readRDMAActivePorts() -> [RDMAPort] {
|
||||
let result = runCommand(["ibv_devinfo"])
|
||||
guard result.exitCode == 0 else { return [] }
|
||||
var ports: [RDMAPort] = []
|
||||
var currentDevice: String?
|
||||
var currentPort: String?
|
||||
|
||||
for line in result.output.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("hca_id:") {
|
||||
currentDevice = trimmed.replacingOccurrences(of: "hca_id:", with: "")
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
} else if trimmed.hasPrefix("port:") {
|
||||
currentPort = trimmed.replacingOccurrences(of: "port:", with: "")
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
} else if trimmed.hasPrefix("state:") {
|
||||
let state = trimmed.replacingOccurrences(of: "state:", with: "").trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
if let device = currentDevice, let port = currentPort {
|
||||
if state.lowercased().contains("active") {
|
||||
ports.append(RDMAPort(device: device, port: port, state: state))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ports
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeState() -> ThunderboltState? {
|
||||
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
|
||||
guard result.exitCode == 0 else {
|
||||
@@ -85,10 +174,11 @@ private struct NetworkStatusFetcher {
|
||||
private func readBridgeInactive() -> Bool? {
|
||||
let result = runCommand(["ifconfig", "bridge0"])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
guard let statusLine = result.output
|
||||
.components(separatedBy: .newlines)
|
||||
.first(where: { $0.contains("status:") })?
|
||||
.lowercased()
|
||||
guard
|
||||
let statusLine = result.output
|
||||
.components(separatedBy: .newlines)
|
||||
.first(where: { $0.contains("status:") })?
|
||||
.lowercased()
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
@@ -171,4 +261,3 @@ private struct NetworkStatusFetcher {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -107,10 +107,13 @@ extension ClusterState {
|
||||
let nodeToRunner = instance.shardAssignments.nodeToRunner
|
||||
let nodeIds = Array(nodeToRunner.keys)
|
||||
let runnerIds = Array(nodeToRunner.values)
|
||||
let nodeNames = nodeIds.compactMap { nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0 }
|
||||
let nodeNames = nodeIds.compactMap {
|
||||
nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0
|
||||
}
|
||||
let statuses = runnerIds.compactMap { runners[$0]?.status.lowercased() }
|
||||
let downloadProgress = aggregateDownloadProgress(for: nodeIds)
|
||||
let state = InstanceViewModel.State(statuses: statuses, hasActiveDownload: downloadProgress != nil)
|
||||
let state = InstanceViewModel.State(
|
||||
statuses: statuses, hasActiveDownload: downloadProgress != nil)
|
||||
let chatTasks = (chatTasksByInstance[entry.key] ?? [])
|
||||
.sorted(by: { $0.sortPriority < $1.sortPriority })
|
||||
.map { InstanceTaskViewModel(task: $0) }
|
||||
@@ -165,8 +168,8 @@ extension ClusterState {
|
||||
}
|
||||
}
|
||||
|
||||
private extension InstanceViewModel.State {
|
||||
init(statuses: [String], hasActiveDownload: Bool = false) {
|
||||
extension InstanceViewModel.State {
|
||||
fileprivate init(statuses: [String], hasActiveDownload: Bool = false) {
|
||||
if statuses.contains(where: { $0.contains("failed") }) {
|
||||
self = .failed
|
||||
} else if hasActiveDownload || statuses.contains(where: { $0.contains("downloading") }) {
|
||||
@@ -243,4 +246,3 @@ extension InstanceTaskViewModel {
|
||||
self.parameters = task.parameters
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +87,9 @@ struct TopologyViewModel {
|
||||
extension ClusterState {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
|
||||
let allNodes = nodeViewModels().filter {
|
||||
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
|
||||
}
|
||||
guard !allNodes.isEmpty else { return nil }
|
||||
|
||||
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
|
||||
@@ -106,18 +108,24 @@ extension ClusterState {
|
||||
}
|
||||
|
||||
// Rotate so the local node (from /node_id API) is first
|
||||
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
|
||||
if let localId = localNodeId,
|
||||
let index = orderedNodes.firstIndex(where: { $0.id == localId })
|
||||
{
|
||||
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
|
||||
}
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
|
||||
return TopologyEdgeViewModel(sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
|
||||
} ?? []
|
||||
let edgesArray: [TopologyEdgeViewModel] =
|
||||
topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId),
|
||||
nodeIds.contains(connection.sendBackNodeId)
|
||||
else { return nil }
|
||||
return TopologyEdgeViewModel(
|
||||
sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
|
||||
} ?? []
|
||||
let edges = Set(edgesArray)
|
||||
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
|
||||
return TopologyViewModel(
|
||||
nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ struct InstanceRowView: View {
|
||||
if let progress = instance.downloadProgress {
|
||||
downloadStatusView(progress: progress)
|
||||
} else {
|
||||
statusChip(label: instance.state.label.uppercased(), color: statusColor)
|
||||
}
|
||||
statusChip(label: instance.state.label.uppercased(), color: statusColor)
|
||||
}
|
||||
}
|
||||
if let progress = instance.downloadProgress {
|
||||
GeometryReader { geometry in
|
||||
@@ -97,7 +97,8 @@ struct InstanceRowView: View {
|
||||
.font(.caption)
|
||||
.fontWeight(.semibold)
|
||||
if let subtitle = task.subtitle,
|
||||
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame {
|
||||
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame
|
||||
{
|
||||
Text(subtitle)
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
@@ -234,9 +235,12 @@ struct InstanceRowView: View {
|
||||
Button {
|
||||
isExpanded.wrappedValue.toggle()
|
||||
} label: {
|
||||
Label(isExpanded.wrappedValue ? "Hide" : "Show", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
Label(
|
||||
isExpanded.wrappedValue ? "Hide" : "Show",
|
||||
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
|
||||
)
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.font(.caption2)
|
||||
@@ -311,7 +315,9 @@ struct InstanceRowView: View {
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private func detailRow(icon: String? = nil, title: String, value: String, tint: Color = .secondary) -> some View {
|
||||
private func detailRow(
|
||||
icon: String? = nil, title: String, value: String, tint: Color = .secondary
|
||||
) -> some View {
|
||||
HStack(alignment: .firstTextBaseline, spacing: 6) {
|
||||
if let icon {
|
||||
Image(systemName: icon)
|
||||
@@ -329,4 +335,3 @@ struct InstanceRowView: View {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,4 +32,3 @@ struct NodeDetailView: View {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,4 +28,3 @@ struct NodeRowView: View {
|
||||
.padding(.vertical, 4)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -76,30 +76,33 @@ struct TopologyMiniView: View {
|
||||
|
||||
private func connectionLines(in size: CGSize) -> some View {
|
||||
let positions = positionedNodes(in: size)
|
||||
let positionById = Dictionary(uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
|
||||
let positionById = Dictionary(
|
||||
uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
|
||||
return Canvas { context, _ in
|
||||
guard !topology.edges.isEmpty else { return }
|
||||
let nodeRadius: CGFloat = 32
|
||||
let arrowLength: CGFloat = 10
|
||||
let arrowSpread: CGFloat = .pi / 7
|
||||
for edge in topology.edges {
|
||||
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId] else { continue }
|
||||
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId]
|
||||
else { continue }
|
||||
let dx = end.x - start.x
|
||||
let dy = end.y - start.y
|
||||
let distance = max(CGFloat(hypot(dx, dy)), 1)
|
||||
let ux = dx / distance
|
||||
let uy = dy / distance
|
||||
let adjustedStart = CGPoint(x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
|
||||
let adjustedStart = CGPoint(
|
||||
x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
|
||||
let adjustedEnd = CGPoint(x: end.x - ux * nodeRadius, y: end.y - uy * nodeRadius)
|
||||
|
||||
var linePath = Path()
|
||||
linePath.move(to: adjustedStart)
|
||||
linePath.addLine(to: adjustedEnd)
|
||||
context.stroke(
|
||||
context.stroke(
|
||||
linePath,
|
||||
with: .color(.secondary.opacity(0.3)),
|
||||
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
|
||||
)
|
||||
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
|
||||
)
|
||||
|
||||
let angle = atan2(uy, ux)
|
||||
let tip = adjustedEnd
|
||||
@@ -168,5 +171,3 @@ private struct NodeGlyphView: View {
|
||||
.frame(width: 95)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
//
|
||||
|
||||
import Testing
|
||||
|
||||
@testable import EXO
|
||||
|
||||
struct EXOTests {
|
||||
|
||||
526
bench/exo_bench.py
Normal file
526
bench/exo_bench.py
Normal file
@@ -0,0 +1,526 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def format_peak_memory(b: float) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if b < 1024.0:
|
||||
return f"{b:.2f}{unit}"
|
||||
b /= 1024.0
|
||||
raise ValueError("You're using petabytes of memory. Something went wrong...")
|
||||
|
||||
|
||||
def parse_int_list(values: list[str]) -> list[int]:
|
||||
items: list[int] = []
|
||||
for v in values:
|
||||
for part in v.split(","):
|
||||
part = part.strip()
|
||||
if part:
|
||||
items.append(int(part))
|
||||
|
||||
seen: set[int] = set()
|
||||
out: list[int] = []
|
||||
for x in items:
|
||||
if x not in seen:
|
||||
out.append(x)
|
||||
seen.add(x)
|
||||
return out
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if m.get("id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
full_id = str(m.get("hugging_face_id") or m["id"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def run_one_completion(
|
||||
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
content, pp_tokens = prompt_sizer.build(pp_hint)
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"max_tokens": tg,
|
||||
}
|
||||
|
||||
t0 = time.perf_counter()
|
||||
out = client.post_bench_chat_completions(payload)
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
stats = out.get("generation_stats")
|
||||
|
||||
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
|
||||
|
||||
return {
|
||||
"elapsed_s": elapsed,
|
||||
"output_text_preview": preview,
|
||||
"stats": stats,
|
||||
}, pp_tokens
|
||||
|
||||
|
||||
class PromptSizer:
|
||||
def __init__(self, tokenizer: Any, atom: str = "a "):
|
||||
self.tokenizer = tokenizer
|
||||
self.atom = atom
|
||||
self.count_fn = PromptSizer._make_counter(tokenizer)
|
||||
self.base_tokens = self.count_fn("")
|
||||
|
||||
@staticmethod
|
||||
def _make_counter(tokenizer: Any) -> Callable[[str], int]:
|
||||
def count_fn(user_content: str) -> int:
|
||||
messages = [{"role": "user", "content": user_content}]
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
|
||||
def build(self, target_prompt_tokens: int) -> tuple[str, int]:
|
||||
target = int(target_prompt_tokens)
|
||||
if target < self.base_tokens:
|
||||
raise RuntimeError(
|
||||
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
|
||||
)
|
||||
|
||||
content = ""
|
||||
tok = self.count_fn(content)
|
||||
|
||||
while tok < target:
|
||||
content += self.atom
|
||||
tok = self.count_fn(content)
|
||||
|
||||
if tok != target:
|
||||
raise RuntimeError(
|
||||
f"Overshot: got {tok} tokens (target {target}). "
|
||||
f"Pick a different atom (try ' a' or '\\n' or '0 ')."
|
||||
)
|
||||
|
||||
return content, tok
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
description="Benchmark exo model throughput across placement previews.",
|
||||
)
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--pp",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Prompt-size hints (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--tg",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Generation lengths (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Pipeline jaccl is often pointless, skip by default",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--warmup",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
default="bench/results.json",
|
||||
help="Write raw per-run results JSON to this path.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dry-run", action="store_true", help="List selected placements and exit."
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
tg_list = parse_int_list(args.tg)
|
||||
if not pp_list or not tg_list:
|
||||
logger.error("pp and tg lists must be non-empty")
|
||||
return 2
|
||||
if args.repeat <= 0:
|
||||
logger.error("--repeat must be >= 1")
|
||||
return 2
|
||||
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": short_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
full_model_id,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("[exo-bench] tokenizer load failed")
|
||||
|
||||
try:
|
||||
prompt_sizer = PromptSizer(tokenizer)
|
||||
logger.debug(f"[exo-bench] loaded tokenizer: {full_model_id} for prompt sizer")
|
||||
except Exception:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if 0 < n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
|
||||
selected.sort(
|
||||
key=lambda p: (
|
||||
str(p.get("instance_meta", "")),
|
||||
str(p.get("sharding", "")),
|
||||
-nodes_used_in_instance(p["instance"]),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
|
||||
logger.info(f"placements: {len(selected)}")
|
||||
for p in selected:
|
||||
logger.info(
|
||||
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
|
||||
)
|
||||
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
|
||||
sharding = str(preview["sharding"])
|
||||
instance_meta = str(preview["instance_meta"])
|
||||
n_nodes = nodes_used_in_instance(instance)
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info(
|
||||
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
for i in range(args.warmup):
|
||||
run_one_completion(
|
||||
client, full_model_id, pp_list[0], tg_list[0], prompt_sizer
|
||||
)
|
||||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||||
|
||||
for pp in pp_list:
|
||||
if (
|
||||
pp * n_nodes > 2048
|
||||
and "ring" in instance_meta.lower()
|
||||
and "tensor" in sharding.lower()
|
||||
):
|
||||
model_card = MODEL_CARDS[short_id]
|
||||
if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
logger.info(
|
||||
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
)
|
||||
continue
|
||||
for tg in tg_list:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
time.sleep(3)
|
||||
try:
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client, full_model_id, pp, tg, prompt_sizer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
row.update(
|
||||
{
|
||||
"model_short_id": short_id,
|
||||
"model_id": full_model_id,
|
||||
"placement_sharding": sharding,
|
||||
"placement_instance_meta": instance_meta,
|
||||
"placement_nodes": n_nodes,
|
||||
"instance_id": instance_id,
|
||||
"pp_tokens": actual_pp_tokens,
|
||||
"tg": tg,
|
||||
"repeat_index": r,
|
||||
}
|
||||
)
|
||||
runs.append(row)
|
||||
all_rows.append(row)
|
||||
|
||||
if runs:
|
||||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||||
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
f"peak_memory={format_peak_memory(peak)}\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
finally:
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
logger.debug(f"Deleted instance {instance_id}")
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
if args.json_out:
|
||||
with open(args.json_out, "w", encoding="utf-8") as f:
|
||||
json.dump(all_rows, f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"\nWrote results JSON: {args.json_out}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
1
dashboard/src/app.d.ts
vendored
1
dashboard/src/app.d.ts
vendored
@@ -11,4 +11,3 @@ declare global {
|
||||
}
|
||||
|
||||
export {};
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
export { default as TopologyGraph } from './TopologyGraph.svelte';
|
||||
export { default as ChatForm } from './ChatForm.svelte';
|
||||
export { default as ChatMessages } from './ChatMessages.svelte';
|
||||
export { default as ChatAttachments } from './ChatAttachments.svelte';
|
||||
export { default as ChatSidebar } from './ChatSidebar.svelte';
|
||||
export { default as ModelCard } from './ModelCard.svelte';
|
||||
export { default as MarkdownContent } from './MarkdownContent.svelte';
|
||||
|
||||
export { default as TopologyGraph } from "./TopologyGraph.svelte";
|
||||
export { default as ChatForm } from "./ChatForm.svelte";
|
||||
export { default as ChatMessages } from "./ChatMessages.svelte";
|
||||
export { default as ChatAttachments } from "./ChatAttachments.svelte";
|
||||
export { default as ChatSidebar } from "./ChatSidebar.svelte";
|
||||
export { default as ModelCard } from "./ModelCard.svelte";
|
||||
export { default as MarkdownContent } from "./MarkdownContent.svelte";
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,55 +13,124 @@ export interface ChatUploadedFile {
|
||||
}
|
||||
|
||||
export interface ChatAttachment {
|
||||
type: 'image' | 'text' | 'pdf' | 'audio';
|
||||
type: "image" | "text" | "pdf" | "audio";
|
||||
name: string;
|
||||
content?: string;
|
||||
base64Url?: string;
|
||||
mimeType?: string;
|
||||
}
|
||||
|
||||
export type FileCategory = 'image' | 'text' | 'pdf' | 'audio' | 'unknown';
|
||||
export type FileCategory = "image" | "text" | "pdf" | "audio" | "unknown";
|
||||
|
||||
export const IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.svg'];
|
||||
export const IMAGE_MIME_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/svg+xml'];
|
||||
export const IMAGE_EXTENSIONS = [
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".gif",
|
||||
".webp",
|
||||
".svg",
|
||||
];
|
||||
export const IMAGE_MIME_TYPES = [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/svg+xml",
|
||||
];
|
||||
|
||||
export const TEXT_EXTENSIONS = [
|
||||
'.txt', '.md', '.json', '.xml', '.yaml', '.yml', '.csv', '.log',
|
||||
'.js', '.ts', '.jsx', '.tsx', '.py', '.java', '.cpp', '.c', '.h',
|
||||
'.css', '.html', '.htm', '.sql', '.sh', '.bat', '.rs', '.go',
|
||||
'.rb', '.php', '.swift', '.kt', '.scala', '.r', '.dart', '.vue', '.svelte'
|
||||
".txt",
|
||||
".md",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".csv",
|
||||
".log",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".py",
|
||||
".java",
|
||||
".cpp",
|
||||
".c",
|
||||
".h",
|
||||
".css",
|
||||
".html",
|
||||
".htm",
|
||||
".sql",
|
||||
".sh",
|
||||
".bat",
|
||||
".rs",
|
||||
".go",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".r",
|
||||
".dart",
|
||||
".vue",
|
||||
".svelte",
|
||||
];
|
||||
export const TEXT_MIME_TYPES = [
|
||||
'text/plain', 'text/markdown', 'text/csv', 'text/html', 'text/css',
|
||||
'application/json', 'application/xml', 'text/xml', 'application/javascript',
|
||||
'text/javascript', 'application/typescript'
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"text/html",
|
||||
"text/css",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/javascript",
|
||||
"text/javascript",
|
||||
"application/typescript",
|
||||
];
|
||||
|
||||
export const PDF_EXTENSIONS = ['.pdf'];
|
||||
export const PDF_MIME_TYPES = ['application/pdf'];
|
||||
export const PDF_EXTENSIONS = [".pdf"];
|
||||
export const PDF_MIME_TYPES = ["application/pdf"];
|
||||
|
||||
export const AUDIO_EXTENSIONS = ['.mp3', '.wav', '.ogg', '.m4a'];
|
||||
export const AUDIO_MIME_TYPES = ['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/mp4'];
|
||||
export const AUDIO_EXTENSIONS = [".mp3", ".wav", ".ogg", ".m4a"];
|
||||
export const AUDIO_MIME_TYPES = [
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/ogg",
|
||||
"audio/mp4",
|
||||
];
|
||||
|
||||
/**
|
||||
* Get file category based on MIME type and extension
|
||||
*/
|
||||
export function getFileCategory(mimeType: string, fileName: string): FileCategory {
|
||||
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf('.'));
|
||||
|
||||
if (IMAGE_MIME_TYPES.includes(mimeType) || IMAGE_EXTENSIONS.includes(extension)) {
|
||||
return 'image';
|
||||
export function getFileCategory(
|
||||
mimeType: string,
|
||||
fileName: string,
|
||||
): FileCategory {
|
||||
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf("."));
|
||||
|
||||
if (
|
||||
IMAGE_MIME_TYPES.includes(mimeType) ||
|
||||
IMAGE_EXTENSIONS.includes(extension)
|
||||
) {
|
||||
return "image";
|
||||
}
|
||||
if (PDF_MIME_TYPES.includes(mimeType) || PDF_EXTENSIONS.includes(extension)) {
|
||||
return 'pdf';
|
||||
return "pdf";
|
||||
}
|
||||
if (AUDIO_MIME_TYPES.includes(mimeType) || AUDIO_EXTENSIONS.includes(extension)) {
|
||||
return 'audio';
|
||||
if (
|
||||
AUDIO_MIME_TYPES.includes(mimeType) ||
|
||||
AUDIO_EXTENSIONS.includes(extension)
|
||||
) {
|
||||
return "audio";
|
||||
}
|
||||
if (TEXT_MIME_TYPES.includes(mimeType) || TEXT_EXTENSIONS.includes(extension) || mimeType.startsWith('text/')) {
|
||||
return 'text';
|
||||
if (
|
||||
TEXT_MIME_TYPES.includes(mimeType) ||
|
||||
TEXT_EXTENSIONS.includes(extension) ||
|
||||
mimeType.startsWith("text/")
|
||||
) {
|
||||
return "text";
|
||||
}
|
||||
return 'unknown';
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -69,36 +138,36 @@ export function getFileCategory(mimeType: string, fileName: string): FileCategor
|
||||
*/
|
||||
export function getAcceptString(categories: FileCategory[]): string {
|
||||
const accepts: string[] = [];
|
||||
|
||||
|
||||
for (const category of categories) {
|
||||
switch (category) {
|
||||
case 'image':
|
||||
case "image":
|
||||
accepts.push(...IMAGE_EXTENSIONS, ...IMAGE_MIME_TYPES);
|
||||
break;
|
||||
case 'text':
|
||||
case "text":
|
||||
accepts.push(...TEXT_EXTENSIONS, ...TEXT_MIME_TYPES);
|
||||
break;
|
||||
case 'pdf':
|
||||
case "pdf":
|
||||
accepts.push(...PDF_EXTENSIONS, ...PDF_MIME_TYPES);
|
||||
break;
|
||||
case 'audio':
|
||||
case "audio":
|
||||
accepts.push(...AUDIO_EXTENSIONS, ...AUDIO_MIME_TYPES);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return accepts.join(',');
|
||||
|
||||
return accepts.join(",");
|
||||
}
|
||||
|
||||
/**
|
||||
* Format file size for display
|
||||
*/
|
||||
export function formatFileSize(bytes: number): string {
|
||||
if (bytes === 0) return '0 B';
|
||||
if (bytes === 0) return "0 B";
|
||||
const k = 1024;
|
||||
const sizes = ['B', 'KB', 'MB', 'GB'];
|
||||
const sizes = ["B", "KB", "MB", "GB"];
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + " " + sizes[i];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -128,42 +197,44 @@ export function readFileAsText(file: File): Promise<string> {
|
||||
/**
|
||||
* Process uploaded files into ChatUploadedFile format
|
||||
*/
|
||||
export async function processUploadedFiles(files: File[]): Promise<ChatUploadedFile[]> {
|
||||
export async function processUploadedFiles(
|
||||
files: File[],
|
||||
): Promise<ChatUploadedFile[]> {
|
||||
const results: ChatUploadedFile[] = [];
|
||||
|
||||
|
||||
for (const file of files) {
|
||||
const id = Date.now().toString() + Math.random().toString(36).substring(2, 9);
|
||||
const id =
|
||||
Date.now().toString() + Math.random().toString(36).substring(2, 9);
|
||||
const category = getFileCategory(file.type, file.name);
|
||||
|
||||
|
||||
const base: ChatUploadedFile = {
|
||||
id,
|
||||
name: file.name,
|
||||
size: file.size,
|
||||
type: file.type,
|
||||
file
|
||||
file,
|
||||
};
|
||||
|
||||
|
||||
try {
|
||||
if (category === 'image') {
|
||||
if (category === "image") {
|
||||
const preview = await readFileAsDataURL(file);
|
||||
results.push({ ...base, preview });
|
||||
} else if (category === 'text' || category === 'unknown') {
|
||||
} else if (category === "text" || category === "unknown") {
|
||||
const textContent = await readFileAsText(file);
|
||||
results.push({ ...base, textContent });
|
||||
} else if (category === 'pdf') {
|
||||
} else if (category === "pdf") {
|
||||
results.push(base);
|
||||
} else if (category === 'audio') {
|
||||
} else if (category === "audio") {
|
||||
const preview = await readFileAsDataURL(file);
|
||||
results.push({ ...base, preview });
|
||||
} else {
|
||||
results.push(base);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error processing file:', file.name, error);
|
||||
console.error("Error processing file:", file.name, error);
|
||||
results.push(base);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
@@ -332,8 +332,13 @@
|
||||
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
|
||||
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
|
||||
</div>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
|
||||
<div>
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
|
||||
</div>
|
||||
<div class="text-exo-light-gray normal-case tracking-normal">
|
||||
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -385,7 +390,7 @@
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
{#if model.status !== 'completed'}
|
||||
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
|
||||
{/if}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import tailwindcss from '@tailwindcss/vite';
|
||||
import { sveltekit } from '@sveltejs/kit/vite';
|
||||
import { defineConfig } from 'vite';
|
||||
import tailwindcss from "@tailwindcss/vite";
|
||||
import { sveltekit } from "@sveltejs/kit/vite";
|
||||
import { defineConfig } from "vite";
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [tailwindcss(), sveltekit()],
|
||||
server: {
|
||||
proxy: {
|
||||
'/v1': 'http://localhost:52415',
|
||||
'/state': 'http://localhost:52415',
|
||||
'/models': 'http://localhost:52415',
|
||||
'/instance': 'http://localhost:52415'
|
||||
}
|
||||
}
|
||||
"/v1": "http://localhost:52415",
|
||||
"/state": "http://localhost:52415",
|
||||
"/models": "http://localhost:52415",
|
||||
"/instance": "http://localhost:52415",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
21
flake.nix
21
flake.nix
@@ -42,11 +42,22 @@
|
||||
};
|
||||
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
|
||||
projectRootFile = "flake.nix";
|
||||
programs.ruff-format.enable = true;
|
||||
programs.ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
programs.rustfmt.enable = true;
|
||||
programs.rustfmt.package = (fenixToolchain system).rustfmt;
|
||||
programs.nixpkgs-fmt.enable = true;
|
||||
programs = {
|
||||
nixpkgs-fmt.enable = true;
|
||||
ruff-format = {
|
||||
enable = true;
|
||||
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
};
|
||||
rustfmt = {
|
||||
enable = true;
|
||||
package = (fenixToolchain system).rustfmt;
|
||||
};
|
||||
prettier = {
|
||||
enable = true;
|
||||
includes = [ "*.ts" ];
|
||||
};
|
||||
swift-format.enable = true;
|
||||
};
|
||||
};
|
||||
in
|
||||
{
|
||||
|
||||
@@ -82,7 +82,7 @@ build-backend = "uv_build"
|
||||
###
|
||||
|
||||
[tool.basedpyright]
|
||||
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"]
|
||||
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src", "bench"]
|
||||
typeCheckingMode = "strict"
|
||||
failOnWarnings = true
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
@@ -34,6 +36,7 @@ from exo.shared.types.api import (
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -172,6 +175,7 @@ class API:
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -490,6 +494,45 @@ class API:
|
||||
],
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
resp = BenchChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=combined_text
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
generation_stats=stats,
|
||||
)
|
||||
return resp
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
@@ -525,6 +568,33 @@ class API:
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(payload.model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {payload.model}"
|
||||
)
|
||||
|
||||
payload.stream = False
|
||||
|
||||
command = ChatCompletion(request_params=payload)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_chat_completion_with_stats(
|
||||
command.command_id,
|
||||
parse_gpt_oss,
|
||||
)
|
||||
return response
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
@@ -51,6 +52,10 @@ class ChatCompletionMessage(BaseModel):
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class BenchChatCompletionMessage(ChatCompletionMessage):
|
||||
pass
|
||||
|
||||
|
||||
class TopLogprobItem(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
@@ -113,6 +118,18 @@ class ChatCompletionResponse(BaseModel):
|
||||
service_tier: str | None = None
|
||||
|
||||
|
||||
class GenerationStats(BaseModel):
|
||||
prompt_tps: float
|
||||
generation_tps: float
|
||||
prompt_tokens: int
|
||||
generation_tokens: int
|
||||
peak_memory_usage: Memory
|
||||
|
||||
|
||||
class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
@@ -135,6 +152,10 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
user: str | None = None
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
pass
|
||||
|
||||
|
||||
class PlaceInstanceParams(BaseModel):
|
||||
model_id: str
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -20,6 +21,7 @@ class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
token: int
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -3,10 +3,17 @@ from typing import Any, Callable, Generator, cast, get_args
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.shared.types.api import ChatCompletionMessage, FinishReason
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
@@ -41,7 +48,6 @@ def maybe_quantize_kv_cache(
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -64,6 +70,9 @@ def warmup_inference(
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Use a default sampler for warmup
|
||||
sampler = make_sampler(temp=0.7)
|
||||
|
||||
logger.info("Generating warmup tokens")
|
||||
for _r in stream_generate(
|
||||
model=model,
|
||||
@@ -72,7 +81,7 @@ def warmup_inference(
|
||||
max_tokens=50,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=65536,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
@@ -80,20 +89,47 @@ def warmup_inference(
|
||||
tokens_generated += 1
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
|
||||
return tokens_generated
|
||||
|
||||
|
||||
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
|
||||
token_ids = [int(t) for t in token_ids]
|
||||
|
||||
def proc(_history: mx.array, logits: mx.array) -> mx.array:
|
||||
for tid in token_ids:
|
||||
logits[..., tid] = -1e9
|
||||
return logits
|
||||
|
||||
return proc
|
||||
|
||||
|
||||
def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
eos: list[int] | None = getattr(tokenizer, "eos_token_ids", None)
|
||||
if eos is None:
|
||||
return []
|
||||
return eos
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
task: ChatCompletionTaskParams,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
|
||||
|
||||
# Currently we support chat-completion tasks only.
|
||||
logger.info(f"task_params: {task}")
|
||||
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
|
||||
prompt = apply_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
chat_task_data=task,
|
||||
@@ -101,6 +137,17 @@ def mlx_generate(
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
eos_ids = eos_ids_from_tokenizer(tokenizer)
|
||||
logits_processors = [ban_token_ids(eos_ids)]
|
||||
|
||||
sampler = make_sampler(
|
||||
temp=task.temperature if task.temperature is not None else 0.7,
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
@@ -108,26 +155,40 @@ def mlx_generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
prefill_step_size=65536,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(out.text)
|
||||
if out.finish_reason is not None and out.finish_reason not in get_args(
|
||||
FinishReason
|
||||
):
|
||||
# We don't throw here as this failure case is really not all that bad
|
||||
# Just log the error and move on
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
|
||||
if out.finish_reason not in get_args(FinishReason):
|
||||
# We don't throw here as this failure case is really not all that bad
|
||||
# Just log the error and move on
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -3,11 +3,10 @@ import os
|
||||
import resource
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
@@ -176,11 +175,7 @@ def initialize_mlx(
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
# TODO: pass temperature
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
|
||||
logger.info("Created a sampler")
|
||||
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
@@ -201,7 +196,7 @@ def load_mlx_items(
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
return cast(Model, model), tokenizer, sampler
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
@@ -397,3 +392,13 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
)
|
||||
mx.set_wired_limit(max_rec_size)
|
||||
logger.info(f"Wired limit set to {max_rec_size}.")
|
||||
|
||||
|
||||
def mlx_cleanup(
|
||||
model: Model | None, tokenizer: TokenizerWrapper | None, group: Group | None
|
||||
) -> None:
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
@@ -6,7 +6,7 @@ from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
|
||||
logger: "loguru.Logger" = loguru.logger
|
||||
|
||||
@@ -31,6 +31,8 @@ def entrypoint(
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
f"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}"
|
||||
@@ -42,8 +44,10 @@ def entrypoint(
|
||||
)
|
||||
)
|
||||
finally:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
try:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
finally:
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
@@ -36,7 +38,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
@@ -56,180 +58,153 @@ def main(
|
||||
bound_instance.bound_runner_id,
|
||||
bound_instance.bound_shard,
|
||||
)
|
||||
try:
|
||||
logger.info("hello from the runner")
|
||||
if getattr(shard_metadata, "immediate_exception", False):
|
||||
raise Exception("Fake exception - runner failed to spin up.")
|
||||
if timeout := getattr(shard_metadata, "should_timeout", 0):
|
||||
time.sleep(timeout)
|
||||
logger.info("hello from the runner")
|
||||
if getattr(shard_metadata, "immediate_exception", False):
|
||||
raise Exception("Fake exception - runner failed to spin up.")
|
||||
if timeout := getattr(shard_metadata, "should_timeout", 0):
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
setup_start_time = time.time()
|
||||
|
||||
model = None
|
||||
tokenizer = None
|
||||
sampler = None
|
||||
group = None
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected)
|
||||
and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, sampler = load_mlx_items(
|
||||
bound_instance, group
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
assert sampler
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert model
|
||||
assert tokenizer
|
||||
assert sampler
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
task=task_params,
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
),
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
break
|
||||
except ClosedResourceError:
|
||||
logger.warning("runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
f"Runner {runner_id} crashed with critical exception {e}"
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(error_message=str(e)),
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
)
|
||||
finally:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
break
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
|
||||
@@ -111,7 +111,7 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1, 1)))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ async def check_reachability(
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.HTTPException:
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
|
||||
246
tests/headless_runner.py
Normal file
246
tests/headless_runner.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import multiprocessing as mp
|
||||
import socket
|
||||
import time
|
||||
import typing
|
||||
|
||||
import anyio
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from hypercorn import Config
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.commands import CommandId
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
)
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, mp_channel
|
||||
from exo.worker.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
class Tests(BaseModel):
|
||||
# list[hostname, ip addr]
|
||||
devs: list[list[str]]
|
||||
model_id: str
|
||||
kind: typing.Literal["init", "warmup", "inference"]
|
||||
|
||||
|
||||
hn = socket.gethostname()
|
||||
mp.set_start_method("spawn", force=True)
|
||||
logger_setup(None)
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("starting cool server majig")
|
||||
logger.info(hn)
|
||||
await assert_downloads()
|
||||
cfg = Config()
|
||||
cfg.bind = "0.0.0.0:52415"
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = "-"
|
||||
cfg.errorlog = "-"
|
||||
cfg.logger_class = InterceptLogger
|
||||
app = FastAPI()
|
||||
app.post("/ring")(ring_backend)
|
||||
app.post("/jaccl")(jaccl_backend)
|
||||
shutdown = anyio.Event()
|
||||
await serve(
|
||||
app, # type: ignore
|
||||
cfg,
|
||||
shutdown_trigger=lambda: shutdown.wait(),
|
||||
)
|
||||
await anyio.sleep_forever()
|
||||
# gracefully shutdown the api
|
||||
shutdown.set()
|
||||
|
||||
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["llama-3.2-1b"].model_id))
|
||||
|
||||
|
||||
async def ring_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
return await execute_test(test, ring_instance(test, iid))
|
||||
|
||||
|
||||
def ring_instance(test: Tests, iid: InstanceId) -> Instance:
|
||||
global hn
|
||||
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
|
||||
world_size = len(test.devs)
|
||||
for i in range(world_size):
|
||||
if hn.startswith(test.devs[i][0]):
|
||||
hn = test.devs[i][0]
|
||||
if i - 1 >= 0:
|
||||
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
|
||||
if i + 1 < len(test.devs):
|
||||
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
|
||||
hbn[i] = Host(ip="0.0.0.0", port=52416)
|
||||
break
|
||||
|
||||
meta = MODEL_CARDS[test.model_id].metadata
|
||||
instance = MlxRingInstance(
|
||||
instance_id=iid,
|
||||
ephemeral_port=52416,
|
||||
hosts_by_node={NodeId(hn): hbn},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(test.model_id),
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): PipelineShardMetadata(
|
||||
model_meta=meta,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=(meta.n_layers // world_size) * i,
|
||||
end_layer=min(
|
||||
meta.n_layers, (meta.n_layers // world_size) * (i + 1)
|
||||
),
|
||||
n_layers=min(meta.n_layers, (meta.n_layers // world_size) * (i + 1))
|
||||
- (meta.n_layers // world_size) * i,
|
||||
)
|
||||
for i in range(world_size)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
async def execute_test(test: Tests, instance: Instance):
|
||||
world_size = len(test.devs)
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
_handle, recv, send = new_runner(instance)
|
||||
if world_size > 1:
|
||||
send.send(ConnectToGroup(instance_id=iid))
|
||||
send.send(LoadModel(instance_id=iid))
|
||||
|
||||
match test.kind:
|
||||
case "init":
|
||||
pass
|
||||
case "warmup":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
case "inference":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
send.send(
|
||||
ChatCompletion(
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=test.model_id,
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="system", content="You are a helpful assistant"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
],
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
)
|
||||
|
||||
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
|
||||
|
||||
async def map_recv():
|
||||
with recv:
|
||||
try:
|
||||
async for item in recv:
|
||||
yield item.model_dump_json() + "\n"
|
||||
except anyio.ClosedResourceError:
|
||||
pass
|
||||
|
||||
ret = StreamingResponse(map_recv())
|
||||
ret._pls_dont_gc = _handle # type: ignore
|
||||
return ret
|
||||
|
||||
|
||||
async def jaccl_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
return await execute_test(test, jaccl_instance(test, iid))
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
global hn
|
||||
meta = MODEL_CARDS[test.model_id].metadata
|
||||
world_size = len(test.devs)
|
||||
for name, _ in test.devs:
|
||||
if hn.startswith(name):
|
||||
hn = name
|
||||
break
|
||||
|
||||
return MlxJacclInstance(
|
||||
instance_id=iid,
|
||||
ibv_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
|
||||
# rank 0 is always coordinator
|
||||
jaccl_coordinators={
|
||||
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
|
||||
},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(test.model_id),
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): TensorShardMetadata(
|
||||
model_meta=meta,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=meta.n_layers,
|
||||
end_layer=meta.n_layers,
|
||||
n_layers=meta.n_layers,
|
||||
)
|
||||
for i in range(world_size)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def new_runner(
|
||||
instance: Instance,
|
||||
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
|
||||
)
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
task_send, task_recv = mp_channel[Task]()
|
||||
|
||||
runner_process = mp.Process(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
logger,
|
||||
),
|
||||
)
|
||||
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
|
||||
runner_process.start()
|
||||
time.sleep(0.1)
|
||||
return (runner_process, ev_recv, task_send)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
anyio.run(main)
|
||||
52
tests/start_distributed_test.sh
Executable file
52
tests/start_distributed_test.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
query() {
|
||||
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
|
||||
}
|
||||
|
||||
if [[ $# -lt 2 ]]; then
|
||||
echo "USAGE: $0 <test kind> [host1] [host2] ..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
kind=$1
|
||||
shift
|
||||
|
||||
test_kinds="ring jaccl"
|
||||
|
||||
if ! echo "$test_kinds" | grep -q "$kind"; then
|
||||
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
hostnames=("$@")
|
||||
weaved=()
|
||||
ips=()
|
||||
for name in "${hostnames[@]}"; do
|
||||
ip=$(query "$name")
|
||||
ips+=("$ip")
|
||||
weaved+=("$name" "$ip")
|
||||
done
|
||||
|
||||
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
|
||||
devs="[${devs_raw%, }]"
|
||||
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"llama-3.2-1b\",
|
||||
\"devs\": ${devs},
|
||||
\"kind\": \"inference\"
|
||||
}"
|
||||
echo "req $req"
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed"
|
||||
} &
|
||||
done
|
||||
|
||||
wait
|
||||
@@ -1,24 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
networksetup -listallnetworkservices | grep -q '^Thunderbolt Bridge$' \
|
||||
&& echo "Disabling bridge in networksetup" \
|
||||
&& networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
|
||||
networksetup -listallnetworkservices | grep -q '^\*Thunderbolt Bridge$' \
|
||||
&& echo "Bridge disabled in networksetup"
|
||||
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && echo "Removing bridge members in ifconfig" && {
|
||||
ifconfig bridge0 | \
|
||||
awk '/member/ {print $2}' | \
|
||||
xargs -n1 sudo ifconfig bridge0 deletem
|
||||
}
|
||||
ifconfig bridge0 | grep -q 'status: active' && sudo ifconfig bridge0 down
|
||||
ifconfig bridge0 | grep -q 'status: inactive' && echo "Bridge disabled in ifconfig"
|
||||
}
|
||||
|
||||
for iface in $(seq 2 7); do
|
||||
sudo ipconfig set "en$iface" dhcp && echo "enabled dhcp on en$iface" || echo "failed to enable dhcp on en$iface"
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user