mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-17 02:18:47 -05:00
Compare commits
70 Commits
iroh-migra
...
v1.0.63
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e28664c41 | ||
|
|
ae0a804ccb | ||
|
|
07cf2c1aa1 | ||
|
|
83c5285a80 | ||
|
|
39ee2bf7bd | ||
|
|
991adfbd6f | ||
|
|
4b3de6b984 | ||
|
|
c8de3b90ea | ||
|
|
6e6567a802 | ||
|
|
a735dad667 | ||
|
|
aaf4e36bc3 | ||
|
|
3e623ccf0d | ||
|
|
c22dad8a7d | ||
|
|
4bc4d50685 | ||
|
|
e0aab46fd8 | ||
|
|
82ba42bae9 | ||
|
|
3671528fa4 | ||
|
|
e6434ec446 | ||
|
|
bdb43e1dbb | ||
|
|
e4a01e2b0e | ||
|
|
1200a7db64 | ||
|
|
47ceb54bc1 | ||
|
|
f8112fdf25 | ||
|
|
e388f59480 | ||
|
|
e5e74e1eef | ||
|
|
b968d6f0a0 | ||
|
|
3bfffd9b4f | ||
|
|
007eb80029 | ||
|
|
8d7b6789b3 | ||
|
|
3c5b7ea670 | ||
|
|
b74a610537 | ||
|
|
18c4e49f91 | ||
|
|
d85b5d3781 | ||
|
|
caafc48693 | ||
|
|
cca8c9984a | ||
|
|
d1e88def42 | ||
|
|
59e7594e34 | ||
|
|
c65320acd3 | ||
|
|
b9a78f6f3a | ||
|
|
8f7f0e893a | ||
|
|
4759b09d4c | ||
|
|
ca680185f3 | ||
|
|
383309e24e | ||
|
|
55463a9806 | ||
|
|
56af61fac9 | ||
|
|
f76d543d98 | ||
|
|
ea841aca37 | ||
|
|
077b1bc732 | ||
|
|
4963c33162 | ||
|
|
4f6fcd9e93 | ||
|
|
839b67f318 | ||
|
|
47b8e0ce12 | ||
|
|
17f9b583a4 | ||
|
|
844bcc7ce6 | ||
|
|
c1be5184b2 | ||
|
|
1ec550dff1 | ||
|
|
283c0e39e4 | ||
|
|
35be4c55c3 | ||
|
|
31d4cd8409 | ||
|
|
8a6da58404 | ||
|
|
16e2bfd3b3 | ||
|
|
ade3ee7ec5 | ||
|
|
fea42473dd | ||
|
|
ca7adcc2a8 | ||
|
|
9d9e24f969 | ||
|
|
b5d424b658 | ||
|
|
b465134012 | ||
|
|
eabdcab978 | ||
|
|
8e9332d6a7 | ||
|
|
4b65d5f896 |
159
.github/benchmark-dashboard/README.md
vendored
159
.github/benchmark-dashboard/README.md
vendored
@@ -1,159 +0,0 @@
|
||||
# EXO Benchmark Dashboard
|
||||
|
||||
A fully self-contained, browser-based dashboard for tracking EXO benchmark performance over time.
|
||||
|
||||
## Features
|
||||
|
||||
- 📊 **Success Rate Tracking**: Monitor cluster reliability across commits
|
||||
- ⚡ **Response Time Analysis**: Track average request completion times
|
||||
- 🎯 **Throughput Metrics**: Tokens per second visualization
|
||||
- 📈 **Request Distribution**: Success/failure breakdown over time
|
||||
- 🔄 **Auto-Refresh**: Updates every 60 seconds
|
||||
- 📺 **TV-Ready**: Large, clear visualizations perfect for display
|
||||
- 🔐 **Secure**: Credentials stored in browser localStorage only
|
||||
- 🌐 **No Backend**: Directly accesses S3 from the browser
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Option 1: Direct File Access (Simplest)
|
||||
|
||||
Just open the HTML file directly in your browser:
|
||||
|
||||
```bash
|
||||
open .github/benchmark-dashboard/index.html
|
||||
```
|
||||
|
||||
Then click "Configure AWS Credentials" and enter your keys.
|
||||
|
||||
### Option 2: URL Parameters (For Quick Setup)
|
||||
|
||||
```bash
|
||||
# Serve with credentials in URL (they'll be moved to localStorage)
|
||||
open ".github/benchmark-dashboard/index.html?accessKey=YOUR_KEY&secretKey=YOUR_SECRET®ion=us-east-1"
|
||||
```
|
||||
|
||||
The credentials will be saved to localStorage and removed from the URL immediately.
|
||||
|
||||
### Option 3: Simple HTTP Server
|
||||
|
||||
```bash
|
||||
# From repo root
|
||||
python3 -m http.server 8080
|
||||
|
||||
# Then open: http://localhost:8080/.github/benchmark-dashboard/
|
||||
```
|
||||
|
||||
## AWS Credentials
|
||||
|
||||
The dashboard needs read-only access to the `exo-benchmark-results` S3 bucket.
|
||||
|
||||
### Required IAM Permissions
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::exo-benchmark-results",
|
||||
"arn:aws:s3:::exo-benchmark-results/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Security Notes
|
||||
|
||||
- ✅ Credentials stored in browser `localStorage` only
|
||||
- ✅ Never sent to any server (except AWS)
|
||||
- ✅ All S3 access happens client-side
|
||||
- ✅ Use read-only IAM credentials
|
||||
- ⚠️ Don't commit credentials to git
|
||||
- ⚠️ Use a dedicated read-only IAM user
|
||||
|
||||
## TV/Kiosk Mode
|
||||
|
||||
For permanent display on a TV:
|
||||
|
||||
### macOS
|
||||
```bash
|
||||
open -a "Google Chrome" --args --kiosk ".github/benchmark-dashboard/index.html"
|
||||
```
|
||||
|
||||
### Linux
|
||||
```bash
|
||||
chromium-browser --kiosk --app="file://$(pwd)/.github/benchmark-dashboard/index.html"
|
||||
```
|
||||
|
||||
### Auto-start on Boot
|
||||
|
||||
Create a simple startup script:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# /usr/local/bin/start-benchmark-dashboard.sh
|
||||
|
||||
cd /path/to/exo
|
||||
python3 -m http.server 8080 &
|
||||
sleep 2
|
||||
chromium-browser --kiosk http://localhost:8080/.github/benchmark-dashboard/
|
||||
```
|
||||
|
||||
## Data Displayed
|
||||
|
||||
### Summary Cards
|
||||
- **Latest Success Rate**: Most recent benchmark success percentage with trend
|
||||
- **Avg Response Time**: Latest average response time in ms with trend
|
||||
- **Total Benchmarks**: Count of all benchmarks run
|
||||
- **Active Configurations**: Number of unique benchmark configs
|
||||
|
||||
### Charts
|
||||
1. **Success Rate Over Time**: Line chart showing reliability trends
|
||||
2. **Average Response Time**: Performance over time (lower is better)
|
||||
3. **Throughput**: Tokens/second metric (higher is better)
|
||||
4. **Request Distribution**: Stacked bar chart of successes/failures
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Loads AWS SDK**: Uses AWS SDK for JavaScript (browser version)
|
||||
2. **Lists S3 Objects**: Fetches all files from `s3://exo-benchmark-results/bench/`
|
||||
3. **Downloads Results**: Fetches each JSON result file
|
||||
4. **Parses & Visualizes**: Uses Chart.js to create interactive charts
|
||||
5. **Auto-Refreshes**: Polls S3 every 60 seconds for new results
|
||||
|
||||
## Customization
|
||||
|
||||
To modify the dashboard:
|
||||
|
||||
1. Edit `index.html`
|
||||
2. Adjust `REFRESH_INTERVAL` for different polling frequency
|
||||
3. Modify chart colors/styles in the Chart.js configuration
|
||||
4. Add new metrics by extending the results parsing
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**"AWS credentials not configured"**
|
||||
- Click "Configure AWS Credentials" and enter your keys
|
||||
|
||||
**"Error loading benchmark data"**
|
||||
- Check AWS credentials are correct
|
||||
- Verify S3 bucket name is `exo-benchmark-results`
|
||||
- Ensure IAM user has read permissions
|
||||
- Check browser console for detailed errors
|
||||
|
||||
**"No benchmark results found"**
|
||||
- Wait for benchmark workflows to run
|
||||
- Verify results are being uploaded to S3
|
||||
- Check S3 bucket has files in `bench/` prefix
|
||||
|
||||
**Charts not updating**
|
||||
- Check browser console for errors
|
||||
- Verify network connectivity to S3
|
||||
- Try refreshing the page manually
|
||||
|
||||
1641
.github/benchmark-dashboard/index.html
vendored
1641
.github/benchmark-dashboard/index.html
vendored
File diff suppressed because it is too large
Load Diff
186
.github/configs/README.md
vendored
186
.github/configs/README.md
vendored
@@ -1,186 +0,0 @@
|
||||
# EXO Benchmark Configurations
|
||||
|
||||
This directory contains configuration files for the EXO staged benchmark system.
|
||||
|
||||
## Overview
|
||||
|
||||
The staged benchmark system allows you to run complex, multi-stage load tests against EXO clusters. Each stage can have different characteristics:
|
||||
|
||||
- **Prompt Length**: Number of tokens in the input prompt
|
||||
- **Generation Length**: Maximum tokens to generate in the response
|
||||
- **Time Between Requests**: Delay (in seconds) between firing consecutive requests
|
||||
- **Iterations**: Number of requests to send in this stage
|
||||
|
||||
Requests are **fire-and-forget** - they don't wait for the previous request to complete. This allows you to test overlapping request handling and measure success rates under load.
|
||||
|
||||
## Configuration Files
|
||||
|
||||
### `bench_simple.yaml`
|
||||
A minimal configuration that replicates the behavior of the original `bench.py` script:
|
||||
- Single stage with 1 iteration
|
||||
- Short prompt (~20 tokens)
|
||||
- Generates up to 100 tokens
|
||||
|
||||
This is useful for quick smoke tests.
|
||||
|
||||
### `bench_config.yaml`
|
||||
A comprehensive multi-stage benchmark with:
|
||||
1. **Warmup** (10 requests): Light load with short prompts
|
||||
2. **Medium Load** (20 requests): Moderate load with medium prompts
|
||||
3. **Stress Test** (30 requests): Heavy overlapping requests with long prompts
|
||||
4. **Cooldown** (5 requests): Light load to wind down
|
||||
|
||||
This tests the cluster's behavior under varying load patterns.
|
||||
|
||||
## Configuration Schema
|
||||
|
||||
```yaml
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 4
|
||||
|
||||
# Environment variables to set on each node (optional)
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 512
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 600
|
||||
|
||||
# Model instances to run concurrently
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
# Benchmark stages
|
||||
stages:
|
||||
- name: "stage_name" # Human-readable name for this stage
|
||||
prompt_length: 100 # Target prompt length in tokens
|
||||
generation_length: 200 # Max tokens to generate
|
||||
time_between_requests: 2.0 # Seconds between firing requests
|
||||
iterations: 10 # Number of requests in this stage
|
||||
```
|
||||
|
||||
## Running Benchmarks
|
||||
|
||||
### Via GitHub Actions
|
||||
|
||||
**Automatic (every commit):**
|
||||
- The **`bench`** workflow runs automatically on every push
|
||||
- Uses `bench_simple.yaml` as the default configuration
|
||||
- All settings (hardware plan, timeout, environment variables, models, stages) are defined in the config file
|
||||
|
||||
**Manual (on-demand):**
|
||||
1. Go to **Actions** → **bench** workflow
|
||||
2. Click **Run workflow**
|
||||
3. Configure:
|
||||
- **Config File**: Path to your YAML config (default: `.github/configs/bench_simple.yaml`)
|
||||
- `.github/configs/bench_simple.yaml` for quick tests
|
||||
- `.github/configs/bench_config.yaml` for complex multi-stage tests
|
||||
|
||||
All other settings (hardware plan, timeout, environment variables, models, stages) are read from the specified config file.
|
||||
|
||||
### Via Command Line
|
||||
|
||||
```bash
|
||||
# Start EXO on localhost:8000
|
||||
uv run exo --api-port 8000
|
||||
|
||||
# Run simple benchmark (1 stage, 1 iteration)
|
||||
python3 .github/scripts/bench.py \
|
||||
--api-port 8000 \
|
||||
--config .github/configs/bench_simple.yaml \
|
||||
--expected-nodes 1 \
|
||||
--is-primary true \
|
||||
--timeout-seconds 600
|
||||
|
||||
# Run complex staged benchmark (4 stages, multiple iterations)
|
||||
python3 .github/scripts/bench.py \
|
||||
--api-port 8000 \
|
||||
--config .github/configs/bench_config.yaml \
|
||||
--expected-nodes 1 \
|
||||
--is-primary true \
|
||||
--timeout-seconds 600
|
||||
```
|
||||
|
||||
## Output Metrics
|
||||
|
||||
For each stage, the benchmark reports:
|
||||
|
||||
- **Total Requests**: Number of requests fired
|
||||
- **Successful Requests**: Requests that completed successfully
|
||||
- **Failed Requests**: Requests that encountered errors
|
||||
- **Success Rate**: Percentage of successful requests
|
||||
- **Total Tokens**: Sum of all tokens generated across successful requests
|
||||
- **Avg Tokens/Request**: Average tokens per successful request
|
||||
- **Avg Time/Request**: Average completion time per successful request
|
||||
|
||||
A JSON summary is also printed for easy parsing and storage.
|
||||
|
||||
## Creating Custom Benchmarks
|
||||
|
||||
To create a custom benchmark:
|
||||
|
||||
1. Copy an existing config file (e.g., `bench_config.yaml`)
|
||||
2. Modify the stages to match your test scenario
|
||||
3. Save it in this directory with a descriptive name
|
||||
4. Run it using the workflow or command line
|
||||
|
||||
### Example: Sustained Load Test
|
||||
|
||||
```yaml
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 2
|
||||
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 1024
|
||||
|
||||
timeout_seconds: 600
|
||||
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
stages:
|
||||
- name: "sustained_load"
|
||||
prompt_length: 200
|
||||
generation_length: 150
|
||||
time_between_requests: 0.5 # Very fast - 2 requests/second
|
||||
iterations: 100 # Run for ~50 seconds
|
||||
```
|
||||
|
||||
### Example: Varying Prompt Sizes
|
||||
|
||||
```yaml
|
||||
hardware_plan:
|
||||
M4PRO_GPU16_24GB: 3
|
||||
|
||||
timeout_seconds: 900
|
||||
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
stages:
|
||||
- name: "tiny_prompts"
|
||||
prompt_length: 10
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
|
||||
- name: "medium_prompts"
|
||||
prompt_length: 200
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
|
||||
- name: "large_prompts"
|
||||
prompt_length: 1000
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
- **Overlapping Requests**: Set `time_between_requests` < expected completion time to test concurrent request handling
|
||||
- **Sequential Requests**: Set `time_between_requests` > expected completion time to ensure requests don't overlap
|
||||
- **Realistic Load**: Model real usage patterns by varying prompt/generation lengths across stages
|
||||
- **Success Rate**: A 100% success rate indicates the cluster handled the load well; lower rates suggest capacity limits
|
||||
|
||||
49
.github/configs/bench_config.yaml
vendored
49
.github/configs/bench_config.yaml
vendored
@@ -1,49 +0,0 @@
|
||||
# EXO Staged Benchmark Configuration
|
||||
# This configuration defines a multi-stage load test for EXO clusters
|
||||
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 4
|
||||
|
||||
# Environment variables to set on each node (optional)
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 512
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 600
|
||||
|
||||
# Multiple instances run concurrently on the cluster
|
||||
model_ids:
|
||||
- "mlx-community/Qwen3-0.6B-4bit"
|
||||
- "mlx-community/Qwen3-0.6B-4bit"
|
||||
|
||||
# Stages run sequentially, each with its own characteristics
|
||||
stages:
|
||||
# Stage 1: Light load with short prompts
|
||||
- name: "warmup"
|
||||
prompt_length: 50 # Number of tokens in prompt
|
||||
generation_length: 100 # Max tokens to generate
|
||||
time_between_requests: 5.0 # Seconds between firing requests
|
||||
iterations: 10 # Number of requests to send in this stage
|
||||
|
||||
# Stage 2: Medium load with medium prompts
|
||||
- name: "medium_load"
|
||||
prompt_length: 200
|
||||
generation_length: 150
|
||||
time_between_requests: 3.0
|
||||
iterations: 20
|
||||
|
||||
# Stage 3: Heavy load with long prompts - requests will overlap
|
||||
- name: "stress_test"
|
||||
prompt_length: 500
|
||||
generation_length: 200
|
||||
time_between_requests: 1.0 # Fast firing - will definitely overlap
|
||||
iterations: 30
|
||||
|
||||
# Stage 4: Cool down with simple prompts
|
||||
- name: "cooldown"
|
||||
prompt_length: 50
|
||||
generation_length: 50
|
||||
time_between_requests: 10.0
|
||||
iterations: 5
|
||||
|
||||
125
.github/configs/bench_simple.yaml
vendored
125
.github/configs/bench_simple.yaml
vendored
@@ -1,125 +0,0 @@
|
||||
# Simple single-shot benchmark
|
||||
# Tests 2 instances concurrently on 2 nodes
|
||||
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
puffin4: 1
|
||||
puffin8: 1
|
||||
|
||||
# Environment variables to set on each node
|
||||
environment:
|
||||
PLACEHOLDER: "placeholder"
|
||||
# OVERRIDE_MEMORY_MB: 50000
|
||||
MLX_METAL_FAST_SYNCH: 1
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 1800
|
||||
|
||||
# Model instances to run concurrently
|
||||
model_ids:
|
||||
# - "mlx-community/DeepSeek-V3.1-8bit"
|
||||
# - "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
- "mlx-community/Kimi-K2-Thinking"
|
||||
# - "mlx-community/Qwen3-235B-A22B-4bit"
|
||||
# - "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
# - "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
# Sharding strategy: "Pipeline" or "Tensor"
|
||||
sharding: "Tensor"
|
||||
|
||||
# Instance type: "MlxRing" or "MlxIbv"
|
||||
instance_meta: "MlxIbv"
|
||||
|
||||
# If true, run requests sequentially (no overlap); if false, fire-and-forget (default: false)
|
||||
no_overlap: true
|
||||
|
||||
# Benchmark stages
|
||||
# pp: 64, 256, 1024, 2048, 4096, 8192, 16384
|
||||
# g: 64, 512
|
||||
stages:
|
||||
# - name: "simple"
|
||||
# prompt_length: 512
|
||||
# generation_length: 10
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g64"
|
||||
# prompt_length: 64
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g64"
|
||||
# prompt_length: 64
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g512"
|
||||
# prompt_length: 64
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp256_g64"
|
||||
# prompt_length: 256
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
- name: "pp256_g64"
|
||||
prompt_length: 256
|
||||
generation_length: 64
|
||||
time_between_requests: 2.0
|
||||
iterations: 5
|
||||
# - name: "pp256_g512"
|
||||
# prompt_length: 256
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp1024_g64"
|
||||
# prompt_length: 1024
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp1024_g512"
|
||||
# prompt_length: 1024
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp2048_g64"
|
||||
# prompt_length: 2048
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp2048_g512"
|
||||
# prompt_length: 2048
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp4096_g64"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 4
|
||||
# - name: "pp4096_g512"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp8192_g64"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp8192_g512"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp16384_g64"
|
||||
# prompt_length: 16384
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp16384_g512"
|
||||
# prompt_length: 16384
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
1399
.github/scripts/bench.py
vendored
1399
.github/scripts/bench.py
vendored
File diff suppressed because it is too large
Load Diff
70
.github/scripts/build_matrix.py
vendored
70
.github/scripts/build_matrix.py
vendored
@@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class MatrixEntry(TypedDict):
|
||||
label: str
|
||||
index: int
|
||||
|
||||
|
||||
class MatrixInclude(TypedDict):
|
||||
label: str
|
||||
index: int
|
||||
is_primary: bool
|
||||
expected_nodes: int
|
||||
|
||||
|
||||
class Config(TypedDict):
|
||||
hardware_plan: dict[str, int]
|
||||
timeout_seconds: NotRequired[int]
|
||||
environment: NotRequired[dict[str, str]]
|
||||
|
||||
|
||||
# Read the config file
|
||||
config_file: str = os.environ["CONFIG_FILE"]
|
||||
with open(config_file, "r") as f:
|
||||
config: Config = cast(Config, yaml.safe_load(f))
|
||||
|
||||
# Extract hardware plan from config
|
||||
plan: dict[str, int] = config["hardware_plan"]
|
||||
if not plan:
|
||||
raise ValueError(f"No hardware_plan found in {config_file}")
|
||||
|
||||
# Build matrix entries
|
||||
entries: list[MatrixEntry] = []
|
||||
for label, count in plan.items():
|
||||
for idx in range(count):
|
||||
entries.append({"label": label, "index": idx})
|
||||
|
||||
total_nodes: int = len(entries)
|
||||
matrix: dict[str, list[MatrixInclude]] = {
|
||||
"include": [
|
||||
{
|
||||
"label": e["label"],
|
||||
"index": e["index"],
|
||||
"is_primary": (i == 0),
|
||||
"expected_nodes": total_nodes,
|
||||
}
|
||||
for i, e in enumerate(entries)
|
||||
]
|
||||
}
|
||||
|
||||
# Extract other config values
|
||||
timeout_seconds: int = config.get("timeout_seconds", 600)
|
||||
environment: dict[str, str] = config.get("environment", {})
|
||||
|
||||
# Output to GitHub Actions
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as f:
|
||||
f.write(f"matrix={json.dumps(matrix)}\n")
|
||||
f.write(f"config_file={config_file}\n")
|
||||
f.write(f"timeout_seconds={timeout_seconds}\n")
|
||||
f.write(f"environment={json.dumps(environment)}\n")
|
||||
|
||||
print(f"Matrix: {json.dumps(matrix)}")
|
||||
print(f"Config file: {config_file}")
|
||||
print(f"Timeout: {timeout_seconds}")
|
||||
print(f"Environment: {json.dumps(environment)}")
|
||||
156
.github/workflows/BENCH_USAGE.md
vendored
156
.github/workflows/BENCH_USAGE.md
vendored
@@ -1,156 +0,0 @@
|
||||
# Benchmark Workflow Usage
|
||||
|
||||
## Overview
|
||||
|
||||
The `bench_matrix.yml` workflow enables distributed benchmarking of models across multiple self-hosted macOS runners with different hardware configurations.
|
||||
|
||||
## Workflow Inputs
|
||||
|
||||
| Input | Description | Default | Required |
|
||||
|-------|-------------|---------|----------|
|
||||
| `model_id` | Model ID to benchmark | `mlx-community/Llama-3.2-1B-Instruct-4bit` | Yes |
|
||||
| `hardware_plan` | JSON mapping of runner labels to counts | `{"M4PRO_GPU16_24GB": 1}` | Yes |
|
||||
| `prompt` | Benchmark prompt text | `What is the capital of France?` | No |
|
||||
| `timeout_seconds` | Timeout for instance/runner readiness | `600` | No |
|
||||
|
||||
## Hardware Plan Format
|
||||
|
||||
The `hardware_plan` input is a JSON object mapping runner labels to the number of machines:
|
||||
|
||||
```json
|
||||
{
|
||||
"M4PRO_GPU16_24GB": 2,
|
||||
"M3ULTRA_GPU80_512GB": 1
|
||||
}
|
||||
```
|
||||
|
||||
This example would:
|
||||
- Start 2 runners with the `M4PRO_GPU16_24GB` label
|
||||
- Start 1 runner with the `M3ULTRA_GPU80_512GB` label
|
||||
- Total of 3 runners coordinating on a single distributed inference instance
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Planning Job** (`plan`)
|
||||
- Runs on `ubuntu-latest`
|
||||
- Parses the `hardware_plan` JSON
|
||||
- Generates a dynamic matrix with one entry per runner
|
||||
- Only the first runner (index 0) is marked as `is_primary`
|
||||
|
||||
2. **Benchmark Worker Jobs** (`bench_worker`)
|
||||
- Each job runs on a self-hosted macOS runner with the specified label
|
||||
- All runners start EXO in parallel
|
||||
- The primary runner creates the model instance
|
||||
- All runners wait for their assigned runner to be ready (Loaded/Running status)
|
||||
- The primary runner executes the benchmark and prints results
|
||||
- The primary runner deletes the instance
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Single Machine Benchmark
|
||||
|
||||
```yaml
|
||||
model_id: mlx-community/Llama-3.2-1B-Instruct-4bit
|
||||
hardware_plan: '{"M4PRO_GPU16_24GB": 1}'
|
||||
prompt: What is the capital of France?
|
||||
timeout_seconds: 600
|
||||
```
|
||||
|
||||
### Multi-Machine Distributed Benchmark
|
||||
|
||||
```yaml
|
||||
model_id: mlx-community/Llama-3.2-3B-Instruct-4bit
|
||||
hardware_plan: '{"M4PRO_GPU16_24GB": 2, "M3ULTRA_GPU80_512GB": 1}'
|
||||
prompt: Explain quantum computing in simple terms.
|
||||
timeout_seconds: 900
|
||||
```
|
||||
|
||||
## Benchmark Output
|
||||
|
||||
The primary runner outputs a JSON object with benchmark results:
|
||||
|
||||
```json
|
||||
{
|
||||
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
"instance_id": "abc-123-def",
|
||||
"tokens": 42,
|
||||
"elapsed_s": 2.451,
|
||||
"tps": 17.136
|
||||
}
|
||||
```
|
||||
|
||||
Where:
|
||||
- `tokens`: Number of chunks/tokens generated
|
||||
- `elapsed_s`: Total elapsed time in seconds
|
||||
- `tps`: Tokens per second (tokens / elapsed_s)
|
||||
|
||||
## Runner Requirements
|
||||
|
||||
Each self-hosted runner must:
|
||||
- Be labeled with appropriate hardware tags (e.g., `M4PRO_GPU16_24GB`)
|
||||
- Have the `self-hosted` and `macOS` labels
|
||||
- Have Nix installed with flakes enabled
|
||||
- Have network connectivity to other runners in the same job
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ GitHub Actions Workflow (bench_matrix.yml) │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌────────────────┐ │
|
||||
│ │ Plan Job │ │
|
||||
│ │ (ubuntu) │──┬─► Matrix: [{label, index, primary}] │
|
||||
│ └────────────────┘ │ │
|
||||
│ │ │
|
||||
│ ┌───────────────────▼──────────────────────────────────┐ │
|
||||
│ │ Bench Worker Jobs (Matrix) │ │
|
||||
│ ├──────────────────────────────────────────────────────┤ │
|
||||
│ │ │ │
|
||||
│ │ Runner 0 (Primary) Runner 1 Runner 2 │ │
|
||||
│ │ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ │
|
||||
│ │ │ Start EXO │ │ Start EXO │ │ Start EXO│ │ │
|
||||
│ │ │ Create Inst │ │ Wait... │ │ Wait... │ │ │
|
||||
│ │ │ Wait Ready │ │ Wait Ready │ │ Wait... │ │ │
|
||||
│ │ │ Run Bench │ │ (idle) │ │ (idle) │ │ │
|
||||
│ │ │ Print TPS │ │ │ │ │ │ │
|
||||
│ │ │ Delete Inst │ │ │ │ │ │ │
|
||||
│ │ └─────────────┘ └─────────────┘ └──────────┘ │ │
|
||||
│ └───────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### `scripts/bench.py`
|
||||
|
||||
A standalone Python script that:
|
||||
- Creates instance (primary only)
|
||||
- Polls `/state` endpoint until instance and all runners are ready
|
||||
- Executes chat completion with timing (primary only)
|
||||
- Parses SSE stream and counts tokens
|
||||
- Computes TPS metrics
|
||||
- Cleans up instance (primary only)
|
||||
|
||||
### Key Functions
|
||||
|
||||
- `wait_for_instance()`: Polls until instance with model_id appears
|
||||
- `wait_for_runners_ready()`: Polls until expected number of runners reach Loaded/Running status
|
||||
- `run_benchmark()`: Executes chat completion, measures time, counts tokens
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Instance never becomes ready
|
||||
- Check EXO logs in the workflow output
|
||||
- Verify model_id is valid and accessible
|
||||
- Increase `timeout_seconds`
|
||||
|
||||
### Runner mismatch
|
||||
- Ensure hardware_plan counts match available labeled runners
|
||||
- Check runner labels match exactly (case-sensitive)
|
||||
|
||||
### Network issues
|
||||
- Verify runners can communicate on the network
|
||||
- Check firewall rules between runner hosts
|
||||
|
||||
305
.github/workflows/bench.yml
vendored
305
.github/workflows/bench.yml
vendored
@@ -1,305 +0,0 @@
|
||||
name: bench
|
||||
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
plan:
|
||||
if: contains(github.event.head_commit.message, '/bench')
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.build.outputs.matrix }}
|
||||
config_file: ${{ steps.build.outputs.config_file }}
|
||||
timeout_seconds: ${{ steps.build.outputs.timeout_seconds }}
|
||||
environment: ${{ steps.build.outputs.environment }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Build matrix from config file
|
||||
id: build
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
CONFIG_FILE='.github/configs/bench_simple.yaml'
|
||||
export CONFIG_FILE
|
||||
echo "Config file: $CONFIG_FILE"
|
||||
python3 .github/scripts/build_matrix.py
|
||||
|
||||
bench_worker:
|
||||
needs: plan
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix: ${{ fromJSON(needs.plan.outputs.matrix) }}
|
||||
name: "bench on ${{ matrix.label }} [${{ matrix.index }}]"
|
||||
runs-on: [self-hosted, macOS, "${{ matrix.label }}"]
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
git config --local user.name "github-actions bot"
|
||||
shell: bash
|
||||
|
||||
# TODO: this is mega hacky and I'd like a simpler solution.
|
||||
- name: Setup Nix Environment
|
||||
run: |
|
||||
echo "Checking for nix installation..."
|
||||
|
||||
# Check if nix is already available
|
||||
if command -v nix >/dev/null 2>&1; then
|
||||
echo "Nix already in PATH"
|
||||
# Try sourcing profile scripts to set up environment properly
|
||||
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
echo "Sourcing multi-user nix-daemon profile script"
|
||||
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
elif [ -f "$HOME/.nix-profile/etc/profile.d/nix.sh" ]; then
|
||||
echo "Sourcing single-user nix profile script"
|
||||
source "$HOME/.nix-profile/etc/profile.d/nix.sh"
|
||||
elif [ -f /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh ]; then
|
||||
echo "Sourcing per-user nix profile script"
|
||||
source /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh
|
||||
elif [ -f /etc/profile.d/nix.sh ]; then
|
||||
echo "Sourcing system-wide nix profile script"
|
||||
source /etc/profile.d/nix.sh
|
||||
# Fallback: manually add nix to PATH if binary exists
|
||||
elif [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
echo "Found nix binary, manually adding to PATH"
|
||||
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
elif [ -f "$HOME/.nix-profile/bin/nix" ]; then
|
||||
echo "Found nix binary in user profile, manually adding to PATH"
|
||||
export PATH="$HOME/.nix-profile/bin:$PATH"
|
||||
else
|
||||
echo "Nix not found. Debugging info:"
|
||||
echo "USER: $USER"
|
||||
echo "HOME: $HOME"
|
||||
echo "Current PATH: $PATH"
|
||||
echo ""
|
||||
echo "Checking common Nix locations:"
|
||||
echo " /nix/var/nix/profiles/default/bin/nix:"
|
||||
ls -la /nix/var/nix/profiles/default/bin/nix 2>/dev/null || echo " Not found"
|
||||
echo " /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh:"
|
||||
ls -la /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh 2>/dev/null || echo " Not found"
|
||||
echo " ~/.nix-profile/etc/profile.d/nix.sh:"
|
||||
ls -la "$HOME/.nix-profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||
echo " /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh:"
|
||||
ls -la "/nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||
echo ""
|
||||
echo "/nix directory structure:"
|
||||
ls -la /nix 2>/dev/null || echo " /nix directory not found"
|
||||
echo ""
|
||||
echo "/nix/var:"
|
||||
ls -la /nix/var 2>/dev/null || echo " /nix/var not found"
|
||||
echo ""
|
||||
echo "/nix/store:"
|
||||
ls -la /nix/store 2>/dev/null | head -20 || echo " /nix/store not found"
|
||||
echo ""
|
||||
echo "GitHub Actions runner is running as user '$USER'."
|
||||
echo "If Nix is installed for a different user, either:"
|
||||
echo " 1. Install Nix for user '$USER' (multi-user install recommended)"
|
||||
echo " 2. Configure the runner service to run as the user with Nix installed"
|
||||
echo " 3. Ensure Nix is installed system-wide with proper daemon setup"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify nix is available and persist to GITHUB_ENV
|
||||
if command -v nix >/dev/null 2>&1; then
|
||||
echo "✓ Nix is available"
|
||||
nix --version
|
||||
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
if [ -n "$NIX_PATH" ]; then
|
||||
echo "NIX_PATH=$NIX_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
else
|
||||
echo "ERROR: Failed to set up Nix"
|
||||
echo "PATH after setup attempt: $PATH"
|
||||
exit 1
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Setup EXO_HOME and API_PORT
|
||||
run: |
|
||||
EXO_HOME=$(mktemp -d -t exo-e2e-XXXXXXXX)
|
||||
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||
EXO_MODELS_DIR="$HOME/.exo/models"
|
||||
EXO_LIBP2P_NAMESPACE="bench-${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
|
||||
echo "EXO_HOME=$EXO_HOME" >> "$GITHUB_ENV"
|
||||
echo "API_PORT=$API_PORT" >> "$GITHUB_ENV"
|
||||
echo "EXO_MODELS_DIR=$EXO_MODELS_DIR" >> "$GITHUB_ENV"
|
||||
echo "EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE" >> "$GITHUB_ENV"
|
||||
echo "Created EXO_HOME: $EXO_HOME"
|
||||
echo "Generated API_PORT: $API_PORT"
|
||||
echo "Using models from: $EXO_MODELS_DIR"
|
||||
echo "Using libp2p namespace: $EXO_LIBP2P_NAMESPACE"
|
||||
shell: bash
|
||||
|
||||
- name: Configure local MLX if available
|
||||
run: |
|
||||
echo "=== DEBUG: Checking for local MLX configuration ==="
|
||||
MODIFIED=false
|
||||
|
||||
echo "Checking for /Users/Shared/mlx directory..."
|
||||
if [ -d "/Users/Shared/mlx" ]; then
|
||||
echo "✓ Found /Users/Shared/mlx"
|
||||
ls -la /Users/Shared/mlx | head -5
|
||||
echo "Enabling local mlx path in pyproject.toml"
|
||||
sed -i.bak 's|^# mlx = { path = "/Users/Shared/mlx", editable=true }$|mlx = { path = "/Users/Shared/mlx", editable=true }|' pyproject.toml
|
||||
MODIFIED=true
|
||||
else
|
||||
echo "✗ /Users/Shared/mlx not found, will use PyPI version"
|
||||
fi
|
||||
|
||||
echo "Checking for /Users/Shared/mlx-lm directory..."
|
||||
if [ -d "/Users/Shared/mlx-lm" ]; then
|
||||
echo "✓ Found /Users/Shared/mlx-lm"
|
||||
ls -la /Users/Shared/mlx-lm | head -5
|
||||
echo "Enabling local mlx-lm path in pyproject.toml"
|
||||
sed -i.bak 's|^# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }$|mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }|' pyproject.toml
|
||||
MODIFIED=true
|
||||
else
|
||||
echo "✗ /Users/Shared/mlx-lm not found, will use PyPI version"
|
||||
fi
|
||||
|
||||
if [ "$MODIFIED" = true ]; then
|
||||
echo "=== Modified pyproject.toml [tool.uv.sources] section: ==="
|
||||
sed -n '/\[tool\.uv\.sources\]/,/^\[/{/^\[tool\.uv\.sources\]/p; /^\[/!p;}' pyproject.toml
|
||||
echo "=== Regenerating uv.lock with local MLX paths... ==="
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command uv lock --upgrade-package mlx --upgrade-package mlx-lm
|
||||
echo "✓ Lock file regenerated"
|
||||
else
|
||||
echo "⚠ No local MLX directories found, using PyPI packages"
|
||||
fi
|
||||
echo "=== DEBUG: Local MLX configuration complete ==="
|
||||
shell: bash
|
||||
|
||||
- name: Sync dependencies
|
||||
run: |
|
||||
if [ -d "/Users/Shared/test" ]; then
|
||||
pushd /Users/Shared/test
|
||||
uv sync --reinstall
|
||||
popd
|
||||
fi
|
||||
echo "Running just sync to ensure clean dependencies..."
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync
|
||||
shell: bash
|
||||
|
||||
- name: Start EXO and run bench script
|
||||
shell: bash
|
||||
env:
|
||||
IS_PRIMARY: ${{ matrix.is_primary }}
|
||||
EXPECTED_NODES: ${{ matrix.expected_nodes }}
|
||||
HARDWARE_LABEL: ${{ matrix.label }}
|
||||
CONFIG_FILE: ${{ needs.plan.outputs.config_file }}
|
||||
TIMEOUT_SECONDS: ${{ needs.plan.outputs.timeout_seconds }}
|
||||
ENVIRONMENT_JSON: ${{ needs.plan.outputs.environment }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Parse environment variables from config
|
||||
ENV_VARS=""
|
||||
if [ -n "$ENVIRONMENT_JSON" ] && [ "$ENVIRONMENT_JSON" != "{}" ]; then
|
||||
ENV_VARS=$(echo "$ENVIRONMENT_JSON" | python3 -c "import sys, json; env = json.load(sys.stdin); print(' '.join([f'{k}={v}' for k, v in env.items()]))")
|
||||
fi
|
||||
|
||||
echo "Starting EXO with API_PORT=${API_PORT} EXO_HOME=${EXO_HOME} EXO_LIBP2P_NAMESPACE=${EXO_LIBP2P_NAMESPACE}"
|
||||
echo "Environment variables from config: $ENV_VARS"
|
||||
LOG_FILE=/tmp/exo.log
|
||||
: > "$LOG_FILE"
|
||||
|
||||
MASTER_FLAG=""
|
||||
if [ "$IS_PRIMARY" = "true" ]; then
|
||||
MASTER_FLAG="-m"
|
||||
fi
|
||||
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||
"EXO_HOME=$EXO_HOME EXO_MODELS_DIR=$EXO_MODELS_DIR EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE $ENV_VARS PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run exo $MASTER_FLAG --api-port $API_PORT" \
|
||||
>> "$LOG_FILE" 2>&1 &
|
||||
|
||||
EXO_PID=$!
|
||||
echo "Started EXO in background with PID: $EXO_PID"
|
||||
echo "Log file: $LOG_FILE"
|
||||
|
||||
cleanup() {
|
||||
echo '=== EXO log (tail) ==='
|
||||
tail -n 300 "$LOG_FILE" || true
|
||||
if ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||
echo "Killing EXO (PID $EXO_PID)"
|
||||
kill "$EXO_PID" || true
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
for i in $(seq 1 60); do
|
||||
if curl -s "http://localhost:${API_PORT}/state" >/dev/null 2>&1; then
|
||||
echo "EXO API ready"
|
||||
break
|
||||
fi
|
||||
if ! ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||
echo "EXO terminated early"; sed -n '1,200p' "$LOG_FILE" || true; exit 1
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
RESULTS_FILE="/tmp/bench_results_${GITHUB_RUN_ID}_${GITHUB_RUN_ATTEMPT}_$(date +%s).json"
|
||||
echo "Results will be saved to: $RESULTS_FILE"
|
||||
echo "RESULTS_FILE=$RESULTS_FILE" >> "$GITHUB_ENV"
|
||||
|
||||
echo "Running bench script with config: $CONFIG_FILE, timeout: $TIMEOUT_SECONDS"
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||
"PYTHONUNBUFFERED=1 uv run --no-project --with pyyaml --with pydantic python .github/scripts/bench.py \
|
||||
--api-port $API_PORT \
|
||||
--config $CONFIG_FILE \
|
||||
--expected-nodes ${EXPECTED_NODES} \
|
||||
--is-primary ${IS_PRIMARY} \
|
||||
--timeout-seconds ${TIMEOUT_SECONDS} \
|
||||
--output $RESULTS_FILE \
|
||||
--git-commit ${GITHUB_SHA} \
|
||||
--hardware-labels ${HARDWARE_LABEL}"
|
||||
|
||||
- name: Install AWS CLI
|
||||
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||
run: |
|
||||
if ! command -v aws &> /dev/null; then
|
||||
echo "AWS CLI not found, installing..."
|
||||
brew install awscli
|
||||
else
|
||||
echo "AWS CLI already installed"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Upload results to S3
|
||||
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.S3_BENCHMARKS_AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
run: |
|
||||
echo "Checking for results file: $RESULTS_FILE"
|
||||
echo "Is primary: ${{ matrix.is_primary }}"
|
||||
|
||||
if [ -f "$RESULTS_FILE" ]; then
|
||||
TIMESTAMP=$(date -u +%Y/%m/%d/%H%M%S)
|
||||
S3_KEY="bench/${TIMESTAMP}_${GITHUB_SHA:0:8}_${GITHUB_RUN_ID}.json"
|
||||
echo "Uploading results to s3://exo-benchmark-results/$S3_KEY"
|
||||
|
||||
aws s3 cp "$RESULTS_FILE" "s3://exo-benchmark-results/$S3_KEY" \
|
||||
--content-type application/json \
|
||||
--metadata "commit=${GITHUB_SHA},run_id=${GITHUB_RUN_ID},branch=${GITHUB_REF_NAME}"
|
||||
|
||||
echo "Results uploaded successfully"
|
||||
echo "View at: https://exo-benchmark-results.s3.amazonaws.com/$S3_KEY"
|
||||
else
|
||||
echo "Results file not found at: $RESULTS_FILE"
|
||||
echo "Skipping upload"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Cleanup EXO_HOME
|
||||
run: |
|
||||
echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||
rm -rf "$EXO_HOME"
|
||||
shell: bash
|
||||
if: always()
|
||||
158
.github/workflows/build-app.yml
vendored
158
.github/workflows/build-app.yml
vendored
@@ -1,6 +1,18 @@
|
||||
name: Build EXO macOS DMG
|
||||
|
||||
# Release workflow:
|
||||
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
|
||||
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
|
||||
# 3. This workflow builds, signs, and notarizes the DMG
|
||||
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
|
||||
# 5. DMG and appcast.xml are uploaded to S3
|
||||
# 6. The draft GitHub Release is published with the DMG attached
|
||||
#
|
||||
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
|
||||
# If no draft exists, a release is auto-created with generated notes.
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
@@ -10,14 +22,17 @@ on:
|
||||
jobs:
|
||||
build-macos-app:
|
||||
runs-on: "macos-26"
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
SPARKLE_VERSION: 2.8.1
|
||||
SPARKLE_VERSION: 2.9.0-beta.1
|
||||
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
|
||||
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
|
||||
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
|
||||
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
|
||||
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
|
||||
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
|
||||
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT: ${{ secrets.EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT }}
|
||||
AWS_REGION: ${{ secrets.AWS_REGION }}
|
||||
EXO_BUILD_NUMBER: ${{ github.run_number }}
|
||||
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
|
||||
@@ -34,7 +49,7 @@ jobs:
|
||||
|
||||
- name: Derive release version from tag
|
||||
run: |
|
||||
if [[ "$GITHUB_REF_NAME" == "test-app" ]]; then
|
||||
if [[ "$GITHUB_REF_NAME" == "test-app" || "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
VERSION="0.0.0-alpha.0"
|
||||
echo "IS_ALPHA=true" >> $GITHUB_ENV
|
||||
else
|
||||
@@ -47,6 +62,32 @@ jobs:
|
||||
fi
|
||||
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
|
||||
|
||||
- name: Compute build version from semver
|
||||
run: |
|
||||
VERSION="$RELEASE_VERSION"
|
||||
# Extract major.minor.patch (strip prerelease suffix)
|
||||
BASE_VERSION="${VERSION%%-*}"
|
||||
MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
|
||||
MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
|
||||
PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)
|
||||
|
||||
# Extract prerelease number (e.g., "alpha.2" -> 2, or 999 for releases)
|
||||
if [[ "$VERSION" == *-* ]]; then
|
||||
PRERELEASE_PART="${VERSION#*-}"
|
||||
PRERELEASE_NUM="${PRERELEASE_PART##*.}"
|
||||
# Default to 0 if not a number
|
||||
if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
|
||||
PRERELEASE_NUM=0
|
||||
fi
|
||||
else
|
||||
PRERELEASE_NUM=999
|
||||
fi
|
||||
|
||||
# Compute: PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR)
|
||||
BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
|
||||
echo "EXO_BUILD_VERSION=$BUILD_VERSION" >> $GITHUB_ENV
|
||||
echo "Computed build version: $BUILD_VERSION from $VERSION"
|
||||
|
||||
- name: Ensure tag commit is on main
|
||||
if: github.ref_type == 'tag'
|
||||
run: |
|
||||
@@ -59,6 +100,52 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Fetch and validate release notes
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
# Find draft release by name using gh release list (more reliable with default token)
|
||||
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
|
||||
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
|
||||
|
||||
if [[ -z "$DRAFT_EXISTS" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
|
||||
echo "Please create a draft release with release notes before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch full release details via API to get body and ID
|
||||
echo "Found draft release, fetching details..."
|
||||
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
|
||||
|
||||
# Extract release notes
|
||||
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
|
||||
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "Draft release has no notes (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: Draft release exists but has no release notes"
|
||||
echo "Please add release notes to the draft release before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Save release ID for later publishing
|
||||
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
|
||||
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
|
||||
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
|
||||
|
||||
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
|
||||
echo "$NOTES" > /tmp/release_notes.md
|
||||
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
|
||||
|
||||
# ============================================================
|
||||
# Install dependencies
|
||||
# ============================================================
|
||||
@@ -85,11 +172,22 @@ jobs:
|
||||
uv python install
|
||||
uv sync --locked
|
||||
|
||||
- name: Install Nix
|
||||
uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- name: Configure Cachix
|
||||
uses: cachix/cachix-action@v14
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build dashboard
|
||||
run: |
|
||||
cd dashboard
|
||||
npm ci
|
||||
npm run build
|
||||
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
|
||||
mkdir -p dashboard/build
|
||||
cp -r "$DASHBOARD_OUT"/* dashboard/build/
|
||||
|
||||
- name: Install Sparkle CLI
|
||||
run: |
|
||||
@@ -162,11 +260,12 @@ jobs:
|
||||
-configuration Release \
|
||||
-derivedDataPath build \
|
||||
MARKETING_VERSION="$RELEASE_VERSION" \
|
||||
CURRENT_PROJECT_VERSION="$EXO_BUILD_NUMBER" \
|
||||
CURRENT_PROJECT_VERSION="$EXO_BUILD_VERSION" \
|
||||
EXO_BUILD_TAG="$RELEASE_VERSION" \
|
||||
EXO_BUILD_COMMIT="$GITHUB_SHA" \
|
||||
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
|
||||
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
|
||||
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT="$EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT" \
|
||||
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
|
||||
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
|
||||
mkdir -p ../../output
|
||||
@@ -264,6 +363,28 @@ jobs:
|
||||
$CHANNEL_FLAG \
|
||||
.
|
||||
|
||||
- name: Inject release notes into appcast
|
||||
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
|
||||
env:
|
||||
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
|
||||
run: |
|
||||
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
|
||||
export NOTES=$(cat "$RELEASE_NOTES_FILE")
|
||||
|
||||
# Insert description after the enclosure tag for this version
|
||||
awk '
|
||||
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
|
||||
print
|
||||
print " <description sparkle:format=\"markdown\"><![CDATA["
|
||||
print ENVIRON["NOTES"]
|
||||
print " ]]></description>"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
|
||||
|
||||
echo "Injected markdown release notes for version $RELEASE_VERSION"
|
||||
|
||||
# ============================================================
|
||||
# Upload artifacts
|
||||
# ============================================================
|
||||
@@ -294,5 +415,28 @@ jobs:
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
|
||||
if [[ "$IS_ALPHA" != "true" ]]; then
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
fi
|
||||
|
||||
- name: Publish GitHub Release
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
|
||||
|
||||
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
|
||||
# Update the draft release with the tag and upload DMG
|
||||
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
|
||||
-f tag_name="$GITHUB_REF_NAME" \
|
||||
-F draft=false
|
||||
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
|
||||
echo "Published release $GITHUB_REF_NAME with DMG attached"
|
||||
else
|
||||
# Alpha without draft release - create one with auto-generated notes
|
||||
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
|
||||
--title "$GITHUB_REF_NAME" \
|
||||
--generate-notes \
|
||||
--prerelease
|
||||
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
|
||||
fi
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
|
||||
117
.github/workflows/pipeline.yml
vendored
117
.github/workflows/pipeline.yml
vendored
@@ -20,6 +20,12 @@ jobs:
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
@@ -88,9 +94,19 @@ jobs:
|
||||
|
||||
- uses: ./.github/actions/typecheck
|
||||
|
||||
nix-flake-check:
|
||||
name: Check Nix flake
|
||||
runs-on: ubuntu-latest
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
runs-on: ${{ matrix.runner }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- runner: macos-26
|
||||
system: aarch64-darwin
|
||||
- runner: ubuntu-latest
|
||||
system: x86_64-linux
|
||||
- runner: ubuntu-24.04-arm
|
||||
system: aarch64-linux
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -101,83 +117,20 @@ jobs:
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- name: Run nix flake check
|
||||
run: |
|
||||
nix flake check
|
||||
shell: bash
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
# ci:
|
||||
# needs: typecheck
|
||||
# runs-on: ubuntu-latest
|
||||
# permissions:
|
||||
# contents: read
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# steps:
|
||||
# - name: Checkout repository
|
||||
# uses: actions/checkout@v4
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
# token: ${{ secrets.GITHUB_TOKEN }}
|
||||
# lfs: true
|
||||
#
|
||||
# - name: Configure git user
|
||||
# run: |
|
||||
# git config --local user.email "github-actions@users.noreply.github.com"
|
||||
# git config --local user.name "github-actions bot"
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Pull LFS files
|
||||
# run: |
|
||||
# echo "Pulling Git LFS files..."
|
||||
# git lfs pull
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Setup EXO_HOME and API_PORT
|
||||
# run: |
|
||||
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
|
||||
# # Generate random port (macOS compatible method)
|
||||
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
|
||||
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
|
||||
# echo "Created EXO_HOME: $EXO_HOME"
|
||||
# echo "Generated API_PORT: $API_PORT"
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Setup Nix Environment
|
||||
# run: |
|
||||
# echo "Checking for nix installation..."
|
||||
#
|
||||
# # Check if nix binary exists directly
|
||||
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
# echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
# nix --version
|
||||
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
# echo "Found nix profile script, sourcing..."
|
||||
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
# nix --version
|
||||
# elif command -v nix >/dev/null 2>&1; then
|
||||
# echo "Nix already in PATH"
|
||||
# nix --version
|
||||
# else
|
||||
# echo "Nix not found. Debugging info:"
|
||||
# echo "Contents of /nix/var/nix/profiles/default/:"
|
||||
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||
# exit 1
|
||||
# fi
|
||||
# shell: bash
|
||||
#
|
||||
# - uses: ./.github/actions/lint-check
|
||||
#
|
||||
# - uses: ./.github/actions/unit-test
|
||||
#
|
||||
# - name: Cleanup EXO_HOME
|
||||
# run: |
|
||||
# echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||
# rm -rf "$EXO_HOME"
|
||||
# shell: bash
|
||||
# if: always()
|
||||
- name: Build all Nix outputs
|
||||
run: |
|
||||
nix flake show --json | jq -r '
|
||||
[
|
||||
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
|
||||
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
|
||||
] | .[]
|
||||
' | xargs nix build
|
||||
|
||||
- name: Run nix flake check
|
||||
run: nix flake check
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -7,6 +7,8 @@ digest.txt
|
||||
# nix
|
||||
.direnv/
|
||||
|
||||
# IDEA (PyCharm)
|
||||
.idea
|
||||
|
||||
# xcode / macos
|
||||
*.xcuserstate
|
||||
@@ -14,6 +16,7 @@ digest.txt
|
||||
*.xcuserdatad/
|
||||
**/.DS_Store
|
||||
app/EXO/build/
|
||||
dist/
|
||||
|
||||
|
||||
# rust
|
||||
|
||||
156
.mlx_typings/mlx_lm/models/deepseek_v3.pyi
Normal file
156
.mlx_typings/mlx_lm/models/deepseek_v3.pyi
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Type stubs for mlx_lm.models.deepseek_v3"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
moe_intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
n_shared_experts: Optional[int]
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
kv_lora_rank: int
|
||||
q_lora_rank: Optional[int]
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
topk_method: str
|
||||
scoring_func: str
|
||||
norm_topk_prob: bool
|
||||
n_group: int
|
||||
topk_group: int
|
||||
num_experts_per_tok: int
|
||||
moe_layer_freq: int
|
||||
first_k_dense_replace: int
|
||||
max_position_embeddings: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Any]]
|
||||
attention_bias: bool
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
num_heads: int
|
||||
max_position_embeddings: int
|
||||
rope_theta: float
|
||||
q_lora_rank: Optional[int]
|
||||
qk_rope_head_dim: int
|
||||
kv_lora_rank: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
q_head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
q_a_proj: nn.Linear
|
||||
q_a_layernorm: nn.RMSNorm
|
||||
q_b_proj: nn.Linear
|
||||
kv_a_proj_with_mqa: nn.Linear
|
||||
kv_a_layernorm: nn.RMSNorm
|
||||
kv_b_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
rope: Any
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
gate_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
hidden_size: Optional[int] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
config: ModelArgs
|
||||
top_k: int
|
||||
norm_topk_prob: bool
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
n_group: int
|
||||
topk_group: int
|
||||
weight: mx.array
|
||||
e_score_correction_bias: mx.array
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
config: ModelArgs
|
||||
num_experts_per_tok: int
|
||||
switch_mlp: SwitchGLU
|
||||
gate: MoEGate
|
||||
shared_experts: DeepseekV3MLP
|
||||
sharding_group: Optional[mx.distributed.Group]
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
self_attn: DeepseekV3Attention
|
||||
mlp: DeepseekV3MLP | DeepseekV3MoE
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
vocab_size: int
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[DeepseekV3DecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
model_type: str
|
||||
model: DeepseekV3Model
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
@property
|
||||
def layers(self) -> list[DeepseekV3DecoderLayer]: ...
|
||||
@@ -57,6 +57,11 @@ class SwiGLU(nn.Module):
|
||||
def __call__(self, x, gate): ...
|
||||
|
||||
class SwitchGLU(nn.Module):
|
||||
gate_proj: SwitchLinear
|
||||
up_proj: SwitchLinear
|
||||
down_proj: SwitchLinear
|
||||
activation: SwiGLU
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
|
||||
@@ -4,6 +4,7 @@ This type stub file was generated by pyright.
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
@@ -103,37 +104,55 @@ class TokenizerWrapper:
|
||||
Accessing any attribute other than the ``detokenizer`` is forwarded to the
|
||||
huggingface tokenizer.
|
||||
"""
|
||||
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
|
||||
def add_eos_token(self, token: str): # -> None:
|
||||
...
|
||||
@property
|
||||
def has_thinking(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def think_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def think_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def has_tool_calling(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def tool_call_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def tool_call_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def detokenizer(self): # -> NaiveStreamingDetokenizer:
|
||||
"""
|
||||
Get a stateful streaming detokenizer.
|
||||
"""
|
||||
|
||||
def __getattr__(self, attr): # -> set[Any] | Any:
|
||||
...
|
||||
def __setattr__(self, attr, value): # -> None:
|
||||
...
|
||||
_tokenizer: PreTrainedTokenizerFast
|
||||
eos_token_id: int | None
|
||||
eos_token: str | None
|
||||
bos_token_id: int | None
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
all_special_tokens: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
detokenizer_class: Any = ...,
|
||||
eos_token_ids: list[int] | None = ...,
|
||||
chat_template: Any = ...,
|
||||
tool_parser: Any = ...,
|
||||
tool_call_start: str | None = ...,
|
||||
tool_call_end: str | None = ...,
|
||||
) -> None: ...
|
||||
def encode(self, text: str, **kwargs: Any) -> list[int]: ...
|
||||
def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tokenize: bool = False,
|
||||
add_generation_prompt: bool = False,
|
||||
tools: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> str: ...
|
||||
def get_vocab(self) -> dict[str, int]: ...
|
||||
def add_eos_token(self, token: str) -> None: ...
|
||||
@property
|
||||
def has_thinking(self) -> bool: ...
|
||||
@property
|
||||
def think_start(self) -> str | None: ...
|
||||
@property
|
||||
def think_end(self) -> str | None: ...
|
||||
@property
|
||||
def has_tool_calling(self) -> bool: ...
|
||||
@property
|
||||
def tool_call_start(self) -> str | None: ...
|
||||
@property
|
||||
def tool_call_end(self) -> str | None: ...
|
||||
@property
|
||||
def detokenizer(self) -> NaiveStreamingDetokenizer:
|
||||
"""Get a stateful streaming detokenizer."""
|
||||
|
||||
def __getattr__(self, attr: str) -> Any: ...
|
||||
def __setattr__(self, attr: str, value: Any) -> None: ...
|
||||
|
||||
class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
|
||||
@@ -146,18 +165,11 @@ class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
def batch_decode(self, *args, **kwargs): # -> list[str]:
|
||||
...
|
||||
|
||||
def load_tokenizer(
|
||||
def load(
|
||||
model_path: Path,
|
||||
tokenizer_config_extra=...,
|
||||
return_tokenizer=...,
|
||||
eos_token_ids=...,
|
||||
) -> (
|
||||
TokenizerWrapper
|
||||
| type[SPMStreamingDetokenizer]
|
||||
| partial[SPMStreamingDetokenizer]
|
||||
| type[BPEStreamingDetokenizer]
|
||||
| type[NaiveStreamingDetokenizer]
|
||||
):
|
||||
tokenizer_config_extra: dict[str, Any] | None = None,
|
||||
eos_token_ids: list[int] | int | None = None,
|
||||
) -> TokenizerWrapper:
|
||||
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||
detokenizer to use.
|
||||
|
||||
@@ -165,4 +177,7 @@ def load_tokenizer(
|
||||
a Hugging Face repo ID.
|
||||
"""
|
||||
|
||||
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
|
||||
# Alias for backward compatibility
|
||||
load_tokenizer = load
|
||||
|
||||
def no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...
|
||||
|
||||
3
.prettierrc
Normal file
3
.prettierrc
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"useTabs": true
|
||||
}
|
||||
6
.swift-format
Normal file
6
.swift-format
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"version": 1,
|
||||
"indentation": {
|
||||
"spaces": 4
|
||||
}
|
||||
}
|
||||
96
AGENTS.md
Normal file
96
AGENTS.md
Normal file
@@ -0,0 +1,96 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to AI coding agents when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
|
||||
|
||||
## Build & Run Commands
|
||||
|
||||
```bash
|
||||
# Build the dashboard (required before running exo)
|
||||
cd dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo (starts both master and worker with API at http://localhost:52415)
|
||||
uv run exo
|
||||
|
||||
# Run with verbose logging
|
||||
uv run exo -v # or -vv for more verbose
|
||||
|
||||
# Run tests (excludes slow tests by default)
|
||||
uv run pytest
|
||||
|
||||
# Run all tests including slow tests
|
||||
uv run pytest -m ""
|
||||
|
||||
# Run a specific test file
|
||||
uv run pytest src/exo/shared/tests/test_election.py
|
||||
|
||||
# Run a specific test function
|
||||
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
|
||||
|
||||
# Type checking (strict mode)
|
||||
uv run basedpyright
|
||||
|
||||
# Linting
|
||||
uv run ruff check
|
||||
|
||||
# Format code (using nix)
|
||||
nix fmt
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
A single exo `Node` (src/exo/main.py) runs multiple components:
|
||||
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
|
||||
- **Worker**: Handles inference tasks, downloads models, manages runner processes
|
||||
- **Master**: Coordinates cluster state, places model instances across nodes
|
||||
- **Election**: Bully algorithm for master election
|
||||
- **API**: FastAPI server for OpenAI-compatible chat completions
|
||||
|
||||
### Message Flow
|
||||
Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
|
||||
- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers
|
||||
- `LOCAL_EVENTS`: Workers send events to master for indexing
|
||||
- `COMMANDS`: Workers/API send commands to master
|
||||
- `ELECTION_MESSAGES`: Election protocol messages
|
||||
- `CONNECTION_MESSAGES`: libp2p connection updates
|
||||
|
||||
### Event Sourcing
|
||||
The system uses event sourcing for state management:
|
||||
- `State` (src/exo/shared/types/state.py): Immutable state object
|
||||
- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state
|
||||
- Master indexes events and broadcasts; workers apply indexed events
|
||||
|
||||
### Key Type Hierarchy
|
||||
- `src/exo/shared/types/`: Pydantic models for all shared types
|
||||
- `events.py`: Event types (discriminated union)
|
||||
- `commands.py`: Command types
|
||||
- `tasks.py`: Task types for worker execution
|
||||
- `state.py`: Cluster state model
|
||||
|
||||
### Rust Components
|
||||
Rust code in `rust/` provides:
|
||||
- `networking`: libp2p networking (gossipsub, peer discovery)
|
||||
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
|
||||
- `system_custodian`: System-level operations
|
||||
|
||||
### Dashboard
|
||||
Svelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.
|
||||
|
||||
## Code Style Requirements
|
||||
|
||||
From .cursorrules:
|
||||
- Strict, exhaustive typing - never bypass the type-checker
|
||||
- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives
|
||||
- Pydantic models with `frozen=True` and `strict=True`
|
||||
- Pure functions with injectable effect handlers for side-effects
|
||||
- Descriptive names - no abbreviations or 3-letter acronyms
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
3719
Cargo.lock
generated
3719
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
71
Cargo.toml
71
Cargo.toml
@@ -1,8 +1,9 @@
|
||||
[workspace]
|
||||
resolver = "3"
|
||||
members = [
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/util",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -23,38 +24,62 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
syn = "2.0"
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
darling = "0.20"
|
||||
|
||||
# Macro dependecies
|
||||
extend = "1.2"
|
||||
delegate = "0.13"
|
||||
impl-trait-for-tuples = "0.2"
|
||||
clap = "4.5"
|
||||
derive_more = { version = "2.0.1", features = ["display"] }
|
||||
pin-project = "1"
|
||||
|
||||
# Utility dependencies
|
||||
itertools = "0.14"
|
||||
thiserror = "2"
|
||||
internment = "0.8"
|
||||
recursion = "0.5"
|
||||
regex = "1.11"
|
||||
once_cell = "1.21"
|
||||
thread_local = "1.1"
|
||||
bon = "3.4"
|
||||
generativity = "1.1"
|
||||
anyhow = "1.0"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Functional generics/lenses frameworks
|
||||
frunk_core = "0.4"
|
||||
frunk = "0.4"
|
||||
frunk_utils = "0.2"
|
||||
frunk-enum-core = "0.3"
|
||||
|
||||
# Async dependencies
|
||||
tokio = "1.46"
|
||||
n0-future = "0.3.1"
|
||||
postcard = "1.1.3"
|
||||
n0-error = "0.1.2"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
futures-timer = "3.0"
|
||||
|
||||
# Data structures
|
||||
either = "1.15"
|
||||
ordered-float = "5.0"
|
||||
ahash = "0.8"
|
||||
|
||||
# Tracing/logging
|
||||
log = "0.4"
|
||||
blake3 = "1.8.2"
|
||||
env_logger = "0.11"
|
||||
tracing-subscriber = "0.3.20"
|
||||
|
||||
# networking
|
||||
iroh = "0.95.1"
|
||||
iroh-gossip = "0.95.0"
|
||||
bytes = "1.11.0"
|
||||
|
||||
# pyo3
|
||||
pyo3 = "0.27.1"
|
||||
# pyo3-async-runtimes = "0.27.0"
|
||||
pyo3-log = "0.13.2"
|
||||
pyo3-stub-gen = "0.17.2"
|
||||
|
||||
# util
|
||||
rand = "0.9.2"
|
||||
extend = "1.2"
|
||||
|
||||
[patch.crates-io]
|
||||
netwatch = { git = "https://github.com/Evanev7/net-tools.git", branch="patch-for-exo" }
|
||||
libp2p = "0.56"
|
||||
libp2p-tcp = "0.44"
|
||||
|
||||
[workspace.lints.rust]
|
||||
static_mut_refs = "warn" # Or use "warn" instead of deny
|
||||
incomplete_features = "allow"
|
||||
|
||||
# Clippy's lint category level configurations;
|
||||
# every member crate needs to inherit these by adding
|
||||
|
||||
41
MISSED_THINGS.md
Normal file
41
MISSED_THINGS.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# Missed things
|
||||
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
|
||||
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
108
README.md
108
README.md
@@ -8,7 +8,7 @@
|
||||
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/72NsF6ux" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://x.com/exolabs" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/twitter/follow/exolabs?style=social" alt="X"></a>
|
||||
<a href="https://www.apache.org/licenses/LICENSE-2.0.html" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/License-Apache2.0-blue.svg" alt="License: Apache-2.0"></a>
|
||||
</p>
|
||||
@@ -61,10 +61,10 @@ Devices running exo automatically discover each other, without needing any manua
|
||||
|
||||
There are two ways to run exo:
|
||||
|
||||
### Run from Source (Mac & Linux)
|
||||
### Run from Source (macOS)
|
||||
|
||||
**Prerequisites:**
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
@@ -98,6 +98,62 @@ uv run exo
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
### Run from Source (Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard) - version 18 or higher
|
||||
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
|
||||
|
||||
**Installation methods:**
|
||||
|
||||
**Option 1: Using system package manager (Ubuntu/Debian example):**
|
||||
```bash
|
||||
# Install Node.js and npm
|
||||
sudo apt update
|
||||
sudo apt install nodejs npm
|
||||
|
||||
# Install uv
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Option 2: Using Homebrew on Linux (if preferred):**
|
||||
```bash
|
||||
# Install Homebrew on Linux
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
|
||||
# Install dependencies
|
||||
brew install uv node
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Note:** The `macmon` package is macOS-only and not required for Linux.
|
||||
|
||||
Clone the repo, build the dashboard, and run exo:
|
||||
|
||||
```bash
|
||||
# Clone exo
|
||||
git clone https://github.com/exo-explore/exo
|
||||
|
||||
# Build dashboard
|
||||
cd exo/dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo
|
||||
uv run exo
|
||||
```
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
|
||||
|
||||
### macOS App
|
||||
|
||||
exo ships a macOS app that runs in the background on your Mac.
|
||||
@@ -110,6 +166,47 @@ Download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-
|
||||
|
||||
The app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.
|
||||
|
||||
#### Uninstalling the macOS App
|
||||
|
||||
The recommended way to uninstall is through the app itself: click the menu bar icon → Advanced → Uninstall. This cleanly removes all system components.
|
||||
|
||||
If you've already deleted the app, you can run the standalone uninstaller script:
|
||||
|
||||
```bash
|
||||
sudo ./app/EXO/uninstall-exo.sh
|
||||
```
|
||||
|
||||
This removes:
|
||||
- Network setup LaunchDaemon
|
||||
- Network configuration script
|
||||
- Log files
|
||||
- The "exo" network location
|
||||
|
||||
**Note:** You'll need to manually remove EXO from Login Items in System Settings → General → Login Items.
|
||||
|
||||
---
|
||||
|
||||
### Enabling RDMA on macOS
|
||||
|
||||
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
|
||||
|
||||
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
|
||||
To enable RDMA on macOS, follow these steps:
|
||||
|
||||
1. Shut down your Mac.
|
||||
2. Hold down the power button for 10 seconds until the boot menu appears.
|
||||
3. Select "Options" to enter Recovery mode.
|
||||
4. When the Recovery UI appears, open the Terminal from the Utilities menu.
|
||||
5. In the Terminal, type:
|
||||
```
|
||||
rdma_ctl enable
|
||||
```
|
||||
and press Enter.
|
||||
6. Reboot your Mac.
|
||||
|
||||
After that, RDMA will be enabled in macOS and exo will take care of the rest.
|
||||
|
||||
---
|
||||
|
||||
### Using the API
|
||||
@@ -208,7 +305,10 @@ curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
|
||||
- List all models: `curl http://localhost:52415/models`
|
||||
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
|
||||
|
||||
For further details, see API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
|
||||
For further details, see:
|
||||
|
||||
- API basic documentation in [docs/api.md](docs/api.md).
|
||||
- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
|
||||
|
||||
---
|
||||
|
||||
|
||||
1
TODO.md
1
TODO.md
@@ -19,7 +19,6 @@
|
||||
25. Rethink retry logic
|
||||
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
||||
27. Log cleanup - per-module log filters and default to DEBUG log levels
|
||||
28. Really need to remove all mlx logic outside of the runner - API has a transitive dependency on engines which imports mlx
|
||||
|
||||
Potential refactors:
|
||||
|
||||
|
||||
@@ -585,7 +585,7 @@
|
||||
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 2.8.1;
|
||||
minimumVersion = 2.9.0-beta.1;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/sparkle-project/Sparkle.git",
|
||||
"state" : {
|
||||
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
|
||||
"version" : "2.8.1"
|
||||
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
|
||||
"version" : "2.9.0-beta.1"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -12,18 +12,25 @@ struct ContentView: View {
|
||||
@EnvironmentObject private var controller: ExoProcessController
|
||||
@EnvironmentObject private var stateService: ClusterStateService
|
||||
@EnvironmentObject private var networkStatusService: NetworkStatusService
|
||||
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@EnvironmentObject private var updater: SparkleUpdater
|
||||
@State private var focusedNode: NodeViewModel?
|
||||
@State private var deletingInstanceIDs: Set<String> = []
|
||||
@State private var showAllNodes = false
|
||||
@State private var showAllInstances = false
|
||||
@State private var showAdvanced = false
|
||||
@State private var showDebugInfo = false
|
||||
@State private var bugReportInFlight = false
|
||||
@State private var bugReportMessage: String?
|
||||
@State private var uninstallInProgress = false
|
||||
@State private var pendingNamespace: String = ""
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
statusSection
|
||||
if shouldShowLocalNetworkWarning {
|
||||
localNetworkWarningBanner
|
||||
}
|
||||
if shouldShowClusterDetails {
|
||||
Divider()
|
||||
overviewSection
|
||||
@@ -38,6 +45,7 @@ struct ContentView: View {
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowClusterDetails)
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowInstances)
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowLocalNetworkWarning)
|
||||
.padding()
|
||||
.frame(width: 340)
|
||||
.onAppear {
|
||||
@@ -47,9 +55,67 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
// Show warning if local network is not working and EXO is running.
|
||||
// The checker uses a longer timeout on first launch to allow time for
|
||||
// the permission prompt, so this correctly handles both:
|
||||
// 1. User denied permission on first launch
|
||||
// 2. Permission broke after restart (macOS TCC bug)
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
private var localNetworkWarningBanner: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack(spacing: 6) {
|
||||
Image(systemName: "exclamationmark.triangle.fill")
|
||||
.foregroundColor(.orange)
|
||||
Text("Local Network Access Issue")
|
||||
.font(.caption)
|
||||
.fontWeight(.semibold)
|
||||
}
|
||||
Text(
|
||||
"Device discovery won't work. To fix:\n1. Quit EXO\n2. Open System Settings → Privacy & Security → Local Network\n3. Toggle EXO off, then back on\n4. Relaunch EXO"
|
||||
)
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
Button {
|
||||
openLocalNetworkSettings()
|
||||
} label: {
|
||||
Text("Open Settings")
|
||||
.font(.caption2)
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
}
|
||||
.padding(8)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.fill(Color.orange.opacity(0.1))
|
||||
)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.stroke(Color.orange.opacity(0.3), lineWidth: 1)
|
||||
)
|
||||
}
|
||||
|
||||
private func openLocalNetworkSettings() {
|
||||
// Open Privacy & Security settings - Local Network section
|
||||
if let url = URL(
|
||||
string: "x-apple.systempreferences:com.apple.preference.security?Privacy_LocalNetwork")
|
||||
{
|
||||
NSWorkspace.shared.open(url)
|
||||
}
|
||||
}
|
||||
|
||||
private var topologySection: some View {
|
||||
Group {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(), !topology.nodes.isEmpty {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(
|
||||
localNodeId: stateService.localNodeId), !topology.nodes.isEmpty
|
||||
{
|
||||
TopologyMiniView(topology: topology)
|
||||
}
|
||||
}
|
||||
@@ -83,8 +149,10 @@ struct ContentView: View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
HStack {
|
||||
VStack(alignment: .leading) {
|
||||
Text("\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB")
|
||||
.font(.headline)
|
||||
Text(
|
||||
"\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB"
|
||||
)
|
||||
.font(.headline)
|
||||
Text("Memory")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
@@ -193,11 +261,7 @@ struct ContentView: View {
|
||||
Divider()
|
||||
.padding(.vertical, 4)
|
||||
}
|
||||
controlButton(title: "Check for Updates") {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
.padding(.bottom, 8)
|
||||
debugSection
|
||||
advancedSection
|
||||
.padding(.bottom, 8)
|
||||
controlButton(title: "Quit", tint: .secondary) {
|
||||
controller.stop()
|
||||
@@ -206,7 +270,57 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void) -> some View {
|
||||
private var advancedSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
Text("Advanced")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
Spacer()
|
||||
collapseButton(isExpanded: $showAdvanced)
|
||||
}
|
||||
.animation(nil, value: showAdvanced)
|
||||
if showAdvanced {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Cluster Namespace")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
HStack {
|
||||
TextField("optional", text: $pendingNamespace)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingNamespace = controller.customNamespace
|
||||
}
|
||||
Button("Save & Restart") {
|
||||
controller.customNamespace = pendingNamespace
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingNamespace == controller.customNamespace)
|
||||
}
|
||||
}
|
||||
HoverButton(title: "Check for Updates", small: true) {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
debugSection
|
||||
HoverButton(title: "Uninstall", tint: .red, small: true) {
|
||||
showUninstallConfirmationAlert()
|
||||
}
|
||||
.disabled(uninstallInProgress)
|
||||
}
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.25), value: showAdvanced)
|
||||
}
|
||||
|
||||
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void)
|
||||
-> some View
|
||||
{
|
||||
HoverButton(title: title, tint: tint, trailingSystemImage: nil, action: action)
|
||||
}
|
||||
|
||||
@@ -237,9 +351,12 @@ struct ContentView: View {
|
||||
Button {
|
||||
isExpanded.wrappedValue.toggle()
|
||||
} label: {
|
||||
Label(isExpanded.wrappedValue ? "Hide" : "Show All", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
Label(
|
||||
isExpanded.wrappedValue ? "Hide" : "Show All",
|
||||
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
|
||||
)
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.font(.caption2)
|
||||
@@ -328,15 +445,15 @@ struct ContentView: View {
|
||||
}
|
||||
|
||||
private var debugSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
Text("Debug Info")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
Spacer()
|
||||
collapseButton(isExpanded: $showDebugInfo)
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
HoverButton(
|
||||
title: "Debug Info",
|
||||
tint: .primary,
|
||||
trailingSystemImage: showDebugInfo ? "chevron.up" : "chevron.down",
|
||||
small: true
|
||||
) {
|
||||
showDebugInfo.toggle()
|
||||
}
|
||||
.animation(nil, value: showDebugInfo)
|
||||
if showDebugInfo {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Version: \(buildTag)")
|
||||
@@ -349,15 +466,63 @@ struct ContentView: View {
|
||||
.font(.caption2)
|
||||
.foregroundColor(thunderboltStatusColor)
|
||||
interfaceIpList
|
||||
rdmaStatusView
|
||||
sendBugReportButton
|
||||
.padding(.top, 6)
|
||||
}
|
||||
.padding(.leading, 8)
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.25), value: showDebugInfo)
|
||||
}
|
||||
|
||||
private var rdmaStatusView: some View {
|
||||
let rdma = networkStatusService.status.rdmaStatus
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
Text("RDMA: \(rdmaStatusText(rdma))")
|
||||
.font(.caption2)
|
||||
.foregroundColor(rdmaStatusColor(rdma))
|
||||
if !rdma.devices.isEmpty {
|
||||
Text(" Devices: \(rdma.devices.joined(separator: ", "))")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
if !rdma.activePorts.isEmpty {
|
||||
Text(" Active Ports:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(rdma.activePorts, id: \.device) { port in
|
||||
Text(" \(port.device) port \(port.port): \(port.state)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.green)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func rdmaStatusText(_ rdma: RDMAStatus) -> String {
|
||||
switch rdma.rdmaCtlEnabled {
|
||||
case .some(true):
|
||||
return "Enabled"
|
||||
case .some(false):
|
||||
return "Disabled"
|
||||
case nil:
|
||||
return rdma.devices.isEmpty ? "Not Available" : "Available"
|
||||
}
|
||||
}
|
||||
|
||||
private func rdmaStatusColor(_ rdma: RDMAStatus) -> Color {
|
||||
switch rdma.rdmaCtlEnabled {
|
||||
case .some(true):
|
||||
return .green
|
||||
case .some(false):
|
||||
return .orange
|
||||
case nil:
|
||||
return rdma.devices.isEmpty ? .secondary : .green
|
||||
}
|
||||
}
|
||||
|
||||
private var sendBugReportButton: some View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Button {
|
||||
@@ -447,6 +612,88 @@ struct ContentView: View {
|
||||
bugReportInFlight = false
|
||||
}
|
||||
|
||||
private func showUninstallConfirmationAlert() {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Uninstall EXO"
|
||||
alert.informativeText = """
|
||||
This will remove EXO and all its system components:
|
||||
|
||||
• Network configuration daemon
|
||||
• Launch at login registration
|
||||
• EXO network location
|
||||
|
||||
The app will be moved to Trash.
|
||||
"""
|
||||
alert.alertStyle = .warning
|
||||
alert.addButton(withTitle: "Uninstall")
|
||||
alert.addButton(withTitle: "Cancel")
|
||||
|
||||
// Style the Uninstall button as destructive
|
||||
if let uninstallButton = alert.buttons.first {
|
||||
uninstallButton.hasDestructiveAction = true
|
||||
}
|
||||
|
||||
let response = alert.runModal()
|
||||
if response == .alertFirstButtonReturn {
|
||||
performUninstall()
|
||||
}
|
||||
}
|
||||
|
||||
private func performUninstall() {
|
||||
uninstallInProgress = true
|
||||
|
||||
// Stop EXO process first
|
||||
controller.cancelPendingLaunch()
|
||||
controller.stop()
|
||||
stateService.stopPolling()
|
||||
|
||||
// Run the privileged uninstall on a background thread
|
||||
// Using .utility QoS to avoid priority inversion with NSAppleScript's subprocess
|
||||
DispatchQueue.global(qos: .utility).async {
|
||||
do {
|
||||
// Remove network setup daemon and components (requires admin privileges)
|
||||
try NetworkSetupHelper.uninstall()
|
||||
|
||||
DispatchQueue.main.async {
|
||||
// Unregister from launch at login
|
||||
LaunchAtLoginHelper.disable()
|
||||
|
||||
// Move app to trash
|
||||
self.moveAppToTrash()
|
||||
|
||||
// Quit the app
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {
|
||||
NSApplication.shared.terminate(nil)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
DispatchQueue.main.async {
|
||||
self.showErrorAlert(message: error.localizedDescription)
|
||||
self.uninstallInProgress = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func showErrorAlert(message: String) {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Uninstall Failed"
|
||||
alert.informativeText = message
|
||||
alert.alertStyle = .critical
|
||||
alert.addButton(withTitle: "OK")
|
||||
alert.runModal()
|
||||
}
|
||||
|
||||
private func moveAppToTrash() {
|
||||
guard let appURL = Bundle.main.bundleURL as URL? else { return }
|
||||
do {
|
||||
try FileManager.default.trashItem(at: appURL, resultingItemURL: nil)
|
||||
} catch {
|
||||
// If we can't trash the app, that's OK - user can do it manually
|
||||
// The important system components have already been cleaned up
|
||||
}
|
||||
}
|
||||
|
||||
private var buildTag: String {
|
||||
Bundle.main.infoDictionary?["EXOBuildTag"] as? String ?? "unknown"
|
||||
}
|
||||
@@ -460,14 +707,27 @@ private struct HoverButton: View {
|
||||
let title: String
|
||||
let tint: Color
|
||||
let trailingSystemImage: String?
|
||||
let small: Bool
|
||||
let action: () -> Void
|
||||
|
||||
init(
|
||||
title: String, tint: Color = .primary, trailingSystemImage: String? = nil,
|
||||
small: Bool = false, action: @escaping () -> Void
|
||||
) {
|
||||
self.title = title
|
||||
self.tint = tint
|
||||
self.trailingSystemImage = trailingSystemImage
|
||||
self.small = small
|
||||
self.action = action
|
||||
}
|
||||
|
||||
@State private var isHovering = false
|
||||
|
||||
var body: some View {
|
||||
Button(action: action) {
|
||||
HStack {
|
||||
Text(title)
|
||||
.font(small ? .caption : nil)
|
||||
Spacer()
|
||||
if let systemName = trailingSystemImage {
|
||||
Image(systemName: systemName)
|
||||
@@ -475,8 +735,8 @@ private struct HoverButton: View {
|
||||
}
|
||||
}
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
.padding(.vertical, 6)
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, small ? 4 : 6)
|
||||
.padding(.horizontal, small ? 6 : 8)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 6)
|
||||
.fill(
|
||||
@@ -491,4 +751,3 @@ private struct HoverButton: View {
|
||||
.onHover { isHovering = $0 }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
import AppKit
|
||||
import CoreImage
|
||||
import CoreImage.CIFilterBuiltins
|
||||
import ServiceManagement
|
||||
import Sparkle
|
||||
import SwiftUI
|
||||
import ServiceManagement
|
||||
import UserNotifications
|
||||
import os.log
|
||||
|
||||
@@ -19,6 +19,7 @@ struct EXOApp: App {
|
||||
@StateObject private var controller: ExoProcessController
|
||||
@StateObject private var stateService: ClusterStateService
|
||||
@StateObject private var networkStatusService: NetworkStatusService
|
||||
@StateObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@StateObject private var updater: SparkleUpdater
|
||||
private let terminationObserver: TerminationObserver
|
||||
private let ciContext = CIContext(options: nil)
|
||||
@@ -37,9 +38,13 @@ struct EXOApp: App {
|
||||
_stateService = StateObject(wrappedValue: service)
|
||||
let networkStatus = NetworkStatusService()
|
||||
_networkStatusService = StateObject(wrappedValue: networkStatus)
|
||||
let localNetwork = LocalNetworkChecker()
|
||||
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
|
||||
_updater = StateObject(wrappedValue: updater)
|
||||
enableLaunchAtLoginIfNeeded()
|
||||
NetworkSetupHelper.ensureLaunchDaemonInstalled()
|
||||
// Check local network access BEFORE launching exo
|
||||
localNetwork.check()
|
||||
controller.scheduleLaunch(after: 15)
|
||||
service.startPolling()
|
||||
networkStatus.startPolling()
|
||||
@@ -51,6 +56,7 @@ struct EXOApp: App {
|
||||
.environmentObject(controller)
|
||||
.environmentObject(stateService)
|
||||
.environmentObject(networkStatusService)
|
||||
.environmentObject(localNetworkChecker)
|
||||
.environmentObject(updater)
|
||||
} label: {
|
||||
menuBarIcon
|
||||
@@ -107,7 +113,7 @@ struct EXOApp: App {
|
||||
filter.contrast = 0.9
|
||||
|
||||
guard let output = filter.outputImage,
|
||||
let rendered = ciContext.createCGImage(output, from: output.extent)
|
||||
let rendered = ciContext.createCGImage(output, from: output.extent)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
@@ -120,7 +126,26 @@ struct EXOApp: App {
|
||||
do {
|
||||
try SMAppService.mainApp.register()
|
||||
} catch {
|
||||
Logger().error("Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
Logger().error(
|
||||
"Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for managing EXO's launch-at-login registration
|
||||
enum LaunchAtLoginHelper {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LaunchAtLogin")
|
||||
|
||||
/// Unregisters EXO from launching at login
|
||||
static func disable() {
|
||||
guard SMAppService.mainApp.status == .enabled else { return }
|
||||
do {
|
||||
try SMAppService.mainApp.unregister()
|
||||
logger.info("Unregistered EXO from launch at login")
|
||||
} catch {
|
||||
logger.error(
|
||||
"Failed to unregister EXO from launch at login: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,7 +170,7 @@ final class SparkleUpdater: NSObject, ObservableObject {
|
||||
center.requestAuthorization(options: [.alert, .sound]) { _, _ in }
|
||||
controller.updater.automaticallyChecksForUpdates = true
|
||||
controller.updater.automaticallyDownloadsUpdates = false
|
||||
controller.updater.updateCheckInterval = 900 // 15 minutes
|
||||
controller.updater.updateCheckInterval = 900 // 15 minutes
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 5) { [weak controller] in
|
||||
controller?.updater.checkForUpdatesInBackground()
|
||||
}
|
||||
@@ -212,7 +237,8 @@ private final class ExoNotificationDelegate: NSObject, UNUserNotificationCenterD
|
||||
func userNotificationCenter(
|
||||
_ center: UNUserNotificationCenter,
|
||||
willPresent notification: UNNotification,
|
||||
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) -> Void
|
||||
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) ->
|
||||
Void
|
||||
) {
|
||||
completionHandler([.banner, .list, .sound])
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ import AppKit
|
||||
import Combine
|
||||
import Foundation
|
||||
|
||||
private let customNamespaceKey = "EXOCustomNamespace"
|
||||
|
||||
@MainActor
|
||||
final class ExoProcessController: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -27,6 +29,14 @@ final class ExoProcessController: ObservableObject {
|
||||
@Published private(set) var status: Status = .stopped
|
||||
@Published private(set) var lastError: String?
|
||||
@Published private(set) var launchCountdownSeconds: Int?
|
||||
@Published var customNamespace: String = {
|
||||
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
|
||||
}()
|
||||
{
|
||||
didSet {
|
||||
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
|
||||
}
|
||||
}
|
||||
|
||||
private var process: Process?
|
||||
private var runtimeDirectoryURL: URL?
|
||||
@@ -180,7 +190,7 @@ final class ExoProcessController: ObservableObject {
|
||||
private func makeEnvironment(for runtimeURL: URL) -> [String: String] {
|
||||
var environment = ProcessInfo.processInfo.environment
|
||||
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = buildTag()
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
|
||||
|
||||
var paths: [String] = []
|
||||
if let existing = environment["PATH"], !existing.isEmpty {
|
||||
@@ -212,11 +222,19 @@ final class ExoProcessController: ObservableObject {
|
||||
if let tag = Bundle.main.infoDictionary?["EXOBuildTag"] as? String, !tag.isEmpty {
|
||||
return tag
|
||||
}
|
||||
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String, !short.isEmpty {
|
||||
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String,
|
||||
!short.isEmpty
|
||||
{
|
||||
return short
|
||||
}
|
||||
return "dev"
|
||||
}
|
||||
|
||||
private func computeNamespace() -> String {
|
||||
let base = buildTag()
|
||||
let custom = customNamespace.trimmingCharacters(in: .whitespaces)
|
||||
return custom.isEmpty ? base : custom
|
||||
}
|
||||
}
|
||||
|
||||
struct RuntimeError: LocalizedError {
|
||||
|
||||
@@ -8,5 +8,15 @@
|
||||
<string>$(EXO_BUILD_TAG)</string>
|
||||
<key>EXOBuildCommit</key>
|
||||
<string>$(EXO_BUILD_COMMIT)</string>
|
||||
<key>EXOBugReportPresignedUrlEndpoint</key>
|
||||
<string>$(EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT)</string>
|
||||
<key>NSLocalNetworkUsageDescription</key>
|
||||
<string>EXO needs local network access to discover and connect to other devices in your cluster for distributed AI inference.</string>
|
||||
<key>NSBonjourServices</key>
|
||||
<array>
|
||||
<string>_p2p._tcp</string>
|
||||
<string>_p2p._udp</string>
|
||||
<string>_libp2p._udp</string>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
||||
|
||||
@@ -16,10 +16,13 @@ struct ClusterState: Decodable {
|
||||
self.instances = rawInstances.mapValues(\.instance)
|
||||
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
|
||||
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
|
||||
let rawTasks = try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
let rawTasks =
|
||||
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
self.tasks = rawTasks.compactMapValues(\.task)
|
||||
self.topology = try container.decodeIfPresent(Topology.self, forKey: .topology)
|
||||
let rawDownloads = try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads) ?? [:]
|
||||
let rawDownloads =
|
||||
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
|
||||
?? [:]
|
||||
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
|
||||
}
|
||||
|
||||
@@ -41,7 +44,8 @@ private struct TaggedInstance: Decodable {
|
||||
let payloads = try container.decode([String: ClusterInstancePayload].self)
|
||||
guard let entry = payloads.first else {
|
||||
throw DecodingError.dataCorrupted(
|
||||
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
|
||||
DecodingError.Context(
|
||||
codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
|
||||
)
|
||||
}
|
||||
self.instance = ClusterInstance(
|
||||
@@ -77,7 +81,8 @@ struct RunnerStatusSummary: Decodable {
|
||||
let payloads = try container.decode([String: RunnerStatusDetail].self)
|
||||
guard let entry = payloads.first else {
|
||||
throw DecodingError.dataCorrupted(
|
||||
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
|
||||
DecodingError.Context(
|
||||
codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
|
||||
)
|
||||
}
|
||||
self.status = entry.key
|
||||
@@ -257,7 +262,9 @@ struct ChatCompletionTaskParameters: Decodable, Equatable {
|
||||
|
||||
func promptPreview() -> String? {
|
||||
guard let messages else { return nil }
|
||||
if let userMessage = messages.last(where: { $0.role?.lowercased() == "user" && ($0.content?.isEmpty == false) }) {
|
||||
if let userMessage = messages.last(where: {
|
||||
$0.role?.lowercased() == "user" && ($0.content?.isEmpty == false)
|
||||
}) {
|
||||
return userMessage.content
|
||||
}
|
||||
return messages.last?.content
|
||||
@@ -365,5 +372,3 @@ extension ClusterState {
|
||||
|
||||
func availableModels() -> [ModelOption] { [] }
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import CryptoKit
|
||||
import Foundation
|
||||
|
||||
struct BugReportOutcome: Equatable {
|
||||
@@ -7,17 +6,17 @@ struct BugReportOutcome: Equatable {
|
||||
}
|
||||
|
||||
enum BugReportError: LocalizedError {
|
||||
case missingCredentials
|
||||
case invalidEndpoint
|
||||
case presignedUrlFailed(String)
|
||||
case uploadFailed(String)
|
||||
case collectFailed(String)
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .missingCredentials:
|
||||
return "Bug report upload credentials are not set."
|
||||
case .invalidEndpoint:
|
||||
return "Bug report endpoint is invalid."
|
||||
case .presignedUrlFailed(let message):
|
||||
return "Failed to get presigned URLs: \(message)"
|
||||
case .uploadFailed(let message):
|
||||
return "Bug report upload failed: \(message)"
|
||||
case .collectFailed(let message):
|
||||
@@ -27,11 +26,13 @@ enum BugReportError: LocalizedError {
|
||||
}
|
||||
|
||||
struct BugReportService {
|
||||
struct AWSConfig {
|
||||
let accessKey: String
|
||||
let secretKey: String
|
||||
let region: String
|
||||
let bucket: String
|
||||
private struct PresignedUrlsRequest: Codable {
|
||||
let keys: [String]
|
||||
}
|
||||
|
||||
private struct PresignedUrlsResponse: Codable {
|
||||
let urls: [String: String]
|
||||
let expiresIn: Int?
|
||||
}
|
||||
|
||||
func sendReport(
|
||||
@@ -39,9 +40,9 @@ struct BugReportService {
|
||||
now: Date = Date(),
|
||||
isManual: Bool = false
|
||||
) async throws -> BugReportOutcome {
|
||||
let credentials = try loadCredentials()
|
||||
let timestamp = ISO8601DateFormatter().string(from: now)
|
||||
let prefix = "reports/\(timestamp)/"
|
||||
let timestamp = Self.runTimestampString(now)
|
||||
let dayPrefix = Self.dayPrefixString(now)
|
||||
let prefix = "reports/\(dayPrefix)/\(timestamp)/"
|
||||
|
||||
let logData = readLog()
|
||||
let ifconfigText = try await captureIfconfig()
|
||||
@@ -66,29 +67,82 @@ struct BugReportService {
|
||||
("\(prefix)exo.log", logData),
|
||||
("\(prefix)state.json", stateData),
|
||||
("\(prefix)events.json", eventsData),
|
||||
("\(prefix)report.json", reportJSON)
|
||||
("\(prefix)report.json", reportJSON),
|
||||
]
|
||||
|
||||
let uploader = try S3Uploader(config: credentials)
|
||||
for item in uploads {
|
||||
guard let data = item.data else { continue }
|
||||
try await uploader.upload(
|
||||
objectPath: item.path,
|
||||
body: data
|
||||
)
|
||||
let uploadItems: [(key: String, body: Data)] = uploads.compactMap { item in
|
||||
guard let body = item.data else { return nil }
|
||||
return (key: item.path, body: body)
|
||||
}
|
||||
|
||||
return BugReportOutcome(success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
|
||||
guard !uploadItems.isEmpty else {
|
||||
return BugReportOutcome(success: false, message: "No data to upload")
|
||||
}
|
||||
|
||||
let presignedUrls = try await fetchPresignedUploadUrls(keys: uploadItems.map(\.key))
|
||||
for item in uploadItems {
|
||||
guard let urlString = presignedUrls[item.key], let url = URL(string: urlString) else {
|
||||
throw BugReportError.uploadFailed("Missing presigned URL for \(item.key)")
|
||||
}
|
||||
try await uploadToPresignedUrl(url: url, body: item.body)
|
||||
}
|
||||
|
||||
return BugReportOutcome(
|
||||
success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
|
||||
}
|
||||
|
||||
private func loadCredentials() throws -> AWSConfig {
|
||||
// These credentials are write-only and necessary to receive bug reports from users
|
||||
return AWSConfig(
|
||||
accessKey: "AKIAYEKP5EMXTOBYDGHX",
|
||||
secretKey: "Ep5gIlUZ1o8ssTLQwmyy34yPGfTPEYQ4evE8NdPE",
|
||||
region: "us-east-1",
|
||||
bucket: "exo-bug-reports"
|
||||
)
|
||||
private static func dayPrefixString(_ date: Date) -> String {
|
||||
var calendar = Calendar(identifier: .gregorian)
|
||||
calendar.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
|
||||
let components = calendar.dateComponents([.year, .month, .day], from: date)
|
||||
let year = components.year ?? 0
|
||||
let month = components.month ?? 0
|
||||
let day = components.day ?? 0
|
||||
return String(format: "%04d/%02d/%02d", year, month, day)
|
||||
}
|
||||
|
||||
private static func runTimestampString(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.locale = Locale(identifier: "en_US_POSIX")
|
||||
formatter.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
|
||||
formatter.dateFormat = "yyyy-MM-dd'T'HHmmss.SSS'Z'"
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private func fetchPresignedUploadUrls(keys: [String], bundle: Bundle = .main) async throws
|
||||
-> [String: String]
|
||||
{
|
||||
guard
|
||||
let endpointString = bundle.infoDictionary?["EXOBugReportPresignedUrlEndpoint"]
|
||||
as? String
|
||||
else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
let trimmedEndpointString = endpointString.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
guard !trimmedEndpointString.isEmpty, let endpoint = URL(string: trimmedEndpointString)
|
||||
else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
|
||||
var request = URLRequest(url: endpoint)
|
||||
request.httpMethod = "POST"
|
||||
request.timeoutInterval = 10
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
|
||||
let encoder = JSONEncoder()
|
||||
request.httpBody = try encoder.encode(PresignedUrlsRequest(keys: keys))
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse else {
|
||||
throw BugReportError.presignedUrlFailed("Non-HTTP response")
|
||||
}
|
||||
guard (200..<300).contains(http.statusCode) else {
|
||||
throw BugReportError.presignedUrlFailed("HTTP status \(http.statusCode)")
|
||||
}
|
||||
|
||||
let decoder = JSONDecoder()
|
||||
let decoded = try decoder.decode(PresignedUrlsResponse.self, from: data)
|
||||
return decoded.urls
|
||||
}
|
||||
|
||||
private func readLog() -> Data? {
|
||||
@@ -101,7 +155,8 @@ struct BugReportService {
|
||||
private func captureIfconfig() async throws -> String {
|
||||
let result = runCommand(["/sbin/ifconfig"])
|
||||
guard result.exitCode == 0 else {
|
||||
throw BugReportError.collectFailed(result.error.isEmpty ? "ifconfig failed" : result.error)
|
||||
throw BugReportError.collectFailed(
|
||||
result.error.isEmpty ? "ifconfig failed" : result.error)
|
||||
}
|
||||
return result.output
|
||||
}
|
||||
@@ -109,12 +164,23 @@ struct BugReportService {
|
||||
private func readDebugInfo() -> DebugInfo {
|
||||
DebugInfo(
|
||||
thunderboltBridgeDisabled: readThunderboltBridgeDisabled(),
|
||||
interfaces: readInterfaces()
|
||||
interfaces: readInterfaces(),
|
||||
rdma: readRDMADebugInfo()
|
||||
)
|
||||
}
|
||||
|
||||
private func readRDMADebugInfo() -> DebugInfo.RDMADebugInfo {
|
||||
DebugInfo.RDMADebugInfo(
|
||||
rdmaCtlStatus: safeRunCommand(["/usr/bin/rdma_ctl", "status"]),
|
||||
ibvDevices: safeRunCommand(["/usr/bin/ibv_devices"]),
|
||||
ibvDevinfo: safeRunCommand(["/usr/bin/ibv_devinfo"])
|
||||
)
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeDisabled() -> Bool? {
|
||||
let result = runCommand(["/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
|
||||
let result = runCommand([
|
||||
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
|
||||
])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased()
|
||||
if output.contains("enabled") {
|
||||
@@ -157,7 +223,8 @@ struct BugReportService {
|
||||
request.timeoutInterval = 5
|
||||
do {
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
@@ -166,6 +233,36 @@ struct BugReportService {
|
||||
}
|
||||
}
|
||||
|
||||
private func uploadToPresignedUrl(url: URL, body: Data) async throws {
|
||||
let maxAttempts = 2
|
||||
var lastError: Error?
|
||||
|
||||
for attempt in 1...maxAttempts {
|
||||
do {
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "PUT"
|
||||
request.httpBody = body
|
||||
request.timeoutInterval = 30
|
||||
|
||||
let (_, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse else {
|
||||
throw BugReportError.uploadFailed("Non-HTTP response")
|
||||
}
|
||||
guard (200..<300).contains(http.statusCode) else {
|
||||
throw BugReportError.uploadFailed("HTTP status \(http.statusCode)")
|
||||
}
|
||||
return
|
||||
} catch {
|
||||
lastError = error
|
||||
if attempt < maxAttempts {
|
||||
try await Task.sleep(nanoseconds: 400_000_000)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw BugReportError.uploadFailed(lastError?.localizedDescription ?? "Unknown error")
|
||||
}
|
||||
|
||||
private func makeReportJson(
|
||||
timestamp: String,
|
||||
hostName: String,
|
||||
@@ -183,7 +280,7 @@ struct BugReportService {
|
||||
"system": system,
|
||||
"exo_version": exo.version as Any,
|
||||
"exo_commit": exo.commit as Any,
|
||||
"report_type": isManual ? "manual" : "automated"
|
||||
"report_type": isManual ? "manual" : "automated",
|
||||
]
|
||||
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
|
||||
}
|
||||
@@ -214,10 +311,13 @@ struct BugReportService {
|
||||
let user = safeRunCommand(["/usr/bin/whoami"])
|
||||
let consoleUser = safeRunCommand(["/usr/bin/stat", "-f%Su", "/dev/console"])
|
||||
let uptime = safeRunCommand(["/usr/bin/uptime"])
|
||||
let diskRoot = safeRunCommand(["/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'"])
|
||||
let diskRoot = safeRunCommand([
|
||||
"/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'",
|
||||
])
|
||||
|
||||
let interfacesList = safeRunCommand(["/usr/sbin/ipconfig", "getiflist"])
|
||||
let interfacesAndIPs = interfacesList?
|
||||
let interfacesAndIPs =
|
||||
interfacesList?
|
||||
.split(whereSeparator: { $0 == " " || $0 == "\n" })
|
||||
.compactMap { iface -> [String: Any]? in
|
||||
let name = String(iface)
|
||||
@@ -228,7 +328,8 @@ struct BugReportService {
|
||||
} ?? []
|
||||
|
||||
let wifiSSID: String?
|
||||
let airportPath = "/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
|
||||
let airportPath =
|
||||
"/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
|
||||
if FileManager.default.isExecutableFile(atPath: airportPath) {
|
||||
wifiSSID = safeRunCommand([airportPath, "-I"]).flatMap(parseWifiSSID)
|
||||
} else {
|
||||
@@ -256,7 +357,7 @@ struct BugReportService {
|
||||
"disk_root": diskRoot as Any,
|
||||
"interfaces_and_ips": interfacesAndIPs,
|
||||
"ipconfig_getiflist": interfacesList as Any,
|
||||
"wifi_ssid": wifiSSID as Any
|
||||
"wifi_ssid": wifiSSID as Any,
|
||||
]
|
||||
}
|
||||
|
||||
@@ -314,7 +415,8 @@ struct BugReportService {
|
||||
for line in airportOutput.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("SSID:") {
|
||||
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(in: .whitespaces)
|
||||
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -351,6 +453,7 @@ struct BugReportService {
|
||||
private struct DebugInfo {
|
||||
let thunderboltBridgeDisabled: Bool?
|
||||
let interfaces: [InterfaceStatus]
|
||||
let rdma: RDMADebugInfo
|
||||
|
||||
struct InterfaceStatus {
|
||||
let name: String
|
||||
@@ -359,7 +462,21 @@ private struct DebugInfo {
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"name": name,
|
||||
"ip": ip as Any
|
||||
"ip": ip as Any,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
struct RDMADebugInfo {
|
||||
let rdmaCtlStatus: String?
|
||||
let ibvDevices: String?
|
||||
let ibvDevinfo: String?
|
||||
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"rdma_ctl_status": rdmaCtlStatus as Any,
|
||||
"ibv_devices": ibvDevices as Any,
|
||||
"ibv_devinfo": ibvDevinfo as Any,
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -367,7 +484,8 @@ private struct DebugInfo {
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"thunderbolt_bridge_disabled": thunderboltBridgeDisabled as Any,
|
||||
"interfaces": interfaces.map { $0.toDictionary() }
|
||||
"interfaces": interfaces.map { $0.toDictionary() },
|
||||
"rdma": rdma.toDictionary(),
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -377,163 +495,3 @@ private struct CommandResult {
|
||||
let output: String
|
||||
let error: String
|
||||
}
|
||||
|
||||
private struct S3Uploader {
|
||||
let config: BugReportService.AWSConfig
|
||||
|
||||
init(config: BugReportService.AWSConfig) throws {
|
||||
self.config = config
|
||||
}
|
||||
|
||||
func upload(objectPath: String, body: Data) async throws {
|
||||
let host = "\(config.bucket).s3.amazonaws.com"
|
||||
guard let url = URL(string: "https://\(host)/\(objectPath)") else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
|
||||
let now = Date()
|
||||
let amzDate = awsTimestamp(now)
|
||||
let dateStamp = dateStamp(now)
|
||||
let payloadHash = sha256Hex(body)
|
||||
|
||||
let headers = [
|
||||
"host": host,
|
||||
"x-amz-content-sha256": payloadHash,
|
||||
"x-amz-date": amzDate
|
||||
]
|
||||
|
||||
let canonicalRequest = buildCanonicalRequest(
|
||||
method: "PUT",
|
||||
url: url,
|
||||
headers: headers,
|
||||
payloadHash: payloadHash
|
||||
)
|
||||
|
||||
let stringToSign = buildStringToSign(
|
||||
amzDate: amzDate,
|
||||
dateStamp: dateStamp,
|
||||
canonicalRequestHash: sha256Hex(canonicalRequest.data(using: .utf8) ?? Data())
|
||||
)
|
||||
|
||||
let signingKey = deriveKey(secret: config.secretKey, dateStamp: dateStamp, region: config.region, service: "s3")
|
||||
let signature = hmacHex(key: signingKey, data: Data(stringToSign.utf8))
|
||||
|
||||
let signedHeaders = "host;x-amz-content-sha256;x-amz-date"
|
||||
let authorization = """
|
||||
AWS4-HMAC-SHA256 Credential=\(config.accessKey)/\(dateStamp)/\(config.region)/s3/aws4_request, SignedHeaders=\(signedHeaders), Signature=\(signature)
|
||||
"""
|
||||
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "PUT"
|
||||
request.httpBody = body
|
||||
request.setValue(headers["x-amz-content-sha256"], forHTTPHeaderField: "x-amz-content-sha256")
|
||||
request.setValue(headers["x-amz-date"], forHTTPHeaderField: "x-amz-date")
|
||||
request.setValue(host, forHTTPHeaderField: "Host")
|
||||
request.setValue(authorization, forHTTPHeaderField: "Authorization")
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
|
||||
let statusText = (response as? HTTPURLResponse)?.statusCode ?? -1
|
||||
_ = data // ignore response body for UX
|
||||
throw BugReportError.uploadFailed("HTTP status \(statusText)")
|
||||
}
|
||||
}
|
||||
|
||||
private func buildCanonicalRequest(
|
||||
method: String,
|
||||
url: URL,
|
||||
headers: [String: String],
|
||||
payloadHash: String
|
||||
) -> String {
|
||||
let canonicalURI = encodePath(url.path)
|
||||
let canonicalQuery = url.query ?? ""
|
||||
let sortedHeaders = headers.sorted { $0.key < $1.key }
|
||||
let canonicalHeaders = sortedHeaders
|
||||
.map { "\($0.key.lowercased()):\($0.value)\n" }
|
||||
.joined()
|
||||
let signedHeaders = sortedHeaders.map { $0.key.lowercased() }.joined(separator: ";")
|
||||
|
||||
return [
|
||||
method,
|
||||
canonicalURI,
|
||||
canonicalQuery,
|
||||
canonicalHeaders,
|
||||
signedHeaders,
|
||||
payloadHash
|
||||
].joined(separator: "\n")
|
||||
}
|
||||
|
||||
private func encodePath(_ path: String) -> String {
|
||||
return path
|
||||
.split(separator: "/")
|
||||
.map { segment in
|
||||
segment.addingPercentEncoding(withAllowedCharacters: Self.rfc3986) ?? String(segment)
|
||||
}
|
||||
.joined(separator: "/")
|
||||
.prependSlashIfNeeded()
|
||||
}
|
||||
|
||||
private func buildStringToSign(
|
||||
amzDate: String,
|
||||
dateStamp: String,
|
||||
canonicalRequestHash: String
|
||||
) -> String {
|
||||
"""
|
||||
AWS4-HMAC-SHA256
|
||||
\(amzDate)
|
||||
\(dateStamp)/\(config.region)/s3/aws4_request
|
||||
\(canonicalRequestHash)
|
||||
"""
|
||||
}
|
||||
|
||||
private func deriveKey(secret: String, dateStamp: String, region: String, service: String) -> Data {
|
||||
let kDate = hmac(key: Data(("AWS4" + secret).utf8), data: Data(dateStamp.utf8))
|
||||
let kRegion = hmac(key: kDate, data: Data(region.utf8))
|
||||
let kService = hmac(key: kRegion, data: Data(service.utf8))
|
||||
return hmac(key: kService, data: Data("aws4_request".utf8))
|
||||
}
|
||||
|
||||
private func hmac(key: Data, data: Data) -> Data {
|
||||
let keySym = SymmetricKey(data: key)
|
||||
let mac = HMAC<SHA256>.authenticationCode(for: data, using: keySym)
|
||||
return Data(mac)
|
||||
}
|
||||
|
||||
private func hmacHex(key: Data, data: Data) -> String {
|
||||
hmac(key: key, data: data).map { String(format: "%02x", $0) }.joined()
|
||||
}
|
||||
|
||||
private func sha256Hex(_ data: Data) -> String {
|
||||
let digest = SHA256.hash(data: data)
|
||||
return digest.compactMap { String(format: "%02x", $0) }.joined()
|
||||
}
|
||||
|
||||
private func awsTimestamp(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
|
||||
formatter.timeZone = TimeZone(abbreviation: "UTC")
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private func dateStamp(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.dateFormat = "yyyyMMdd"
|
||||
formatter.timeZone = TimeZone(abbreviation: "UTC")
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private static let rfc3986: CharacterSet = {
|
||||
var set = CharacterSet.alphanumerics
|
||||
set.insert(charactersIn: "-._~")
|
||||
return set
|
||||
}()
|
||||
}
|
||||
|
||||
private extension String {
|
||||
func prependSlashIfNeeded() -> String {
|
||||
if hasPrefix("/") {
|
||||
return self
|
||||
}
|
||||
return "/" + self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ final class ClusterStateService: ObservableObject {
|
||||
@Published private(set) var lastError: String?
|
||||
@Published private(set) var lastActionMessage: String?
|
||||
@Published private(set) var modelOptions: [ModelOption] = []
|
||||
@Published private(set) var localNodeId: String?
|
||||
|
||||
private var timer: Timer?
|
||||
private let decoder: JSONDecoder
|
||||
@@ -29,6 +30,7 @@ final class ClusterStateService: ObservableObject {
|
||||
func startPolling(interval: TimeInterval = 0.5) {
|
||||
stopPolling()
|
||||
Task {
|
||||
await fetchLocalNodeId()
|
||||
await fetchModels()
|
||||
await fetchSnapshot()
|
||||
}
|
||||
@@ -46,9 +48,33 @@ final class ClusterStateService: ObservableObject {
|
||||
latestSnapshot = nil
|
||||
lastError = nil
|
||||
lastActionMessage = nil
|
||||
localNodeId = nil
|
||||
}
|
||||
|
||||
private func fetchLocalNodeId() async {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("node_id")
|
||||
var request = URLRequest(url: url)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (data, response) = try await session.data(for: request)
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
return
|
||||
}
|
||||
if let nodeId = try? decoder.decode(String.self, from: data) {
|
||||
localNodeId = nodeId
|
||||
}
|
||||
} catch {
|
||||
// Silently ignore - localNodeId will remain nil and retry on next poll
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchSnapshot() async {
|
||||
// Retry fetching local node ID if not yet set
|
||||
if localNodeId == nil {
|
||||
await fetchLocalNodeId()
|
||||
}
|
||||
do {
|
||||
var request = URLRequest(url: endpoint)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
@@ -89,7 +115,9 @@ final class ClusterStateService: ObservableObject {
|
||||
}
|
||||
}
|
||||
|
||||
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int) async {
|
||||
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int)
|
||||
async
|
||||
{
|
||||
do {
|
||||
var request = URLRequest(url: baseURL.appendingPathComponent("instance"))
|
||||
request.httpMethod = "POST"
|
||||
@@ -98,7 +126,7 @@ final class ClusterStateService: ObservableObject {
|
||||
"model_id": modelId,
|
||||
"sharding": sharding,
|
||||
"instance_meta": instanceMeta,
|
||||
"min_nodes": minNodes
|
||||
"min_nodes": minNodes,
|
||||
]
|
||||
request.httpBody = try JSONSerialization.data(withJSONObject: payload, options: [])
|
||||
let (_, response) = try await session.data(for: request)
|
||||
@@ -119,7 +147,9 @@ final class ClusterStateService: ObservableObject {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("models")
|
||||
let (data, response) = try await session.data(from: url)
|
||||
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
throw URLError(.badServerResponse)
|
||||
}
|
||||
let list = try decoder.decode(ModelListResponse.self, from: data)
|
||||
|
||||
149
app/EXO/EXO/Services/LocalNetworkChecker.swift
Normal file
149
app/EXO/EXO/Services/LocalNetworkChecker.swift
Normal file
@@ -0,0 +1,149 @@
|
||||
import Foundation
|
||||
import Network
|
||||
import os.log
|
||||
|
||||
/// Checks if the app's local network permission is actually functional.
|
||||
///
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service uses NWConnection to mDNS multicast
|
||||
/// to verify actual connectivity.
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
case unknown
|
||||
case checking
|
||||
case working
|
||||
case notWorking(reason: String)
|
||||
|
||||
var isHealthy: Bool {
|
||||
if case .working = self { return true }
|
||||
return false
|
||||
}
|
||||
|
||||
var displayText: String {
|
||||
switch self {
|
||||
case .unknown:
|
||||
return "Unknown"
|
||||
case .checking:
|
||||
return "Checking..."
|
||||
case .working:
|
||||
return "Working"
|
||||
case .notWorking(let reason):
|
||||
return reason
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
func check() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
|
||||
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks connectivity using NWConnection to mDNS multicast.
|
||||
/// The connection attempt triggers the permission prompt if not yet shown.
|
||||
private func checkConnectivity(timeout: UInt64) async -> Status {
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
|
||||
// mDNS multicast address - same as libp2p uses for peer discovery
|
||||
let host = NWEndpoint.Host("224.0.0.251")
|
||||
let port = NWEndpoint.Port(integerLiteral: 5353)
|
||||
|
||||
let params = NWParameters.udp
|
||||
params.allowLocalEndpointReuse = true
|
||||
|
||||
let conn = NWConnection(host: host, port: port, using: params)
|
||||
connection = conn
|
||||
|
||||
return await withCheckedContinuation { continuation in
|
||||
var hasResumed = false
|
||||
let lock = NSLock()
|
||||
|
||||
let resumeOnce: (Status) -> Void = { status in
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
guard !hasResumed else { return }
|
||||
hasResumed = true
|
||||
continuation.resume(returning: status)
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { state in
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
case .waiting(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
|
||||
resumeOnce(.notWorking(reason: "Connection blocked"))
|
||||
}
|
||||
// Otherwise keep waiting - might be showing permission prompt
|
||||
case .failed(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|
||||
|| errorStr.contains("permission") || errorStr.contains("denied")
|
||||
{
|
||||
resumeOnce(.notWorking(reason: "Permission denied"))
|
||||
} else {
|
||||
resumeOnce(.notWorking(reason: "Failed: \(error.localizedDescription)"))
|
||||
}
|
||||
case .cancelled, .setup, .preparing:
|
||||
break
|
||||
@unknown default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
conn.start(queue: .main)
|
||||
|
||||
Task {
|
||||
try? await Task.sleep(nanoseconds: timeout)
|
||||
let state = conn.state
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
case .waiting, .preparing, .setup:
|
||||
resumeOnce(.notWorking(reason: "Timeout (may be blocked)"))
|
||||
default:
|
||||
resumeOnce(.notWorking(reason: "Timeout"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stop() {
|
||||
checkTask?.cancel()
|
||||
checkTask = nil
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
}
|
||||
}
|
||||
@@ -5,64 +5,66 @@ import os.log
|
||||
enum NetworkSetupHelper {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
|
||||
private static let daemonLabel = "io.exo.networksetup"
|
||||
private static let scriptDestination = "/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let scriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1791
|
||||
|
||||
private static let setupScript = """
|
||||
#!/usr/bin/env bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
|
||||
static func ensureLaunchDaemonInstalled() {
|
||||
Task.detached {
|
||||
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
|
||||
Task.detached(priority: .utility) {
|
||||
do {
|
||||
if daemonAlreadyInstalled() {
|
||||
return
|
||||
@@ -70,11 +72,70 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
try await installLaunchDaemon()
|
||||
logger.info("Network setup launch daemon installed and started")
|
||||
} catch {
|
||||
logger.error("Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)")
|
||||
logger.error(
|
||||
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes all EXO network setup components from the system.
|
||||
/// This includes the LaunchDaemon, scripts, logs, and network location.
|
||||
/// Requires admin privileges.
|
||||
static func uninstall() throws {
|
||||
let uninstallScript = makeUninstallScript()
|
||||
try runShellAsAdmin(uninstallScript)
|
||||
logger.info("EXO network setup components removed successfully")
|
||||
}
|
||||
|
||||
/// Checks if there are any EXO network components installed that need cleanup
|
||||
static func hasInstalledComponents() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
return scriptExists || plistExists
|
||||
}
|
||||
|
||||
private static func makeUninstallScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LOG_OUT="/var/log/\(daemonLabel).log"
|
||||
LOG_ERR="/var/log/\(daemonLabel).err.log"
|
||||
|
||||
# Unload the LaunchDaemon if running
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
|
||||
# Remove LaunchDaemon plist
|
||||
rm -f "$PLIST_DEST"
|
||||
|
||||
# Remove the script and parent directory if empty
|
||||
rm -f "$SCRIPT_DEST"
|
||||
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
|
||||
|
||||
# Remove log files
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
|
||||
# Switch back to Automatic network location
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
networksetup -listlocations | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
} || true
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
} || true
|
||||
|
||||
echo "EXO network components removed successfully"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func daemonAlreadyInstalled() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
@@ -82,7 +143,8 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
guard scriptExists, plistExists else { return false }
|
||||
guard
|
||||
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
|
||||
let plist = try? PropertyListSerialization.propertyList(from: data, options: [], format: nil) as? [String: Any]
|
||||
let plist = try? PropertyListSerialization.propertyList(
|
||||
from: data, options: [], format: nil) as? [String: Any]
|
||||
else {
|
||||
return false
|
||||
}
|
||||
@@ -92,7 +154,9 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
else {
|
||||
return false
|
||||
}
|
||||
if let programArgs = plist["ProgramArguments"] as? [String], programArgs.contains(scriptDestination) == false {
|
||||
if let programArgs = plist["ProgramArguments"] as? [String],
|
||||
programArgs.contains(scriptDestination) == false
|
||||
{
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -105,58 +169,59 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
|
||||
private static func makeInstallerScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func runShellAsAdmin(_ script: String) throws {
|
||||
let escapedScript = script
|
||||
let escapedScript =
|
||||
script
|
||||
.replacingOccurrences(of: "\\", with: "\\\\")
|
||||
.replacingOccurrences(of: "\"", with: "\\\"")
|
||||
|
||||
let appleScriptSource = """
|
||||
do shell script "\(escapedScript)" with administrator privileges
|
||||
"""
|
||||
do shell script "\(escapedScript)" with administrator privileges
|
||||
"""
|
||||
|
||||
guard let appleScript = NSAppleScript(source: appleScriptSource) else {
|
||||
throw NetworkSetupError.scriptCreationFailed
|
||||
|
||||
@@ -35,14 +35,34 @@ struct NetworkStatus: Equatable {
|
||||
let thunderboltBridgeState: ThunderboltState?
|
||||
let bridgeInactive: Bool?
|
||||
let interfaceStatuses: [InterfaceIpStatus]
|
||||
let rdmaStatus: RDMAStatus
|
||||
|
||||
static let empty = NetworkStatus(
|
||||
thunderboltBridgeState: nil,
|
||||
bridgeInactive: nil,
|
||||
interfaceStatuses: []
|
||||
interfaceStatuses: [],
|
||||
rdmaStatus: .empty
|
||||
)
|
||||
}
|
||||
|
||||
struct RDMAStatus: Equatable {
|
||||
let rdmaCtlEnabled: Bool?
|
||||
let devices: [String]
|
||||
let activePorts: [RDMAPort]
|
||||
|
||||
var isAvailable: Bool {
|
||||
rdmaCtlEnabled == true || !devices.isEmpty
|
||||
}
|
||||
|
||||
static let empty = RDMAStatus(rdmaCtlEnabled: nil, devices: [], activePorts: [])
|
||||
}
|
||||
|
||||
struct RDMAPort: Equatable {
|
||||
let device: String
|
||||
let port: String
|
||||
let state: String
|
||||
}
|
||||
|
||||
struct InterfaceIpStatus: Equatable {
|
||||
let interfaceName: String
|
||||
let ipAddress: String?
|
||||
@@ -59,10 +79,79 @@ private struct NetworkStatusFetcher {
|
||||
NetworkStatus(
|
||||
thunderboltBridgeState: readThunderboltBridgeState(),
|
||||
bridgeInactive: readBridgeInactive(),
|
||||
interfaceStatuses: readInterfaceStatuses()
|
||||
interfaceStatuses: readInterfaceStatuses(),
|
||||
rdmaStatus: readRDMAStatus()
|
||||
)
|
||||
}
|
||||
|
||||
private func readRDMAStatus() -> RDMAStatus {
|
||||
let rdmaCtlEnabled = readRDMACtlEnabled()
|
||||
let devices = readRDMADevices()
|
||||
let activePorts = readRDMAActivePorts()
|
||||
return RDMAStatus(
|
||||
rdmaCtlEnabled: rdmaCtlEnabled, devices: devices, activePorts: activePorts)
|
||||
}
|
||||
|
||||
private func readRDMACtlEnabled() -> Bool? {
|
||||
let result = runCommand(["rdma_ctl", "status"])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
if output.contains("enabled") {
|
||||
return true
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
private func readRDMADevices() -> [String] {
|
||||
let result = runCommand(["ibv_devices"])
|
||||
guard result.exitCode == 0 else { return [] }
|
||||
var devices: [String] = []
|
||||
for line in result.output.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("---") || trimmed.lowercased().hasPrefix("device")
|
||||
|| trimmed.isEmpty
|
||||
{
|
||||
continue
|
||||
}
|
||||
let parts = trimmed.split(separator: " ", maxSplits: 1)
|
||||
if let deviceName = parts.first {
|
||||
devices.append(String(deviceName))
|
||||
}
|
||||
}
|
||||
return devices
|
||||
}
|
||||
|
||||
private func readRDMAActivePorts() -> [RDMAPort] {
|
||||
let result = runCommand(["ibv_devinfo"])
|
||||
guard result.exitCode == 0 else { return [] }
|
||||
var ports: [RDMAPort] = []
|
||||
var currentDevice: String?
|
||||
var currentPort: String?
|
||||
|
||||
for line in result.output.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("hca_id:") {
|
||||
currentDevice = trimmed.replacingOccurrences(of: "hca_id:", with: "")
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
} else if trimmed.hasPrefix("port:") {
|
||||
currentPort = trimmed.replacingOccurrences(of: "port:", with: "")
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
} else if trimmed.hasPrefix("state:") {
|
||||
let state = trimmed.replacingOccurrences(of: "state:", with: "").trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
if let device = currentDevice, let port = currentPort {
|
||||
if state.lowercased().contains("active") {
|
||||
ports.append(RDMAPort(device: device, port: port, state: state))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ports
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeState() -> ThunderboltState? {
|
||||
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
|
||||
guard result.exitCode == 0 else {
|
||||
@@ -85,10 +174,11 @@ private struct NetworkStatusFetcher {
|
||||
private func readBridgeInactive() -> Bool? {
|
||||
let result = runCommand(["ifconfig", "bridge0"])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
guard let statusLine = result.output
|
||||
.components(separatedBy: .newlines)
|
||||
.first(where: { $0.contains("status:") })?
|
||||
.lowercased()
|
||||
guard
|
||||
let statusLine = result.output
|
||||
.components(separatedBy: .newlines)
|
||||
.first(where: { $0.contains("status:") })?
|
||||
.lowercased()
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
@@ -171,4 +261,3 @@ private struct NetworkStatusFetcher {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ struct InstanceViewModel: Identifiable, Equatable {
|
||||
case waiting
|
||||
case failed
|
||||
case idle
|
||||
case unknown
|
||||
case preparing
|
||||
|
||||
var label: String {
|
||||
switch self {
|
||||
@@ -68,7 +68,7 @@ struct InstanceViewModel: Identifiable, Equatable {
|
||||
case .waiting: return "Waiting"
|
||||
case .failed: return "Failed"
|
||||
case .idle: return "Idle"
|
||||
case .unknown: return "Unknown"
|
||||
case .preparing: return "Preparing"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -107,10 +107,13 @@ extension ClusterState {
|
||||
let nodeToRunner = instance.shardAssignments.nodeToRunner
|
||||
let nodeIds = Array(nodeToRunner.keys)
|
||||
let runnerIds = Array(nodeToRunner.values)
|
||||
let nodeNames = nodeIds.compactMap { nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0 }
|
||||
let nodeNames = nodeIds.compactMap {
|
||||
nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0
|
||||
}
|
||||
let statuses = runnerIds.compactMap { runners[$0]?.status.lowercased() }
|
||||
let downloadProgress = aggregateDownloadProgress(for: nodeIds)
|
||||
let state = InstanceViewModel.State(statuses: statuses, hasActiveDownload: downloadProgress != nil)
|
||||
let state = InstanceViewModel.State(
|
||||
statuses: statuses, hasActiveDownload: downloadProgress != nil)
|
||||
let chatTasks = (chatTasksByInstance[entry.key] ?? [])
|
||||
.sorted(by: { $0.sortPriority < $1.sortPriority })
|
||||
.map { InstanceTaskViewModel(task: $0) }
|
||||
@@ -165,8 +168,8 @@ extension ClusterState {
|
||||
}
|
||||
}
|
||||
|
||||
private extension InstanceViewModel.State {
|
||||
init(statuses: [String], hasActiveDownload: Bool = false) {
|
||||
extension InstanceViewModel.State {
|
||||
fileprivate init(statuses: [String], hasActiveDownload: Bool = false) {
|
||||
if statuses.contains(where: { $0.contains("failed") }) {
|
||||
self = .failed
|
||||
} else if hasActiveDownload || statuses.contains(where: { $0.contains("downloading") }) {
|
||||
@@ -182,7 +185,7 @@ private extension InstanceViewModel.State {
|
||||
} else if statuses.isEmpty {
|
||||
self = .idle
|
||||
} else {
|
||||
self = .unknown
|
||||
self = .preparing
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -243,4 +246,3 @@ extension InstanceTaskViewModel {
|
||||
self.parameters = task.parameters
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -85,9 +85,11 @@ struct TopologyViewModel {
|
||||
}
|
||||
|
||||
extension ClusterState {
|
||||
func topologyViewModel() -> TopologyViewModel? {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
|
||||
let allNodes = nodeViewModels().filter {
|
||||
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
|
||||
}
|
||||
guard !allNodes.isEmpty else { return nil }
|
||||
|
||||
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
|
||||
@@ -105,17 +107,25 @@ extension ClusterState {
|
||||
orderedNodes = allNodes
|
||||
}
|
||||
|
||||
// Rotate so the local node (from /node_id API) is first
|
||||
if let localId = localNodeId,
|
||||
let index = orderedNodes.firstIndex(where: { $0.id == localId })
|
||||
{
|
||||
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
|
||||
}
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
|
||||
return TopologyEdgeViewModel(sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
|
||||
} ?? []
|
||||
let edgesArray: [TopologyEdgeViewModel] =
|
||||
topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId),
|
||||
nodeIds.contains(connection.sendBackNodeId)
|
||||
else { return nil }
|
||||
return TopologyEdgeViewModel(
|
||||
sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
|
||||
} ?? []
|
||||
let edges = Set(edgesArray)
|
||||
|
||||
let topologyRootId = topology?.nodes.first?.nodeId
|
||||
let currentId = orderedNodes.first(where: { $0.id == topologyRootId })?.id ?? orderedNodes.first?.id
|
||||
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: currentId)
|
||||
return TopologyViewModel(
|
||||
nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ struct InstanceRowView: View {
|
||||
if let progress = instance.downloadProgress {
|
||||
downloadStatusView(progress: progress)
|
||||
} else {
|
||||
statusChip(label: instance.state.label.uppercased(), color: statusColor)
|
||||
}
|
||||
statusChip(label: instance.state.label.uppercased(), color: statusColor)
|
||||
}
|
||||
}
|
||||
if let progress = instance.downloadProgress {
|
||||
GeometryReader { geometry in
|
||||
@@ -83,7 +83,7 @@ struct InstanceRowView: View {
|
||||
case .ready: return .teal
|
||||
case .waiting, .idle: return .gray
|
||||
case .failed: return .red
|
||||
case .unknown: return .secondary
|
||||
case .preparing: return .secondary
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,7 +97,8 @@ struct InstanceRowView: View {
|
||||
.font(.caption)
|
||||
.fontWeight(.semibold)
|
||||
if let subtitle = task.subtitle,
|
||||
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame {
|
||||
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame
|
||||
{
|
||||
Text(subtitle)
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
@@ -234,9 +235,12 @@ struct InstanceRowView: View {
|
||||
Button {
|
||||
isExpanded.wrappedValue.toggle()
|
||||
} label: {
|
||||
Label(isExpanded.wrappedValue ? "Hide" : "Show", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
Label(
|
||||
isExpanded.wrappedValue ? "Hide" : "Show",
|
||||
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
|
||||
)
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.font(.caption2)
|
||||
@@ -311,7 +315,9 @@ struct InstanceRowView: View {
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private func detailRow(icon: String? = nil, title: String, value: String, tint: Color = .secondary) -> some View {
|
||||
private func detailRow(
|
||||
icon: String? = nil, title: String, value: String, tint: Color = .secondary
|
||||
) -> some View {
|
||||
HStack(alignment: .firstTextBaseline, spacing: 6) {
|
||||
if let icon {
|
||||
Image(systemName: icon)
|
||||
@@ -329,4 +335,3 @@ struct InstanceRowView: View {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,4 +32,3 @@ struct NodeDetailView: View {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,4 +28,3 @@ struct NodeRowView: View {
|
||||
.padding(.vertical, 4)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -76,30 +76,33 @@ struct TopologyMiniView: View {
|
||||
|
||||
private func connectionLines(in size: CGSize) -> some View {
|
||||
let positions = positionedNodes(in: size)
|
||||
let positionById = Dictionary(uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
|
||||
let positionById = Dictionary(
|
||||
uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
|
||||
return Canvas { context, _ in
|
||||
guard !topology.edges.isEmpty else { return }
|
||||
let nodeRadius: CGFloat = 32
|
||||
let arrowLength: CGFloat = 10
|
||||
let arrowSpread: CGFloat = .pi / 7
|
||||
for edge in topology.edges {
|
||||
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId] else { continue }
|
||||
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId]
|
||||
else { continue }
|
||||
let dx = end.x - start.x
|
||||
let dy = end.y - start.y
|
||||
let distance = max(CGFloat(hypot(dx, dy)), 1)
|
||||
let ux = dx / distance
|
||||
let uy = dy / distance
|
||||
let adjustedStart = CGPoint(x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
|
||||
let adjustedStart = CGPoint(
|
||||
x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
|
||||
let adjustedEnd = CGPoint(x: end.x - ux * nodeRadius, y: end.y - uy * nodeRadius)
|
||||
|
||||
var linePath = Path()
|
||||
linePath.move(to: adjustedStart)
|
||||
linePath.addLine(to: adjustedEnd)
|
||||
context.stroke(
|
||||
context.stroke(
|
||||
linePath,
|
||||
with: .color(.secondary.opacity(0.3)),
|
||||
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
|
||||
)
|
||||
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
|
||||
)
|
||||
|
||||
let angle = atan2(uy, ux)
|
||||
let tip = adjustedEnd
|
||||
@@ -168,5 +171,3 @@ private struct NodeGlyphView: View {
|
||||
.frame(width: 95)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
//
|
||||
|
||||
import Testing
|
||||
|
||||
@testable import EXO
|
||||
|
||||
struct EXOTests {
|
||||
|
||||
154
app/EXO/uninstall-exo.sh
Executable file
154
app/EXO/uninstall-exo.sh
Executable file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# EXO Uninstaller Script
|
||||
#
|
||||
# This script removes all EXO system components that persist after deleting the app.
|
||||
# Run with: sudo ./uninstall-exo.sh
|
||||
#
|
||||
# Components removed:
|
||||
# - LaunchDaemon: /Library/LaunchDaemons/io.exo.networksetup.plist
|
||||
# - Network script: /Library/Application Support/EXO/
|
||||
# - Log files: /var/log/io.exo.networksetup.*
|
||||
# - Network location: "exo"
|
||||
# - Launch at login registration
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="io.exo.networksetup"
|
||||
SCRIPT_DEST="/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
PLIST_DEST="/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
LOG_OUT="/var/log/${LABEL}.log"
|
||||
LOG_ERR="/var/log/${LABEL}.err.log"
|
||||
APP_BUNDLE_ID="io.exo.EXO"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
echo_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
echo_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if running as root
|
||||
if [[ $EUID -ne 0 ]]; then
|
||||
echo_error "This script must be run as root (use sudo)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " EXO Uninstaller"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
# Unload the LaunchDaemon if running
|
||||
echo_info "Stopping network setup daemon..."
|
||||
if launchctl list | grep -q "$LABEL"; then
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
echo_info "Daemon stopped"
|
||||
else
|
||||
echo_warn "Daemon was not running"
|
||||
fi
|
||||
|
||||
# Remove LaunchDaemon plist
|
||||
if [[ -f "$PLIST_DEST" ]]; then
|
||||
rm -f "$PLIST_DEST"
|
||||
echo_info "Removed LaunchDaemon plist"
|
||||
else
|
||||
echo_warn "LaunchDaemon plist not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove the script and parent directory
|
||||
if [[ -f "$SCRIPT_DEST" ]]; then
|
||||
rm -f "$SCRIPT_DEST"
|
||||
echo_info "Removed network setup script"
|
||||
else
|
||||
echo_warn "Network setup script not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove EXO directory if empty
|
||||
if [[ -d "/Library/Application Support/EXO" ]]; then
|
||||
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
|
||||
echo_info "Removed EXO support directory" || \
|
||||
echo_warn "EXO support directory not empty, leaving in place"
|
||||
fi
|
||||
|
||||
# Remove log files
|
||||
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
echo_info "Removed log files"
|
||||
else
|
||||
echo_warn "Log files not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Switch back to Automatic network location
|
||||
echo_info "Restoring network configuration..."
|
||||
if networksetup -listlocations | grep -q "^Automatic$"; then
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
echo_info "Switched to Automatic network location"
|
||||
else
|
||||
echo_warn "Automatic network location not found"
|
||||
fi
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
if networksetup -listlocations | grep -q "^exo$"; then
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
echo_info "Deleted 'exo' network location"
|
||||
else
|
||||
echo_warn "'exo' network location not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
echo_info "Re-enabled Thunderbolt Bridge"
|
||||
fi
|
||||
|
||||
# Note about launch at login registration
|
||||
# SMAppService-based login items cannot be removed from a shell script.
|
||||
# They can only be unregistered from within the app itself or manually via System Settings.
|
||||
echo_warn "Launch at login must be removed manually:"
|
||||
echo_warn " System Settings → General → Login Items → Remove EXO"
|
||||
|
||||
# Check if EXO.app exists in common locations
|
||||
APP_FOUND=false
|
||||
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
|
||||
if [[ -d "$app_path" ]]; then
|
||||
if [[ "$APP_FOUND" == false ]]; then
|
||||
echo ""
|
||||
APP_FOUND=true
|
||||
fi
|
||||
echo_warn "EXO.app found at: $app_path"
|
||||
echo_warn "You may want to move it to Trash manually."
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo_info "EXO uninstall complete!"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
echo "The following have been removed:"
|
||||
echo " • Network setup LaunchDaemon"
|
||||
echo " • Network configuration script"
|
||||
echo " • Log files"
|
||||
echo " • 'exo' network location"
|
||||
echo ""
|
||||
echo "Your network has been restored to use the 'Automatic' location."
|
||||
echo "Thunderbolt Bridge has been re-enabled (if present)."
|
||||
echo ""
|
||||
echo "Manual step required:"
|
||||
echo " Remove EXO from Login Items in System Settings → General → Login Items"
|
||||
echo ""
|
||||
|
||||
529
bench/exo_bench.py
Normal file
529
bench/exo_bench.py
Normal file
@@ -0,0 +1,529 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def format_peak_memory(b: float) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if b < 1024.0:
|
||||
return f"{b:.2f}{unit}"
|
||||
b /= 1024.0
|
||||
raise ValueError("You're using petabytes of memory. Something went wrong...")
|
||||
|
||||
|
||||
def parse_int_list(values: list[str]) -> list[int]:
|
||||
items: list[int] = []
|
||||
for v in values:
|
||||
for part in v.split(","):
|
||||
part = part.strip()
|
||||
if part:
|
||||
items.append(int(part))
|
||||
|
||||
seen: set[int] = set()
|
||||
out: list[int] = []
|
||||
for x in items:
|
||||
if x not in seen:
|
||||
out.append(x)
|
||||
seen.add(x)
|
||||
return out
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if m.get("id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
full_id = str(m.get("hugging_face_id") or m["id"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def run_one_completion(
|
||||
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
content, pp_tokens = prompt_sizer.build(pp_hint)
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"max_tokens": tg,
|
||||
}
|
||||
|
||||
t0 = time.perf_counter()
|
||||
out = client.post_bench_chat_completions(payload)
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
stats = out.get("generation_stats")
|
||||
|
||||
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
|
||||
|
||||
return {
|
||||
"elapsed_s": elapsed,
|
||||
"output_text_preview": preview,
|
||||
"stats": stats,
|
||||
}, pp_tokens
|
||||
|
||||
|
||||
class PromptSizer:
|
||||
def __init__(self, tokenizer: Any, atom: str = "a "):
|
||||
self.tokenizer = tokenizer
|
||||
self.atom = atom
|
||||
self.count_fn = PromptSizer._make_counter(tokenizer)
|
||||
self.base_tokens = self.count_fn("")
|
||||
|
||||
@staticmethod
|
||||
def _make_counter(tokenizer: Any) -> Callable[[str], int]:
|
||||
def count_fn(user_content: str) -> int:
|
||||
messages = [{"role": "user", "content": user_content}]
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# Fix for transformers 5.x
|
||||
if hasattr(ids, "input_ids"):
|
||||
ids = ids.input_ids
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
|
||||
def build(self, target_prompt_tokens: int) -> tuple[str, int]:
|
||||
target = int(target_prompt_tokens)
|
||||
if target < self.base_tokens:
|
||||
raise RuntimeError(
|
||||
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
|
||||
)
|
||||
|
||||
content = ""
|
||||
tok = self.count_fn(content)
|
||||
|
||||
while tok < target:
|
||||
content += self.atom
|
||||
tok = self.count_fn(content)
|
||||
|
||||
if tok != target:
|
||||
raise RuntimeError(
|
||||
f"Overshot: got {tok} tokens (target {target}). "
|
||||
f"Pick a different atom (try ' a' or '\\n' or '0 ')."
|
||||
)
|
||||
|
||||
return content, tok
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
description="Benchmark exo model throughput across placement previews.",
|
||||
)
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--pp",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Prompt-size hints (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--tg",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Generation lengths (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Pipeline jaccl is often pointless, skip by default",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--warmup",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
default="bench/results.json",
|
||||
help="Write raw per-run results JSON to this path.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dry-run", action="store_true", help="List selected placements and exit."
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
tg_list = parse_int_list(args.tg)
|
||||
if not pp_list or not tg_list:
|
||||
logger.error("pp and tg lists must be non-empty")
|
||||
return 2
|
||||
if args.repeat <= 0:
|
||||
logger.error("--repeat must be >= 1")
|
||||
return 2
|
||||
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": short_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
full_model_id,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("[exo-bench] tokenizer load failed")
|
||||
|
||||
try:
|
||||
prompt_sizer = PromptSizer(tokenizer)
|
||||
logger.debug(f"[exo-bench] loaded tokenizer: {full_model_id} for prompt sizer")
|
||||
except Exception:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if 0 < n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
|
||||
selected.sort(
|
||||
key=lambda p: (
|
||||
str(p.get("instance_meta", "")),
|
||||
str(p.get("sharding", "")),
|
||||
-nodes_used_in_instance(p["instance"]),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
|
||||
logger.info(f"placements: {len(selected)}")
|
||||
for p in selected:
|
||||
logger.info(
|
||||
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
|
||||
)
|
||||
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
|
||||
sharding = str(preview["sharding"])
|
||||
instance_meta = str(preview["instance_meta"])
|
||||
n_nodes = nodes_used_in_instance(instance)
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info(
|
||||
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
for i in range(args.warmup):
|
||||
run_one_completion(
|
||||
client, full_model_id, pp_list[0], tg_list[0], prompt_sizer
|
||||
)
|
||||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||||
|
||||
for pp in pp_list:
|
||||
if (
|
||||
pp * n_nodes > 2048
|
||||
and "ring" in instance_meta.lower()
|
||||
and "tensor" in sharding.lower()
|
||||
):
|
||||
model_card = MODEL_CARDS[short_id]
|
||||
if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
logger.info(
|
||||
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
)
|
||||
continue
|
||||
for tg in tg_list:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
time.sleep(3)
|
||||
try:
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client, full_model_id, pp, tg, prompt_sizer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
row.update(
|
||||
{
|
||||
"model_short_id": short_id,
|
||||
"model_id": full_model_id,
|
||||
"placement_sharding": sharding,
|
||||
"placement_instance_meta": instance_meta,
|
||||
"placement_nodes": n_nodes,
|
||||
"instance_id": instance_id,
|
||||
"pp_tokens": actual_pp_tokens,
|
||||
"tg": tg,
|
||||
"repeat_index": r,
|
||||
}
|
||||
)
|
||||
runs.append(row)
|
||||
all_rows.append(row)
|
||||
|
||||
if runs:
|
||||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||||
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
f"peak_memory={format_peak_memory(peak)}\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
finally:
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
logger.debug(f"Deleted instance {instance_id}")
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
if args.json_out:
|
||||
with open(args.json_out, "w", encoding="utf-8") as f:
|
||||
json.dump(all_rows, f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"\nWrote results JSON: {args.json_out}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
60
dashboard/dashboard.nix
Normal file
60
dashboard/dashboard.nix
Normal file
@@ -0,0 +1,60 @@
|
||||
{ lib
|
||||
, config
|
||||
, dream2nix
|
||||
, ...
|
||||
}:
|
||||
let
|
||||
# Read and parse the lock file
|
||||
rawLockFile = builtins.fromJSON (builtins.readFile "${config.deps.dashboardSrc}/package-lock.json");
|
||||
|
||||
# For packages with bundleDependencies, filter out deps that are bundled
|
||||
# (bundled deps are inside the tarball, not separate lockfile entries)
|
||||
fixedPackages = lib.mapAttrs
|
||||
(path: entry:
|
||||
if entry ? bundleDependencies && entry.bundleDependencies != [ ]
|
||||
then entry // {
|
||||
dependencies = lib.filterAttrs
|
||||
(name: _: !(lib.elem name entry.bundleDependencies))
|
||||
(entry.dependencies or { });
|
||||
}
|
||||
else entry
|
||||
)
|
||||
(rawLockFile.packages or { });
|
||||
|
||||
fixedLockFile = rawLockFile // { packages = fixedPackages; };
|
||||
in
|
||||
{
|
||||
imports = [
|
||||
dream2nix.modules.dream2nix.nodejs-package-lock-v3
|
||||
dream2nix.modules.dream2nix.nodejs-granular-v3
|
||||
];
|
||||
|
||||
name = "exo-dashboard";
|
||||
version = "1.0.0";
|
||||
|
||||
mkDerivation = {
|
||||
src = config.deps.dashboardSrc;
|
||||
|
||||
buildPhase = ''
|
||||
runHook preBuild
|
||||
npm run build
|
||||
runHook postBuild
|
||||
'';
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
cp -r build $out/build
|
||||
runHook postInstall
|
||||
'';
|
||||
};
|
||||
|
||||
deps = { nixpkgs, ... }: {
|
||||
inherit (nixpkgs) stdenv;
|
||||
dashboardSrc = null; # Injected by parts.nix
|
||||
};
|
||||
|
||||
nodejs-package-lock-v3 = {
|
||||
# Don't use packageLockFile - provide the fixed lock content directly
|
||||
packageLock = fixedLockFile;
|
||||
};
|
||||
}
|
||||
39
dashboard/package-lock.json
generated
39
dashboard/package-lock.json
generated
@@ -9,6 +9,8 @@
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -2254,6 +2256,31 @@
|
||||
"jiti": "lib/jiti-cli.mjs"
|
||||
}
|
||||
},
|
||||
"node_modules/katex": {
|
||||
"version": "0.16.27",
|
||||
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.27.tgz",
|
||||
"integrity": "sha512-aeQoDkuRWSqQN6nSvVCEFvfXdqo1OQiCmmW1kc9xSdjutPv7BGO7pqY9sQRJpMOGrEdfDgF2TfRXe5eUAD2Waw==",
|
||||
"funding": [
|
||||
"https://opencollective.com/katex",
|
||||
"https://github.com/sponsors/katex"
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"commander": "^8.3.0"
|
||||
},
|
||||
"bin": {
|
||||
"katex": "cli.js"
|
||||
}
|
||||
},
|
||||
"node_modules/katex/node_modules/commander": {
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz",
|
||||
"integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
}
|
||||
},
|
||||
"node_modules/kleur": {
|
||||
"version": "4.1.5",
|
||||
"resolved": "https://registry.npmjs.org/kleur/-/kleur-4.1.5.tgz",
|
||||
@@ -2540,6 +2567,18 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.5.5"
|
||||
}
|
||||
},
|
||||
"node_modules/marked": {
|
||||
"version": "17.0.1",
|
||||
"resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz",
|
||||
"integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==",
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"marked": "bin/marked.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 20"
|
||||
}
|
||||
},
|
||||
"node_modules/mode-watcher": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/mode-watcher/-/mode-watcher-1.1.0.tgz",
|
||||
|
||||
@@ -27,7 +27,8 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
44
dashboard/parts.nix
Normal file
44
dashboard/parts.nix
Normal file
@@ -0,0 +1,44 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Filter source to only include dashboard directory
|
||||
src = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
inDashboardDir =
|
||||
(lib.hasInfix "/dashboard/" path)
|
||||
|| (lib.hasSuffix "/dashboard" (builtins.dirOf path))
|
||||
|| (baseName == "dashboard" && type == "directory");
|
||||
in
|
||||
inDashboardDir;
|
||||
};
|
||||
|
||||
# Build the dashboard with dream2nix (includes node_modules in output)
|
||||
dashboardFull = inputs.dream2nix.lib.evalModules {
|
||||
packageSets.nixpkgs = pkgs;
|
||||
modules = [
|
||||
./dashboard.nix
|
||||
{
|
||||
paths.projectRoot = inputs.self;
|
||||
paths.projectRootFile = "flake.nix";
|
||||
paths.package = inputs.self + "/dashboard";
|
||||
}
|
||||
# Inject the filtered source
|
||||
{
|
||||
deps.dashboardSrc = lib.mkForce "${src}/dashboard";
|
||||
}
|
||||
];
|
||||
};
|
||||
in
|
||||
{
|
||||
# Extract just the static site from the full build
|
||||
packages.dashboard = pkgs.runCommand "exo-dashboard" { } ''
|
||||
cp -r ${dashboardFull}/build $out
|
||||
'';
|
||||
};
|
||||
}
|
||||
1
dashboard/src/app.d.ts
vendored
1
dashboard/src/app.d.ts
vendored
@@ -11,4 +11,3 @@ declare global {
|
||||
}
|
||||
|
||||
export {};
|
||||
|
||||
|
||||
@@ -60,12 +60,39 @@
|
||||
return models;
|
||||
});
|
||||
|
||||
// Auto-select the first available model if none is selected
|
||||
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
|
||||
let previousModelIds: Set<string> = new Set();
|
||||
|
||||
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
|
||||
$effect(() => {
|
||||
const models = availableModels();
|
||||
if (models.length > 0 && !currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
const currentModelIds = new Set(models.map(m => m.id));
|
||||
|
||||
if (models.length > 0) {
|
||||
// Find newly added models (in current but not in previous)
|
||||
const newModels = models.filter(m => !previousModelIds.has(m.id));
|
||||
|
||||
// If no model selected, select the first available
|
||||
if (!currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If current model is stale (no longer has a running instance), reset to first available
|
||||
else if (!models.some(m => m.id === currentModel)) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If a new model was just added, select it
|
||||
else if (newModels.length > 0 && previousModelIds.size > 0) {
|
||||
setSelectedChatModel(newModels[0].id);
|
||||
}
|
||||
} else {
|
||||
// No instances running - clear the selected model
|
||||
if (currentModel) {
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
}
|
||||
|
||||
// Update previous model IDs for next comparison
|
||||
previousModelIds = currentModelIds;
|
||||
});
|
||||
|
||||
function getInstanceModelId(instanceWrapped: unknown): string {
|
||||
@@ -139,6 +166,11 @@
|
||||
}
|
||||
|
||||
function handleKeydown(event: KeyboardEvent) {
|
||||
// Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)
|
||||
if (event.isComposing || event.keyCode === 229) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.key === 'Enter' && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
handleSubmit();
|
||||
|
||||
@@ -8,89 +8,80 @@
|
||||
regenerateLastResponse
|
||||
} from '$lib/stores/app.svelte';
|
||||
import type { MessageAttachment } from '$lib/stores/app.svelte';
|
||||
import { tick, onDestroy } from 'svelte';
|
||||
import MarkdownContent from './MarkdownContent.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
|
||||
const messageList = $derived(messages());
|
||||
const response = $derived(currentResponse());
|
||||
const loading = $derived(isLoading());
|
||||
|
||||
// Ref for scroll anchor at bottom
|
||||
let scrollAnchorRef: HTMLDivElement | undefined = $state();
|
||||
// Scroll management - user controls scroll, show button when not at bottom
|
||||
const SCROLL_THRESHOLD = 100;
|
||||
let showScrollButton = $state(false);
|
||||
let lastMessageCount = 0;
|
||||
let containerRef: HTMLDivElement | undefined = $state();
|
||||
|
||||
// Scroll management
|
||||
const SCROLL_BOTTOM_THRESHOLD = 120;
|
||||
let autoScrollEnabled = true;
|
||||
let currentScrollEl: HTMLElement | null = null;
|
||||
|
||||
function resolveScrollElement(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
let node: HTMLElement | null = scrollAnchorRef?.parentElement as HTMLElement | null;
|
||||
while (node) {
|
||||
const isScrollable = node.scrollHeight > node.clientHeight + 1;
|
||||
if (isScrollable) return node;
|
||||
node = node.parentElement;
|
||||
function getScrollContainer(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
return containerRef?.parentElement ?? null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
if (!currentScrollEl) return;
|
||||
const distanceFromBottom = currentScrollEl.scrollHeight - currentScrollEl.scrollTop - currentScrollEl.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
autoScrollEnabled = isNearBottom;
|
||||
}
|
||||
|
||||
function attachScrollListener() {
|
||||
const nextEl = resolveScrollElement();
|
||||
if (currentScrollEl === nextEl) return;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
function isNearBottom(el: HTMLElement): boolean {
|
||||
return el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
|
||||
}
|
||||
currentScrollEl = nextEl;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.addEventListener('scroll', handleScroll);
|
||||
// Initialize state based on current position
|
||||
handleScroll();
|
||||
}
|
||||
}
|
||||
|
||||
onDestroy(() => {
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
// Re-evaluate scroll container if prop changes or after mount
|
||||
scrollParent;
|
||||
attachScrollListener();
|
||||
});
|
||||
|
||||
// Auto-scroll to bottom when messages change or response updates, but only if user is near bottom
|
||||
$effect(() => {
|
||||
// Track these values to trigger effect
|
||||
const _ = messageList.length;
|
||||
const __ = response;
|
||||
const ___ = loading;
|
||||
|
||||
tick().then(() => {
|
||||
const el = currentScrollEl ?? resolveScrollElement();
|
||||
if (!el || !scrollAnchorRef) return;
|
||||
const distanceFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
if (autoScrollEnabled || isNearBottom) {
|
||||
scrollAnchorRef.scrollIntoView({ behavior: 'smooth', block: 'end' });
|
||||
autoScrollEnabled = true;
|
||||
function scrollToBottom() {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
}
|
||||
}
|
||||
|
||||
function updateScrollButtonVisibility() {
|
||||
const el = getScrollContainer();
|
||||
if (!el) return;
|
||||
showScrollButton = !isNearBottom(el);
|
||||
}
|
||||
|
||||
// Attach scroll listener
|
||||
$effect(() => {
|
||||
const el = scrollParent ?? containerRef?.parentElement;
|
||||
if (!el) return;
|
||||
|
||||
el.addEventListener('scroll', updateScrollButtonVisibility, { passive: true });
|
||||
// Initial check
|
||||
updateScrollButtonVisibility();
|
||||
return () => el.removeEventListener('scroll', updateScrollButtonVisibility);
|
||||
});
|
||||
|
||||
// Auto-scroll when user sends a new message
|
||||
$effect(() => {
|
||||
const count = messageList.length;
|
||||
if (count > lastMessageCount) {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
requestAnimationFrame(() => {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
});
|
||||
}
|
||||
}
|
||||
lastMessageCount = count;
|
||||
});
|
||||
|
||||
// Update scroll button visibility when content changes
|
||||
$effect(() => {
|
||||
// Track response to trigger re-check during streaming
|
||||
const _ = response;
|
||||
|
||||
// Small delay to let DOM update
|
||||
requestAnimationFrame(() => updateScrollButtonVisibility());
|
||||
});
|
||||
});
|
||||
|
||||
// Edit state
|
||||
let editingMessageId = $state<string | null>(null);
|
||||
@@ -231,7 +222,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message (message.id)}
|
||||
<div class="group flex {message.role === 'user' ? 'justify-end' : 'justify-start'}">
|
||||
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'max-w-[95%] sm:max-w-[85%]'}">
|
||||
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'w-full max-w-[98%] sm:max-w-[95%]'}">
|
||||
{#if message.role === 'assistant'}
|
||||
<!-- Assistant message header -->
|
||||
<div class="flex items-center gap-1.5 sm:gap-2 mb-1.5 sm:mb-2">
|
||||
@@ -305,7 +296,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{:else}
|
||||
<div class="{message.role === 'user'
|
||||
? 'command-panel rounded-lg rounded-tr-sm inline-block'
|
||||
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 inline-block'}">
|
||||
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 block w-full'}">
|
||||
|
||||
{#if message.role === 'user'}
|
||||
<!-- User message styling -->
|
||||
@@ -331,7 +322,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
|
||||
{#if message.content}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
<div class="text-xs text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content}
|
||||
</div>
|
||||
{/if}
|
||||
@@ -360,7 +351,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</svg>
|
||||
<span>Thinking...</span>
|
||||
</span>
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60">
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60 ml-4">
|
||||
{isThinkingExpanded(message.id) ? 'HIDE' : 'SHOW'}
|
||||
</span>
|
||||
</button>
|
||||
@@ -374,8 +365,8 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content || (loading ? response : '')}
|
||||
<div class="text-xs text-foreground">
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
@@ -457,6 +448,20 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Scroll anchor for auto-scroll -->
|
||||
<div bind:this={scrollAnchorRef}></div>
|
||||
<!-- Invisible element for container reference -->
|
||||
<div bind:this={containerRef}></div>
|
||||
|
||||
<!-- Scroll to bottom button -->
|
||||
{#if showScrollButton}
|
||||
<button
|
||||
type="button"
|
||||
onclick={scrollToBottom}
|
||||
class="sticky bottom-4 left-1/2 -translate-x-1/2 w-10 h-10 rounded-full bg-exo-dark-gray/90 border border-exo-medium-gray/50 flex items-center justify-center text-exo-light-gray hover:text-exo-yellow hover:border-exo-yellow/50 transition-all shadow-lg cursor-pointer z-10"
|
||||
title="Scroll to bottom"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -10,7 +10,9 @@ import {
|
||||
clearChat,
|
||||
instances,
|
||||
debugMode,
|
||||
toggleDebugMode
|
||||
toggleDebugMode,
|
||||
topologyOnlyMode,
|
||||
toggleTopologyOnlyMode
|
||||
} from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -23,6 +25,7 @@ import {
|
||||
const activeId = $derived(activeConversationId());
|
||||
const instanceData = $derived(instances());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
|
||||
let searchQuery = $state('');
|
||||
let editingId = $state<string | null>(null);
|
||||
@@ -424,6 +427,19 @@ const debugEnabled = $derived(debugMode());
|
||||
<div class="text-xs text-white/60 font-mono tracking-wider text-center">
|
||||
{conversationList.length} CONVERSATION{conversationList.length !== 1 ? 'S' : ''}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleTopologyOnlyMode}
|
||||
class="p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title="Toggle topology only mode"
|
||||
>
|
||||
<svg class="w-4 h-4 {topologyOnlyEnabled ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="5" r="2" fill="currentColor" />
|
||||
<circle cx="5" cy="19" r="2" fill="currentColor" />
|
||||
<circle cx="19" cy="19" r="2" fill="currentColor" />
|
||||
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
|
||||
export let showHome = true;
|
||||
export let onHome: (() => void) | null = null;
|
||||
export let showSidebarToggle = false;
|
||||
export let sidebarVisible = true;
|
||||
export let onToggleSidebar: (() => void) | null = null;
|
||||
|
||||
function handleHome(): void {
|
||||
if (onHome) {
|
||||
@@ -14,13 +17,38 @@
|
||||
window.location.hash = '/';
|
||||
}
|
||||
}
|
||||
|
||||
function handleToggleSidebar(): void {
|
||||
if (onToggleSidebar) {
|
||||
onToggleSidebar();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<header class="relative z-20 flex items-center justify-center px-6 pt-8 pb-4 bg-exo-dark-gray">
|
||||
<!-- Left: Sidebar Toggle -->
|
||||
{#if showSidebarToggle}
|
||||
<div class="absolute left-6 top-1/2 -translate-y-1/2">
|
||||
<button
|
||||
onclick={handleToggleSidebar}
|
||||
class="p-2 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title={sidebarVisible ? 'Hide sidebar' : 'Show sidebar'}
|
||||
>
|
||||
<svg class="w-5 h-5 {sidebarVisible ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
{#if sidebarVisible}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M11 19l-7-7 7-7m8 14l-7-7 7-7" />
|
||||
{:else}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M13 5l7 7-7 7M5 5l7 7-7 7" />
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Center: Logo (clickable to go home) -->
|
||||
<button
|
||||
onclick={handleHome}
|
||||
class="hover:opacity-80 transition-opacity {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
class="bg-transparent border-none outline-none focus:outline-none transition-opacity duration-200 hover:opacity-90 {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
title={showHome ? 'Go to home' : ''}
|
||||
disabled={!showHome}
|
||||
>
|
||||
|
||||
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
@@ -0,0 +1,451 @@
|
||||
<script lang="ts">
|
||||
import { marked } from 'marked';
|
||||
import hljs from 'highlight.js';
|
||||
import katex from 'katex';
|
||||
import 'katex/dist/katex.min.css';
|
||||
import { browser } from '$app/environment';
|
||||
|
||||
interface Props {
|
||||
content: string;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { content, class: className = '' }: Props = $props();
|
||||
|
||||
let containerRef = $state<HTMLDivElement>();
|
||||
let processedHtml = $state('');
|
||||
|
||||
// Configure marked with syntax highlighting
|
||||
marked.setOptions({
|
||||
gfm: true,
|
||||
breaks: true
|
||||
});
|
||||
|
||||
// Custom renderer for code blocks
|
||||
const renderer = new marked.Renderer();
|
||||
|
||||
renderer.code = function ({ text, lang }: { text: string; lang?: string }) {
|
||||
const language = lang && hljs.getLanguage(lang) ? lang : 'plaintext';
|
||||
const highlighted = hljs.highlight(text, { language }).value;
|
||||
const codeId = `code-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`;
|
||||
|
||||
return `
|
||||
<div class="code-block-wrapper">
|
||||
<div class="code-block-header">
|
||||
<span class="code-language">${language}</span>
|
||||
<button type="button" class="copy-code-btn" data-code="${encodeURIComponent(text)}" title="Copy code">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
|
||||
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<pre><code class="hljs language-${language}" data-code-id="${codeId}">${highlighted}</code></pre>
|
||||
</div>
|
||||
`;
|
||||
};
|
||||
|
||||
// Inline code
|
||||
renderer.codespan = function ({ text }: { text: string }) {
|
||||
return `<code class="inline-code">${text}</code>`;
|
||||
};
|
||||
|
||||
marked.use({ renderer });
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
|
||||
* Also protect code blocks from LaTeX processing
|
||||
*/
|
||||
function preprocessLaTeX(text: string): string {
|
||||
// Protect code blocks
|
||||
const codeBlocks: string[] = [];
|
||||
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
|
||||
codeBlocks.push(match);
|
||||
return `<<CODE_${codeBlocks.length - 1}>>`;
|
||||
});
|
||||
|
||||
// Convert \(...\) to $...$
|
||||
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
|
||||
|
||||
// Convert \[...\] to $$...$$
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
|
||||
|
||||
// Restore code blocks
|
||||
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
|
||||
|
||||
return processed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render math expressions with KaTeX after HTML is generated
|
||||
*/
|
||||
function renderMath(html: string): string {
|
||||
// Render display math ($$...$$)
|
||||
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: true,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$$${math}$$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
// Render inline math ($...$) but avoid matching currency like $5
|
||||
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
|
||||
// Skip if it looks like currency ($ followed by number)
|
||||
if (/^\d/.test(math.trim())) {
|
||||
return match;
|
||||
}
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: false,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$${math}$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
return html;
|
||||
}
|
||||
|
||||
function processMarkdown(text: string): string {
|
||||
try {
|
||||
// Preprocess LaTeX notation
|
||||
const preprocessed = preprocessLaTeX(text);
|
||||
// Parse markdown
|
||||
let html = marked.parse(preprocessed) as string;
|
||||
// Render math expressions
|
||||
html = renderMath(html);
|
||||
return html;
|
||||
} catch (error) {
|
||||
console.error('Markdown processing error:', error);
|
||||
return text.replace(/\n/g, '<br>');
|
||||
}
|
||||
}
|
||||
|
||||
async function handleCopyClick(event: Event) {
|
||||
const target = event.currentTarget as HTMLButtonElement;
|
||||
const encodedCode = target.getAttribute('data-code');
|
||||
if (!encodedCode) return;
|
||||
|
||||
const code = decodeURIComponent(encodedCode);
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(code);
|
||||
// Show copied feedback
|
||||
const originalHtml = target.innerHTML;
|
||||
target.innerHTML = `
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M20 6L9 17l-5-5"/>
|
||||
</svg>
|
||||
`;
|
||||
target.classList.add('copied');
|
||||
setTimeout(() => {
|
||||
target.innerHTML = originalHtml;
|
||||
target.classList.remove('copied');
|
||||
}, 2000);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function setupCopyButtons() {
|
||||
if (!containerRef || !browser) return;
|
||||
|
||||
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of buttons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (content) {
|
||||
processedHtml = processMarkdown(content);
|
||||
} else {
|
||||
processedHtml = '';
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (containerRef && processedHtml) {
|
||||
setupCopyButtons();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class="markdown-content {className}">
|
||||
{@html processedHtml}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.markdown-content {
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
/* Paragraphs */
|
||||
.markdown-content :global(p) {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* Headers */
|
||||
.markdown-content :global(h1) {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 700;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h2) {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h3) {
|
||||
font-size: 1.125rem;
|
||||
font-weight: 600;
|
||||
margin: 1rem 0 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(h4),
|
||||
.markdown-content :global(h5),
|
||||
.markdown-content :global(h6) {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
margin: 0.75rem 0 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Bold and italic */
|
||||
.markdown-content :global(strong) {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(em) {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Inline code */
|
||||
.markdown-content :global(.inline-code) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
padding: 0.125rem 0.375rem;
|
||||
border-radius: 0.25rem;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
}
|
||||
|
||||
/* Links */
|
||||
.markdown-content :global(a) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
text-decoration: underline;
|
||||
text-underline-offset: 2px;
|
||||
}
|
||||
|
||||
.markdown-content :global(a:hover) {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
/* Lists */
|
||||
.markdown-content :global(ul) {
|
||||
list-style-type: disc;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(ol) {
|
||||
list-style-type: decimal;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li) {
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li::marker) {
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
}
|
||||
|
||||
/* Blockquotes */
|
||||
.markdown-content :global(blockquote) {
|
||||
border-left: 3px solid var(--exo-yellow, #ffd700);
|
||||
padding: 0.5rem 1rem;
|
||||
margin: 1rem 0;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-radius: 0 0.25rem 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Tables */
|
||||
.markdown-content :global(table) {
|
||||
width: 100%;
|
||||
margin: 1rem 0;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(th) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
padding: 0.5rem;
|
||||
text-align: left;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(td) {
|
||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||
padding: 0.5rem;
|
||||
}
|
||||
|
||||
/* Horizontal rule */
|
||||
.markdown-content :global(hr) {
|
||||
border: none;
|
||||
border-top: 1px solid rgba(255, 255, 255, 0.1);
|
||||
margin: 1.5rem 0;
|
||||
}
|
||||
|
||||
/* Code block wrapper */
|
||||
.markdown-content :global(.code-block-wrapper) {
|
||||
margin: 1rem 0;
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
background: rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-bottom: 1px solid rgba(255, 215, 0, 0.1);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-language) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
font-size: 0.7rem;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.1em;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
cursor: pointer;
|
||||
transition: color 0.2s;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn:hover) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn.copied) {
|
||||
color: #22c55e;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper pre) {
|
||||
margin: 0;
|
||||
padding: 1rem;
|
||||
overflow-x: auto;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper code) {
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.8125rem;
|
||||
line-height: 1.5;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
/* Syntax highlighting - dark theme matching EXO style */
|
||||
.markdown-content :global(.hljs) {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-keyword),
|
||||
.markdown-content :global(.hljs-selector-tag),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-section),
|
||||
.markdown-content :global(.hljs-link) {
|
||||
color: #c084fc;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-string),
|
||||
.markdown-content :global(.hljs-title),
|
||||
.markdown-content :global(.hljs-name),
|
||||
.markdown-content :global(.hljs-type),
|
||||
.markdown-content :global(.hljs-attribute),
|
||||
.markdown-content :global(.hljs-symbol),
|
||||
.markdown-content :global(.hljs-bullet),
|
||||
.markdown-content :global(.hljs-addition),
|
||||
.markdown-content :global(.hljs-variable),
|
||||
.markdown-content :global(.hljs-template-tag),
|
||||
.markdown-content :global(.hljs-template-variable) {
|
||||
color: #fbbf24;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-comment),
|
||||
.markdown-content :global(.hljs-quote),
|
||||
.markdown-content :global(.hljs-deletion),
|
||||
.markdown-content :global(.hljs-meta) {
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-number),
|
||||
.markdown-content :global(.hljs-regexp),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-built_in) {
|
||||
color: #34d399;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-function),
|
||||
.markdown-content :global(.hljs-class .hljs-title) {
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
/* KaTeX math styling */
|
||||
.markdown-content :global(.katex) {
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display) {
|
||||
margin: 1rem 0;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display > .katex) {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-error) {
|
||||
color: #f87171;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
background: rgba(248, 113, 113, 0.1);
|
||||
padding: 0.125rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
</style>
|
||||
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview } from '$lib/stores/app.svelte';
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview, TopologyEdge } from '$lib/stores/app.svelte';
|
||||
import { debugMode, topologyData } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
model: { id: string; name?: string; storage_size_megabytes?: number };
|
||||
@@ -206,12 +207,8 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
const centerY = topoHeight / 2;
|
||||
const radius = numNodes === 1 ? 0 : numNodes === 2 ? 45 : Math.min(topoWidth, topoHeight) * 0.32;
|
||||
|
||||
// Use API preview data if available
|
||||
// Only use API preview data - no local estimation
|
||||
const hasApiPreview = apiPreview !== null && apiPreview.error === null && apiPreview.memory_delta_by_node !== null;
|
||||
const canFit = hasApiPreview ? true : (() => {
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return totalAvailable >= estimatedMemory;
|
||||
})();
|
||||
const error = apiPreview?.error ?? null;
|
||||
|
||||
let placementNodes: Array<{
|
||||
@@ -232,135 +229,140 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
modelFillHeight: number;
|
||||
}> = [];
|
||||
|
||||
if (hasApiPreview && apiPreview.memory_delta_by_node) {
|
||||
// Use API placement data
|
||||
const memoryDelta = apiPreview.memory_delta_by_node;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else if (apiPreview?.error) {
|
||||
// API returned an error - model can't fit, show all nodes as unused
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: 0,
|
||||
currentPercent,
|
||||
newPercent: currentPercent,
|
||||
isUsed: false,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: 0
|
||||
};
|
||||
});
|
||||
} else {
|
||||
// Fallback: local estimation based on sharding strategy
|
||||
const memoryNeeded = estimatedMemory;
|
||||
// Use API placement data directly
|
||||
const memoryDelta = apiPreview?.memory_delta_by_node ?? {};
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
if (sharding === 'Pipeline') {
|
||||
const memoryPerNode = memoryNeeded / numNodes;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + memoryPerNode) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: memoryPerNode,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed: true,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else {
|
||||
let remaining = memoryNeeded;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const allocated = Math.min(remaining, n.availableGB);
|
||||
remaining -= allocated;
|
||||
const isUsed = allocated > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + allocated) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: allocated,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return { nodes: placementNodes, canFit: hasApiPreview || canFit, totalAvailable, topoWidth, topoHeight, error };
|
||||
return { nodes: placementNodes, canFit: hasApiPreview, totalAvailable, topoWidth, topoHeight, error };
|
||||
});
|
||||
|
||||
const canFit = $derived(apiPreview ? apiPreview.error === null : placementPreview().canFit);
|
||||
const placementError = $derived(apiPreview?.error ?? null);
|
||||
const nodeCount = $derived(nodeList().length);
|
||||
const filterId = $derived(model.id.replace(/[^a-zA-Z0-9]/g, ''));
|
||||
|
||||
// Debug mode state
|
||||
const isDebugMode = $derived(debugMode());
|
||||
const topology = $derived(topologyData());
|
||||
const isRdma = $derived(runtime === 'MlxIbv' || runtime === 'MlxJaccl');
|
||||
|
||||
// Get interface name for an IP from node data
|
||||
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
|
||||
if (!ip || !topology?.nodes) return null;
|
||||
|
||||
// Strip port if present
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
// Check specified node first
|
||||
const node = topology.nodes[nodeId];
|
||||
if (node) {
|
||||
const match = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
// Fallback: check all nodes
|
||||
for (const [, otherNode] of Object.entries(topology.nodes)) {
|
||||
if (!otherNode) continue;
|
||||
const match = otherNode.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = otherNode.ip_to_interface?.[cleanIp] || otherNode.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
// Get directional arrow based on node positions
|
||||
function getArrow(fromNode: { x: number; y: number }, toNode: { x: number; y: number }): string {
|
||||
const dx = toNode.x - fromNode.x;
|
||||
const dy = toNode.y - fromNode.y;
|
||||
const absX = Math.abs(dx);
|
||||
const absY = Math.abs(dy);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dx > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dy > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dx > 0 && dy > 0) return '↘';
|
||||
if (dx > 0 && dy < 0) return '↗';
|
||||
if (dx < 0 && dy > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Get connection info for edges between two nodes
|
||||
// Returns exactly one connection per direction (A→B and B→A), preferring non-loopback
|
||||
function getConnectionInfo(nodeId1: string, nodeId2: string): Array<{ ip: string; iface: string | null; from: string; to: string }> {
|
||||
if (!topology?.edges) return [];
|
||||
|
||||
// Collect candidates for each direction
|
||||
const aToBCandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
const bToACandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
|
||||
for (const edge of topology.edges) {
|
||||
const ip = edge.sendBackIp || '?';
|
||||
const iface = edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
|
||||
|
||||
if (edge.source === nodeId1 && edge.target === nodeId2) {
|
||||
aToBCandidates.push({ ip, iface });
|
||||
} else if (edge.source === nodeId2 && edge.target === nodeId1) {
|
||||
bToACandidates.push({ ip, iface });
|
||||
}
|
||||
}
|
||||
|
||||
// Pick best (prefer non-loopback)
|
||||
const pickBest = (candidates: Array<{ ip: string; iface: string | null }>) => {
|
||||
if (candidates.length === 0) return null;
|
||||
return candidates.find(c => !c.ip.startsWith('127.')) || candidates[0];
|
||||
};
|
||||
|
||||
const result: Array<{ ip: string; iface: string | null; from: string; to: string }> = [];
|
||||
|
||||
const bestAtoB = pickBest(aToBCandidates);
|
||||
if (bestAtoB) result.push({ ...bestAtoB, from: nodeId1, to: nodeId2 });
|
||||
|
||||
const bestBtoA = pickBest(bToACandidates);
|
||||
if (bestBtoA) result.push({ ...bestBtoA, from: nodeId2, to: nodeId1 });
|
||||
|
||||
return result;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="relative group">
|
||||
@@ -453,6 +455,26 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
|
||||
<!-- Connection lines between nodes (if multiple) -->
|
||||
{#if preview.nodes.length > 1}
|
||||
{@const usedNodes = preview.nodes.filter(n => n.isUsed)}
|
||||
{@const nodePositions = Object.fromEntries(preview.nodes.map(n => [n.id, { x: n.x, y: n.y }]))}
|
||||
{@const allConnections = isDebugMode && usedNodes.length > 1 ? (() => {
|
||||
const conns: Array<{ ip: string; iface: string | null; from: string; to: string; midX: number; midY: number; arrow: string }> = [];
|
||||
for (let i = 0; i < usedNodes.length; i++) {
|
||||
for (let j = i + 1; j < usedNodes.length; j++) {
|
||||
const n1 = usedNodes[i];
|
||||
const n2 = usedNodes[j];
|
||||
const midX = (n1.x + n2.x) / 2;
|
||||
const midY = (n1.y + n2.y) / 2;
|
||||
for (const c of getConnectionInfo(n1.id, n2.id)) {
|
||||
const fromPos = nodePositions[c.from];
|
||||
const toPos = nodePositions[c.to];
|
||||
const arrow = fromPos && toPos ? getArrow(fromPos, toPos) : '→';
|
||||
conns.push({ ...c, midX, midY, arrow });
|
||||
}
|
||||
}
|
||||
}
|
||||
return conns;
|
||||
})() : []}
|
||||
{#each preview.nodes as node, i}
|
||||
{#each preview.nodes.slice(i + 1) as node2}
|
||||
<line
|
||||
@@ -464,6 +486,43 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
/>
|
||||
{/each}
|
||||
{/each}
|
||||
<!-- Debug: Show connection IPs/interfaces in corners -->
|
||||
{#if isDebugMode && allConnections.length > 0}
|
||||
{@const centerX = preview.topoWidth / 2}
|
||||
{@const centerY = preview.topoHeight / 2}
|
||||
{@const quadrants = {
|
||||
topLeft: allConnections.filter(c => c.midX < centerX && c.midY < centerY),
|
||||
topRight: allConnections.filter(c => c.midX >= centerX && c.midY < centerY),
|
||||
bottomLeft: allConnections.filter(c => c.midX < centerX && c.midY >= centerY),
|
||||
bottomRight: allConnections.filter(c => c.midX >= centerX && c.midY >= centerY)
|
||||
}}
|
||||
{@const padding = 4}
|
||||
{@const lineHeight = 8}
|
||||
<!-- Top Left -->
|
||||
{#each quadrants.topLeft as conn, idx}
|
||||
<text x={padding} y={padding + idx * lineHeight} text-anchor="start" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Top Right -->
|
||||
{#each quadrants.topRight as conn, idx}
|
||||
<text x={preview.topoWidth - padding} y={padding + idx * lineHeight} text-anchor="end" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Left -->
|
||||
{#each quadrants.bottomLeft as conn, idx}
|
||||
<text x={padding} y={preview.topoHeight - padding - (quadrants.bottomLeft.length - 1 - idx) * lineHeight} text-anchor="start" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Right -->
|
||||
{#each quadrants.bottomRight as conn, idx}
|
||||
<text x={preview.topoWidth - padding} y={preview.topoHeight - padding - (quadrants.bottomRight.length - 1 - idx) * lineHeight} text-anchor="end" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
{#each preview.nodes as node}
|
||||
|
||||
@@ -24,19 +24,36 @@ function getNodeLabel(nodeId: string): string {
|
||||
|
||||
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
|
||||
if (!ip) return { label: '?', missing: true };
|
||||
const node = data?.nodes?.[nodeId];
|
||||
if (!node) return { label: '?', missing: true };
|
||||
|
||||
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
// Helper to check a node's interfaces
|
||||
function checkNode(node: typeof data.nodes[string]): string | null {
|
||||
if (!node) return null;
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (matchFromInterfaces?.name) {
|
||||
return matchFromInterfaces.name;
|
||||
}
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === ip)
|
||||
);
|
||||
if (matchFromInterfaces?.name) {
|
||||
return { label: matchFromInterfaces.name, missing: false };
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
const mapped = node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return { label: mapped, missing: false };
|
||||
|
||||
// Try specified node first
|
||||
const result = checkNode(data?.nodes?.[nodeId]);
|
||||
if (result) return { label: result, missing: false };
|
||||
|
||||
// Fallback: search all nodes for this IP
|
||||
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
|
||||
const otherResult = checkNode(otherNode);
|
||||
if (otherResult) return { label: otherResult, missing: false };
|
||||
}
|
||||
|
||||
return { label: '?', missing: true };
|
||||
@@ -67,6 +84,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
return lines;
|
||||
}
|
||||
|
||||
|
||||
// Apple logo path for MacBook Pro screen
|
||||
const APPLE_LOGO_PATH = "M788.1 340.9c-5.8 4.5-108.2 62.2-108.2 190.5 0 148.4 130.3 200.9 134.2 202.2-.6 3.2-20.7 71.9-68.7 141.9-42.8 61.6-87.5 123.1-155.5 123.1s-85.5-39.5-164-39.5c-76.5 0-103.7 40.8-165.9 40.8s-105.6-57-155.5-127C46.7 790.7 0 663 0 541.8c0-194.4 126.4-297.5 250.8-297.5 66.1 0 121.2 43.4 162.7 43.4 39.5 0 101.1-46 176.3-46 28.5 0 130.9 2.6 198.3 99.2zm-234-181.5c31.1-36.9 53.1-88.1 53.1-139.3 0-7.1-.6-14.3-1.9-20.1-50.6 1.9-110.8 33.7-147.1 75.8-28.5 32.4-55.1 83.6-55.1 135.5 0 7.8 1.3 15.6 1.9 18.1 3.2.6 8.4 1.3 13.6 1.3 45.4 0 102.5-30.4 135.5-71.3z";
|
||||
const LOGO_NATIVE_WIDTH = 814;
|
||||
@@ -238,6 +256,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
|
||||
|
||||
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
|
||||
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
|
||||
edges.forEach(edge => {
|
||||
if (!edge.source || !edge.target || edge.source === edge.target) return;
|
||||
if (!positionById[edge.source] || !positionById[edge.target]) return;
|
||||
@@ -314,110 +333,98 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('marker-end', 'url(#arrowhead)');
|
||||
}
|
||||
|
||||
// Collect debug labels for later positioning at edges
|
||||
if (debugEnabled && entry.connections.length > 0) {
|
||||
const maxBoxes = 6;
|
||||
const fontSize = isMinimized ? 8 : 9;
|
||||
const lineGap = 2;
|
||||
const labelOffsetOut = Math.max(140, minDimension * 0.38);
|
||||
const labelOffsetSide = isMinimized ? 16 : 20;
|
||||
const boxWidth = 170;
|
||||
const maxLineLen = 26;
|
||||
|
||||
const connections = entry.connections.slice(0, maxBoxes);
|
||||
if (entry.connections.length > maxBoxes) {
|
||||
const remaining = entry.connections.length - maxBoxes;
|
||||
connections.push({
|
||||
from: '',
|
||||
to: '',
|
||||
ip: `(+${remaining} more)`,
|
||||
ifaceLabel: '',
|
||||
missingIface: false
|
||||
});
|
||||
}
|
||||
|
||||
let dirX = mx - centerX;
|
||||
let dirY = my - centerY;
|
||||
const dirLen = Math.hypot(dirX, dirY);
|
||||
if (dirLen < 1) {
|
||||
dirX = -uy;
|
||||
dirY = ux;
|
||||
} else {
|
||||
dirX /= dirLen;
|
||||
dirY /= dirLen;
|
||||
}
|
||||
|
||||
const nx = -dirY;
|
||||
const ny = dirX;
|
||||
|
||||
const labelXRaw = mx + dirX * labelOffsetOut + nx * labelOffsetSide;
|
||||
const labelYRaw = my + dirY * labelOffsetOut + ny * labelOffsetSide;
|
||||
const clampPad = Math.min(120, minDimension * 0.12);
|
||||
const labelX = Math.max(clampPad, Math.min(width - clampPad, labelXRaw));
|
||||
const labelY = Math.max(clampPad, Math.min(height - clampPad, labelYRaw));
|
||||
|
||||
const labelGroup = debugLabelsGroup.append('g')
|
||||
.attr('transform', `translate(${labelX}, ${labelY})`);
|
||||
|
||||
const textGroup = labelGroup.append('g');
|
||||
|
||||
connections.forEach((conn, idx) => {
|
||||
const rawLines = conn.from && conn.to
|
||||
? [
|
||||
`${getNodeLabel(conn.from)}→${getNodeLabel(conn.to)}`,
|
||||
`${conn.ip}`,
|
||||
`${conn.ifaceLabel}`
|
||||
]
|
||||
: [conn.ip];
|
||||
|
||||
const wrapped = rawLines.flatMap(line => wrapLine(line, maxLineLen));
|
||||
|
||||
wrapped.forEach((line, lineIdx) => {
|
||||
textGroup.append('text')
|
||||
.attr('x', 0)
|
||||
.attr('y', (idx * (wrapped.length * (fontSize + lineGap))) + lineIdx * (fontSize + lineGap))
|
||||
.attr('text-anchor', 'middle')
|
||||
.attr('dominant-baseline', 'hanging')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.9)')
|
||||
.text(line);
|
||||
});
|
||||
// Determine which side of viewport based on edge midpoint
|
||||
const isLeft = mx < centerX;
|
||||
const isTop = my < safeCenterY;
|
||||
|
||||
// Store for batch rendering after all edges processed
|
||||
if (!debugEdgeLabels) debugEdgeLabels = [];
|
||||
debugEdgeLabels.push({
|
||||
connections: entry.connections,
|
||||
isLeft,
|
||||
isTop,
|
||||
mx,
|
||||
my
|
||||
});
|
||||
|
||||
const bbox = textGroup.node()?.getBBox();
|
||||
if (bbox) {
|
||||
const paddedWidth = Math.max(boxWidth, bbox.width + 14);
|
||||
const boxHeight = bbox.height + 8;
|
||||
const boxMinX = labelX - paddedWidth / 2;
|
||||
const boxMaxX = labelX + paddedWidth / 2;
|
||||
const boxMinY = labelY + bbox.y - 4;
|
||||
const boxMaxY = boxMinY + boxHeight;
|
||||
|
||||
const clampPadDynamic = Math.min(140, minDimension * 0.18);
|
||||
let shiftX = 0;
|
||||
let shiftY = 0;
|
||||
if (boxMinX < clampPadDynamic) shiftX = clampPadDynamic - boxMinX;
|
||||
if (boxMaxX > width - clampPadDynamic) shiftX = (width - clampPadDynamic) - boxMaxX;
|
||||
if (boxMinY < clampPadDynamic) shiftY = clampPadDynamic - boxMinY;
|
||||
if (boxMaxY > height - clampPadDynamic) shiftY = (height - clampPadDynamic) - boxMaxY;
|
||||
|
||||
const finalX = labelX + shiftX;
|
||||
const finalY = labelY + shiftY;
|
||||
labelGroup.attr('transform', `translate(${finalX}, ${finalY})`);
|
||||
|
||||
labelGroup.insert('rect', 'g')
|
||||
.attr('x', -paddedWidth / 2)
|
||||
.attr('y', bbox.y - 4)
|
||||
.attr('width', paddedWidth)
|
||||
.attr('height', boxHeight)
|
||||
.attr('rx', 4)
|
||||
.attr('fill', 'rgba(0,0,0,0.75)')
|
||||
.attr('stroke', 'rgba(255,255,255,0.12)')
|
||||
.attr('stroke-width', 0.6);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Render debug labels at viewport edges/corners
|
||||
if (debugEdgeLabels && debugEdgeLabels.length > 0) {
|
||||
const fontSize = isMinimized ? 10 : 12;
|
||||
const lineHeight = fontSize + 4;
|
||||
const padding = 10;
|
||||
|
||||
// Helper to get arrow based on direction vector
|
||||
function getArrow(fromId: string, toId: string): string {
|
||||
const fromPos = positionById[fromId];
|
||||
const toPos = positionById[toId];
|
||||
if (!fromPos || !toPos) return '→';
|
||||
|
||||
const dirX = toPos.x - fromPos.x;
|
||||
const dirY = toPos.y - fromPos.y;
|
||||
const absX = Math.abs(dirX);
|
||||
const absY = Math.abs(dirY);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dirX > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dirY > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dirX > 0 && dirY > 0) return '↘';
|
||||
if (dirX > 0 && dirY < 0) return '↗';
|
||||
if (dirX < 0 && dirY > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
|
||||
const quadrants: Record<string, typeof debugEdgeLabels> = {
|
||||
topLeft: [],
|
||||
topRight: [],
|
||||
bottomLeft: [],
|
||||
bottomRight: []
|
||||
};
|
||||
|
||||
debugEdgeLabels.forEach(edge => {
|
||||
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
|
||||
quadrants[key].push(edge);
|
||||
});
|
||||
|
||||
// Render each quadrant
|
||||
Object.entries(quadrants).forEach(([quadrant, edges]) => {
|
||||
if (edges.length === 0) return;
|
||||
|
||||
const isLeft = quadrant.includes('Left');
|
||||
const isTop = quadrant.includes('top');
|
||||
|
||||
let baseX = isLeft ? padding : width - padding;
|
||||
let baseY = isTop ? padding : height - padding;
|
||||
const textAnchor = isLeft ? 'start' : 'end';
|
||||
|
||||
let currentY = baseY;
|
||||
|
||||
edges.forEach(edge => {
|
||||
edge.connections.forEach(conn => {
|
||||
const arrow = getArrow(conn.from, conn.to);
|
||||
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;
|
||||
debugLabelsGroup.append('text')
|
||||
.attr('x', baseX)
|
||||
.attr('y', currentY)
|
||||
.attr('text-anchor', textAnchor)
|
||||
.attr('dominant-baseline', isTop ? 'hanging' : 'auto')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.85)')
|
||||
.text(label);
|
||||
currentY += isTop ? lineHeight : -lineHeight;
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Draw nodes
|
||||
const nodesGroup = svg.append('g').attr('class', 'nodes-group');
|
||||
|
||||
@@ -968,4 +975,5 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
from { stroke-dashoffset: 0; }
|
||||
to { stroke-dashoffset: -10; }
|
||||
}
|
||||
|
||||
</style>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
export { default as TopologyGraph } from './TopologyGraph.svelte';
|
||||
export { default as ChatForm } from './ChatForm.svelte';
|
||||
export { default as ChatMessages } from './ChatMessages.svelte';
|
||||
export { default as ChatAttachments } from './ChatAttachments.svelte';
|
||||
export { default as ChatSidebar } from './ChatSidebar.svelte';
|
||||
export { default as ModelCard } from './ModelCard.svelte';
|
||||
|
||||
export { default as TopologyGraph } from "./TopologyGraph.svelte";
|
||||
export { default as ChatForm } from "./ChatForm.svelte";
|
||||
export { default as ChatMessages } from "./ChatMessages.svelte";
|
||||
export { default as ChatAttachments } from "./ChatAttachments.svelte";
|
||||
export { default as ChatSidebar } from "./ChatSidebar.svelte";
|
||||
export { default as ModelCard } from "./ModelCard.svelte";
|
||||
export { default as MarkdownContent } from "./MarkdownContent.svelte";
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,55 +13,124 @@ export interface ChatUploadedFile {
|
||||
}
|
||||
|
||||
export interface ChatAttachment {
|
||||
type: 'image' | 'text' | 'pdf' | 'audio';
|
||||
type: "image" | "text" | "pdf" | "audio";
|
||||
name: string;
|
||||
content?: string;
|
||||
base64Url?: string;
|
||||
mimeType?: string;
|
||||
}
|
||||
|
||||
export type FileCategory = 'image' | 'text' | 'pdf' | 'audio' | 'unknown';
|
||||
export type FileCategory = "image" | "text" | "pdf" | "audio" | "unknown";
|
||||
|
||||
export const IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.svg'];
|
||||
export const IMAGE_MIME_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/svg+xml'];
|
||||
export const IMAGE_EXTENSIONS = [
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".gif",
|
||||
".webp",
|
||||
".svg",
|
||||
];
|
||||
export const IMAGE_MIME_TYPES = [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/svg+xml",
|
||||
];
|
||||
|
||||
export const TEXT_EXTENSIONS = [
|
||||
'.txt', '.md', '.json', '.xml', '.yaml', '.yml', '.csv', '.log',
|
||||
'.js', '.ts', '.jsx', '.tsx', '.py', '.java', '.cpp', '.c', '.h',
|
||||
'.css', '.html', '.htm', '.sql', '.sh', '.bat', '.rs', '.go',
|
||||
'.rb', '.php', '.swift', '.kt', '.scala', '.r', '.dart', '.vue', '.svelte'
|
||||
".txt",
|
||||
".md",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".csv",
|
||||
".log",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".py",
|
||||
".java",
|
||||
".cpp",
|
||||
".c",
|
||||
".h",
|
||||
".css",
|
||||
".html",
|
||||
".htm",
|
||||
".sql",
|
||||
".sh",
|
||||
".bat",
|
||||
".rs",
|
||||
".go",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".r",
|
||||
".dart",
|
||||
".vue",
|
||||
".svelte",
|
||||
];
|
||||
export const TEXT_MIME_TYPES = [
|
||||
'text/plain', 'text/markdown', 'text/csv', 'text/html', 'text/css',
|
||||
'application/json', 'application/xml', 'text/xml', 'application/javascript',
|
||||
'text/javascript', 'application/typescript'
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"text/html",
|
||||
"text/css",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/javascript",
|
||||
"text/javascript",
|
||||
"application/typescript",
|
||||
];
|
||||
|
||||
export const PDF_EXTENSIONS = ['.pdf'];
|
||||
export const PDF_MIME_TYPES = ['application/pdf'];
|
||||
export const PDF_EXTENSIONS = [".pdf"];
|
||||
export const PDF_MIME_TYPES = ["application/pdf"];
|
||||
|
||||
export const AUDIO_EXTENSIONS = ['.mp3', '.wav', '.ogg', '.m4a'];
|
||||
export const AUDIO_MIME_TYPES = ['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/mp4'];
|
||||
export const AUDIO_EXTENSIONS = [".mp3", ".wav", ".ogg", ".m4a"];
|
||||
export const AUDIO_MIME_TYPES = [
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/ogg",
|
||||
"audio/mp4",
|
||||
];
|
||||
|
||||
/**
|
||||
* Get file category based on MIME type and extension
|
||||
*/
|
||||
export function getFileCategory(mimeType: string, fileName: string): FileCategory {
|
||||
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf('.'));
|
||||
|
||||
if (IMAGE_MIME_TYPES.includes(mimeType) || IMAGE_EXTENSIONS.includes(extension)) {
|
||||
return 'image';
|
||||
export function getFileCategory(
|
||||
mimeType: string,
|
||||
fileName: string,
|
||||
): FileCategory {
|
||||
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf("."));
|
||||
|
||||
if (
|
||||
IMAGE_MIME_TYPES.includes(mimeType) ||
|
||||
IMAGE_EXTENSIONS.includes(extension)
|
||||
) {
|
||||
return "image";
|
||||
}
|
||||
if (PDF_MIME_TYPES.includes(mimeType) || PDF_EXTENSIONS.includes(extension)) {
|
||||
return 'pdf';
|
||||
return "pdf";
|
||||
}
|
||||
if (AUDIO_MIME_TYPES.includes(mimeType) || AUDIO_EXTENSIONS.includes(extension)) {
|
||||
return 'audio';
|
||||
if (
|
||||
AUDIO_MIME_TYPES.includes(mimeType) ||
|
||||
AUDIO_EXTENSIONS.includes(extension)
|
||||
) {
|
||||
return "audio";
|
||||
}
|
||||
if (TEXT_MIME_TYPES.includes(mimeType) || TEXT_EXTENSIONS.includes(extension) || mimeType.startsWith('text/')) {
|
||||
return 'text';
|
||||
if (
|
||||
TEXT_MIME_TYPES.includes(mimeType) ||
|
||||
TEXT_EXTENSIONS.includes(extension) ||
|
||||
mimeType.startsWith("text/")
|
||||
) {
|
||||
return "text";
|
||||
}
|
||||
return 'unknown';
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -69,36 +138,36 @@ export function getFileCategory(mimeType: string, fileName: string): FileCategor
|
||||
*/
|
||||
export function getAcceptString(categories: FileCategory[]): string {
|
||||
const accepts: string[] = [];
|
||||
|
||||
|
||||
for (const category of categories) {
|
||||
switch (category) {
|
||||
case 'image':
|
||||
case "image":
|
||||
accepts.push(...IMAGE_EXTENSIONS, ...IMAGE_MIME_TYPES);
|
||||
break;
|
||||
case 'text':
|
||||
case "text":
|
||||
accepts.push(...TEXT_EXTENSIONS, ...TEXT_MIME_TYPES);
|
||||
break;
|
||||
case 'pdf':
|
||||
case "pdf":
|
||||
accepts.push(...PDF_EXTENSIONS, ...PDF_MIME_TYPES);
|
||||
break;
|
||||
case 'audio':
|
||||
case "audio":
|
||||
accepts.push(...AUDIO_EXTENSIONS, ...AUDIO_MIME_TYPES);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return accepts.join(',');
|
||||
|
||||
return accepts.join(",");
|
||||
}
|
||||
|
||||
/**
|
||||
* Format file size for display
|
||||
*/
|
||||
export function formatFileSize(bytes: number): string {
|
||||
if (bytes === 0) return '0 B';
|
||||
if (bytes === 0) return "0 B";
|
||||
const k = 1024;
|
||||
const sizes = ['B', 'KB', 'MB', 'GB'];
|
||||
const sizes = ["B", "KB", "MB", "GB"];
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + " " + sizes[i];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -128,42 +197,44 @@ export function readFileAsText(file: File): Promise<string> {
|
||||
/**
|
||||
* Process uploaded files into ChatUploadedFile format
|
||||
*/
|
||||
export async function processUploadedFiles(files: File[]): Promise<ChatUploadedFile[]> {
|
||||
export async function processUploadedFiles(
|
||||
files: File[],
|
||||
): Promise<ChatUploadedFile[]> {
|
||||
const results: ChatUploadedFile[] = [];
|
||||
|
||||
|
||||
for (const file of files) {
|
||||
const id = Date.now().toString() + Math.random().toString(36).substring(2, 9);
|
||||
const id =
|
||||
Date.now().toString() + Math.random().toString(36).substring(2, 9);
|
||||
const category = getFileCategory(file.type, file.name);
|
||||
|
||||
|
||||
const base: ChatUploadedFile = {
|
||||
id,
|
||||
name: file.name,
|
||||
size: file.size,
|
||||
type: file.type,
|
||||
file
|
||||
file,
|
||||
};
|
||||
|
||||
|
||||
try {
|
||||
if (category === 'image') {
|
||||
if (category === "image") {
|
||||
const preview = await readFileAsDataURL(file);
|
||||
results.push({ ...base, preview });
|
||||
} else if (category === 'text' || category === 'unknown') {
|
||||
} else if (category === "text" || category === "unknown") {
|
||||
const textContent = await readFileAsText(file);
|
||||
results.push({ ...base, textContent });
|
||||
} else if (category === 'pdf') {
|
||||
} else if (category === "pdf") {
|
||||
results.push(base);
|
||||
} else if (category === 'audio') {
|
||||
} else if (category === "audio") {
|
||||
const preview = await readFileAsDataURL(file);
|
||||
results.push({ ...base, preview });
|
||||
} else {
|
||||
results.push(base);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error processing file:', file.name, error);
|
||||
console.error("Error processing file:", file.name, error);
|
||||
results.push(base);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,10 @@
|
||||
selectedChatModel,
|
||||
debugMode,
|
||||
toggleDebugMode,
|
||||
topologyOnlyMode,
|
||||
toggleTopologyOnlyMode,
|
||||
chatSidebarVisible,
|
||||
toggleChatSidebarVisible,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview
|
||||
} from '$lib/stores/app.svelte';
|
||||
@@ -37,6 +41,8 @@
|
||||
const selectedModelId = $derived(selectedPreviewModelId());
|
||||
const loadingPreviews = $derived(isLoadingPreviews());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
|
||||
let mounted = $state(false);
|
||||
|
||||
@@ -45,6 +51,59 @@ const debugEnabled = $derived(debugMode());
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
// Launch defaults persistence
|
||||
const LAUNCH_DEFAULTS_KEY = 'exo-launch-defaults';
|
||||
interface LaunchDefaults {
|
||||
modelId: string | null;
|
||||
sharding: 'Pipeline' | 'Tensor';
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
}
|
||||
|
||||
function saveLaunchDefaults(): void {
|
||||
const defaults: LaunchDefaults = {
|
||||
modelId: selectedPreviewModelId(),
|
||||
sharding: selectedSharding,
|
||||
instanceType: selectedInstanceType,
|
||||
minNodes: selectedMinNodes,
|
||||
};
|
||||
try {
|
||||
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
|
||||
} catch (e) {
|
||||
console.warn('Failed to save launch defaults:', e);
|
||||
}
|
||||
}
|
||||
|
||||
function loadLaunchDefaults(): LaunchDefaults | null {
|
||||
try {
|
||||
const stored = localStorage.getItem(LAUNCH_DEFAULTS_KEY);
|
||||
if (!stored) return null;
|
||||
return JSON.parse(stored) as LaunchDefaults;
|
||||
} catch (e) {
|
||||
console.warn('Failed to load launch defaults:', e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
if (!defaults) return;
|
||||
|
||||
// Apply sharding and instance type unconditionally
|
||||
selectedSharding = defaults.sharding;
|
||||
selectedInstanceType = defaults.instanceType;
|
||||
|
||||
// Apply minNodes if valid (between 1 and maxNodes)
|
||||
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
// Only apply model if it exists in the available models
|
||||
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
}
|
||||
}
|
||||
|
||||
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
|
||||
let selectedMinNodes = $state<number>(1);
|
||||
let minNodesInitialized = $state(false);
|
||||
@@ -292,6 +351,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const data = await response.json();
|
||||
// API returns { data: [{ id, name }] } format
|
||||
models = data.data || [];
|
||||
// Restore last launch defaults if available
|
||||
const currentNodeCount = topologyData() ? Object.keys(topologyData()!.nodes).length : 1;
|
||||
applyLaunchDefaults(models, currentNodeCount);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch models:', error);
|
||||
@@ -338,10 +400,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to launch instance:', errorText);
|
||||
} else {
|
||||
// Auto-select the launched model only if no model is currently selected
|
||||
if (!selectedChatModel()) {
|
||||
setSelectedChatModel(modelId);
|
||||
}
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
@@ -472,6 +532,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const progress = parseDownloadProgress(downloadPayload);
|
||||
if (progress) {
|
||||
// Sum all values across nodes - each node downloads independently
|
||||
totalBytes += progress.totalBytes;
|
||||
downloadedBytes += progress.downloadedBytes;
|
||||
totalSpeed += progress.speed;
|
||||
@@ -489,13 +550,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
return { isDownloading: false, progress: null, perNode: [] };
|
||||
}
|
||||
|
||||
// ETA = total remaining bytes / total speed across all nodes
|
||||
const remainingBytes = totalBytes - downloadedBytes;
|
||||
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
speed: totalSpeed,
|
||||
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
|
||||
etaMs,
|
||||
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
|
||||
completedFiles,
|
||||
totalFiles,
|
||||
@@ -526,7 +591,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
// Unwrap the instance
|
||||
const [instanceTag, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== 'object') {
|
||||
return { isDownloading: false, progress: null, statusText: 'UNKNOWN', perNode: [] };
|
||||
return { isDownloading: false, progress: null, statusText: 'PREPARING', perNode: [] };
|
||||
}
|
||||
|
||||
const inst = instance as { shardAssignments?: { nodeToRunner?: Record<string, string>; runnerToShard?: Record<string, unknown>; modelId?: string } };
|
||||
@@ -576,6 +641,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const progress = parseDownloadProgress(downloadPayload);
|
||||
if (progress) {
|
||||
// Sum all values across nodes - each node downloads independently
|
||||
totalBytes += progress.totalBytes;
|
||||
downloadedBytes += progress.downloadedBytes;
|
||||
totalSpeed += progress.speed;
|
||||
@@ -596,13 +662,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
return { isDownloading: false, progress: null, statusText: statusInfo.statusText, perNode: [] };
|
||||
}
|
||||
|
||||
// ETA = total remaining bytes / total speed across all nodes
|
||||
const remainingBytes = totalBytes - downloadedBytes;
|
||||
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
speed: totalSpeed,
|
||||
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
|
||||
etaMs,
|
||||
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
|
||||
completedFiles,
|
||||
totalFiles,
|
||||
@@ -618,10 +688,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
function getStatusColor(statusText: string): string {
|
||||
switch (statusText) {
|
||||
case 'FAILED': return 'text-red-400';
|
||||
case 'SHUTDOWN': return 'text-gray-400';
|
||||
case 'DOWNLOADING': return 'text-blue-400';
|
||||
case 'LOADING':
|
||||
case 'WARMING UP':
|
||||
case 'WAITING': return 'text-yellow-400';
|
||||
case 'WAITING':
|
||||
case 'INITIALIZING': return 'text-yellow-400';
|
||||
case 'RUNNING': return 'text-teal-400';
|
||||
case 'READY':
|
||||
case 'LOADED': return 'text-green-400';
|
||||
@@ -632,7 +704,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
function deriveInstanceStatus(instanceWrapped: unknown): { statusText: string; statusClass: string } {
|
||||
const [, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== 'object') {
|
||||
return { statusText: 'UNKNOWN', statusClass: 'inactive' };
|
||||
return { statusText: 'PREPARING', statusClass: 'inactive' };
|
||||
}
|
||||
|
||||
const inst = instance as { shardAssignments?: { runnerToShard?: Record<string, unknown> } };
|
||||
@@ -644,12 +716,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
if (!r) return null;
|
||||
const [kind] = getTagged(r);
|
||||
const statusMap: Record<string, string> = {
|
||||
RunnerWaitingForInitialization: 'WaitingForInitialization',
|
||||
RunnerInitializingBackend: 'InitializingBackend',
|
||||
RunnerWaitingForModel: 'WaitingForModel',
|
||||
RunnerLoading: 'Loading',
|
||||
RunnerLoaded: 'Loaded',
|
||||
RunnerWarmingUp: 'WarmingUp',
|
||||
RunnerReady: 'Ready',
|
||||
RunnerRunning: 'Running',
|
||||
RunnerShutdown: 'Shutdown',
|
||||
RunnerFailed: 'Failed',
|
||||
};
|
||||
return kind ? statusMap[kind] || null : null;
|
||||
@@ -658,14 +733,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const has = (s: string) => statuses.includes(s);
|
||||
|
||||
if (statuses.length === 0) return { statusText: 'UNKNOWN', statusClass: 'inactive' };
|
||||
if (statuses.length === 0) return { statusText: 'PREPARING', statusClass: 'inactive' };
|
||||
if (has('Failed')) return { statusText: 'FAILED', statusClass: 'failed' };
|
||||
if (has('Shutdown')) return { statusText: 'SHUTDOWN', statusClass: 'inactive' };
|
||||
if (has('Loading')) return { statusText: 'LOADING', statusClass: 'starting' };
|
||||
if (has('WarmingUp')) return { statusText: 'WARMING UP', statusClass: 'starting' };
|
||||
if (has('Running')) return { statusText: 'RUNNING', statusClass: 'running' };
|
||||
if (has('Ready')) return { statusText: 'READY', statusClass: 'loaded' };
|
||||
if (has('Loaded')) return { statusText: 'LOADED', statusClass: 'loaded' };
|
||||
if (has('WaitingForModel')) return { statusText: 'WAITING', statusClass: 'starting' };
|
||||
if (has('InitializingBackend')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
|
||||
if (has('WaitingForInitialization')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
|
||||
|
||||
return { statusText: 'RUNNING', statusClass: 'active' };
|
||||
}
|
||||
@@ -683,6 +761,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
async function deleteInstance(instanceId: string) {
|
||||
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
|
||||
|
||||
// Get the model ID of the instance being deleted before we delete it
|
||||
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
|
||||
const wasSelected = selectedChatModel() === deletedInstanceModelId;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/instance/${instanceId}`, {
|
||||
method: 'DELETE',
|
||||
@@ -691,6 +773,24 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Failed to delete instance:', response.status);
|
||||
} else if (wasSelected) {
|
||||
// If we deleted the currently selected model, switch to another available model
|
||||
// Find another instance that isn't the one we just deleted
|
||||
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
|
||||
if (remainingInstances.length > 0) {
|
||||
// Select the last instance (most recently added, since objects preserve insertion order)
|
||||
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
|
||||
const newModelId = getInstanceModelId(lastInstance);
|
||||
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
|
||||
setSelectedChatModel(newModelId);
|
||||
} else {
|
||||
// Clear selection if no valid model found
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
} else {
|
||||
// No more instances, clear the selection
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error deleting instance:', error);
|
||||
@@ -964,6 +1064,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
function handleSliderMouseUp() {
|
||||
isDraggingSlider = false;
|
||||
saveLaunchDefaults();
|
||||
}
|
||||
|
||||
// Handle touch events for mobile
|
||||
@@ -983,6 +1084,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
function handleSliderTouchEnd() {
|
||||
isDraggingSlider = false;
|
||||
saveLaunchDefaults();
|
||||
}
|
||||
|
||||
const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0);
|
||||
@@ -1107,16 +1209,47 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="shooting-star" style="top: 50%; left: 40%; --duration: 45s; --delay: 30s;"></div>
|
||||
</div>
|
||||
|
||||
<HeaderNav showHome={chatStarted} onHome={handleGoHome} />
|
||||
{#if !topologyOnlyEnabled}
|
||||
<HeaderNav
|
||||
showHome={chatStarted}
|
||||
onHome={handleGoHome}
|
||||
showSidebarToggle={true}
|
||||
sidebarVisible={sidebarVisible}
|
||||
onToggleSidebar={toggleChatSidebarVisible}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<!-- Main Content -->
|
||||
<main class="flex-1 flex overflow-hidden relative">
|
||||
<!-- Left: Conversation History Sidebar (always visible) -->
|
||||
<!-- Left: Conversation History Sidebar (hidden in topology-only mode or when toggled off) -->
|
||||
{#if !topologyOnlyEnabled && sidebarVisible}
|
||||
<div class="w-80 flex-shrink-0 border-r border-exo-yellow/10">
|
||||
<ChatSidebar class="h-full" />
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if !chatStarted}
|
||||
{#if topologyOnlyEnabled}
|
||||
<!-- TOPOLOGY ONLY MODE: Full-screen topology -->
|
||||
<div class="flex-1 flex flex-col min-h-0 min-w-0 p-4" in:fade={{ duration: 300 }}>
|
||||
<div class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden">
|
||||
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
|
||||
<!-- Exit topology-only mode button -->
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleTopologyOnlyMode}
|
||||
class="absolute bottom-4 right-4 p-2 rounded border border-exo-yellow/30 bg-exo-dark-gray/80 hover:border-exo-yellow/50 hover:bg-exo-dark-gray transition-colors cursor-pointer backdrop-blur-sm"
|
||||
title="Exit topology only mode"
|
||||
>
|
||||
<svg class="w-5 h-5 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="5" r="2" fill="currentColor" />
|
||||
<circle cx="5" cy="19" r="2" fill="currentColor" />
|
||||
<circle cx="19" cy="19" r="2" fill="currentColor" />
|
||||
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{:else if !chatStarted}
|
||||
<!-- WELCOME STATE: Topology + Instance Controls (no left sidebar for cleaner look) -->
|
||||
<div class="flex-1 flex overflow-visible relative" in:fade={{ duration: 300 }} out:fade={{ duration: 200 }}>
|
||||
|
||||
@@ -1154,9 +1287,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
<div
|
||||
bind:this={instancesContainerRef}
|
||||
class="max-h-72 space-y-3 overflow-y-auto"
|
||||
class="max-h-72 xl:max-h-96 space-y-3 overflow-y-auto overflow-x-hidden py-px"
|
||||
>
|
||||
{#each Object.entries(instanceData) as [id, instance]}
|
||||
{@const downloadInfo = getInstanceDownloadStatus(id, instance)}
|
||||
@@ -1300,14 +1433,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{:else}
|
||||
{#each nodeProg.progress.files as f}
|
||||
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
|
||||
{@const isFileComplete = filePercent >= 100}
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
|
||||
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/80">{filePercent.toFixed(1)}%</span>
|
||||
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
|
||||
style="width: {filePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
@@ -1408,6 +1542,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
onclick={() => {
|
||||
if (modelCanFit) {
|
||||
selectPreviewModel(model.id);
|
||||
saveLaunchDefaults();
|
||||
isModelDropdownOpen = false;
|
||||
modelDropdownSearch = '';
|
||||
}
|
||||
@@ -1441,7 +1576,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Sharding:</div>
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
onclick={() => selectedSharding = 'Pipeline'}
|
||||
onclick={() => { selectedSharding = 'Pipeline'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Pipeline' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Pipeline' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1452,7 +1587,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
Pipeline
|
||||
</button>
|
||||
<button
|
||||
onclick={() => selectedSharding = 'Tensor'}
|
||||
onclick={() => { selectedSharding = 'Tensor'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Tensor' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Tensor' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1470,7 +1605,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Instance Type:</div>
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
onclick={() => selectedInstanceType = 'MlxRing'}
|
||||
onclick={() => { selectedInstanceType = 'MlxRing'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxRing' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxRing' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1481,7 +1616,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
MLX Ring
|
||||
</button>
|
||||
<button
|
||||
onclick={() => selectedInstanceType = 'MlxIbv'}
|
||||
onclick={() => { selectedInstanceType = 'MlxIbv'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxIbv' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxIbv' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1611,13 +1746,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
in:fade={{ duration: 300, delay: 100 }}
|
||||
>
|
||||
<div class="flex-1 overflow-y-auto px-8 py-6" bind:this={chatScrollRef}>
|
||||
<div class="max-w-3xl mx-auto">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatMessages scrollParent={chatScrollRef} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
|
||||
<div class="max-w-3xl mx-auto">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} />
|
||||
</div>
|
||||
</div>
|
||||
@@ -1655,10 +1790,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<!-- Panel Header -->
|
||||
<div class="flex items-center gap-2 mb-4">
|
||||
<div class="w-2 h-2 bg-exo-yellow rounded-full shadow-[0_0_8px_rgba(255,215,0,0.6)] animate-pulse"></div>
|
||||
<h3 class="text-sm text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
|
||||
<h3 class="text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
|
||||
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
|
||||
</div>
|
||||
<div class="space-y-3 max-h-72 overflow-y-auto pr-1">
|
||||
<div class="space-y-3 max-h-72 xl:max-h-96 overflow-y-auto overflow-x-hidden py-px pr-1">
|
||||
{#each Object.entries(instanceData) as [id, instance]}
|
||||
{@const downloadInfo = getInstanceDownloadStatus(id, instance)}
|
||||
{@const statusText = downloadInfo.statusText}
|
||||
@@ -1701,28 +1836,28 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="flex justify-between items-start mb-2 pl-2">
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
|
||||
<span class="text-exo-light-gray font-mono text-xs tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400/80 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-sm font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[10px] text-white/60 hover:text-exo-yellow transition-colors mt-0.5"
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
href={`https://huggingface.co/${instanceModelId}`}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
aria-label="View model on Hugging Face"
|
||||
>
|
||||
<span>Hugging Face</span>
|
||||
<svg class="w-3 h-3" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<svg class="w-3.5 h-3.5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M14 3h7v7"/>
|
||||
<path d="M10 14l11-11"/>
|
||||
<path d="M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6"/>
|
||||
@@ -1733,68 +1868,84 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-white/60 text-xs font-mono">{instanceInfo.nodeNames.join(', ')}</div>
|
||||
{/if}
|
||||
{#if debugEnabled && instanceConnections.length > 0}
|
||||
<div class="mt-1 space-y-0.5">
|
||||
{#each instanceConnections as conn}
|
||||
<div class="text-[10px] leading-snug font-mono text-white/70">
|
||||
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
|
||||
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
|
||||
</div>
|
||||
{/each}
|
||||
<div class="mt-2 space-y-1">
|
||||
{#each instanceConnections as conn}
|
||||
<div class="text-[11px] leading-snug font-mono text-white/70">
|
||||
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
|
||||
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Download Progress -->
|
||||
{#if downloadInfo.isDownloading && downloadInfo.progress}
|
||||
<div class="mt-2 space-y-1">
|
||||
<div class="flex justify-between text-xs font-mono">
|
||||
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
|
||||
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Download Progress -->
|
||||
{#if downloadInfo.isDownloading && downloadInfo.progress}
|
||||
<div class="mt-2 space-y-1">
|
||||
<div class="flex justify-between text-sm font-mono">
|
||||
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
|
||||
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {downloadInfo.progress.percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
|
||||
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
|
||||
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
|
||||
</div>
|
||||
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {downloadInfo.progress.percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
{#if downloadInfo.perNode.length > 0}
|
||||
<div class="mt-2 space-y-1.5 max-h-48 overflow-y-auto pr-1">
|
||||
{#each downloadInfo.perNode as nodeProg}
|
||||
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
|
||||
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
|
||||
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
|
||||
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
|
||||
</div>
|
||||
</div>
|
||||
{#if downloadInfo.perNode.length > 0}
|
||||
<div class="mt-2 space-y-2 max-h-48 overflow-y-auto pr-1">
|
||||
{#each downloadInfo.perNode as nodeProg}
|
||||
{@const nodePercent = Math.min(100, Math.max(0, nodeProg.progress.percentage))}
|
||||
{@const isExpanded = instanceDownloadExpandedNodes.has(nodeProg.nodeId)}
|
||||
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
|
||||
<button
|
||||
type="button"
|
||||
class="w-full text-left space-y-1.5"
|
||||
onclick={() => toggleInstanceDownloadDetails(nodeProg.nodeId)}
|
||||
>
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
|
||||
<span class="text-white/80 truncate pr-2">{nodeProg.nodeName}</span>
|
||||
<span class="text-blue-300">{Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%</span>
|
||||
<span class="flex items-center gap-1 text-blue-300">
|
||||
{nodePercent.toFixed(1)}%
|
||||
<svg class="w-3 h-3 text-exo-light-gray" viewBox="0 0 20 20" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<path d="M6 8l4 4 4-4" class={isExpanded ? 'transform rotate-180 origin-center transition-transform duration-150' : 'transition-transform duration-150'}></path>
|
||||
</svg>
|
||||
</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mb-1.5">
|
||||
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-blue-500/80 transition-all duration-300"
|
||||
style="width: {Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {nodePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
|
||||
<span>{formatBytes(nodeProg.progress.downloadedBytes)} / {formatBytes(nodeProg.progress.totalBytes)}</span>
|
||||
<span>{formatSpeed(nodeProg.progress.speed)} • ETA {formatEta(nodeProg.progress.etaMs)}</span>
|
||||
</div>
|
||||
{#if nodeProg.progress.files.length > 0}
|
||||
{@const inProgressFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) < 100)}
|
||||
{@const completedFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) >= 100)}
|
||||
{#if inProgressFiles.length > 0}
|
||||
<div class="space-y-1">
|
||||
{#each inProgressFiles as f}
|
||||
<div class="text-[10px] font-mono text-exo-light-gray/80">
|
||||
<div class="flex items-center justify-between">
|
||||
</button>
|
||||
|
||||
{#if isExpanded}
|
||||
<div class="mt-2 space-y-1.5">
|
||||
{#if nodeProg.progress.files.length === 0}
|
||||
<div class="text-[11px] font-mono text-exo-light-gray/70">No file details reported.</div>
|
||||
{:else}
|
||||
{#each nodeProg.progress.files as f}
|
||||
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
|
||||
{@const isFileComplete = filePercent >= 100}
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
|
||||
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/70">{Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%</span>
|
||||
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/50 rounded-sm overflow-hidden mt-0.5">
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70"
|
||||
style="width: {Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
|
||||
style="width: {filePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex items-center justify-between text-[10px] text-exo-light-gray/70 mt-0.5">
|
||||
@@ -1803,27 +1954,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{#if completedFiles.length > 0}
|
||||
<div class="pt-1 space-y-0.5">
|
||||
{#each completedFiles as f}
|
||||
<div class="text-[10px] font-mono text-exo-light-gray/70 flex items-center justify-between">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/60">100%</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-sm text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
|
||||
{:else}
|
||||
<div class="text-sm {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-xs text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
|
||||
{:else}
|
||||
<div class="text-xs {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -199,7 +199,13 @@
|
||||
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
|
||||
?? (downloadPayload as Record<string, unknown>).downloadProgress
|
||||
?? {};
|
||||
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
|
||||
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
|
||||
const totalBytes = getBytes(
|
||||
(downloadPayload as Record<string, unknown>).total_bytes
|
||||
?? (downloadPayload as Record<string, unknown>).totalBytes
|
||||
?? (rawProgress as Record<string, unknown>).total_bytes
|
||||
?? (rawProgress as Record<string, unknown>).totalBytes
|
||||
);
|
||||
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
|
||||
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
|
||||
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
|
||||
@@ -332,8 +338,13 @@
|
||||
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
|
||||
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
|
||||
</div>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
|
||||
<div>
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
|
||||
</div>
|
||||
<div class="text-exo-light-gray normal-case tracking-normal">
|
||||
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -345,13 +356,19 @@
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-dark-gray/60 p-3 space-y-2">
|
||||
<div class="flex items-center justify-between gap-3">
|
||||
<div class="min-w-0 space-y-0.5">
|
||||
<div class="text-sm font-mono text-white truncate">{model.prettyName ?? model.modelId}</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono truncate">
|
||||
{model.modelId}
|
||||
</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
<div
|
||||
class="text-xs font-mono text-white truncate"
|
||||
title={model.prettyName ?? model.modelId}
|
||||
>{model.prettyName ?? model.modelId}</div>
|
||||
<div
|
||||
class="text-[10px] text-exo-light-gray font-mono truncate"
|
||||
title={model.modelId}
|
||||
>{model.modelId}</div>
|
||||
{#if model.status !== 'completed'}
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-xs font-mono {pct >= 100 ? 'text-green-400' : pct <= 0 ? 'text-red-400' : 'text-exo-yellow'}">
|
||||
@@ -379,7 +396,7 @@
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
{#if model.status !== 'completed'}
|
||||
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
|
||||
{/if}
|
||||
@@ -426,14 +443,14 @@
|
||||
<style>
|
||||
.downloads-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
|
||||
grid-template-columns: repeat(auto-fill, minmax(320px, 1fr));
|
||||
}
|
||||
@media (min-width: 1024px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(3, minmax(0, 1fr));
|
||||
}
|
||||
}
|
||||
@media (min-width: 1440px) {
|
||||
@media (min-width: 1600px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(4, minmax(0, 1fr));
|
||||
}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import tailwindcss from '@tailwindcss/vite';
|
||||
import { sveltekit } from '@sveltejs/kit/vite';
|
||||
import { defineConfig } from 'vite';
|
||||
import tailwindcss from "@tailwindcss/vite";
|
||||
import { sveltekit } from "@sveltejs/kit/vite";
|
||||
import { defineConfig } from "vite";
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [tailwindcss(), sveltekit()],
|
||||
server: {
|
||||
proxy: {
|
||||
'/v1': 'http://localhost:52415',
|
||||
'/state': 'http://localhost:52415',
|
||||
'/models': 'http://localhost:52415',
|
||||
'/instance': 'http://localhost:52415'
|
||||
}
|
||||
}
|
||||
"/v1": "http://localhost:52415",
|
||||
"/state": "http://localhost:52415",
|
||||
"/models": "http://localhost:52415",
|
||||
"/instance": "http://localhost:52415",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
212
docs/api.md
Normal file
212
docs/api.md
Normal file
@@ -0,0 +1,212 @@
|
||||
# EXO API – Technical Reference
|
||||
|
||||
This document describes the REST API exposed by the **EXO ** service, as implemented in:
|
||||
|
||||
`src/exo/master/api.py`
|
||||
|
||||
The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using an OpenAI-compatible interface.
|
||||
|
||||
Base URL example:
|
||||
|
||||
```
|
||||
http://localhost:52415
|
||||
```
|
||||
|
||||
## 1. General / Meta Endpoints
|
||||
|
||||
### Get Master Node ID
|
||||
|
||||
**GET** `/node_id`
|
||||
|
||||
Returns the identifier of the current master node.
|
||||
|
||||
**Response (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"node_id": "node-1234"
|
||||
}
|
||||
```
|
||||
|
||||
### Get Cluster State
|
||||
|
||||
**GET** `/state`
|
||||
|
||||
Returns the current state of the cluster, including nodes and active instances.
|
||||
|
||||
**Response:**
|
||||
JSON object describing topology, nodes, and instances.
|
||||
|
||||
### Get Events
|
||||
|
||||
**GET** `/events`
|
||||
|
||||
Returns the list of internal events recorded by the master (mainly for debugging and observability).
|
||||
|
||||
**Response:**
|
||||
Array of event objects.
|
||||
|
||||
## 2. Model Instance Management
|
||||
|
||||
### Create Instance
|
||||
|
||||
**POST** `/instance`
|
||||
|
||||
Creates a new model instance in the cluster.
|
||||
|
||||
**Request body (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"instance": {
|
||||
"model_id": "llama-3.2-1b",
|
||||
"placement": { }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
JSON description of the created instance.
|
||||
|
||||
### Delete Instance
|
||||
|
||||
**DELETE** `/instance/{instance_id}`
|
||||
|
||||
Deletes an existing instance by ID.
|
||||
|
||||
**Path parameters:**
|
||||
|
||||
* `instance_id`: string, ID of the instance to delete
|
||||
|
||||
**Response:**
|
||||
Status / confirmation JSON.
|
||||
|
||||
### Get Instance
|
||||
|
||||
**GET** `/instance/{instance_id}`
|
||||
|
||||
Returns details of a specific instance.
|
||||
|
||||
**Path parameters:**
|
||||
|
||||
* `instance_id`: string
|
||||
|
||||
**Response:**
|
||||
JSON description of the instance.
|
||||
|
||||
### Preview Placements
|
||||
|
||||
**GET** `/instance/previews?model_id=...`
|
||||
|
||||
Returns possible placement previews for a given model.
|
||||
|
||||
**Query parameters:**
|
||||
|
||||
* `model_id`: string, required
|
||||
|
||||
**Response:**
|
||||
Array of placement preview objects.
|
||||
|
||||
### Compute Placement
|
||||
|
||||
**GET** `/instance/placement`
|
||||
|
||||
Computes a placement for a potential instance without creating it.
|
||||
|
||||
**Query parameters (typical):**
|
||||
|
||||
* `model_id`: string
|
||||
* `sharding`: string or config
|
||||
* `instance_meta`: JSON-encoded metadata
|
||||
* `min_nodes`: integer
|
||||
|
||||
**Response:**
|
||||
JSON object describing the proposed placement / instance configuration.
|
||||
|
||||
### Place Instance (Dry Operation)
|
||||
|
||||
**POST** `/place_instance`
|
||||
|
||||
Performs a placement operation for an instance (planning step), without necessarily creating it.
|
||||
|
||||
**Request body:**
|
||||
JSON describing the instance to be placed.
|
||||
|
||||
**Response:**
|
||||
Placement result.
|
||||
|
||||
## 3. Models
|
||||
|
||||
### List Models
|
||||
|
||||
**GET** `/models`
|
||||
**GET** `/v1/models` (alias)
|
||||
|
||||
Returns the list of available models and their metadata.
|
||||
|
||||
**Response:**
|
||||
Array of model descriptors.
|
||||
|
||||
## 4. Inference / Chat Completions
|
||||
|
||||
### OpenAI-Compatible Chat Completions
|
||||
|
||||
**POST** `/v1/chat/completions`
|
||||
|
||||
Executes a chat completion request using an OpenAI-compatible schema. Supports streaming and non-streaming modes.
|
||||
|
||||
**Request body (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama-3.2-1b",
|
||||
"messages": [
|
||||
{ "role": "system", "content": "You are a helpful assistant." },
|
||||
{ "role": "user", "content": "Hello" }
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
OpenAI-compatible chat completion response.
|
||||
|
||||
### Benchmarked Chat Completions
|
||||
|
||||
**POST** `/bench/chat/completions`
|
||||
|
||||
Same as `/v1/chat/completions`, but also returns performance and generation statistics.
|
||||
|
||||
**Request body:**
|
||||
Same schema as `/v1/chat/completions`.
|
||||
|
||||
**Response:**
|
||||
Chat completion plus benchmarking metrics.
|
||||
|
||||
## 5. Complete Endpoint Summary
|
||||
|
||||
```
|
||||
GET /node_id
|
||||
GET /state
|
||||
GET /events
|
||||
|
||||
POST /instance
|
||||
GET /instance/{instance_id}
|
||||
DELETE /instance/{instance_id}
|
||||
|
||||
GET /instance/previews
|
||||
GET /instance/placement
|
||||
POST /place_instance
|
||||
|
||||
GET /models
|
||||
GET /v1/models
|
||||
|
||||
POST /v1/chat/completions
|
||||
POST /bench/chat/completions
|
||||
```
|
||||
|
||||
## 6. Notes
|
||||
|
||||
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
|
||||
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
|
||||
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.
|
||||
185
flake.lock
generated
185
flake.lock
generated
@@ -1,5 +1,42 @@
|
||||
{
|
||||
"nodes": {
|
||||
"crane": {
|
||||
"locked": {
|
||||
"lastModified": 1767744144,
|
||||
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"dream2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"purescript-overlay": "purescript-overlay",
|
||||
"pyproject-nix": "pyproject-nix"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1765953015,
|
||||
"narHash": "sha256-5FBZbbWR1Csp3Y2icfRkxMJw/a/5FGg8hCXej2//bbI=",
|
||||
"owner": "nix-community",
|
||||
"repo": "dream2nix",
|
||||
"rev": "69eb01fa0995e1e90add49d8ca5bcba213b0416f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "dream2nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"fenix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
@@ -8,11 +45,11 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761893049,
|
||||
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
|
||||
"lastModified": 1768287139,
|
||||
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -21,25 +58,59 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"lastModified": 1696426674,
|
||||
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-parts": {
|
||||
"inputs": {
|
||||
"nixpkgs-lib": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1768135262,
|
||||
"narHash": "sha256-PVvu7OqHBGWN16zSi6tEmPwwHQ4rLPU9Plvs8/1TUBY=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"rev": "80daad04eddbbf5a4d883996a73f3f542fa437ac",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1768127708,
|
||||
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-swift": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
@@ -50,27 +121,74 @@
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"purescript-overlay": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat",
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"slimlock": "slimlock"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1728546539,
|
||||
"narHash": "sha256-Sws7w0tlnjD+Bjck1nv29NjC5DbL6nH5auL9Ex9Iz2A=",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "purescript-overlay",
|
||||
"rev": "4ad4c15d07bd899d7346b331f377606631eb0ee4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "purescript-overlay",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"crane": "crane",
|
||||
"dream2nix": "dream2nix",
|
||||
"fenix": "fenix",
|
||||
"flake-utils": "flake-utils",
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1761849405,
|
||||
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
|
||||
"lastModified": 1768224240,
|
||||
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"rev": "725349602e525df37f377701e001fe8aab807878",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -80,18 +198,25 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"slimlock": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"purescript-overlay",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"lastModified": 1688756706,
|
||||
"narHash": "sha256-xzkkMv3neJJJ89zo3o2ojp7nFeaZc2G0fYwNXNJRFlo=",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "slimlock",
|
||||
"rev": "cf72723f59e2340d24881fd7bf61cb113b4c407c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "slimlock",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
@@ -102,11 +227,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1762938485,
|
||||
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
|
||||
"lastModified": 1768158989,
|
||||
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
|
||||
"owner": "numtide",
|
||||
"repo": "treefmt-nix",
|
||||
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
|
||||
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
202
flake.nix
202
flake.nix
@@ -3,122 +3,134 @@
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
# Provides Rust dev-env integration:
|
||||
|
||||
flake-parts = {
|
||||
url = "github:hercules-ci/flake-parts";
|
||||
inputs.nixpkgs-lib.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
crane.url = "github:ipetkov/crane";
|
||||
|
||||
fenix = {
|
||||
url = "github:nix-community/fenix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
# Provides formatting infrastructure:
|
||||
|
||||
treefmt-nix = {
|
||||
url = "github:numtide/treefmt-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
dream2nix = {
|
||||
url = "github:nix-community/dream2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
|
||||
};
|
||||
|
||||
# TODO: figure out caching story
|
||||
# nixConfig = {
|
||||
# # nix community cachix
|
||||
# extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=";
|
||||
# extra-substituters = "https://nix-community.cachix.org";
|
||||
# };
|
||||
nixConfig = {
|
||||
extra-trusted-public-keys = "exo.cachix.org-1:okq7hl624TBeAR3kV+g39dUFSiaZgLRkLsFBCuJ2NZI=";
|
||||
extra-substituters = "https://exo.cachix.org";
|
||||
};
|
||||
|
||||
outputs =
|
||||
inputs:
|
||||
let
|
||||
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
|
||||
systems = [
|
||||
"x86_64-linux"
|
||||
"aarch64-darwin"
|
||||
"aarch64-linux"
|
||||
];
|
||||
fenixToolchain = system: inputs.fenix.packages.${system}.stable;
|
||||
in
|
||||
inputs.flake-utils.lib.eachSystem systems (
|
||||
system:
|
||||
let
|
||||
pkgs = import inputs.nixpkgs {
|
||||
inherit system;
|
||||
overlays = [ ];
|
||||
};
|
||||
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
|
||||
projectRootFile = "flake.nix";
|
||||
programs = {
|
||||
ruff-format.enable = true;
|
||||
ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
rustfmt.enable = true;
|
||||
rustfmt.package = (fenixToolchain system).rustfmt;
|
||||
nixpkgs-fmt.enable = true;
|
||||
|
||||
imports = [
|
||||
inputs.treefmt-nix.flakeModule
|
||||
./dashboard/parts.nix
|
||||
./rust/parts.nix
|
||||
];
|
||||
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
treefmt = {
|
||||
projectRootFile = "flake.nix";
|
||||
programs = {
|
||||
nixpkgs-fmt.enable = true;
|
||||
ruff-format = {
|
||||
enable = true;
|
||||
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
};
|
||||
rustfmt = {
|
||||
enable = true;
|
||||
package = config.rust.toolchain;
|
||||
};
|
||||
prettier = {
|
||||
enable = true;
|
||||
includes = [ "*.ts" ];
|
||||
};
|
||||
swift-format = {
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
in
|
||||
{
|
||||
formatter = treefmtEval.config.build.wrapper;
|
||||
checks.formatting = treefmtEval.config.build.check inputs.self;
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
|
||||
devShells.default = pkgs.mkShell {
|
||||
packages =
|
||||
with pkgs;
|
||||
[
|
||||
# PYTHON
|
||||
python313
|
||||
uv
|
||||
ruff
|
||||
basedpyright
|
||||
|
||||
# RUST
|
||||
((fenixToolchain system).withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
])
|
||||
cargo-machete
|
||||
bacon
|
||||
rustup # Just here to make RustRover happy
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
|
||||
# SVELTE
|
||||
nodejs
|
||||
|
||||
# MISC
|
||||
just
|
||||
jq
|
||||
]
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
|
||||
# IFCONFIG
|
||||
unixtools.ifconfig
|
||||
|
||||
# Build dependencies for Linux
|
||||
pkg-config
|
||||
openssl
|
||||
])
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
|
||||
# MACMON
|
||||
macmon
|
||||
]);
|
||||
|
||||
shellHook = ''
|
||||
# PYTHON
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
|
||||
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
|
||||
# Build environment for Linux
|
||||
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
''}
|
||||
echo
|
||||
echo "🍎🍎 Run 'just <recipe>' to get started"
|
||||
just --list
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
inputsFrom = [ self'.checks.cargo-build ];
|
||||
|
||||
packages =
|
||||
[
|
||||
# FORMATTING
|
||||
config.treefmt.build.wrapper
|
||||
|
||||
# PYTHON
|
||||
python313
|
||||
uv
|
||||
ruff
|
||||
basedpyright
|
||||
|
||||
# RUST
|
||||
config.rust.toolchain
|
||||
maturin
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
|
||||
# SVELTE
|
||||
nodejs
|
||||
|
||||
# MISC
|
||||
just
|
||||
jq
|
||||
]
|
||||
++ lib.optionals stdenv.isLinux [
|
||||
unixtools.ifconfig
|
||||
]
|
||||
++ lib.optionals stdenv.isDarwin [
|
||||
macmon
|
||||
];
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
|
||||
shellHook = ''
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
|
||||
${lib.optionalString stdenv.isLinux ''
|
||||
export LD_LIBRARY_PATH="${openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
''}
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
2
justfile
2
justfile
@@ -1,3 +1,5 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
nix fmt
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "exo"
|
||||
version = "0.10.0"
|
||||
version = "0.3.0"
|
||||
description = "Exo"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
@@ -8,31 +8,22 @@ dependencies = [
|
||||
"aiofiles>=24.1.0",
|
||||
"aiohttp>=3.12.14",
|
||||
"types-aiofiles>=24.1.0.20250708",
|
||||
"typeguard>=4.4.4",
|
||||
"pydantic>=2.11.7",
|
||||
"base58>=2.1.1",
|
||||
"cryptography>=45.0.5",
|
||||
"fastapi>=0.116.1",
|
||||
"filelock>=3.18.0",
|
||||
"aiosqlite>=0.21.0",
|
||||
"networkx>=3.5",
|
||||
"protobuf>=6.32.0",
|
||||
"rich>=14.1.0",
|
||||
"rustworkx>=0.17.1",
|
||||
"sqlmodel>=0.0.24",
|
||||
"sqlalchemy[asyncio]>=2.0.43",
|
||||
"greenlet>=3.2.4",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"psutil>=7.0.0",
|
||||
"loguru>=0.7.3",
|
||||
"textual>=5.3.0",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"bidict>=0.23.1",
|
||||
"mlx>=0.30.1",
|
||||
"mlx-lm>=0.28.3",
|
||||
"mlx==0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -43,6 +34,7 @@ exo = "exo.main:main"
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"basedpyright>=1.29.0",
|
||||
"pyinstaller>=6.17.0",
|
||||
"pytest>=8.4.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
@@ -80,7 +72,7 @@ build-backend = "uv_build"
|
||||
###
|
||||
|
||||
[tool.basedpyright]
|
||||
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"]
|
||||
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src", "bench"]
|
||||
typeCheckingMode = "strict"
|
||||
failOnWarnings = true
|
||||
|
||||
@@ -108,6 +100,7 @@ root = "src"
|
||||
|
||||
# supported platforms for this project
|
||||
[tool.uv]
|
||||
prerelease = "allow"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
|
||||
2
rust/clippy.toml
Normal file
2
rust/clippy.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
# we can manually exclude false-positive lint errors for dual packages (if in dependencies)
|
||||
#allowed-duplicate-crates = ["hashbrown"]
|
||||
@@ -5,6 +5,8 @@ edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
path = "src/lib.rs"
|
||||
name = "exo_pyo3_bindings"
|
||||
|
||||
# "cdylib" needed to produce shared library for Python to import
|
||||
@@ -20,24 +22,46 @@ doc = false
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
networking.workspace = true
|
||||
networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { workspace = true, features = ["experimental-async"] }
|
||||
pyo3-stub-gen.workspace = true
|
||||
# pyo3-async-runtimes = { workspace = true, features = ["attributes", "tokio-runtime", "testing"] }
|
||||
pyo3-log.workspace = true
|
||||
pyo3 = { version = "0.27.1", features = [
|
||||
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
|
||||
"nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
#"experimental-inspect", # inspection of generated binary => easier to automate type-hint generation
|
||||
#"py-clone", # adding Clone-ing of `Py<T>` without GIL (may cause panics - remove if panics happen)
|
||||
"multiple-pymethods", # allows multiple #[pymethods] sections per class
|
||||
|
||||
# integrations with other libraries
|
||||
"arc_lock", "bigdecimal", "either", "hashbrown", "indexmap", "num-bigint", "num-complex", "num-rational",
|
||||
"ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.17.2" }
|
||||
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
|
||||
pyo3-log = "0.13.2"
|
||||
|
||||
# macro dependencies
|
||||
extend.workspace = true
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
|
||||
# async runtime
|
||||
tokio = { workspace = true, features = ["full", "tracing"] }
|
||||
futures = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
postcard = { workspace = true, features = ["use-std"] }
|
||||
rand.workspace = true
|
||||
n0-future.workspace = true
|
||||
once_cell = "1.21.3"
|
||||
thread_local = "1.1.9"
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
|
||||
|
||||
# Tracing
|
||||
@@ -46,9 +70,8 @@ n0-future.workspace = true
|
||||
#console-subscriber = "0.1.5"
|
||||
#tracing-log = "0.2.0"
|
||||
log = { workspace = true }
|
||||
env_logger = { workspace = true }
|
||||
env_logger = "0.11"
|
||||
|
||||
|
||||
# Networking
|
||||
iroh = { workspace = true }
|
||||
iroh-gossip = { workspace = true }
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
from exo_pyo3_bindings import RustNetworkingHandle, Keypair
|
||||
from asyncio import run
|
||||
|
||||
|
||||
async def main():
|
||||
nh = await RustNetworkingHandle.create(Keypair.generate_ed25519(), "mdns_example")
|
||||
recv = await nh.get_connection_receiver()
|
||||
while True:
|
||||
cm = await recv.receive()
|
||||
print(
|
||||
f"Endpoint({cm.endpoint_id}, reachable={list(map(lambda it: it.ip_addr(), cm.current_transport_addrs)) if cm.current_transport_addrs is not None else None})"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run(main())
|
||||
@@ -2,63 +2,220 @@
|
||||
# ruff: noqa: E501, F401
|
||||
|
||||
import builtins
|
||||
import enum
|
||||
import typing
|
||||
|
||||
@typing.final
|
||||
class EndpointId:
|
||||
class AllQueuesFullError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> AllQueuesFullError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class IpAddress:
|
||||
def __str__(self) -> builtins.str: ...
|
||||
def ip_addr(self) -> builtins.str: ...
|
||||
def port(self) -> builtins.int: ...
|
||||
def zone_id(self) -> typing.Optional[builtins.int]: ...
|
||||
class ConnectionUpdate:
|
||||
@property
|
||||
def update_type(self) -> ConnectionUpdateType:
|
||||
r"""
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@property
|
||||
def remote_ipv4(self) -> builtins.str:
|
||||
r"""
|
||||
Remote connection's IPv4 address.
|
||||
"""
|
||||
@property
|
||||
def remote_tcp_port(self) -> builtins.int:
|
||||
r"""
|
||||
Remote connection's TCP port.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Keypair:
|
||||
r"""
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ed25519() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_postcard_encoding(bytes: bytes) -> Keypair:
|
||||
def generate_ecdsa() -> Keypair:
|
||||
r"""
|
||||
Decode a postcard structure into a keypair
|
||||
Generate a new ECDSA keypair.
|
||||
"""
|
||||
def to_postcard_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key with the postcard format
|
||||
"""
|
||||
def endpoint_id(self) -> EndpointId:
|
||||
r"""
|
||||
Read out the endpoint id corresponding to this keypair
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class RustConnectionMessage:
|
||||
@property
|
||||
def endpoint_id(self) -> EndpointId: ...
|
||||
@property
|
||||
def current_transport_addrs(self) -> typing.Optional[builtins.set[IpAddress]]: ...
|
||||
|
||||
@typing.final
|
||||
class RustConnectionReceiver:
|
||||
async def receive(self) -> RustConnectionMessage: ...
|
||||
|
||||
@typing.final
|
||||
class RustNetworkingHandle:
|
||||
@staticmethod
|
||||
async def create(identity: Keypair, namespace: builtins.str) -> RustNetworkingHandle: ...
|
||||
async def subscribe(self, topic: builtins.str) -> tuple[RustSender, RustReceiver]: ...
|
||||
async def get_connection_receiver(self) -> RustConnectionReceiver: ...
|
||||
def generate_secp256k1() -> Keypair:
|
||||
r"""
|
||||
Generate a new Secp256k1 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
"""
|
||||
@staticmethod
|
||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
format (i.e. unencrypted) as defined in [RFC5208].
|
||||
|
||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
"""
|
||||
@staticmethod
|
||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
structure as defined in [RFC5915].
|
||||
|
||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
"""
|
||||
@staticmethod
|
||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key as protobuf structure.
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class RustReceiver:
|
||||
async def receive(self) -> bytes: ...
|
||||
class Multiaddr:
|
||||
r"""
|
||||
Representation of a Multiaddr.
|
||||
"""
|
||||
@staticmethod
|
||||
def empty() -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress.
|
||||
"""
|
||||
@staticmethod
|
||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress with the given capacity.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its byte slice representation.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string: builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its string representation.
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
Return the length in bytes of this multiaddress.
|
||||
"""
|
||||
def is_empty(self) -> builtins.bool:
|
||||
r"""
|
||||
Returns true if the length of this multiaddress is 0.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
"""
|
||||
def to_string(self) -> builtins.str:
|
||||
r"""
|
||||
Convert a Multiaddr to a string.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class RustSender:
|
||||
async def send(self, message: bytes) -> None: ...
|
||||
class NetworkingHandle:
|
||||
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
|
||||
async def connection_update_recv(self) -> ConnectionUpdate:
|
||||
r"""
|
||||
Receives the next `ConnectionUpdate` from networking.
|
||||
"""
|
||||
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
|
||||
r"""
|
||||
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
will sleep until a `ConnectionUpdate`s is sent.
|
||||
"""
|
||||
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Subscribe to a `GossipSub` topic.
|
||||
|
||||
Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
"""
|
||||
async def gossipsub_unsubscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Unsubscribes from a `GossipSub` topic.
|
||||
|
||||
Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
"""
|
||||
async def gossipsub_publish(self, topic: builtins.str, data: bytes) -> None:
|
||||
r"""
|
||||
Publishes a message with multiple topics to the `GossipSub` network.
|
||||
|
||||
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
"""
|
||||
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
|
||||
r"""
|
||||
Receives the next message from the `GossipSub` network.
|
||||
"""
|
||||
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
|
||||
r"""
|
||||
Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
will sleep until a message is sent.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> NoPeersSubscribedToTopicError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class PeerId:
|
||||
r"""
|
||||
Identifier of a peer of the network.
|
||||
|
||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
"""
|
||||
@staticmethod
|
||||
def random() -> PeerId:
|
||||
r"""
|
||||
Generates a random peer ID from a cryptographically secure PRNG.
|
||||
|
||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> PeerId:
|
||||
r"""
|
||||
Parses a `PeerId` from bytes.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Returns a raw bytes representation of this `PeerId`.
|
||||
"""
|
||||
def to_base58(self) -> builtins.str:
|
||||
r"""
|
||||
Returns a base-58 encoded string of this `PeerId`.
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
Connection or disconnection event discriminant type.
|
||||
"""
|
||||
Connected = ...
|
||||
Disconnected = ...
|
||||
|
||||
|
||||
@@ -8,8 +8,7 @@ version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" },
|
||||
{ name = "Evan Quiney", email = "evanev7@gmail.com" }
|
||||
{ name = "Andrei Cravtov", email = "the.andrei.cravtov@gmail.com" }
|
||||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = []
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
//! SEE: <https://pyo3.rs/v0.27.1/async-await.html#detaching-from-the-interpreter-across-await>
|
||||
//! SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
//!
|
||||
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pin_project::pin_project;
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::prelude::*;
|
||||
use std::{
|
||||
future::Future,
|
||||
@@ -8,36 +10,31 @@ use std::{
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
#[pin_project]
|
||||
#[repr(transparent)]
|
||||
pub struct AllowThreads<F>(F);
|
||||
pub(crate) struct AllowThreads<F>(#[pin] F);
|
||||
|
||||
impl<F> AllowThreads<F>
|
||||
where
|
||||
Self: Future,
|
||||
{
|
||||
pub(crate) const fn new(f: F) -> Self {
|
||||
pub fn new(f: F) -> Self {
|
||||
Self(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Future for AllowThreads<F>
|
||||
where
|
||||
F: Future + Unpin + Send,
|
||||
F::Output: Send,
|
||||
F: Future + Ungil,
|
||||
F::Output: Ungil,
|
||||
{
|
||||
type Output = Result<F::Output, PyErr>;
|
||||
type Output = F::Output;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let waker = cx.waker();
|
||||
match Python::try_attach(|py| {
|
||||
py.detach(|| pin!(&mut self.0).poll(&mut Context::from_waker(waker)))
|
||||
}) {
|
||||
Some(Poll::Pending) => Poll::Pending,
|
||||
Some(Poll::Ready(t)) => Poll::Ready(Ok(t)),
|
||||
// TODO: this doesn't actually work - graceful py shutdown handling
|
||||
None => Poll::Ready(Err(PyRuntimeError::new_err(
|
||||
"Python runtime shutdown while awaiting a future",
|
||||
))),
|
||||
}
|
||||
Python::with_gil(|py| {
|
||||
py.allow_threads(|| self.project().0.poll(&mut Context::from_waker(waker)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use pyo3_stub_gen::Result;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
|
||||
let stub = exo_pyo3_bindings::stub_info()?;
|
||||
stub.generate()?;
|
||||
Ok(())
|
||||
|
||||
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
240
rust/exo_pyo3_bindings/src/examples/mod.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
//! This module exists to hold examples of some pyo3 patterns that may be too complex to
|
||||
//! re-create from scratch, but too inhomogenous to create an abstraction/wrapper around.
|
||||
//!
|
||||
//! Pattern examples include:
|
||||
//! - Async task handles: with GC-integrated cleanup
|
||||
//! - Sync/async callbacks from python: with propper eventloop handling
|
||||
//!
|
||||
//! Mutability pattern: https://pyo3.rs/v0.26.0/async-await.html#send--static-constraint
|
||||
//! - Store mutable fields in tokio's `Mutex<T>`
|
||||
//! - For async code: take `&self` and `.lock().await`
|
||||
//! - For sync code: take `&mut self` and `.get_mut()`
|
||||
|
||||
use crate::ext::{PyResultExt as _, ResultExt as _, TokioRuntimeExt as _};
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::BoxFuture;
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::{
|
||||
Bound, Py, PyAny, PyErr, PyResult, PyTraverseError, PyVisit, Python, pyclass, pymethods,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
|
||||
fn needs_tokio_runtime() {
|
||||
tokio::runtime::Handle::current();
|
||||
}
|
||||
|
||||
type SyncCallback = Box<dyn Fn() + Send + Sync>;
|
||||
type AsyncCallback = Box<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
|
||||
|
||||
enum AsyncTaskMessage {
|
||||
SyncCallback(SyncCallback),
|
||||
AsyncCallback(AsyncCallback),
|
||||
}
|
||||
|
||||
async fn async_task(
|
||||
sender: mpsc::UnboundedSender<()>,
|
||||
mut receiver: mpsc::UnboundedReceiver<AsyncTaskMessage>,
|
||||
) {
|
||||
log::info!("RUST: async task started");
|
||||
|
||||
// task state
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(1));
|
||||
|
||||
let mut sync_cbs: Vec<SyncCallback> = vec![];
|
||||
let mut async_cbs: Vec<AsyncCallback> = vec![];
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// handle incoming messages from task-handle
|
||||
message = receiver.recv() => {
|
||||
// handle closed channel by exiting
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming event
|
||||
match message {
|
||||
AsyncTaskMessage::SyncCallback(cb) => {
|
||||
sync_cbs.push(cb);
|
||||
}
|
||||
AsyncTaskMessage::AsyncCallback(cb) => {
|
||||
async_cbs.push(cb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle all other events
|
||||
_ = interval.tick() => {
|
||||
log::info!("RUST: async task tick");
|
||||
|
||||
// call back all sync callbacks
|
||||
for cb in &sync_cbs {
|
||||
cb();
|
||||
}
|
||||
|
||||
// call back all async callbacks
|
||||
for cb in &async_cbs {
|
||||
cb().await;
|
||||
}
|
||||
|
||||
// send event on unbounded channel
|
||||
sender.send(()).expect("handle receiver cannot be closed/dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: async task stopped");
|
||||
}
|
||||
|
||||
// #[gen_stub_pyclass]
|
||||
#[pyclass(name = "AsyncTaskHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyAsyncTaskHandle {
|
||||
sender: Option<mpsc::UnboundedSender<AsyncTaskMessage>>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyAsyncTaskHandle {
|
||||
const fn sender(&self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn sender_mut(&mut self) -> &mpsc::UnboundedSender<AsyncTaskMessage> {
|
||||
self.sender
|
||||
.as_mut()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
|
||||
const fn new(
|
||||
sender: mpsc::UnboundedSender<AsyncTaskMessage>,
|
||||
receiver: mpsc::UnboundedReceiver<()>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sender: Some(sender),
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAsyncTaskHandle {
|
||||
#[new]
|
||||
fn py_new(py: Python<'_>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channel TOWARDS our task
|
||||
let (h_sender, t_receiver) = mpsc::unbounded_channel::<AsyncTaskMessage>();
|
||||
|
||||
// create communication channel FROM our task
|
||||
let (t_sender, h_receiver) = mpsc::unbounded_channel::<()>();
|
||||
|
||||
// perform necessary setup within tokio context - or it crashes
|
||||
let () = get_runtime().block_on(async { needs_tokio_runtime() });
|
||||
|
||||
// spawn tokio task with this thread's task-locals - without this, async callbacks on the new threads will not work!!
|
||||
_ = get_runtime().spawn_with_scope(py, async move {
|
||||
async_task(t_sender, t_receiver).await;
|
||||
});
|
||||
Ok(Self::new(h_sender, h_receiver))
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_sync_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], None]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::SyncCallback(Box::new(move || {
|
||||
_ = Python::with_gil(|py| callback.call0(py).write_unraisable_with(py));
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// NOTE: exceptions in callbacks are silently ignored until end of execution
|
||||
fn add_async_callback(
|
||||
&self,
|
||||
// #[gen_stub(override_type(
|
||||
// type_repr="collections.abc.Callable[[], collections.abc.Awaitable[None]]",
|
||||
// imports=("collections.abc")
|
||||
// ))]
|
||||
callback: Py<PyAny>,
|
||||
) -> PyResult<()> {
|
||||
// blocking call to async method -> can do non-blocking if needed
|
||||
self.sender()
|
||||
.send(AsyncTaskMessage::AsyncCallback(Box::new(move || {
|
||||
let c = Python::with_gil(|py| callback.clone_ref(py));
|
||||
async move {
|
||||
if let Some(f) = Python::with_gil(|py| {
|
||||
let coroutine = c.call0(py).write_unraisable_with(py)?;
|
||||
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
|
||||
.write_unraisable_with(py)
|
||||
}) {
|
||||
_ = f.await.write_unraisable();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
})))
|
||||
.pyerr()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn receive_unit(&mut self) -> PyResult<()> {
|
||||
self.receiver
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
))
|
||||
}
|
||||
|
||||
fn drain_units(&mut self) -> PyResult<i32> {
|
||||
let mut cnt = 0;
|
||||
loop {
|
||||
match self.receiver.try_recv() {
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
return Err(PyErr::new::<PyRuntimeError, _>(
|
||||
"cannot receive unit on closed channel",
|
||||
));
|
||||
}
|
||||
Err(TryRecvError::Empty) => return Ok(cnt),
|
||||
Ok(()) => {
|
||||
cnt += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
// #[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.sender = None; // Using Option<T> as a trick to force `sender` channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
pub fn examples_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyAsyncTaskHandle>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
use iroh::{EndpointId, SecretKey, endpoint_info::EndpointIdExt as _};
|
||||
use postcard::ser_flavors::StdVec;
|
||||
|
||||
use crate::ext::ResultExt as _;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use rand::rng;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PyKeypair(pub(crate) SecretKey);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(SecretKey::generate(&mut rng()))
|
||||
}
|
||||
/// Decode a postcard structure into a keypair
|
||||
#[staticmethod]
|
||||
fn from_postcard_encoding(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(postcard::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
/// Encode a private key with the postcard format
|
||||
fn to_postcard_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = postcard::serialize_with_flavor(&self.0, StdVec::new()).pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
/// Read out the endpoint id corresponding to this keypair
|
||||
fn endpoint_id(&self) -> PyEndpointId {
|
||||
PyEndpointId(self.0.public())
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "EndpointId", frozen)]
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PyEndpointId(pub(crate) EndpointId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyEndpointId {
|
||||
pub fn __str__(&self) -> String {
|
||||
self.0.to_z32()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EndpointId> for PyEndpointId {
|
||||
fn from(value: EndpointId) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyEndpointId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -4,27 +4,65 @@
|
||||
//!
|
||||
//!
|
||||
|
||||
mod allow_threading;
|
||||
mod identity;
|
||||
mod networking;
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(tuple_trait)]
|
||||
#![feature(unboxed_closures)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
extern crate core;
|
||||
mod allow_threading;
|
||||
mod examples;
|
||||
pub(crate) mod networking;
|
||||
pub(crate) mod pylibp2p;
|
||||
|
||||
use crate::identity::ident_submodule;
|
||||
use crate::networking::networking_submodule;
|
||||
use crate::pylibp2p::ident::ident_submodule;
|
||||
use crate::pylibp2p::multiaddr::multiaddr_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
/// Namespace for all the constants used by this crate.
|
||||
pub(crate) mod r#const {
|
||||
pub const MPSC_CHANNEL_SIZE: usize = 1024;
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
use std::marker::Tuple;
|
||||
|
||||
pub trait SendFn<Args: Tuple + Send + 'static, Output> =
|
||||
Fn<Args, Output = Output> + Send + 'static;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use crate::allow_threading::AllowThreads;
|
||||
use extend::ext;
|
||||
use pyo3::exceptions::{PyConnectionError, PyRuntimeError};
|
||||
use pyo3::marker::Ungil;
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Py, PyErr, PyResult, Python};
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[ext(pub, name = ByteArrayExt)]
|
||||
impl [u8] {
|
||||
fn pybytes(&self) -> Py<PyBytes> {
|
||||
Python::attach(|py| PyBytes::new(py, self).unbind())
|
||||
Python::with_gil(|py| PyBytes::new(py, self).unbind())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,9 +77,7 @@ pub(crate) mod ext {
|
||||
}
|
||||
|
||||
pub trait FutureExt: Future + Sized {
|
||||
/// SEE: <https://pyo3.rs/v0.27.1/async-await.html#detaching-from-the-interpreter-across-await>
|
||||
/// An [`AllowThreads`] returns a Future with an Err output if python has shutdown while we
|
||||
/// were awaiting something
|
||||
/// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
fn allow_threads_py(self) -> AllowThreads<Self>
|
||||
where
|
||||
AllowThreads<Self>: Future,
|
||||
@@ -62,7 +98,7 @@ pub(crate) mod ext {
|
||||
#[ext(pub, name = PyResultExt)]
|
||||
impl<T> PyResult<T> {
|
||||
fn write_unraisable(self) -> Option<T> {
|
||||
Python::attach(|py| self.write_unraisable_with(py))
|
||||
Python::with_gil(|py| self.write_unraisable_with(py))
|
||||
}
|
||||
|
||||
fn write_unraisable_with(self, py: Python<'_>) -> Option<T> {
|
||||
@@ -76,6 +112,85 @@ pub(crate) mod ext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioRuntimeExt)]
|
||||
impl Runtime {
|
||||
fn spawn_with_scope<F>(&self, py: Python<'_>, future: F) -> PyResult<JoinHandle<F::Output>>
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
let locals = pyo3_async_runtimes::tokio::get_current_locals(py)?;
|
||||
Ok(self.spawn(pyo3_async_runtimes::tokio::scope(locals, future)))
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscSenderExt)]
|
||||
impl<T> mpsc::Sender<T> {
|
||||
/// Sends a value, waiting until there is capacity.
|
||||
///
|
||||
/// A successful send occurs when it is determined that the other end of the
|
||||
/// channel has not hung up already. An unsuccessful send would be one where
|
||||
/// the corresponding receiver has already been closed.
|
||||
async fn send_py(&self, value: T) -> PyResult<()> {
|
||||
self.send(value)
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = TokioMpscReceiverExt)]
|
||||
impl<T> mpsc::Receiver<T> {
|
||||
/// Receives the next value for this receiver.
|
||||
async fn recv_py(&mut self) -> PyResult<T> {
|
||||
self.recv().await.ok_or_else(PyErr::receiver_channel_closed)
|
||||
}
|
||||
|
||||
/// Receives at most `limit` values for this receiver and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn recv_many_py(&mut self, limit: usize) -> PyResult<Vec<T>> {
|
||||
// get updates from receiver channel
|
||||
let mut updates = Vec::with_capacity(limit);
|
||||
let received = self.recv_many(&mut updates, limit).await;
|
||||
|
||||
// if we received zero items, then the channel was unexpectedly closed
|
||||
if limit != 0 && received == 0 {
|
||||
return Err(PyErr::receiver_channel_closed());
|
||||
}
|
||||
|
||||
Ok(updates)
|
||||
}
|
||||
|
||||
/// Tries to receive the next value for this receiver.
|
||||
fn try_recv_py(&mut self) -> PyResult<Option<T>> {
|
||||
match self.try_recv() {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(TryRecvError::Empty) => Ok(None),
|
||||
Err(TryRecvError::Disconnected) => Err(PyErr::receiver_channel_closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod private {
|
||||
use std::marker::Sized;
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// A wrapper around [`Py`] that implements [`Clone`] using [`Python::with_gil`].
|
||||
#[repr(transparent)]
|
||||
pub(crate) struct ClonePy<T>(pub Py<T>);
|
||||
|
||||
impl<T> Clone for ClonePy<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self(self.0.clone_ref(py)))
|
||||
}
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust. The name of this function must match
|
||||
@@ -84,18 +199,18 @@ pub(crate) mod ext {
|
||||
#[pymodule(name = "exo_pyo3_bindings")]
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
/*
|
||||
use log::LevelFilter;
|
||||
#[allow(clippy::expect_used)]
|
||||
pyo3_log::Logger::default()
|
||||
.filter(LevelFilter::Warn)
|
||||
.install()
|
||||
.expect("logger install");
|
||||
*/
|
||||
pyo3_log::init();
|
||||
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
multiaddr_submodule(m)?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
// TODO: ...
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,194 +1,570 @@
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt as _, ResultExt as _};
|
||||
use crate::identity::{PyEndpointId, PyKeypair};
|
||||
use iroh::SecretKey;
|
||||
use iroh::discovery::EndpointInfo;
|
||||
use iroh::discovery::mdns::DiscoveryEvent;
|
||||
use iroh_gossip::api::{ApiError, Event, GossipReceiver, GossipSender, Message};
|
||||
use n0_future::{Stream, StreamExt as _};
|
||||
use networking::ExoNet;
|
||||
use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration};
|
||||
use pyo3::prelude::*;
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::pyclass;
|
||||
use crate::pylibp2p::ident::{PyKeypair, PyPeerId};
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use libp2p::{gossipsub, mdns};
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::collections::BTreeSet;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::{Pin, pin};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::Mutex;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
use util::ext::VecExt as _;
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
static RUNTIME: LazyLock<Runtime> =
|
||||
LazyLock::new(|| Runtime::new().expect("Failed to create tokio runtime"));
|
||||
mod exception {
|
||||
use pyo3::types::PyTuple;
|
||||
use pyo3::{PyErrArguments, exceptions::PyException, prelude::*};
|
||||
use pyo3_stub_gen::derive::*;
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "IpAddress")]
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct PyIpAddress {
|
||||
inner: SocketAddr,
|
||||
}
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="NoPeersSubscribedToTopicError")]
|
||||
pub struct PyNoPeersSubscribedToTopicError {}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyIpAddress {
|
||||
pub fn __str__(&self) -> String {
|
||||
self.inner.to_string()
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
const MSG: &'static str = "\
|
||||
No peers are currently subscribed to receive messages on this topic. \
|
||||
Wait for peers to subscribe or check your network connectivity.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ip_addr(&self) -> String {
|
||||
self.inner.ip().to_string()
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNoPeersSubscribedToTopicError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn port(&self) -> u16 {
|
||||
self.inner.port()
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="AllQueuesFullError")]
|
||||
pub struct PyAllQueuesFullError {}
|
||||
|
||||
impl PyAllQueuesFullError {
|
||||
const MSG: &'static str =
|
||||
"All libp2p peers are unresponsive, resend the message or reconnect.";
|
||||
|
||||
/// Creates a new [ `PyErr` ] of this type.
|
||||
///
|
||||
/// [`PyErr`] : https://docs.rs/pyo3/latest/pyo3/struct.PyErr.html "PyErr in pyo3"
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(()) // TODO: check if this needs to be replaced???
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn zone_id(&self) -> Option<u32> {
|
||||
match self.inner {
|
||||
SocketAddr::V6(ip) => Some(ip.scope_id()),
|
||||
SocketAddr::V4(_) => None,
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyAllQueuesFullError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection or disconnection event discriminant type.
|
||||
#[gen_stub_pyclass_enum]
|
||||
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum PyConnectionUpdateType {
|
||||
Connected = 0,
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustNetworkingHandle")]
|
||||
pub struct PyNetworkingHandle {
|
||||
net: Arc<ExoNet>,
|
||||
#[pyclass(frozen, name = "ConnectionUpdate")]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyConnectionUpdate {
|
||||
/// Whether this is a connection or disconnection event
|
||||
#[pyo3(get)]
|
||||
update_type: PyConnectionUpdateType,
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: PyPeerId,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
remote_ipv4: String,
|
||||
|
||||
/// Remote connection's TCP port.
|
||||
#[pyo3(get)]
|
||||
remote_tcp_port: u16,
|
||||
}
|
||||
|
||||
enum ToTask {
|
||||
GossipsubSubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
||||
},
|
||||
GossipsubUnsubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<bool>,
|
||||
},
|
||||
GossipsubPublish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::enum_glob_use)]
|
||||
async fn networking_task(
|
||||
mut swarm: networking::swarm::Swarm,
|
||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use mdns::Event::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = to_task_rx.recv() => {
|
||||
// handle closed channel
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming messages
|
||||
match message {
|
||||
GossipsubSubscribe { topic, result_tx } => {
|
||||
// try to subscribe
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubUnsubscribe { topic, result_tx } => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(result) {
|
||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubPublish { topic, data, result_tx } => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
||||
Err(exception::PyAllQueuesFullError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(pyresult) {
|
||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// architectural solution to this problem:
|
||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
||||
//
|
||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
||||
// then for actual communication it will dial those peers if need-be
|
||||
swarm_event = swarm.select_next_some() => {
|
||||
match swarm_event {
|
||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
||||
message: Message {
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => {
|
||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
||||
let message = (topic.into_string(), data);
|
||||
|
||||
// send incoming message to channel (or exit if connection closed)
|
||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
e => {
|
||||
log::info!("RUST: other event {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: networking task stopped");
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
||||
}
|
||||
|
||||
impl Drop for PyNetworkingHandle {
|
||||
fn drop(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyNetworkingHandle {
|
||||
fn new(
|
||||
to_task_tx: mpsc::Sender<ToTask>,
|
||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_task_tx: Some(to_task_tx),
|
||||
connection_update_rx: Mutex::new(connection_update_rx),
|
||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
||||
}
|
||||
}
|
||||
|
||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
||||
self.to_task_tx
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyNetworkingHandle {
|
||||
#[staticmethod]
|
||||
pub async fn create(identity: PyKeypair, namespace: String) -> PyResult<Self> {
|
||||
let loc: SecretKey = identity.0.clone();
|
||||
let net = Arc::new(
|
||||
RUNTIME
|
||||
.spawn(async move { ExoNet::init_iroh(loc, &namespace).await })
|
||||
.await
|
||||
// todo: pyerr better
|
||||
.pyerr()?
|
||||
.pyerr()?,
|
||||
);
|
||||
let cloned = Arc::clone(&net);
|
||||
RUNTIME.spawn(async move { cloned.start_auto_dialer().await });
|
||||
// NOTE: `async fn`s here that use `.await` will wrap the future in `.allow_threads_py()`
|
||||
// immediately beforehand to release the interpreter.
|
||||
// SEE: https://pyo3.rs/v0.26.0/async-await.html#detaching-from-the-interpreter-across-await
|
||||
|
||||
Ok(Self { net })
|
||||
// ---- Lifecycle management methods ----
|
||||
|
||||
#[new]
|
||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channels
|
||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let swarm = get_runtime()
|
||||
.block_on(async { create_swarm(identity) })
|
||||
.pyerr()?;
|
||||
|
||||
// spawn tokio task running the networking logic
|
||||
get_runtime().spawn(async move {
|
||||
networking_task(
|
||||
swarm,
|
||||
to_task_rx,
|
||||
connection_update_tx,
|
||||
gossipsub_message_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(Self::new(
|
||||
to_task_tx,
|
||||
connection_update_rx,
|
||||
gossipsub_message_rx,
|
||||
))
|
||||
}
|
||||
|
||||
async fn subscribe(&self, topic: String) -> PyResult<(PySender, PyReceiver)> {
|
||||
let fut = self.net.subscribe(&topic);
|
||||
let (send, recv) = fut.await.pyerr()?;
|
||||
Ok((PySender { inner: send }, PyReceiver { inner: recv }))
|
||||
#[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
}
|
||||
|
||||
async fn get_connection_receiver(&self) -> PyResult<PyConnectionReceiver> {
|
||||
let stream = self.net.connection_info().await;
|
||||
Ok(PyConnectionReceiver {
|
||||
inner: Mutex::new(Box::pin(stream)),
|
||||
})
|
||||
#[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustConnectionMessage")]
|
||||
pub struct PyConnectionMessage {
|
||||
#[pyo3(get)]
|
||||
pub endpoint_id: PyEndpointId,
|
||||
#[pyo3(get)]
|
||||
pub current_transport_addrs: Option<BTreeSet<PyIpAddress>>,
|
||||
}
|
||||
// ---- Connection update receiver methods ----
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustSender")]
|
||||
struct PySender {
|
||||
inner: GossipSender,
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PySender {
|
||||
async fn send(&mut self, message: Py<PyBytes>) -> PyResult<()> {
|
||||
let bytes = Python::attach(|py| message.as_bytes(py).to_vec());
|
||||
let broadcast_fut = self.inner.broadcast(bytes.into());
|
||||
pin!(broadcast_fut).allow_threads_py().await?.pyerr()
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustReceiver")]
|
||||
struct PyReceiver {
|
||||
inner: GossipReceiver,
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyReceiver {
|
||||
async fn receive(&mut self) -> PyResult<Py<PyBytes>> {
|
||||
loop {
|
||||
let next_fut = self.inner.next();
|
||||
match pin!(next_fut).allow_threads_py().await? {
|
||||
// Successful cases
|
||||
Some(Ok(Event::Received(Message { content, .. }))) => {
|
||||
return Ok(content.to_vec().pybytes());
|
||||
}
|
||||
Some(Ok(other)) => log::info!("Dropping gossip event {other:?}"),
|
||||
None => return Err(PyStopAsyncIteration::new_err("")),
|
||||
Some(Err(ApiError::Closed { .. })) => {
|
||||
return Err(PyStopAsyncIteration::new_err(""));
|
||||
}
|
||||
|
||||
// Failure case
|
||||
Some(Err(other)) => {
|
||||
return Err(PyRuntimeError::new_err(other.to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "RustConnectionReceiver")]
|
||||
struct PyConnectionReceiver {
|
||||
inner: Mutex<Pin<Box<dyn Stream<Item = DiscoveryEvent> + Send>>>,
|
||||
}
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
||||
// fn connection_update_is_empty(&self) -> bool {
|
||||
// self.connection_update_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
||||
// fn connection_update_len(&self) -> usize {
|
||||
// self.connection_update_rx.blocking_lock().len()
|
||||
// }
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyConnectionReceiver {
|
||||
async fn receive(&mut self) -> PyResult<PyConnectionMessage> {
|
||||
// Errors on trying to receive twice - which is a dev error. This could just block the
|
||||
// async task, but I want the error to persist
|
||||
let mut lock = self.inner.try_lock().pyerr()?;
|
||||
match lock.next().allow_threads_py().await? {
|
||||
// Successful cases
|
||||
Some(DiscoveryEvent::Discovered {
|
||||
endpoint_info: EndpointInfo { endpoint_id, data },
|
||||
..
|
||||
}) => Ok(PyConnectionMessage {
|
||||
endpoint_id: endpoint_id.into(),
|
||||
current_transport_addrs: Some(
|
||||
data.ip_addrs()
|
||||
.map(|inner| PyIpAddress { inner: *inner })
|
||||
.collect(),
|
||||
),
|
||||
}),
|
||||
Some(DiscoveryEvent::Expired { endpoint_id }) => Ok(PyConnectionMessage {
|
||||
endpoint_id: endpoint_id.into(),
|
||||
current_transport_addrs: None,
|
||||
}),
|
||||
// Failure case
|
||||
None => Err(PyStopAsyncIteration::new_err("")),
|
||||
}
|
||||
// ---- Gossipsub management methods ----
|
||||
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if the subscription worked. Returns `False` if we were already subscribed.
|
||||
async fn gossipsub_subscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubSubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
}
|
||||
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
///
|
||||
/// Returns `True` if we were subscribed to this topic. Returns `False` if we were not subscribed.
|
||||
async fn gossipsub_unsubscribe(&self, topic: String) -> PyResult<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubUnsubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & convert any errors
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())
|
||||
}
|
||||
|
||||
/// Publishes a message with multiple topics to the `GossipSub` network.
|
||||
///
|
||||
/// If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
async fn gossipsub_publish(&self, topic: String, data: Py<PyBytes>) -> PyResult<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::with_gil(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
topic,
|
||||
data,
|
||||
result_tx: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
|
||||
// wait for response & return any errors => ignore messageID for now!!!
|
||||
let _ = rx
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---- Gossipsub message receiver methods ----
|
||||
|
||||
/// Receives the next message from the `GossipSub` network.
|
||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
||||
self.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
}
|
||||
|
||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
Ok(self
|
||||
.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?
|
||||
.map(|(t, d)| (t, d.pybytes())))
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next message from the `GossipSub` network.
|
||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
||||
// Ok(self
|
||||
// .gossipsub_message_rx
|
||||
// .blocking_lock()
|
||||
// .try_recv_py()?
|
||||
// .map(|(t, d)| (t, d.pybytes())))
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `GossipSub` message channel is empty.
|
||||
// fn gossipsub_is_empty(&self) -> bool {
|
||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `GossipSub` messages in the channel.
|
||||
// fn gossipsub_len(&self) -> usize {
|
||||
// self.gossipsub_message_rx.blocking_lock().len()
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyConnectionMessage>()?;
|
||||
m.add_class::<PyReceiver>()?;
|
||||
m.add_class::<PySender>()?;
|
||||
m.add_class::<PyConnectionReceiver>()?;
|
||||
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
|
||||
m.add_class::<exception::PyAllQueuesFullError>()?;
|
||||
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyConnectionUpdate>()?;
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
|
||||
Ok(())
|
||||
|
||||
159
rust/exo_pyo3_bindings/src/pylibp2p/ident.rs
Normal file
159
rust/exo_pyo3_bindings/src/pylibp2p/ident.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
/// Identity keypair of a node.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Keypair", frozen)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyKeypair(pub Keypair);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
8
rust/exo_pyo3_bindings/src/pylibp2p/mod.rs
Normal file
8
rust/exo_pyo3_bindings/src/pylibp2p/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
//! A module for exposing Rust's libp2p datatypes over Pyo3
|
||||
//!
|
||||
//! TODO: right now we are coupled to libp2p's identity, but eventually we want to create our own
|
||||
//! independent identity type of some kind or another. This may require handshaking.
|
||||
//!
|
||||
|
||||
pub mod ident;
|
||||
pub mod multiaddr;
|
||||
81
rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs
Normal file
81
rust/exo_pyo3_bindings/src/pylibp2p/multiaddr.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::Multiaddr;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
use std::str::FromStr as _;
|
||||
|
||||
/// Representation of a Multiaddr.
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "Multiaddr", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyMultiaddr(pub Multiaddr);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyMultiaddr {
|
||||
/// Create a new, empty multiaddress.
|
||||
#[staticmethod]
|
||||
fn empty() -> Self {
|
||||
Self(Multiaddr::empty())
|
||||
}
|
||||
|
||||
/// Create a new, empty multiaddress with the given capacity.
|
||||
#[staticmethod]
|
||||
fn with_capacity(n: usize) -> Self {
|
||||
Self(Multiaddr::with_capacity(n))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its byte slice representation.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Multiaddr::try_from(bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Parse a `Multiaddr` value from its string representation.
|
||||
#[staticmethod]
|
||||
fn from_string(string: String) -> PyResult<Self> {
|
||||
Ok(Self(Multiaddr::from_str(&string).pyerr()?))
|
||||
}
|
||||
|
||||
/// Return the length in bytes of this multiaddress.
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns true if the length of this multiaddress is 0.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
/// Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_vec();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Convert a Multiaddr to a string.
|
||||
fn to_string(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __repr__(&self) -> String {
|
||||
format!("Multiaddr({})", self.0)
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __str__(&self) -> String {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn multiaddr_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyMultiaddr>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
54
rust/exo_pyo3_bindings/tests/dummy.rs
Normal file
54
rust/exo_pyo3_bindings/tests/dummy.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use core::mem::drop;
|
||||
use core::option::Option::Some;
|
||||
use core::time::Duration;
|
||||
use tokio;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drop_channel() {
|
||||
struct Ping;
|
||||
|
||||
let (tx, mut rx) = mpsc::channel::<Ping>(10);
|
||||
|
||||
let _ = tokio::spawn(async move {
|
||||
println!("TASK: entered");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = rx.recv() => {
|
||||
match result {
|
||||
Some(_) => {
|
||||
println!("TASK: pinged");
|
||||
}
|
||||
None => {
|
||||
println!("TASK: closing channel");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs_f32(0.1)) => {
|
||||
println!("TASK: heartbeat");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("TASK: exited");
|
||||
});
|
||||
|
||||
let tx2 = tx.clone();
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx.send(Ping).await.expect("Should not fail");
|
||||
drop(tx);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
|
||||
tx2.send(Ping).await.expect("Should not fail");
|
||||
drop(tx2);
|
||||
|
||||
tokio::time::sleep(Duration::from_secs_f32(0.11)).await;
|
||||
}
|
||||
}
|
||||
@@ -1,47 +1,34 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from exo_pyo3_bindings import (
|
||||
Keypair,
|
||||
RustNetworkingHandle,
|
||||
RustReceiver,
|
||||
RustConnectionReceiver,
|
||||
)
|
||||
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sleep_on_multiple_items() -> None:
|
||||
print("PYTHON: starting handle")
|
||||
s_h = await RustNetworkingHandle.create(Keypair.generate_ed25519(), "test")
|
||||
r_h = await RustNetworkingHandle.create(Keypair.generate_ed25519(), "test")
|
||||
h = NetworkingHandle(Keypair.generate_ed25519())
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
cm = await r_h.get_connection_receiver()
|
||||
|
||||
_, recv = await r_h.subscribe("topic")
|
||||
send, _ = await s_h.subscribe("topic")
|
||||
|
||||
ct = asyncio.create_task(_await_cons(cm))
|
||||
mt = asyncio.create_task(_await_msg(recv))
|
||||
ct = asyncio.create_task(_await_cons(h))
|
||||
mt = asyncio.create_task(_await_msg(h))
|
||||
|
||||
# sleep for 4 ticks
|
||||
for i in range(4):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await send.send(b"somehting or other")
|
||||
|
||||
await ct
|
||||
await mt
|
||||
try:
|
||||
await h.gossipsub_publish("topic", b"somehting or other")
|
||||
except NoPeersSubscribedToTopicError as e:
|
||||
print("caught it", e)
|
||||
|
||||
|
||||
async def _await_cons(h: RustConnectionReceiver):
|
||||
async def _await_cons(h: NetworkingHandle):
|
||||
while True:
|
||||
c = await h.receive()
|
||||
c = await h.connection_update_recv()
|
||||
print(f"PYTHON: connection update: {c}")
|
||||
|
||||
|
||||
async def _await_msg(r: RustReceiver):
|
||||
async def _await_msg(h: NetworkingHandle):
|
||||
while True:
|
||||
m = await r.receive()
|
||||
m = await h.gossipsub_recv()
|
||||
print(f"PYTHON: message: {m}")
|
||||
|
||||
@@ -1,18 +1,44 @@
|
||||
[package]
|
||||
name = "networking"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
blake3 = { workspace = true, features = ["neon", "rayon"] }
|
||||
iroh = { workspace = true, features = ["discovery-local-network"] }
|
||||
iroh-gossip.workspace = true
|
||||
log.workspace = true
|
||||
n0-error.workspace = true
|
||||
n0-future.workspace = true
|
||||
rand.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter"] }
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "networking"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
@@ -1,85 +1,74 @@
|
||||
#![allow(clippy::expect_used, clippy::unwrap_used, clippy::cargo)]
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use iroh::SecretKey;
|
||||
use iroh_gossip::api::{Event, Message};
|
||||
use n0_future::StreamExt as _;
|
||||
use networking::ExoNet;
|
||||
use tokio::time::sleep;
|
||||
use tokio::{io, io::AsyncBufReadExt as _};
|
||||
use futures::stream::StreamExt as _;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt()
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init()
|
||||
.expect("logger");
|
||||
.try_init();
|
||||
|
||||
// Configure swarm
|
||||
let net = Arc::new(
|
||||
ExoNet::init_iroh(SecretKey::generate(&mut rand::rng()), "chatroom")
|
||||
.await
|
||||
.expect("iroh init shouldn't fail"),
|
||||
);
|
||||
let innet = Arc::clone(&net);
|
||||
let jh1 = tokio::spawn(async move { innet.start_auto_dialer().await });
|
||||
|
||||
while net.known_peers.lock().await.is_empty() {
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
let mut swarm =
|
||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let (send, mut recv) = net
|
||||
.subscribe("chatting")
|
||||
.await
|
||||
.expect("topic shouldn't fail");
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("Subscribing to topic failed");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
let jh2 = tokio::spawn(async move {
|
||||
loop {
|
||||
if let Ok(Some(line)) = stdin.next_line().await
|
||||
&& let Err(e) = send.broadcast(line.into()).await
|
||||
{
|
||||
println!("Publish error: {e:?}");
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
// on gossipsub outgoing
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(Ok(event)) = recv.next().await {
|
||||
match event {
|
||||
event = swarm.select_next_some() => match event {
|
||||
// on gossipsub incoming
|
||||
Event::Received(Message {
|
||||
content,
|
||||
delivered_from,
|
||||
..
|
||||
}) => println!(
|
||||
"\n\nGot message: '{}' with from peer: {delivered_from}\n\n",
|
||||
String::from_utf8_lossy(&content),
|
||||
),
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
Event::NeighborUp(peer_id) => {
|
||||
println!("\n\nConnected to: {peer_id}\n\n");
|
||||
}
|
||||
Event::NeighborDown(peer_id) => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}\n\n");
|
||||
}
|
||||
Event::Lagged => {
|
||||
eprintln!("\n\nLagged\n\n");
|
||||
SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) => match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
discovery::Event::ConnectionClosed {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@SwarmEvent::OutgoingConnectionError { .. } => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
jh1.await.unwrap();
|
||||
jh2.await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
127
rust/networking/examples/chatroom_manual.rs
Normal file
127
rust/networking/examples/chatroom_manual.rs
Normal file
@@ -0,0 +1,127 @@
|
||||
// Copyright 2018 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use futures::stream::StreamExt;
|
||||
use libp2p::{
|
||||
gossipsub, mdns, noise,
|
||||
swarm::{NetworkBehaviour, SwarmEvent},
|
||||
tcp, yamux,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use std::{error::Error, hash::Hash};
|
||||
use tokio::{io, io::AsyncBufReadExt, select};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
// We create a custom network behaviour that combines Gossipsub and Mdns.
|
||||
#[derive(NetworkBehaviour)]
|
||||
struct MyBehaviour {
|
||||
gossipsub: gossipsub::Behaviour,
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.try_init();
|
||||
|
||||
let mut swarm = libp2p::SwarmBuilder::with_new_identity()
|
||||
.with_tokio()
|
||||
.with_tcp(
|
||||
tcp::Config::default(),
|
||||
noise::Config::new,
|
||||
yamux::Config::default,
|
||||
)?
|
||||
.with_behaviour(|key| {
|
||||
// Set a custom gossipsub configuration
|
||||
let gossipsub_config = gossipsub::ConfigBuilder::default()
|
||||
.heartbeat_interval(Duration::from_secs(10))
|
||||
.validation_mode(gossipsub::ValidationMode::Strict) // This sets the kind of message validation. The default is Strict (enforce message signing)
|
||||
.build()
|
||||
.map_err(io::Error::other)?; // Temporary hack because `build` does not return a proper `std::error::Error`.
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
let gossipsub = gossipsub::Behaviour::new(
|
||||
gossipsub::MessageAuthenticity::Signed(key.clone()),
|
||||
gossipsub_config,
|
||||
)?;
|
||||
|
||||
let mdns =
|
||||
mdns::tokio::Behaviour::new(mdns::Config::default(), key.public().to_peer_id())?;
|
||||
Ok(MyBehaviour { gossipsub, mdns })
|
||||
})?
|
||||
.build();
|
||||
|
||||
println!("Running swarm with identity {}", swarm.local_peer_id());
|
||||
|
||||
// Create a Gossipsub topic
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
// subscribes to our topic
|
||||
swarm.behaviour_mut().gossipsub.subscribe(&topic)?;
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
event = swarm.select_next_some() => match event {
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Discovered(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discovered a new peer: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.add_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Mdns(mdns::Event::Expired(list))) => {
|
||||
for (peer_id, multiaddr) in list {
|
||||
println!("mDNS discover peer has expired: {peer_id} on {multiaddr}");
|
||||
swarm.behaviour_mut().gossipsub.remove_explicit_peer(&peer_id);
|
||||
}
|
||||
},
|
||||
SwarmEvent::Behaviour(MyBehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
})) => println!(
|
||||
"Got message: '{}' with id: {id} from peer: {peer_id}",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
SwarmEvent::NewListenAddr { address, .. } => {
|
||||
println!("Local node is listening on {address}");
|
||||
}
|
||||
e => {
|
||||
println!("Other swarm event: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
#![allow(clippy::cargo, clippy::unwrap_used)]
|
||||
use iroh::{SecretKey, endpoint_info::EndpointIdExt as _};
|
||||
use n0_future::StreamExt as _;
|
||||
use networking::ExoNet;
|
||||
|
||||
// Launch a mock version of iroh for testing purposes
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
let key = SecretKey::generate(&mut rand::rng());
|
||||
let dbg_key = key.public().to_z32();
|
||||
println!("Starting with pk: {dbg_key}");
|
||||
let net = ExoNet::init_iroh(key, "").await.unwrap();
|
||||
|
||||
let mut conn_info = net.connection_info().await;
|
||||
|
||||
let task = tokio::task::spawn(async move {
|
||||
println!("Inner task started!");
|
||||
loop {
|
||||
dbg!(conn_info.next().await);
|
||||
}
|
||||
});
|
||||
|
||||
println!("Task started!");
|
||||
|
||||
task.await.unwrap();
|
||||
}
|
||||
44
rust/networking/src/RESEARCH_NOTES.txt
Normal file
44
rust/networking/src/RESEARCH_NOTES.txt
Normal file
@@ -0,0 +1,44 @@
|
||||
https://github.com/ml-explore/mlx/commit/3fe98bacc7640d857acf3539f1d21b47a32e5609
|
||||
^raw sockets distributed -> `<net/ndrv.h>` -> https://newosxbook.com/code/xnu-3247.1.106/bsd/net/ndrv.h.auto.html
|
||||
--> header file for a networking component found in the macOS kernel (XNU) that defines structures for network device driver registration, specifically the ndrv_demux_desc and ndrv_protocol_desc structures used for demultiplexing protocol data at the network interface level. It specifies how to describe protocol data, such as an Ethernet type or a SNAP header, and how to associate these descriptions with a specific protocol family to receive matching packets.
|
||||
--> Used to bind an NDRV socket so that packets that match given protocol demux descriptions can be received.
|
||||
--> An NDRV socket is a special kind of socket in the Darwin/macOS operating system's XNU kernel, used for low-level network packet manipulation and binding to specific protocols for packet processing. It allows user-space applications or drivers to directly write Layer 2 (L2) network packets or interact with the network stack at a lower level, often by binding to protocol descriptors like the ndrv_protocol_desc. This type of socket is used for functions such as capturing and injecting packets, especially in network infrastructure software like routers or for kernel-level network monitoring and security tools.
|
||||
--> also called PF_NDRV sockets --> https://newosxbook.com/bonus/vol1ch16.html
|
||||
----> they are conceptually similar to https://scapy.disruptivelabs.in/networking/socket-interface PF_RAW or PF_PACKET
|
||||
|
||||
https://stackoverflow.com/questions/17169298/af-packet-on-osx
|
||||
^AF_PACKET duplicates the packets as soon as it receives them from the physical layer (for incoming packets) or just before sending them out to the physical layer (for outgoing packets). -> this is on Linux only
|
||||
^it doesn't exist on OS X so you can use /dev/bpfX (Berkeley Packet Filter) for sniffing
|
||||
|
||||
https://www.unix.com/man_page/mojave/4/ip/
|
||||
^OS X manpages for IP
|
||||
|
||||
https://developer.apple.com/documentation/kernel/implementing_drivers_system_extensions_and_kexts
|
||||
^driver kit, system extensions & kexts for macOS
|
||||
|
||||
----
|
||||
|
||||
To set up a Linux system to use a Thunderbolt connection as a network device, connect the two computers with a Thunderbolt cable, load the thunderbolt-net kernel module (usually automatic but modprobe is an option for manual loading), and then the operating system will create virtual Ethernet interfaces (e.g., thunderbolt0) for networking. You can then use standard tools like ifconfig or your desktop environment's network manager to configure these new interfaces for a link-local network.
|
||||
--> https://gist.github.com/geosp/80fbd39e617b7d1d9421683df4ea224a
|
||||
----> here is a guide on how to set up thunderbolt-ethernet on linux
|
||||
----> I may be able to steal the thunderbolt-net code ideas to implement a kernel module for MacOS
|
||||
|
||||
https://chatgpt.com/s/t_68af8e41a8548191993281a014f846a7
|
||||
^GPT discussion about making socket interface
|
||||
|
||||
https://chatgpt.com/s/t_68afb798a85c8191973c02a0fa7a48a3 --> link-local address,,??
|
||||
https://chatgpt.com/s/t_68afb02987e08191b2b0044d3667ece2
|
||||
^GPT discussion about accessing TB on MacOS low level interactions
|
||||
|
||||
--------------------------------
|
||||
|
||||
https://www.intel.com/content/www/us/en/support/articles/000098893/software.html
|
||||
^Thunderbolt Share & Thunderbolt Networking Mode => intel's equivalent of thunderbolt bridge
|
||||
|
||||
|
||||
---------------------------------
|
||||
|
||||
https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/
|
||||
-->fake ethernet devices on MacOS -> omg??? we can detect thunderbolt bridge, then bind to it, then re-expose it as fake ethernet??
|
||||
-->ps: https://chatgpt.com/s/t_68afb2b25fb881919526763fb5d7359c, AF/PF_NDRV are one and the same!!!
|
||||
-->https://github.com/zerotier/ZeroTierOne/blob/dev/osdep/MacEthernetTapAgent.c
|
||||
383
rust/networking/src/discovery.rs
Normal file
383
rust/networking/src/discovery.rs
Normal file
@@ -0,0 +1,383 @@
|
||||
use crate::ext::MultiaddrExt;
|
||||
use crate::keep_alive;
|
||||
use delegate::delegate;
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
use futures_timer::Delay;
|
||||
use libp2p::core::transport::PortUse;
|
||||
use libp2p::core::{ConnectedPoint, Endpoint};
|
||||
use libp2p::swarm::behaviour::ConnectionEstablished;
|
||||
use libp2p::swarm::dial_opts::DialOpts;
|
||||
use libp2p::swarm::{
|
||||
CloseConnection, ConnectionClosed, ConnectionDenied, ConnectionHandler,
|
||||
ConnectionHandlerSelect, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent,
|
||||
THandlerOutEvent, ToSwarm, dummy,
|
||||
};
|
||||
use libp2p::{Multiaddr, PeerId, identity, mdns};
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::convert::Infallible;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use util::wakerdeque::WakerDeque;
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
mod managed {
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
|
||||
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
mdns: mdns::tokio::Behaviour,
|
||||
ping: ping::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
mdns: mdns_behaviour(keypair)?,
|
||||
ping: ping_behaviour(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn mdns_behaviour(keypair: &identity::Keypair) -> io::Result<mdns::tokio::Behaviour> {
|
||||
use mdns::{Config, tokio};
|
||||
|
||||
// mDNS config => enable IPv6
|
||||
let mdns_config = Config {
|
||||
ttl: MDNS_RECORD_TTL,
|
||||
query_interval: MDNS_QUERY_INTERVAL,
|
||||
|
||||
// enable_ipv6: true, // TODO: for some reason, TCP+mDNS don't work well with ipv6?? figure out how to make work
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mdns_behaviour = tokio::Behaviour::new(mdns_config, keypair.public().to_peer_id());
|
||||
Ok(mdns_behaviour?)
|
||||
}
|
||||
|
||||
fn ping_behaviour() -> ping::Behaviour {
|
||||
ping::Behaviour::new(
|
||||
ping::Config::new()
|
||||
.with_timeout(PING_TIMEOUT)
|
||||
.with_interval(PING_INTERVAL),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Events for when a listening connection is truly established and truly closed.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
ConnectionEstablished {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
ConnectionClosed {
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
},
|
||||
}
|
||||
|
||||
/// Discovery behavior that wraps mDNS to produce truly discovered durable peer-connections.
|
||||
///
|
||||
/// The behaviour operates as such:
|
||||
/// 1) All true (listening) connections/disconnections are tracked, emitting corresponding events
|
||||
/// to the swarm.
|
||||
/// 1) mDNS discovered/expired peers are tracked; discovered but not connected peers are dialed
|
||||
/// immediately, and expired but connected peers are disconnected from immediately.
|
||||
/// 2) Every fixed interval: discovered but not connected peers are dialed, and expired but
|
||||
/// connected peers are disconnected from.
|
||||
pub struct Behaviour {
|
||||
// state-tracking for managed behaviors & mDNS-discovered peers
|
||||
managed: managed::Behaviour,
|
||||
mdns_discovered: HashMap<PeerId, BTreeSet<Multiaddr>>,
|
||||
|
||||
retry_delay: Delay, // retry interval
|
||||
|
||||
// pending events to emmit => waker-backed Deque to control polling
|
||||
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
managed: managed::Behaviour::new(keypair)?,
|
||||
mdns_discovered: HashMap::new(),
|
||||
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
|
||||
pending_events: WakerDeque::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn dial(&mut self, peer_id: PeerId, addr: Multiaddr) {
|
||||
self.pending_events.push_back(ToSwarm::Dial {
|
||||
opts: DialOpts::peer_id(peer_id).addresses(vec![addr]).build(),
|
||||
})
|
||||
}
|
||||
|
||||
fn close_connection(&mut self, peer_id: PeerId, connection: ConnectionId) {
|
||||
// push front to make this IMMEDIATE
|
||||
self.pending_events.push_front(ToSwarm::CloseConnection {
|
||||
peer_id,
|
||||
connection: CloseConnection::One(connection),
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_mdns_discovered(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
self.dial(p, ma.clone()); // always connect
|
||||
|
||||
// get peer's multi-addresses or insert if missing
|
||||
let Some(mas) = self.mdns_discovered.get_mut(&p) else {
|
||||
self.mdns_discovered.insert(p, BTreeSet::from([ma]));
|
||||
continue;
|
||||
};
|
||||
|
||||
// multiaddress should never already be present - else something has gone wrong
|
||||
let is_new_addr = mas.insert(ma);
|
||||
assert!(is_new_addr, "cannot discover a discovered peer");
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_mdns_expired(&mut self, peers: Vec<(PeerId, Multiaddr)>) {
|
||||
for (p, ma) in peers {
|
||||
// at this point, we *must* have the peer
|
||||
let mas = self
|
||||
.mdns_discovered
|
||||
.get_mut(&p)
|
||||
.expect("nonexistent peer cannot expire");
|
||||
|
||||
// at this point, we *must* have the multiaddress
|
||||
let was_present = mas.remove(&ma);
|
||||
assert!(was_present, "nonexistent multiaddress cannot expire");
|
||||
|
||||
// if empty, remove the peer-id entirely
|
||||
if mas.is_empty() {
|
||||
self.mdns_discovered.remove(&p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn on_connection_established(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out connected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
|
||||
fn on_connection_closed(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
remote_ip: IpAddr,
|
||||
remote_tcp_port: u16,
|
||||
) {
|
||||
// send out disconnected event
|
||||
self.pending_events
|
||||
.push_back(ToSwarm::GenerateEvent(Event::ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_tcp_port,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl NetworkBehaviour for Behaviour {
|
||||
type ConnectionHandler =
|
||||
ConnectionHandlerSelect<dummy::ConnectionHandler, THandler<managed::Behaviour>>;
|
||||
type ToSwarm = Event;
|
||||
|
||||
// simply delegate to underlying mDNS behaviour
|
||||
|
||||
delegate! {
|
||||
to self.managed {
|
||||
fn handle_pending_inbound_connection(&mut self, connection_id: ConnectionId, local_addr: &Multiaddr, remote_addr: &Multiaddr) -> Result<(), ConnectionDenied>;
|
||||
fn handle_pending_outbound_connection(&mut self, connection_id: ConnectionId, maybe_peer: Option<PeerId>, addresses: &[Multiaddr], effective_role: Endpoint) -> Result<Vec<Multiaddr>, ConnectionDenied>;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_established_inbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
local_addr: &Multiaddr,
|
||||
remote_addr: &Multiaddr,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_inbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
local_addr,
|
||||
remote_addr,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_question_mark)]
|
||||
fn handle_established_outbound_connection(
|
||||
&mut self,
|
||||
connection_id: ConnectionId,
|
||||
peer: PeerId,
|
||||
addr: &Multiaddr,
|
||||
role_override: Endpoint,
|
||||
port_use: PortUse,
|
||||
) -> Result<THandler<Self>, ConnectionDenied> {
|
||||
Ok(ConnectionHandler::select(
|
||||
dummy::ConnectionHandler,
|
||||
self.managed.handle_established_outbound_connection(
|
||||
connection_id,
|
||||
peer,
|
||||
addr,
|
||||
role_override,
|
||||
port_use,
|
||||
)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn on_connection_handler_event(
|
||||
&mut self,
|
||||
peer_id: PeerId,
|
||||
connection_id: ConnectionId,
|
||||
event: THandlerOutEvent<Self>,
|
||||
) {
|
||||
match event {
|
||||
Either::Left(ev) => libp2p::core::util::unreachable(ev),
|
||||
Either::Right(ev) => {
|
||||
self.managed
|
||||
.on_connection_handler_event(peer_id, connection_id, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// hook into these methods to drive behavior
|
||||
|
||||
fn on_swarm_event(&mut self, event: FromSwarm) {
|
||||
self.managed.on_swarm_event(event); // let mDNS handle swarm events
|
||||
|
||||
// handle swarm events to update internal state:
|
||||
match event {
|
||||
FromSwarm::ConnectionEstablished(ConnectionEstablished {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection established event which is filtered correctly
|
||||
self.on_connection_established(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
FromSwarm::ConnectionClosed(ConnectionClosed {
|
||||
peer_id,
|
||||
connection_id,
|
||||
endpoint,
|
||||
..
|
||||
}) => {
|
||||
let remote_address = match endpoint {
|
||||
ConnectedPoint::Dialer { address, .. } => address,
|
||||
ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr,
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
}
|
||||
}
|
||||
|
||||
// since we are running TCP/IP transport layer, we are assuming that
|
||||
// no address changes can occur, hence encountering one is a fatal error
|
||||
FromSwarm::AddressChange(a) => {
|
||||
unreachable!("unhandlable: address change encountered: {:?}", a)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut Context) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
|
||||
// delegate to managed behaviors for any behaviors they need to perform
|
||||
match self.managed.poll(cx) {
|
||||
Poll::Ready(ToSwarm::GenerateEvent(e)) => {
|
||||
match e {
|
||||
// handle discovered and expired events from mDNS
|
||||
managed::BehaviourEvent::Mdns(e) => match e.clone() {
|
||||
mdns::Event::Discovered(peers) => {
|
||||
self.handle_mdns_discovered(peers);
|
||||
}
|
||||
mdns::Event::Expired(peers) => {
|
||||
self.handle_mdns_expired(peers);
|
||||
}
|
||||
},
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// since we just consumed an event, we should immediately wake just in case
|
||||
// there are more events to come where that came from
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
|
||||
// forward any other mDNS event to the swarm or its connection handler(s)
|
||||
Poll::Ready(e) => {
|
||||
return Poll::Ready(
|
||||
e.map_out(|_| unreachable!("events returning to swarm already handled"))
|
||||
.map_in(Either::Right),
|
||||
);
|
||||
}
|
||||
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
// retry connecting to all mDNS peers periodically (fails safely if already connected)
|
||||
if self.retry_delay.poll_unpin(cx).is_ready() {
|
||||
for (p, mas) in self.mdns_discovered.clone() {
|
||||
for ma in mas {
|
||||
self.dial(p, ma)
|
||||
}
|
||||
}
|
||||
self.retry_delay.reset(RETRY_CONNECT_INTERVAL) // reset timeout
|
||||
}
|
||||
|
||||
// send out any pending events from our own service
|
||||
if let Some(e) = self.pending_events.pop_front(cx) {
|
||||
return Poll::Ready(e.map_in(Either::Left));
|
||||
}
|
||||
|
||||
// wait for pending events
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
44
rust/networking/src/keep_alive.rs
Normal file
44
rust/networking/src/keep_alive.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use delegate::delegate;
|
||||
use libp2p::swarm::handler::ConnectionEvent;
|
||||
use libp2p::swarm::{ConnectionHandlerEvent, SubstreamProtocol, dummy, handler};
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// An implementation of [`ConnectionHandler`] that doesn't handle any protocols, but it keeps
|
||||
/// the connection alive.
|
||||
#[derive(Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct ConnectionHandler(dummy::ConnectionHandler);
|
||||
|
||||
impl ConnectionHandler {
|
||||
pub fn new() -> Self {
|
||||
ConnectionHandler(dummy::ConnectionHandler)
|
||||
}
|
||||
}
|
||||
|
||||
impl handler::ConnectionHandler for ConnectionHandler {
|
||||
// delegate types and implementation mostly to dummy handler
|
||||
type FromBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::FromBehaviour;
|
||||
type ToBehaviour = <dummy::ConnectionHandler as handler::ConnectionHandler>::ToBehaviour;
|
||||
type InboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundProtocol;
|
||||
type OutboundProtocol =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundProtocol;
|
||||
type InboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::InboundOpenInfo;
|
||||
type OutboundOpenInfo =
|
||||
<dummy::ConnectionHandler as handler::ConnectionHandler>::OutboundOpenInfo;
|
||||
|
||||
delegate! {
|
||||
to self.0 {
|
||||
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>>;
|
||||
fn on_behaviour_event(&mut self, event: Self::FromBehaviour);
|
||||
fn on_connection_event(&mut self, event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol, Self::InboundOpenInfo, Self::OutboundOpenInfo>);
|
||||
}
|
||||
}
|
||||
|
||||
// specifically override this to force connection to stay alive
|
||||
fn connection_keep_alive(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -1,149 +1,64 @@
|
||||
use std::collections::BTreeSet;
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
use iroh::{
|
||||
Endpoint, EndpointId, SecretKey, TransportAddr,
|
||||
discovery::{
|
||||
Discovery as _, EndpointData, IntoDiscoveryError,
|
||||
mdns::{DiscoveryEvent, MdnsDiscovery},
|
||||
},
|
||||
endpoint::BindError,
|
||||
endpoint_info::EndpointIdExt as _,
|
||||
protocol::Router,
|
||||
};
|
||||
use iroh_gossip::{
|
||||
Gossip, TopicId,
|
||||
api::{ApiError, GossipReceiver, GossipSender},
|
||||
};
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
// #![feature(stmt_expr_attributes)]
|
||||
// #![feature(unboxed_closures)]
|
||||
// #![feature(assert_matches)]
|
||||
// #![feature(async_fn_in_dyn_trait)]
|
||||
// #![feature(async_for_loop)]
|
||||
// #![feature(auto_traits)]
|
||||
// #![feature(negative_impls)]
|
||||
|
||||
use n0_error::{e, stack_error};
|
||||
use n0_future::{Stream, StreamExt as _};
|
||||
use tokio::sync::Mutex;
|
||||
pub mod discovery;
|
||||
pub mod keep_alive;
|
||||
pub mod swarm;
|
||||
|
||||
#[stack_error(derive, add_meta, from_sources)]
|
||||
pub enum ExoError {
|
||||
#[error(transparent)]
|
||||
FailedBinding { source: BindError },
|
||||
/// The gossip topic was closed.
|
||||
#[error(transparent)]
|
||||
FailedCommunication { source: ApiError },
|
||||
#[error("No IP Protocol supported on device")]
|
||||
IPNotSupported { source: IntoDiscoveryError },
|
||||
#[error("No peers found before subscribing")]
|
||||
NoPeers,
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {
|
||||
use std::error::Error;
|
||||
|
||||
pub type AnyError = Box<dyn Error + Send + Sync + 'static>;
|
||||
pub type AnyResult<T> = Result<T, AnyError>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ExoNet {
|
||||
pub alpn: String,
|
||||
pub router: Router,
|
||||
pub gossip: Gossip,
|
||||
pub mdns: MdnsDiscovery,
|
||||
pub known_peers: Mutex<BTreeSet<EndpointId>>,
|
||||
}
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {
|
||||
use extend::ext;
|
||||
use libp2p::Multiaddr;
|
||||
use libp2p::multiaddr::Protocol;
|
||||
use std::net::IpAddr;
|
||||
|
||||
impl ExoNet {
|
||||
#[inline]
|
||||
pub async fn init_iroh(sk: SecretKey, namespace: &str) -> Result<Self, ExoError> {
|
||||
let endpoint = Endpoint::empty_builder(iroh::RelayMode::Disabled)
|
||||
.secret_key(sk)
|
||||
.bind()
|
||||
.await?;
|
||||
let mdns = MdnsDiscovery::builder().build(endpoint.id())?;
|
||||
let endpoint_addr = endpoint.addr();
|
||||
|
||||
let bound = endpoint_addr.ip_addrs().map(|it| TransportAddr::Ip(*it));
|
||||
|
||||
log::info!("publishing {endpoint_addr:?} with mdns");
|
||||
mdns.publish(&EndpointData::new(bound));
|
||||
endpoint.discovery().add(mdns.clone());
|
||||
let alpn = format!("/exo_discovery_network/{namespace}");
|
||||
// max msg size 4MB
|
||||
let gossip = Gossip::builder()
|
||||
.max_message_size(4 * 1024 * 1024)
|
||||
.alpn(&alpn)
|
||||
.spawn(endpoint.clone());
|
||||
let router = Router::builder(endpoint)
|
||||
.accept(&alpn, gossip.clone())
|
||||
.spawn();
|
||||
Ok(Self {
|
||||
alpn,
|
||||
router,
|
||||
gossip,
|
||||
mdns,
|
||||
known_peers: Mutex::new(BTreeSet::new()),
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn start_auto_dialer(&self) {
|
||||
let mut recv = self.connection_info().await;
|
||||
|
||||
log::info!(
|
||||
"Starting auto dialer for id {}",
|
||||
self.router.endpoint().id().to_z32()
|
||||
);
|
||||
while let Some(item) = recv.next().await {
|
||||
match item {
|
||||
DiscoveryEvent::Discovered { endpoint_info, .. } => {
|
||||
let id = endpoint_info.endpoint_id;
|
||||
if id == self.router.endpoint().id() {
|
||||
continue;
|
||||
}
|
||||
if !self
|
||||
.known_peers
|
||||
.lock()
|
||||
.await
|
||||
.contains(&endpoint_info.endpoint_id)
|
||||
&& let Ok(conn) = self
|
||||
.router
|
||||
.endpoint()
|
||||
.connect(endpoint_info, self.alpn.as_bytes())
|
||||
.await
|
||||
&& conn.alpn() == self.alpn.as_bytes()
|
||||
{
|
||||
self.known_peers.lock().await.insert(id);
|
||||
match self.gossip.handle_connection(conn).await {
|
||||
Ok(()) => log::info!("Successfully dialled"),
|
||||
Err(_) => log::info!("Failed to dial peer"),
|
||||
}
|
||||
}
|
||||
#[ext(pub, name = MultiaddrExt)]
|
||||
impl Multiaddr {
|
||||
/// If the multiaddress corresponds to a TCP address, extracts it
|
||||
fn try_to_tcp_addr(&self) -> Option<(IpAddr, u16)> {
|
||||
let mut ps = self.into_iter();
|
||||
let ip = if let Some(p) = ps.next() {
|
||||
match p {
|
||||
Protocol::Ip4(ip) => IpAddr::V4(ip),
|
||||
Protocol::Ip6(ip) => IpAddr::V6(ip),
|
||||
_ => return None,
|
||||
}
|
||||
DiscoveryEvent::Expired { endpoint_id } => {
|
||||
log::info!("Peer expired {}", endpoint_id.to_z32());
|
||||
self.known_peers.lock().await.remove(&endpoint_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
let Some(Protocol::Tcp(port)) = ps.next() else {
|
||||
return None;
|
||||
};
|
||||
Some((ip, port))
|
||||
}
|
||||
log::info!("Auto dialer stopping");
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn connection_info(&self) -> impl Stream<Item = DiscoveryEvent> + Unpin + use<> {
|
||||
self.mdns.subscribe().await
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn subscribe(&self, topic: &str) -> Result<(GossipSender, GossipReceiver), ExoError> {
|
||||
if self.known_peers.lock().await.is_empty() {
|
||||
return Err(e!(ExoError::NoPeers));
|
||||
}
|
||||
Ok(self
|
||||
.gossip
|
||||
.subscribe_and_join(
|
||||
str_to_topic_id(topic),
|
||||
self.known_peers.lock().await.clone().into_iter().collect(),
|
||||
)
|
||||
.await?
|
||||
.split())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[allow(clippy::expect_used)]
|
||||
pub async fn shutdown(&self) {
|
||||
self.router.shutdown().await.expect("router panic");
|
||||
}
|
||||
}
|
||||
|
||||
fn str_to_topic_id(data: &str) -> TopicId {
|
||||
TopicId::from_bytes(*blake3::hash(data.as_bytes()).as_bytes())
|
||||
pub(crate) mod private {
|
||||
#![allow(dead_code)]
|
||||
|
||||
/// Sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
145
rust/networking/src/swarm.rs
Normal file
145
rust/networking/src/swarm.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use libp2p::{SwarmBuilder, identity};
|
||||
|
||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
///
|
||||
/// TODO: right now this is a hardcoded constant; figure out what the versioning semantics should
|
||||
/// even be, and how to inject the right version into this config/initialization. E.g. should
|
||||
/// this be passed in as a parameter? What about rapidly changing versions in debug builds?
|
||||
/// this is all VERY very hard to figure out and needs to be mulled over as a team.
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||
|
||||
/// Create and configure a swarm which listens to all ports on OS
|
||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
.with_behaviour(Behaviour::new)?
|
||||
.build();
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
Ok(swarm)
|
||||
}
|
||||
|
||||
mod transport {
|
||||
use crate::alias;
|
||||
use crate::swarm::{NETWORK_VERSION, OVERRIDE_VERSION_ENV_VAR};
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use keccak_const::Sha3_256;
|
||||
use libp2p::core::muxing;
|
||||
use libp2p::core::transport::Boxed;
|
||||
use libp2p::pnet::{PnetError, PnetOutput};
|
||||
use libp2p::{PeerId, Transport, identity, noise, pnet, yamux};
|
||||
use std::{env, sync::LazyLock};
|
||||
|
||||
/// Key used for networking's private network; parametrized on the [`NETWORK_VERSION`].
|
||||
/// See [`pnet_upgrade`] for more.
|
||||
static PNET_PRESHARED_KEY: LazyLock<[u8; 32]> = LazyLock::new(|| {
|
||||
let builder = Sha3_256::new().update(b"exo_discovery_network");
|
||||
|
||||
if let Ok(var) = env::var(OVERRIDE_VERSION_ENV_VAR) {
|
||||
let bytes = var.into_bytes();
|
||||
builder.update(&bytes)
|
||||
} else {
|
||||
builder.update(NETWORK_VERSION)
|
||||
}
|
||||
.finalize()
|
||||
});
|
||||
|
||||
/// Make the Swarm run on a private network, as to not clash with public libp2p nodes and
|
||||
/// also different-versioned instances of this same network.
|
||||
/// This is implemented as an additional "upgrade" ontop of existing [`libp2p::Transport`] layers.
|
||||
async fn pnet_upgrade<TSocket>(
|
||||
socket: TSocket,
|
||||
_: impl Sized,
|
||||
) -> Result<PnetOutput<TSocket>, PnetError>
|
||||
where
|
||||
TSocket: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
use pnet::{PnetConfig, PreSharedKey};
|
||||
PnetConfig::new(PreSharedKey::new(*PNET_PRESHARED_KEY))
|
||||
.handshake(socket)
|
||||
.await
|
||||
}
|
||||
|
||||
/// TCP/IP transport layer configuration.
|
||||
pub fn tcp_transport(
|
||||
keypair: &identity::Keypair,
|
||||
) -> alias::AnyResult<Boxed<(PeerId, muxing::StreamMuxerBox)>> {
|
||||
use libp2p::{
|
||||
core::upgrade::Version,
|
||||
tcp::{Config, tokio},
|
||||
};
|
||||
|
||||
// `TCP_NODELAY` enabled => avoid latency
|
||||
let tcp_config = Config::default().nodelay(true);
|
||||
|
||||
// V1 + lazy flushing => 0-RTT negotiation
|
||||
let upgrade_version = Version::V1Lazy;
|
||||
|
||||
// Noise is faster than TLS + we don't care much for security
|
||||
let noise_config = noise::Config::new(keypair)?;
|
||||
|
||||
// Use default Yamux config for multiplexing
|
||||
let yamux_config = yamux::Config::default();
|
||||
|
||||
// Create new Tokio-driven TCP/IP transport layer
|
||||
let base_transport = tokio::Transport::new(tcp_config)
|
||||
.and_then(pnet_upgrade)
|
||||
.upgrade(upgrade_version)
|
||||
.authenticate(noise_config)
|
||||
.multiplex(yamux_config);
|
||||
|
||||
// Return boxed transport (to flatten complex type)
|
||||
Ok(base_transport.boxed())
|
||||
}
|
||||
}
|
||||
|
||||
mod behaviour {
|
||||
use crate::{alias, discovery};
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{gossipsub, identity};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Behavior of the Swarm which composes all desired behaviors:
|
||||
/// Right now its just [`discovery::Behaviour`] and [`gossipsub::Behaviour`].
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
pub discovery: discovery::Behaviour,
|
||||
pub gossipsub: gossipsub::Behaviour,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
pub fn new(keypair: &identity::Keypair) -> alias::AnyResult<Self> {
|
||||
Ok(Self {
|
||||
discovery: discovery::Behaviour::new(keypair)?,
|
||||
gossipsub: gossipsub_behaviour(keypair),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn gossipsub_behaviour(keypair: &identity::Keypair) -> gossipsub::Behaviour {
|
||||
use gossipsub::{ConfigBuilder, MessageAuthenticity, ValidationMode};
|
||||
|
||||
// build a gossipsub network behaviour
|
||||
// => signed message authenticity + strict validation mode means the message-ID is
|
||||
// automatically provided by gossipsub w/out needing to provide custom message-ID function
|
||||
gossipsub::Behaviour::new(
|
||||
MessageAuthenticity::Signed(keypair.clone()),
|
||||
ConfigBuilder::default()
|
||||
.publish_queue_duration(Duration::from_secs(15))
|
||||
.max_transmit_size(1024 * 1024)
|
||||
.validation_mode(ValidationMode::Strict)
|
||||
.build()
|
||||
.expect("the configuration should always be valid"),
|
||||
)
|
||||
.expect("creating gossipsub behavior should always work")
|
||||
}
|
||||
}
|
||||
7
rust/networking/tests/dummy.rs
Normal file
7
rust/networking/tests/dummy.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
// maybe this will hold test in the future...??
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn does_nothing() {}
|
||||
}
|
||||
145
rust/parts.nix
Normal file
145
rust/parts.nix
Normal file
@@ -0,0 +1,145 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, ... }:
|
||||
let
|
||||
# Fenix nightly toolchain with all components
|
||||
fenixPkgs = inputs'.fenix.packages;
|
||||
rustToolchain = fenixPkgs.complete.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
"rust-analyzer"
|
||||
];
|
||||
|
||||
# Crane with fenix toolchain
|
||||
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain rustToolchain;
|
||||
|
||||
# Source filtering - only include rust/ directory and root Cargo files
|
||||
# This ensures changes to Python/docs/etc don't trigger Rust rebuilds
|
||||
src = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
parentDir = builtins.dirOf path;
|
||||
inRustDir =
|
||||
(lib.hasInfix "/rust/" path)
|
||||
|| (lib.hasSuffix "/rust" parentDir)
|
||||
|| (baseName == "rust" && type == "directory");
|
||||
isRootCargoFile =
|
||||
(baseName == "Cargo.toml" || baseName == "Cargo.lock")
|
||||
&& (builtins.dirOf path == toString inputs.self);
|
||||
in
|
||||
isRootCargoFile
|
||||
|| (inRustDir && (craneLib.filterCargoSources path type || lib.hasSuffix ".toml" path || lib.hasSuffix ".md" path));
|
||||
};
|
||||
|
||||
# Common arguments for all Rust builds
|
||||
commonArgs = {
|
||||
inherit src;
|
||||
pname = "exo-rust";
|
||||
version = "0.0.1";
|
||||
strictDeps = true;
|
||||
|
||||
nativeBuildInputs = [
|
||||
pkgs.pkg-config
|
||||
pkgs.python313 # Required for pyo3-build-config
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
pkgs.openssl
|
||||
pkgs.python313 # Required for pyo3 tests
|
||||
];
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
|
||||
# Required for pyo3 tests to find libpython
|
||||
LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.python313 ];
|
||||
};
|
||||
|
||||
# Build dependencies once for caching
|
||||
cargoArtifacts = craneLib.buildDepsOnly (
|
||||
commonArgs
|
||||
// {
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
in
|
||||
{
|
||||
# Export toolchain for use in treefmt and devShell
|
||||
options.rust = {
|
||||
toolchain = lib.mkOption {
|
||||
type = lib.types.package;
|
||||
default = rustToolchain;
|
||||
description = "The Rust toolchain to use";
|
||||
};
|
||||
};
|
||||
|
||||
config = {
|
||||
packages = {
|
||||
# Python bindings wheel via maturin
|
||||
exo_pyo3_bindings = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
pname = "exo_pyo3_bindings";
|
||||
|
||||
nativeBuildInputs = commonArgs.nativeBuildInputs ++ [
|
||||
pkgs.maturin
|
||||
];
|
||||
|
||||
buildPhaseCargoCommand = ''
|
||||
maturin build \
|
||||
--release \
|
||||
--manylinux off \
|
||||
--manifest-path rust/exo_pyo3_bindings/Cargo.toml \
|
||||
--features "pyo3/extension-module,pyo3/experimental-async" \
|
||||
--interpreter ${pkgs.python313}/bin/python \
|
||||
--out dist
|
||||
'';
|
||||
|
||||
# Don't use crane's default install behavior
|
||||
doNotPostBuildInstallCargoBinaries = true;
|
||||
|
||||
installPhaseCommand = ''
|
||||
mkdir -p $out
|
||||
cp dist/*.whl $out/
|
||||
'';
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Full workspace build (all crates)
|
||||
cargo-build = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
# Run tests with nextest
|
||||
cargo-nextest = craneLib.cargoNextest (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
|
||||
# Build documentation
|
||||
cargo-doc = craneLib.cargoDoc (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
};
|
||||
};
|
||||
};
|
||||
}
|
||||
2
rust/rust-toolchain.toml
Normal file
2
rust/rust-toolchain.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[toolchain]
|
||||
channel = "nightly"
|
||||
25
rust/util/Cargo.toml
Normal file
25
rust/util/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "util"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "util"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
thiserror = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
internment = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
bon = { workspace = true }
|
||||
recursion = { workspace = true }
|
||||
53
rust/util/src/lib.rs
Normal file
53
rust/util/src/lib.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
//! TODO: crate documentation
|
||||
//!
|
||||
//! this is here as a placeholder documentation
|
||||
//!
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub mod nonempty;
|
||||
pub mod wakerdeque;
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub mod ext {
|
||||
use extend::ext;
|
||||
|
||||
#[ext(pub, name = BoxedSliceExt)]
|
||||
impl<T> Box<[T]> {
|
||||
#[inline]
|
||||
fn map<B, F>(self, f: F) -> Box<[B]>
|
||||
where
|
||||
F: FnMut(T) -> B,
|
||||
{
|
||||
self.into_iter().map(f).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[ext(pub, name = VecExt)]
|
||||
impl<T> Vec<T> {
|
||||
#[inline]
|
||||
fn map<B, F>(self, f: F) -> Vec<B>
|
||||
where
|
||||
F: FnMut(T) -> B,
|
||||
{
|
||||
self.into_iter().map(f).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
138
rust/util/src/nonempty.rs
Normal file
138
rust/util/src/nonempty.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use std::slice::SliceIndex;
|
||||
use std::{ops, slice};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("Cannot create to `NonemptyArray` because the supplied slice is empty")]
|
||||
pub struct EmptySliceError;
|
||||
|
||||
/// A pointer to a non-empty fixed-size slice allocated on the heap.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
#[repr(transparent)]
|
||||
pub struct NonemptyArray<T>(Box<[T]>);
|
||||
|
||||
#[allow(clippy::arbitrary_source_item_ordering)]
|
||||
impl<T> NonemptyArray<T> {
|
||||
#[inline]
|
||||
pub fn singleton(value: T) -> Self {
|
||||
Self(Box::new([value]))
|
||||
}
|
||||
|
||||
#[allow(clippy::missing_errors_doc)]
|
||||
#[inline]
|
||||
pub fn try_from_boxed_slice<S: Into<Box<[T]>>>(
|
||||
boxed_slice: S,
|
||||
) -> Result<Self, EmptySliceError> {
|
||||
let boxed_slice = boxed_slice.into();
|
||||
if boxed_slice.is_empty() {
|
||||
Err(EmptySliceError)
|
||||
} else {
|
||||
Ok(Self(boxed_slice))
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn into_boxed_slice(self) -> Box<[T]> {
|
||||
self.0
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn to_vec(&self) -> Vec<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
self.0.to_vec()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub const fn as_slice(&self) -> &[T] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn first(&self) -> &T {
|
||||
&self.0[0]
|
||||
}
|
||||
|
||||
#[allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn last(&self) -> &T {
|
||||
&self.0[self.0.len() - 1]
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub fn get<I>(&self, index: I) -> Option<&I::Output>
|
||||
where
|
||||
I: SliceIndex<[T]>,
|
||||
{
|
||||
self.0.get(index)
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
#[must_use]
|
||||
#[inline]
|
||||
pub const fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
#[allow(clippy::iter_without_into_iter)]
|
||||
#[inline]
|
||||
pub fn iter(&self) -> slice::Iter<'_, T> {
|
||||
self.0.iter()
|
||||
}
|
||||
|
||||
#[allow(clippy::iter_without_into_iter)]
|
||||
#[inline]
|
||||
pub fn iter_mut(&mut self) -> slice::IterMut<'_, T> {
|
||||
self.0.iter_mut()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn map<U, F: FnMut(T) -> U>(self, f: F) -> NonemptyArray<U> {
|
||||
NonemptyArray(self.0.into_iter().map(f).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<NonemptyArray<T>> for Box<[T]> {
|
||||
#[inline]
|
||||
fn from(value: NonemptyArray<T>) -> Self {
|
||||
value.into_boxed_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ops::Index<usize> for NonemptyArray<T> {
|
||||
type Output = T;
|
||||
|
||||
#[inline]
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
self.0.index(index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoIterator for NonemptyArray<T> {
|
||||
type Item = T;
|
||||
type IntoIter = std::vec::IntoIter<T>;
|
||||
|
||||
#[inline]
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.into_boxed_slice().into_vec().into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> IntoIterator for &'a NonemptyArray<T> {
|
||||
type Item = &'a T;
|
||||
type IntoIter = slice::Iter<'a, T>;
|
||||
|
||||
#[inline]
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.iter()
|
||||
}
|
||||
}
|
||||
55
rust/util/src/wakerdeque.rs
Normal file
55
rust/util/src/wakerdeque.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::task::{Context, Waker};
|
||||
|
||||
/// A wrapper around [`VecDeque`] which wakes (if it can) on any `push_*` methods,
|
||||
/// and updates the internally stored waker by consuming [`Context`] on any `pop_*` methods.
|
||||
pub struct WakerDeque<T> {
|
||||
waker: Option<Waker>,
|
||||
deque: VecDeque<T>,
|
||||
}
|
||||
|
||||
impl<T: Debug> Debug for WakerDeque<T> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
self.deque.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> WakerDeque<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
waker: None,
|
||||
deque: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, cx: &mut Context<'_>) {
|
||||
self.waker = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
fn wake(&mut self) {
|
||||
let Some(ref mut w) = self.waker else { return };
|
||||
w.wake_by_ref();
|
||||
self.waker = None;
|
||||
}
|
||||
|
||||
pub fn pop_front(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_front()
|
||||
}
|
||||
|
||||
pub fn pop_back(&mut self, cx: &mut Context<'_>) -> Option<T> {
|
||||
self.update(cx);
|
||||
self.deque.pop_back()
|
||||
}
|
||||
|
||||
pub fn push_front(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_front(value);
|
||||
}
|
||||
|
||||
pub fn push_back(&mut self, value: T) {
|
||||
self.wake();
|
||||
self.deque.push_back(value);
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user