mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-16 01:51:03 -05:00
Compare commits
59 Commits
linux-cpu-
...
JakeHillio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c0b5f69e0 | ||
|
|
c22dad8a7d | ||
|
|
4bc4d50685 | ||
|
|
e0aab46fd8 | ||
|
|
82ba42bae9 | ||
|
|
3671528fa4 | ||
|
|
e6434ec446 | ||
|
|
bdb43e1dbb | ||
|
|
e4a01e2b0e | ||
|
|
1200a7db64 | ||
|
|
47ceb54bc1 | ||
|
|
f8112fdf25 | ||
|
|
e388f59480 | ||
|
|
e5e74e1eef | ||
|
|
b968d6f0a0 | ||
|
|
3bfffd9b4f | ||
|
|
007eb80029 | ||
|
|
8d7b6789b3 | ||
|
|
3c5b7ea670 | ||
|
|
b74a610537 | ||
|
|
18c4e49f91 | ||
|
|
d85b5d3781 | ||
|
|
caafc48693 | ||
|
|
cca8c9984a | ||
|
|
d1e88def42 | ||
|
|
59e7594e34 | ||
|
|
c65320acd3 | ||
|
|
b9a78f6f3a | ||
|
|
8f7f0e893a | ||
|
|
4759b09d4c | ||
|
|
ca680185f3 | ||
|
|
383309e24e | ||
|
|
55463a9806 | ||
|
|
56af61fac9 | ||
|
|
f76d543d98 | ||
|
|
ea841aca37 | ||
|
|
077b1bc732 | ||
|
|
4963c33162 | ||
|
|
4f6fcd9e93 | ||
|
|
839b67f318 | ||
|
|
47b8e0ce12 | ||
|
|
17f9b583a4 | ||
|
|
844bcc7ce6 | ||
|
|
c1be5184b2 | ||
|
|
1ec550dff1 | ||
|
|
283c0e39e4 | ||
|
|
35be4c55c3 | ||
|
|
31d4cd8409 | ||
|
|
8a6da58404 | ||
|
|
16e2bfd3b3 | ||
|
|
ade3ee7ec5 | ||
|
|
fea42473dd | ||
|
|
ca7adcc2a8 | ||
|
|
9d9e24f969 | ||
|
|
b5d424b658 | ||
|
|
b465134012 | ||
|
|
eabdcab978 | ||
|
|
8e9332d6a7 | ||
|
|
4b65d5f896 |
159
.github/benchmark-dashboard/README.md
vendored
159
.github/benchmark-dashboard/README.md
vendored
@@ -1,159 +0,0 @@
|
||||
# EXO Benchmark Dashboard
|
||||
|
||||
A fully self-contained, browser-based dashboard for tracking EXO benchmark performance over time.
|
||||
|
||||
## Features
|
||||
|
||||
- 📊 **Success Rate Tracking**: Monitor cluster reliability across commits
|
||||
- ⚡ **Response Time Analysis**: Track average request completion times
|
||||
- 🎯 **Throughput Metrics**: Tokens per second visualization
|
||||
- 📈 **Request Distribution**: Success/failure breakdown over time
|
||||
- 🔄 **Auto-Refresh**: Updates every 60 seconds
|
||||
- 📺 **TV-Ready**: Large, clear visualizations perfect for display
|
||||
- 🔐 **Secure**: Credentials stored in browser localStorage only
|
||||
- 🌐 **No Backend**: Directly accesses S3 from the browser
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Option 1: Direct File Access (Simplest)
|
||||
|
||||
Just open the HTML file directly in your browser:
|
||||
|
||||
```bash
|
||||
open .github/benchmark-dashboard/index.html
|
||||
```
|
||||
|
||||
Then click "Configure AWS Credentials" and enter your keys.
|
||||
|
||||
### Option 2: URL Parameters (For Quick Setup)
|
||||
|
||||
```bash
|
||||
# Serve with credentials in URL (they'll be moved to localStorage)
|
||||
open ".github/benchmark-dashboard/index.html?accessKey=YOUR_KEY&secretKey=YOUR_SECRET®ion=us-east-1"
|
||||
```
|
||||
|
||||
The credentials will be saved to localStorage and removed from the URL immediately.
|
||||
|
||||
### Option 3: Simple HTTP Server
|
||||
|
||||
```bash
|
||||
# From repo root
|
||||
python3 -m http.server 8080
|
||||
|
||||
# Then open: http://localhost:8080/.github/benchmark-dashboard/
|
||||
```
|
||||
|
||||
## AWS Credentials
|
||||
|
||||
The dashboard needs read-only access to the `exo-benchmark-results` S3 bucket.
|
||||
|
||||
### Required IAM Permissions
|
||||
|
||||
```json
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"s3:GetObject",
|
||||
"s3:ListBucket"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:s3:::exo-benchmark-results",
|
||||
"arn:aws:s3:::exo-benchmark-results/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Security Notes
|
||||
|
||||
- ✅ Credentials stored in browser `localStorage` only
|
||||
- ✅ Never sent to any server (except AWS)
|
||||
- ✅ All S3 access happens client-side
|
||||
- ✅ Use read-only IAM credentials
|
||||
- ⚠️ Don't commit credentials to git
|
||||
- ⚠️ Use a dedicated read-only IAM user
|
||||
|
||||
## TV/Kiosk Mode
|
||||
|
||||
For permanent display on a TV:
|
||||
|
||||
### macOS
|
||||
```bash
|
||||
open -a "Google Chrome" --args --kiosk ".github/benchmark-dashboard/index.html"
|
||||
```
|
||||
|
||||
### Linux
|
||||
```bash
|
||||
chromium-browser --kiosk --app="file://$(pwd)/.github/benchmark-dashboard/index.html"
|
||||
```
|
||||
|
||||
### Auto-start on Boot
|
||||
|
||||
Create a simple startup script:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# /usr/local/bin/start-benchmark-dashboard.sh
|
||||
|
||||
cd /path/to/exo
|
||||
python3 -m http.server 8080 &
|
||||
sleep 2
|
||||
chromium-browser --kiosk http://localhost:8080/.github/benchmark-dashboard/
|
||||
```
|
||||
|
||||
## Data Displayed
|
||||
|
||||
### Summary Cards
|
||||
- **Latest Success Rate**: Most recent benchmark success percentage with trend
|
||||
- **Avg Response Time**: Latest average response time in ms with trend
|
||||
- **Total Benchmarks**: Count of all benchmarks run
|
||||
- **Active Configurations**: Number of unique benchmark configs
|
||||
|
||||
### Charts
|
||||
1. **Success Rate Over Time**: Line chart showing reliability trends
|
||||
2. **Average Response Time**: Performance over time (lower is better)
|
||||
3. **Throughput**: Tokens/second metric (higher is better)
|
||||
4. **Request Distribution**: Stacked bar chart of successes/failures
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Loads AWS SDK**: Uses AWS SDK for JavaScript (browser version)
|
||||
2. **Lists S3 Objects**: Fetches all files from `s3://exo-benchmark-results/bench/`
|
||||
3. **Downloads Results**: Fetches each JSON result file
|
||||
4. **Parses & Visualizes**: Uses Chart.js to create interactive charts
|
||||
5. **Auto-Refreshes**: Polls S3 every 60 seconds for new results
|
||||
|
||||
## Customization
|
||||
|
||||
To modify the dashboard:
|
||||
|
||||
1. Edit `index.html`
|
||||
2. Adjust `REFRESH_INTERVAL` for different polling frequency
|
||||
3. Modify chart colors/styles in the Chart.js configuration
|
||||
4. Add new metrics by extending the results parsing
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**"AWS credentials not configured"**
|
||||
- Click "Configure AWS Credentials" and enter your keys
|
||||
|
||||
**"Error loading benchmark data"**
|
||||
- Check AWS credentials are correct
|
||||
- Verify S3 bucket name is `exo-benchmark-results`
|
||||
- Ensure IAM user has read permissions
|
||||
- Check browser console for detailed errors
|
||||
|
||||
**"No benchmark results found"**
|
||||
- Wait for benchmark workflows to run
|
||||
- Verify results are being uploaded to S3
|
||||
- Check S3 bucket has files in `bench/` prefix
|
||||
|
||||
**Charts not updating**
|
||||
- Check browser console for errors
|
||||
- Verify network connectivity to S3
|
||||
- Try refreshing the page manually
|
||||
|
||||
1641
.github/benchmark-dashboard/index.html
vendored
1641
.github/benchmark-dashboard/index.html
vendored
File diff suppressed because it is too large
Load Diff
186
.github/configs/README.md
vendored
186
.github/configs/README.md
vendored
@@ -1,186 +0,0 @@
|
||||
# EXO Benchmark Configurations
|
||||
|
||||
This directory contains configuration files for the EXO staged benchmark system.
|
||||
|
||||
## Overview
|
||||
|
||||
The staged benchmark system allows you to run complex, multi-stage load tests against EXO clusters. Each stage can have different characteristics:
|
||||
|
||||
- **Prompt Length**: Number of tokens in the input prompt
|
||||
- **Generation Length**: Maximum tokens to generate in the response
|
||||
- **Time Between Requests**: Delay (in seconds) between firing consecutive requests
|
||||
- **Iterations**: Number of requests to send in this stage
|
||||
|
||||
Requests are **fire-and-forget** - they don't wait for the previous request to complete. This allows you to test overlapping request handling and measure success rates under load.
|
||||
|
||||
## Configuration Files
|
||||
|
||||
### `bench_simple.yaml`
|
||||
A minimal configuration that replicates the behavior of the original `bench.py` script:
|
||||
- Single stage with 1 iteration
|
||||
- Short prompt (~20 tokens)
|
||||
- Generates up to 100 tokens
|
||||
|
||||
This is useful for quick smoke tests.
|
||||
|
||||
### `bench_config.yaml`
|
||||
A comprehensive multi-stage benchmark with:
|
||||
1. **Warmup** (10 requests): Light load with short prompts
|
||||
2. **Medium Load** (20 requests): Moderate load with medium prompts
|
||||
3. **Stress Test** (30 requests): Heavy overlapping requests with long prompts
|
||||
4. **Cooldown** (5 requests): Light load to wind down
|
||||
|
||||
This tests the cluster's behavior under varying load patterns.
|
||||
|
||||
## Configuration Schema
|
||||
|
||||
```yaml
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 4
|
||||
|
||||
# Environment variables to set on each node (optional)
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 512
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 600
|
||||
|
||||
# Model instances to run concurrently
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
# Benchmark stages
|
||||
stages:
|
||||
- name: "stage_name" # Human-readable name for this stage
|
||||
prompt_length: 100 # Target prompt length in tokens
|
||||
generation_length: 200 # Max tokens to generate
|
||||
time_between_requests: 2.0 # Seconds between firing requests
|
||||
iterations: 10 # Number of requests in this stage
|
||||
```
|
||||
|
||||
## Running Benchmarks
|
||||
|
||||
### Via GitHub Actions
|
||||
|
||||
**Automatic (every commit):**
|
||||
- The **`bench`** workflow runs automatically on every push
|
||||
- Uses `bench_simple.yaml` as the default configuration
|
||||
- All settings (hardware plan, timeout, environment variables, models, stages) are defined in the config file
|
||||
|
||||
**Manual (on-demand):**
|
||||
1. Go to **Actions** → **bench** workflow
|
||||
2. Click **Run workflow**
|
||||
3. Configure:
|
||||
- **Config File**: Path to your YAML config (default: `.github/configs/bench_simple.yaml`)
|
||||
- `.github/configs/bench_simple.yaml` for quick tests
|
||||
- `.github/configs/bench_config.yaml` for complex multi-stage tests
|
||||
|
||||
All other settings (hardware plan, timeout, environment variables, models, stages) are read from the specified config file.
|
||||
|
||||
### Via Command Line
|
||||
|
||||
```bash
|
||||
# Start EXO on localhost:8000
|
||||
uv run exo --api-port 8000
|
||||
|
||||
# Run simple benchmark (1 stage, 1 iteration)
|
||||
python3 .github/scripts/bench.py \
|
||||
--api-port 8000 \
|
||||
--config .github/configs/bench_simple.yaml \
|
||||
--expected-nodes 1 \
|
||||
--is-primary true \
|
||||
--timeout-seconds 600
|
||||
|
||||
# Run complex staged benchmark (4 stages, multiple iterations)
|
||||
python3 .github/scripts/bench.py \
|
||||
--api-port 8000 \
|
||||
--config .github/configs/bench_config.yaml \
|
||||
--expected-nodes 1 \
|
||||
--is-primary true \
|
||||
--timeout-seconds 600
|
||||
```
|
||||
|
||||
## Output Metrics
|
||||
|
||||
For each stage, the benchmark reports:
|
||||
|
||||
- **Total Requests**: Number of requests fired
|
||||
- **Successful Requests**: Requests that completed successfully
|
||||
- **Failed Requests**: Requests that encountered errors
|
||||
- **Success Rate**: Percentage of successful requests
|
||||
- **Total Tokens**: Sum of all tokens generated across successful requests
|
||||
- **Avg Tokens/Request**: Average tokens per successful request
|
||||
- **Avg Time/Request**: Average completion time per successful request
|
||||
|
||||
A JSON summary is also printed for easy parsing and storage.
|
||||
|
||||
## Creating Custom Benchmarks
|
||||
|
||||
To create a custom benchmark:
|
||||
|
||||
1. Copy an existing config file (e.g., `bench_config.yaml`)
|
||||
2. Modify the stages to match your test scenario
|
||||
3. Save it in this directory with a descriptive name
|
||||
4. Run it using the workflow or command line
|
||||
|
||||
### Example: Sustained Load Test
|
||||
|
||||
```yaml
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 2
|
||||
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 1024
|
||||
|
||||
timeout_seconds: 600
|
||||
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
stages:
|
||||
- name: "sustained_load"
|
||||
prompt_length: 200
|
||||
generation_length: 150
|
||||
time_between_requests: 0.5 # Very fast - 2 requests/second
|
||||
iterations: 100 # Run for ~50 seconds
|
||||
```
|
||||
|
||||
### Example: Varying Prompt Sizes
|
||||
|
||||
```yaml
|
||||
hardware_plan:
|
||||
M4PRO_GPU16_24GB: 3
|
||||
|
||||
timeout_seconds: 900
|
||||
|
||||
model_ids:
|
||||
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
stages:
|
||||
- name: "tiny_prompts"
|
||||
prompt_length: 10
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
|
||||
- name: "medium_prompts"
|
||||
prompt_length: 200
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
|
||||
- name: "large_prompts"
|
||||
prompt_length: 1000
|
||||
generation_length: 100
|
||||
time_between_requests: 1.0
|
||||
iterations: 10
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
- **Overlapping Requests**: Set `time_between_requests` < expected completion time to test concurrent request handling
|
||||
- **Sequential Requests**: Set `time_between_requests` > expected completion time to ensure requests don't overlap
|
||||
- **Realistic Load**: Model real usage patterns by varying prompt/generation lengths across stages
|
||||
- **Success Rate**: A 100% success rate indicates the cluster handled the load well; lower rates suggest capacity limits
|
||||
|
||||
49
.github/configs/bench_config.yaml
vendored
49
.github/configs/bench_config.yaml
vendored
@@ -1,49 +0,0 @@
|
||||
# EXO Staged Benchmark Configuration
|
||||
# This configuration defines a multi-stage load test for EXO clusters
|
||||
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
M3ULTRA_GPU80_512GB: 4
|
||||
|
||||
# Environment variables to set on each node (optional)
|
||||
environment:
|
||||
OVERRIDE_MEMORY_MB: 512
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 600
|
||||
|
||||
# Multiple instances run concurrently on the cluster
|
||||
model_ids:
|
||||
- "mlx-community/Qwen3-0.6B-4bit"
|
||||
- "mlx-community/Qwen3-0.6B-4bit"
|
||||
|
||||
# Stages run sequentially, each with its own characteristics
|
||||
stages:
|
||||
# Stage 1: Light load with short prompts
|
||||
- name: "warmup"
|
||||
prompt_length: 50 # Number of tokens in prompt
|
||||
generation_length: 100 # Max tokens to generate
|
||||
time_between_requests: 5.0 # Seconds between firing requests
|
||||
iterations: 10 # Number of requests to send in this stage
|
||||
|
||||
# Stage 2: Medium load with medium prompts
|
||||
- name: "medium_load"
|
||||
prompt_length: 200
|
||||
generation_length: 150
|
||||
time_between_requests: 3.0
|
||||
iterations: 20
|
||||
|
||||
# Stage 3: Heavy load with long prompts - requests will overlap
|
||||
- name: "stress_test"
|
||||
prompt_length: 500
|
||||
generation_length: 200
|
||||
time_between_requests: 1.0 # Fast firing - will definitely overlap
|
||||
iterations: 30
|
||||
|
||||
# Stage 4: Cool down with simple prompts
|
||||
- name: "cooldown"
|
||||
prompt_length: 50
|
||||
generation_length: 50
|
||||
time_between_requests: 10.0
|
||||
iterations: 5
|
||||
|
||||
125
.github/configs/bench_simple.yaml
vendored
125
.github/configs/bench_simple.yaml
vendored
@@ -1,125 +0,0 @@
|
||||
# Simple single-shot benchmark
|
||||
# Tests 2 instances concurrently on 2 nodes
|
||||
|
||||
# Hardware configuration - maps runner labels to instance counts
|
||||
hardware_plan:
|
||||
puffin4: 1
|
||||
puffin8: 1
|
||||
|
||||
# Environment variables to set on each node
|
||||
environment:
|
||||
PLACEHOLDER: "placeholder"
|
||||
# OVERRIDE_MEMORY_MB: 50000
|
||||
MLX_METAL_FAST_SYNCH: 1
|
||||
|
||||
# Timeout for instance and runner readiness (seconds)
|
||||
timeout_seconds: 1800
|
||||
|
||||
# Model instances to run concurrently
|
||||
model_ids:
|
||||
# - "mlx-community/DeepSeek-V3.1-8bit"
|
||||
# - "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
- "mlx-community/Kimi-K2-Thinking"
|
||||
# - "mlx-community/Qwen3-235B-A22B-4bit"
|
||||
# - "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
# - "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
|
||||
# Sharding strategy: "Pipeline" or "Tensor"
|
||||
sharding: "Tensor"
|
||||
|
||||
# Instance type: "MlxRing" or "MlxIbv"
|
||||
instance_meta: "MlxIbv"
|
||||
|
||||
# If true, run requests sequentially (no overlap); if false, fire-and-forget (default: false)
|
||||
no_overlap: true
|
||||
|
||||
# Benchmark stages
|
||||
# pp: 64, 256, 1024, 2048, 4096, 8192, 16384
|
||||
# g: 64, 512
|
||||
stages:
|
||||
# - name: "simple"
|
||||
# prompt_length: 512
|
||||
# generation_length: 10
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g64"
|
||||
# prompt_length: 64
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g64"
|
||||
# prompt_length: 64
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp64_g512"
|
||||
# prompt_length: 64
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp256_g64"
|
||||
# prompt_length: 256
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
- name: "pp256_g64"
|
||||
prompt_length: 256
|
||||
generation_length: 64
|
||||
time_between_requests: 2.0
|
||||
iterations: 5
|
||||
# - name: "pp256_g512"
|
||||
# prompt_length: 256
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp1024_g64"
|
||||
# prompt_length: 1024
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp1024_g512"
|
||||
# prompt_length: 1024
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp2048_g64"
|
||||
# prompt_length: 2048
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp2048_g512"
|
||||
# prompt_length: 2048
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp4096_g64"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 4
|
||||
# - name: "pp4096_g512"
|
||||
# prompt_length: 4096
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp8192_g64"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp8192_g512"
|
||||
# prompt_length: 8192
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 5
|
||||
# - name: "pp16384_g64"
|
||||
# prompt_length: 16384
|
||||
# generation_length: 64
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
# - name: "pp16384_g512"
|
||||
# prompt_length: 16384
|
||||
# generation_length: 512
|
||||
# time_between_requests: 2.0
|
||||
# iterations: 10
|
||||
1399
.github/scripts/bench.py
vendored
1399
.github/scripts/bench.py
vendored
File diff suppressed because it is too large
Load Diff
70
.github/scripts/build_matrix.py
vendored
70
.github/scripts/build_matrix.py
vendored
@@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class MatrixEntry(TypedDict):
|
||||
label: str
|
||||
index: int
|
||||
|
||||
|
||||
class MatrixInclude(TypedDict):
|
||||
label: str
|
||||
index: int
|
||||
is_primary: bool
|
||||
expected_nodes: int
|
||||
|
||||
|
||||
class Config(TypedDict):
|
||||
hardware_plan: dict[str, int]
|
||||
timeout_seconds: NotRequired[int]
|
||||
environment: NotRequired[dict[str, str]]
|
||||
|
||||
|
||||
# Read the config file
|
||||
config_file: str = os.environ["CONFIG_FILE"]
|
||||
with open(config_file, "r") as f:
|
||||
config: Config = cast(Config, yaml.safe_load(f))
|
||||
|
||||
# Extract hardware plan from config
|
||||
plan: dict[str, int] = config["hardware_plan"]
|
||||
if not plan:
|
||||
raise ValueError(f"No hardware_plan found in {config_file}")
|
||||
|
||||
# Build matrix entries
|
||||
entries: list[MatrixEntry] = []
|
||||
for label, count in plan.items():
|
||||
for idx in range(count):
|
||||
entries.append({"label": label, "index": idx})
|
||||
|
||||
total_nodes: int = len(entries)
|
||||
matrix: dict[str, list[MatrixInclude]] = {
|
||||
"include": [
|
||||
{
|
||||
"label": e["label"],
|
||||
"index": e["index"],
|
||||
"is_primary": (i == 0),
|
||||
"expected_nodes": total_nodes,
|
||||
}
|
||||
for i, e in enumerate(entries)
|
||||
]
|
||||
}
|
||||
|
||||
# Extract other config values
|
||||
timeout_seconds: int = config.get("timeout_seconds", 600)
|
||||
environment: dict[str, str] = config.get("environment", {})
|
||||
|
||||
# Output to GitHub Actions
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as f:
|
||||
f.write(f"matrix={json.dumps(matrix)}\n")
|
||||
f.write(f"config_file={config_file}\n")
|
||||
f.write(f"timeout_seconds={timeout_seconds}\n")
|
||||
f.write(f"environment={json.dumps(environment)}\n")
|
||||
|
||||
print(f"Matrix: {json.dumps(matrix)}")
|
||||
print(f"Config file: {config_file}")
|
||||
print(f"Timeout: {timeout_seconds}")
|
||||
print(f"Environment: {json.dumps(environment)}")
|
||||
156
.github/workflows/BENCH_USAGE.md
vendored
156
.github/workflows/BENCH_USAGE.md
vendored
@@ -1,156 +0,0 @@
|
||||
# Benchmark Workflow Usage
|
||||
|
||||
## Overview
|
||||
|
||||
The `bench_matrix.yml` workflow enables distributed benchmarking of models across multiple self-hosted macOS runners with different hardware configurations.
|
||||
|
||||
## Workflow Inputs
|
||||
|
||||
| Input | Description | Default | Required |
|
||||
|-------|-------------|---------|----------|
|
||||
| `model_id` | Model ID to benchmark | `mlx-community/Llama-3.2-1B-Instruct-4bit` | Yes |
|
||||
| `hardware_plan` | JSON mapping of runner labels to counts | `{"M4PRO_GPU16_24GB": 1}` | Yes |
|
||||
| `prompt` | Benchmark prompt text | `What is the capital of France?` | No |
|
||||
| `timeout_seconds` | Timeout for instance/runner readiness | `600` | No |
|
||||
|
||||
## Hardware Plan Format
|
||||
|
||||
The `hardware_plan` input is a JSON object mapping runner labels to the number of machines:
|
||||
|
||||
```json
|
||||
{
|
||||
"M4PRO_GPU16_24GB": 2,
|
||||
"M3ULTRA_GPU80_512GB": 1
|
||||
}
|
||||
```
|
||||
|
||||
This example would:
|
||||
- Start 2 runners with the `M4PRO_GPU16_24GB` label
|
||||
- Start 1 runner with the `M3ULTRA_GPU80_512GB` label
|
||||
- Total of 3 runners coordinating on a single distributed inference instance
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Planning Job** (`plan`)
|
||||
- Runs on `ubuntu-latest`
|
||||
- Parses the `hardware_plan` JSON
|
||||
- Generates a dynamic matrix with one entry per runner
|
||||
- Only the first runner (index 0) is marked as `is_primary`
|
||||
|
||||
2. **Benchmark Worker Jobs** (`bench_worker`)
|
||||
- Each job runs on a self-hosted macOS runner with the specified label
|
||||
- All runners start EXO in parallel
|
||||
- The primary runner creates the model instance
|
||||
- All runners wait for their assigned runner to be ready (Loaded/Running status)
|
||||
- The primary runner executes the benchmark and prints results
|
||||
- The primary runner deletes the instance
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Single Machine Benchmark
|
||||
|
||||
```yaml
|
||||
model_id: mlx-community/Llama-3.2-1B-Instruct-4bit
|
||||
hardware_plan: '{"M4PRO_GPU16_24GB": 1}'
|
||||
prompt: What is the capital of France?
|
||||
timeout_seconds: 600
|
||||
```
|
||||
|
||||
### Multi-Machine Distributed Benchmark
|
||||
|
||||
```yaml
|
||||
model_id: mlx-community/Llama-3.2-3B-Instruct-4bit
|
||||
hardware_plan: '{"M4PRO_GPU16_24GB": 2, "M3ULTRA_GPU80_512GB": 1}'
|
||||
prompt: Explain quantum computing in simple terms.
|
||||
timeout_seconds: 900
|
||||
```
|
||||
|
||||
## Benchmark Output
|
||||
|
||||
The primary runner outputs a JSON object with benchmark results:
|
||||
|
||||
```json
|
||||
{
|
||||
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
"instance_id": "abc-123-def",
|
||||
"tokens": 42,
|
||||
"elapsed_s": 2.451,
|
||||
"tps": 17.136
|
||||
}
|
||||
```
|
||||
|
||||
Where:
|
||||
- `tokens`: Number of chunks/tokens generated
|
||||
- `elapsed_s`: Total elapsed time in seconds
|
||||
- `tps`: Tokens per second (tokens / elapsed_s)
|
||||
|
||||
## Runner Requirements
|
||||
|
||||
Each self-hosted runner must:
|
||||
- Be labeled with appropriate hardware tags (e.g., `M4PRO_GPU16_24GB`)
|
||||
- Have the `self-hosted` and `macOS` labels
|
||||
- Have Nix installed with flakes enabled
|
||||
- Have network connectivity to other runners in the same job
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ GitHub Actions Workflow (bench_matrix.yml) │
|
||||
├─────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌────────────────┐ │
|
||||
│ │ Plan Job │ │
|
||||
│ │ (ubuntu) │──┬─► Matrix: [{label, index, primary}] │
|
||||
│ └────────────────┘ │ │
|
||||
│ │ │
|
||||
│ ┌───────────────────▼──────────────────────────────────┐ │
|
||||
│ │ Bench Worker Jobs (Matrix) │ │
|
||||
│ ├──────────────────────────────────────────────────────┤ │
|
||||
│ │ │ │
|
||||
│ │ Runner 0 (Primary) Runner 1 Runner 2 │ │
|
||||
│ │ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ │
|
||||
│ │ │ Start EXO │ │ Start EXO │ │ Start EXO│ │ │
|
||||
│ │ │ Create Inst │ │ Wait... │ │ Wait... │ │ │
|
||||
│ │ │ Wait Ready │ │ Wait Ready │ │ Wait... │ │ │
|
||||
│ │ │ Run Bench │ │ (idle) │ │ (idle) │ │ │
|
||||
│ │ │ Print TPS │ │ │ │ │ │ │
|
||||
│ │ │ Delete Inst │ │ │ │ │ │ │
|
||||
│ │ └─────────────┘ └─────────────┘ └──────────┘ │ │
|
||||
│ └───────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### `scripts/bench.py`
|
||||
|
||||
A standalone Python script that:
|
||||
- Creates instance (primary only)
|
||||
- Polls `/state` endpoint until instance and all runners are ready
|
||||
- Executes chat completion with timing (primary only)
|
||||
- Parses SSE stream and counts tokens
|
||||
- Computes TPS metrics
|
||||
- Cleans up instance (primary only)
|
||||
|
||||
### Key Functions
|
||||
|
||||
- `wait_for_instance()`: Polls until instance with model_id appears
|
||||
- `wait_for_runners_ready()`: Polls until expected number of runners reach Loaded/Running status
|
||||
- `run_benchmark()`: Executes chat completion, measures time, counts tokens
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Instance never becomes ready
|
||||
- Check EXO logs in the workflow output
|
||||
- Verify model_id is valid and accessible
|
||||
- Increase `timeout_seconds`
|
||||
|
||||
### Runner mismatch
|
||||
- Ensure hardware_plan counts match available labeled runners
|
||||
- Check runner labels match exactly (case-sensitive)
|
||||
|
||||
### Network issues
|
||||
- Verify runners can communicate on the network
|
||||
- Check firewall rules between runner hosts
|
||||
|
||||
305
.github/workflows/bench.yml
vendored
305
.github/workflows/bench.yml
vendored
@@ -1,305 +0,0 @@
|
||||
name: bench
|
||||
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
plan:
|
||||
if: contains(github.event.head_commit.message, '/bench')
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.build.outputs.matrix }}
|
||||
config_file: ${{ steps.build.outputs.config_file }}
|
||||
timeout_seconds: ${{ steps.build.outputs.timeout_seconds }}
|
||||
environment: ${{ steps.build.outputs.environment }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Build matrix from config file
|
||||
id: build
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
CONFIG_FILE='.github/configs/bench_simple.yaml'
|
||||
export CONFIG_FILE
|
||||
echo "Config file: $CONFIG_FILE"
|
||||
python3 .github/scripts/build_matrix.py
|
||||
|
||||
bench_worker:
|
||||
needs: plan
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix: ${{ fromJSON(needs.plan.outputs.matrix) }}
|
||||
name: "bench on ${{ matrix.label }} [${{ matrix.index }}]"
|
||||
runs-on: [self-hosted, macOS, "${{ matrix.label }}"]
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
git config --local user.name "github-actions bot"
|
||||
shell: bash
|
||||
|
||||
# TODO: this is mega hacky and I'd like a simpler solution.
|
||||
- name: Setup Nix Environment
|
||||
run: |
|
||||
echo "Checking for nix installation..."
|
||||
|
||||
# Check if nix is already available
|
||||
if command -v nix >/dev/null 2>&1; then
|
||||
echo "Nix already in PATH"
|
||||
# Try sourcing profile scripts to set up environment properly
|
||||
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
echo "Sourcing multi-user nix-daemon profile script"
|
||||
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
elif [ -f "$HOME/.nix-profile/etc/profile.d/nix.sh" ]; then
|
||||
echo "Sourcing single-user nix profile script"
|
||||
source "$HOME/.nix-profile/etc/profile.d/nix.sh"
|
||||
elif [ -f /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh ]; then
|
||||
echo "Sourcing per-user nix profile script"
|
||||
source /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh
|
||||
elif [ -f /etc/profile.d/nix.sh ]; then
|
||||
echo "Sourcing system-wide nix profile script"
|
||||
source /etc/profile.d/nix.sh
|
||||
# Fallback: manually add nix to PATH if binary exists
|
||||
elif [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
echo "Found nix binary, manually adding to PATH"
|
||||
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
elif [ -f "$HOME/.nix-profile/bin/nix" ]; then
|
||||
echo "Found nix binary in user profile, manually adding to PATH"
|
||||
export PATH="$HOME/.nix-profile/bin:$PATH"
|
||||
else
|
||||
echo "Nix not found. Debugging info:"
|
||||
echo "USER: $USER"
|
||||
echo "HOME: $HOME"
|
||||
echo "Current PATH: $PATH"
|
||||
echo ""
|
||||
echo "Checking common Nix locations:"
|
||||
echo " /nix/var/nix/profiles/default/bin/nix:"
|
||||
ls -la /nix/var/nix/profiles/default/bin/nix 2>/dev/null || echo " Not found"
|
||||
echo " /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh:"
|
||||
ls -la /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh 2>/dev/null || echo " Not found"
|
||||
echo " ~/.nix-profile/etc/profile.d/nix.sh:"
|
||||
ls -la "$HOME/.nix-profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||
echo " /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh:"
|
||||
ls -la "/nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
|
||||
echo ""
|
||||
echo "/nix directory structure:"
|
||||
ls -la /nix 2>/dev/null || echo " /nix directory not found"
|
||||
echo ""
|
||||
echo "/nix/var:"
|
||||
ls -la /nix/var 2>/dev/null || echo " /nix/var not found"
|
||||
echo ""
|
||||
echo "/nix/store:"
|
||||
ls -la /nix/store 2>/dev/null | head -20 || echo " /nix/store not found"
|
||||
echo ""
|
||||
echo "GitHub Actions runner is running as user '$USER'."
|
||||
echo "If Nix is installed for a different user, either:"
|
||||
echo " 1. Install Nix for user '$USER' (multi-user install recommended)"
|
||||
echo " 2. Configure the runner service to run as the user with Nix installed"
|
||||
echo " 3. Ensure Nix is installed system-wide with proper daemon setup"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify nix is available and persist to GITHUB_ENV
|
||||
if command -v nix >/dev/null 2>&1; then
|
||||
echo "✓ Nix is available"
|
||||
nix --version
|
||||
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
if [ -n "$NIX_PATH" ]; then
|
||||
echo "NIX_PATH=$NIX_PATH" >> $GITHUB_ENV
|
||||
fi
|
||||
else
|
||||
echo "ERROR: Failed to set up Nix"
|
||||
echo "PATH after setup attempt: $PATH"
|
||||
exit 1
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Setup EXO_HOME and API_PORT
|
||||
run: |
|
||||
EXO_HOME=$(mktemp -d -t exo-e2e-XXXXXXXX)
|
||||
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||
EXO_MODELS_DIR="$HOME/.exo/models"
|
||||
EXO_LIBP2P_NAMESPACE="bench-${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
|
||||
echo "EXO_HOME=$EXO_HOME" >> "$GITHUB_ENV"
|
||||
echo "API_PORT=$API_PORT" >> "$GITHUB_ENV"
|
||||
echo "EXO_MODELS_DIR=$EXO_MODELS_DIR" >> "$GITHUB_ENV"
|
||||
echo "EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE" >> "$GITHUB_ENV"
|
||||
echo "Created EXO_HOME: $EXO_HOME"
|
||||
echo "Generated API_PORT: $API_PORT"
|
||||
echo "Using models from: $EXO_MODELS_DIR"
|
||||
echo "Using libp2p namespace: $EXO_LIBP2P_NAMESPACE"
|
||||
shell: bash
|
||||
|
||||
- name: Configure local MLX if available
|
||||
run: |
|
||||
echo "=== DEBUG: Checking for local MLX configuration ==="
|
||||
MODIFIED=false
|
||||
|
||||
echo "Checking for /Users/Shared/mlx directory..."
|
||||
if [ -d "/Users/Shared/mlx" ]; then
|
||||
echo "✓ Found /Users/Shared/mlx"
|
||||
ls -la /Users/Shared/mlx | head -5
|
||||
echo "Enabling local mlx path in pyproject.toml"
|
||||
sed -i.bak 's|^# mlx = { path = "/Users/Shared/mlx", editable=true }$|mlx = { path = "/Users/Shared/mlx", editable=true }|' pyproject.toml
|
||||
MODIFIED=true
|
||||
else
|
||||
echo "✗ /Users/Shared/mlx not found, will use PyPI version"
|
||||
fi
|
||||
|
||||
echo "Checking for /Users/Shared/mlx-lm directory..."
|
||||
if [ -d "/Users/Shared/mlx-lm" ]; then
|
||||
echo "✓ Found /Users/Shared/mlx-lm"
|
||||
ls -la /Users/Shared/mlx-lm | head -5
|
||||
echo "Enabling local mlx-lm path in pyproject.toml"
|
||||
sed -i.bak 's|^# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }$|mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }|' pyproject.toml
|
||||
MODIFIED=true
|
||||
else
|
||||
echo "✗ /Users/Shared/mlx-lm not found, will use PyPI version"
|
||||
fi
|
||||
|
||||
if [ "$MODIFIED" = true ]; then
|
||||
echo "=== Modified pyproject.toml [tool.uv.sources] section: ==="
|
||||
sed -n '/\[tool\.uv\.sources\]/,/^\[/{/^\[tool\.uv\.sources\]/p; /^\[/!p;}' pyproject.toml
|
||||
echo "=== Regenerating uv.lock with local MLX paths... ==="
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command uv lock --upgrade-package mlx --upgrade-package mlx-lm
|
||||
echo "✓ Lock file regenerated"
|
||||
else
|
||||
echo "⚠ No local MLX directories found, using PyPI packages"
|
||||
fi
|
||||
echo "=== DEBUG: Local MLX configuration complete ==="
|
||||
shell: bash
|
||||
|
||||
- name: Sync dependencies
|
||||
run: |
|
||||
if [ -d "/Users/Shared/test" ]; then
|
||||
pushd /Users/Shared/test
|
||||
uv sync --reinstall
|
||||
popd
|
||||
fi
|
||||
echo "Running just sync to ensure clean dependencies..."
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync
|
||||
shell: bash
|
||||
|
||||
- name: Start EXO and run bench script
|
||||
shell: bash
|
||||
env:
|
||||
IS_PRIMARY: ${{ matrix.is_primary }}
|
||||
EXPECTED_NODES: ${{ matrix.expected_nodes }}
|
||||
HARDWARE_LABEL: ${{ matrix.label }}
|
||||
CONFIG_FILE: ${{ needs.plan.outputs.config_file }}
|
||||
TIMEOUT_SECONDS: ${{ needs.plan.outputs.timeout_seconds }}
|
||||
ENVIRONMENT_JSON: ${{ needs.plan.outputs.environment }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Parse environment variables from config
|
||||
ENV_VARS=""
|
||||
if [ -n "$ENVIRONMENT_JSON" ] && [ "$ENVIRONMENT_JSON" != "{}" ]; then
|
||||
ENV_VARS=$(echo "$ENVIRONMENT_JSON" | python3 -c "import sys, json; env = json.load(sys.stdin); print(' '.join([f'{k}={v}' for k, v in env.items()]))")
|
||||
fi
|
||||
|
||||
echo "Starting EXO with API_PORT=${API_PORT} EXO_HOME=${EXO_HOME} EXO_LIBP2P_NAMESPACE=${EXO_LIBP2P_NAMESPACE}"
|
||||
echo "Environment variables from config: $ENV_VARS"
|
||||
LOG_FILE=/tmp/exo.log
|
||||
: > "$LOG_FILE"
|
||||
|
||||
MASTER_FLAG=""
|
||||
if [ "$IS_PRIMARY" = "true" ]; then
|
||||
MASTER_FLAG="-m"
|
||||
fi
|
||||
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||
"EXO_HOME=$EXO_HOME EXO_MODELS_DIR=$EXO_MODELS_DIR EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE $ENV_VARS PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run exo $MASTER_FLAG --api-port $API_PORT" \
|
||||
>> "$LOG_FILE" 2>&1 &
|
||||
|
||||
EXO_PID=$!
|
||||
echo "Started EXO in background with PID: $EXO_PID"
|
||||
echo "Log file: $LOG_FILE"
|
||||
|
||||
cleanup() {
|
||||
echo '=== EXO log (tail) ==='
|
||||
tail -n 300 "$LOG_FILE" || true
|
||||
if ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||
echo "Killing EXO (PID $EXO_PID)"
|
||||
kill "$EXO_PID" || true
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
for i in $(seq 1 60); do
|
||||
if curl -s "http://localhost:${API_PORT}/state" >/dev/null 2>&1; then
|
||||
echo "EXO API ready"
|
||||
break
|
||||
fi
|
||||
if ! ps -p "$EXO_PID" >/dev/null 2>&1; then
|
||||
echo "EXO terminated early"; sed -n '1,200p' "$LOG_FILE" || true; exit 1
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
RESULTS_FILE="/tmp/bench_results_${GITHUB_RUN_ID}_${GITHUB_RUN_ATTEMPT}_$(date +%s).json"
|
||||
echo "Results will be saved to: $RESULTS_FILE"
|
||||
echo "RESULTS_FILE=$RESULTS_FILE" >> "$GITHUB_ENV"
|
||||
|
||||
echo "Running bench script with config: $CONFIG_FILE, timeout: $TIMEOUT_SECONDS"
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
|
||||
"PYTHONUNBUFFERED=1 uv run --no-project --with pyyaml --with pydantic python .github/scripts/bench.py \
|
||||
--api-port $API_PORT \
|
||||
--config $CONFIG_FILE \
|
||||
--expected-nodes ${EXPECTED_NODES} \
|
||||
--is-primary ${IS_PRIMARY} \
|
||||
--timeout-seconds ${TIMEOUT_SECONDS} \
|
||||
--output $RESULTS_FILE \
|
||||
--git-commit ${GITHUB_SHA} \
|
||||
--hardware-labels ${HARDWARE_LABEL}"
|
||||
|
||||
- name: Install AWS CLI
|
||||
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||
run: |
|
||||
if ! command -v aws &> /dev/null; then
|
||||
echo "AWS CLI not found, installing..."
|
||||
brew install awscli
|
||||
else
|
||||
echo "AWS CLI already installed"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Upload results to S3
|
||||
if: always() && env.RESULTS_FILE && matrix.is_primary
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.S3_BENCHMARKS_AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
run: |
|
||||
echo "Checking for results file: $RESULTS_FILE"
|
||||
echo "Is primary: ${{ matrix.is_primary }}"
|
||||
|
||||
if [ -f "$RESULTS_FILE" ]; then
|
||||
TIMESTAMP=$(date -u +%Y/%m/%d/%H%M%S)
|
||||
S3_KEY="bench/${TIMESTAMP}_${GITHUB_SHA:0:8}_${GITHUB_RUN_ID}.json"
|
||||
echo "Uploading results to s3://exo-benchmark-results/$S3_KEY"
|
||||
|
||||
aws s3 cp "$RESULTS_FILE" "s3://exo-benchmark-results/$S3_KEY" \
|
||||
--content-type application/json \
|
||||
--metadata "commit=${GITHUB_SHA},run_id=${GITHUB_RUN_ID},branch=${GITHUB_REF_NAME}"
|
||||
|
||||
echo "Results uploaded successfully"
|
||||
echo "View at: https://exo-benchmark-results.s3.amazonaws.com/$S3_KEY"
|
||||
else
|
||||
echo "Results file not found at: $RESULTS_FILE"
|
||||
echo "Skipping upload"
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Cleanup EXO_HOME
|
||||
run: |
|
||||
echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||
rm -rf "$EXO_HOME"
|
||||
shell: bash
|
||||
if: always()
|
||||
52
.github/workflows/build-app.yml
vendored
52
.github/workflows/build-app.yml
vendored
@@ -1,6 +1,7 @@
|
||||
name: Build EXO macOS DMG
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
@@ -18,6 +19,7 @@ jobs:
|
||||
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
|
||||
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
|
||||
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
|
||||
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT: ${{ secrets.EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT }}
|
||||
AWS_REGION: ${{ secrets.AWS_REGION }}
|
||||
EXO_BUILD_NUMBER: ${{ github.run_number }}
|
||||
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
|
||||
@@ -34,7 +36,7 @@ jobs:
|
||||
|
||||
- name: Derive release version from tag
|
||||
run: |
|
||||
if [[ "$GITHUB_REF_NAME" == "test-app" ]]; then
|
||||
if [[ "$GITHUB_REF_NAME" == "test-app" || "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
VERSION="0.0.0-alpha.0"
|
||||
echo "IS_ALPHA=true" >> $GITHUB_ENV
|
||||
else
|
||||
@@ -47,6 +49,32 @@ jobs:
|
||||
fi
|
||||
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
|
||||
|
||||
- name: Compute build version from semver
|
||||
run: |
|
||||
VERSION="$RELEASE_VERSION"
|
||||
# Extract major.minor.patch (strip prerelease suffix)
|
||||
BASE_VERSION="${VERSION%%-*}"
|
||||
MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
|
||||
MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
|
||||
PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)
|
||||
|
||||
# Extract prerelease number (e.g., "alpha.2" -> 2, or 999 for releases)
|
||||
if [[ "$VERSION" == *-* ]]; then
|
||||
PRERELEASE_PART="${VERSION#*-}"
|
||||
PRERELEASE_NUM="${PRERELEASE_PART##*.}"
|
||||
# Default to 0 if not a number
|
||||
if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
|
||||
PRERELEASE_NUM=0
|
||||
fi
|
||||
else
|
||||
PRERELEASE_NUM=999
|
||||
fi
|
||||
|
||||
# Compute: PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR)
|
||||
BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
|
||||
echo "EXO_BUILD_VERSION=$BUILD_VERSION" >> $GITHUB_ENV
|
||||
echo "Computed build version: $BUILD_VERSION from $VERSION"
|
||||
|
||||
- name: Ensure tag commit is on main
|
||||
if: github.ref_type == 'tag'
|
||||
run: |
|
||||
@@ -85,11 +113,22 @@ jobs:
|
||||
uv python install
|
||||
uv sync --locked
|
||||
|
||||
- name: Install Nix
|
||||
uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- name: Configure Cachix
|
||||
uses: cachix/cachix-action@v14
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build dashboard
|
||||
run: |
|
||||
cd dashboard
|
||||
npm ci
|
||||
npm run build
|
||||
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
|
||||
mkdir -p dashboard/build
|
||||
cp -r "$DASHBOARD_OUT"/* dashboard/build/
|
||||
|
||||
- name: Install Sparkle CLI
|
||||
run: |
|
||||
@@ -162,11 +201,12 @@ jobs:
|
||||
-configuration Release \
|
||||
-derivedDataPath build \
|
||||
MARKETING_VERSION="$RELEASE_VERSION" \
|
||||
CURRENT_PROJECT_VERSION="$EXO_BUILD_NUMBER" \
|
||||
CURRENT_PROJECT_VERSION="$EXO_BUILD_VERSION" \
|
||||
EXO_BUILD_TAG="$RELEASE_VERSION" \
|
||||
EXO_BUILD_COMMIT="$GITHUB_SHA" \
|
||||
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
|
||||
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
|
||||
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT="$EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT" \
|
||||
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
|
||||
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
|
||||
mkdir -p ../../output
|
||||
@@ -294,5 +334,5 @@ jobs:
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
|
||||
if [[ "$IS_ALPHA" != "true" ]]; then
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
fi
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
|
||||
117
.github/workflows/pipeline.yml
vendored
117
.github/workflows/pipeline.yml
vendored
@@ -20,6 +20,12 @@ jobs:
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
@@ -88,9 +94,19 @@ jobs:
|
||||
|
||||
- uses: ./.github/actions/typecheck
|
||||
|
||||
nix-flake-check:
|
||||
name: Check Nix flake
|
||||
runs-on: ubuntu-latest
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
runs-on: ${{ matrix.runner }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- runner: macos-26
|
||||
system: aarch64-darwin
|
||||
- runner: ubuntu-latest
|
||||
system: x86_64-linux
|
||||
- runner: ubuntu-24.04-arm
|
||||
system: aarch64-linux
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -101,83 +117,20 @@ jobs:
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- name: Run nix flake check
|
||||
run: |
|
||||
nix flake check
|
||||
shell: bash
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
# ci:
|
||||
# needs: typecheck
|
||||
# runs-on: ubuntu-latest
|
||||
# permissions:
|
||||
# contents: read
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# steps:
|
||||
# - name: Checkout repository
|
||||
# uses: actions/checkout@v4
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
# token: ${{ secrets.GITHUB_TOKEN }}
|
||||
# lfs: true
|
||||
#
|
||||
# - name: Configure git user
|
||||
# run: |
|
||||
# git config --local user.email "github-actions@users.noreply.github.com"
|
||||
# git config --local user.name "github-actions bot"
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Pull LFS files
|
||||
# run: |
|
||||
# echo "Pulling Git LFS files..."
|
||||
# git lfs pull
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Setup EXO_HOME and API_PORT
|
||||
# run: |
|
||||
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
|
||||
# # Generate random port (macOS compatible method)
|
||||
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
|
||||
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
|
||||
# echo "Created EXO_HOME: $EXO_HOME"
|
||||
# echo "Generated API_PORT: $API_PORT"
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Setup Nix Environment
|
||||
# run: |
|
||||
# echo "Checking for nix installation..."
|
||||
#
|
||||
# # Check if nix binary exists directly
|
||||
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
# echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
# nix --version
|
||||
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
# echo "Found nix profile script, sourcing..."
|
||||
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
# nix --version
|
||||
# elif command -v nix >/dev/null 2>&1; then
|
||||
# echo "Nix already in PATH"
|
||||
# nix --version
|
||||
# else
|
||||
# echo "Nix not found. Debugging info:"
|
||||
# echo "Contents of /nix/var/nix/profiles/default/:"
|
||||
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||
# exit 1
|
||||
# fi
|
||||
# shell: bash
|
||||
#
|
||||
# - uses: ./.github/actions/lint-check
|
||||
#
|
||||
# - uses: ./.github/actions/unit-test
|
||||
#
|
||||
# - name: Cleanup EXO_HOME
|
||||
# run: |
|
||||
# echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||
# rm -rf "$EXO_HOME"
|
||||
# shell: bash
|
||||
# if: always()
|
||||
- name: Build all Nix outputs
|
||||
run: |
|
||||
nix flake show --json | jq -r '
|
||||
[
|
||||
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
|
||||
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
|
||||
] | .[]
|
||||
' | xargs nix build
|
||||
|
||||
- name: Run nix flake check
|
||||
run: nix flake check
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -7,6 +7,8 @@ digest.txt
|
||||
# nix
|
||||
.direnv/
|
||||
|
||||
# IDEA (PyCharm)
|
||||
.idea
|
||||
|
||||
# xcode / macos
|
||||
*.xcuserstate
|
||||
@@ -14,6 +16,7 @@ digest.txt
|
||||
*.xcuserdatad/
|
||||
**/.DS_Store
|
||||
app/EXO/build/
|
||||
dist/
|
||||
|
||||
|
||||
# rust
|
||||
|
||||
156
.mlx_typings/mlx_lm/models/deepseek_v3.pyi
Normal file
156
.mlx_typings/mlx_lm/models/deepseek_v3.pyi
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Type stubs for mlx_lm.models.deepseek_v3"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
moe_intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
n_shared_experts: Optional[int]
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
kv_lora_rank: int
|
||||
q_lora_rank: Optional[int]
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
topk_method: str
|
||||
scoring_func: str
|
||||
norm_topk_prob: bool
|
||||
n_group: int
|
||||
topk_group: int
|
||||
num_experts_per_tok: int
|
||||
moe_layer_freq: int
|
||||
first_k_dense_replace: int
|
||||
max_position_embeddings: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Any]]
|
||||
attention_bias: bool
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
num_heads: int
|
||||
max_position_embeddings: int
|
||||
rope_theta: float
|
||||
q_lora_rank: Optional[int]
|
||||
qk_rope_head_dim: int
|
||||
kv_lora_rank: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
q_head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
q_a_proj: nn.Linear
|
||||
q_a_layernorm: nn.RMSNorm
|
||||
q_b_proj: nn.Linear
|
||||
kv_a_proj_with_mqa: nn.Linear
|
||||
kv_a_layernorm: nn.RMSNorm
|
||||
kv_b_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
rope: Any
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
gate_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
hidden_size: Optional[int] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
config: ModelArgs
|
||||
top_k: int
|
||||
norm_topk_prob: bool
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
n_group: int
|
||||
topk_group: int
|
||||
weight: mx.array
|
||||
e_score_correction_bias: mx.array
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
config: ModelArgs
|
||||
num_experts_per_tok: int
|
||||
switch_mlp: SwitchGLU
|
||||
gate: MoEGate
|
||||
shared_experts: DeepseekV3MLP
|
||||
sharding_group: Optional[mx.distributed.Group]
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
self_attn: DeepseekV3Attention
|
||||
mlp: DeepseekV3MLP | DeepseekV3MoE
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
vocab_size: int
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[DeepseekV3DecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
model_type: str
|
||||
model: DeepseekV3Model
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
@property
|
||||
def layers(self) -> list[DeepseekV3DecoderLayer]: ...
|
||||
@@ -57,6 +57,11 @@ class SwiGLU(nn.Module):
|
||||
def __call__(self, x, gate): ...
|
||||
|
||||
class SwitchGLU(nn.Module):
|
||||
gate_proj: SwitchLinear
|
||||
up_proj: SwitchLinear
|
||||
down_proj: SwitchLinear
|
||||
activation: SwiGLU
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
|
||||
@@ -4,6 +4,7 @@ This type stub file was generated by pyright.
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
@@ -103,37 +104,55 @@ class TokenizerWrapper:
|
||||
Accessing any attribute other than the ``detokenizer`` is forwarded to the
|
||||
huggingface tokenizer.
|
||||
"""
|
||||
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
|
||||
def add_eos_token(self, token: str): # -> None:
|
||||
...
|
||||
@property
|
||||
def has_thinking(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def think_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def think_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def has_tool_calling(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def tool_call_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def tool_call_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def detokenizer(self): # -> NaiveStreamingDetokenizer:
|
||||
"""
|
||||
Get a stateful streaming detokenizer.
|
||||
"""
|
||||
|
||||
def __getattr__(self, attr): # -> set[Any] | Any:
|
||||
...
|
||||
def __setattr__(self, attr, value): # -> None:
|
||||
...
|
||||
_tokenizer: PreTrainedTokenizerFast
|
||||
eos_token_id: int | None
|
||||
eos_token: str | None
|
||||
bos_token_id: int | None
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
all_special_tokens: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
detokenizer_class: Any = ...,
|
||||
eos_token_ids: list[int] | None = ...,
|
||||
chat_template: Any = ...,
|
||||
tool_parser: Any = ...,
|
||||
tool_call_start: str | None = ...,
|
||||
tool_call_end: str | None = ...,
|
||||
) -> None: ...
|
||||
def encode(self, text: str, **kwargs: Any) -> list[int]: ...
|
||||
def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tokenize: bool = False,
|
||||
add_generation_prompt: bool = False,
|
||||
tools: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> str: ...
|
||||
def get_vocab(self) -> dict[str, int]: ...
|
||||
def add_eos_token(self, token: str) -> None: ...
|
||||
@property
|
||||
def has_thinking(self) -> bool: ...
|
||||
@property
|
||||
def think_start(self) -> str | None: ...
|
||||
@property
|
||||
def think_end(self) -> str | None: ...
|
||||
@property
|
||||
def has_tool_calling(self) -> bool: ...
|
||||
@property
|
||||
def tool_call_start(self) -> str | None: ...
|
||||
@property
|
||||
def tool_call_end(self) -> str | None: ...
|
||||
@property
|
||||
def detokenizer(self) -> NaiveStreamingDetokenizer:
|
||||
"""Get a stateful streaming detokenizer."""
|
||||
|
||||
def __getattr__(self, attr: str) -> Any: ...
|
||||
def __setattr__(self, attr: str, value: Any) -> None: ...
|
||||
|
||||
class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
|
||||
@@ -146,18 +165,11 @@ class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
def batch_decode(self, *args, **kwargs): # -> list[str]:
|
||||
...
|
||||
|
||||
def load_tokenizer(
|
||||
def load(
|
||||
model_path: Path,
|
||||
tokenizer_config_extra=...,
|
||||
return_tokenizer=...,
|
||||
eos_token_ids=...,
|
||||
) -> (
|
||||
TokenizerWrapper
|
||||
| type[SPMStreamingDetokenizer]
|
||||
| partial[SPMStreamingDetokenizer]
|
||||
| type[BPEStreamingDetokenizer]
|
||||
| type[NaiveStreamingDetokenizer]
|
||||
):
|
||||
tokenizer_config_extra: dict[str, Any] | None = None,
|
||||
eos_token_ids: list[int] | int | None = None,
|
||||
) -> TokenizerWrapper:
|
||||
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||
detokenizer to use.
|
||||
|
||||
@@ -165,4 +177,7 @@ def load_tokenizer(
|
||||
a Hugging Face repo ID.
|
||||
"""
|
||||
|
||||
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
|
||||
# Alias for backward compatibility
|
||||
load_tokenizer = load
|
||||
|
||||
def no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...
|
||||
|
||||
6
.swift-format
Normal file
6
.swift-format
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"version": 1,
|
||||
"indentation": {
|
||||
"spaces": 4
|
||||
}
|
||||
}
|
||||
96
AGENTS.md
Normal file
96
AGENTS.md
Normal file
@@ -0,0 +1,96 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to AI coding agents when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
|
||||
|
||||
## Build & Run Commands
|
||||
|
||||
```bash
|
||||
# Build the dashboard (required before running exo)
|
||||
cd dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo (starts both master and worker with API at http://localhost:52415)
|
||||
uv run exo
|
||||
|
||||
# Run with verbose logging
|
||||
uv run exo -v # or -vv for more verbose
|
||||
|
||||
# Run tests (excludes slow tests by default)
|
||||
uv run pytest
|
||||
|
||||
# Run all tests including slow tests
|
||||
uv run pytest -m ""
|
||||
|
||||
# Run a specific test file
|
||||
uv run pytest src/exo/shared/tests/test_election.py
|
||||
|
||||
# Run a specific test function
|
||||
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
|
||||
|
||||
# Type checking (strict mode)
|
||||
uv run basedpyright
|
||||
|
||||
# Linting
|
||||
uv run ruff check
|
||||
|
||||
# Format code (using nix)
|
||||
nix fmt
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
A single exo `Node` (src/exo/main.py) runs multiple components:
|
||||
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
|
||||
- **Worker**: Handles inference tasks, downloads models, manages runner processes
|
||||
- **Master**: Coordinates cluster state, places model instances across nodes
|
||||
- **Election**: Bully algorithm for master election
|
||||
- **API**: FastAPI server for OpenAI-compatible chat completions
|
||||
|
||||
### Message Flow
|
||||
Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
|
||||
- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers
|
||||
- `LOCAL_EVENTS`: Workers send events to master for indexing
|
||||
- `COMMANDS`: Workers/API send commands to master
|
||||
- `ELECTION_MESSAGES`: Election protocol messages
|
||||
- `CONNECTION_MESSAGES`: libp2p connection updates
|
||||
|
||||
### Event Sourcing
|
||||
The system uses event sourcing for state management:
|
||||
- `State` (src/exo/shared/types/state.py): Immutable state object
|
||||
- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state
|
||||
- Master indexes events and broadcasts; workers apply indexed events
|
||||
|
||||
### Key Type Hierarchy
|
||||
- `src/exo/shared/types/`: Pydantic models for all shared types
|
||||
- `events.py`: Event types (discriminated union)
|
||||
- `commands.py`: Command types
|
||||
- `tasks.py`: Task types for worker execution
|
||||
- `state.py`: Cluster state model
|
||||
|
||||
### Rust Components
|
||||
Rust code in `rust/` provides:
|
||||
- `networking`: libp2p networking (gossipsub, peer discovery)
|
||||
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
|
||||
- `system_custodian`: System-level operations
|
||||
|
||||
### Dashboard
|
||||
Svelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.
|
||||
|
||||
## Code Style Requirements
|
||||
|
||||
From .cursorrules:
|
||||
- Strict, exhaustive typing - never bypass the type-checker
|
||||
- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives
|
||||
- Pydantic models with `frozen=True` and `strict=True`
|
||||
- Pure functions with injectable effect handlers for side-effects
|
||||
- Descriptive names - no abbreviations or 3-letter acronyms
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
19
Cargo.lock
generated
19
Cargo.lock
generated
@@ -4340,25 +4340,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system_custodian"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"impl-trait-for-tuples",
|
||||
"keccak-const",
|
||||
"log",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tagptr"
|
||||
version = "0.2.0"
|
||||
|
||||
@@ -3,7 +3,6 @@ resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/system_custodian",
|
||||
"rust/util",
|
||||
]
|
||||
|
||||
@@ -25,7 +24,6 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
system_custodian = { path = "rust/system_custodian" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
|
||||
41
MISSED_THINGS.md
Normal file
41
MISSED_THINGS.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# Missed things
|
||||
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
|
||||
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
108
README.md
108
README.md
@@ -8,7 +8,7 @@
|
||||
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/72NsF6ux" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://x.com/exolabs" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/twitter/follow/exolabs?style=social" alt="X"></a>
|
||||
<a href="https://www.apache.org/licenses/LICENSE-2.0.html" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/License-Apache2.0-blue.svg" alt="License: Apache-2.0"></a>
|
||||
</p>
|
||||
@@ -61,10 +61,10 @@ Devices running exo automatically discover each other, without needing any manua
|
||||
|
||||
There are two ways to run exo:
|
||||
|
||||
### Run from Source (Mac & Linux)
|
||||
### Run from Source (macOS)
|
||||
|
||||
**Prerequisites:**
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
@@ -98,6 +98,62 @@ uv run exo
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
### Run from Source (Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard) - version 18 or higher
|
||||
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
|
||||
|
||||
**Installation methods:**
|
||||
|
||||
**Option 1: Using system package manager (Ubuntu/Debian example):**
|
||||
```bash
|
||||
# Install Node.js and npm
|
||||
sudo apt update
|
||||
sudo apt install nodejs npm
|
||||
|
||||
# Install uv
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Option 2: Using Homebrew on Linux (if preferred):**
|
||||
```bash
|
||||
# Install Homebrew on Linux
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
|
||||
# Install dependencies
|
||||
brew install uv node
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Note:** The `macmon` package is macOS-only and not required for Linux.
|
||||
|
||||
Clone the repo, build the dashboard, and run exo:
|
||||
|
||||
```bash
|
||||
# Clone exo
|
||||
git clone https://github.com/exo-explore/exo
|
||||
|
||||
# Build dashboard
|
||||
cd exo/dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo
|
||||
uv run exo
|
||||
```
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
|
||||
|
||||
### macOS App
|
||||
|
||||
exo ships a macOS app that runs in the background on your Mac.
|
||||
@@ -110,6 +166,47 @@ Download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-
|
||||
|
||||
The app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.
|
||||
|
||||
#### Uninstalling the macOS App
|
||||
|
||||
The recommended way to uninstall is through the app itself: click the menu bar icon → Advanced → Uninstall. This cleanly removes all system components.
|
||||
|
||||
If you've already deleted the app, you can run the standalone uninstaller script:
|
||||
|
||||
```bash
|
||||
sudo ./app/EXO/uninstall-exo.sh
|
||||
```
|
||||
|
||||
This removes:
|
||||
- Network setup LaunchDaemon
|
||||
- Network configuration script
|
||||
- Log files
|
||||
- The "exo" network location
|
||||
|
||||
**Note:** You'll need to manually remove EXO from Login Items in System Settings → General → Login Items.
|
||||
|
||||
---
|
||||
|
||||
### Enabling RDMA on macOS
|
||||
|
||||
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
|
||||
|
||||
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
|
||||
To enable RDMA on macOS, follow these steps:
|
||||
|
||||
1. Shut down your Mac.
|
||||
2. Hold down the power button for 10 seconds until the boot menu appears.
|
||||
3. Select "Options" to enter Recovery mode.
|
||||
4. When the Recovery UI appears, open the Terminal from the Utilities menu.
|
||||
5. In the Terminal, type:
|
||||
```
|
||||
rdma_ctl enable
|
||||
```
|
||||
and press Enter.
|
||||
6. Reboot your Mac.
|
||||
|
||||
After that, RDMA will be enabled in macOS and exo will take care of the rest.
|
||||
|
||||
---
|
||||
|
||||
### Using the API
|
||||
@@ -208,7 +305,10 @@ curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
|
||||
- List all models: `curl http://localhost:52415/models`
|
||||
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
|
||||
|
||||
For further details, see API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
|
||||
For further details, see:
|
||||
|
||||
- API basic documentation in [docs/api.md](docs/api.md).
|
||||
- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
|
||||
|
||||
---
|
||||
|
||||
|
||||
1
TODO.md
1
TODO.md
@@ -19,7 +19,6 @@
|
||||
25. Rethink retry logic
|
||||
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
||||
27. Log cleanup - per-module log filters and default to DEBUG log levels
|
||||
28. Validate RDMA connections with ibv_devinfo in the info gatherer
|
||||
|
||||
Potential refactors:
|
||||
|
||||
|
||||
@@ -12,18 +12,25 @@ struct ContentView: View {
|
||||
@EnvironmentObject private var controller: ExoProcessController
|
||||
@EnvironmentObject private var stateService: ClusterStateService
|
||||
@EnvironmentObject private var networkStatusService: NetworkStatusService
|
||||
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@EnvironmentObject private var updater: SparkleUpdater
|
||||
@State private var focusedNode: NodeViewModel?
|
||||
@State private var deletingInstanceIDs: Set<String> = []
|
||||
@State private var showAllNodes = false
|
||||
@State private var showAllInstances = false
|
||||
@State private var showAdvanced = false
|
||||
@State private var showDebugInfo = false
|
||||
@State private var bugReportInFlight = false
|
||||
@State private var bugReportMessage: String?
|
||||
@State private var uninstallInProgress = false
|
||||
@State private var pendingNamespace: String = ""
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
statusSection
|
||||
if shouldShowLocalNetworkWarning {
|
||||
localNetworkWarningBanner
|
||||
}
|
||||
if shouldShowClusterDetails {
|
||||
Divider()
|
||||
overviewSection
|
||||
@@ -38,6 +45,7 @@ struct ContentView: View {
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowClusterDetails)
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowInstances)
|
||||
.animation(.easeInOut(duration: 0.3), value: shouldShowLocalNetworkWarning)
|
||||
.padding()
|
||||
.frame(width: 340)
|
||||
.onAppear {
|
||||
@@ -47,9 +55,62 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
private var localNetworkWarningBanner: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack(spacing: 6) {
|
||||
Image(systemName: "exclamationmark.triangle.fill")
|
||||
.foregroundColor(.orange)
|
||||
Text("Local Network Access Issue")
|
||||
.font(.caption)
|
||||
.fontWeight(.semibold)
|
||||
}
|
||||
Text(
|
||||
"Device discovery won't work. To fix:\n1. Quit EXO\n2. Open System Settings → Privacy & Security → Local Network\n3. Toggle EXO off, then back on\n4. Relaunch EXO"
|
||||
)
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
Button {
|
||||
openLocalNetworkSettings()
|
||||
} label: {
|
||||
Text("Open Settings")
|
||||
.font(.caption2)
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
}
|
||||
.padding(8)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.fill(Color.orange.opacity(0.1))
|
||||
)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 8)
|
||||
.stroke(Color.orange.opacity(0.3), lineWidth: 1)
|
||||
)
|
||||
}
|
||||
|
||||
private func openLocalNetworkSettings() {
|
||||
// Open Privacy & Security settings - Local Network section
|
||||
if let url = URL(
|
||||
string: "x-apple.systempreferences:com.apple.preference.security?Privacy_LocalNetwork")
|
||||
{
|
||||
NSWorkspace.shared.open(url)
|
||||
}
|
||||
}
|
||||
|
||||
private var topologySection: some View {
|
||||
Group {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(), !topology.nodes.isEmpty {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(
|
||||
localNodeId: stateService.localNodeId), !topology.nodes.isEmpty
|
||||
{
|
||||
TopologyMiniView(topology: topology)
|
||||
}
|
||||
}
|
||||
@@ -83,8 +144,10 @@ struct ContentView: View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
HStack {
|
||||
VStack(alignment: .leading) {
|
||||
Text("\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB")
|
||||
.font(.headline)
|
||||
Text(
|
||||
"\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB"
|
||||
)
|
||||
.font(.headline)
|
||||
Text("Memory")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
@@ -193,11 +256,7 @@ struct ContentView: View {
|
||||
Divider()
|
||||
.padding(.vertical, 4)
|
||||
}
|
||||
controlButton(title: "Check for Updates") {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
.padding(.bottom, 8)
|
||||
debugSection
|
||||
advancedSection
|
||||
.padding(.bottom, 8)
|
||||
controlButton(title: "Quit", tint: .secondary) {
|
||||
controller.stop()
|
||||
@@ -206,7 +265,57 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void) -> some View {
|
||||
private var advancedSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
Text("Advanced")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
Spacer()
|
||||
collapseButton(isExpanded: $showAdvanced)
|
||||
}
|
||||
.animation(nil, value: showAdvanced)
|
||||
if showAdvanced {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Cluster Namespace")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
HStack {
|
||||
TextField("optional", text: $pendingNamespace)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingNamespace = controller.customNamespace
|
||||
}
|
||||
Button("Save & Restart") {
|
||||
controller.customNamespace = pendingNamespace
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingNamespace == controller.customNamespace)
|
||||
}
|
||||
}
|
||||
HoverButton(title: "Check for Updates", small: true) {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
debugSection
|
||||
HoverButton(title: "Uninstall", tint: .red, small: true) {
|
||||
showUninstallConfirmationAlert()
|
||||
}
|
||||
.disabled(uninstallInProgress)
|
||||
}
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.25), value: showAdvanced)
|
||||
}
|
||||
|
||||
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void)
|
||||
-> some View
|
||||
{
|
||||
HoverButton(title: title, tint: tint, trailingSystemImage: nil, action: action)
|
||||
}
|
||||
|
||||
@@ -237,9 +346,12 @@ struct ContentView: View {
|
||||
Button {
|
||||
isExpanded.wrappedValue.toggle()
|
||||
} label: {
|
||||
Label(isExpanded.wrappedValue ? "Hide" : "Show All", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
Label(
|
||||
isExpanded.wrappedValue ? "Hide" : "Show All",
|
||||
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
|
||||
)
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.font(.caption2)
|
||||
@@ -328,15 +440,15 @@ struct ContentView: View {
|
||||
}
|
||||
|
||||
private var debugSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
Text("Debug Info")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
Spacer()
|
||||
collapseButton(isExpanded: $showDebugInfo)
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
HoverButton(
|
||||
title: "Debug Info",
|
||||
tint: .primary,
|
||||
trailingSystemImage: showDebugInfo ? "chevron.up" : "chevron.down",
|
||||
small: true
|
||||
) {
|
||||
showDebugInfo.toggle()
|
||||
}
|
||||
.animation(nil, value: showDebugInfo)
|
||||
if showDebugInfo {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Version: \(buildTag)")
|
||||
@@ -349,15 +461,63 @@ struct ContentView: View {
|
||||
.font(.caption2)
|
||||
.foregroundColor(thunderboltStatusColor)
|
||||
interfaceIpList
|
||||
rdmaStatusView
|
||||
sendBugReportButton
|
||||
.padding(.top, 6)
|
||||
}
|
||||
.padding(.leading, 8)
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.25), value: showDebugInfo)
|
||||
}
|
||||
|
||||
private var rdmaStatusView: some View {
|
||||
let rdma = networkStatusService.status.rdmaStatus
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
Text("RDMA: \(rdmaStatusText(rdma))")
|
||||
.font(.caption2)
|
||||
.foregroundColor(rdmaStatusColor(rdma))
|
||||
if !rdma.devices.isEmpty {
|
||||
Text(" Devices: \(rdma.devices.joined(separator: ", "))")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
if !rdma.activePorts.isEmpty {
|
||||
Text(" Active Ports:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(rdma.activePorts, id: \.device) { port in
|
||||
Text(" \(port.device) port \(port.port): \(port.state)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.green)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func rdmaStatusText(_ rdma: RDMAStatus) -> String {
|
||||
switch rdma.rdmaCtlEnabled {
|
||||
case .some(true):
|
||||
return "Enabled"
|
||||
case .some(false):
|
||||
return "Disabled"
|
||||
case nil:
|
||||
return rdma.devices.isEmpty ? "Not Available" : "Available"
|
||||
}
|
||||
}
|
||||
|
||||
private func rdmaStatusColor(_ rdma: RDMAStatus) -> Color {
|
||||
switch rdma.rdmaCtlEnabled {
|
||||
case .some(true):
|
||||
return .green
|
||||
case .some(false):
|
||||
return .orange
|
||||
case nil:
|
||||
return rdma.devices.isEmpty ? .secondary : .green
|
||||
}
|
||||
}
|
||||
|
||||
private var sendBugReportButton: some View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Button {
|
||||
@@ -447,6 +607,88 @@ struct ContentView: View {
|
||||
bugReportInFlight = false
|
||||
}
|
||||
|
||||
private func showUninstallConfirmationAlert() {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Uninstall EXO"
|
||||
alert.informativeText = """
|
||||
This will remove EXO and all its system components:
|
||||
|
||||
• Network configuration daemon
|
||||
• Launch at login registration
|
||||
• EXO network location
|
||||
|
||||
The app will be moved to Trash.
|
||||
"""
|
||||
alert.alertStyle = .warning
|
||||
alert.addButton(withTitle: "Uninstall")
|
||||
alert.addButton(withTitle: "Cancel")
|
||||
|
||||
// Style the Uninstall button as destructive
|
||||
if let uninstallButton = alert.buttons.first {
|
||||
uninstallButton.hasDestructiveAction = true
|
||||
}
|
||||
|
||||
let response = alert.runModal()
|
||||
if response == .alertFirstButtonReturn {
|
||||
performUninstall()
|
||||
}
|
||||
}
|
||||
|
||||
private func performUninstall() {
|
||||
uninstallInProgress = true
|
||||
|
||||
// Stop EXO process first
|
||||
controller.cancelPendingLaunch()
|
||||
controller.stop()
|
||||
stateService.stopPolling()
|
||||
|
||||
// Run the privileged uninstall on a background thread
|
||||
// Using .utility QoS to avoid priority inversion with NSAppleScript's subprocess
|
||||
DispatchQueue.global(qos: .utility).async {
|
||||
do {
|
||||
// Remove network setup daemon and components (requires admin privileges)
|
||||
try NetworkSetupHelper.uninstall()
|
||||
|
||||
DispatchQueue.main.async {
|
||||
// Unregister from launch at login
|
||||
LaunchAtLoginHelper.disable()
|
||||
|
||||
// Move app to trash
|
||||
self.moveAppToTrash()
|
||||
|
||||
// Quit the app
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {
|
||||
NSApplication.shared.terminate(nil)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
DispatchQueue.main.async {
|
||||
self.showErrorAlert(message: error.localizedDescription)
|
||||
self.uninstallInProgress = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func showErrorAlert(message: String) {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Uninstall Failed"
|
||||
alert.informativeText = message
|
||||
alert.alertStyle = .critical
|
||||
alert.addButton(withTitle: "OK")
|
||||
alert.runModal()
|
||||
}
|
||||
|
||||
private func moveAppToTrash() {
|
||||
guard let appURL = Bundle.main.bundleURL as URL? else { return }
|
||||
do {
|
||||
try FileManager.default.trashItem(at: appURL, resultingItemURL: nil)
|
||||
} catch {
|
||||
// If we can't trash the app, that's OK - user can do it manually
|
||||
// The important system components have already been cleaned up
|
||||
}
|
||||
}
|
||||
|
||||
private var buildTag: String {
|
||||
Bundle.main.infoDictionary?["EXOBuildTag"] as? String ?? "unknown"
|
||||
}
|
||||
@@ -460,14 +702,27 @@ private struct HoverButton: View {
|
||||
let title: String
|
||||
let tint: Color
|
||||
let trailingSystemImage: String?
|
||||
let small: Bool
|
||||
let action: () -> Void
|
||||
|
||||
init(
|
||||
title: String, tint: Color = .primary, trailingSystemImage: String? = nil,
|
||||
small: Bool = false, action: @escaping () -> Void
|
||||
) {
|
||||
self.title = title
|
||||
self.tint = tint
|
||||
self.trailingSystemImage = trailingSystemImage
|
||||
self.small = small
|
||||
self.action = action
|
||||
}
|
||||
|
||||
@State private var isHovering = false
|
||||
|
||||
var body: some View {
|
||||
Button(action: action) {
|
||||
HStack {
|
||||
Text(title)
|
||||
.font(small ? .caption : nil)
|
||||
Spacer()
|
||||
if let systemName = trailingSystemImage {
|
||||
Image(systemName: systemName)
|
||||
@@ -475,8 +730,8 @@ private struct HoverButton: View {
|
||||
}
|
||||
}
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
.padding(.vertical, 6)
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, small ? 4 : 6)
|
||||
.padding(.horizontal, small ? 6 : 8)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 6)
|
||||
.fill(
|
||||
@@ -491,4 +746,3 @@ private struct HoverButton: View {
|
||||
.onHover { isHovering = $0 }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
import AppKit
|
||||
import CoreImage
|
||||
import CoreImage.CIFilterBuiltins
|
||||
import ServiceManagement
|
||||
import Sparkle
|
||||
import SwiftUI
|
||||
import ServiceManagement
|
||||
import UserNotifications
|
||||
import os.log
|
||||
|
||||
@@ -19,6 +19,7 @@ struct EXOApp: App {
|
||||
@StateObject private var controller: ExoProcessController
|
||||
@StateObject private var stateService: ClusterStateService
|
||||
@StateObject private var networkStatusService: NetworkStatusService
|
||||
@StateObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@StateObject private var updater: SparkleUpdater
|
||||
private let terminationObserver: TerminationObserver
|
||||
private let ciContext = CIContext(options: nil)
|
||||
@@ -37,9 +38,13 @@ struct EXOApp: App {
|
||||
_stateService = StateObject(wrappedValue: service)
|
||||
let networkStatus = NetworkStatusService()
|
||||
_networkStatusService = StateObject(wrappedValue: networkStatus)
|
||||
let localNetwork = LocalNetworkChecker()
|
||||
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
|
||||
_updater = StateObject(wrappedValue: updater)
|
||||
enableLaunchAtLoginIfNeeded()
|
||||
NetworkSetupHelper.ensureLaunchDaemonInstalled()
|
||||
// Check local network access BEFORE launching exo
|
||||
localNetwork.check()
|
||||
controller.scheduleLaunch(after: 15)
|
||||
service.startPolling()
|
||||
networkStatus.startPolling()
|
||||
@@ -51,6 +56,7 @@ struct EXOApp: App {
|
||||
.environmentObject(controller)
|
||||
.environmentObject(stateService)
|
||||
.environmentObject(networkStatusService)
|
||||
.environmentObject(localNetworkChecker)
|
||||
.environmentObject(updater)
|
||||
} label: {
|
||||
menuBarIcon
|
||||
@@ -107,7 +113,7 @@ struct EXOApp: App {
|
||||
filter.contrast = 0.9
|
||||
|
||||
guard let output = filter.outputImage,
|
||||
let rendered = ciContext.createCGImage(output, from: output.extent)
|
||||
let rendered = ciContext.createCGImage(output, from: output.extent)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
@@ -120,7 +126,26 @@ struct EXOApp: App {
|
||||
do {
|
||||
try SMAppService.mainApp.register()
|
||||
} catch {
|
||||
Logger().error("Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
Logger().error(
|
||||
"Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for managing EXO's launch-at-login registration
|
||||
enum LaunchAtLoginHelper {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LaunchAtLogin")
|
||||
|
||||
/// Unregisters EXO from launching at login
|
||||
static func disable() {
|
||||
guard SMAppService.mainApp.status == .enabled else { return }
|
||||
do {
|
||||
try SMAppService.mainApp.unregister()
|
||||
logger.info("Unregistered EXO from launch at login")
|
||||
} catch {
|
||||
logger.error(
|
||||
"Failed to unregister EXO from launch at login: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,7 +170,7 @@ final class SparkleUpdater: NSObject, ObservableObject {
|
||||
center.requestAuthorization(options: [.alert, .sound]) { _, _ in }
|
||||
controller.updater.automaticallyChecksForUpdates = true
|
||||
controller.updater.automaticallyDownloadsUpdates = false
|
||||
controller.updater.updateCheckInterval = 900 // 15 minutes
|
||||
controller.updater.updateCheckInterval = 900 // 15 minutes
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 5) { [weak controller] in
|
||||
controller?.updater.checkForUpdatesInBackground()
|
||||
}
|
||||
@@ -212,7 +237,8 @@ private final class ExoNotificationDelegate: NSObject, UNUserNotificationCenterD
|
||||
func userNotificationCenter(
|
||||
_ center: UNUserNotificationCenter,
|
||||
willPresent notification: UNNotification,
|
||||
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) -> Void
|
||||
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) ->
|
||||
Void
|
||||
) {
|
||||
completionHandler([.banner, .list, .sound])
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ import AppKit
|
||||
import Combine
|
||||
import Foundation
|
||||
|
||||
private let customNamespaceKey = "EXOCustomNamespace"
|
||||
|
||||
@MainActor
|
||||
final class ExoProcessController: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -27,6 +29,14 @@ final class ExoProcessController: ObservableObject {
|
||||
@Published private(set) var status: Status = .stopped
|
||||
@Published private(set) var lastError: String?
|
||||
@Published private(set) var launchCountdownSeconds: Int?
|
||||
@Published var customNamespace: String = {
|
||||
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
|
||||
}()
|
||||
{
|
||||
didSet {
|
||||
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
|
||||
}
|
||||
}
|
||||
|
||||
private var process: Process?
|
||||
private var runtimeDirectoryURL: URL?
|
||||
@@ -180,7 +190,7 @@ final class ExoProcessController: ObservableObject {
|
||||
private func makeEnvironment(for runtimeURL: URL) -> [String: String] {
|
||||
var environment = ProcessInfo.processInfo.environment
|
||||
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = buildTag()
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
|
||||
|
||||
var paths: [String] = []
|
||||
if let existing = environment["PATH"], !existing.isEmpty {
|
||||
@@ -212,11 +222,19 @@ final class ExoProcessController: ObservableObject {
|
||||
if let tag = Bundle.main.infoDictionary?["EXOBuildTag"] as? String, !tag.isEmpty {
|
||||
return tag
|
||||
}
|
||||
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String, !short.isEmpty {
|
||||
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String,
|
||||
!short.isEmpty
|
||||
{
|
||||
return short
|
||||
}
|
||||
return "dev"
|
||||
}
|
||||
|
||||
private func computeNamespace() -> String {
|
||||
let base = buildTag()
|
||||
let custom = customNamespace.trimmingCharacters(in: .whitespaces)
|
||||
return custom.isEmpty ? base : custom
|
||||
}
|
||||
}
|
||||
|
||||
struct RuntimeError: LocalizedError {
|
||||
|
||||
@@ -8,5 +8,15 @@
|
||||
<string>$(EXO_BUILD_TAG)</string>
|
||||
<key>EXOBuildCommit</key>
|
||||
<string>$(EXO_BUILD_COMMIT)</string>
|
||||
<key>EXOBugReportPresignedUrlEndpoint</key>
|
||||
<string>$(EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT)</string>
|
||||
<key>NSLocalNetworkUsageDescription</key>
|
||||
<string>EXO needs local network access to discover and connect to other devices in your cluster for distributed AI inference.</string>
|
||||
<key>NSBonjourServices</key>
|
||||
<array>
|
||||
<string>_p2p._tcp</string>
|
||||
<string>_p2p._udp</string>
|
||||
<string>_libp2p._udp</string>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
||||
|
||||
@@ -16,10 +16,13 @@ struct ClusterState: Decodable {
|
||||
self.instances = rawInstances.mapValues(\.instance)
|
||||
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
|
||||
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
|
||||
let rawTasks = try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
let rawTasks =
|
||||
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
self.tasks = rawTasks.compactMapValues(\.task)
|
||||
self.topology = try container.decodeIfPresent(Topology.self, forKey: .topology)
|
||||
let rawDownloads = try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads) ?? [:]
|
||||
let rawDownloads =
|
||||
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
|
||||
?? [:]
|
||||
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
|
||||
}
|
||||
|
||||
@@ -41,7 +44,8 @@ private struct TaggedInstance: Decodable {
|
||||
let payloads = try container.decode([String: ClusterInstancePayload].self)
|
||||
guard let entry = payloads.first else {
|
||||
throw DecodingError.dataCorrupted(
|
||||
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
|
||||
DecodingError.Context(
|
||||
codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
|
||||
)
|
||||
}
|
||||
self.instance = ClusterInstance(
|
||||
@@ -77,7 +81,8 @@ struct RunnerStatusSummary: Decodable {
|
||||
let payloads = try container.decode([String: RunnerStatusDetail].self)
|
||||
guard let entry = payloads.first else {
|
||||
throw DecodingError.dataCorrupted(
|
||||
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
|
||||
DecodingError.Context(
|
||||
codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
|
||||
)
|
||||
}
|
||||
self.status = entry.key
|
||||
@@ -257,7 +262,9 @@ struct ChatCompletionTaskParameters: Decodable, Equatable {
|
||||
|
||||
func promptPreview() -> String? {
|
||||
guard let messages else { return nil }
|
||||
if let userMessage = messages.last(where: { $0.role?.lowercased() == "user" && ($0.content?.isEmpty == false) }) {
|
||||
if let userMessage = messages.last(where: {
|
||||
$0.role?.lowercased() == "user" && ($0.content?.isEmpty == false)
|
||||
}) {
|
||||
return userMessage.content
|
||||
}
|
||||
return messages.last?.content
|
||||
@@ -365,5 +372,3 @@ extension ClusterState {
|
||||
|
||||
func availableModels() -> [ModelOption] { [] }
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import CryptoKit
|
||||
import Foundation
|
||||
|
||||
struct BugReportOutcome: Equatable {
|
||||
@@ -7,17 +6,17 @@ struct BugReportOutcome: Equatable {
|
||||
}
|
||||
|
||||
enum BugReportError: LocalizedError {
|
||||
case missingCredentials
|
||||
case invalidEndpoint
|
||||
case presignedUrlFailed(String)
|
||||
case uploadFailed(String)
|
||||
case collectFailed(String)
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .missingCredentials:
|
||||
return "Bug report upload credentials are not set."
|
||||
case .invalidEndpoint:
|
||||
return "Bug report endpoint is invalid."
|
||||
case .presignedUrlFailed(let message):
|
||||
return "Failed to get presigned URLs: \(message)"
|
||||
case .uploadFailed(let message):
|
||||
return "Bug report upload failed: \(message)"
|
||||
case .collectFailed(let message):
|
||||
@@ -27,11 +26,13 @@ enum BugReportError: LocalizedError {
|
||||
}
|
||||
|
||||
struct BugReportService {
|
||||
struct AWSConfig {
|
||||
let accessKey: String
|
||||
let secretKey: String
|
||||
let region: String
|
||||
let bucket: String
|
||||
private struct PresignedUrlsRequest: Codable {
|
||||
let keys: [String]
|
||||
}
|
||||
|
||||
private struct PresignedUrlsResponse: Codable {
|
||||
let urls: [String: String]
|
||||
let expiresIn: Int?
|
||||
}
|
||||
|
||||
func sendReport(
|
||||
@@ -39,9 +40,9 @@ struct BugReportService {
|
||||
now: Date = Date(),
|
||||
isManual: Bool = false
|
||||
) async throws -> BugReportOutcome {
|
||||
let credentials = try loadCredentials()
|
||||
let timestamp = ISO8601DateFormatter().string(from: now)
|
||||
let prefix = "reports/\(timestamp)/"
|
||||
let timestamp = Self.runTimestampString(now)
|
||||
let dayPrefix = Self.dayPrefixString(now)
|
||||
let prefix = "reports/\(dayPrefix)/\(timestamp)/"
|
||||
|
||||
let logData = readLog()
|
||||
let ifconfigText = try await captureIfconfig()
|
||||
@@ -66,29 +67,82 @@ struct BugReportService {
|
||||
("\(prefix)exo.log", logData),
|
||||
("\(prefix)state.json", stateData),
|
||||
("\(prefix)events.json", eventsData),
|
||||
("\(prefix)report.json", reportJSON)
|
||||
("\(prefix)report.json", reportJSON),
|
||||
]
|
||||
|
||||
let uploader = try S3Uploader(config: credentials)
|
||||
for item in uploads {
|
||||
guard let data = item.data else { continue }
|
||||
try await uploader.upload(
|
||||
objectPath: item.path,
|
||||
body: data
|
||||
)
|
||||
let uploadItems: [(key: String, body: Data)] = uploads.compactMap { item in
|
||||
guard let body = item.data else { return nil }
|
||||
return (key: item.path, body: body)
|
||||
}
|
||||
|
||||
return BugReportOutcome(success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
|
||||
guard !uploadItems.isEmpty else {
|
||||
return BugReportOutcome(success: false, message: "No data to upload")
|
||||
}
|
||||
|
||||
let presignedUrls = try await fetchPresignedUploadUrls(keys: uploadItems.map(\.key))
|
||||
for item in uploadItems {
|
||||
guard let urlString = presignedUrls[item.key], let url = URL(string: urlString) else {
|
||||
throw BugReportError.uploadFailed("Missing presigned URL for \(item.key)")
|
||||
}
|
||||
try await uploadToPresignedUrl(url: url, body: item.body)
|
||||
}
|
||||
|
||||
return BugReportOutcome(
|
||||
success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
|
||||
}
|
||||
|
||||
private func loadCredentials() throws -> AWSConfig {
|
||||
// These credentials are write-only and necessary to receive bug reports from users
|
||||
return AWSConfig(
|
||||
accessKey: "AKIAYEKP5EMXTOBYDGHX",
|
||||
secretKey: "Ep5gIlUZ1o8ssTLQwmyy34yPGfTPEYQ4evE8NdPE",
|
||||
region: "us-east-1",
|
||||
bucket: "exo-bug-reports"
|
||||
)
|
||||
private static func dayPrefixString(_ date: Date) -> String {
|
||||
var calendar = Calendar(identifier: .gregorian)
|
||||
calendar.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
|
||||
let components = calendar.dateComponents([.year, .month, .day], from: date)
|
||||
let year = components.year ?? 0
|
||||
let month = components.month ?? 0
|
||||
let day = components.day ?? 0
|
||||
return String(format: "%04d/%02d/%02d", year, month, day)
|
||||
}
|
||||
|
||||
private static func runTimestampString(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.locale = Locale(identifier: "en_US_POSIX")
|
||||
formatter.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
|
||||
formatter.dateFormat = "yyyy-MM-dd'T'HHmmss.SSS'Z'"
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private func fetchPresignedUploadUrls(keys: [String], bundle: Bundle = .main) async throws
|
||||
-> [String: String]
|
||||
{
|
||||
guard
|
||||
let endpointString = bundle.infoDictionary?["EXOBugReportPresignedUrlEndpoint"]
|
||||
as? String
|
||||
else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
let trimmedEndpointString = endpointString.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
guard !trimmedEndpointString.isEmpty, let endpoint = URL(string: trimmedEndpointString)
|
||||
else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
|
||||
var request = URLRequest(url: endpoint)
|
||||
request.httpMethod = "POST"
|
||||
request.timeoutInterval = 10
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
|
||||
let encoder = JSONEncoder()
|
||||
request.httpBody = try encoder.encode(PresignedUrlsRequest(keys: keys))
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse else {
|
||||
throw BugReportError.presignedUrlFailed("Non-HTTP response")
|
||||
}
|
||||
guard (200..<300).contains(http.statusCode) else {
|
||||
throw BugReportError.presignedUrlFailed("HTTP status \(http.statusCode)")
|
||||
}
|
||||
|
||||
let decoder = JSONDecoder()
|
||||
let decoded = try decoder.decode(PresignedUrlsResponse.self, from: data)
|
||||
return decoded.urls
|
||||
}
|
||||
|
||||
private func readLog() -> Data? {
|
||||
@@ -101,7 +155,8 @@ struct BugReportService {
|
||||
private func captureIfconfig() async throws -> String {
|
||||
let result = runCommand(["/sbin/ifconfig"])
|
||||
guard result.exitCode == 0 else {
|
||||
throw BugReportError.collectFailed(result.error.isEmpty ? "ifconfig failed" : result.error)
|
||||
throw BugReportError.collectFailed(
|
||||
result.error.isEmpty ? "ifconfig failed" : result.error)
|
||||
}
|
||||
return result.output
|
||||
}
|
||||
@@ -109,12 +164,23 @@ struct BugReportService {
|
||||
private func readDebugInfo() -> DebugInfo {
|
||||
DebugInfo(
|
||||
thunderboltBridgeDisabled: readThunderboltBridgeDisabled(),
|
||||
interfaces: readInterfaces()
|
||||
interfaces: readInterfaces(),
|
||||
rdma: readRDMADebugInfo()
|
||||
)
|
||||
}
|
||||
|
||||
private func readRDMADebugInfo() -> DebugInfo.RDMADebugInfo {
|
||||
DebugInfo.RDMADebugInfo(
|
||||
rdmaCtlStatus: safeRunCommand(["/usr/bin/rdma_ctl", "status"]),
|
||||
ibvDevices: safeRunCommand(["/usr/bin/ibv_devices"]),
|
||||
ibvDevinfo: safeRunCommand(["/usr/bin/ibv_devinfo"])
|
||||
)
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeDisabled() -> Bool? {
|
||||
let result = runCommand(["/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
|
||||
let result = runCommand([
|
||||
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
|
||||
])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased()
|
||||
if output.contains("enabled") {
|
||||
@@ -157,7 +223,8 @@ struct BugReportService {
|
||||
request.timeoutInterval = 5
|
||||
do {
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
@@ -166,6 +233,36 @@ struct BugReportService {
|
||||
}
|
||||
}
|
||||
|
||||
private func uploadToPresignedUrl(url: URL, body: Data) async throws {
|
||||
let maxAttempts = 2
|
||||
var lastError: Error?
|
||||
|
||||
for attempt in 1...maxAttempts {
|
||||
do {
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "PUT"
|
||||
request.httpBody = body
|
||||
request.timeoutInterval = 30
|
||||
|
||||
let (_, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse else {
|
||||
throw BugReportError.uploadFailed("Non-HTTP response")
|
||||
}
|
||||
guard (200..<300).contains(http.statusCode) else {
|
||||
throw BugReportError.uploadFailed("HTTP status \(http.statusCode)")
|
||||
}
|
||||
return
|
||||
} catch {
|
||||
lastError = error
|
||||
if attempt < maxAttempts {
|
||||
try await Task.sleep(nanoseconds: 400_000_000)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw BugReportError.uploadFailed(lastError?.localizedDescription ?? "Unknown error")
|
||||
}
|
||||
|
||||
private func makeReportJson(
|
||||
timestamp: String,
|
||||
hostName: String,
|
||||
@@ -183,7 +280,7 @@ struct BugReportService {
|
||||
"system": system,
|
||||
"exo_version": exo.version as Any,
|
||||
"exo_commit": exo.commit as Any,
|
||||
"report_type": isManual ? "manual" : "automated"
|
||||
"report_type": isManual ? "manual" : "automated",
|
||||
]
|
||||
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
|
||||
}
|
||||
@@ -214,10 +311,13 @@ struct BugReportService {
|
||||
let user = safeRunCommand(["/usr/bin/whoami"])
|
||||
let consoleUser = safeRunCommand(["/usr/bin/stat", "-f%Su", "/dev/console"])
|
||||
let uptime = safeRunCommand(["/usr/bin/uptime"])
|
||||
let diskRoot = safeRunCommand(["/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'"])
|
||||
let diskRoot = safeRunCommand([
|
||||
"/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'",
|
||||
])
|
||||
|
||||
let interfacesList = safeRunCommand(["/usr/sbin/ipconfig", "getiflist"])
|
||||
let interfacesAndIPs = interfacesList?
|
||||
let interfacesAndIPs =
|
||||
interfacesList?
|
||||
.split(whereSeparator: { $0 == " " || $0 == "\n" })
|
||||
.compactMap { iface -> [String: Any]? in
|
||||
let name = String(iface)
|
||||
@@ -228,7 +328,8 @@ struct BugReportService {
|
||||
} ?? []
|
||||
|
||||
let wifiSSID: String?
|
||||
let airportPath = "/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
|
||||
let airportPath =
|
||||
"/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
|
||||
if FileManager.default.isExecutableFile(atPath: airportPath) {
|
||||
wifiSSID = safeRunCommand([airportPath, "-I"]).flatMap(parseWifiSSID)
|
||||
} else {
|
||||
@@ -256,7 +357,7 @@ struct BugReportService {
|
||||
"disk_root": diskRoot as Any,
|
||||
"interfaces_and_ips": interfacesAndIPs,
|
||||
"ipconfig_getiflist": interfacesList as Any,
|
||||
"wifi_ssid": wifiSSID as Any
|
||||
"wifi_ssid": wifiSSID as Any,
|
||||
]
|
||||
}
|
||||
|
||||
@@ -314,7 +415,8 @@ struct BugReportService {
|
||||
for line in airportOutput.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("SSID:") {
|
||||
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(in: .whitespaces)
|
||||
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -351,6 +453,7 @@ struct BugReportService {
|
||||
private struct DebugInfo {
|
||||
let thunderboltBridgeDisabled: Bool?
|
||||
let interfaces: [InterfaceStatus]
|
||||
let rdma: RDMADebugInfo
|
||||
|
||||
struct InterfaceStatus {
|
||||
let name: String
|
||||
@@ -359,7 +462,21 @@ private struct DebugInfo {
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"name": name,
|
||||
"ip": ip as Any
|
||||
"ip": ip as Any,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
struct RDMADebugInfo {
|
||||
let rdmaCtlStatus: String?
|
||||
let ibvDevices: String?
|
||||
let ibvDevinfo: String?
|
||||
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"rdma_ctl_status": rdmaCtlStatus as Any,
|
||||
"ibv_devices": ibvDevices as Any,
|
||||
"ibv_devinfo": ibvDevinfo as Any,
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -367,7 +484,8 @@ private struct DebugInfo {
|
||||
func toDictionary() -> [String: Any] {
|
||||
[
|
||||
"thunderbolt_bridge_disabled": thunderboltBridgeDisabled as Any,
|
||||
"interfaces": interfaces.map { $0.toDictionary() }
|
||||
"interfaces": interfaces.map { $0.toDictionary() },
|
||||
"rdma": rdma.toDictionary(),
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -377,163 +495,3 @@ private struct CommandResult {
|
||||
let output: String
|
||||
let error: String
|
||||
}
|
||||
|
||||
private struct S3Uploader {
|
||||
let config: BugReportService.AWSConfig
|
||||
|
||||
init(config: BugReportService.AWSConfig) throws {
|
||||
self.config = config
|
||||
}
|
||||
|
||||
func upload(objectPath: String, body: Data) async throws {
|
||||
let host = "\(config.bucket).s3.amazonaws.com"
|
||||
guard let url = URL(string: "https://\(host)/\(objectPath)") else {
|
||||
throw BugReportError.invalidEndpoint
|
||||
}
|
||||
|
||||
let now = Date()
|
||||
let amzDate = awsTimestamp(now)
|
||||
let dateStamp = dateStamp(now)
|
||||
let payloadHash = sha256Hex(body)
|
||||
|
||||
let headers = [
|
||||
"host": host,
|
||||
"x-amz-content-sha256": payloadHash,
|
||||
"x-amz-date": amzDate
|
||||
]
|
||||
|
||||
let canonicalRequest = buildCanonicalRequest(
|
||||
method: "PUT",
|
||||
url: url,
|
||||
headers: headers,
|
||||
payloadHash: payloadHash
|
||||
)
|
||||
|
||||
let stringToSign = buildStringToSign(
|
||||
amzDate: amzDate,
|
||||
dateStamp: dateStamp,
|
||||
canonicalRequestHash: sha256Hex(canonicalRequest.data(using: .utf8) ?? Data())
|
||||
)
|
||||
|
||||
let signingKey = deriveKey(secret: config.secretKey, dateStamp: dateStamp, region: config.region, service: "s3")
|
||||
let signature = hmacHex(key: signingKey, data: Data(stringToSign.utf8))
|
||||
|
||||
let signedHeaders = "host;x-amz-content-sha256;x-amz-date"
|
||||
let authorization = """
|
||||
AWS4-HMAC-SHA256 Credential=\(config.accessKey)/\(dateStamp)/\(config.region)/s3/aws4_request, SignedHeaders=\(signedHeaders), Signature=\(signature)
|
||||
"""
|
||||
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "PUT"
|
||||
request.httpBody = body
|
||||
request.setValue(headers["x-amz-content-sha256"], forHTTPHeaderField: "x-amz-content-sha256")
|
||||
request.setValue(headers["x-amz-date"], forHTTPHeaderField: "x-amz-date")
|
||||
request.setValue(host, forHTTPHeaderField: "Host")
|
||||
request.setValue(authorization, forHTTPHeaderField: "Authorization")
|
||||
|
||||
let (data, response) = try await URLSession.shared.data(for: request)
|
||||
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
|
||||
let statusText = (response as? HTTPURLResponse)?.statusCode ?? -1
|
||||
_ = data // ignore response body for UX
|
||||
throw BugReportError.uploadFailed("HTTP status \(statusText)")
|
||||
}
|
||||
}
|
||||
|
||||
private func buildCanonicalRequest(
|
||||
method: String,
|
||||
url: URL,
|
||||
headers: [String: String],
|
||||
payloadHash: String
|
||||
) -> String {
|
||||
let canonicalURI = encodePath(url.path)
|
||||
let canonicalQuery = url.query ?? ""
|
||||
let sortedHeaders = headers.sorted { $0.key < $1.key }
|
||||
let canonicalHeaders = sortedHeaders
|
||||
.map { "\($0.key.lowercased()):\($0.value)\n" }
|
||||
.joined()
|
||||
let signedHeaders = sortedHeaders.map { $0.key.lowercased() }.joined(separator: ";")
|
||||
|
||||
return [
|
||||
method,
|
||||
canonicalURI,
|
||||
canonicalQuery,
|
||||
canonicalHeaders,
|
||||
signedHeaders,
|
||||
payloadHash
|
||||
].joined(separator: "\n")
|
||||
}
|
||||
|
||||
private func encodePath(_ path: String) -> String {
|
||||
return path
|
||||
.split(separator: "/")
|
||||
.map { segment in
|
||||
segment.addingPercentEncoding(withAllowedCharacters: Self.rfc3986) ?? String(segment)
|
||||
}
|
||||
.joined(separator: "/")
|
||||
.prependSlashIfNeeded()
|
||||
}
|
||||
|
||||
private func buildStringToSign(
|
||||
amzDate: String,
|
||||
dateStamp: String,
|
||||
canonicalRequestHash: String
|
||||
) -> String {
|
||||
"""
|
||||
AWS4-HMAC-SHA256
|
||||
\(amzDate)
|
||||
\(dateStamp)/\(config.region)/s3/aws4_request
|
||||
\(canonicalRequestHash)
|
||||
"""
|
||||
}
|
||||
|
||||
private func deriveKey(secret: String, dateStamp: String, region: String, service: String) -> Data {
|
||||
let kDate = hmac(key: Data(("AWS4" + secret).utf8), data: Data(dateStamp.utf8))
|
||||
let kRegion = hmac(key: kDate, data: Data(region.utf8))
|
||||
let kService = hmac(key: kRegion, data: Data(service.utf8))
|
||||
return hmac(key: kService, data: Data("aws4_request".utf8))
|
||||
}
|
||||
|
||||
private func hmac(key: Data, data: Data) -> Data {
|
||||
let keySym = SymmetricKey(data: key)
|
||||
let mac = HMAC<SHA256>.authenticationCode(for: data, using: keySym)
|
||||
return Data(mac)
|
||||
}
|
||||
|
||||
private func hmacHex(key: Data, data: Data) -> String {
|
||||
hmac(key: key, data: data).map { String(format: "%02x", $0) }.joined()
|
||||
}
|
||||
|
||||
private func sha256Hex(_ data: Data) -> String {
|
||||
let digest = SHA256.hash(data: data)
|
||||
return digest.compactMap { String(format: "%02x", $0) }.joined()
|
||||
}
|
||||
|
||||
private func awsTimestamp(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
|
||||
formatter.timeZone = TimeZone(abbreviation: "UTC")
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private func dateStamp(_ date: Date) -> String {
|
||||
let formatter = DateFormatter()
|
||||
formatter.dateFormat = "yyyyMMdd"
|
||||
formatter.timeZone = TimeZone(abbreviation: "UTC")
|
||||
return formatter.string(from: date)
|
||||
}
|
||||
|
||||
private static let rfc3986: CharacterSet = {
|
||||
var set = CharacterSet.alphanumerics
|
||||
set.insert(charactersIn: "-._~")
|
||||
return set
|
||||
}()
|
||||
}
|
||||
|
||||
private extension String {
|
||||
func prependSlashIfNeeded() -> String {
|
||||
if hasPrefix("/") {
|
||||
return self
|
||||
}
|
||||
return "/" + self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ final class ClusterStateService: ObservableObject {
|
||||
@Published private(set) var lastError: String?
|
||||
@Published private(set) var lastActionMessage: String?
|
||||
@Published private(set) var modelOptions: [ModelOption] = []
|
||||
@Published private(set) var localNodeId: String?
|
||||
|
||||
private var timer: Timer?
|
||||
private let decoder: JSONDecoder
|
||||
@@ -29,6 +30,7 @@ final class ClusterStateService: ObservableObject {
|
||||
func startPolling(interval: TimeInterval = 0.5) {
|
||||
stopPolling()
|
||||
Task {
|
||||
await fetchLocalNodeId()
|
||||
await fetchModels()
|
||||
await fetchSnapshot()
|
||||
}
|
||||
@@ -46,9 +48,33 @@ final class ClusterStateService: ObservableObject {
|
||||
latestSnapshot = nil
|
||||
lastError = nil
|
||||
lastActionMessage = nil
|
||||
localNodeId = nil
|
||||
}
|
||||
|
||||
private func fetchLocalNodeId() async {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("node_id")
|
||||
var request = URLRequest(url: url)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (data, response) = try await session.data(for: request)
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
return
|
||||
}
|
||||
if let nodeId = try? decoder.decode(String.self, from: data) {
|
||||
localNodeId = nodeId
|
||||
}
|
||||
} catch {
|
||||
// Silently ignore - localNodeId will remain nil and retry on next poll
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchSnapshot() async {
|
||||
// Retry fetching local node ID if not yet set
|
||||
if localNodeId == nil {
|
||||
await fetchLocalNodeId()
|
||||
}
|
||||
do {
|
||||
var request = URLRequest(url: endpoint)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
@@ -89,7 +115,9 @@ final class ClusterStateService: ObservableObject {
|
||||
}
|
||||
}
|
||||
|
||||
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int) async {
|
||||
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int)
|
||||
async
|
||||
{
|
||||
do {
|
||||
var request = URLRequest(url: baseURL.appendingPathComponent("instance"))
|
||||
request.httpMethod = "POST"
|
||||
@@ -98,7 +126,7 @@ final class ClusterStateService: ObservableObject {
|
||||
"model_id": modelId,
|
||||
"sharding": sharding,
|
||||
"instance_meta": instanceMeta,
|
||||
"min_nodes": minNodes
|
||||
"min_nodes": minNodes,
|
||||
]
|
||||
request.httpBody = try JSONSerialization.data(withJSONObject: payload, options: [])
|
||||
let (_, response) = try await session.data(for: request)
|
||||
@@ -119,7 +147,9 @@ final class ClusterStateService: ObservableObject {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("models")
|
||||
let (data, response) = try await session.data(from: url)
|
||||
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
throw URLError(.badServerResponse)
|
||||
}
|
||||
let list = try decoder.decode(ModelListResponse.self, from: data)
|
||||
|
||||
150
app/EXO/EXO/Services/LocalNetworkChecker.swift
Normal file
150
app/EXO/EXO/Services/LocalNetworkChecker.swift
Normal file
@@ -0,0 +1,150 @@
|
||||
import Foundation
|
||||
import Network
|
||||
import os.log
|
||||
|
||||
/// Checks if the app's local network permission is actually functional.
|
||||
///
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service detects this by creating a UDP
|
||||
/// connection to the mDNS multicast address (224.0.0.251:5353).
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
case unknown
|
||||
case checking
|
||||
case working
|
||||
case notWorking(reason: String)
|
||||
|
||||
var isHealthy: Bool {
|
||||
if case .working = self { return true }
|
||||
return false
|
||||
}
|
||||
|
||||
var displayText: String {
|
||||
switch self {
|
||||
case .unknown:
|
||||
return "Unknown"
|
||||
case .checking:
|
||||
return "Checking..."
|
||||
case .working:
|
||||
return "Working"
|
||||
case .notWorking(let reason):
|
||||
return reason
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
@Published private(set) var lastConnectionState: String = "none"
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
/// Checks if local network access is working.
|
||||
func check() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
lastConnectionState = "connecting"
|
||||
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
let result = await self.performCheck()
|
||||
self.status = result
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
|
||||
private func performCheck() async -> Status {
|
||||
Self.logger.info("Checking local network access via UDP multicast")
|
||||
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
|
||||
// mDNS multicast address - same as libp2p uses for peer discovery
|
||||
let host = NWEndpoint.Host("224.0.0.251")
|
||||
let port = NWEndpoint.Port(integerLiteral: 5353)
|
||||
|
||||
let params = NWParameters.udp
|
||||
params.allowLocalEndpointReuse = true
|
||||
|
||||
let conn = NWConnection(host: host, port: port, using: params)
|
||||
connection = conn
|
||||
|
||||
return await withCheckedContinuation { continuation in
|
||||
var hasResumed = false
|
||||
let lock = NSLock()
|
||||
|
||||
let resumeOnce: (Status) -> Void = { status in
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
guard !hasResumed else { return }
|
||||
hasResumed = true
|
||||
continuation.resume(returning: status)
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { [weak self] state in
|
||||
let stateStr: String
|
||||
switch state {
|
||||
case .setup: stateStr = "setup"
|
||||
case .preparing: stateStr = "preparing"
|
||||
case .ready: stateStr = "ready"
|
||||
case .waiting(let e): stateStr = "waiting(\(e))"
|
||||
case .failed(let e): stateStr = "failed(\(e))"
|
||||
case .cancelled: stateStr = "cancelled"
|
||||
@unknown default: stateStr = "unknown"
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
self?.lastConnectionState = stateStr
|
||||
}
|
||||
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
case .waiting(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
|
||||
resumeOnce(.notWorking(reason: "Connection blocked"))
|
||||
}
|
||||
case .failed(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|
||||
|| errorStr.contains("permission") || errorStr.contains("denied")
|
||||
{
|
||||
resumeOnce(.notWorking(reason: "Permission denied"))
|
||||
} else {
|
||||
resumeOnce(.notWorking(reason: "Failed: \(error.localizedDescription)"))
|
||||
}
|
||||
case .cancelled, .setup, .preparing:
|
||||
break
|
||||
@unknown default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
conn.start(queue: .main)
|
||||
|
||||
Task {
|
||||
try? await Task.sleep(nanoseconds: 3_000_000_000)
|
||||
let state = conn.state
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
case .waiting, .preparing, .setup:
|
||||
resumeOnce(.notWorking(reason: "Timeout (may be blocked)"))
|
||||
default:
|
||||
resumeOnce(.notWorking(reason: "Timeout"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stop() {
|
||||
checkTask?.cancel()
|
||||
checkTask = nil
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
}
|
||||
}
|
||||
@@ -5,64 +5,66 @@ import os.log
|
||||
enum NetworkSetupHelper {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
|
||||
private static let daemonLabel = "io.exo.networksetup"
|
||||
private static let scriptDestination = "/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let scriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1791
|
||||
|
||||
private static let setupScript = """
|
||||
#!/usr/bin/env bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
|
||||
static func ensureLaunchDaemonInstalled() {
|
||||
Task.detached {
|
||||
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
|
||||
Task.detached(priority: .utility) {
|
||||
do {
|
||||
if daemonAlreadyInstalled() {
|
||||
return
|
||||
@@ -70,11 +72,70 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
try await installLaunchDaemon()
|
||||
logger.info("Network setup launch daemon installed and started")
|
||||
} catch {
|
||||
logger.error("Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)")
|
||||
logger.error(
|
||||
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes all EXO network setup components from the system.
|
||||
/// This includes the LaunchDaemon, scripts, logs, and network location.
|
||||
/// Requires admin privileges.
|
||||
static func uninstall() throws {
|
||||
let uninstallScript = makeUninstallScript()
|
||||
try runShellAsAdmin(uninstallScript)
|
||||
logger.info("EXO network setup components removed successfully")
|
||||
}
|
||||
|
||||
/// Checks if there are any EXO network components installed that need cleanup
|
||||
static func hasInstalledComponents() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
return scriptExists || plistExists
|
||||
}
|
||||
|
||||
private static func makeUninstallScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LOG_OUT="/var/log/\(daemonLabel).log"
|
||||
LOG_ERR="/var/log/\(daemonLabel).err.log"
|
||||
|
||||
# Unload the LaunchDaemon if running
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
|
||||
# Remove LaunchDaemon plist
|
||||
rm -f "$PLIST_DEST"
|
||||
|
||||
# Remove the script and parent directory if empty
|
||||
rm -f "$SCRIPT_DEST"
|
||||
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
|
||||
|
||||
# Remove log files
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
|
||||
# Switch back to Automatic network location
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
networksetup -listlocations | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
} || true
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
} || true
|
||||
|
||||
echo "EXO network components removed successfully"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func daemonAlreadyInstalled() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
@@ -82,7 +143,8 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
guard scriptExists, plistExists else { return false }
|
||||
guard
|
||||
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
|
||||
let plist = try? PropertyListSerialization.propertyList(from: data, options: [], format: nil) as? [String: Any]
|
||||
let plist = try? PropertyListSerialization.propertyList(
|
||||
from: data, options: [], format: nil) as? [String: Any]
|
||||
else {
|
||||
return false
|
||||
}
|
||||
@@ -92,7 +154,9 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
else {
|
||||
return false
|
||||
}
|
||||
if let programArgs = plist["ProgramArguments"] as? [String], programArgs.contains(scriptDestination) == false {
|
||||
if let programArgs = plist["ProgramArguments"] as? [String],
|
||||
programArgs.contains(scriptDestination) == false
|
||||
{
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -105,58 +169,59 @@ networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
|
||||
private static func makeInstallerScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func runShellAsAdmin(_ script: String) throws {
|
||||
let escapedScript = script
|
||||
let escapedScript =
|
||||
script
|
||||
.replacingOccurrences(of: "\\", with: "\\\\")
|
||||
.replacingOccurrences(of: "\"", with: "\\\"")
|
||||
|
||||
let appleScriptSource = """
|
||||
do shell script "\(escapedScript)" with administrator privileges
|
||||
"""
|
||||
do shell script "\(escapedScript)" with administrator privileges
|
||||
"""
|
||||
|
||||
guard let appleScript = NSAppleScript(source: appleScriptSource) else {
|
||||
throw NetworkSetupError.scriptCreationFailed
|
||||
|
||||
@@ -35,14 +35,34 @@ struct NetworkStatus: Equatable {
|
||||
let thunderboltBridgeState: ThunderboltState?
|
||||
let bridgeInactive: Bool?
|
||||
let interfaceStatuses: [InterfaceIpStatus]
|
||||
let rdmaStatus: RDMAStatus
|
||||
|
||||
static let empty = NetworkStatus(
|
||||
thunderboltBridgeState: nil,
|
||||
bridgeInactive: nil,
|
||||
interfaceStatuses: []
|
||||
interfaceStatuses: [],
|
||||
rdmaStatus: .empty
|
||||
)
|
||||
}
|
||||
|
||||
struct RDMAStatus: Equatable {
|
||||
let rdmaCtlEnabled: Bool?
|
||||
let devices: [String]
|
||||
let activePorts: [RDMAPort]
|
||||
|
||||
var isAvailable: Bool {
|
||||
rdmaCtlEnabled == true || !devices.isEmpty
|
||||
}
|
||||
|
||||
static let empty = RDMAStatus(rdmaCtlEnabled: nil, devices: [], activePorts: [])
|
||||
}
|
||||
|
||||
struct RDMAPort: Equatable {
|
||||
let device: String
|
||||
let port: String
|
||||
let state: String
|
||||
}
|
||||
|
||||
struct InterfaceIpStatus: Equatable {
|
||||
let interfaceName: String
|
||||
let ipAddress: String?
|
||||
@@ -59,10 +79,79 @@ private struct NetworkStatusFetcher {
|
||||
NetworkStatus(
|
||||
thunderboltBridgeState: readThunderboltBridgeState(),
|
||||
bridgeInactive: readBridgeInactive(),
|
||||
interfaceStatuses: readInterfaceStatuses()
|
||||
interfaceStatuses: readInterfaceStatuses(),
|
||||
rdmaStatus: readRDMAStatus()
|
||||
)
|
||||
}
|
||||
|
||||
private func readRDMAStatus() -> RDMAStatus {
|
||||
let rdmaCtlEnabled = readRDMACtlEnabled()
|
||||
let devices = readRDMADevices()
|
||||
let activePorts = readRDMAActivePorts()
|
||||
return RDMAStatus(
|
||||
rdmaCtlEnabled: rdmaCtlEnabled, devices: devices, activePorts: activePorts)
|
||||
}
|
||||
|
||||
private func readRDMACtlEnabled() -> Bool? {
|
||||
let result = runCommand(["rdma_ctl", "status"])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
if output.contains("enabled") {
|
||||
return true
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
private func readRDMADevices() -> [String] {
|
||||
let result = runCommand(["ibv_devices"])
|
||||
guard result.exitCode == 0 else { return [] }
|
||||
var devices: [String] = []
|
||||
for line in result.output.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("---") || trimmed.lowercased().hasPrefix("device")
|
||||
|| trimmed.isEmpty
|
||||
{
|
||||
continue
|
||||
}
|
||||
let parts = trimmed.split(separator: " ", maxSplits: 1)
|
||||
if let deviceName = parts.first {
|
||||
devices.append(String(deviceName))
|
||||
}
|
||||
}
|
||||
return devices
|
||||
}
|
||||
|
||||
private func readRDMAActivePorts() -> [RDMAPort] {
|
||||
let result = runCommand(["ibv_devinfo"])
|
||||
guard result.exitCode == 0 else { return [] }
|
||||
var ports: [RDMAPort] = []
|
||||
var currentDevice: String?
|
||||
var currentPort: String?
|
||||
|
||||
for line in result.output.split(separator: "\n") {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("hca_id:") {
|
||||
currentDevice = trimmed.replacingOccurrences(of: "hca_id:", with: "")
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
} else if trimmed.hasPrefix("port:") {
|
||||
currentPort = trimmed.replacingOccurrences(of: "port:", with: "")
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
} else if trimmed.hasPrefix("state:") {
|
||||
let state = trimmed.replacingOccurrences(of: "state:", with: "").trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
if let device = currentDevice, let port = currentPort {
|
||||
if state.lowercased().contains("active") {
|
||||
ports.append(RDMAPort(device: device, port: port, state: state))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ports
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeState() -> ThunderboltState? {
|
||||
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
|
||||
guard result.exitCode == 0 else {
|
||||
@@ -85,10 +174,11 @@ private struct NetworkStatusFetcher {
|
||||
private func readBridgeInactive() -> Bool? {
|
||||
let result = runCommand(["ifconfig", "bridge0"])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
guard let statusLine = result.output
|
||||
.components(separatedBy: .newlines)
|
||||
.first(where: { $0.contains("status:") })?
|
||||
.lowercased()
|
||||
guard
|
||||
let statusLine = result.output
|
||||
.components(separatedBy: .newlines)
|
||||
.first(where: { $0.contains("status:") })?
|
||||
.lowercased()
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
@@ -171,4 +261,3 @@ private struct NetworkStatusFetcher {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ struct InstanceViewModel: Identifiable, Equatable {
|
||||
case waiting
|
||||
case failed
|
||||
case idle
|
||||
case unknown
|
||||
case preparing
|
||||
|
||||
var label: String {
|
||||
switch self {
|
||||
@@ -68,7 +68,7 @@ struct InstanceViewModel: Identifiable, Equatable {
|
||||
case .waiting: return "Waiting"
|
||||
case .failed: return "Failed"
|
||||
case .idle: return "Idle"
|
||||
case .unknown: return "Unknown"
|
||||
case .preparing: return "Preparing"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -107,10 +107,13 @@ extension ClusterState {
|
||||
let nodeToRunner = instance.shardAssignments.nodeToRunner
|
||||
let nodeIds = Array(nodeToRunner.keys)
|
||||
let runnerIds = Array(nodeToRunner.values)
|
||||
let nodeNames = nodeIds.compactMap { nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0 }
|
||||
let nodeNames = nodeIds.compactMap {
|
||||
nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0
|
||||
}
|
||||
let statuses = runnerIds.compactMap { runners[$0]?.status.lowercased() }
|
||||
let downloadProgress = aggregateDownloadProgress(for: nodeIds)
|
||||
let state = InstanceViewModel.State(statuses: statuses, hasActiveDownload: downloadProgress != nil)
|
||||
let state = InstanceViewModel.State(
|
||||
statuses: statuses, hasActiveDownload: downloadProgress != nil)
|
||||
let chatTasks = (chatTasksByInstance[entry.key] ?? [])
|
||||
.sorted(by: { $0.sortPriority < $1.sortPriority })
|
||||
.map { InstanceTaskViewModel(task: $0) }
|
||||
@@ -165,8 +168,8 @@ extension ClusterState {
|
||||
}
|
||||
}
|
||||
|
||||
private extension InstanceViewModel.State {
|
||||
init(statuses: [String], hasActiveDownload: Bool = false) {
|
||||
extension InstanceViewModel.State {
|
||||
fileprivate init(statuses: [String], hasActiveDownload: Bool = false) {
|
||||
if statuses.contains(where: { $0.contains("failed") }) {
|
||||
self = .failed
|
||||
} else if hasActiveDownload || statuses.contains(where: { $0.contains("downloading") }) {
|
||||
@@ -182,7 +185,7 @@ private extension InstanceViewModel.State {
|
||||
} else if statuses.isEmpty {
|
||||
self = .idle
|
||||
} else {
|
||||
self = .unknown
|
||||
self = .preparing
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -243,4 +246,3 @@ extension InstanceTaskViewModel {
|
||||
self.parameters = task.parameters
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -85,9 +85,11 @@ struct TopologyViewModel {
|
||||
}
|
||||
|
||||
extension ClusterState {
|
||||
func topologyViewModel() -> TopologyViewModel? {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
|
||||
let allNodes = nodeViewModels().filter {
|
||||
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
|
||||
}
|
||||
guard !allNodes.isEmpty else { return nil }
|
||||
|
||||
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
|
||||
@@ -105,17 +107,25 @@ extension ClusterState {
|
||||
orderedNodes = allNodes
|
||||
}
|
||||
|
||||
// Rotate so the local node (from /node_id API) is first
|
||||
if let localId = localNodeId,
|
||||
let index = orderedNodes.firstIndex(where: { $0.id == localId })
|
||||
{
|
||||
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
|
||||
}
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
|
||||
return TopologyEdgeViewModel(sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
|
||||
} ?? []
|
||||
let edgesArray: [TopologyEdgeViewModel] =
|
||||
topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId),
|
||||
nodeIds.contains(connection.sendBackNodeId)
|
||||
else { return nil }
|
||||
return TopologyEdgeViewModel(
|
||||
sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
|
||||
} ?? []
|
||||
let edges = Set(edgesArray)
|
||||
|
||||
let topologyRootId = topology?.nodes.first?.nodeId
|
||||
let currentId = orderedNodes.first(where: { $0.id == topologyRootId })?.id ?? orderedNodes.first?.id
|
||||
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: currentId)
|
||||
return TopologyViewModel(
|
||||
nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,8 +20,8 @@ struct InstanceRowView: View {
|
||||
if let progress = instance.downloadProgress {
|
||||
downloadStatusView(progress: progress)
|
||||
} else {
|
||||
statusChip(label: instance.state.label.uppercased(), color: statusColor)
|
||||
}
|
||||
statusChip(label: instance.state.label.uppercased(), color: statusColor)
|
||||
}
|
||||
}
|
||||
if let progress = instance.downloadProgress {
|
||||
GeometryReader { geometry in
|
||||
@@ -83,7 +83,7 @@ struct InstanceRowView: View {
|
||||
case .ready: return .teal
|
||||
case .waiting, .idle: return .gray
|
||||
case .failed: return .red
|
||||
case .unknown: return .secondary
|
||||
case .preparing: return .secondary
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,7 +97,8 @@ struct InstanceRowView: View {
|
||||
.font(.caption)
|
||||
.fontWeight(.semibold)
|
||||
if let subtitle = task.subtitle,
|
||||
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame {
|
||||
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame
|
||||
{
|
||||
Text(subtitle)
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
@@ -234,9 +235,12 @@ struct InstanceRowView: View {
|
||||
Button {
|
||||
isExpanded.wrappedValue.toggle()
|
||||
} label: {
|
||||
Label(isExpanded.wrappedValue ? "Hide" : "Show", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
Label(
|
||||
isExpanded.wrappedValue ? "Hide" : "Show",
|
||||
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
|
||||
)
|
||||
.labelStyle(.titleAndIcon)
|
||||
.contentTransition(.symbolEffect(.replace))
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.font(.caption2)
|
||||
@@ -311,7 +315,9 @@ struct InstanceRowView: View {
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private func detailRow(icon: String? = nil, title: String, value: String, tint: Color = .secondary) -> some View {
|
||||
private func detailRow(
|
||||
icon: String? = nil, title: String, value: String, tint: Color = .secondary
|
||||
) -> some View {
|
||||
HStack(alignment: .firstTextBaseline, spacing: 6) {
|
||||
if let icon {
|
||||
Image(systemName: icon)
|
||||
@@ -329,4 +335,3 @@ struct InstanceRowView: View {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,4 +32,3 @@ struct NodeDetailView: View {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,4 +28,3 @@ struct NodeRowView: View {
|
||||
.padding(.vertical, 4)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -76,30 +76,33 @@ struct TopologyMiniView: View {
|
||||
|
||||
private func connectionLines(in size: CGSize) -> some View {
|
||||
let positions = positionedNodes(in: size)
|
||||
let positionById = Dictionary(uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
|
||||
let positionById = Dictionary(
|
||||
uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
|
||||
return Canvas { context, _ in
|
||||
guard !topology.edges.isEmpty else { return }
|
||||
let nodeRadius: CGFloat = 32
|
||||
let arrowLength: CGFloat = 10
|
||||
let arrowSpread: CGFloat = .pi / 7
|
||||
for edge in topology.edges {
|
||||
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId] else { continue }
|
||||
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId]
|
||||
else { continue }
|
||||
let dx = end.x - start.x
|
||||
let dy = end.y - start.y
|
||||
let distance = max(CGFloat(hypot(dx, dy)), 1)
|
||||
let ux = dx / distance
|
||||
let uy = dy / distance
|
||||
let adjustedStart = CGPoint(x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
|
||||
let adjustedStart = CGPoint(
|
||||
x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
|
||||
let adjustedEnd = CGPoint(x: end.x - ux * nodeRadius, y: end.y - uy * nodeRadius)
|
||||
|
||||
var linePath = Path()
|
||||
linePath.move(to: adjustedStart)
|
||||
linePath.addLine(to: adjustedEnd)
|
||||
context.stroke(
|
||||
context.stroke(
|
||||
linePath,
|
||||
with: .color(.secondary.opacity(0.3)),
|
||||
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
|
||||
)
|
||||
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
|
||||
)
|
||||
|
||||
let angle = atan2(uy, ux)
|
||||
let tip = adjustedEnd
|
||||
@@ -168,5 +171,3 @@ private struct NodeGlyphView: View {
|
||||
.frame(width: 95)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
//
|
||||
|
||||
import Testing
|
||||
|
||||
@testable import EXO
|
||||
|
||||
struct EXOTests {
|
||||
|
||||
154
app/EXO/uninstall-exo.sh
Executable file
154
app/EXO/uninstall-exo.sh
Executable file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# EXO Uninstaller Script
|
||||
#
|
||||
# This script removes all EXO system components that persist after deleting the app.
|
||||
# Run with: sudo ./uninstall-exo.sh
|
||||
#
|
||||
# Components removed:
|
||||
# - LaunchDaemon: /Library/LaunchDaemons/io.exo.networksetup.plist
|
||||
# - Network script: /Library/Application Support/EXO/
|
||||
# - Log files: /var/log/io.exo.networksetup.*
|
||||
# - Network location: "exo"
|
||||
# - Launch at login registration
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="io.exo.networksetup"
|
||||
SCRIPT_DEST="/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
PLIST_DEST="/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
LOG_OUT="/var/log/${LABEL}.log"
|
||||
LOG_ERR="/var/log/${LABEL}.err.log"
|
||||
APP_BUNDLE_ID="io.exo.EXO"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
echo_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
echo_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if running as root
|
||||
if [[ $EUID -ne 0 ]]; then
|
||||
echo_error "This script must be run as root (use sudo)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " EXO Uninstaller"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
|
||||
# Unload the LaunchDaemon if running
|
||||
echo_info "Stopping network setup daemon..."
|
||||
if launchctl list | grep -q "$LABEL"; then
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
echo_info "Daemon stopped"
|
||||
else
|
||||
echo_warn "Daemon was not running"
|
||||
fi
|
||||
|
||||
# Remove LaunchDaemon plist
|
||||
if [[ -f "$PLIST_DEST" ]]; then
|
||||
rm -f "$PLIST_DEST"
|
||||
echo_info "Removed LaunchDaemon plist"
|
||||
else
|
||||
echo_warn "LaunchDaemon plist not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove the script and parent directory
|
||||
if [[ -f "$SCRIPT_DEST" ]]; then
|
||||
rm -f "$SCRIPT_DEST"
|
||||
echo_info "Removed network setup script"
|
||||
else
|
||||
echo_warn "Network setup script not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove EXO directory if empty
|
||||
if [[ -d "/Library/Application Support/EXO" ]]; then
|
||||
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
|
||||
echo_info "Removed EXO support directory" || \
|
||||
echo_warn "EXO support directory not empty, leaving in place"
|
||||
fi
|
||||
|
||||
# Remove log files
|
||||
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
echo_info "Removed log files"
|
||||
else
|
||||
echo_warn "Log files not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Switch back to Automatic network location
|
||||
echo_info "Restoring network configuration..."
|
||||
if networksetup -listlocations | grep -q "^Automatic$"; then
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
echo_info "Switched to Automatic network location"
|
||||
else
|
||||
echo_warn "Automatic network location not found"
|
||||
fi
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
if networksetup -listlocations | grep -q "^exo$"; then
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
echo_info "Deleted 'exo' network location"
|
||||
else
|
||||
echo_warn "'exo' network location not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
echo_info "Re-enabled Thunderbolt Bridge"
|
||||
fi
|
||||
|
||||
# Note about launch at login registration
|
||||
# SMAppService-based login items cannot be removed from a shell script.
|
||||
# They can only be unregistered from within the app itself or manually via System Settings.
|
||||
echo_warn "Launch at login must be removed manually:"
|
||||
echo_warn " System Settings → General → Login Items → Remove EXO"
|
||||
|
||||
# Check if EXO.app exists in common locations
|
||||
APP_FOUND=false
|
||||
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
|
||||
if [[ -d "$app_path" ]]; then
|
||||
if [[ "$APP_FOUND" == false ]]; then
|
||||
echo ""
|
||||
APP_FOUND=true
|
||||
fi
|
||||
echo_warn "EXO.app found at: $app_path"
|
||||
echo_warn "You may want to move it to Trash manually."
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo_info "EXO uninstall complete!"
|
||||
echo "========================================"
|
||||
echo ""
|
||||
echo "The following have been removed:"
|
||||
echo " • Network setup LaunchDaemon"
|
||||
echo " • Network configuration script"
|
||||
echo " • Log files"
|
||||
echo " • 'exo' network location"
|
||||
echo ""
|
||||
echo "Your network has been restored to use the 'Automatic' location."
|
||||
echo "Thunderbolt Bridge has been re-enabled (if present)."
|
||||
echo ""
|
||||
echo "Manual step required:"
|
||||
echo " Remove EXO from Login Items in System Settings → General → Login Items"
|
||||
echo ""
|
||||
|
||||
526
bench/exo_bench.py
Normal file
526
bench/exo_bench.py
Normal file
@@ -0,0 +1,526 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
super().__init__(f"HTTP {status} {reason}: {body_preview}")
|
||||
self.status = status
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def request_json(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
body: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if params:
|
||||
path = path + "?" + urlencode(params)
|
||||
|
||||
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
|
||||
try:
|
||||
payload: bytes | None = None
|
||||
hdrs: dict[str, str] = {"Accept": "application/json"}
|
||||
|
||||
if body is not None:
|
||||
payload = json.dumps(body).encode("utf-8")
|
||||
hdrs["Content-Type"] = "application/json"
|
||||
if headers:
|
||||
hdrs.update(headers)
|
||||
|
||||
conn.request(method.upper(), path, body=payload, headers=hdrs)
|
||||
resp = conn.getresponse()
|
||||
raw = resp.read()
|
||||
text = raw.decode("utf-8", errors="replace") if raw else ""
|
||||
|
||||
if resp.status >= 400:
|
||||
raise ExoHttpError(resp.status, resp.reason, text[:300])
|
||||
|
||||
if not text:
|
||||
return None
|
||||
return json.loads(text)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return self.request_json("POST", "/bench/chat/completions", body=payload)
|
||||
|
||||
|
||||
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
|
||||
if len(instance) != 1:
|
||||
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
|
||||
|
||||
tag = next(iter(instance))
|
||||
inner = instance[tag]
|
||||
if not isinstance(inner, dict):
|
||||
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
|
||||
return inner
|
||||
|
||||
|
||||
def instance_id_from_instance(instance: dict[str, Any]) -> str:
|
||||
inner = unwrap_instance(instance)
|
||||
return str(inner["instanceId"])
|
||||
|
||||
|
||||
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
|
||||
inner = unwrap_instance(instance)
|
||||
return len(inner["shardAssignments"]["nodeToRunner"])
|
||||
|
||||
|
||||
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
|
||||
inner = unwrap_instance(instance)
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
return list(runner_to_shard.keys())
|
||||
|
||||
|
||||
def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
|
||||
|
||||
|
||||
def wait_for_instance_gone(
|
||||
client: ExoClient, instance_id: str, timeout: float = 3.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
client.request_json("GET", f"/instance/{instance_id}")
|
||||
time.sleep(0.4)
|
||||
except ExoHttpError as e:
|
||||
if e.status == 404:
|
||||
return
|
||||
|
||||
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
|
||||
|
||||
|
||||
def format_peak_memory(b: float) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if b < 1024.0:
|
||||
return f"{b:.2f}{unit}"
|
||||
b /= 1024.0
|
||||
raise ValueError("You're using petabytes of memory. Something went wrong...")
|
||||
|
||||
|
||||
def parse_int_list(values: list[str]) -> list[int]:
|
||||
items: list[int] = []
|
||||
for v in values:
|
||||
for part in v.split(","):
|
||||
part = part.strip()
|
||||
if part:
|
||||
items.append(int(part))
|
||||
|
||||
seen: set[int] = set()
|
||||
out: list[int] = []
|
||||
for x in items:
|
||||
if x not in seen:
|
||||
out.append(x)
|
||||
seen.add(x)
|
||||
return out
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
data = models.get("data") or []
|
||||
|
||||
for m in data:
|
||||
if m.get("id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
full_id = str(m.get("hugging_face_id") or m["id"])
|
||||
return short_id, full_id
|
||||
|
||||
for m in data:
|
||||
if m.get("hugging_face_id") == model_arg:
|
||||
short_id = str(m["id"])
|
||||
full_id = str(m["hugging_face_id"])
|
||||
return short_id, full_id
|
||||
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
return ("ring" in s) or ("jaccl" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def sharding_filter(sharding: str, wanted: str) -> bool:
|
||||
s = (sharding or "").lower()
|
||||
if wanted == "both":
|
||||
return ("pipeline" in s) or ("tensor" in s)
|
||||
return wanted in s
|
||||
|
||||
|
||||
def run_one_completion(
|
||||
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
content, pp_tokens = prompt_sizer.build(pp_hint)
|
||||
payload: dict[str, Any] = {
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"max_tokens": tg,
|
||||
}
|
||||
|
||||
t0 = time.perf_counter()
|
||||
out = client.post_bench_chat_completions(payload)
|
||||
elapsed = time.perf_counter() - t0
|
||||
|
||||
stats = out.get("generation_stats")
|
||||
|
||||
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
|
||||
|
||||
return {
|
||||
"elapsed_s": elapsed,
|
||||
"output_text_preview": preview,
|
||||
"stats": stats,
|
||||
}, pp_tokens
|
||||
|
||||
|
||||
class PromptSizer:
|
||||
def __init__(self, tokenizer: Any, atom: str = "a "):
|
||||
self.tokenizer = tokenizer
|
||||
self.atom = atom
|
||||
self.count_fn = PromptSizer._make_counter(tokenizer)
|
||||
self.base_tokens = self.count_fn("")
|
||||
|
||||
@staticmethod
|
||||
def _make_counter(tokenizer: Any) -> Callable[[str], int]:
|
||||
def count_fn(user_content: str) -> int:
|
||||
messages = [{"role": "user", "content": user_content}]
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
|
||||
def build(self, target_prompt_tokens: int) -> tuple[str, int]:
|
||||
target = int(target_prompt_tokens)
|
||||
if target < self.base_tokens:
|
||||
raise RuntimeError(
|
||||
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
|
||||
)
|
||||
|
||||
content = ""
|
||||
tok = self.count_fn(content)
|
||||
|
||||
while tok < target:
|
||||
content += self.atom
|
||||
tok = self.count_fn(content)
|
||||
|
||||
if tok != target:
|
||||
raise RuntimeError(
|
||||
f"Overshot: got {tok} tokens (target {target}). "
|
||||
f"Pick a different atom (try ' a' or '\\n' or '0 ')."
|
||||
)
|
||||
|
||||
return content, tok
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-bench",
|
||||
description="Benchmark exo model throughput across placement previews.",
|
||||
)
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
|
||||
ap.add_argument(
|
||||
"--pp",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Prompt-size hints (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--tg",
|
||||
nargs="+",
|
||||
required=True,
|
||||
help="Generation lengths (ints). Accepts commas.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--max-nodes",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-pipeline-jaccl",
|
||||
action="store_true",
|
||||
help="Pipeline jaccl is often pointless, skip by default",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--warmup",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
default="bench/results.json",
|
||||
help="Write raw per-run results JSON to this path.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dry-run", action="store_true", help="List selected placements and exit."
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
tg_list = parse_int_list(args.tg)
|
||||
if not pp_list or not tg_list:
|
||||
logger.error("pp and tg lists must be non-empty")
|
||||
return 2
|
||||
if args.repeat <= 0:
|
||||
logger.error("--repeat must be >= 1")
|
||||
return 2
|
||||
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": short_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
full_model_id,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("[exo-bench] tokenizer load failed")
|
||||
|
||||
try:
|
||||
prompt_sizer = PromptSizer(tokenizer)
|
||||
logger.debug(f"[exo-bench] loaded tokenizer: {full_model_id} for prompt sizer")
|
||||
except Exception:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
# Skip tensor ring single node as it is pointless when pipeline ring
|
||||
if n == 1 and (
|
||||
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
|
||||
or (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
args.skip_pipeline_jaccl
|
||||
and (
|
||||
args.instance_meta == "both"
|
||||
and "jaccl" in p.get("instance_meta", "").lower()
|
||||
)
|
||||
and (
|
||||
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if 0 < n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
logger.error("No valid placements matched your filters.")
|
||||
return 1
|
||||
|
||||
selected.sort(
|
||||
key=lambda p: (
|
||||
str(p.get("instance_meta", "")),
|
||||
str(p.get("sharding", "")),
|
||||
-nodes_used_in_instance(p["instance"]),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
|
||||
logger.info(f"placements: {len(selected)}")
|
||||
for p in selected:
|
||||
logger.info(
|
||||
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
|
||||
)
|
||||
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
|
||||
sharding = str(preview["sharding"])
|
||||
instance_meta = str(preview["instance_meta"])
|
||||
n_nodes = nodes_used_in_instance(instance)
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info(
|
||||
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
for i in range(args.warmup):
|
||||
run_one_completion(
|
||||
client, full_model_id, pp_list[0], tg_list[0], prompt_sizer
|
||||
)
|
||||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||||
|
||||
for pp in pp_list:
|
||||
if (
|
||||
pp * n_nodes > 2048
|
||||
and "ring" in instance_meta.lower()
|
||||
and "tensor" in sharding.lower()
|
||||
):
|
||||
model_card = MODEL_CARDS[short_id]
|
||||
if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
logger.info(
|
||||
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
)
|
||||
continue
|
||||
for tg in tg_list:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
time.sleep(3)
|
||||
try:
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client, full_model_id, pp, tg, prompt_sizer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
row.update(
|
||||
{
|
||||
"model_short_id": short_id,
|
||||
"model_id": full_model_id,
|
||||
"placement_sharding": sharding,
|
||||
"placement_instance_meta": instance_meta,
|
||||
"placement_nodes": n_nodes,
|
||||
"instance_id": instance_id,
|
||||
"pp_tokens": actual_pp_tokens,
|
||||
"tg": tg,
|
||||
"repeat_index": r,
|
||||
}
|
||||
)
|
||||
runs.append(row)
|
||||
all_rows.append(row)
|
||||
|
||||
if runs:
|
||||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||||
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
f"peak_memory={format_peak_memory(peak)}\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
finally:
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
logger.debug(f"Deleted instance {instance_id}")
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
if args.json_out:
|
||||
with open(args.json_out, "w", encoding="utf-8") as f:
|
||||
json.dump(all_rows, f, indent=2, ensure_ascii=False)
|
||||
logger.debug(f"\nWrote results JSON: {args.json_out}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
60
dashboard/dashboard.nix
Normal file
60
dashboard/dashboard.nix
Normal file
@@ -0,0 +1,60 @@
|
||||
{ lib
|
||||
, config
|
||||
, dream2nix
|
||||
, ...
|
||||
}:
|
||||
let
|
||||
# Read and parse the lock file
|
||||
rawLockFile = builtins.fromJSON (builtins.readFile "${config.deps.dashboardSrc}/package-lock.json");
|
||||
|
||||
# For packages with bundleDependencies, filter out deps that are bundled
|
||||
# (bundled deps are inside the tarball, not separate lockfile entries)
|
||||
fixedPackages = lib.mapAttrs
|
||||
(path: entry:
|
||||
if entry ? bundleDependencies && entry.bundleDependencies != [ ]
|
||||
then entry // {
|
||||
dependencies = lib.filterAttrs
|
||||
(name: _: !(lib.elem name entry.bundleDependencies))
|
||||
(entry.dependencies or { });
|
||||
}
|
||||
else entry
|
||||
)
|
||||
(rawLockFile.packages or { });
|
||||
|
||||
fixedLockFile = rawLockFile // { packages = fixedPackages; };
|
||||
in
|
||||
{
|
||||
imports = [
|
||||
dream2nix.modules.dream2nix.nodejs-package-lock-v3
|
||||
dream2nix.modules.dream2nix.nodejs-granular-v3
|
||||
];
|
||||
|
||||
name = "exo-dashboard";
|
||||
version = "1.0.0";
|
||||
|
||||
mkDerivation = {
|
||||
src = config.deps.dashboardSrc;
|
||||
|
||||
buildPhase = ''
|
||||
runHook preBuild
|
||||
npm run build
|
||||
runHook postBuild
|
||||
'';
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
cp -r build $out/build
|
||||
runHook postInstall
|
||||
'';
|
||||
};
|
||||
|
||||
deps = { nixpkgs, ... }: {
|
||||
inherit (nixpkgs) stdenv;
|
||||
dashboardSrc = null; # Injected by parts.nix
|
||||
};
|
||||
|
||||
nodejs-package-lock-v3 = {
|
||||
# Don't use packageLockFile - provide the fixed lock content directly
|
||||
packageLock = fixedLockFile;
|
||||
};
|
||||
}
|
||||
128
dashboard/package-lock.json
generated
128
dashboard/package-lock.json
generated
@@ -9,6 +9,8 @@
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -19,6 +21,8 @@
|
||||
"@types/d3": "^7.4.3",
|
||||
"@types/node": "^22",
|
||||
"d3": "^7.9.0",
|
||||
"prettier": "^3.4.2",
|
||||
"prettier-plugin-svelte": "^3.3.3",
|
||||
"svelte": "^5.0.0",
|
||||
"svelte-check": "^4.0.0",
|
||||
"tailwindcss": "^4.0.0",
|
||||
@@ -1159,6 +1163,66 @@
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/core": {
|
||||
"version": "1.6.0",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@emnapi/wasi-threads": "1.1.0",
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/runtime": {
|
||||
"version": "1.6.0",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@emnapi/wasi-threads": {
|
||||
"version": "1.1.0",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@napi-rs/wasm-runtime": {
|
||||
"version": "1.0.7",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@emnapi/core": "^1.5.0",
|
||||
"@emnapi/runtime": "^1.5.0",
|
||||
"@tybys/wasm-util": "^0.10.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/@tybys/wasm-util": {
|
||||
"version": "0.10.1",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-wasm32-wasi/node_modules/tslib": {
|
||||
"version": "2.8.1",
|
||||
"dev": true,
|
||||
"inBundle": true,
|
||||
"license": "0BSD",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/@tailwindcss/oxide-win32-arm64-msvc": {
|
||||
"version": "4.1.17",
|
||||
"resolved": "https://registry.npmjs.org/@tailwindcss/oxide-win32-arm64-msvc/-/oxide-win32-arm64-msvc-4.1.17.tgz",
|
||||
@@ -2254,6 +2318,31 @@
|
||||
"jiti": "lib/jiti-cli.mjs"
|
||||
}
|
||||
},
|
||||
"node_modules/katex": {
|
||||
"version": "0.16.27",
|
||||
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.27.tgz",
|
||||
"integrity": "sha512-aeQoDkuRWSqQN6nSvVCEFvfXdqo1OQiCmmW1kc9xSdjutPv7BGO7pqY9sQRJpMOGrEdfDgF2TfRXe5eUAD2Waw==",
|
||||
"funding": [
|
||||
"https://opencollective.com/katex",
|
||||
"https://github.com/sponsors/katex"
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"commander": "^8.3.0"
|
||||
},
|
||||
"bin": {
|
||||
"katex": "cli.js"
|
||||
}
|
||||
},
|
||||
"node_modules/katex/node_modules/commander": {
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz",
|
||||
"integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
}
|
||||
},
|
||||
"node_modules/kleur": {
|
||||
"version": "4.1.5",
|
||||
"resolved": "https://registry.npmjs.org/kleur/-/kleur-4.1.5.tgz",
|
||||
@@ -2540,6 +2629,18 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.5.5"
|
||||
}
|
||||
},
|
||||
"node_modules/marked": {
|
||||
"version": "17.0.1",
|
||||
"resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz",
|
||||
"integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==",
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"marked": "bin/marked.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 20"
|
||||
}
|
||||
},
|
||||
"node_modules/mode-watcher": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/mode-watcher/-/mode-watcher-1.1.0.tgz",
|
||||
@@ -2649,6 +2750,33 @@
|
||||
"node": "^10 || ^12 || >=14"
|
||||
}
|
||||
},
|
||||
"node_modules/prettier": {
|
||||
"version": "3.8.0",
|
||||
"resolved": "https://registry.npmjs.org/prettier/-/prettier-3.8.0.tgz",
|
||||
"integrity": "sha512-yEPsovQfpxYfgWNhCfECjG5AQaO+K3dp6XERmOepyPDVqcJm+bjyCVO3pmU+nAPe0N5dDvekfGezt/EIiRe1TA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"prettier": "bin/prettier.cjs"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/prettier/prettier?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/prettier-plugin-svelte": {
|
||||
"version": "3.4.1",
|
||||
"resolved": "https://registry.npmjs.org/prettier-plugin-svelte/-/prettier-plugin-svelte-3.4.1.tgz",
|
||||
"integrity": "sha512-xL49LCloMoZRvSwa6IEdN2GV6cq2IqpYGstYtMT+5wmml1/dClEoI0MZR78MiVPpu6BdQFfN0/y73yO6+br5Pg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"prettier": "^3.0.0",
|
||||
"svelte": "^3.2.0 || ^4.0.0-next.0 || ^5.0.0-next.0"
|
||||
}
|
||||
},
|
||||
"node_modules/readdirp": {
|
||||
"version": "4.1.2",
|
||||
"resolved": "https://registry.npmjs.org/readdirp/-/readdirp-4.1.2.tgz",
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json"
|
||||
},
|
||||
"devDependencies": {
|
||||
"prettier": "^3.4.2",
|
||||
"prettier-plugin-svelte": "^3.3.3",
|
||||
"@sveltejs/adapter-static": "^3.0.10",
|
||||
"@sveltejs/kit": "^2.48.4",
|
||||
"@sveltejs/vite-plugin-svelte": "^5.0.0",
|
||||
@@ -27,7 +29,8 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
53
dashboard/parts.nix
Normal file
53
dashboard/parts.nix
Normal file
@@ -0,0 +1,53 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Filter source to only include dashboard directory
|
||||
dashboardSrc = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
inDashboardDir =
|
||||
(lib.hasInfix "/dashboard/" path)
|
||||
|| (lib.hasSuffix "/dashboard" (builtins.dirOf path))
|
||||
|| (baseName == "dashboard" && type == "directory");
|
||||
in
|
||||
inDashboardDir;
|
||||
};
|
||||
|
||||
# Build the dashboard with dream2nix (includes node_modules in output)
|
||||
dashboardFull = inputs.dream2nix.lib.evalModules {
|
||||
packageSets.nixpkgs = pkgs;
|
||||
modules = [
|
||||
./dashboard.nix
|
||||
{
|
||||
paths.projectRoot = inputs.self;
|
||||
paths.projectRootFile = "flake.nix";
|
||||
paths.package = inputs.self + "/dashboard";
|
||||
}
|
||||
# Inject the filtered source
|
||||
{
|
||||
deps.dashboardSrc = lib.mkForce "${dashboardSrc}/dashboard";
|
||||
}
|
||||
];
|
||||
};
|
||||
in
|
||||
{
|
||||
# Extract just the static site from the full build
|
||||
packages.dashboard = pkgs.runCommand "exo-dashboard" { } ''
|
||||
cp -r ${dashboardFull}/build $out
|
||||
'';
|
||||
|
||||
# Prettier with svelte plugin for treefmt
|
||||
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
|
||||
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
|
||||
exec ${pkgs.nodejs}/bin/node \
|
||||
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
|
||||
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
|
||||
"$@"
|
||||
'';
|
||||
};
|
||||
}
|
||||
1
dashboard/src/app.d.ts
vendored
1
dashboard/src/app.d.ts
vendored
@@ -11,4 +11,3 @@ declare global {
|
||||
}
|
||||
|
||||
export {};
|
||||
|
||||
|
||||
@@ -13,11 +13,16 @@
|
||||
function getFileIcon(file: ChatUploadedFile): string {
|
||||
const category = getFileCategory(file.type, file.name);
|
||||
switch (category) {
|
||||
case 'image': return '🖼';
|
||||
case 'text': return '📄';
|
||||
case 'pdf': return '📑';
|
||||
case 'audio': return '🎵';
|
||||
default: return '📎';
|
||||
case 'image':
|
||||
return '🖼';
|
||||
case 'text':
|
||||
return '📄';
|
||||
case 'pdf':
|
||||
return '📑';
|
||||
case 'audio':
|
||||
return '🎵';
|
||||
default:
|
||||
return '📎';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,18 +38,16 @@
|
||||
{#if files.length > 0}
|
||||
<div class="flex flex-wrap gap-2 mb-3 px-1">
|
||||
{#each files as file (file.id)}
|
||||
<div class="group relative flex items-center gap-2 bg-exo-dark-gray/80 border border-exo-yellow/30 rounded px-2.5 py-1.5 text-xs font-mono transition-all hover:border-exo-yellow/50 hover:shadow-[0_0_10px_rgba(255,215,0,0.1)]">
|
||||
<div
|
||||
class="group relative flex items-center gap-2 bg-exo-dark-gray/80 border border-exo-yellow/30 rounded px-2.5 py-1.5 text-xs font-mono transition-all hover:border-exo-yellow/50 hover:shadow-[0_0_10px_rgba(255,215,0,0.1)]"
|
||||
>
|
||||
<!-- File preview or icon -->
|
||||
{#if file.preview && getFileCategory(file.type, file.name) === 'image'}
|
||||
<img
|
||||
src={file.preview}
|
||||
alt={file.name}
|
||||
class="w-8 h-8 object-cover rounded border border-exo-yellow/20"
|
||||
/>
|
||||
<img src={file.preview} alt={file.name} class="w-8 h-8 object-cover rounded border border-exo-yellow/20" />
|
||||
{:else}
|
||||
<span class="text-base">{getFileIcon(file)}</span>
|
||||
{/if}
|
||||
|
||||
|
||||
<!-- File info -->
|
||||
<div class="flex flex-col min-w-0">
|
||||
<span class="text-exo-yellow truncate max-w-[120px]" title={file.name}>
|
||||
@@ -54,7 +57,7 @@
|
||||
{formatFileSize(file.size)}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Remove button -->
|
||||
{#if !readonly && onRemove}
|
||||
<button
|
||||
@@ -72,4 +75,3 @@
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
|
||||
@@ -12,13 +12,7 @@
|
||||
showModelSelector?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
placeholder = 'Ask anything',
|
||||
showHelperText = false,
|
||||
autofocus = true,
|
||||
showModelSelector = false
|
||||
}: Props = $props();
|
||||
let { class: className = '', placeholder = 'Ask anything', showHelperText = false, autofocus = true, showModelSelector = false }: Props = $props();
|
||||
|
||||
let message = $state('');
|
||||
let textareaRef: HTMLTextAreaElement | undefined = $state();
|
||||
@@ -31,7 +25,7 @@
|
||||
const currentTtft = $derived(ttftMs());
|
||||
const currentTps = $derived(tps());
|
||||
const currentTokens = $derived(totalTokens());
|
||||
|
||||
|
||||
// Custom dropdown state
|
||||
let isModelDropdownOpen = $state(false);
|
||||
let dropdownButtonRef: HTMLButtonElement | undefined = $state();
|
||||
@@ -50,7 +44,7 @@
|
||||
|
||||
// Extract available models from running instances
|
||||
const availableModels = $derived(() => {
|
||||
const models: Array<{id: string, label: string}> = [];
|
||||
const models: Array<{ id: string; label: string }> = [];
|
||||
for (const [, instance] of Object.entries(instanceData)) {
|
||||
const modelId = getInstanceModelId(instance);
|
||||
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
|
||||
@@ -98,18 +92,18 @@
|
||||
|
||||
function handlePaste(event: ClipboardEvent) {
|
||||
if (!event.clipboardData) return;
|
||||
|
||||
|
||||
const files = Array.from(event.clipboardData.items)
|
||||
.filter(item => item.kind === 'file')
|
||||
.map(item => item.getAsFile())
|
||||
.filter((file): file is File => file !== null);
|
||||
|
||||
|
||||
if (files.length > 0) {
|
||||
event.preventDefault();
|
||||
handleFiles(files);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Handle long text paste as file
|
||||
const text = event.clipboardData.getData('text/plain');
|
||||
if (text.length > 2500) {
|
||||
@@ -132,13 +126,18 @@
|
||||
function handleDrop(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
isDragOver = false;
|
||||
|
||||
|
||||
if (event.dataTransfer?.files) {
|
||||
handleFiles(Array.from(event.dataTransfer.files));
|
||||
}
|
||||
}
|
||||
|
||||
function handleKeydown(event: KeyboardEvent) {
|
||||
// Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)
|
||||
if (event.isComposing || event.keyCode === 229) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.key === 'Enter' && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
handleSubmit();
|
||||
@@ -147,16 +146,16 @@
|
||||
|
||||
function handleSubmit() {
|
||||
if ((!message.trim() && uploadedFiles.length === 0) || loading) return;
|
||||
|
||||
|
||||
const content = message.trim();
|
||||
const files = [...uploadedFiles];
|
||||
|
||||
|
||||
message = '';
|
||||
uploadedFiles = [];
|
||||
resetTextareaHeight();
|
||||
|
||||
|
||||
sendMessage(content, files);
|
||||
|
||||
|
||||
// Refocus the textarea after sending
|
||||
setTimeout(() => textareaRef?.focus(), 10);
|
||||
}
|
||||
@@ -179,13 +178,13 @@
|
||||
|
||||
// Track previous loading state to detect when loading completes
|
||||
let wasLoading = $state(false);
|
||||
|
||||
|
||||
$effect(() => {
|
||||
if (autofocus && textareaRef) {
|
||||
setTimeout(() => textareaRef?.focus(), 10);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
// Refocus after loading completes (AI response finished)
|
||||
$effect(() => {
|
||||
if (wasLoading && !loading && textareaRef) {
|
||||
@@ -198,37 +197,29 @@
|
||||
</script>
|
||||
|
||||
<!-- Hidden file input -->
|
||||
<input
|
||||
bind:this={fileInputRef}
|
||||
type="file"
|
||||
accept={acceptString}
|
||||
multiple
|
||||
class="hidden"
|
||||
onchange={handleFileInputChange}
|
||||
/>
|
||||
<input bind:this={fileInputRef} type="file" accept={acceptString} multiple class="hidden" onchange={handleFileInputChange} />
|
||||
|
||||
<form
|
||||
onsubmit={(e) => { e.preventDefault(); handleSubmit(); }}
|
||||
<form
|
||||
onsubmit={e => {
|
||||
e.preventDefault();
|
||||
handleSubmit();
|
||||
}}
|
||||
class="w-full {className}"
|
||||
ondragover={handleDragOver}
|
||||
ondragleave={handleDragLeave}
|
||||
ondrop={handleDrop}
|
||||
>
|
||||
<div
|
||||
class="relative command-panel rounded overflow-hidden transition-all duration-200 {isDragOver ? 'ring-2 ring-exo-yellow ring-opacity-50' : ''}"
|
||||
>
|
||||
<div class="relative command-panel rounded overflow-hidden transition-all duration-200 {isDragOver ? 'ring-2 ring-exo-yellow ring-opacity-50' : ''}">
|
||||
<!-- Top accent line -->
|
||||
<div class="absolute top-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/50 to-transparent"></div>
|
||||
|
||||
|
||||
<!-- Drag overlay -->
|
||||
{#if isDragOver}
|
||||
<div class="absolute inset-0 bg-exo-dark-gray/80 z-10 flex items-center justify-center">
|
||||
<div class="text-exo-yellow text-sm font-mono tracking-wider uppercase">
|
||||
DROP FILES HERE
|
||||
</div>
|
||||
<div class="text-exo-yellow text-sm font-mono tracking-wider uppercase">DROP FILES HERE</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
|
||||
<!-- Model selector (when enabled) -->
|
||||
{#if showModelSelector && availableModels().length > 0}
|
||||
<div class="flex items-center justify-between gap-2 px-3 py-2 border-b border-exo-medium-gray/30">
|
||||
@@ -239,8 +230,10 @@
|
||||
<button
|
||||
bind:this={dropdownButtonRef}
|
||||
type="button"
|
||||
onclick={() => isModelDropdownOpen = !isModelDropdownOpen}
|
||||
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-1.5 text-xs font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen ? 'border-exo-yellow/70' : ''}"
|
||||
onclick={() => (isModelDropdownOpen = !isModelDropdownOpen)}
|
||||
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-1.5 text-xs font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
>
|
||||
{#if availableModels().find(m => m.id === currentModel)}
|
||||
<span class="text-exo-yellow truncate">{availableModels().find(m => m.id === currentModel)?.label}</span>
|
||||
@@ -256,18 +249,13 @@
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
{#if isModelDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => isModelDropdownOpen = false}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<button type="button" class="fixed inset-0 z-[9998] cursor-default" onclick={() => (isModelDropdownOpen = false)} aria-label="Close dropdown"></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto"
|
||||
style="bottom: calc(100vh - {dropdownPosition().top}px + 4px); left: {dropdownPosition().left}px; width: {dropdownPosition().width}px;"
|
||||
>
|
||||
@@ -279,11 +267,9 @@
|
||||
setSelectedChatModel(model.id);
|
||||
isModelDropdownOpen = false;
|
||||
}}
|
||||
class="w-full px-3 py-2 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {
|
||||
currentModel === model.id
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'
|
||||
}"
|
||||
class="w-full px-3 py-2 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {currentModel === model.id
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if currentModel === model.id}
|
||||
<svg class="w-3 h-3 flex-shrink-0" fill="currentColor" viewBox="0 0 20 20">
|
||||
@@ -317,17 +303,14 @@
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
|
||||
<!-- Attached files preview -->
|
||||
{#if uploadedFiles.length > 0}
|
||||
<div class="px-3 pt-3">
|
||||
<ChatAttachments
|
||||
files={uploadedFiles}
|
||||
onRemove={handleFileRemove}
|
||||
/>
|
||||
<ChatAttachments files={uploadedFiles} onRemove={handleFileRemove} />
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
|
||||
<!-- Input area -->
|
||||
<div class="flex items-start gap-2 sm:gap-3 py-3 px-3 sm:px-4">
|
||||
<!-- Attach file button -->
|
||||
@@ -339,13 +322,18 @@
|
||||
title="Attach file"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M15.172 7l-6.586 6.586a2 2 0 102.828 2.828l6.414-6.586a4 4 0 00-5.656-5.656l-6.415 6.585a6 6 0 108.486 8.486L20.5 13" />
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M15.172 7l-6.586 6.586a2 2 0 102.828 2.828l6.414-6.586a4 4 0 00-5.656-5.656l-6.415 6.585a6 6 0 108.486 8.486L20.5 13"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
|
||||
|
||||
<!-- Terminal prompt -->
|
||||
<span class="text-exo-yellow text-sm font-bold flex-shrink-0 leading-7">▶</span>
|
||||
|
||||
|
||||
<textarea
|
||||
bind:this={textareaRef}
|
||||
bind:value={message}
|
||||
@@ -358,14 +346,12 @@
|
||||
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
{!canSend || loading ? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed' : 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label="Send message"
|
||||
>
|
||||
{#if loading}
|
||||
@@ -379,11 +365,11 @@
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
<div class="absolute bottom-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/30 to-transparent"></div>
|
||||
</div>
|
||||
|
||||
|
||||
{#if showHelperText}
|
||||
<p class="mt-2 sm:mt-3 text-center text-xs sm:text-xs text-exo-light-gray tracking-[0.1em] sm:tracking-[0.15em] uppercase">
|
||||
<kbd class="px-1 sm:px-1.5 py-0.5 rounded bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50">ENTER</kbd>
|
||||
|
||||
@@ -1,96 +1,80 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
messages,
|
||||
currentResponse,
|
||||
isLoading,
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
regenerateLastResponse
|
||||
} from '$lib/stores/app.svelte';
|
||||
import { messages, currentResponse, isLoading, deleteMessage, editAndRegenerate, regenerateLastResponse } from '$lib/stores/app.svelte';
|
||||
import type { MessageAttachment } from '$lib/stores/app.svelte';
|
||||
import { tick, onDestroy } from 'svelte';
|
||||
import MarkdownContent from './MarkdownContent.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
|
||||
const messageList = $derived(messages());
|
||||
const response = $derived(currentResponse());
|
||||
const loading = $derived(isLoading());
|
||||
|
||||
// Ref for scroll anchor at bottom
|
||||
let scrollAnchorRef: HTMLDivElement | undefined = $state();
|
||||
// Scroll management - user controls scroll, show button when not at bottom
|
||||
const SCROLL_THRESHOLD = 100;
|
||||
let showScrollButton = $state(false);
|
||||
let lastMessageCount = 0;
|
||||
let containerRef: HTMLDivElement | undefined = $state();
|
||||
|
||||
// Scroll management
|
||||
const SCROLL_BOTTOM_THRESHOLD = 120;
|
||||
let autoScrollEnabled = true;
|
||||
let currentScrollEl: HTMLElement | null = null;
|
||||
|
||||
function resolveScrollElement(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
let node: HTMLElement | null = scrollAnchorRef?.parentElement as HTMLElement | null;
|
||||
while (node) {
|
||||
const isScrollable = node.scrollHeight > node.clientHeight + 1;
|
||||
if (isScrollable) return node;
|
||||
node = node.parentElement;
|
||||
function getScrollContainer(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
return containerRef?.parentElement ?? null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
if (!currentScrollEl) return;
|
||||
const distanceFromBottom = currentScrollEl.scrollHeight - currentScrollEl.scrollTop - currentScrollEl.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
autoScrollEnabled = isNearBottom;
|
||||
}
|
||||
|
||||
function attachScrollListener() {
|
||||
const nextEl = resolveScrollElement();
|
||||
if (currentScrollEl === nextEl) return;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
function isNearBottom(el: HTMLElement): boolean {
|
||||
return el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
|
||||
}
|
||||
currentScrollEl = nextEl;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.addEventListener('scroll', handleScroll);
|
||||
// Initialize state based on current position
|
||||
handleScroll();
|
||||
}
|
||||
}
|
||||
|
||||
onDestroy(() => {
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
// Re-evaluate scroll container if prop changes or after mount
|
||||
scrollParent;
|
||||
attachScrollListener();
|
||||
});
|
||||
|
||||
// Auto-scroll to bottom when messages change or response updates, but only if user is near bottom
|
||||
$effect(() => {
|
||||
// Track these values to trigger effect
|
||||
const _ = messageList.length;
|
||||
const __ = response;
|
||||
const ___ = loading;
|
||||
|
||||
tick().then(() => {
|
||||
const el = currentScrollEl ?? resolveScrollElement();
|
||||
if (!el || !scrollAnchorRef) return;
|
||||
const distanceFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
if (autoScrollEnabled || isNearBottom) {
|
||||
scrollAnchorRef.scrollIntoView({ behavior: 'smooth', block: 'end' });
|
||||
autoScrollEnabled = true;
|
||||
function scrollToBottom() {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
}
|
||||
}
|
||||
|
||||
function updateScrollButtonVisibility() {
|
||||
const el = getScrollContainer();
|
||||
if (!el) return;
|
||||
showScrollButton = !isNearBottom(el);
|
||||
}
|
||||
|
||||
// Attach scroll listener
|
||||
$effect(() => {
|
||||
const el = scrollParent ?? containerRef?.parentElement;
|
||||
if (!el) return;
|
||||
|
||||
el.addEventListener('scroll', updateScrollButtonVisibility, { passive: true });
|
||||
// Initial check
|
||||
updateScrollButtonVisibility();
|
||||
return () => el.removeEventListener('scroll', updateScrollButtonVisibility);
|
||||
});
|
||||
|
||||
// Auto-scroll when user sends a new message
|
||||
$effect(() => {
|
||||
const count = messageList.length;
|
||||
if (count > lastMessageCount) {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
requestAnimationFrame(() => {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
});
|
||||
}
|
||||
}
|
||||
lastMessageCount = count;
|
||||
});
|
||||
|
||||
// Update scroll button visibility when content changes
|
||||
$effect(() => {
|
||||
// Track response to trigger re-check during streaming
|
||||
const _ = response;
|
||||
|
||||
// Small delay to let DOM update
|
||||
requestAnimationFrame(() => updateScrollButtonVisibility());
|
||||
});
|
||||
});
|
||||
|
||||
// Edit state
|
||||
let editingMessageId = $state<string | null>(null);
|
||||
@@ -100,14 +84,14 @@ $effect(() => {
|
||||
// Delete confirmation state
|
||||
let deleteConfirmId = $state<string | null>(null);
|
||||
|
||||
// Copied state for feedback
|
||||
let copiedMessageId = $state<string | null>(null);
|
||||
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
|
||||
// Copied state for feedback
|
||||
let copiedMessageId = $state<string | null>(null);
|
||||
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
function formatTimestamp(timestamp: number): string {
|
||||
return new Date(timestamp).toLocaleTimeString('en-US', {
|
||||
return new Date(timestamp).toLocaleTimeString('en-US', {
|
||||
hour12: false,
|
||||
hour: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit'
|
||||
});
|
||||
@@ -115,9 +99,12 @@ let expandedThinkingMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
function getAttachmentIcon(attachment: MessageAttachment): string {
|
||||
switch (attachment.type) {
|
||||
case 'image': return '🖼';
|
||||
case 'text': return '📄';
|
||||
default: return '📎';
|
||||
case 'image':
|
||||
return '🖼';
|
||||
case 'text':
|
||||
return '📄';
|
||||
default:
|
||||
return '📎';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,19 +128,19 @@ let expandedThinkingMessageIds = $state<Set<string>>(new Set());
|
||||
}
|
||||
}
|
||||
|
||||
function toggleThinkingVisibility(messageId: string) {
|
||||
const next = new Set(expandedThinkingMessageIds);
|
||||
if (next.has(messageId)) {
|
||||
next.delete(messageId);
|
||||
} else {
|
||||
next.add(messageId);
|
||||
function toggleThinkingVisibility(messageId: string) {
|
||||
const next = new Set(expandedThinkingMessageIds);
|
||||
if (next.has(messageId)) {
|
||||
next.delete(messageId);
|
||||
} else {
|
||||
next.add(messageId);
|
||||
}
|
||||
expandedThinkingMessageIds = next;
|
||||
}
|
||||
expandedThinkingMessageIds = next;
|
||||
}
|
||||
|
||||
function isThinkingExpanded(messageId: string): boolean {
|
||||
return expandedThinkingMessageIds.has(messageId);
|
||||
}
|
||||
function isThinkingExpanded(messageId: string): boolean {
|
||||
return expandedThinkingMessageIds.has(messageId);
|
||||
}
|
||||
|
||||
function handleStartEdit(messageId: string, content: string) {
|
||||
editingMessageId = messageId;
|
||||
@@ -231,7 +218,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message (message.id)}
|
||||
<div class="group flex {message.role === 'user' ? 'justify-end' : 'justify-start'}">
|
||||
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'max-w-[95%] sm:max-w-[85%]'}">
|
||||
<div class={message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'w-full max-w-[98%] sm:max-w-[95%]'}>
|
||||
{#if message.role === 'assistant'}
|
||||
<!-- Assistant message header -->
|
||||
<div class="flex items-center gap-1.5 sm:gap-2 mb-1.5 sm:mb-2">
|
||||
@@ -240,7 +227,9 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
<span class="text-xs sm:text-sm text-exo-light-gray tracking-wider tabular-nums">{formatTimestamp(message.timestamp)}</span>
|
||||
{#if message.ttftMs || message.tps}
|
||||
<span class="text-xs text-exo-light-gray/80 font-mono ml-2">
|
||||
{#if message.ttftMs}<span class="text-exo-light-gray/50">TTFT</span> {message.ttftMs.toFixed(0)}ms{/if}{#if message.ttftMs && message.tps}<span class="text-exo-light-gray/30 mx-1">•</span>{/if}{#if message.tps}{message.tps.toFixed(1)} <span class="text-exo-light-gray/50">tok/s</span>{/if}
|
||||
{#if message.ttftMs}<span class="text-exo-light-gray/50">TTFT</span> {message.ttftMs.toFixed(0)}ms{/if}{#if message.ttftMs && message.tps}<span class="text-exo-light-gray/30 mx-1"
|
||||
>•</span
|
||||
>{/if}{#if message.tps}{message.tps.toFixed(1)} <span class="text-exo-light-gray/50">tok/s</span>{/if}
|
||||
</span>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -252,22 +241,22 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
<div class="w-1.5 h-1.5 sm:w-2 sm:h-2 bg-exo-light-gray/50 rounded-full"></div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
|
||||
{#if deleteConfirmId === message.id}
|
||||
<!-- Delete confirmation -->
|
||||
<div class="bg-red-500/10 border border-red-500/30 rounded-lg p-3">
|
||||
<p class="text-xs text-red-400 mb-3">Delete this message{message.role === 'user' ? ' and all responses after it' : ''}?</p>
|
||||
<div class="flex gap-2 justify-end">
|
||||
<button
|
||||
onclick={handleCancelDelete}
|
||||
class="px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-exo-medium-gray/20 text-exo-light-gray border border-exo-medium-gray/30 rounded hover:bg-exo-medium-gray/30 transition-colors cursor-pointer"
|
||||
>
|
||||
CANCEL
|
||||
</button>
|
||||
<button
|
||||
onclick={handleConfirmDelete}
|
||||
class="px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-red-500/20 text-red-400 border border-red-500/30 rounded hover:bg-red-500/30 transition-colors cursor-pointer"
|
||||
>
|
||||
<button
|
||||
onclick={handleCancelDelete}
|
||||
class="px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-exo-medium-gray/20 text-exo-light-gray border border-exo-medium-gray/30 rounded hover:bg-exo-medium-gray/30 transition-colors cursor-pointer"
|
||||
>
|
||||
CANCEL
|
||||
</button>
|
||||
<button
|
||||
onclick={handleConfirmDelete}
|
||||
class="px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-red-500/20 text-red-400 border border-red-500/30 rounded hover:bg-red-500/30 transition-colors cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
@@ -284,17 +273,17 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
style="min-height: 60px; max-height: 200px;"
|
||||
></textarea>
|
||||
<div class="flex gap-2 justify-end mt-2">
|
||||
<button
|
||||
onclick={handleCancelEdit}
|
||||
class="px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-exo-medium-gray/20 text-exo-light-gray border border-exo-medium-gray/30 rounded hover:bg-exo-medium-gray/30 transition-colors cursor-pointer"
|
||||
>
|
||||
CANCEL
|
||||
</button>
|
||||
<button
|
||||
onclick={handleSaveEdit}
|
||||
disabled={!editContent.trim()}
|
||||
class="px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-transparent text-exo-yellow border border-exo-yellow/30 rounded hover:border-exo-yellow/50 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-1.5 cursor-pointer"
|
||||
>
|
||||
<button
|
||||
onclick={handleCancelEdit}
|
||||
class="px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-exo-medium-gray/20 text-exo-light-gray border border-exo-medium-gray/30 rounded hover:bg-exo-medium-gray/30 transition-colors cursor-pointer"
|
||||
>
|
||||
CANCEL
|
||||
</button>
|
||||
<button
|
||||
onclick={handleSaveEdit}
|
||||
disabled={!editContent.trim()}
|
||||
class="px-3 py-1.5 text-sm font-mono tracking-wider uppercase bg-transparent text-exo-yellow border border-exo-yellow/30 rounded hover:border-exo-yellow/50 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-1.5 cursor-pointer"
|
||||
>
|
||||
<svg class="w-3 h-3" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 19l9 2-9-18-9 18 9-2zm0 0v-8" />
|
||||
</svg>
|
||||
@@ -303,10 +292,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="{message.role === 'user'
|
||||
? 'command-panel rounded-lg rounded-tr-sm inline-block'
|
||||
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 inline-block'}">
|
||||
|
||||
<div class={message.role === 'user' ? 'command-panel rounded-lg rounded-tr-sm inline-block' : 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 block w-full'}>
|
||||
{#if message.role === 'user'}
|
||||
<!-- User message styling -->
|
||||
<div class="px-4 py-3">
|
||||
@@ -316,11 +302,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{#each message.attachments as attachment}
|
||||
<div class="flex items-center gap-2 bg-exo-dark-gray/60 border border-exo-yellow/20 rounded px-2 py-1 text-xs font-mono">
|
||||
{#if attachment.type === 'image' && attachment.preview}
|
||||
<img
|
||||
src={attachment.preview}
|
||||
alt={attachment.name}
|
||||
class="w-12 h-12 object-cover rounded border border-exo-yellow/20"
|
||||
/>
|
||||
<img src={attachment.preview} alt={attachment.name} class="w-12 h-12 object-cover rounded border border-exo-yellow/20" />
|
||||
{:else}
|
||||
<span>{getAttachmentIcon(attachment)}</span>
|
||||
{/if}
|
||||
@@ -329,9 +311,9 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
|
||||
{#if message.content}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
<div class="text-xs text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content}
|
||||
</div>
|
||||
{/if}
|
||||
@@ -360,22 +342,19 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</svg>
|
||||
<span>Thinking...</span>
|
||||
</span>
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60">
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60 ml-4">
|
||||
{isThinkingExpanded(message.id) ? 'HIDE' : 'SHOW'}
|
||||
</span>
|
||||
</button>
|
||||
{#if isThinkingExpanded(message.id)}
|
||||
<div
|
||||
id={`thinking-panel-${message.id}`}
|
||||
class="px-3 pb-3 text-xs text-exo-light-gray/90 font-mono whitespace-pre-wrap break-words leading-relaxed"
|
||||
>
|
||||
<div id={`thinking-panel-${message.id}`} class="px-3 pb-3 text-xs text-exo-light-gray/90 font-mono whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.thinking.trim()}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content || (loading ? response : '')}
|
||||
<div class="text-xs text-foreground">
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
@@ -383,26 +362,27 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Action buttons -->
|
||||
<div class="flex items-center gap-1 mt-1.5 opacity-0 group-hover:opacity-100 transition-opacity {message.role === 'user' ? 'justify-end' : 'justify-start'}">
|
||||
<!-- Copy button -->
|
||||
<button
|
||||
onclick={() => handleCopy(message.content, message.id)}
|
||||
class="p-1.5 text-exo-light-gray hover:text-exo-yellow transition-colors rounded cursor-pointer"
|
||||
title="Copy message"
|
||||
>
|
||||
<button onclick={() => handleCopy(message.content, message.id)} class="p-1.5 text-exo-light-gray hover:text-exo-yellow transition-colors rounded cursor-pointer" title="Copy message">
|
||||
{#if copiedMessageId === message.id}
|
||||
<svg class="w-3.5 h-3.5 text-green-400" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 13l4 4L19 7" />
|
||||
</svg>
|
||||
{:else}
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z" />
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
</button>
|
||||
|
||||
|
||||
<!-- Edit button (user messages only) -->
|
||||
{#if message.role === 'user'}
|
||||
<button
|
||||
@@ -411,24 +391,30 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
title="Edit message"
|
||||
>
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" />
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
|
||||
<!-- Regenerate button (last assistant message only) -->
|
||||
{#if message.role === 'assistant' && isLastAssistantMessage(message.id) && !loading}
|
||||
<button
|
||||
onclick={handleRegenerate}
|
||||
class="p-1.5 text-exo-light-gray hover:text-exo-yellow transition-colors rounded cursor-pointer"
|
||||
title="Regenerate response"
|
||||
>
|
||||
<button onclick={handleRegenerate} class="p-1.5 text-exo-light-gray hover:text-exo-yellow transition-colors rounded cursor-pointer" title="Regenerate response">
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
onclick={() => handleDeleteClick(message.id)}
|
||||
@@ -436,7 +422,12 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
title="Delete message"
|
||||
>
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" />
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
@@ -444,7 +435,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
|
||||
|
||||
{#if messageList.length === 0}
|
||||
<div class="flex-1 flex flex-col items-center justify-center text-center pt-[20vh]">
|
||||
<div class="w-12 h-12 sm:w-16 sm:h-16 border border-exo-yellow/20 rounded-full flex items-center justify-center mb-3 sm:mb-4">
|
||||
@@ -456,7 +447,21 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
<p class="text-sm sm:text-xs text-exo-light-gray tracking-wider mt-1">ENTER A QUERY TO BEGIN</p>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Scroll anchor for auto-scroll -->
|
||||
<div bind:this={scrollAnchorRef}></div>
|
||||
|
||||
<!-- Invisible element for container reference -->
|
||||
<div bind:this={containerRef}></div>
|
||||
|
||||
<!-- Scroll to bottom button -->
|
||||
{#if showScrollButton}
|
||||
<button
|
||||
type="button"
|
||||
onclick={scrollToBottom}
|
||||
class="sticky bottom-4 left-1/2 -translate-x-1/2 w-10 h-10 rounded-full bg-exo-dark-gray/90 border border-exo-medium-gray/50 flex items-center justify-center text-exo-light-gray hover:text-exo-yellow hover:border-exo-yellow/50 transition-all shadow-lg cursor-pointer z-10"
|
||||
title="Scroll to bottom"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
conversations,
|
||||
activeConversationId,
|
||||
createConversation,
|
||||
loadConversation,
|
||||
deleteConversation,
|
||||
import {
|
||||
conversations,
|
||||
activeConversationId,
|
||||
createConversation,
|
||||
loadConversation,
|
||||
deleteConversation,
|
||||
deleteAllConversations,
|
||||
renameConversation,
|
||||
clearChat,
|
||||
instances,
|
||||
debugMode,
|
||||
toggleDebugMode
|
||||
toggleDebugMode,
|
||||
topologyOnlyMode,
|
||||
toggleTopologyOnlyMode
|
||||
} from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -21,8 +23,9 @@ import {
|
||||
|
||||
const conversationList = $derived(conversations());
|
||||
const activeId = $derived(activeConversationId());
|
||||
const instanceData = $derived(instances());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const instanceData = $derived(instances());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
|
||||
let searchQuery = $state('');
|
||||
let editingId = $state<string | null>(null);
|
||||
@@ -30,11 +33,7 @@ const debugEnabled = $derived(debugMode());
|
||||
let deleteConfirmId = $state<string | null>(null);
|
||||
let showDeleteAllConfirm = $state(false);
|
||||
|
||||
const filteredConversations = $derived(
|
||||
searchQuery.trim()
|
||||
? conversationList.filter(c => c.name.toLowerCase().includes(searchQuery.toLowerCase()))
|
||||
: conversationList
|
||||
);
|
||||
const filteredConversations = $derived(searchQuery.trim() ? conversationList.filter(c => c.name.toLowerCase().includes(searchQuery.toLowerCase())) : conversationList);
|
||||
|
||||
function handleNewChat() {
|
||||
createConversation();
|
||||
@@ -104,7 +103,7 @@ const debugEnabled = $derived(debugMode());
|
||||
const date = new Date(timestamp);
|
||||
const now = new Date();
|
||||
const diffDays = Math.floor((now.getTime() - date.getTime()) / (1000 * 60 * 60 * 24));
|
||||
|
||||
|
||||
if (diffDays === 0) {
|
||||
return date.toLocaleTimeString('en-US', { hour: '2-digit', minute: '2-digit' });
|
||||
} else if (diffDays === 1) {
|
||||
@@ -116,7 +115,7 @@ const debugEnabled = $derived(debugMode());
|
||||
}
|
||||
}
|
||||
|
||||
function getLastAssistantStats(conversation: typeof conversationList[0]): { ttftMs?: number; tps?: number } | null {
|
||||
function getLastAssistantStats(conversation: (typeof conversationList)[0]): { ttftMs?: number; tps?: number } | null {
|
||||
// Find the last assistant message with stats
|
||||
for (let i = conversation.messages.length - 1; i >= 0; i--) {
|
||||
const msg = conversation.messages[i];
|
||||
@@ -180,7 +179,7 @@ const debugEnabled = $derived(debugMode());
|
||||
return { sharding, instanceType };
|
||||
}
|
||||
|
||||
function resolveConversationInfo(conversation: typeof conversationList[0]): { modelLabel: string; strategyLabel: string } {
|
||||
function resolveConversationInfo(conversation: (typeof conversationList)[0]): { modelLabel: string; strategyLabel: string } {
|
||||
// Attempt to match conversation model to an instance
|
||||
let matchedInstance: unknown = null;
|
||||
let modelId = conversation.modelId ?? null;
|
||||
@@ -254,7 +253,7 @@ const debugEnabled = $derived(debugMode());
|
||||
{searchQuery ? 'SEARCH RESULTS' : 'CONVERSATIONS'}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
|
||||
{#each filteredConversations as conversation (conversation.id)}
|
||||
{@const info = resolveConversationInfo(conversation)}
|
||||
<div class="px-2">
|
||||
@@ -309,11 +308,9 @@ const debugEnabled = $derived(debugMode());
|
||||
role="button"
|
||||
tabindex="0"
|
||||
onclick={() => handleSelectConversation(conversation.id)}
|
||||
onkeydown={(e) => e.key === 'Enter' && handleSelectConversation(conversation.id)}
|
||||
onkeydown={e => e.key === 'Enter' && handleSelectConversation(conversation.id)}
|
||||
class="group w-full flex items-center justify-between p-2 rounded mb-1 transition-all text-left cursor-pointer
|
||||
{activeId === conversation.id
|
||||
? 'bg-transparent border border-exo-yellow/30'
|
||||
: 'hover:border-exo-yellow/20 border border-transparent'}"
|
||||
{activeId === conversation.id ? 'bg-transparent border border-exo-yellow/30' : 'hover:border-exo-yellow/20 border border-transparent'}"
|
||||
>
|
||||
<div class="flex-1 min-w-0 pr-2">
|
||||
<div class="text-sm truncate {activeId === conversation.id ? 'text-exo-yellow' : 'text-white/90'}">
|
||||
@@ -330,30 +327,36 @@ const debugEnabled = $derived(debugMode());
|
||||
</div>
|
||||
{#if stats}
|
||||
<div class="text-xs text-white/60 font-mono mt-1">
|
||||
{#if stats.ttftMs}<span class="text-white/40">TTFT</span> {stats.ttftMs.toFixed(0)}ms{/if}{#if stats.ttftMs && stats.tps}<span class="text-white/30 mx-1.5">•</span>{/if}{#if stats.tps}{stats.tps.toFixed(1)} <span class="text-white/40">tok/s</span>{/if}
|
||||
{#if stats.ttftMs}<span class="text-white/40">TTFT</span> {stats.ttftMs.toFixed(0)}ms{/if}{#if stats.ttftMs && stats.tps}<span class="text-white/30 mx-1.5">•</span
|
||||
>{/if}{#if stats.tps}{stats.tps.toFixed(1)} <span class="text-white/40">tok/s</span>{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
|
||||
<div class="flex items-center gap-1 opacity-0 group-hover:opacity-100 transition-opacity">
|
||||
<button
|
||||
type="button"
|
||||
onclick={(e) => handleStartEdit(conversation.id, conversation.name, e)}
|
||||
class="p-1 text-exo-light-gray hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
title="Rename"
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
onclick={e => handleStartEdit(conversation.id, conversation.name, e)}
|
||||
class="p-1 text-exo-light-gray hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
title="Rename"
|
||||
>
|
||||
<svg class="w-3 h-3" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" />
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onclick={(e) => handleDeleteClick(conversation.id, e)}
|
||||
class="p-1 text-exo-light-gray hover:text-red-400 transition-colors cursor-pointer"
|
||||
title="Delete"
|
||||
>
|
||||
<button type="button" onclick={e => handleDeleteClick(conversation.id, e)} class="p-1 text-exo-light-gray hover:text-red-400 transition-colors cursor-pointer" title="Delete">
|
||||
<svg class="w-3 h-3" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" />
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
@@ -366,7 +369,12 @@ const debugEnabled = $derived(debugMode());
|
||||
<div class="flex flex-col items-center justify-center h-full p-4 text-center">
|
||||
<div class="w-12 h-12 border border-exo-yellow/20 rounded-full flex items-center justify-center mb-3">
|
||||
<svg class="w-6 h-6 text-exo-yellow/40" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M8 12h.01M12 12h.01M16 12h.01M21 12c0 4.418-4.03 8-9 8a9.863 9.863 0 01-4.255-.949L3 20l1.395-3.72C3.512 15.042 3 13.574 3 12c0-4.418 4.03-8 9-8s9 3.582 9 8z" />
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="1.5"
|
||||
d="M8 12h.01M12 12h.01M16 12h.01M21 12c0 4.418-4.03 8-9 8a9.863 9.863 0 01-4.255-.949L3 20l1.395-3.72C3.512 15.042 3 13.574 3 12c0-4.418 4.03-8 9-8s9 3.582 9 8z"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
<p class="text-xs text-white/70 font-mono tracking-wider uppercase mb-1">
|
||||
@@ -385,46 +393,60 @@ const debugEnabled = $derived(debugMode());
|
||||
<div class="bg-red-500/10 border border-red-500/30 rounded p-2 mb-2">
|
||||
<p class="text-xs text-red-400 text-center mb-2">Delete all {conversationList.length} conversations?</p>
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
onclick={handleConfirmDeleteAll}
|
||||
class="flex-1 py-1.5 text-xs font-mono tracking-wider uppercase bg-red-500/20 text-red-400 border border-red-500/30 rounded hover:bg-red-500/30 transition-colors cursor-pointer"
|
||||
>
|
||||
DELETE ALL
|
||||
</button>
|
||||
<button
|
||||
onclick={handleCancelDeleteAll}
|
||||
class="flex-1 py-1.5 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/20 text-exo-light-gray border border-exo-medium-gray/30 rounded hover:bg-exo-medium-gray/30 transition-colors cursor-pointer"
|
||||
>
|
||||
<button
|
||||
onclick={handleConfirmDeleteAll}
|
||||
class="flex-1 py-1.5 text-xs font-mono tracking-wider uppercase bg-red-500/20 text-red-400 border border-red-500/30 rounded hover:bg-red-500/30 transition-colors cursor-pointer"
|
||||
>
|
||||
DELETE ALL
|
||||
</button>
|
||||
<button
|
||||
onclick={handleCancelDeleteAll}
|
||||
class="flex-1 py-1.5 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/20 text-exo-light-gray border border-exo-medium-gray/30 rounded hover:bg-exo-medium-gray/30 transition-colors cursor-pointer"
|
||||
>
|
||||
CANCEL
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{:else if conversationList.length > 0}
|
||||
<button
|
||||
onclick={handleDeleteAllClick}
|
||||
class="w-full flex items-center justify-center gap-2 py-1.5 text-sm font-mono tracking-wider uppercase text-white/70 hover:text-red-400 hover:bg-red-500/10 border border-transparent hover:border-red-500/20 rounded transition-all cursor-pointer"
|
||||
>
|
||||
<button
|
||||
onclick={handleDeleteAllClick}
|
||||
class="w-full flex items-center justify-center gap-2 py-1.5 text-sm font-mono tracking-wider uppercase text-white/70 hover:text-red-400 hover:bg-red-500/10 border border-transparent hover:border-red-500/20 rounded transition-all cursor-pointer"
|
||||
>
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" />
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
|
||||
/>
|
||||
</svg>
|
||||
DELETE ALL CHATS
|
||||
</button>
|
||||
{/if}
|
||||
<div class="flex items-center justify-center gap-3 {conversationList.length > 0 && !showDeleteAllConfirm ? 'mt-2' : ''}">
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleDebugMode}
|
||||
class="p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title="Toggle debug mode"
|
||||
>
|
||||
<svg class="w-4 h-4 {debugEnabled ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="currentColor" viewBox="0 0 24 24">
|
||||
<path d="M19 8h-1.81A6.002 6.002 0 0 0 12 2a6.002 6.002 0 0 0-5.19 3H5a1 1 0 0 0 0 2h1v2H5a1 1 0 0 0 0 2h1v2H5a1 1 0 0 0 0 2h1.81A6.002 6.002 0 0 0 12 22a6.002 6.002 0 0 0 5.19-3H19a1 1 0 0 0 0-2h-1v-2h1a1 1 0 0 0 0-2h-1v-2h1a1 1 0 1 0 0-2Zm-5 10.32V19a1 1 0 1 1-2 0v-.68a3.999 3.999 0 0 1-3-3.83V9.32a3.999 3.999 0 0 1 3-3.83V5a1 1 0 0 1 2 0v.49a3.999 3.999 0 0 1 3 3.83v5.17a3.999 3.999 0 0 1-3 3.83Z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<div class="text-xs text-white/60 font-mono tracking-wider text-center">
|
||||
{conversationList.length} CONVERSATION{conversationList.length !== 1 ? 'S' : ''}
|
||||
<div class="flex items-center justify-center gap-3 {conversationList.length > 0 && !showDeleteAllConfirm ? 'mt-2' : ''}">
|
||||
<button type="button" onclick={toggleDebugMode} class="p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer" title="Toggle debug mode">
|
||||
<svg class="w-4 h-4 {debugEnabled ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="currentColor" viewBox="0 0 24 24">
|
||||
<path
|
||||
d="M19 8h-1.81A6.002 6.002 0 0 0 12 2a6.002 6.002 0 0 0-5.19 3H5a1 1 0 0 0 0 2h1v2H5a1 1 0 0 0 0 2h1v2H5a1 1 0 0 0 0 2h1.81A6.002 6.002 0 0 0 12 22a6.002 6.002 0 0 0 5.19-3H19a1 1 0 0 0 0-2h-1v-2h1a1 1 0 0 0 0-2h-1v-2h1a1 1 0 1 0 0-2Zm-5 10.32V19a1 1 0 1 1-2 0v-.68a3.999 3.999 0 0 1-3-3.83V9.32a3.999 3.999 0 0 1 3-3.83V5a1 1 0 0 1 2 0v.49a3.999 3.999 0 0 1 3 3.83v5.17a3.999 3.999 0 0 1-3 3.83Z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
<div class="text-xs text-white/60 font-mono tracking-wider text-center">
|
||||
{conversationList.length} CONVERSATION{conversationList.length !== 1 ? 'S' : ''}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleTopologyOnlyMode}
|
||||
class="p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title="Toggle topology only mode"
|
||||
>
|
||||
<svg class="w-4 h-4 {topologyOnlyEnabled ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="5" r="2" fill="currentColor" />
|
||||
<circle cx="5" cy="19" r="2" fill="currentColor" />
|
||||
<circle cx="19" cy="19" r="2" fill="currentColor" />
|
||||
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
|
||||
export let showHome = true;
|
||||
export let onHome: (() => void) | null = null;
|
||||
export let showSidebarToggle = false;
|
||||
export let sidebarVisible = true;
|
||||
export let onToggleSidebar: (() => void) | null = null;
|
||||
|
||||
function handleHome(): void {
|
||||
if (onHome) {
|
||||
@@ -14,13 +17,38 @@
|
||||
window.location.hash = '/';
|
||||
}
|
||||
}
|
||||
|
||||
function handleToggleSidebar(): void {
|
||||
if (onToggleSidebar) {
|
||||
onToggleSidebar();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<header class="relative z-20 flex items-center justify-center px-6 pt-8 pb-4 bg-exo-dark-gray">
|
||||
<!-- Left: Sidebar Toggle -->
|
||||
{#if showSidebarToggle}
|
||||
<div class="absolute left-6 top-1/2 -translate-y-1/2">
|
||||
<button
|
||||
onclick={handleToggleSidebar}
|
||||
class="p-2 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title={sidebarVisible ? 'Hide sidebar' : 'Show sidebar'}
|
||||
>
|
||||
<svg class="w-5 h-5 {sidebarVisible ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
{#if sidebarVisible}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M11 19l-7-7 7-7m8 14l-7-7 7-7" />
|
||||
{:else}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M13 5l7 7-7 7M5 5l7 7-7 7" />
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Center: Logo (clickable to go home) -->
|
||||
<button
|
||||
onclick={handleHome}
|
||||
class="hover:opacity-80 transition-opacity {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
class="bg-transparent border-none outline-none focus:outline-none transition-opacity duration-200 hover:opacity-90 {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
title={showHome ? 'Go to home' : ''}
|
||||
disabled={!showHome}
|
||||
>
|
||||
@@ -36,16 +64,17 @@
|
||||
title="Back to topology view"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M3 12l2-2m0 0l7-7 7 7M5 10v10a1 1 0 001 1h3m10-11l2 2m-2-2v10a1 1 0 01-1 1h-3m-6 0a1 1 0 001-1v-4a1 1 0 011-1h2a1 1 0 011 1v4a1 1 0 001 1m-6 0h6" />
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M3 12l2-2m0 0l7-7 7 7M5 10v10a1 1 0 001 1h3m10-11l2 2m-2-2v10a1 1 0 01-1 1h-3m-6 0a1 1 0 001-1v-4a1 1 0 011-1h2a1 1 0 011 1v4a1 1 0 001 1m-6 0h6"
|
||||
/>
|
||||
</svg>
|
||||
Home
|
||||
</button>
|
||||
{/if}
|
||||
<a
|
||||
href="/#/downloads"
|
||||
class="text-sm text-exo-light-gray hover:text-exo-yellow transition-colors tracking-wider uppercase flex items-center gap-2 cursor-pointer"
|
||||
title="View downloads overview"
|
||||
>
|
||||
<a href="/#/downloads" class="text-sm text-exo-light-gray hover:text-exo-yellow transition-colors tracking-wider uppercase flex items-center gap-2 cursor-pointer" title="View downloads overview">
|
||||
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M12 3v12" />
|
||||
<path d="M7 12l5 5 5-5" />
|
||||
|
||||
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
@@ -0,0 +1,451 @@
|
||||
<script lang="ts">
|
||||
import { marked } from 'marked';
|
||||
import hljs from 'highlight.js';
|
||||
import katex from 'katex';
|
||||
import 'katex/dist/katex.min.css';
|
||||
import { browser } from '$app/environment';
|
||||
|
||||
interface Props {
|
||||
content: string;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { content, class: className = '' }: Props = $props();
|
||||
|
||||
let containerRef = $state<HTMLDivElement>();
|
||||
let processedHtml = $state('');
|
||||
|
||||
// Configure marked with syntax highlighting
|
||||
marked.setOptions({
|
||||
gfm: true,
|
||||
breaks: true
|
||||
});
|
||||
|
||||
// Custom renderer for code blocks
|
||||
const renderer = new marked.Renderer();
|
||||
|
||||
renderer.code = function ({ text, lang }: { text: string; lang?: string }) {
|
||||
const language = lang && hljs.getLanguage(lang) ? lang : 'plaintext';
|
||||
const highlighted = hljs.highlight(text, { language }).value;
|
||||
const codeId = `code-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`;
|
||||
|
||||
return `
|
||||
<div class="code-block-wrapper">
|
||||
<div class="code-block-header">
|
||||
<span class="code-language">${language}</span>
|
||||
<button type="button" class="copy-code-btn" data-code="${encodeURIComponent(text)}" title="Copy code">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
|
||||
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<pre><code class="hljs language-${language}" data-code-id="${codeId}">${highlighted}</code></pre>
|
||||
</div>
|
||||
`;
|
||||
};
|
||||
|
||||
// Inline code
|
||||
renderer.codespan = function ({ text }: { text: string }) {
|
||||
return `<code class="inline-code">${text}</code>`;
|
||||
};
|
||||
|
||||
marked.use({ renderer });
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
|
||||
* Also protect code blocks from LaTeX processing
|
||||
*/
|
||||
function preprocessLaTeX(text: string): string {
|
||||
// Protect code blocks
|
||||
const codeBlocks: string[] = [];
|
||||
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, match => {
|
||||
codeBlocks.push(match);
|
||||
return `<<CODE_${codeBlocks.length - 1}>>`;
|
||||
});
|
||||
|
||||
// Convert \(...\) to $...$
|
||||
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
|
||||
|
||||
// Convert \[...\] to $$...$$
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
|
||||
|
||||
// Restore code blocks
|
||||
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
|
||||
|
||||
return processed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render math expressions with KaTeX after HTML is generated
|
||||
*/
|
||||
function renderMath(html: string): string {
|
||||
// Render display math ($$...$$)
|
||||
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: true,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$$${math}$$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
// Render inline math ($...$) but avoid matching currency like $5
|
||||
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
|
||||
// Skip if it looks like currency ($ followed by number)
|
||||
if (/^\d/.test(math.trim())) {
|
||||
return match;
|
||||
}
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: false,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$${math}$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
return html;
|
||||
}
|
||||
|
||||
function processMarkdown(text: string): string {
|
||||
try {
|
||||
// Preprocess LaTeX notation
|
||||
const preprocessed = preprocessLaTeX(text);
|
||||
// Parse markdown
|
||||
let html = marked.parse(preprocessed) as string;
|
||||
// Render math expressions
|
||||
html = renderMath(html);
|
||||
return html;
|
||||
} catch (error) {
|
||||
console.error('Markdown processing error:', error);
|
||||
return text.replace(/\n/g, '<br>');
|
||||
}
|
||||
}
|
||||
|
||||
async function handleCopyClick(event: Event) {
|
||||
const target = event.currentTarget as HTMLButtonElement;
|
||||
const encodedCode = target.getAttribute('data-code');
|
||||
if (!encodedCode) return;
|
||||
|
||||
const code = decodeURIComponent(encodedCode);
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(code);
|
||||
// Show copied feedback
|
||||
const originalHtml = target.innerHTML;
|
||||
target.innerHTML = `
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M20 6L9 17l-5-5"/>
|
||||
</svg>
|
||||
`;
|
||||
target.classList.add('copied');
|
||||
setTimeout(() => {
|
||||
target.innerHTML = originalHtml;
|
||||
target.classList.remove('copied');
|
||||
}, 2000);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function setupCopyButtons() {
|
||||
if (!containerRef || !browser) return;
|
||||
|
||||
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of buttons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (content) {
|
||||
processedHtml = processMarkdown(content);
|
||||
} else {
|
||||
processedHtml = '';
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (containerRef && processedHtml) {
|
||||
setupCopyButtons();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class="markdown-content {className}">
|
||||
{@html processedHtml}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.markdown-content {
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
/* Paragraphs */
|
||||
.markdown-content :global(p) {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* Headers */
|
||||
.markdown-content :global(h1) {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 700;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h2) {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h3) {
|
||||
font-size: 1.125rem;
|
||||
font-weight: 600;
|
||||
margin: 1rem 0 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(h4),
|
||||
.markdown-content :global(h5),
|
||||
.markdown-content :global(h6) {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
margin: 0.75rem 0 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Bold and italic */
|
||||
.markdown-content :global(strong) {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(em) {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Inline code */
|
||||
.markdown-content :global(.inline-code) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
padding: 0.125rem 0.375rem;
|
||||
border-radius: 0.25rem;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
}
|
||||
|
||||
/* Links */
|
||||
.markdown-content :global(a) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
text-decoration: underline;
|
||||
text-underline-offset: 2px;
|
||||
}
|
||||
|
||||
.markdown-content :global(a:hover) {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
/* Lists */
|
||||
.markdown-content :global(ul) {
|
||||
list-style-type: disc;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(ol) {
|
||||
list-style-type: decimal;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li) {
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li::marker) {
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
}
|
||||
|
||||
/* Blockquotes */
|
||||
.markdown-content :global(blockquote) {
|
||||
border-left: 3px solid var(--exo-yellow, #ffd700);
|
||||
padding: 0.5rem 1rem;
|
||||
margin: 1rem 0;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-radius: 0 0.25rem 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Tables */
|
||||
.markdown-content :global(table) {
|
||||
width: 100%;
|
||||
margin: 1rem 0;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(th) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
padding: 0.5rem;
|
||||
text-align: left;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(td) {
|
||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||
padding: 0.5rem;
|
||||
}
|
||||
|
||||
/* Horizontal rule */
|
||||
.markdown-content :global(hr) {
|
||||
border: none;
|
||||
border-top: 1px solid rgba(255, 255, 255, 0.1);
|
||||
margin: 1.5rem 0;
|
||||
}
|
||||
|
||||
/* Code block wrapper */
|
||||
.markdown-content :global(.code-block-wrapper) {
|
||||
margin: 1rem 0;
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
background: rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-bottom: 1px solid rgba(255, 215, 0, 0.1);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-language) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
font-size: 0.7rem;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.1em;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
cursor: pointer;
|
||||
transition: color 0.2s;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn:hover) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn.copied) {
|
||||
color: #22c55e;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper pre) {
|
||||
margin: 0;
|
||||
padding: 1rem;
|
||||
overflow-x: auto;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper code) {
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.8125rem;
|
||||
line-height: 1.5;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
/* Syntax highlighting - dark theme matching EXO style */
|
||||
.markdown-content :global(.hljs) {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-keyword),
|
||||
.markdown-content :global(.hljs-selector-tag),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-section),
|
||||
.markdown-content :global(.hljs-link) {
|
||||
color: #c084fc;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-string),
|
||||
.markdown-content :global(.hljs-title),
|
||||
.markdown-content :global(.hljs-name),
|
||||
.markdown-content :global(.hljs-type),
|
||||
.markdown-content :global(.hljs-attribute),
|
||||
.markdown-content :global(.hljs-symbol),
|
||||
.markdown-content :global(.hljs-bullet),
|
||||
.markdown-content :global(.hljs-addition),
|
||||
.markdown-content :global(.hljs-variable),
|
||||
.markdown-content :global(.hljs-template-tag),
|
||||
.markdown-content :global(.hljs-template-variable) {
|
||||
color: #fbbf24;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-comment),
|
||||
.markdown-content :global(.hljs-quote),
|
||||
.markdown-content :global(.hljs-deletion),
|
||||
.markdown-content :global(.hljs-meta) {
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-number),
|
||||
.markdown-content :global(.hljs-regexp),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-built_in) {
|
||||
color: #34d399;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-function),
|
||||
.markdown-content :global(.hljs-class .hljs-title) {
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
/* KaTeX math styling */
|
||||
.markdown-content :global(.katex) {
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display) {
|
||||
margin: 1rem 0;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display > .katex) {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-error) {
|
||||
color: #f87171;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
background: rgba(248, 113, 113, 0.1);
|
||||
padding: 0.125rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
</style>
|
||||
@@ -1,17 +1,18 @@
|
||||
<script lang="ts">
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview } from '$lib/stores/app.svelte';
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview, TopologyEdge } from '$lib/stores/app.svelte';
|
||||
import { debugMode, topologyData } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
interface Props {
|
||||
model: { id: string; name?: string; storage_size_megabytes?: number };
|
||||
isLaunching?: boolean;
|
||||
downloadStatus?: {
|
||||
isDownloading: boolean;
|
||||
progress: DownloadProgress | null;
|
||||
perNode?: Array<{
|
||||
nodeId: string;
|
||||
nodeName: string;
|
||||
progress: DownloadProgress;
|
||||
}>;
|
||||
perNode?: Array<{
|
||||
nodeId: string;
|
||||
nodeName: string;
|
||||
progress: DownloadProgress;
|
||||
}>;
|
||||
} | null;
|
||||
nodes?: Record<string, NodeInfo>;
|
||||
sharding?: 'Pipeline' | 'Tensor';
|
||||
@@ -22,38 +23,27 @@ interface Props {
|
||||
modelIdOverride?: string | null;
|
||||
}
|
||||
|
||||
let {
|
||||
model,
|
||||
isLaunching = false,
|
||||
downloadStatus = null,
|
||||
nodes = {},
|
||||
sharding = 'Pipeline',
|
||||
runtime = 'MlxRing',
|
||||
onLaunch,
|
||||
tags = [],
|
||||
apiPreview = null,
|
||||
modelIdOverride = null
|
||||
}: Props = $props();
|
||||
let { model, isLaunching = false, downloadStatus = null, nodes = {}, sharding = 'Pipeline', runtime = 'MlxRing', onLaunch, tags = [], apiPreview = null, modelIdOverride = null }: Props = $props();
|
||||
|
||||
// Estimate memory requirements from model name
|
||||
// Uses regex with word boundaries to avoid false matches like '4bit' matching '4b'
|
||||
function estimateMemoryGB(modelId: string, modelName?: string): number {
|
||||
// Check both ID and name for quantization info
|
||||
const combined = `${modelId} ${modelName || ''}`.toLowerCase();
|
||||
|
||||
|
||||
// Detect quantization level - affects memory by roughly 2x between levels
|
||||
const is4bit = combined.includes('4bit') || combined.includes('4-bit') || combined.includes(':4bit');
|
||||
const is8bit = combined.includes('8bit') || combined.includes('8-bit') || combined.includes(':8bit');
|
||||
// 4-bit = 0.5 bytes/param, 8-bit = 1 byte/param, fp16 = 2 bytes/param
|
||||
const quantMultiplier = is4bit ? 0.5 : is8bit ? 1 : 2;
|
||||
const id = modelId.toLowerCase();
|
||||
|
||||
|
||||
// Known large models that don't follow the standard naming pattern
|
||||
// DeepSeek V3 has 685B parameters
|
||||
if (id.includes('deepseek-v3')) {
|
||||
return Math.round(685 * quantMultiplier);
|
||||
}
|
||||
// DeepSeek V2 has 236B parameters
|
||||
// DeepSeek V2 has 236B parameters
|
||||
if (id.includes('deepseek-v2')) {
|
||||
return Math.round(236 * quantMultiplier);
|
||||
}
|
||||
@@ -61,14 +51,14 @@ interface Props {
|
||||
if (id.includes('llama-4')) {
|
||||
return Math.round(400 * quantMultiplier);
|
||||
}
|
||||
|
||||
|
||||
// Match parameter counts with word boundaries (e.g., "70b" but not "4bit")
|
||||
const paramMatch = id.match(/(\d+(?:\.\d+)?)\s*b(?![a-z])/i);
|
||||
if (paramMatch) {
|
||||
const params = parseFloat(paramMatch[1]);
|
||||
return Math.max(4, Math.round(params * quantMultiplier));
|
||||
}
|
||||
|
||||
|
||||
// Fallback patterns for explicit size markers (assume fp16 baseline, adjust for quant)
|
||||
if (id.includes('405b') || id.includes('400b')) return Math.round(405 * quantMultiplier);
|
||||
if (id.includes('180b')) return Math.round(180 * quantMultiplier);
|
||||
@@ -82,7 +72,7 @@ interface Props {
|
||||
if (id.includes('8b') || id.includes('9b') || id.includes('7b')) return Math.round(8 * quantMultiplier);
|
||||
if (id.includes('3b') || id.includes('3.8b')) return Math.round(4 * quantMultiplier);
|
||||
if (id.includes('2b') || id.includes('1b') || id.includes('1.5b') || id.includes('0.5b')) return Math.round(2 * quantMultiplier);
|
||||
|
||||
|
||||
return 16; // Default fallback
|
||||
}
|
||||
|
||||
@@ -113,25 +103,21 @@ interface Props {
|
||||
const isDownloading = $derived(downloadStatus?.isDownloading ?? false);
|
||||
const progress = $derived(downloadStatus?.progress);
|
||||
const percentage = $derived(progress?.percentage ?? 0);
|
||||
let expandedNodes = $state<Set<string>>(new Set());
|
||||
let expandedNodes = $state<Set<string>>(new Set());
|
||||
|
||||
function toggleNodeDetails(nodeId: string): void {
|
||||
const next = new Set(expandedNodes);
|
||||
if (next.has(nodeId)) {
|
||||
next.delete(nodeId);
|
||||
} else {
|
||||
next.add(nodeId);
|
||||
function toggleNodeDetails(nodeId: string): void {
|
||||
const next = new Set(expandedNodes);
|
||||
if (next.has(nodeId)) {
|
||||
next.delete(nodeId);
|
||||
} else {
|
||||
next.add(nodeId);
|
||||
}
|
||||
expandedNodes = next;
|
||||
}
|
||||
expandedNodes = next;
|
||||
}
|
||||
|
||||
|
||||
// Use actual storage_size_megabytes from API if available, otherwise fall back to estimate
|
||||
const estimatedMemory = $derived(
|
||||
model.storage_size_megabytes
|
||||
? Math.round(model.storage_size_megabytes / 1024)
|
||||
: estimateMemoryGB(model.id, model.name)
|
||||
);
|
||||
|
||||
const estimatedMemory = $derived(model.storage_size_megabytes ? Math.round(model.storage_size_megabytes / 1024) : estimateMemoryGB(model.id, model.name));
|
||||
|
||||
function getDeviceType(name: string): 'macbook' | 'studio' | 'mini' | 'unknown' {
|
||||
const lower = name.toLowerCase();
|
||||
if (lower.includes('macbook')) return 'macbook';
|
||||
@@ -139,7 +125,7 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
if (lower.includes('mini')) return 'mini';
|
||||
return 'unknown';
|
||||
}
|
||||
|
||||
|
||||
const clampPercent = (value: number): number => Math.min(100, Math.max(0, value));
|
||||
const huggingFaceModelId = $derived(modelIdOverride ?? model.id);
|
||||
|
||||
@@ -148,7 +134,7 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
// topology payload is missing them. Topology order is preserved exactly so
|
||||
// that the mini preview matches the main TopologyGraph layout.
|
||||
const nodeList = $derived(() => {
|
||||
const nodesFromTopology = Object.keys(nodes).map((id) => {
|
||||
const nodesFromTopology = Object.keys(nodes).map(id => {
|
||||
const info = nodes[id];
|
||||
const totalBytes = info.macmon_info?.memory?.ram_total ?? info.system_info?.memory ?? 0;
|
||||
const usedBytes = info.macmon_info?.memory?.ram_usage ?? 0;
|
||||
@@ -168,10 +154,10 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
if (previewIds.length === 0) return nodesFromTopology;
|
||||
|
||||
// Append any preview-only nodes (not in topology) at the end
|
||||
const topologyIds = new Set(nodesFromTopology.map((n) => n.id));
|
||||
const topologyIds = new Set(nodesFromTopology.map(n => n.id));
|
||||
const extraPreviewNodes = previewIds
|
||||
.filter((id) => !topologyIds.has(id))
|
||||
.map((id) => {
|
||||
.filter(id => !topologyIds.has(id))
|
||||
.map(id => {
|
||||
const deltaBytes = previewEntries?.[id] ?? 0;
|
||||
const deltaGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const totalGB = Math.max(deltaGB * 1.2, 1);
|
||||
@@ -191,13 +177,13 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
|
||||
return [...nodesFromTopology, ...extraPreviewNodes];
|
||||
});
|
||||
|
||||
|
||||
// Calculate placement preview with all SVG metrics pre-computed
|
||||
// Uses API preview data when available, falls back to local estimation
|
||||
const placementPreview = $derived(() => {
|
||||
const nodeArray = nodeList();
|
||||
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, error: null };
|
||||
|
||||
|
||||
const numNodes = nodeArray.length;
|
||||
const iconSize = numNodes === 1 ? 50 : 36;
|
||||
const topoWidth = 260;
|
||||
@@ -205,20 +191,16 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
const centerX = topoWidth / 2;
|
||||
const centerY = topoHeight / 2;
|
||||
const radius = numNodes === 1 ? 0 : numNodes === 2 ? 45 : Math.min(topoWidth, topoHeight) * 0.32;
|
||||
|
||||
// Use API preview data if available
|
||||
|
||||
// Only use API preview data - no local estimation
|
||||
const hasApiPreview = apiPreview !== null && apiPreview.error === null && apiPreview.memory_delta_by_node !== null;
|
||||
const canFit = hasApiPreview ? true : (() => {
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return totalAvailable >= estimatedMemory;
|
||||
})();
|
||||
const error = apiPreview?.error ?? null;
|
||||
|
||||
let placementNodes: Array<{
|
||||
|
||||
let placementNodes: Array<{
|
||||
id: string;
|
||||
deviceName: string;
|
||||
deviceType: 'macbook' | 'studio' | 'mini' | 'unknown';
|
||||
totalGB: number;
|
||||
totalGB: number;
|
||||
currentUsedGB: number;
|
||||
modelUsageGB: number;
|
||||
currentPercent: number;
|
||||
@@ -231,136 +213,137 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
currentFillHeight: number;
|
||||
modelFillHeight: number;
|
||||
}> = [];
|
||||
|
||||
if (hasApiPreview && apiPreview.memory_delta_by_node) {
|
||||
// Use API placement data
|
||||
const memoryDelta = apiPreview.memory_delta_by_node;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else if (apiPreview?.error) {
|
||||
// API returned an error - model can't fit, show all nodes as unused
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: 0,
|
||||
currentPercent,
|
||||
newPercent: currentPercent,
|
||||
isUsed: false,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: 0
|
||||
};
|
||||
});
|
||||
} else {
|
||||
// Fallback: local estimation based on sharding strategy
|
||||
const memoryNeeded = estimatedMemory;
|
||||
|
||||
if (sharding === 'Pipeline') {
|
||||
const memoryPerNode = memoryNeeded / numNodes;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + memoryPerNode) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: memoryPerNode,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed: true,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else {
|
||||
let remaining = memoryNeeded;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const allocated = Math.min(remaining, n.availableGB);
|
||||
remaining -= allocated;
|
||||
const isUsed = allocated > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + allocated) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: allocated,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Use API placement data directly
|
||||
const memoryDelta = apiPreview?.memory_delta_by_node ?? {};
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return { nodes: placementNodes, canFit: hasApiPreview || canFit, totalAvailable, topoWidth, topoHeight, error };
|
||||
return { nodes: placementNodes, canFit: hasApiPreview, totalAvailable, topoWidth, topoHeight, error };
|
||||
});
|
||||
|
||||
|
||||
const canFit = $derived(apiPreview ? apiPreview.error === null : placementPreview().canFit);
|
||||
const placementError = $derived(apiPreview?.error ?? null);
|
||||
const nodeCount = $derived(nodeList().length);
|
||||
const filterId = $derived(model.id.replace(/[^a-zA-Z0-9]/g, ''));
|
||||
|
||||
// Debug mode state
|
||||
const isDebugMode = $derived(debugMode());
|
||||
const topology = $derived(topologyData());
|
||||
const isRdma = $derived(runtime === 'MlxIbv' || runtime === 'MlxJaccl');
|
||||
|
||||
// Get interface name for an IP from node data
|
||||
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
|
||||
if (!ip || !topology?.nodes) return null;
|
||||
|
||||
// Strip port if present
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
// Check specified node first
|
||||
const node = topology.nodes[nodeId];
|
||||
if (node) {
|
||||
const match = node.network_interfaces?.find(iface => (iface.addresses || []).some(addr => addr === cleanIp || addr === ip));
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
// Fallback: check all nodes
|
||||
for (const [, otherNode] of Object.entries(topology.nodes)) {
|
||||
if (!otherNode) continue;
|
||||
const match = otherNode.network_interfaces?.find(iface => (iface.addresses || []).some(addr => addr === cleanIp || addr === ip));
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = otherNode.ip_to_interface?.[cleanIp] || otherNode.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
// Get directional arrow based on node positions
|
||||
function getArrow(fromNode: { x: number; y: number }, toNode: { x: number; y: number }): string {
|
||||
const dx = toNode.x - fromNode.x;
|
||||
const dy = toNode.y - fromNode.y;
|
||||
const absX = Math.abs(dx);
|
||||
const absY = Math.abs(dy);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dx > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dy > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dx > 0 && dy > 0) return '↘';
|
||||
if (dx > 0 && dy < 0) return '↗';
|
||||
if (dx < 0 && dy > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Get connection info for edges between two nodes
|
||||
// Returns exactly one connection per direction (A→B and B→A), preferring non-loopback
|
||||
function getConnectionInfo(nodeId1: string, nodeId2: string): Array<{ ip: string; iface: string | null; from: string; to: string }> {
|
||||
if (!topology?.edges) return [];
|
||||
|
||||
// Collect candidates for each direction
|
||||
const aToBCandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
const bToACandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
|
||||
for (const edge of topology.edges) {
|
||||
const ip = edge.sendBackIp || '?';
|
||||
const iface = edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
|
||||
|
||||
if (edge.source === nodeId1 && edge.target === nodeId2) {
|
||||
aToBCandidates.push({ ip, iface });
|
||||
} else if (edge.source === nodeId2 && edge.target === nodeId1) {
|
||||
bToACandidates.push({ ip, iface });
|
||||
}
|
||||
}
|
||||
|
||||
// Pick best (prefer non-loopback)
|
||||
const pickBest = (candidates: Array<{ ip: string; iface: string | null }>) => {
|
||||
if (candidates.length === 0) return null;
|
||||
return candidates.find(c => !c.ip.startsWith('127.')) || candidates[0];
|
||||
};
|
||||
|
||||
const result: Array<{ ip: string; iface: string | null; from: string; to: string }> = [];
|
||||
|
||||
const bestAtoB = pickBest(aToBCandidates);
|
||||
if (bestAtoB) result.push({ ...bestAtoB, from: nodeId1, to: nodeId2 });
|
||||
|
||||
const bestBtoA = pickBest(bToACandidates);
|
||||
if (bestBtoA) result.push({ ...bestBtoA, from: nodeId2, to: nodeId1 });
|
||||
|
||||
return result;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="relative group">
|
||||
@@ -369,34 +352,42 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
<div class="absolute -top-px -right-px w-2 h-2 border-r border-t {canFit ? 'border-exo-yellow/30 group-hover:border-exo-yellow/60' : 'border-red-500/30'} transition-colors"></div>
|
||||
<div class="absolute -bottom-px -left-px w-2 h-2 border-l border-b {canFit ? 'border-exo-yellow/30 group-hover:border-exo-yellow/60' : 'border-red-500/30'} transition-colors"></div>
|
||||
<div class="absolute -bottom-px -right-px w-2 h-2 border-r border-b {canFit ? 'border-exo-yellow/30 group-hover:border-exo-yellow/60' : 'border-red-500/30'} transition-colors"></div>
|
||||
|
||||
<div class="bg-exo-dark-gray/60 border {canFit ? 'border-exo-yellow/20 group-hover:border-exo-yellow/40' : 'border-red-500/20'} p-3 transition-all duration-200 group-hover:shadow-[0_0_15px_rgba(255,215,0,0.1)]">
|
||||
|
||||
<div
|
||||
class="bg-exo-dark-gray/60 border {canFit
|
||||
? 'border-exo-yellow/20 group-hover:border-exo-yellow/40'
|
||||
: 'border-red-500/20'} p-3 transition-all duration-200 group-hover:shadow-[0_0_15px_rgba(255,215,0,0.1)]"
|
||||
>
|
||||
<!-- Model Name & Memory Required -->
|
||||
<div class="flex items-start justify-between gap-2 mb-2">
|
||||
<div class="flex-1 min-w-0">
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate" title={model.name || model.id}>
|
||||
{model.name || model.id}
|
||||
</div>
|
||||
{#if huggingFaceModelId}
|
||||
<a
|
||||
class="shrink-0 text-white/60 hover:text-exo-yellow transition-colors"
|
||||
href={`https://huggingface.co/${huggingFaceModelId}`}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
aria-label="View model on Hugging Face"
|
||||
>
|
||||
<svg class="w-3.5 h-3.5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M14 3h7v7"/>
|
||||
<path d="M10 14l11-11"/>
|
||||
<path d="M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6"/>
|
||||
</svg>
|
||||
</a>
|
||||
{/if}
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate" title={model.name || model.id}>
|
||||
{model.name || model.id}
|
||||
</div>
|
||||
{#if huggingFaceModelId}
|
||||
<a
|
||||
class="shrink-0 text-white/60 hover:text-exo-yellow transition-colors"
|
||||
href={`https://huggingface.co/${huggingFaceModelId}`}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
aria-label="View model on Hugging Face"
|
||||
>
|
||||
<svg class="w-3.5 h-3.5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M14 3h7v7" />
|
||||
<path d="M10 14l11-11" />
|
||||
<path d="M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6" />
|
||||
</svg>
|
||||
</a>
|
||||
{/if}
|
||||
{#if tags.length > 0}
|
||||
<div class="flex gap-1 flex-shrink-0">
|
||||
{#each tags as tag}
|
||||
<span class="px-1.5 py-0.5 text-xs font-mono tracking-wider uppercase rounded {tag === 'FASTEST' ? 'bg-green-500/20 text-green-400 border border-green-500/30' : 'bg-purple-500/20 text-purple-400 border border-purple-500/30'}">
|
||||
<span
|
||||
class="px-1.5 py-0.5 text-xs font-mono tracking-wider uppercase rounded {tag === 'FASTEST'
|
||||
? 'bg-green-500/20 text-green-400 border border-green-500/30'
|
||||
: 'bg-purple-500/20 text-purple-400 border border-purple-500/30'}"
|
||||
>
|
||||
{tag}
|
||||
</span>
|
||||
{/each}
|
||||
@@ -415,48 +406,74 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Configuration Badge -->
|
||||
<div class="flex items-center gap-1.5 mb-2">
|
||||
<span class="px-1.5 py-0.5 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/40">
|
||||
{sharding}
|
||||
</span>
|
||||
<span class="px-1.5 py-0.5 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/40">
|
||||
{runtime === 'MlxRing' ? 'MLX Ring' : runtime === 'MlxIbv' || runtime === 'MlxJaccl' ? 'MLX RDMA' : runtime}
|
||||
</span>
|
||||
<span class="px-1.5 py-0.5 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/40">
|
||||
{runtime === 'MlxRing' ? 'MLX Ring' : runtime === 'MlxIbv' || runtime === 'MlxJaccl' ? 'MLX RDMA' : runtime}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Mini Topology Preview -->
|
||||
{#if placementPreview().nodes.length > 0}
|
||||
{@const preview = placementPreview()}
|
||||
<div class="mb-3 bg-exo-black/60 rounded border border-exo-medium-gray/20 p-2 relative overflow-hidden">
|
||||
<!-- Scanline effect -->
|
||||
<div class="absolute inset-0 bg-[repeating-linear-gradient(0deg,transparent,transparent_2px,rgba(255,215,0,0.02)_2px,rgba(255,215,0,0.02)_4px)] pointer-events-none"></div>
|
||||
|
||||
|
||||
<svg width="100%" height={preview.topoHeight} viewBox="0 0 {preview.topoWidth} {preview.topoHeight}" class="overflow-visible">
|
||||
<defs>
|
||||
<!-- Glow filter for active nodes -->
|
||||
<filter id="nodeGlow-{filterId}" x="-50%" y="-50%" width="200%" height="200%">
|
||||
<feGaussianBlur stdDeviation="2" result="blur"/>
|
||||
<feGaussianBlur stdDeviation="2" result="blur" />
|
||||
<feMerge>
|
||||
<feMergeNode in="blur"/>
|
||||
<feMergeNode in="SourceGraphic"/>
|
||||
<feMergeNode in="blur" />
|
||||
<feMergeNode in="SourceGraphic" />
|
||||
</feMerge>
|
||||
</filter>
|
||||
|
||||
|
||||
<!-- Strong glow for new memory -->
|
||||
<filter id="memGlow-{filterId}" x="-100%" y="-100%" width="300%" height="300%">
|
||||
<feGaussianBlur stdDeviation="3" result="blur"/>
|
||||
<feComposite in="SourceGraphic" in2="blur" operator="over"/>
|
||||
<feGaussianBlur stdDeviation="3" result="blur" />
|
||||
<feComposite in="SourceGraphic" in2="blur" operator="over" />
|
||||
</filter>
|
||||
</defs>
|
||||
|
||||
|
||||
<!-- Connection lines between nodes (if multiple) -->
|
||||
{#if preview.nodes.length > 1}
|
||||
{@const usedNodes = preview.nodes.filter(n => n.isUsed)}
|
||||
{@const nodePositions = Object.fromEntries(preview.nodes.map(n => [n.id, { x: n.x, y: n.y }]))}
|
||||
{@const allConnections =
|
||||
isDebugMode && usedNodes.length > 1
|
||||
? (() => {
|
||||
const conns: Array<{ ip: string; iface: string | null; from: string; to: string; midX: number; midY: number; arrow: string }> = [];
|
||||
for (let i = 0; i < usedNodes.length; i++) {
|
||||
for (let j = i + 1; j < usedNodes.length; j++) {
|
||||
const n1 = usedNodes[i];
|
||||
const n2 = usedNodes[j];
|
||||
const midX = (n1.x + n2.x) / 2;
|
||||
const midY = (n1.y + n2.y) / 2;
|
||||
for (const c of getConnectionInfo(n1.id, n2.id)) {
|
||||
const fromPos = nodePositions[c.from];
|
||||
const toPos = nodePositions[c.to];
|
||||
const arrow = fromPos && toPos ? getArrow(fromPos, toPos) : '→';
|
||||
conns.push({ ...c, midX, midY, arrow });
|
||||
}
|
||||
}
|
||||
}
|
||||
return conns;
|
||||
})()
|
||||
: []}
|
||||
{#each preview.nodes as node, i}
|
||||
{#each preview.nodes.slice(i + 1) as node2}
|
||||
<line
|
||||
x1={node.x} y1={node.y} x2={node2.x} y2={node2.y}
|
||||
<line
|
||||
x1={node.x}
|
||||
y1={node.y}
|
||||
x2={node2.x}
|
||||
y2={node2.y}
|
||||
stroke={node.isUsed && node2.isUsed ? '#FFD700' : '#374151'}
|
||||
stroke-width="1"
|
||||
stroke-dasharray={node.isUsed && node2.isUsed ? '4,2' : '2,4'}
|
||||
@@ -464,47 +481,99 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
/>
|
||||
{/each}
|
||||
{/each}
|
||||
<!-- Debug: Show connection IPs/interfaces in corners -->
|
||||
{#if isDebugMode && allConnections.length > 0}
|
||||
{@const centerX = preview.topoWidth / 2}
|
||||
{@const centerY = preview.topoHeight / 2}
|
||||
{@const quadrants = {
|
||||
topLeft: allConnections.filter(c => c.midX < centerX && c.midY < centerY),
|
||||
topRight: allConnections.filter(c => c.midX >= centerX && c.midY < centerY),
|
||||
bottomLeft: allConnections.filter(c => c.midX < centerX && c.midY >= centerY),
|
||||
bottomRight: allConnections.filter(c => c.midX >= centerX && c.midY >= centerY)
|
||||
}}
|
||||
{@const padding = 4}
|
||||
{@const lineHeight = 8}
|
||||
<!-- Top Left -->
|
||||
{#each quadrants.topLeft as conn, idx}
|
||||
<text
|
||||
x={padding}
|
||||
y={padding + idx * lineHeight}
|
||||
text-anchor="start"
|
||||
dominant-baseline="hanging"
|
||||
font-size="6"
|
||||
font-family="SF Mono, Monaco, monospace"
|
||||
fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}
|
||||
>
|
||||
{conn.arrow}
|
||||
{isRdma ? conn.iface || '?' : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Top Right -->
|
||||
{#each quadrants.topRight as conn, idx}
|
||||
<text
|
||||
x={preview.topoWidth - padding}
|
||||
y={padding + idx * lineHeight}
|
||||
text-anchor="end"
|
||||
dominant-baseline="hanging"
|
||||
font-size="6"
|
||||
font-family="SF Mono, Monaco, monospace"
|
||||
fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}
|
||||
>
|
||||
{conn.arrow}
|
||||
{isRdma ? conn.iface || '?' : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Left -->
|
||||
{#each quadrants.bottomLeft as conn, idx}
|
||||
<text
|
||||
x={padding}
|
||||
y={preview.topoHeight - padding - (quadrants.bottomLeft.length - 1 - idx) * lineHeight}
|
||||
text-anchor="start"
|
||||
dominant-baseline="auto"
|
||||
font-size="6"
|
||||
font-family="SF Mono, Monaco, monospace"
|
||||
fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}
|
||||
>
|
||||
{conn.arrow}
|
||||
{isRdma ? conn.iface || '?' : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Right -->
|
||||
{#each quadrants.bottomRight as conn, idx}
|
||||
<text
|
||||
x={preview.topoWidth - padding}
|
||||
y={preview.topoHeight - padding - (quadrants.bottomRight.length - 1 - idx) * lineHeight}
|
||||
text-anchor="end"
|
||||
dominant-baseline="auto"
|
||||
font-size="6"
|
||||
font-family="SF Mono, Monaco, monospace"
|
||||
fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}
|
||||
>
|
||||
{conn.arrow}
|
||||
{isRdma ? conn.iface || '?' : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
|
||||
{#each preview.nodes as node}
|
||||
<g
|
||||
transform="translate({node.x}, {node.y})"
|
||||
opacity={node.isUsed ? 1 : 0.25}
|
||||
filter={node.isUsed ? `url(#nodeGlow-${filterId})` : 'none'}
|
||||
>
|
||||
<g transform="translate({node.x}, {node.y})" opacity={node.isUsed ? 1 : 0.25} filter={node.isUsed ? `url(#nodeGlow-${filterId})` : 'none'}>
|
||||
<!-- Device icon based on type -->
|
||||
{#if node.deviceType === 'macbook'}
|
||||
<!-- MacBook Pro icon with memory fill -->
|
||||
<g transform="translate({-node.iconSize/2}, {-node.iconSize/2})">
|
||||
<g transform="translate({-node.iconSize / 2}, {-node.iconSize / 2})">
|
||||
<!-- Screen bezel -->
|
||||
<rect
|
||||
x="2" y="0"
|
||||
width={node.iconSize - 4} height={node.iconSize * 0.65}
|
||||
rx="2"
|
||||
fill="none"
|
||||
stroke={node.isUsed ? '#FFD700' : '#4B5563'}
|
||||
stroke-width="1.5"
|
||||
/>
|
||||
<rect x="2" y="0" width={node.iconSize - 4} height={node.iconSize * 0.65} rx="2" fill="none" stroke={node.isUsed ? '#FFD700' : '#4B5563'} stroke-width="1.5" />
|
||||
<!-- Screen area (memory fill container) -->
|
||||
<rect
|
||||
x="4" y="2"
|
||||
width={node.iconSize - 8} height={node.screenHeight}
|
||||
fill="#0a0a0a"
|
||||
/>
|
||||
<rect x="4" y="2" width={node.iconSize - 8} height={node.screenHeight} fill="#0a0a0a" />
|
||||
<!-- Current memory fill (gray) -->
|
||||
<rect
|
||||
x="4"
|
||||
y={2 + node.screenHeight - node.currentFillHeight}
|
||||
width={node.iconSize - 8}
|
||||
height={node.currentFillHeight}
|
||||
fill="#374151"
|
||||
/>
|
||||
<rect x="4" y={2 + node.screenHeight - node.currentFillHeight} width={node.iconSize - 8} height={node.currentFillHeight} fill="#374151" />
|
||||
<!-- New model memory fill (glowing yellow) -->
|
||||
{#if node.modelUsageGB > 0 && node.isUsed}
|
||||
<rect
|
||||
x="4"
|
||||
<rect
|
||||
x="4"
|
||||
y={2 + node.screenHeight - node.currentFillHeight - node.modelFillHeight}
|
||||
width={node.iconSize - 8}
|
||||
width={node.iconSize - 8}
|
||||
height={node.modelFillHeight}
|
||||
fill="#FFD700"
|
||||
filter="url(#memGlow-{filterId})"
|
||||
@@ -512,7 +581,7 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
/>
|
||||
{/if}
|
||||
<!-- Base/keyboard -->
|
||||
<path
|
||||
<path
|
||||
d="M 0 {node.iconSize * 0.68} L {node.iconSize} {node.iconSize * 0.68} L {node.iconSize - 2} {node.iconSize * 0.78} L 2 {node.iconSize * 0.78} Z"
|
||||
fill="none"
|
||||
stroke={node.isUsed ? '#FFD700' : '#4B5563'}
|
||||
@@ -521,35 +590,18 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
</g>
|
||||
{:else if node.deviceType === 'studio'}
|
||||
<!-- Mac Studio icon -->
|
||||
<g transform="translate({-node.iconSize/2}, {-node.iconSize/2})">
|
||||
<rect
|
||||
x="2" y="2"
|
||||
width={node.iconSize - 4} height={node.iconSize - 4}
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke={node.isUsed ? '#FFD700' : '#4B5563'}
|
||||
stroke-width="1.5"
|
||||
/>
|
||||
<g transform="translate({-node.iconSize / 2}, {-node.iconSize / 2})">
|
||||
<rect x="2" y="2" width={node.iconSize - 4} height={node.iconSize - 4} rx="4" fill="none" stroke={node.isUsed ? '#FFD700' : '#4B5563'} stroke-width="1.5" />
|
||||
<!-- Memory fill background -->
|
||||
<rect
|
||||
x="4" y="4"
|
||||
width={node.iconSize - 8} height={node.iconSize - 8}
|
||||
fill="#0a0a0a"
|
||||
/>
|
||||
<rect x="4" y="4" width={node.iconSize - 8} height={node.iconSize - 8} fill="#0a0a0a" />
|
||||
<!-- Current memory fill -->
|
||||
<rect
|
||||
x="4"
|
||||
y={4 + (node.iconSize - 8) * (1 - node.currentPercent / 100)}
|
||||
width={node.iconSize - 8}
|
||||
height={(node.iconSize - 8) * (node.currentPercent / 100)}
|
||||
fill="#374151"
|
||||
/>
|
||||
<rect x="4" y={4 + (node.iconSize - 8) * (1 - node.currentPercent / 100)} width={node.iconSize - 8} height={(node.iconSize - 8) * (node.currentPercent / 100)} fill="#374151" />
|
||||
<!-- New model memory fill -->
|
||||
{#if node.modelUsageGB > 0 && node.isUsed}
|
||||
<rect
|
||||
x="4"
|
||||
<rect
|
||||
x="4"
|
||||
y={4 + (node.iconSize - 8) * (1 - node.newPercent / 100)}
|
||||
width={node.iconSize - 8}
|
||||
width={node.iconSize - 8}
|
||||
height={(node.iconSize - 8) * ((node.newPercent - node.currentPercent) / 100)}
|
||||
fill="#FFD700"
|
||||
filter="url(#memGlow-{filterId})"
|
||||
@@ -559,36 +611,25 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
</g>
|
||||
{:else if node.deviceType === 'mini'}
|
||||
<!-- Mac Mini icon -->
|
||||
<g transform="translate({-node.iconSize/2}, {-node.iconSize/2})">
|
||||
<rect
|
||||
x="2" y={node.iconSize * 0.3}
|
||||
width={node.iconSize - 4} height={node.iconSize * 0.4}
|
||||
rx="3"
|
||||
fill="none"
|
||||
stroke={node.isUsed ? '#FFD700' : '#4B5563'}
|
||||
stroke-width="1.5"
|
||||
/>
|
||||
<g transform="translate({-node.iconSize / 2}, {-node.iconSize / 2})">
|
||||
<rect x="2" y={node.iconSize * 0.3} width={node.iconSize - 4} height={node.iconSize * 0.4} rx="3" fill="none" stroke={node.isUsed ? '#FFD700' : '#4B5563'} stroke-width="1.5" />
|
||||
<!-- Memory fill background -->
|
||||
<rect
|
||||
x="4" y={node.iconSize * 0.32}
|
||||
width={node.iconSize - 8} height={node.iconSize * 0.36}
|
||||
fill="#0a0a0a"
|
||||
/>
|
||||
<rect x="4" y={node.iconSize * 0.32} width={node.iconSize - 8} height={node.iconSize * 0.36} fill="#0a0a0a" />
|
||||
<!-- Current memory fill -->
|
||||
<rect
|
||||
x="4"
|
||||
y={node.iconSize * 0.32 + (node.iconSize * 0.36) * (1 - node.currentPercent / 100)}
|
||||
width={node.iconSize - 8}
|
||||
height={(node.iconSize * 0.36) * (node.currentPercent / 100)}
|
||||
<rect
|
||||
x="4"
|
||||
y={node.iconSize * 0.32 + node.iconSize * 0.36 * (1 - node.currentPercent / 100)}
|
||||
width={node.iconSize - 8}
|
||||
height={node.iconSize * 0.36 * (node.currentPercent / 100)}
|
||||
fill="#374151"
|
||||
/>
|
||||
<!-- New model memory fill -->
|
||||
{#if node.modelUsageGB > 0 && node.isUsed}
|
||||
<rect
|
||||
x="4"
|
||||
y={node.iconSize * 0.32 + (node.iconSize * 0.36) * (1 - node.newPercent / 100)}
|
||||
width={node.iconSize - 8}
|
||||
height={(node.iconSize * 0.36) * ((node.newPercent - node.currentPercent) / 100)}
|
||||
<rect
|
||||
x="4"
|
||||
y={node.iconSize * 0.32 + node.iconSize * 0.36 * (1 - node.newPercent / 100)}
|
||||
width={node.iconSize - 8}
|
||||
height={node.iconSize * 0.36 * ((node.newPercent - node.currentPercent) / 100)}
|
||||
fill="#FFD700"
|
||||
filter="url(#memGlow-{filterId})"
|
||||
class="animate-pulse-slow"
|
||||
@@ -597,19 +638,20 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
</g>
|
||||
{:else}
|
||||
<!-- Unknown device - hexagon -->
|
||||
<g transform="translate({-node.iconSize/2}, {-node.iconSize/2})">
|
||||
<polygon
|
||||
points="{node.iconSize/2},0 {node.iconSize},{node.iconSize*0.25} {node.iconSize},{node.iconSize*0.75} {node.iconSize/2},{node.iconSize} 0,{node.iconSize*0.75} 0,{node.iconSize*0.25}"
|
||||
<g transform="translate({-node.iconSize / 2}, {-node.iconSize / 2})">
|
||||
<polygon
|
||||
points="{node.iconSize / 2},0 {node.iconSize},{node.iconSize * 0.25} {node.iconSize},{node.iconSize * 0.75} {node.iconSize / 2},{node.iconSize} 0,{node.iconSize *
|
||||
0.75} 0,{node.iconSize * 0.25}"
|
||||
fill={node.isUsed ? 'rgba(255,215,0,0.1)' : '#0a0a0a'}
|
||||
stroke={node.isUsed ? '#FFD700' : '#4B5563'}
|
||||
stroke-width="1.5"
|
||||
/>
|
||||
</g>
|
||||
{/if}
|
||||
|
||||
|
||||
<!-- Percentage label -->
|
||||
<text
|
||||
y={node.iconSize/2 + 12}
|
||||
<text
|
||||
y={node.iconSize / 2 + 12}
|
||||
text-anchor="middle"
|
||||
font-size="8"
|
||||
font-family="SF Mono, Monaco, monospace"
|
||||
@@ -627,13 +669,12 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
<button
|
||||
onclick={onLaunch}
|
||||
disabled={isLaunching || !canFit}
|
||||
class="w-full py-2 text-sm font-mono tracking-wider uppercase border transition-all duration-200
|
||||
{isLaunching
|
||||
? 'bg-transparent text-exo-yellow border-exo-yellow/50 cursor-wait'
|
||||
: !canFit
|
||||
? 'bg-red-500/10 text-red-400/70 border-red-500/30 cursor-not-allowed'
|
||||
: 'bg-transparent text-exo-light-gray border-exo-light-gray/40 hover:text-exo-yellow hover:border-exo-yellow/50 cursor-pointer'
|
||||
}"
|
||||
class="w-full py-2 text-sm font-mono tracking-wider uppercase border transition-all duration-200
|
||||
{isLaunching
|
||||
? 'bg-transparent text-exo-yellow border-exo-yellow/50 cursor-wait'
|
||||
: !canFit
|
||||
? 'bg-red-500/10 text-red-400/70 border-red-500/30 cursor-not-allowed'
|
||||
: 'bg-transparent text-exo-light-gray border-exo-light-gray/40 hover:text-exo-yellow hover:border-exo-yellow/50 cursor-pointer'}"
|
||||
>
|
||||
{#if isLaunching}
|
||||
<span class="flex items-center justify-center gap-1.5">
|
||||
@@ -651,8 +692,13 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
|
||||
<style>
|
||||
@keyframes pulse-slow {
|
||||
0%, 100% { opacity: 0.8; }
|
||||
50% { opacity: 1; }
|
||||
0%,
|
||||
100% {
|
||||
opacity: 0.8;
|
||||
}
|
||||
50% {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
.animate-pulse-slow {
|
||||
animation: pulse-slow 1.5s ease-in-out infinite;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<script lang="ts">
|
||||
import { onMount, onDestroy } from 'svelte';
|
||||
import * as d3 from 'd3';
|
||||
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
|
||||
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -13,62 +13,78 @@ import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.sv
|
||||
let svgContainer: SVGSVGElement | undefined = $state();
|
||||
let resizeObserver: ResizeObserver | undefined;
|
||||
|
||||
const isMinimized = $derived(isTopologyMinimized());
|
||||
const data = $derived(topologyData());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const isMinimized = $derived(isTopologyMinimized());
|
||||
const data = $derived(topologyData());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
}
|
||||
|
||||
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
|
||||
if (!ip) return { label: '?', missing: true };
|
||||
const node = data?.nodes?.[nodeId];
|
||||
if (!node) return { label: '?', missing: true };
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === ip)
|
||||
);
|
||||
if (matchFromInterfaces?.name) {
|
||||
return { label: matchFromInterfaces.name, missing: false };
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
}
|
||||
|
||||
const mapped = node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return { label: mapped, missing: false };
|
||||
}
|
||||
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
|
||||
if (!ip) return { label: '?', missing: true };
|
||||
|
||||
return { label: '?', missing: true };
|
||||
}
|
||||
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
function wrapLine(text: string, maxLen: number): string[] {
|
||||
if (text.length <= maxLen) return [text];
|
||||
const words = text.split(' ');
|
||||
const lines: string[] = [];
|
||||
let current = '';
|
||||
for (const word of words) {
|
||||
if (word.length > maxLen) {
|
||||
if (current) {
|
||||
lines.push(current);
|
||||
current = '';
|
||||
// Helper to check a node's interfaces
|
||||
function checkNode(node: (typeof data.nodes)[string]): string | null {
|
||||
if (!node) return null;
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find(iface => (iface.addresses || []).some(addr => addr === cleanIp || addr === ip));
|
||||
if (matchFromInterfaces?.name) {
|
||||
return matchFromInterfaces.name;
|
||||
}
|
||||
for (let i = 0; i < word.length; i += maxLen) {
|
||||
lines.push(word.slice(i, i + maxLen));
|
||||
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
}
|
||||
} else if ((current + ' ' + word).trim().length > maxLen) {
|
||||
lines.push(current);
|
||||
current = word;
|
||||
} else {
|
||||
current = current ? `${current} ${word}` : word;
|
||||
return null;
|
||||
}
|
||||
|
||||
// Try specified node first
|
||||
const result = checkNode(data?.nodes?.[nodeId]);
|
||||
if (result) return { label: result, missing: false };
|
||||
|
||||
// Fallback: search all nodes for this IP
|
||||
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
|
||||
const otherResult = checkNode(otherNode);
|
||||
if (otherResult) return { label: otherResult, missing: false };
|
||||
}
|
||||
|
||||
return { label: '?', missing: true };
|
||||
}
|
||||
|
||||
function wrapLine(text: string, maxLen: number): string[] {
|
||||
if (text.length <= maxLen) return [text];
|
||||
const words = text.split(' ');
|
||||
const lines: string[] = [];
|
||||
let current = '';
|
||||
for (const word of words) {
|
||||
if (word.length > maxLen) {
|
||||
if (current) {
|
||||
lines.push(current);
|
||||
current = '';
|
||||
}
|
||||
for (let i = 0; i < word.length; i += maxLen) {
|
||||
lines.push(word.slice(i, i + maxLen));
|
||||
}
|
||||
} else if ((current + ' ' + word).trim().length > maxLen) {
|
||||
lines.push(current);
|
||||
current = word;
|
||||
} else {
|
||||
current = current ? `${current} ${word}` : word;
|
||||
}
|
||||
}
|
||||
if (current) lines.push(current);
|
||||
return lines;
|
||||
}
|
||||
if (current) lines.push(current);
|
||||
return lines;
|
||||
}
|
||||
|
||||
// Apple logo path for MacBook Pro screen
|
||||
const APPLE_LOGO_PATH = "M788.1 340.9c-5.8 4.5-108.2 62.2-108.2 190.5 0 148.4 130.3 200.9 134.2 202.2-.6 3.2-20.7 71.9-68.7 141.9-42.8 61.6-87.5 123.1-155.5 123.1s-85.5-39.5-164-39.5c-76.5 0-103.7 40.8-165.9 40.8s-105.6-57-155.5-127C46.7 790.7 0 663 0 541.8c0-194.4 126.4-297.5 250.8-297.5 66.1 0 121.2 43.4 162.7 43.4 39.5 0 101.1-46 176.3-46 28.5 0 130.9 2.6 198.3 99.2zm-234-181.5c31.1-36.9 53.1-88.1 53.1-139.3 0-7.1-.6-14.3-1.9-20.1-50.6 1.9-110.8 33.7-147.1 75.8-28.5 32.4-55.1 83.6-55.1 135.5 0 7.8 1.3 15.6 1.9 18.1 3.2.6 8.4 1.3 13.6 1.3 45.4 0 102.5-30.4 135.5-71.3z";
|
||||
const APPLE_LOGO_PATH =
|
||||
'M788.1 340.9c-5.8 4.5-108.2 62.2-108.2 190.5 0 148.4 130.3 200.9 134.2 202.2-.6 3.2-20.7 71.9-68.7 141.9-42.8 61.6-87.5 123.1-155.5 123.1s-85.5-39.5-164-39.5c-76.5 0-103.7 40.8-165.9 40.8s-105.6-57-155.5-127C46.7 790.7 0 663 0 541.8c0-194.4 126.4-297.5 250.8-297.5 66.1 0 121.2 43.4 162.7 43.4 39.5 0 101.1-46 176.3-46 28.5 0 130.9 2.6 198.3 99.2zm-234-181.5c31.1-36.9 53.1-88.1 53.1-139.3 0-7.1-.6-14.3-1.9-20.1-50.6 1.9-110.8 33.7-147.1 75.8-28.5 32.4-55.1 83.6-55.1 135.5 0 7.8 1.3 15.6 1.9 18.1 3.2.6 8.4 1.3 13.6 1.3 45.4 0 102.5-30.4 135.5-71.3z';
|
||||
const LOGO_NATIVE_WIDTH = 814;
|
||||
const LOGO_NATIVE_HEIGHT = 1000;
|
||||
|
||||
@@ -83,17 +99,17 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
function getTemperatureColor(temp: number): string {
|
||||
// Default for N/A temp - light gray
|
||||
if (isNaN(temp) || temp === null) return 'rgba(179, 179, 179, 0.8)';
|
||||
|
||||
const coolTemp = 45; // Temp for pure blue
|
||||
|
||||
const coolTemp = 45; // Temp for pure blue
|
||||
const midTemp = 57.5; // Temp for pure yellow
|
||||
const hotTemp = 75; // Temp for pure red
|
||||
|
||||
const coolColor = { r: 93, g: 173, b: 226 }; // #5DADE2 (Blue)
|
||||
const midColor = { r: 255, g: 215, b: 0 }; // #FFD700 (Yellow)
|
||||
const hotColor = { r: 244, g: 67, b: 54 }; // #F44336 (Red)
|
||||
|
||||
const hotTemp = 75; // Temp for pure red
|
||||
|
||||
const coolColor = { r: 93, g: 173, b: 226 }; // #5DADE2 (Blue)
|
||||
const midColor = { r: 255, g: 215, b: 0 }; // #FFD700 (Yellow)
|
||||
const hotColor = { r: 244, g: 67, b: 54 }; // #F44336 (Red)
|
||||
|
||||
let r: number, g: number, b: number;
|
||||
|
||||
|
||||
if (temp <= coolTemp) {
|
||||
({ r, g, b } = coolColor);
|
||||
} else if (temp <= midTemp) {
|
||||
@@ -109,7 +125,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
} else {
|
||||
({ r, g, b } = hotColor);
|
||||
}
|
||||
|
||||
|
||||
return `rgb(${r}, ${g}, ${b})`;
|
||||
}
|
||||
|
||||
@@ -132,23 +148,17 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
|
||||
// Add defs for clip paths and filters
|
||||
const defs = svg.append('defs');
|
||||
|
||||
|
||||
// Glow filter
|
||||
const glowFilter = defs.append('filter')
|
||||
.attr('id', 'glow')
|
||||
.attr('x', '-50%')
|
||||
.attr('y', '-50%')
|
||||
.attr('width', '200%')
|
||||
.attr('height', '200%');
|
||||
glowFilter.append('feGaussianBlur')
|
||||
.attr('stdDeviation', '2')
|
||||
.attr('result', 'coloredBlur');
|
||||
const glowFilter = defs.append('filter').attr('id', 'glow').attr('x', '-50%').attr('y', '-50%').attr('width', '200%').attr('height', '200%');
|
||||
glowFilter.append('feGaussianBlur').attr('stdDeviation', '2').attr('result', 'coloredBlur');
|
||||
const glowMerge = glowFilter.append('feMerge');
|
||||
glowMerge.append('feMergeNode').attr('in', 'coloredBlur');
|
||||
glowMerge.append('feMergeNode').attr('in', 'SourceGraphic');
|
||||
|
||||
// Arrowhead marker for directional edges
|
||||
const marker = defs.append('marker')
|
||||
const marker = defs
|
||||
.append('marker')
|
||||
.attr('id', 'arrowhead')
|
||||
.attr('viewBox', '0 0 10 10')
|
||||
.attr('refX', '10')
|
||||
@@ -156,7 +166,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('markerWidth', '11')
|
||||
.attr('markerHeight', '11')
|
||||
.attr('orient', 'auto-start-reverse');
|
||||
marker.append('path')
|
||||
marker
|
||||
.append('path')
|
||||
.attr('d', 'M 0 0 L 10 5 L 0 10')
|
||||
.attr('fill', 'none')
|
||||
.attr('stroke', 'var(--exo-light-gray, #B3B3B3)')
|
||||
@@ -166,7 +177,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.style('animation', 'none');
|
||||
|
||||
if (nodeIds.length === 0) {
|
||||
svg.append('text')
|
||||
svg
|
||||
.append('text')
|
||||
.attr('x', centerX)
|
||||
.attr('y', centerY)
|
||||
.attr('text-anchor', 'middle')
|
||||
@@ -181,23 +193,21 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
|
||||
const numNodes = nodeIds.length;
|
||||
const minDimension = Math.min(width, height);
|
||||
|
||||
|
||||
// Dynamic scaling - larger nodes for big displays
|
||||
const sizeScale = numNodes === 1 ? 1 : Math.max(0.6, 1 - (numNodes - 1) * 0.10);
|
||||
const baseNodeRadius = isMinimized
|
||||
? Math.max(36, Math.min(60, minDimension * 0.22))
|
||||
: Math.min(120, minDimension * 0.20);
|
||||
const sizeScale = numNodes === 1 ? 1 : Math.max(0.6, 1 - (numNodes - 1) * 0.1);
|
||||
const baseNodeRadius = isMinimized ? Math.max(36, Math.min(60, minDimension * 0.22)) : Math.min(120, minDimension * 0.2);
|
||||
const nodeRadius = baseNodeRadius * sizeScale;
|
||||
|
||||
|
||||
// Orbit radius - balanced spacing for nodes
|
||||
const circumference = numNodes * nodeRadius * 4;
|
||||
const radiusFromCircumference = circumference / (2 * Math.PI);
|
||||
const minOrbitRadius = Math.max(radiusFromCircumference, minDimension * 0.18);
|
||||
const maxOrbitRadius = minDimension * 0.30;
|
||||
const orbitRadius = isMinimized
|
||||
const maxOrbitRadius = minDimension * 0.3;
|
||||
const orbitRadius = isMinimized
|
||||
? Math.min(maxOrbitRadius, Math.max(minOrbitRadius, minDimension * 0.26))
|
||||
: Math.min(maxOrbitRadius, Math.max(minOrbitRadius, minDimension * (0.22 + numNodes * 0.02)));
|
||||
|
||||
|
||||
// Determine display mode based on space and node count
|
||||
const showFullLabels = !isMinimized && numNodes <= 4;
|
||||
const showCompactLabels = !isMinimized && numNodes > 4;
|
||||
@@ -206,7 +216,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const topPadding = 70; // Space for "NETWORK TOPOLOGY" label and node names
|
||||
const bottomPadding = 70; // Space for stats and bottom label
|
||||
const safeCenterY = topPadding + (height - topPadding - bottomPadding) / 2;
|
||||
|
||||
|
||||
// Calculate node positions
|
||||
const nodesWithPositions = nodeIds.map((id, index) => {
|
||||
if (numNodes === 1) {
|
||||
@@ -220,7 +230,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
}
|
||||
// Distribute nodes around the orbit
|
||||
// Start from top (-90 degrees) and go clockwise
|
||||
const angle = (index / numNodes) * 2 * Math.PI - (Math.PI / 2);
|
||||
const angle = (index / numNodes) * 2 * Math.PI - Math.PI / 2;
|
||||
return {
|
||||
id,
|
||||
data: nodes[id],
|
||||
@@ -230,7 +240,9 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
});
|
||||
|
||||
const positionById: Record<string, { x: number; y: number }> = {};
|
||||
nodesWithPositions.forEach(n => { positionById[n.id] = { x: n.x, y: n.y }; });
|
||||
nodesWithPositions.forEach(n => {
|
||||
positionById[n.id] = { x: n.x, y: n.y };
|
||||
});
|
||||
|
||||
// Draw edges
|
||||
const linksGroup = svg.append('g').attr('class', 'links-group');
|
||||
@@ -238,15 +250,16 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
|
||||
|
||||
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
|
||||
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
|
||||
edges.forEach(edge => {
|
||||
if (!edge.source || !edge.target || edge.source === edge.target) return;
|
||||
if (!positionById[edge.source] || !positionById[edge.target]) return;
|
||||
|
||||
|
||||
const a = edge.source < edge.target ? edge.source : edge.target;
|
||||
const b = edge.source < edge.target ? edge.target : edge.source;
|
||||
const key = `${a}|${b}`;
|
||||
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false, connections: [] };
|
||||
|
||||
|
||||
if (edge.source === a) entry.aToB = true;
|
||||
else entry.bToA = true;
|
||||
|
||||
@@ -268,12 +281,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
if (!posA || !posB) return;
|
||||
|
||||
// Base dashed line
|
||||
linksGroup.append('line')
|
||||
.attr('x1', posA.x)
|
||||
.attr('y1', posA.y)
|
||||
.attr('x2', posB.x)
|
||||
.attr('y2', posB.y)
|
||||
.attr('class', 'graph-link');
|
||||
linksGroup.append('line').attr('x1', posA.x).attr('y1', posA.y).attr('x2', posB.x).attr('y2', posB.y).attr('class', 'graph-link');
|
||||
|
||||
// Calculate midpoint and direction for arrows
|
||||
const dx = posB.x - posA.x;
|
||||
@@ -290,7 +298,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
if (entry.aToB) {
|
||||
const tipX = mx - ux * tipOffset;
|
||||
const tipY = my - uy * tipOffset;
|
||||
arrowsGroup.append('line')
|
||||
arrowsGroup
|
||||
.append('line')
|
||||
.attr('x1', tipX - ux * carrier)
|
||||
.attr('y1', tipY - uy * carrier)
|
||||
.attr('x2', tipX)
|
||||
@@ -304,7 +313,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
if (entry.bToA) {
|
||||
const tipX = mx + ux * tipOffset;
|
||||
const tipY = my + uy * tipOffset;
|
||||
arrowsGroup.append('line')
|
||||
arrowsGroup
|
||||
.append('line')
|
||||
.attr('x1', tipX + ux * carrier)
|
||||
.attr('y1', tipY + uy * carrier)
|
||||
.attr('x2', tipX)
|
||||
@@ -314,110 +324,99 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('marker-end', 'url(#arrowhead)');
|
||||
}
|
||||
|
||||
// Collect debug labels for later positioning at edges
|
||||
if (debugEnabled && entry.connections.length > 0) {
|
||||
const maxBoxes = 6;
|
||||
const fontSize = isMinimized ? 8 : 9;
|
||||
const lineGap = 2;
|
||||
const labelOffsetOut = Math.max(140, minDimension * 0.38);
|
||||
const labelOffsetSide = isMinimized ? 16 : 20;
|
||||
const boxWidth = 170;
|
||||
const maxLineLen = 26;
|
||||
// Determine which side of viewport based on edge midpoint
|
||||
const isLeft = mx < centerX;
|
||||
const isTop = my < safeCenterY;
|
||||
|
||||
const connections = entry.connections.slice(0, maxBoxes);
|
||||
if (entry.connections.length > maxBoxes) {
|
||||
const remaining = entry.connections.length - maxBoxes;
|
||||
connections.push({
|
||||
from: '',
|
||||
to: '',
|
||||
ip: `(+${remaining} more)`,
|
||||
ifaceLabel: '',
|
||||
missingIface: false
|
||||
});
|
||||
}
|
||||
|
||||
let dirX = mx - centerX;
|
||||
let dirY = my - centerY;
|
||||
const dirLen = Math.hypot(dirX, dirY);
|
||||
if (dirLen < 1) {
|
||||
dirX = -uy;
|
||||
dirY = ux;
|
||||
} else {
|
||||
dirX /= dirLen;
|
||||
dirY /= dirLen;
|
||||
}
|
||||
|
||||
const nx = -dirY;
|
||||
const ny = dirX;
|
||||
|
||||
const labelXRaw = mx + dirX * labelOffsetOut + nx * labelOffsetSide;
|
||||
const labelYRaw = my + dirY * labelOffsetOut + ny * labelOffsetSide;
|
||||
const clampPad = Math.min(120, minDimension * 0.12);
|
||||
const labelX = Math.max(clampPad, Math.min(width - clampPad, labelXRaw));
|
||||
const labelY = Math.max(clampPad, Math.min(height - clampPad, labelYRaw));
|
||||
|
||||
const labelGroup = debugLabelsGroup.append('g')
|
||||
.attr('transform', `translate(${labelX}, ${labelY})`);
|
||||
|
||||
const textGroup = labelGroup.append('g');
|
||||
|
||||
connections.forEach((conn, idx) => {
|
||||
const rawLines = conn.from && conn.to
|
||||
? [
|
||||
`${getNodeLabel(conn.from)}→${getNodeLabel(conn.to)}`,
|
||||
`${conn.ip}`,
|
||||
`${conn.ifaceLabel}`
|
||||
]
|
||||
: [conn.ip];
|
||||
|
||||
const wrapped = rawLines.flatMap(line => wrapLine(line, maxLineLen));
|
||||
|
||||
wrapped.forEach((line, lineIdx) => {
|
||||
textGroup.append('text')
|
||||
.attr('x', 0)
|
||||
.attr('y', (idx * (wrapped.length * (fontSize + lineGap))) + lineIdx * (fontSize + lineGap))
|
||||
.attr('text-anchor', 'middle')
|
||||
.attr('dominant-baseline', 'hanging')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.9)')
|
||||
.text(line);
|
||||
});
|
||||
// Store for batch rendering after all edges processed
|
||||
if (!debugEdgeLabels) debugEdgeLabels = [];
|
||||
debugEdgeLabels.push({
|
||||
connections: entry.connections,
|
||||
isLeft,
|
||||
isTop,
|
||||
mx,
|
||||
my
|
||||
});
|
||||
|
||||
const bbox = textGroup.node()?.getBBox();
|
||||
if (bbox) {
|
||||
const paddedWidth = Math.max(boxWidth, bbox.width + 14);
|
||||
const boxHeight = bbox.height + 8;
|
||||
const boxMinX = labelX - paddedWidth / 2;
|
||||
const boxMaxX = labelX + paddedWidth / 2;
|
||||
const boxMinY = labelY + bbox.y - 4;
|
||||
const boxMaxY = boxMinY + boxHeight;
|
||||
|
||||
const clampPadDynamic = Math.min(140, minDimension * 0.18);
|
||||
let shiftX = 0;
|
||||
let shiftY = 0;
|
||||
if (boxMinX < clampPadDynamic) shiftX = clampPadDynamic - boxMinX;
|
||||
if (boxMaxX > width - clampPadDynamic) shiftX = (width - clampPadDynamic) - boxMaxX;
|
||||
if (boxMinY < clampPadDynamic) shiftY = clampPadDynamic - boxMinY;
|
||||
if (boxMaxY > height - clampPadDynamic) shiftY = (height - clampPadDynamic) - boxMaxY;
|
||||
|
||||
const finalX = labelX + shiftX;
|
||||
const finalY = labelY + shiftY;
|
||||
labelGroup.attr('transform', `translate(${finalX}, ${finalY})`);
|
||||
|
||||
labelGroup.insert('rect', 'g')
|
||||
.attr('x', -paddedWidth / 2)
|
||||
.attr('y', bbox.y - 4)
|
||||
.attr('width', paddedWidth)
|
||||
.attr('height', boxHeight)
|
||||
.attr('rx', 4)
|
||||
.attr('fill', 'rgba(0,0,0,0.75)')
|
||||
.attr('stroke', 'rgba(255,255,255,0.12)')
|
||||
.attr('stroke-width', 0.6);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Render debug labels at viewport edges/corners
|
||||
if (debugEdgeLabels && debugEdgeLabels.length > 0) {
|
||||
const fontSize = isMinimized ? 10 : 12;
|
||||
const lineHeight = fontSize + 4;
|
||||
const padding = 10;
|
||||
|
||||
// Helper to get arrow based on direction vector
|
||||
function getArrow(fromId: string, toId: string): string {
|
||||
const fromPos = positionById[fromId];
|
||||
const toPos = positionById[toId];
|
||||
if (!fromPos || !toPos) return '→';
|
||||
|
||||
const dirX = toPos.x - fromPos.x;
|
||||
const dirY = toPos.y - fromPos.y;
|
||||
const absX = Math.abs(dirX);
|
||||
const absY = Math.abs(dirY);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dirX > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dirY > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dirX > 0 && dirY > 0) return '↘';
|
||||
if (dirX > 0 && dirY < 0) return '↗';
|
||||
if (dirX < 0 && dirY > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
|
||||
const quadrants: Record<string, typeof debugEdgeLabels> = {
|
||||
topLeft: [],
|
||||
topRight: [],
|
||||
bottomLeft: [],
|
||||
bottomRight: []
|
||||
};
|
||||
|
||||
debugEdgeLabels.forEach(edge => {
|
||||
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
|
||||
quadrants[key].push(edge);
|
||||
});
|
||||
|
||||
// Render each quadrant
|
||||
Object.entries(quadrants).forEach(([quadrant, edges]) => {
|
||||
if (edges.length === 0) return;
|
||||
|
||||
const isLeft = quadrant.includes('Left');
|
||||
const isTop = quadrant.includes('top');
|
||||
|
||||
let baseX = isLeft ? padding : width - padding;
|
||||
let baseY = isTop ? padding : height - padding;
|
||||
const textAnchor = isLeft ? 'start' : 'end';
|
||||
|
||||
let currentY = baseY;
|
||||
|
||||
edges.forEach(edge => {
|
||||
edge.connections.forEach(conn => {
|
||||
const arrow = getArrow(conn.from, conn.to);
|
||||
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;
|
||||
debugLabelsGroup
|
||||
.append('text')
|
||||
.attr('x', baseX)
|
||||
.attr('y', currentY)
|
||||
.attr('text-anchor', textAnchor)
|
||||
.attr('dominant-baseline', isTop ? 'hanging' : 'auto')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.85)')
|
||||
.text(label);
|
||||
currentY += isTop ? lineHeight : -lineHeight;
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Draw nodes
|
||||
const nodesGroup = svg.append('g').attr('class', 'nodes-group');
|
||||
|
||||
@@ -451,28 +450,25 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
}
|
||||
}
|
||||
|
||||
const nodeG = nodesGroup.append('g')
|
||||
.attr('class', 'graph-node')
|
||||
.style('cursor', 'pointer');
|
||||
const nodeG = nodesGroup.append('g').attr('class', 'graph-node').style('cursor', 'pointer');
|
||||
|
||||
// Add tooltip
|
||||
nodeG.append('title')
|
||||
.text(`${friendlyName}\nID: ${nodeInfo.id.slice(-8)}\nMemory: ${formatBytes(ramUsed)}/${formatBytes(ramTotal)}`);
|
||||
nodeG.append('title').text(`${friendlyName}\nID: ${nodeInfo.id.slice(-8)}\nMemory: ${formatBytes(ramUsed)}/${formatBytes(ramTotal)}`);
|
||||
|
||||
let iconBaseWidth = nodeRadius * 1.2;
|
||||
let iconBaseHeight = nodeRadius * 1.0;
|
||||
const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, '-')}`;
|
||||
|
||||
const modelLower = modelId.toLowerCase();
|
||||
|
||||
|
||||
// Check if this node should be highlighted (from hovered instance)
|
||||
const isHighlighted = highlightedNodes.has(nodeInfo.id);
|
||||
|
||||
// Holographic wireframe colors - yellow border when highlighted
|
||||
const wireColor = isHighlighted ? 'rgba(255,215,0,0.9)' : 'rgba(179,179,179,0.8)';
|
||||
const wireColorBright = 'rgba(255,255,255,0.9)';
|
||||
const fillColor = isHighlighted ? 'rgba(255,215,0,0.15)' : 'rgba(255,215,0,0.08)';
|
||||
const strokeWidth = isHighlighted ? 2.5 : 1.5;
|
||||
|
||||
// Holographic wireframe colors - yellow border when highlighted
|
||||
const wireColor = isHighlighted ? 'rgba(255,215,0,0.9)' : 'rgba(179,179,179,0.8)';
|
||||
const wireColorBright = 'rgba(255,255,255,0.9)';
|
||||
const fillColor = isHighlighted ? 'rgba(255,215,0,0.15)' : 'rgba(255,215,0,0.08)';
|
||||
const strokeWidth = isHighlighted ? 2.5 : 1.5;
|
||||
const screenFill = 'rgba(0,20,40,0.9)';
|
||||
const glowColor = 'rgba(255,215,0,0.3)';
|
||||
|
||||
@@ -487,7 +483,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
|
||||
// Create clip path for memory fill area (front body)
|
||||
const studioClipId = `studio-clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, '-')}`;
|
||||
defs.append('clipPath')
|
||||
defs
|
||||
.append('clipPath')
|
||||
.attr('id', studioClipId)
|
||||
.append('rect')
|
||||
.attr('x', x)
|
||||
@@ -497,7 +494,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('rx', cornerRadius - 1);
|
||||
|
||||
// Main body (uniform color)
|
||||
nodeG.append('rect')
|
||||
nodeG
|
||||
.append('rect')
|
||||
.attr('x', x)
|
||||
.attr('y', y)
|
||||
.attr('width', iconBaseWidth)
|
||||
@@ -511,7 +509,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
if (ramUsagePercent > 0) {
|
||||
const memFillTotalHeight = iconBaseHeight - topSurfaceHeight;
|
||||
const memFillActualHeight = (ramUsagePercent / 100) * memFillTotalHeight;
|
||||
nodeG.append('rect')
|
||||
nodeG
|
||||
.append('rect')
|
||||
.attr('x', x)
|
||||
.attr('y', y + topSurfaceHeight + (memFillTotalHeight - memFillActualHeight))
|
||||
.attr('width', iconBaseWidth)
|
||||
@@ -529,7 +528,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const vSlot2X = x + iconBaseWidth * 0.28;
|
||||
|
||||
[vSlot1X, vSlot2X].forEach(vx => {
|
||||
nodeG.append('rect')
|
||||
nodeG
|
||||
.append('rect')
|
||||
.attr('x', vx - vSlotWidth / 2)
|
||||
.attr('y', vSlotY)
|
||||
.attr('width', vSlotWidth)
|
||||
@@ -541,14 +541,14 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
// Horizontal slot (SD card)
|
||||
const hSlotWidth = iconBaseWidth * 0.2;
|
||||
const hSlotX = x + iconBaseWidth * 0.5 - hSlotWidth / 2;
|
||||
nodeG.append('rect')
|
||||
nodeG
|
||||
.append('rect')
|
||||
.attr('x', hSlotX)
|
||||
.attr('y', vSlotY)
|
||||
.attr('width', hSlotWidth)
|
||||
.attr('height', slotHeight * 0.6)
|
||||
.attr('fill', detailColor)
|
||||
.attr('rx', 1);
|
||||
|
||||
} else if (modelLower === 'mac mini') {
|
||||
// Mac Mini - classic flat box with memory fill
|
||||
iconBaseWidth = nodeRadius * 1.3;
|
||||
@@ -556,11 +556,12 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const x = nodeInfo.x - iconBaseWidth / 2;
|
||||
const y = nodeInfo.y - iconBaseHeight / 2;
|
||||
const cornerRadius = 3;
|
||||
const topSurfaceHeight = iconBaseHeight * 0.20;
|
||||
const topSurfaceHeight = iconBaseHeight * 0.2;
|
||||
|
||||
// Create clip path for memory fill area
|
||||
const miniClipId = `mini-clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, '-')}`;
|
||||
defs.append('clipPath')
|
||||
defs
|
||||
.append('clipPath')
|
||||
.attr('id', miniClipId)
|
||||
.append('rect')
|
||||
.attr('x', x)
|
||||
@@ -570,7 +571,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('rx', cornerRadius - 1);
|
||||
|
||||
// Main body (uniform color)
|
||||
nodeG.append('rect')
|
||||
nodeG
|
||||
.append('rect')
|
||||
.attr('x', x)
|
||||
.attr('y', y)
|
||||
.attr('width', iconBaseWidth)
|
||||
@@ -584,7 +586,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
if (ramUsagePercent > 0) {
|
||||
const memFillTotalHeight = iconBaseHeight - topSurfaceHeight;
|
||||
const memFillActualHeight = (ramUsagePercent / 100) * memFillTotalHeight;
|
||||
nodeG.append('rect')
|
||||
nodeG
|
||||
.append('rect')
|
||||
.attr('x', x)
|
||||
.attr('y', y + topSurfaceHeight + (memFillTotalHeight - memFillActualHeight))
|
||||
.attr('width', iconBaseWidth)
|
||||
@@ -595,14 +598,15 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
|
||||
// Front panel details - vertical slots (no horizontal slot for Mini)
|
||||
const detailColor = 'rgba(0,0,0,0.35)';
|
||||
const slotHeight = iconBaseHeight * 0.20;
|
||||
const slotHeight = iconBaseHeight * 0.2;
|
||||
const vSlotWidth = iconBaseWidth * 0.045;
|
||||
const vSlotY = y + topSurfaceHeight + (iconBaseHeight - topSurfaceHeight) * 0.45;
|
||||
const vSlot1X = x + iconBaseWidth * 0.20;
|
||||
const vSlot2X = x + iconBaseWidth * 0.30;
|
||||
const vSlot1X = x + iconBaseWidth * 0.2;
|
||||
const vSlot2X = x + iconBaseWidth * 0.3;
|
||||
|
||||
[vSlot1X, vSlot2X].forEach(vx => {
|
||||
nodeG.append('rect')
|
||||
nodeG
|
||||
.append('rect')
|
||||
.attr('x', vx - vSlotWidth / 2)
|
||||
.attr('y', vSlotY)
|
||||
.attr('width', vSlotWidth)
|
||||
@@ -610,7 +614,6 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('fill', detailColor)
|
||||
.attr('rx', 1.2);
|
||||
});
|
||||
|
||||
} else if (modelLower === 'macbook pro' || modelLower.includes('macbook')) {
|
||||
// MacBook Pro - classic style with memory fill on screen
|
||||
iconBaseWidth = nodeRadius * 1.6;
|
||||
@@ -618,15 +621,16 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const x = nodeInfo.x - iconBaseWidth / 2;
|
||||
const y = nodeInfo.y - iconBaseHeight / 2;
|
||||
|
||||
const screenHeight = iconBaseHeight * 0.70;
|
||||
const baseHeight = iconBaseHeight * 0.30;
|
||||
const screenHeight = iconBaseHeight * 0.7;
|
||||
const baseHeight = iconBaseHeight * 0.3;
|
||||
const screenWidth = iconBaseWidth * 0.85;
|
||||
const screenX = nodeInfo.x - screenWidth / 2;
|
||||
const screenBezel = 3;
|
||||
|
||||
// Create clip path for screen content
|
||||
const screenClipId = `screen-clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, '-')}`;
|
||||
defs.append('clipPath')
|
||||
defs
|
||||
.append('clipPath')
|
||||
.attr('id', screenClipId)
|
||||
.append('rect')
|
||||
.attr('x', screenX + screenBezel)
|
||||
@@ -636,7 +640,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('rx', 2);
|
||||
|
||||
// Screen outer frame
|
||||
nodeG.append('rect')
|
||||
nodeG
|
||||
.append('rect')
|
||||
.attr('x', screenX)
|
||||
.attr('y', y)
|
||||
.attr('width', screenWidth)
|
||||
@@ -647,7 +652,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('stroke-width', strokeWidth);
|
||||
|
||||
// Screen inner (dark background)
|
||||
nodeG.append('rect')
|
||||
nodeG
|
||||
.append('rect')
|
||||
.attr('x', screenX + screenBezel)
|
||||
.attr('y', y + screenBezel)
|
||||
.attr('width', screenWidth - screenBezel * 2)
|
||||
@@ -659,7 +665,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
if (ramUsagePercent > 0) {
|
||||
const memFillTotalHeight = screenHeight - screenBezel * 2;
|
||||
const memFillActualHeight = (ramUsagePercent / 100) * memFillTotalHeight;
|
||||
nodeG.append('rect')
|
||||
nodeG
|
||||
.append('rect')
|
||||
.attr('x', screenX + screenBezel)
|
||||
.attr('y', y + screenBezel + (memFillTotalHeight - memFillActualHeight))
|
||||
.attr('width', screenWidth - screenBezel * 2)
|
||||
@@ -671,13 +678,9 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
// Apple logo on screen (centered, on top of memory fill)
|
||||
const targetLogoHeight = screenHeight * 0.22;
|
||||
const logoScale = targetLogoHeight / LOGO_NATIVE_HEIGHT;
|
||||
const logoX = nodeInfo.x - (LOGO_NATIVE_WIDTH * logoScale / 2);
|
||||
const logoY = y + screenHeight / 2 - (LOGO_NATIVE_HEIGHT * logoScale / 2);
|
||||
nodeG.append('path')
|
||||
.attr('d', APPLE_LOGO_PATH)
|
||||
.attr('transform', `translate(${logoX}, ${logoY}) scale(${logoScale})`)
|
||||
.attr('fill', '#FFFFFF')
|
||||
.attr('opacity', 0.9);
|
||||
const logoX = nodeInfo.x - (LOGO_NATIVE_WIDTH * logoScale) / 2;
|
||||
const logoY = y + screenHeight / 2 - (LOGO_NATIVE_HEIGHT * logoScale) / 2;
|
||||
nodeG.append('path').attr('d', APPLE_LOGO_PATH).attr('transform', `translate(${logoX}, ${logoY}) scale(${logoScale})`).attr('fill', '#FFFFFF').attr('opacity', 0.9);
|
||||
|
||||
// Base (keyboard) - trapezoidal
|
||||
const baseY = y + screenHeight;
|
||||
@@ -686,7 +689,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const baseTopX = nodeInfo.x - baseTopWidth / 2;
|
||||
const baseBottomX = nodeInfo.x - baseBottomWidth / 2;
|
||||
|
||||
nodeG.append('path')
|
||||
nodeG
|
||||
.append('path')
|
||||
.attr('d', `M ${baseTopX} ${baseY} L ${baseTopX + baseTopWidth} ${baseY} L ${baseBottomX + baseBottomWidth} ${baseY + baseHeight} L ${baseBottomX} ${baseY + baseHeight} Z`)
|
||||
.attr('fill', '#2c2c2c')
|
||||
.attr('stroke', wireColor)
|
||||
@@ -697,66 +701,44 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const keyboardY = baseY + 3;
|
||||
const keyboardWidth = baseTopWidth - 12;
|
||||
const keyboardHeight = baseHeight * 0.55;
|
||||
nodeG.append('rect')
|
||||
.attr('x', keyboardX)
|
||||
.attr('y', keyboardY)
|
||||
.attr('width', keyboardWidth)
|
||||
.attr('height', keyboardHeight)
|
||||
.attr('fill', 'rgba(0,0,0,0.2)')
|
||||
.attr('rx', 2);
|
||||
nodeG.append('rect').attr('x', keyboardX).attr('y', keyboardY).attr('width', keyboardWidth).attr('height', keyboardHeight).attr('fill', 'rgba(0,0,0,0.2)').attr('rx', 2);
|
||||
|
||||
// Trackpad
|
||||
const trackpadWidth = baseTopWidth * 0.4;
|
||||
const trackpadX = nodeInfo.x - trackpadWidth / 2;
|
||||
const trackpadY = baseY + keyboardHeight + 5;
|
||||
const trackpadHeight = baseHeight * 0.30;
|
||||
nodeG.append('rect')
|
||||
.attr('x', trackpadX)
|
||||
.attr('y', trackpadY)
|
||||
.attr('width', trackpadWidth)
|
||||
.attr('height', trackpadHeight)
|
||||
.attr('fill', 'rgba(255,255,255,0.08)')
|
||||
.attr('rx', 2);
|
||||
|
||||
const trackpadHeight = baseHeight * 0.3;
|
||||
nodeG.append('rect').attr('x', trackpadX).attr('y', trackpadY).attr('width', trackpadWidth).attr('height', trackpadHeight).attr('fill', 'rgba(255,255,255,0.08)').attr('rx', 2);
|
||||
} else {
|
||||
// Default/Unknown - holographic hexagon
|
||||
const hexRadius = nodeRadius * 0.6;
|
||||
const hexPoints = Array.from({ length: 6 }, (_, i) => {
|
||||
const angle = (i * 60 - 30) * Math.PI / 180;
|
||||
const angle = ((i * 60 - 30) * Math.PI) / 180;
|
||||
return `${nodeInfo.x + hexRadius * Math.cos(angle)},${nodeInfo.y + hexRadius * Math.sin(angle)}`;
|
||||
}).join(' ');
|
||||
|
||||
// Main shape
|
||||
nodeG.append('polygon')
|
||||
.attr('points', hexPoints)
|
||||
.attr('fill', fillColor)
|
||||
.attr('stroke', wireColor)
|
||||
.attr('stroke-width', strokeWidth);
|
||||
nodeG.append('polygon').attr('points', hexPoints).attr('fill', fillColor).attr('stroke', wireColor).attr('stroke-width', strokeWidth);
|
||||
}
|
||||
|
||||
// --- Vertical GPU Bar (right side of icon) ---
|
||||
// Show in both full mode and minimized mode (scaled appropriately)
|
||||
if (showFullLabels || isMinimized) {
|
||||
const gpuBarWidth = isMinimized ? Math.max(16, nodeRadius * 0.32) : Math.max(28, nodeRadius * 0.30);
|
||||
const gpuBarWidth = isMinimized ? Math.max(16, nodeRadius * 0.32) : Math.max(28, nodeRadius * 0.3);
|
||||
const gpuBarHeight = iconBaseHeight * 0.95;
|
||||
const barXOffset = iconBaseWidth / 2 + (isMinimized ? 5 : 10);
|
||||
const gpuBarX = nodeInfo.x + barXOffset;
|
||||
const gpuBarY = nodeInfo.y - gpuBarHeight / 2;
|
||||
|
||||
// GPU Bar Background (grey, no border)
|
||||
nodeG.append('rect')
|
||||
.attr('x', gpuBarX)
|
||||
.attr('y', gpuBarY)
|
||||
.attr('width', gpuBarWidth)
|
||||
.attr('height', gpuBarHeight)
|
||||
.attr('fill', 'rgba(80, 80, 90, 0.7)')
|
||||
.attr('rx', 2);
|
||||
nodeG.append('rect').attr('x', gpuBarX).attr('y', gpuBarY).attr('width', gpuBarWidth).attr('height', gpuBarHeight).attr('fill', 'rgba(80, 80, 90, 0.7)').attr('rx', 2);
|
||||
|
||||
// GPU Bar Fill (from bottom up, colored by temperature)
|
||||
if (gpuUsagePercent > 0) {
|
||||
const fillHeight = (gpuUsagePercent / 100) * gpuBarHeight;
|
||||
const gpuFillColor = getTemperatureColor(gpuTemp);
|
||||
nodeG.append('rect')
|
||||
nodeG
|
||||
.append('rect')
|
||||
.attr('x', gpuBarX)
|
||||
.attr('y', gpuBarY + (gpuBarHeight - fillHeight))
|
||||
.attr('width', gpuBarWidth)
|
||||
@@ -777,7 +759,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const powerText = sysPower !== null ? `${sysPower.toFixed(0)}W` : '-';
|
||||
|
||||
// GPU Usage %
|
||||
nodeG.append('text')
|
||||
nodeG
|
||||
.append('text')
|
||||
.attr('x', gpuTextX)
|
||||
.attr('y', gpuTextY - lineSpacing)
|
||||
.attr('text-anchor', 'middle')
|
||||
@@ -789,7 +772,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.text(gpuUsageText);
|
||||
|
||||
// Temperature
|
||||
nodeG.append('text')
|
||||
nodeG
|
||||
.append('text')
|
||||
.attr('x', gpuTextX)
|
||||
.attr('y', gpuTextY)
|
||||
.attr('text-anchor', 'middle')
|
||||
@@ -801,7 +785,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.text(tempText);
|
||||
|
||||
// Power (Watts)
|
||||
nodeG.append('text')
|
||||
nodeG
|
||||
.append('text')
|
||||
.attr('x', gpuTextX)
|
||||
.attr('y', gpuTextY + lineSpacing)
|
||||
.attr('text-anchor', 'middle')
|
||||
@@ -818,15 +803,14 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
// FULL MODE: Name above, memory info below (1-4 nodes)
|
||||
const nameY = nodeInfo.y - iconBaseHeight / 2 - 15;
|
||||
const fontSize = Math.max(10, nodeRadius * 0.16);
|
||||
|
||||
|
||||
// Truncate name based on node count
|
||||
const maxNameLen = numNodes === 1 ? 22 : (numNodes === 2 ? 18 : numNodes === 3 ? 16 : 14);
|
||||
const displayName = friendlyName.length > maxNameLen
|
||||
? friendlyName.slice(0, maxNameLen - 2) + '..'
|
||||
: friendlyName;
|
||||
|
||||
const maxNameLen = numNodes === 1 ? 22 : numNodes === 2 ? 18 : numNodes === 3 ? 16 : 14;
|
||||
const displayName = friendlyName.length > maxNameLen ? friendlyName.slice(0, maxNameLen - 2) + '..' : friendlyName;
|
||||
|
||||
// Name label above
|
||||
nodeG.append('text')
|
||||
nodeG
|
||||
.append('text')
|
||||
.attr('x', nodeInfo.x)
|
||||
.attr('y', nameY)
|
||||
.attr('text-anchor', 'middle')
|
||||
@@ -839,32 +823,34 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
|
||||
// Memory info below - used in grey, total in yellow
|
||||
const infoY = nodeInfo.y + iconBaseHeight / 2 + 16;
|
||||
const memText = nodeG.append('text')
|
||||
const memText = nodeG
|
||||
.append('text')
|
||||
.attr('x', nodeInfo.x)
|
||||
.attr('y', infoY)
|
||||
.attr('text-anchor', 'middle')
|
||||
.attr('font-size', fontSize * 0.85)
|
||||
.attr('font-family', 'SF Mono, Monaco, monospace');
|
||||
memText.append('tspan')
|
||||
memText
|
||||
.append('tspan')
|
||||
.attr('fill', 'rgba(255,215,0,0.9)')
|
||||
.text(`${formatBytes(ramUsed)}`);
|
||||
memText.append('tspan')
|
||||
memText
|
||||
.append('tspan')
|
||||
.attr('fill', 'rgba(179,179,179,0.9)')
|
||||
.text(`/${formatBytes(ramTotal)}`);
|
||||
memText.append('tspan')
|
||||
memText
|
||||
.append('tspan')
|
||||
.attr('fill', 'rgba(179,179,179,0.7)')
|
||||
.text(` (${ramUsagePercent.toFixed(0)}%)`);
|
||||
|
||||
} else if (showCompactLabels) {
|
||||
// COMPACT MODE: Just name and basic info (4+ nodes)
|
||||
const fontSize = Math.max(7, nodeRadius * 0.11);
|
||||
|
||||
|
||||
// Very compact name below icon
|
||||
const nameY = nodeInfo.y + iconBaseHeight / 2 + 9;
|
||||
const shortName = friendlyName.length > 10
|
||||
? friendlyName.slice(0, 8) + '..'
|
||||
: friendlyName;
|
||||
nodeG.append('text')
|
||||
const shortName = friendlyName.length > 10 ? friendlyName.slice(0, 8) + '..' : friendlyName;
|
||||
nodeG
|
||||
.append('text')
|
||||
.attr('x', nodeInfo.x)
|
||||
.attr('y', nameY)
|
||||
.attr('text-anchor', 'middle')
|
||||
@@ -875,7 +861,8 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
|
||||
// Single line of key stats
|
||||
const statsY = nameY + 9;
|
||||
nodeG.append('text')
|
||||
nodeG
|
||||
.append('text')
|
||||
.attr('x', nodeInfo.x)
|
||||
.attr('y', statsY)
|
||||
.attr('text-anchor', 'middle')
|
||||
@@ -883,17 +870,15 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('font-size', fontSize * 0.85)
|
||||
.attr('font-family', 'SF Mono, Monaco, monospace')
|
||||
.text(`${ramUsagePercent.toFixed(0)}%${!isNaN(gpuTemp) ? ' ' + gpuTemp.toFixed(0) + '°C' : ''}`);
|
||||
|
||||
} else {
|
||||
// MINIMIZED MODE: Show name above and memory info below (like main topology)
|
||||
const fontSize = 8;
|
||||
|
||||
|
||||
// Friendly name (shortened) above icon
|
||||
const nameY = nodeInfo.y - iconBaseHeight / 2 - 8;
|
||||
const shortName = friendlyName.length > 12
|
||||
? friendlyName.slice(0, 10) + '..'
|
||||
: friendlyName;
|
||||
nodeG.append('text')
|
||||
const shortName = friendlyName.length > 12 ? friendlyName.slice(0, 10) + '..' : friendlyName;
|
||||
nodeG
|
||||
.append('text')
|
||||
.attr('x', nodeInfo.x)
|
||||
.attr('y', nameY)
|
||||
.attr('text-anchor', 'middle')
|
||||
@@ -905,24 +890,27 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
|
||||
// Memory info below icon - used in grey, total in yellow (same as main topology)
|
||||
const infoY = nodeInfo.y + iconBaseHeight / 2 + 10;
|
||||
const memTextMini = nodeG.append('text')
|
||||
const memTextMini = nodeG
|
||||
.append('text')
|
||||
.attr('x', nodeInfo.x)
|
||||
.attr('y', infoY)
|
||||
.attr('text-anchor', 'middle')
|
||||
.attr('font-size', fontSize * 0.85)
|
||||
.attr('font-family', 'SF Mono, Monaco, monospace');
|
||||
memTextMini.append('tspan')
|
||||
memTextMini
|
||||
.append('tspan')
|
||||
.attr('fill', 'rgba(255,215,0,0.9)')
|
||||
.text(`${formatBytes(ramUsed)}`);
|
||||
memTextMini.append('tspan')
|
||||
memTextMini
|
||||
.append('tspan')
|
||||
.attr('fill', 'rgba(179,179,179,0.9)')
|
||||
.text(`/${formatBytes(ramTotal)}`);
|
||||
memTextMini.append('tspan')
|
||||
memTextMini
|
||||
.append('tspan')
|
||||
.attr('fill', 'rgba(179,179,179,0.7)')
|
||||
.text(` (${ramUsagePercent.toFixed(0)}%)`);
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
@@ -945,27 +933,30 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
});
|
||||
</script>
|
||||
|
||||
<svg
|
||||
bind:this={svgContainer}
|
||||
class="w-full h-full {className}"
|
||||
></svg>
|
||||
<svg bind:this={svgContainer} class="w-full h-full {className}"></svg>
|
||||
|
||||
<style>
|
||||
:global(.graph-node) {
|
||||
transition: transform 0.2s ease, opacity 0.2s ease;
|
||||
transition:
|
||||
transform 0.2s ease,
|
||||
opacity 0.2s ease;
|
||||
}
|
||||
:global(.graph-node:hover) {
|
||||
filter: brightness(1.1);
|
||||
}
|
||||
:global(.graph-link) {
|
||||
stroke: var(--exo-light-gray, #B3B3B3);
|
||||
stroke: var(--exo-light-gray, #b3b3b3);
|
||||
stroke-width: 1px;
|
||||
stroke-dasharray: 4, 4;
|
||||
opacity: 0.8;
|
||||
animation: flowAnimation 0.75s linear infinite;
|
||||
}
|
||||
@keyframes flowAnimation {
|
||||
from { stroke-dashoffset: 0; }
|
||||
to { stroke-dashoffset: -10; }
|
||||
from {
|
||||
stroke-dashoffset: 0;
|
||||
}
|
||||
to {
|
||||
stroke-dashoffset: -10;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
export { default as TopologyGraph } from './TopologyGraph.svelte';
|
||||
export { default as ChatForm } from './ChatForm.svelte';
|
||||
export { default as ChatMessages } from './ChatMessages.svelte';
|
||||
export { default as ChatAttachments } from './ChatAttachments.svelte';
|
||||
export { default as ChatSidebar } from './ChatSidebar.svelte';
|
||||
export { default as ModelCard } from './ModelCard.svelte';
|
||||
|
||||
export { default as TopologyGraph } from "./TopologyGraph.svelte";
|
||||
export { default as ChatForm } from "./ChatForm.svelte";
|
||||
export { default as ChatMessages } from "./ChatMessages.svelte";
|
||||
export { default as ChatAttachments } from "./ChatAttachments.svelte";
|
||||
export { default as ChatSidebar } from "./ChatSidebar.svelte";
|
||||
export { default as ModelCard } from "./ModelCard.svelte";
|
||||
export { default as MarkdownContent } from "./MarkdownContent.svelte";
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,55 +13,93 @@ export interface ChatUploadedFile {
|
||||
}
|
||||
|
||||
export interface ChatAttachment {
|
||||
type: 'image' | 'text' | 'pdf' | 'audio';
|
||||
type: "image" | "text" | "pdf" | "audio";
|
||||
name: string;
|
||||
content?: string;
|
||||
base64Url?: string;
|
||||
mimeType?: string;
|
||||
}
|
||||
|
||||
export type FileCategory = 'image' | 'text' | 'pdf' | 'audio' | 'unknown';
|
||||
export type FileCategory = "image" | "text" | "pdf" | "audio" | "unknown";
|
||||
|
||||
export const IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.svg'];
|
||||
export const IMAGE_MIME_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/svg+xml'];
|
||||
export const IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg"];
|
||||
export const IMAGE_MIME_TYPES = ["image/jpeg", "image/png", "image/gif", "image/webp", "image/svg+xml"];
|
||||
|
||||
export const TEXT_EXTENSIONS = [
|
||||
'.txt', '.md', '.json', '.xml', '.yaml', '.yml', '.csv', '.log',
|
||||
'.js', '.ts', '.jsx', '.tsx', '.py', '.java', '.cpp', '.c', '.h',
|
||||
'.css', '.html', '.htm', '.sql', '.sh', '.bat', '.rs', '.go',
|
||||
'.rb', '.php', '.swift', '.kt', '.scala', '.r', '.dart', '.vue', '.svelte'
|
||||
".txt",
|
||||
".md",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".csv",
|
||||
".log",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".py",
|
||||
".java",
|
||||
".cpp",
|
||||
".c",
|
||||
".h",
|
||||
".css",
|
||||
".html",
|
||||
".htm",
|
||||
".sql",
|
||||
".sh",
|
||||
".bat",
|
||||
".rs",
|
||||
".go",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".r",
|
||||
".dart",
|
||||
".vue",
|
||||
".svelte"
|
||||
];
|
||||
export const TEXT_MIME_TYPES = [
|
||||
'text/plain', 'text/markdown', 'text/csv', 'text/html', 'text/css',
|
||||
'application/json', 'application/xml', 'text/xml', 'application/javascript',
|
||||
'text/javascript', 'application/typescript'
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"text/html",
|
||||
"text/css",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/javascript",
|
||||
"text/javascript",
|
||||
"application/typescript"
|
||||
];
|
||||
|
||||
export const PDF_EXTENSIONS = ['.pdf'];
|
||||
export const PDF_MIME_TYPES = ['application/pdf'];
|
||||
export const PDF_EXTENSIONS = [".pdf"];
|
||||
export const PDF_MIME_TYPES = ["application/pdf"];
|
||||
|
||||
export const AUDIO_EXTENSIONS = ['.mp3', '.wav', '.ogg', '.m4a'];
|
||||
export const AUDIO_MIME_TYPES = ['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/mp4'];
|
||||
export const AUDIO_EXTENSIONS = [".mp3", ".wav", ".ogg", ".m4a"];
|
||||
export const AUDIO_MIME_TYPES = ["audio/mpeg", "audio/wav", "audio/ogg", "audio/mp4"];
|
||||
|
||||
/**
|
||||
* Get file category based on MIME type and extension
|
||||
*/
|
||||
export function getFileCategory(mimeType: string, fileName: string): FileCategory {
|
||||
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf('.'));
|
||||
|
||||
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf("."));
|
||||
|
||||
if (IMAGE_MIME_TYPES.includes(mimeType) || IMAGE_EXTENSIONS.includes(extension)) {
|
||||
return 'image';
|
||||
return "image";
|
||||
}
|
||||
if (PDF_MIME_TYPES.includes(mimeType) || PDF_EXTENSIONS.includes(extension)) {
|
||||
return 'pdf';
|
||||
return "pdf";
|
||||
}
|
||||
if (AUDIO_MIME_TYPES.includes(mimeType) || AUDIO_EXTENSIONS.includes(extension)) {
|
||||
return 'audio';
|
||||
return "audio";
|
||||
}
|
||||
if (TEXT_MIME_TYPES.includes(mimeType) || TEXT_EXTENSIONS.includes(extension) || mimeType.startsWith('text/')) {
|
||||
return 'text';
|
||||
if (TEXT_MIME_TYPES.includes(mimeType) || TEXT_EXTENSIONS.includes(extension) || mimeType.startsWith("text/")) {
|
||||
return "text";
|
||||
}
|
||||
return 'unknown';
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -69,36 +107,36 @@ export function getFileCategory(mimeType: string, fileName: string): FileCategor
|
||||
*/
|
||||
export function getAcceptString(categories: FileCategory[]): string {
|
||||
const accepts: string[] = [];
|
||||
|
||||
|
||||
for (const category of categories) {
|
||||
switch (category) {
|
||||
case 'image':
|
||||
case "image":
|
||||
accepts.push(...IMAGE_EXTENSIONS, ...IMAGE_MIME_TYPES);
|
||||
break;
|
||||
case 'text':
|
||||
case "text":
|
||||
accepts.push(...TEXT_EXTENSIONS, ...TEXT_MIME_TYPES);
|
||||
break;
|
||||
case 'pdf':
|
||||
case "pdf":
|
||||
accepts.push(...PDF_EXTENSIONS, ...PDF_MIME_TYPES);
|
||||
break;
|
||||
case 'audio':
|
||||
case "audio":
|
||||
accepts.push(...AUDIO_EXTENSIONS, ...AUDIO_MIME_TYPES);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return accepts.join(',');
|
||||
|
||||
return accepts.join(",");
|
||||
}
|
||||
|
||||
/**
|
||||
* Format file size for display
|
||||
*/
|
||||
export function formatFileSize(bytes: number): string {
|
||||
if (bytes === 0) return '0 B';
|
||||
if (bytes === 0) return "0 B";
|
||||
const k = 1024;
|
||||
const sizes = ['B', 'KB', 'MB', 'GB'];
|
||||
const sizes = ["B", "KB", "MB", "GB"];
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + " " + sizes[i];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -130,11 +168,11 @@ export function readFileAsText(file: File): Promise<string> {
|
||||
*/
|
||||
export async function processUploadedFiles(files: File[]): Promise<ChatUploadedFile[]> {
|
||||
const results: ChatUploadedFile[] = [];
|
||||
|
||||
|
||||
for (const file of files) {
|
||||
const id = Date.now().toString() + Math.random().toString(36).substring(2, 9);
|
||||
const category = getFileCategory(file.type, file.name);
|
||||
|
||||
|
||||
const base: ChatUploadedFile = {
|
||||
id,
|
||||
name: file.name,
|
||||
@@ -142,28 +180,27 @@ export async function processUploadedFiles(files: File[]): Promise<ChatUploadedF
|
||||
type: file.type,
|
||||
file
|
||||
};
|
||||
|
||||
|
||||
try {
|
||||
if (category === 'image') {
|
||||
if (category === "image") {
|
||||
const preview = await readFileAsDataURL(file);
|
||||
results.push({ ...base, preview });
|
||||
} else if (category === 'text' || category === 'unknown') {
|
||||
} else if (category === "text" || category === "unknown") {
|
||||
const textContent = await readFileAsText(file);
|
||||
results.push({ ...base, textContent });
|
||||
} else if (category === 'pdf') {
|
||||
} else if (category === "pdf") {
|
||||
results.push(base);
|
||||
} else if (category === 'audio') {
|
||||
} else if (category === "audio") {
|
||||
const preview = await readFileAsDataURL(file);
|
||||
results.push({ ...base, preview });
|
||||
} else {
|
||||
results.push(base);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error processing file:', file.name, error);
|
||||
console.error("Error processing file:", file.name, error);
|
||||
results.push(base);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<script lang="ts">
|
||||
import '../app.css';
|
||||
|
||||
|
||||
let { children } = $props();
|
||||
</script>
|
||||
|
||||
@@ -12,4 +12,3 @@
|
||||
<div class="min-h-screen bg-background text-foreground">
|
||||
{@render children?.()}
|
||||
</div>
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,6 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from 'svelte';
|
||||
import {
|
||||
topologyData,
|
||||
downloads,
|
||||
type DownloadProgress,
|
||||
refreshState,
|
||||
lastUpdate as lastUpdateStore
|
||||
} from '$lib/stores/app.svelte';
|
||||
import { topologyData, downloads, type DownloadProgress, refreshState, lastUpdate as lastUpdateStore } from '$lib/stores/app.svelte';
|
||||
import HeaderNav from '$lib/components/HeaderNav.svelte';
|
||||
|
||||
type FileProgress = {
|
||||
@@ -166,11 +160,7 @@
|
||||
|
||||
for (const [nodeId, nodeDownloads] of entries) {
|
||||
const modelMap = new Map<string, ModelEntry>();
|
||||
const nodeEntries = Array.isArray(nodeDownloads)
|
||||
? nodeDownloads
|
||||
: nodeDownloads && typeof nodeDownloads === 'object'
|
||||
? Object.values(nodeDownloads as Record<string, unknown>)
|
||||
: [];
|
||||
const nodeEntries = Array.isArray(nodeDownloads) ? nodeDownloads : nodeDownloads && typeof nodeDownloads === 'object' ? Object.values(nodeDownloads as Record<string, unknown>) : [];
|
||||
|
||||
for (const downloadWrapped of nodeEntries) {
|
||||
if (!downloadWrapped || typeof downloadWrapped !== 'object') continue;
|
||||
@@ -196,13 +186,17 @@
|
||||
return (meta.prettyName as string) ?? null;
|
||||
})();
|
||||
|
||||
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
|
||||
?? (downloadPayload as Record<string, unknown>).downloadProgress
|
||||
?? {};
|
||||
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
|
||||
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress ?? (downloadPayload as Record<string, unknown>).downloadProgress ?? {};
|
||||
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
|
||||
const totalBytes = getBytes(
|
||||
(downloadPayload as Record<string, unknown>).total_bytes ??
|
||||
(downloadPayload as Record<string, unknown>).totalBytes ??
|
||||
(rawProgress as Record<string, unknown>).total_bytes ??
|
||||
(rawProgress as Record<string, unknown>).totalBytes
|
||||
);
|
||||
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
|
||||
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
|
||||
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
|
||||
const speed = ((rawProgress as Record<string, unknown>).speed as number) ?? 0;
|
||||
const etaMs = ((rawProgress as Record<string, unknown>).eta_ms as number) ?? ((rawProgress as Record<string, unknown>).etaMs as number) ?? 0;
|
||||
const percentage = totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0;
|
||||
|
||||
const files: FileProgress[] = [];
|
||||
@@ -239,26 +233,25 @@
|
||||
const existing = modelMap.get(modelId);
|
||||
if (!existing) {
|
||||
modelMap.set(modelId, entry);
|
||||
} else if (
|
||||
(entry.status === 'completed' && existing.status !== 'completed') ||
|
||||
(entry.status === existing.status && entry.downloadedBytes > existing.downloadedBytes)
|
||||
) {
|
||||
} else if ((entry.status === 'completed' && existing.status !== 'completed') || (entry.status === existing.status && entry.downloadedBytes > existing.downloadedBytes)) {
|
||||
modelMap.set(modelId, entry);
|
||||
}
|
||||
}
|
||||
|
||||
let models = Array.from(modelMap.values()).sort((a, b) => b.percentage - a.percentage);
|
||||
if (models.length === 0 && nodeEntries.length > 0) {
|
||||
models = [{
|
||||
modelId: 'Unknown download',
|
||||
percentage: 0,
|
||||
downloadedBytes: 0,
|
||||
totalBytes: 0,
|
||||
speed: 0,
|
||||
etaMs: 0,
|
||||
status: 'downloading',
|
||||
files: []
|
||||
}];
|
||||
models = [
|
||||
{
|
||||
modelId: 'Unknown download',
|
||||
percentage: 0,
|
||||
downloadedBytes: 0,
|
||||
totalBytes: 0,
|
||||
speed: 0,
|
||||
etaMs: 0,
|
||||
status: 'downloading',
|
||||
files: []
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
built.push({
|
||||
@@ -332,8 +325,13 @@
|
||||
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
|
||||
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
|
||||
</div>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
|
||||
<div>
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
|
||||
</div>
|
||||
<div class="text-exo-light-gray normal-case tracking-normal">
|
||||
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -345,25 +343,19 @@
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-dark-gray/60 p-3 space-y-2">
|
||||
<div class="flex items-center justify-between gap-3">
|
||||
<div class="min-w-0 space-y-0.5">
|
||||
<div class="text-sm font-mono text-white truncate">{model.prettyName ?? model.modelId}</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono truncate">
|
||||
{model.modelId}
|
||||
</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
<div class="text-xs font-mono text-white truncate" title={model.prettyName ?? model.modelId}>{model.prettyName ?? model.modelId}</div>
|
||||
<div class="text-[10px] text-exo-light-gray font-mono truncate" title={model.modelId}>{model.modelId}</div>
|
||||
{#if model.status !== 'completed'}
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-xs font-mono {pct >= 100 ? 'text-green-400' : pct <= 0 ? 'text-red-400' : 'text-exo-yellow'}">
|
||||
{pct.toFixed(1)}%
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
onclick={() => toggleExpand(key)}
|
||||
aria-expanded={isExpanded}
|
||||
title="Toggle file details"
|
||||
>
|
||||
<button type="button" class="text-exo-light-gray hover:text-exo-yellow transition-colors" onclick={() => toggleExpand(key)} aria-expanded={isExpanded} title="Toggle file details">
|
||||
<svg class="w-4 h-4" viewBox="0 0 20 20" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<path d="M6 8l4 4 4-4" class={isExpanded ? 'transform rotate-180 origin-center transition-transform duration-150' : 'transition-transform duration-150'}></path>
|
||||
</svg>
|
||||
@@ -372,14 +364,11 @@
|
||||
</div>
|
||||
|
||||
<div class="relative h-2 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class={`absolute inset-y-0 left-0 bg-gradient-to-r ${gradient} transition-all duration-300`}
|
||||
style={`width: ${pct.toFixed(1)}%`}
|
||||
></div>
|
||||
<div class={`absolute inset-y-0 left-0 bg-gradient-to-r ${gradient} transition-all duration-300`} style={`width: ${pct.toFixed(1)}%`}></div>
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
{#if model.status !== 'completed'}
|
||||
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
|
||||
{/if}
|
||||
@@ -396,13 +385,10 @@
|
||||
<div class="rounded border border-exo-medium-gray/20 bg-exo-black/40 p-2 space-y-1">
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray/90">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="{fpct >= 100 ? 'text-green-400' : fpct <= 0 ? 'text-red-400' : 'text-exo-yellow'}">{fpct.toFixed(1)}%</span>
|
||||
<span class={fpct >= 100 ? 'text-green-400' : fpct <= 0 ? 'text-red-400' : 'text-exo-yellow'}>{fpct.toFixed(1)}%</span>
|
||||
</div>
|
||||
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class={`absolute inset-y-0 left-0 bg-gradient-to-r ${fgradient} transition-all duration-300`}
|
||||
style={`width: ${fpct.toFixed(1)}%`}
|
||||
></div>
|
||||
<div class={`absolute inset-y-0 left-0 bg-gradient-to-r ${fgradient} transition-all duration-300`} style={`width: ${fpct.toFixed(1)}%`}></div>
|
||||
</div>
|
||||
<div class="flex items-center justify-between text-[10px] text-exo-light-gray/70">
|
||||
<span>{formatBytes(f.downloadedBytes)} / {formatBytes(f.totalBytes)}</span>
|
||||
@@ -419,21 +405,20 @@
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.downloads-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
|
||||
grid-template-columns: repeat(auto-fill, minmax(320px, 1fr));
|
||||
}
|
||||
@media (min-width: 1024px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(3, minmax(0, 1fr));
|
||||
}
|
||||
}
|
||||
@media (min-width: 1440px) {
|
||||
@media (min-width: 1600px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(4, minmax(0, 1fr));
|
||||
}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import tailwindcss from '@tailwindcss/vite';
|
||||
import { sveltekit } from '@sveltejs/kit/vite';
|
||||
import { defineConfig } from 'vite';
|
||||
import tailwindcss from "@tailwindcss/vite";
|
||||
import { sveltekit } from "@sveltejs/kit/vite";
|
||||
import { defineConfig } from "vite";
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [tailwindcss(), sveltekit()],
|
||||
server: {
|
||||
proxy: {
|
||||
'/v1': 'http://localhost:52415',
|
||||
'/state': 'http://localhost:52415',
|
||||
'/models': 'http://localhost:52415',
|
||||
'/instance': 'http://localhost:52415'
|
||||
"/v1": "http://localhost:52415",
|
||||
"/state": "http://localhost:52415",
|
||||
"/models": "http://localhost:52415",
|
||||
"/instance": "http://localhost:52415"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
212
docs/api.md
Normal file
212
docs/api.md
Normal file
@@ -0,0 +1,212 @@
|
||||
# EXO API – Technical Reference
|
||||
|
||||
This document describes the REST API exposed by the **EXO ** service, as implemented in:
|
||||
|
||||
`src/exo/master/api.py`
|
||||
|
||||
The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using an OpenAI-compatible interface.
|
||||
|
||||
Base URL example:
|
||||
|
||||
```
|
||||
http://localhost:52415
|
||||
```
|
||||
|
||||
## 1. General / Meta Endpoints
|
||||
|
||||
### Get Master Node ID
|
||||
|
||||
**GET** `/node_id`
|
||||
|
||||
Returns the identifier of the current master node.
|
||||
|
||||
**Response (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"node_id": "node-1234"
|
||||
}
|
||||
```
|
||||
|
||||
### Get Cluster State
|
||||
|
||||
**GET** `/state`
|
||||
|
||||
Returns the current state of the cluster, including nodes and active instances.
|
||||
|
||||
**Response:**
|
||||
JSON object describing topology, nodes, and instances.
|
||||
|
||||
### Get Events
|
||||
|
||||
**GET** `/events`
|
||||
|
||||
Returns the list of internal events recorded by the master (mainly for debugging and observability).
|
||||
|
||||
**Response:**
|
||||
Array of event objects.
|
||||
|
||||
## 2. Model Instance Management
|
||||
|
||||
### Create Instance
|
||||
|
||||
**POST** `/instance`
|
||||
|
||||
Creates a new model instance in the cluster.
|
||||
|
||||
**Request body (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"instance": {
|
||||
"model_id": "llama-3.2-1b",
|
||||
"placement": { }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
JSON description of the created instance.
|
||||
|
||||
### Delete Instance
|
||||
|
||||
**DELETE** `/instance/{instance_id}`
|
||||
|
||||
Deletes an existing instance by ID.
|
||||
|
||||
**Path parameters:**
|
||||
|
||||
* `instance_id`: string, ID of the instance to delete
|
||||
|
||||
**Response:**
|
||||
Status / confirmation JSON.
|
||||
|
||||
### Get Instance
|
||||
|
||||
**GET** `/instance/{instance_id}`
|
||||
|
||||
Returns details of a specific instance.
|
||||
|
||||
**Path parameters:**
|
||||
|
||||
* `instance_id`: string
|
||||
|
||||
**Response:**
|
||||
JSON description of the instance.
|
||||
|
||||
### Preview Placements
|
||||
|
||||
**GET** `/instance/previews?model_id=...`
|
||||
|
||||
Returns possible placement previews for a given model.
|
||||
|
||||
**Query parameters:**
|
||||
|
||||
* `model_id`: string, required
|
||||
|
||||
**Response:**
|
||||
Array of placement preview objects.
|
||||
|
||||
### Compute Placement
|
||||
|
||||
**GET** `/instance/placement`
|
||||
|
||||
Computes a placement for a potential instance without creating it.
|
||||
|
||||
**Query parameters (typical):**
|
||||
|
||||
* `model_id`: string
|
||||
* `sharding`: string or config
|
||||
* `instance_meta`: JSON-encoded metadata
|
||||
* `min_nodes`: integer
|
||||
|
||||
**Response:**
|
||||
JSON object describing the proposed placement / instance configuration.
|
||||
|
||||
### Place Instance (Dry Operation)
|
||||
|
||||
**POST** `/place_instance`
|
||||
|
||||
Performs a placement operation for an instance (planning step), without necessarily creating it.
|
||||
|
||||
**Request body:**
|
||||
JSON describing the instance to be placed.
|
||||
|
||||
**Response:**
|
||||
Placement result.
|
||||
|
||||
## 3. Models
|
||||
|
||||
### List Models
|
||||
|
||||
**GET** `/models`
|
||||
**GET** `/v1/models` (alias)
|
||||
|
||||
Returns the list of available models and their metadata.
|
||||
|
||||
**Response:**
|
||||
Array of model descriptors.
|
||||
|
||||
## 4. Inference / Chat Completions
|
||||
|
||||
### OpenAI-Compatible Chat Completions
|
||||
|
||||
**POST** `/v1/chat/completions`
|
||||
|
||||
Executes a chat completion request using an OpenAI-compatible schema. Supports streaming and non-streaming modes.
|
||||
|
||||
**Request body (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama-3.2-1b",
|
||||
"messages": [
|
||||
{ "role": "system", "content": "You are a helpful assistant." },
|
||||
{ "role": "user", "content": "Hello" }
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
OpenAI-compatible chat completion response.
|
||||
|
||||
### Benchmarked Chat Completions
|
||||
|
||||
**POST** `/bench/chat/completions`
|
||||
|
||||
Same as `/v1/chat/completions`, but also returns performance and generation statistics.
|
||||
|
||||
**Request body:**
|
||||
Same schema as `/v1/chat/completions`.
|
||||
|
||||
**Response:**
|
||||
Chat completion plus benchmarking metrics.
|
||||
|
||||
## 5. Complete Endpoint Summary
|
||||
|
||||
```
|
||||
GET /node_id
|
||||
GET /state
|
||||
GET /events
|
||||
|
||||
POST /instance
|
||||
GET /instance/{instance_id}
|
||||
DELETE /instance/{instance_id}
|
||||
|
||||
GET /instance/previews
|
||||
GET /instance/placement
|
||||
POST /place_instance
|
||||
|
||||
GET /models
|
||||
GET /v1/models
|
||||
|
||||
POST /v1/chat/completions
|
||||
POST /bench/chat/completions
|
||||
```
|
||||
|
||||
## 6. Notes
|
||||
|
||||
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
|
||||
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
|
||||
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.
|
||||
185
flake.lock
generated
185
flake.lock
generated
@@ -1,5 +1,42 @@
|
||||
{
|
||||
"nodes": {
|
||||
"crane": {
|
||||
"locked": {
|
||||
"lastModified": 1767744144,
|
||||
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"dream2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"purescript-overlay": "purescript-overlay",
|
||||
"pyproject-nix": "pyproject-nix"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1765953015,
|
||||
"narHash": "sha256-5FBZbbWR1Csp3Y2icfRkxMJw/a/5FGg8hCXej2//bbI=",
|
||||
"owner": "nix-community",
|
||||
"repo": "dream2nix",
|
||||
"rev": "69eb01fa0995e1e90add49d8ca5bcba213b0416f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "dream2nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"fenix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
@@ -8,11 +45,11 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761893049,
|
||||
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
|
||||
"lastModified": 1768287139,
|
||||
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -21,25 +58,59 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"lastModified": 1696426674,
|
||||
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-parts": {
|
||||
"inputs": {
|
||||
"nixpkgs-lib": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1768135262,
|
||||
"narHash": "sha256-PVvu7OqHBGWN16zSi6tEmPwwHQ4rLPU9Plvs8/1TUBY=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"rev": "80daad04eddbbf5a4d883996a73f3f542fa437ac",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1768127708,
|
||||
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-swift": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
@@ -50,27 +121,74 @@
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"purescript-overlay": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat",
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"slimlock": "slimlock"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1728546539,
|
||||
"narHash": "sha256-Sws7w0tlnjD+Bjck1nv29NjC5DbL6nH5auL9Ex9Iz2A=",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "purescript-overlay",
|
||||
"rev": "4ad4c15d07bd899d7346b331f377606631eb0ee4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "purescript-overlay",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"crane": "crane",
|
||||
"dream2nix": "dream2nix",
|
||||
"fenix": "fenix",
|
||||
"flake-utils": "flake-utils",
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1761849405,
|
||||
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
|
||||
"lastModified": 1768224240,
|
||||
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"rev": "725349602e525df37f377701e001fe8aab807878",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -80,18 +198,25 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"slimlock": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"purescript-overlay",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"lastModified": 1688756706,
|
||||
"narHash": "sha256-xzkkMv3neJJJ89zo3o2ojp7nFeaZc2G0fYwNXNJRFlo=",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "slimlock",
|
||||
"rev": "cf72723f59e2340d24881fd7bf61cb113b4c407c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "slimlock",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
@@ -102,11 +227,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1762938485,
|
||||
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
|
||||
"lastModified": 1768158989,
|
||||
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
|
||||
"owner": "numtide",
|
||||
"repo": "treefmt-nix",
|
||||
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
|
||||
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
210
flake.nix
210
flake.nix
@@ -3,118 +3,148 @@
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
# Provides Rust dev-env integration:
|
||||
|
||||
flake-parts = {
|
||||
url = "github:hercules-ci/flake-parts";
|
||||
inputs.nixpkgs-lib.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
crane.url = "github:ipetkov/crane";
|
||||
|
||||
fenix = {
|
||||
url = "github:nix-community/fenix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
# Provides formatting infrastructure:
|
||||
|
||||
treefmt-nix = {
|
||||
url = "github:numtide/treefmt-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
dream2nix = {
|
||||
url = "github:nix-community/dream2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
|
||||
};
|
||||
|
||||
# TODO: figure out caching story
|
||||
# nixConfig = {
|
||||
# # nix community cachix
|
||||
# extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=";
|
||||
# extra-substituters = "https://nix-community.cachix.org";
|
||||
# };
|
||||
nixConfig = {
|
||||
extra-trusted-public-keys = "exo.cachix.org-1:okq7hl624TBeAR3kV+g39dUFSiaZgLRkLsFBCuJ2NZI=";
|
||||
extra-substituters = "https://exo.cachix.org";
|
||||
};
|
||||
|
||||
outputs =
|
||||
inputs:
|
||||
let
|
||||
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
|
||||
systems = [
|
||||
"x86_64-linux"
|
||||
"aarch64-darwin"
|
||||
"aarch64-linux"
|
||||
];
|
||||
fenixToolchain = system: inputs.fenix.packages.${system}.complete;
|
||||
in
|
||||
inputs.flake-utils.lib.eachSystem systems (
|
||||
system:
|
||||
let
|
||||
pkgs = import inputs.nixpkgs {
|
||||
inherit system;
|
||||
overlays = [ inputs.fenix.overlays.default ];
|
||||
};
|
||||
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
|
||||
projectRootFile = "flake.nix";
|
||||
programs.ruff-format.enable = true;
|
||||
programs.ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
programs.rustfmt.enable = true;
|
||||
programs.rustfmt.package = (fenixToolchain system).rustfmt;
|
||||
programs.nixpkgs-fmt.enable = true;
|
||||
};
|
||||
in
|
||||
{
|
||||
formatter = treefmtEval.config.build.wrapper;
|
||||
checks.formatting = treefmtEval.config.build.check inputs.self;
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
|
||||
devShells.default = pkgs.mkShell {
|
||||
packages =
|
||||
with pkgs;
|
||||
[
|
||||
# PYTHON
|
||||
python313
|
||||
uv
|
||||
ruff
|
||||
basedpyright
|
||||
imports = [
|
||||
inputs.treefmt-nix.flakeModule
|
||||
./dashboard/parts.nix
|
||||
./rust/parts.nix
|
||||
];
|
||||
|
||||
# RUST
|
||||
((fenixToolchain system).withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
])
|
||||
rustup # Just here to make RustRover happy
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
treefmt = {
|
||||
projectRootFile = "flake.nix";
|
||||
programs = {
|
||||
nixpkgs-fmt.enable = true;
|
||||
ruff-format = {
|
||||
enable = true;
|
||||
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
};
|
||||
rustfmt = {
|
||||
enable = true;
|
||||
package = config.rust.toolchain;
|
||||
};
|
||||
prettier = {
|
||||
enable = true;
|
||||
package = self'.packages.prettier-svelte;
|
||||
includes = [ "*.ts" "*.svelte" ];
|
||||
settings = {
|
||||
useTabs = true;
|
||||
singleQuote = false;
|
||||
printWidth = 200;
|
||||
arrowParens = "avoid";
|
||||
trailingComma = "none";
|
||||
overrides = [
|
||||
{
|
||||
files = "*.svelte";
|
||||
options.singleQuote = true;
|
||||
}
|
||||
];
|
||||
};
|
||||
};
|
||||
swift-format = {
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
|
||||
# SVELTE
|
||||
nodejs
|
||||
|
||||
# MISC
|
||||
just
|
||||
jq
|
||||
]
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
|
||||
# IFCONFIG
|
||||
unixtools.ifconfig
|
||||
|
||||
# Build dependencies for Linux
|
||||
pkg-config
|
||||
openssl
|
||||
])
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
|
||||
# MACMON
|
||||
macmon
|
||||
]);
|
||||
|
||||
shellHook = ''
|
||||
# PYTHON
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
|
||||
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
|
||||
# Build environment for Linux
|
||||
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
''}
|
||||
echo
|
||||
echo "🍎🍎 Run 'just <recipe>' to get started"
|
||||
just --list
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
inputsFrom = [ self'.checks.cargo-build ];
|
||||
|
||||
packages =
|
||||
[
|
||||
# FORMATTING
|
||||
config.treefmt.build.wrapper
|
||||
|
||||
# PYTHON
|
||||
python313
|
||||
uv
|
||||
ruff
|
||||
basedpyright
|
||||
|
||||
# RUST
|
||||
config.rust.toolchain
|
||||
maturin
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
|
||||
# SVELTE
|
||||
nodejs
|
||||
|
||||
# MISC
|
||||
just
|
||||
jq
|
||||
]
|
||||
++ lib.optionals stdenv.isLinux [
|
||||
unixtools.ifconfig
|
||||
]
|
||||
++ lib.optionals stdenv.isDarwin [
|
||||
macmon
|
||||
];
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
|
||||
shellHook = ''
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
|
||||
${lib.optionalString stdenv.isLinux ''
|
||||
export LD_LIBRARY_PATH="${openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
''}
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -8,31 +8,21 @@ dependencies = [
|
||||
"aiofiles>=24.1.0",
|
||||
"aiohttp>=3.12.14",
|
||||
"types-aiofiles>=24.1.0.20250708",
|
||||
"typeguard>=4.4.4",
|
||||
"pydantic>=2.11.7",
|
||||
"base58>=2.1.1",
|
||||
"cryptography>=45.0.5",
|
||||
"fastapi>=0.116.1",
|
||||
"filelock>=3.18.0",
|
||||
"aiosqlite>=0.21.0",
|
||||
"networkx>=3.5",
|
||||
"protobuf>=6.32.0",
|
||||
"rich>=14.1.0",
|
||||
"rustworkx>=0.17.1",
|
||||
"sqlmodel>=0.0.24",
|
||||
"sqlalchemy[asyncio]>=2.0.43",
|
||||
"greenlet>=3.2.4",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"psutil>=7.0.0",
|
||||
"loguru>=0.7.3",
|
||||
"textual>=5.3.0",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"bidict>=0.23.1",
|
||||
"mlx>=0.30.1",
|
||||
"mlx-lm>=0.28.3",
|
||||
"mlx==0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -43,6 +33,7 @@ exo = "exo.main:main"
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"basedpyright>=1.29.0",
|
||||
"pyinstaller>=6.17.0",
|
||||
"pytest>=8.4.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
@@ -80,7 +71,7 @@ build-backend = "uv_build"
|
||||
###
|
||||
|
||||
[tool.basedpyright]
|
||||
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"]
|
||||
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src", "bench"]
|
||||
typeCheckingMode = "strict"
|
||||
failOnWarnings = true
|
||||
|
||||
@@ -108,6 +99,7 @@ root = "src"
|
||||
|
||||
# supported platforms for this project
|
||||
[tool.uv]
|
||||
prerelease = "allow"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
|
||||
145
rust/parts.nix
Normal file
145
rust/parts.nix
Normal file
@@ -0,0 +1,145 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, ... }:
|
||||
let
|
||||
# Fenix nightly toolchain with all components
|
||||
fenixPkgs = inputs'.fenix.packages;
|
||||
rustToolchain = fenixPkgs.complete.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
"rust-analyzer"
|
||||
];
|
||||
|
||||
# Crane with fenix toolchain
|
||||
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain rustToolchain;
|
||||
|
||||
# Source filtering - only include rust/ directory and root Cargo files
|
||||
# This ensures changes to Python/docs/etc don't trigger Rust rebuilds
|
||||
src = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
parentDir = builtins.dirOf path;
|
||||
inRustDir =
|
||||
(lib.hasInfix "/rust/" path)
|
||||
|| (lib.hasSuffix "/rust" parentDir)
|
||||
|| (baseName == "rust" && type == "directory");
|
||||
isRootCargoFile =
|
||||
(baseName == "Cargo.toml" || baseName == "Cargo.lock")
|
||||
&& (builtins.dirOf path == toString inputs.self);
|
||||
in
|
||||
isRootCargoFile
|
||||
|| (inRustDir && (craneLib.filterCargoSources path type || lib.hasSuffix ".toml" path || lib.hasSuffix ".md" path));
|
||||
};
|
||||
|
||||
# Common arguments for all Rust builds
|
||||
commonArgs = {
|
||||
inherit src;
|
||||
pname = "exo-rust";
|
||||
version = "0.0.1";
|
||||
strictDeps = true;
|
||||
|
||||
nativeBuildInputs = [
|
||||
pkgs.pkg-config
|
||||
pkgs.python313 # Required for pyo3-build-config
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
pkgs.openssl
|
||||
pkgs.python313 # Required for pyo3 tests
|
||||
];
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
|
||||
# Required for pyo3 tests to find libpython
|
||||
LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.python313 ];
|
||||
};
|
||||
|
||||
# Build dependencies once for caching
|
||||
cargoArtifacts = craneLib.buildDepsOnly (
|
||||
commonArgs
|
||||
// {
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
in
|
||||
{
|
||||
# Export toolchain for use in treefmt and devShell
|
||||
options.rust = {
|
||||
toolchain = lib.mkOption {
|
||||
type = lib.types.package;
|
||||
default = rustToolchain;
|
||||
description = "The Rust toolchain to use";
|
||||
};
|
||||
};
|
||||
|
||||
config = {
|
||||
packages = {
|
||||
# Python bindings wheel via maturin
|
||||
exo_pyo3_bindings = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
pname = "exo_pyo3_bindings";
|
||||
|
||||
nativeBuildInputs = commonArgs.nativeBuildInputs ++ [
|
||||
pkgs.maturin
|
||||
];
|
||||
|
||||
buildPhaseCargoCommand = ''
|
||||
maturin build \
|
||||
--release \
|
||||
--manylinux off \
|
||||
--manifest-path rust/exo_pyo3_bindings/Cargo.toml \
|
||||
--features "pyo3/extension-module,pyo3/experimental-async" \
|
||||
--interpreter ${pkgs.python313}/bin/python \
|
||||
--out dist
|
||||
'';
|
||||
|
||||
# Don't use crane's default install behavior
|
||||
doNotPostBuildInstallCargoBinaries = true;
|
||||
|
||||
installPhaseCommand = ''
|
||||
mkdir -p $out
|
||||
cp dist/*.whl $out/
|
||||
'';
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Full workspace build (all crates)
|
||||
cargo-build = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
# Run tests with nextest
|
||||
cargo-nextest = craneLib.cargoNextest (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
|
||||
# Build documentation
|
||||
cargo-doc = craneLib.cargoDoc (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
};
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
[package]
|
||||
name = "system_custodian"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "system_custodian"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
path = "src/bin/main.rs"
|
||||
name = "system_custodian"
|
||||
doc = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
//! TODO: documentation
|
||||
//!
|
||||
|
||||
fn main() {}
|
||||
@@ -1,69 +0,0 @@
|
||||
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
|
||||
//!
|
||||
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
|
||||
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
|
||||
//! etc.) is in an appropriate state to facilitate the running of Exo application.
|
||||
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
|
||||
//! service which Exo application use to _control & query_ it.
|
||||
//!
|
||||
//! # Lifecycle
|
||||
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
|
||||
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
|
||||
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
|
||||
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
|
||||
//! destructive to the user's pre-existing configurations.
|
||||
//!
|
||||
//! # Responsibilities
|
||||
//! TODO: these are purely on MacOS, but change to be more broad
|
||||
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
|
||||
//! 1. duplicate the current network set
|
||||
//! 2. modify existing services to turn on IPv6 if not there
|
||||
//! 3. remove any bridge services & add any missing services that AREN'T bridge
|
||||
//! TODO: In the future:
|
||||
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
|
||||
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
|
||||
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
|
||||
//! logic, this would be the place to spin that up.
|
||||
//!
|
||||
//! Then it will watch the SCDynamicStore for:
|
||||
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
|
||||
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
|
||||
//! interface of any changes
|
||||
//! 2. watch for any __undesirable__ changes to configuration and revert it
|
||||
//!
|
||||
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
|
||||
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
|
||||
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
|
||||
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
|
||||
//! over D-Bus
|
||||
//! TODO:
|
||||
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
|
||||
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
|
||||
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
|
||||
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
|
||||
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
|
||||
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
|
||||
//! then this would be the place to carry out discovery and propper handshakes with devices
|
||||
//! on the other end of the link.
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {}
|
||||
@@ -1,5 +1,7 @@
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import resource
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
@@ -27,7 +29,7 @@ from exo.worker.main import Worker
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
worker: Worker
|
||||
worker: Worker | None
|
||||
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
|
||||
election_result_receiver: Receiver[ElectionResult]
|
||||
master: Master | None
|
||||
@@ -61,15 +63,19 @@ class Node:
|
||||
else:
|
||||
api = None
|
||||
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
)
|
||||
if not args.no_worker:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
|
||||
# We start every node with a master
|
||||
master = Master(
|
||||
node_id,
|
||||
@@ -99,8 +105,9 @@ class Node:
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.worker.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.worker:
|
||||
tg.start_soon(self.worker.run)
|
||||
if self.master:
|
||||
tg.start_soon(self.master.run)
|
||||
if self.api:
|
||||
@@ -189,11 +196,14 @@ class Node:
|
||||
|
||||
def main():
|
||||
args = Args.parse()
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
@@ -207,6 +217,7 @@ class Args(CamelCaseModel):
|
||||
spawn_api: bool = False
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -244,6 +255,10 @@ class Args(CamelCaseModel):
|
||||
dest="api_port",
|
||||
default=52415,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -13,6 +13,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
@@ -21,11 +27,16 @@ from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -56,7 +67,7 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
HIDE_THINKING = False
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
@@ -161,7 +172,10 @@ class API:
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/v1/chat/completions")(self.chat_completions)
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -177,17 +191,32 @@ class API:
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=command.model_meta,
|
||||
)
|
||||
|
||||
async def create_instance(
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
command = CreateInstance(instance=payload.instance)
|
||||
instance = payload.instance
|
||||
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
|
||||
required_memory = model_meta.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
if required_memory > available_memory:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Insufficient memory to create instance. Required: {required_memory.in_gb:.1f}GB, Available: {available_memory.in_gb:.1f}GB",
|
||||
)
|
||||
|
||||
command = CreateInstance(
|
||||
instance=instance,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
async def get_placement(
|
||||
@@ -207,7 +236,6 @@ class API:
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
@@ -263,7 +291,6 @@ class API:
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
@@ -354,32 +381,52 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
async for chunk in token_chunks:
|
||||
stream.process(chunk.token_id)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield chunk.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield chunk.model_copy(update={"text": delta})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if thinking:
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
yield chunk
|
||||
break
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
is_thinking = False
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
if HIDE_THINKING:
|
||||
if chunk.text == "<think>":
|
||||
is_thinking = True
|
||||
if chunk.text == "</think>":
|
||||
is_thinking = False
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
if not (is_thinking and HIDE_THINKING):
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
if parse_gpt_oss:
|
||||
async for chunk in self._process_gpt_oss(token_chunks):
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
else:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
@@ -394,6 +441,98 @@ class API:
|
||||
await self._send(command)
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
resp = BenchChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=combined_text
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
generation_stats=stats,
|
||||
)
|
||||
return resp
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
@@ -401,10 +540,12 @@ class API:
|
||||
|
||||
async def chat_completions(
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> StreamingResponse:
|
||||
"""Handle chat completions with proper streaming response."""
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
logger.info(f"{parse_gpt_oss=}")
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -419,17 +560,48 @@ class API:
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
media_type="text/event-stream",
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id, parse_gpt_oss),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(payload.model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {payload.model}"
|
||||
)
|
||||
|
||||
payload.stream = False
|
||||
|
||||
command = ChatCompletion(request_params=payload)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_chat_completion_with_stats(
|
||||
command.command_id,
|
||||
parse_gpt_oss,
|
||||
)
|
||||
return response
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
|
||||
for profile in self.state.node_profiles.values():
|
||||
total_available += profile.memory.ram_available
|
||||
for node in self.state.topology.list_nodes():
|
||||
if node.node_profile is not None:
|
||||
total_available += node.node_profile.memory.ram_available
|
||||
|
||||
return total_available
|
||||
|
||||
@@ -443,6 +615,8 @@ class API:
|
||||
name=card.name,
|
||||
description=card.description,
|
||||
tags=card.tags,
|
||||
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
|
||||
supports_tensor=card.metadata.supports_tensor,
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
]
|
||||
@@ -459,7 +633,7 @@ class API:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._applystate)
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
@@ -471,7 +645,7 @@ class API:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _applystate(self):
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
|
||||
@@ -158,7 +158,6 @@ class Master:
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_profiles,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
@@ -201,7 +200,9 @@ class Master:
|
||||
async def _plan(self) -> None:
|
||||
while True:
|
||||
# kill broken instances
|
||||
connected_node_ids = set([x for x in self.state.topology.list_nodes()])
|
||||
connected_node_ids = set(
|
||||
[x.node_id for x in self.state.topology.list_nodes()]
|
||||
)
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
|
||||
@@ -6,11 +6,10 @@ from typing import Sequence
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
NodeWithProfile,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_jaccl_devices_matrix,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
@@ -20,10 +19,10 @@ from exo.shared.types.commands import (
|
||||
DeleteInstance,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -31,6 +30,7 @@ from exo.shared.types.worker.instances import (
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def random_ephemeral_port() -> int:
|
||||
@@ -52,32 +52,55 @@ def place_instance(
|
||||
command: PlaceInstance,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> dict[InstanceId, Instance]:
|
||||
all_nodes = list(topology.list_nodes())
|
||||
|
||||
cycles = topology.get_cycles() + [[node] for node in all_nodes]
|
||||
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, node_profiles, command.model_meta.storage_size
|
||||
logger.info("finding cycles:")
|
||||
cycles = topology.get_cycles()
|
||||
singleton_cycles = [[node] for node in all_nodes]
|
||||
candidate_cycles = list(
|
||||
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
|
||||
)
|
||||
if len(cycles_with_sufficient_memory) == 0:
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, command.model_meta.storage_size
|
||||
)
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
if command.sharding == Sharding.Tensor:
|
||||
if not command.model_meta.supports_tensor:
|
||||
raise ValueError(
|
||||
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
|
||||
)
|
||||
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
|
||||
cycles_with_sufficient_memory = [
|
||||
cycle
|
||||
for cycle in cycles_with_sufficient_memory
|
||||
if command.model_meta.hidden_size % len(cycle) == 0
|
||||
]
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError(
|
||||
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
|
||||
)
|
||||
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
|
||||
"mlx-community/DeepSeek-V3.1-8bit"
|
||||
):
|
||||
raise ValueError(
|
||||
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
|
||||
)
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
smallest_tb_cycles = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if topology.get_subgraph_from_nodes(
|
||||
[node.node_id for node in cycle]
|
||||
).is_thunderbolt_cycle([node.node_id for node in cycle])
|
||||
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
|
||||
]
|
||||
|
||||
if smallest_tb_cycles != []:
|
||||
smallest_cycles = smallest_tb_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[list[NodeWithProfile]] = [
|
||||
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if any(topology.node_is_leaf(node.node_id) for node in cycle)
|
||||
@@ -86,7 +109,11 @@ def place_instance(
|
||||
selected_cycle = max(
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
|
||||
key=lambda cycle: sum(
|
||||
(node.node_profile.memory.ram_available for node in cycle),
|
||||
(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in cycle
|
||||
if node.node_profile is not None
|
||||
),
|
||||
start=Memory(),
|
||||
),
|
||||
)
|
||||
@@ -95,16 +122,14 @@ def place_instance(
|
||||
command.model_meta, selected_cycle, command.sharding
|
||||
)
|
||||
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(
|
||||
[node.node_id for node in selected_cycle]
|
||||
)
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
|
||||
|
||||
instance_id = InstanceId()
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
|
||||
if len(selected_cycle) == 1:
|
||||
logger.warning(
|
||||
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
|
||||
"You have likely selected ibv for a single node instance; falling back to MlxRing"
|
||||
)
|
||||
|
||||
command.instance_meta = InstanceMeta.MlxRing
|
||||
@@ -112,32 +137,33 @@ def place_instance(
|
||||
# TODO: Single node instances
|
||||
match command.instance_meta:
|
||||
case InstanceMeta.MlxJaccl:
|
||||
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
|
||||
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
|
||||
selected_cycle,
|
||||
cycle_digraph,
|
||||
)
|
||||
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
|
||||
coordinator=selected_cycle[0].node_id,
|
||||
selected_cycle,
|
||||
coordinator_port=random_ephemeral_port(),
|
||||
cycle_digraph=cycle_digraph,
|
||||
)
|
||||
target_instances[instance_id] = MlxJacclInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
hosts_by_node = get_mlx_ring_hosts_by_node(
|
||||
selected_cycle=selected_cycle,
|
||||
cycle_digraph=cycle_digraph,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
target_instances[instance_id] = MlxRingInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[
|
||||
Host(
|
||||
ip=host.ip,
|
||||
port=random_ephemeral_port(),
|
||||
)
|
||||
for host in hosts
|
||||
],
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
from typing import TypeGuard, cast
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
@@ -8,7 +9,7 @@ from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
@@ -23,32 +24,27 @@ class NodeWithProfile(BaseModel):
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
|
||||
return all(node.node_profile is not None for node in nodes)
|
||||
|
||||
|
||||
def filter_cycles_by_memory(
|
||||
cycles: list[list[NodeId]],
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
required_memory: Memory,
|
||||
) -> list[list[NodeWithProfile]]:
|
||||
filtered_cycles: list[list[NodeWithProfile]] = []
|
||||
cycles: list[list[NodeInfo]], required_memory: Memory
|
||||
) -> list[list[NodeInfo]]:
|
||||
filtered_cycles: list[list[NodeInfo]] = []
|
||||
for cycle in cycles:
|
||||
if not all(node in node_profiles for node in cycle):
|
||||
if not narrow_all_nodes(cycle):
|
||||
continue
|
||||
|
||||
total_mem = sum(
|
||||
(node_profiles[node].memory.ram_available for node in cycle), start=Memory()
|
||||
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
|
||||
)
|
||||
if total_mem >= required_memory:
|
||||
filtered_cycles.append(
|
||||
[
|
||||
NodeWithProfile(node_id=node, node_profile=node_profiles[node])
|
||||
for node in cycle
|
||||
]
|
||||
)
|
||||
filtered_cycles.append(cast(list[NodeInfo], cycle))
|
||||
return filtered_cycles
|
||||
|
||||
|
||||
def get_smallest_cycles(
|
||||
cycles: list[list[NodeWithProfile]],
|
||||
) -> list[list[NodeWithProfile]]:
|
||||
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
|
||||
min_nodes = min(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
@@ -139,9 +135,11 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
def get_shard_assignments(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
selected_cycle: list[NodeInfo],
|
||||
sharding: Sharding,
|
||||
) -> ShardAssignments:
|
||||
if not narrow_all_nodes(selected_cycle):
|
||||
raise ValueError("All nodes must have profiles to create shard assignments")
|
||||
match sharding:
|
||||
case Sharding.Pipeline:
|
||||
return get_shard_assignments_for_pipeline_parallel(
|
||||
@@ -178,16 +176,17 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
current_node = cycle[i]
|
||||
next_node = cycle[(i + 1) % len(cycle)]
|
||||
|
||||
for src, sink, connection in cycle_digraph.list_connections():
|
||||
if not isinstance(connection, SocketConnection):
|
||||
continue
|
||||
|
||||
if src == current_node and sink == next_node:
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == current_node.node_id
|
||||
and connection.send_back_node_id == next_node.node_id
|
||||
):
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
assert connection.send_back_multiaddr is not None
|
||||
host = Host(
|
||||
ip=connection.sink_multiaddr.ip_address,
|
||||
port=connection.sink_multiaddr.port,
|
||||
ip=connection.send_back_multiaddr.ip_address,
|
||||
port=connection.send_back_multiaddr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
@@ -195,7 +194,8 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
return hosts
|
||||
|
||||
|
||||
def get_mlx_jaccl_devices_matrix(
|
||||
def get_mlx_ibv_devices_matrix(
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
) -> list[list[str | None]]:
|
||||
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
|
||||
@@ -204,7 +204,6 @@ def get_mlx_jaccl_devices_matrix(
|
||||
to device j, or None if no connection exists or no interface name is found.
|
||||
Diagonal elements are always None.
|
||||
"""
|
||||
selected_cycle = list(cycle_digraph.list_nodes())
|
||||
num_nodes = len(selected_cycle)
|
||||
matrix: list[list[str | None]] = [
|
||||
[None for _ in range(num_nodes)] for _ in range(num_nodes)
|
||||
@@ -215,55 +214,192 @@ def get_mlx_jaccl_devices_matrix(
|
||||
if i == j:
|
||||
continue
|
||||
|
||||
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(conn, RDMAConnection):
|
||||
matrix[i][j] = conn.source_rdma_iface
|
||||
# Find the IP J uses to talk to I
|
||||
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
# This is a local IP on I, which is attached to an interface: find that interface
|
||||
if interface_name := _find_rdma_interface_name_for_ip(
|
||||
connection_ip, node_i
|
||||
):
|
||||
matrix[i][j] = interface_name
|
||||
logger.info(
|
||||
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all-to-all RDMA connections"
|
||||
"Current ibv backend requires all-to-all rdma connections"
|
||||
)
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def _find_connection_ip(
|
||||
node_i: NodeId,
|
||||
node_j: NodeId,
|
||||
node_i: NodeInfo,
|
||||
node_j: NodeInfo,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[str]:
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
# TODO: Prioritise ETHERNET > ??WIFI > TB for coordinator
|
||||
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(connection, SocketConnection):
|
||||
yield connection.sink_multiaddr.ip_address
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_i.node_id
|
||||
and connection.send_back_node_id == node_j.node_id
|
||||
):
|
||||
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_rdma_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
|
||||
continue
|
||||
logger.info(f" | {interface.name}: {interface.ip_address}")
|
||||
if interface.ip_address != ip_address:
|
||||
continue
|
||||
|
||||
logger.info("Found")
|
||||
return f"rdma_{interface.name}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
|
||||
) -> str | None:
|
||||
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
Priority order:
|
||||
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
|
||||
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
"""
|
||||
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
|
||||
# We expect a unique iface -> ip mapping
|
||||
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
|
||||
|
||||
en0_ip = iface_map.get("en0")
|
||||
if en0_ip:
|
||||
return en0_ip
|
||||
|
||||
en1_ip = iface_map.get("en1")
|
||||
if en1_ip:
|
||||
return en1_ip
|
||||
|
||||
non_thunderbolt_ip = next(
|
||||
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
|
||||
)
|
||||
|
||||
if non_thunderbolt_ip:
|
||||
return non_thunderbolt_ip
|
||||
|
||||
if ips:
|
||||
return ips[0][0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
ephemeral_port: int,
|
||||
) -> dict[NodeId, list[Host]]:
|
||||
"""Generate per-node host lists for MLX ring backend.
|
||||
|
||||
Each node gets a list where:
|
||||
- Self position: Host(ip="0.0.0.0", port=ephemeral_port)
|
||||
- Left/right neighbors: actual connection IPs
|
||||
- Non-neighbors: Host(ip="198.51.100.1", port=0) placeholder (RFC 5737 TEST-NET-2)
|
||||
"""
|
||||
world_size = len(selected_cycle)
|
||||
if world_size == 0:
|
||||
return {}
|
||||
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
|
||||
for rank, node in enumerate(selected_cycle):
|
||||
node_id = node.node_id
|
||||
left_rank = (rank - 1) % world_size
|
||||
right_rank = (rank + 1) % world_size
|
||||
|
||||
hosts_for_node: list[Host] = []
|
||||
|
||||
for idx, other_node in enumerate(selected_cycle):
|
||||
if idx == rank:
|
||||
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
continue
|
||||
|
||||
if idx not in {left_rank, right_rank}:
|
||||
# Placeholder IP from RFC 5737 TEST-NET-2
|
||||
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
|
||||
continue
|
||||
|
||||
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
)
|
||||
|
||||
hosts_for_node.append(Host(ip=connection_ip, port=ephemeral_port))
|
||||
|
||||
hosts_by_node[node_id] = hosts_for_node
|
||||
|
||||
return hosts_by_node
|
||||
|
||||
|
||||
def get_mlx_jaccl_coordinators(
|
||||
coordinator: NodeId,
|
||||
selected_cycle: list[NodeInfo],
|
||||
coordinator_port: int,
|
||||
cycle_digraph: Topology,
|
||||
) -> dict[NodeId, str]:
|
||||
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
|
||||
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
|
||||
|
||||
Select an IP address that each node can reach for the rank 0 node. Returns
|
||||
address in format "X.X.X.X:PORT" per node.
|
||||
"""
|
||||
selected_cycle = list(cycle_digraph.list_nodes())
|
||||
logger.info(f"Selecting coordinator: {coordinator}")
|
||||
rank_0_node = selected_cycle[0]
|
||||
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
|
||||
def get_ip_for_node(n: NodeId) -> str:
|
||||
if n == coordinator:
|
||||
def get_ip_for_node(n: NodeInfo) -> str:
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
return "0.0.0.0"
|
||||
|
||||
for ip in _find_connection_ip(n, coordinator, cycle_digraph):
|
||||
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
|
||||
if ip:
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
f"Failed to find directly connected ip between {n} and {coordinator}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all participating devices to be able to communicate"
|
||||
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
|
||||
)
|
||||
raise ValueError("Current ibv backend requires all-to-all rdma connections")
|
||||
|
||||
return {n: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle}
|
||||
return {
|
||||
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
|
||||
}
|
||||
|
||||
@@ -1,36 +1,67 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
|
||||
|
||||
|
||||
def create_node_profile(memory: int) -> NodePerformanceProfile:
|
||||
return NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
@pytest.fixture
|
||||
def create_node():
|
||||
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return NodeInfo(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
|
||||
return _create_node
|
||||
|
||||
|
||||
# TODO: this is a hack to get the port for the send_back_multiaddr
|
||||
def create_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
|
||||
)
|
||||
@pytest.fixture
|
||||
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
||||
port_counter = 1235
|
||||
ip_counter = 1
|
||||
|
||||
def _create_connection(
|
||||
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
|
||||
) -> Connection:
|
||||
nonlocal port_counter
|
||||
nonlocal ip_counter
|
||||
# assign unique ips
|
||||
ip_counter += 1
|
||||
if send_back_port is None:
|
||||
send_back_port = port_counter
|
||||
port_counter += 1
|
||||
return Connection(
|
||||
local_node_id=source_node_id,
|
||||
send_back_node_id=sink_node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
|
||||
),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
def create_rdma_connection(iface: int) -> RDMAConnection:
|
||||
return RDMAConnection(
|
||||
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
|
||||
)
|
||||
return _create_connection
|
||||
|
||||
@@ -19,13 +19,15 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
NodeGatheredInfo,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
@@ -81,14 +83,21 @@ async def test_master():
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodeGatheredInfo(
|
||||
NodePerformanceMeasured(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
info=MemoryUsage(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
),
|
||||
@@ -114,6 +123,8 @@ async def test_master():
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
@@ -152,34 +163,40 @@ async def test_master():
|
||||
assert events[0].idx == 0
|
||||
assert events[1].idx == 1
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodeGatheredInfo)
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
runner_id = list(
|
||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||
)[0]
|
||||
assert events[1].event.instance == MlxRingInstance(
|
||||
instance_id=events[1].event.instance.instance_id,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
created_instance = events[1].event.instance
|
||||
assert isinstance(created_instance, MlxRingInstance)
|
||||
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
|
||||
# Validate the shard assignments
|
||||
expected_shard_assignments = ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
),
|
||||
hosts=[],
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
)
|
||||
assert created_instance.shard_assignments == expected_shard_assignments
|
||||
# For single-node, hosts_by_node should have one entry with self-binding
|
||||
assert len(created_instance.hosts_by_node) == 1
|
||||
assert node_id in created_instance.hosts_by_node
|
||||
assert len(created_instance.hosts_by_node[node_id]) == 1
|
||||
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
|
||||
assert created_instance.ephemeral_port > 0
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
|
||||
@@ -5,20 +7,14 @@ from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
create_connection,
|
||||
create_node_profile,
|
||||
create_rdma_connection,
|
||||
)
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -30,6 +26,11 @@ from exo.shared.types.worker.runners import ShardAssignments
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
return Topology()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def instance() -> Instance:
|
||||
return MlxRingInstance(
|
||||
@@ -37,7 +38,8 @@ def instance() -> Instance:
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
|
||||
),
|
||||
hosts=[],
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,6 +50,8 @@ def model_meta() -> ModelMetadata:
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -73,33 +77,34 @@ def test_get_instance_placements_create_instance(
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
model_meta.n_layers = total_layers
|
||||
model_meta.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
) # make it exactly fit across all nodes
|
||||
topology = Topology()
|
||||
|
||||
cic = place_instance_command(model_meta)
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(available_memory[0]),
|
||||
node_id_b: create_node_profile(available_memory[1]),
|
||||
node_id_c: create_node_profile(available_memory[2]),
|
||||
}
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_connection(node_id_a, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_c, create_connection(2))
|
||||
topology.add_connection(node_id_c, node_id_a, create_connection(3))
|
||||
topology.add_node(create_node(available_memory[0], node_id_a))
|
||||
topology.add_node(create_node(available_memory[1], node_id_b))
|
||||
topology.add_node(create_node(available_memory[2], node_id_c))
|
||||
# Add bidirectional connections for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
@@ -125,20 +130,23 @@ def test_get_instance_placements_create_instance(
|
||||
assert shards_sorted[-1].end_layer == total_layers
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
def test_get_instance_placements_one_node_exact_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -149,20 +157,23 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1001 * 1024)}
|
||||
topology.add_node(create_node(1001 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -173,22 +184,25 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
def test_get_instance_placements_one_node_not_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1001),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
|
||||
place_instance(cic, topology, {}, profiles)
|
||||
place_instance(cic, topology, {})
|
||||
|
||||
|
||||
def test_get_transition_events_no_change(instance: Instance):
|
||||
@@ -233,103 +247,179 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
assert events[0].instance_id == instance_id
|
||||
|
||||
|
||||
def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
def test_placement_selects_cycle_with_most_memory(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
topology = Topology()
|
||||
# Arrange two 3-node cycles with different total memory.
|
||||
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
|
||||
# The algorithm should select the cycle with the most available memory.
|
||||
|
||||
model_meta.storage_size = Memory.from_bytes(1000)
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
model_meta.n_layers = 12
|
||||
|
||||
# Create node ids
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
node_id_d = NodeId()
|
||||
node_id_e = NodeId()
|
||||
node_id_f = NodeId()
|
||||
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(500),
|
||||
node_id_b: create_node_profile(600),
|
||||
node_id_c: create_node_profile(600),
|
||||
node_id_d: create_node_profile(500),
|
||||
}
|
||||
# A-B-C cycle total memory = 1600 (< D-E-F total)
|
||||
topology.add_node(create_node(400, node_id_a))
|
||||
topology.add_node(create_node(400, node_id_b))
|
||||
topology.add_node(create_node(800, node_id_c))
|
||||
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_node(node_id_d)
|
||||
# D-E-F cycle total memory = 1800 (> A-B-C total)
|
||||
topology.add_node(create_node(600, node_id_d))
|
||||
topology.add_node(create_node(600, node_id_e))
|
||||
topology.add_node(create_node(600, node_id_f))
|
||||
|
||||
# Daisy chain topology
|
||||
topology.add_connection(node_id_a, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_a, create_connection(1))
|
||||
topology.add_connection(node_id_b, node_id_c, create_connection(1))
|
||||
topology.add_connection(node_id_c, node_id_b, create_connection(1))
|
||||
topology.add_connection(node_id_c, node_id_d, create_connection(1))
|
||||
topology.add_connection(node_id_d, node_id_c, create_connection(1))
|
||||
# Build bidirectional cycles for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
logger.info(list(topology.list_connections()))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_f))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_f))
|
||||
|
||||
cic = place_instance_command(
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
# Act
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# assert
|
||||
# Assert: D-E-F cycle should be selected as it has more total memory
|
||||
assert len(placements) == 1
|
||||
instance = list(placements.values())[0]
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
|
||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
|
||||
(node_id_c, node_id_d)
|
||||
)
|
||||
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
|
||||
assert more_memory_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
topology = Topology()
|
||||
model_meta.n_layers = 12
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
node_c = NodeId()
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
|
||||
profiles = {
|
||||
node_a: create_node_profile(500),
|
||||
node_b: create_node_profile(500),
|
||||
node_c: create_node_profile(500),
|
||||
}
|
||||
node_a = create_node(500, node_id_a)
|
||||
node_b = create_node(500, node_id_b)
|
||||
node_c = create_node(500, node_id_c)
|
||||
|
||||
ethernet_interface = NetworkInterfaceInfo(
|
||||
name="en0",
|
||||
ip_address="192.168.1.100",
|
||||
)
|
||||
ethernet_conn = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/192.168.1.{100}/tcp/{8000}")
|
||||
)
|
||||
|
||||
profiles[node_a].network_interfaces = [ethernet_interface]
|
||||
profiles[node_b].network_interfaces = [ethernet_interface]
|
||||
profiles[node_c].network_interfaces = [ethernet_interface]
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
conn_a_b = create_connection(node_id_a, node_id_b)
|
||||
conn_b_c = create_connection(node_id_b, node_id_c)
|
||||
conn_c_a = create_connection(node_id_c, node_id_a)
|
||||
|
||||
conn_b_a = create_connection(node_id_b, node_id_a)
|
||||
conn_c_b = create_connection(node_id_c, node_id_b)
|
||||
conn_a_c = create_connection(node_id_a, node_id_c)
|
||||
|
||||
assert conn_a_b.send_back_multiaddr is not None
|
||||
assert conn_b_c.send_back_multiaddr is not None
|
||||
assert conn_c_a.send_back_multiaddr is not None
|
||||
|
||||
assert conn_b_a.send_back_multiaddr is not None
|
||||
assert conn_c_b.send_back_multiaddr is not None
|
||||
assert conn_a_c.send_back_multiaddr is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
topology.add_connection(node_a, node_b, create_rdma_connection(3))
|
||||
topology.add_connection(node_b, node_c, create_rdma_connection(4))
|
||||
topology.add_connection(node_c, node_a, create_rdma_connection(5))
|
||||
topology.add_connection(node_b, node_a, create_rdma_connection(3))
|
||||
topology.add_connection(node_c, node_b, create_rdma_connection(4))
|
||||
topology.add_connection(node_a, node_c, create_rdma_connection(5))
|
||||
|
||||
topology.add_connection(node_a, node_b, ethernet_conn)
|
||||
topology.add_connection(node_b, node_c, ethernet_conn)
|
||||
topology.add_connection(node_c, node_a, ethernet_conn)
|
||||
topology.add_connection(node_a, node_c, ethernet_conn)
|
||||
topology.add_connection(node_b, node_a, ethernet_conn)
|
||||
topology.add_connection(node_c, node_b, ethernet_conn)
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cic = PlaceInstance(
|
||||
sharding=Sharding.Tensor,
|
||||
@@ -339,7 +429,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
min_nodes=1,
|
||||
)
|
||||
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -347,10 +437,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
|
||||
assert instance.jaccl_devices is not None
|
||||
assert instance.ibv_devices is not None
|
||||
assert instance.jaccl_coordinators is not None
|
||||
|
||||
matrix = instance.jaccl_devices
|
||||
matrix = instance.ibv_devices
|
||||
assert len(matrix) == 3
|
||||
|
||||
for i in range(3):
|
||||
@@ -359,15 +449,15 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
|
||||
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
|
||||
|
||||
idx_a = node_to_idx[node_a]
|
||||
idx_b = node_to_idx[node_b]
|
||||
idx_c = node_to_idx[node_c]
|
||||
idx_a = node_to_idx[node_id_a]
|
||||
idx_b = node_to_idx[node_id_b]
|
||||
idx_c = node_to_idx[node_id_c]
|
||||
|
||||
logger.info(matrix)
|
||||
|
||||
assert matrix[idx_a][idx_b] == "rdma_en3"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en4"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en5"
|
||||
assert matrix[idx_a][idx_b] == "rdma_en4"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en3"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
||||
|
||||
# Verify coordinators are set for all nodes
|
||||
assert len(instance.jaccl_coordinators) == 3
|
||||
|
||||
@@ -1,48 +1,56 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
NodeWithProfile,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import create_connection, create_node_profile
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory():
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
topology = Topology()
|
||||
return topology
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
|
||||
connection1 = create_connection(1)
|
||||
connection2 = create_connection(2)
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
|
||||
topology.add_connection(node1_id, node2_id, connection1)
|
||||
topology.add_connection(node2_id, node1_id, connection2)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
assert len(cycles) == 1
|
||||
assert len(cycles[0]) == 2
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_bytes(1)
|
||||
)
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
@@ -50,65 +58,64 @@ def test_filter_cycles_by_memory():
|
||||
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
|
||||
|
||||
|
||||
def test_filter_cycles_by_insufficient_memory():
|
||||
def test_filter_cycles_by_insufficient_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
|
||||
connection1 = create_connection(1)
|
||||
connection2 = create_connection(2)
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
|
||||
topology.add_connection(node1_id, node2_id, connection1)
|
||||
topology.add_connection(node2_id, node1_id, connection2)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
|
||||
topology.get_cycles(), Memory.from_kb(2001)
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 0
|
||||
|
||||
|
||||
def test_filter_multiple_cycles_by_memory():
|
||||
def test_filter_multiple_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node_profile(500 * 1024)
|
||||
node_b = create_node_profile(500 * 1024)
|
||||
node_c = create_node_profile(1000 * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_b_id))
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_kb(1500)
|
||||
)
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
@@ -120,38 +127,31 @@ def test_filter_multiple_cycles_by_memory():
|
||||
}
|
||||
|
||||
|
||||
def test_get_smallest_cycles():
|
||||
def test_get_smallest_cycles(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node_profile(500 * 1024)
|
||||
node_b = create_node_profile(500 * 1024)
|
||||
node_c = create_node_profile(1000 * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
|
||||
cycles = [
|
||||
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
|
||||
for cycle in topology.get_cycles()
|
||||
]
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
# act
|
||||
smallest_cycles = get_smallest_cycles(cycles)
|
||||
smallest_cycles = get_smallest_cycles(topology.get_cycles())
|
||||
|
||||
# assert
|
||||
assert len(smallest_cycles) == 1
|
||||
@@ -168,6 +168,9 @@ def test_get_smallest_cycles():
|
||||
],
|
||||
)
|
||||
def test_get_shard_assignments(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
@@ -176,37 +179,29 @@ def test_get_shard_assignments(
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
node_a = create_node_profile(available_memory[0] * 1024)
|
||||
node_b = create_node_profile(available_memory[1] * 1024)
|
||||
node_c = create_node_profile(available_memory[2] * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
node_a = create_node(available_memory[0] * 1024, node_a_id)
|
||||
node_b = create_node(available_memory[1] * 1024, node_b_id)
|
||||
node_c = create_node(available_memory[2] * 1024, node_c_id)
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_c_id, create_connection(2))
|
||||
topology.add_connection(node_c_id, node_a_id, create_connection(3))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(4))
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
pretty_name="Test Model",
|
||||
n_layers=total_layers,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
cycles = [
|
||||
[NodeWithProfile(node_id=nid, node_profile=node_profiles[nid]) for nid in cycle]
|
||||
for cycle in topology.get_cycles()
|
||||
]
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
# act
|
||||
@@ -235,21 +230,28 @@ def test_get_shard_assignments(
|
||||
)
|
||||
|
||||
|
||||
def test_get_hosts_from_subgraph():
|
||||
def test_get_hosts_from_subgraph(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
node_a = create_node(500, node_a_id)
|
||||
node_b = create_node(500, node_b_id)
|
||||
node_c = create_node(1000, node_c_id)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
|
||||
|
||||
# act
|
||||
hosts = get_hosts_from_subgraph(topology)
|
||||
@@ -257,47 +259,108 @@ def test_get_hosts_from_subgraph():
|
||||
# assert
|
||||
assert len(hosts) == 3
|
||||
expected_hosts = [
|
||||
Host(ip=("169.254.0.2"), port=1234),
|
||||
Host(ip=("169.254.0.3"), port=1234),
|
||||
Host(ip=("169.254.0.4"), port=1234),
|
||||
Host(ip=("169.254.0.2"), port=5001),
|
||||
Host(ip=("169.254.0.3"), port=5002),
|
||||
Host(ip=("169.254.0.4"), port=5003),
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_jaccl_coordinators():
|
||||
def test_get_mlx_jaccl_coordinators(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, create_connection(1))
|
||||
topology.add_connection(node_b_id, node_a_id, create_connection(2))
|
||||
topology.add_connection(node_a_id, node_c_id, create_connection(3))
|
||||
topology.add_connection(node_c_id, node_b_id, create_connection(4))
|
||||
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
|
||||
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
|
||||
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
|
||||
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
|
||||
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
|
||||
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
|
||||
|
||||
conn_a_b = create_connection(1)
|
||||
conn_b_a = create_connection(2)
|
||||
conn_b_c = create_connection(3)
|
||||
conn_c_b = create_connection(4)
|
||||
conn_c_a = create_connection(5)
|
||||
conn_a_c = create_connection(6)
|
||||
# Update node profiles with network interfaces before adding to topology
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
topology.add_connection(node_a_id, node_b_id, conn_a_b)
|
||||
topology.add_connection(node_b_id, node_a_id, conn_b_a)
|
||||
topology.add_connection(node_b_id, node_c_id, conn_b_c)
|
||||
topology.add_connection(node_c_id, node_b_id, conn_c_b)
|
||||
topology.add_connection(node_c_id, node_a_id, conn_c_a)
|
||||
topology.add_connection(node_a_id, node_c_id, conn_a_c)
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cycle = [node_a, node_b, node_c]
|
||||
|
||||
# act
|
||||
coordinators = get_mlx_jaccl_coordinators(
|
||||
node_a_id, coordinator_port=5000, cycle_digraph=topology
|
||||
cycle, coordinator_port=5000, cycle_digraph=topology
|
||||
)
|
||||
|
||||
# assert
|
||||
@@ -326,11 +389,11 @@ def test_get_mlx_jaccl_coordinators():
|
||||
|
||||
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
|
||||
# node_b uses the IP from conn_b_a (node_b -> node_a)
|
||||
assert coordinators[node_b_id] == (f"{conn_b_a.sink_multiaddr.ip_address}:5000"), (
|
||||
"node_b should use the IP from conn_b_a"
|
||||
)
|
||||
assert coordinators[node_b_id] == (
|
||||
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_b should use the IP from conn_b_a"
|
||||
|
||||
# node_c uses the IP from conn_c_a (node_c -> node_a)
|
||||
assert coordinators[node_c_id] == (f"{conn_c_a.sink_multiaddr.ip_address}:5000"), (
|
||||
"node_c should use the IP from conn_c_a"
|
||||
)
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -17,15 +16,20 @@ def topology() -> Topology:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection() -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
def connection() -> Connection:
|
||||
return Connection(
|
||||
local_node_id=NodeId(),
|
||||
send_back_node_id=NodeId(),
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_profile() -> NodePerformanceProfile:
|
||||
memory_profile = MemoryUsage.from_bytes(
|
||||
memory_profile = MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
)
|
||||
system_profile = SystemPerformanceProfile()
|
||||
@@ -39,85 +43,162 @@ def node_profile() -> NodePerformanceProfile:
|
||||
)
|
||||
|
||||
|
||||
def test_add_node(topology: Topology):
|
||||
@pytest.fixture
|
||||
def connection_profile() -> ConnectionProfile:
|
||||
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
||||
|
||||
|
||||
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
|
||||
# arrange
|
||||
node_id = NodeId()
|
||||
|
||||
# act
|
||||
topology.add_node(node_id)
|
||||
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
|
||||
|
||||
# assert
|
||||
assert topology.node_is_leaf(node_id)
|
||||
data = topology.get_node_profile(node_id)
|
||||
assert data == node_profile
|
||||
|
||||
|
||||
def test_add_connection(topology: Topology, connection: SocketConnection):
|
||||
def test_add_connection(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
data = list(conn for _, _, conn in topology.list_connections())
|
||||
data = topology.get_connection_profile(connection)
|
||||
|
||||
# assert
|
||||
assert data == [connection]
|
||||
assert data == connection.connection_profile
|
||||
|
||||
assert topology.node_is_leaf(node_a)
|
||||
assert topology.node_is_leaf(node_b)
|
||||
|
||||
def test_update_node_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_node_profile(
|
||||
connection.local_node_id, node_profile=new_node_profile
|
||||
)
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(connection.local_node_id)
|
||||
assert data == new_node_profile
|
||||
|
||||
|
||||
def test_update_connection_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_connection_profile = ConnectionProfile(
|
||||
throughput=2000, latency=2000, jitter=2000
|
||||
)
|
||||
connection = Connection(
|
||||
local_node_id=connection.local_node_id,
|
||||
send_back_node_id=connection.send_back_node_id,
|
||||
send_back_multiaddr=connection.send_back_multiaddr,
|
||||
connection_profile=new_connection_profile,
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_connection_profile(connection)
|
||||
|
||||
# assert
|
||||
data = topology.get_connection_profile(connection)
|
||||
assert data == new_connection_profile
|
||||
|
||||
|
||||
def test_remove_connection_still_connected(
|
||||
topology: Topology, connection: SocketConnection
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
topology.remove_connection(node_a, node_b, connection)
|
||||
topology.remove_connection(connection)
|
||||
|
||||
# assert
|
||||
assert list(topology.get_all_connections_between(node_a, node_b)) == []
|
||||
assert topology.get_connection_profile(connection) is None
|
||||
|
||||
|
||||
def test_remove_node_still_connected(topology: Topology, connection: SocketConnection):
|
||||
def test_remove_node_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
topology.remove_node(node_b)
|
||||
topology.remove_node(connection.local_node_id)
|
||||
|
||||
# assert
|
||||
assert list(topology.out_edges(node_a)) == []
|
||||
assert topology.get_node_profile(connection.local_node_id) is None
|
||||
|
||||
|
||||
def test_list_nodes(topology: Topology, connection: SocketConnection):
|
||||
def test_list_nodes(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(node_a, node_b, connection)
|
||||
assert list(topology.out_edges(node_a)) == [(node_b, connection)]
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
nodes = list(topology.list_nodes())
|
||||
|
||||
# assert
|
||||
assert len(nodes) == 2
|
||||
assert all(isinstance(node, NodeId) for node in nodes)
|
||||
assert {node for node in nodes} == {node_a, node_b}
|
||||
assert all(isinstance(node, NodeInfo) for node in nodes)
|
||||
assert {node.node_id for node in nodes} == {
|
||||
connection.local_node_id,
|
||||
connection.send_back_node_id,
|
||||
}
|
||||
|
||||
@@ -11,8 +11,10 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
NodeCreated,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
NodeTimedOut,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
@@ -25,23 +27,13 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import RDMAConnection
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import (
|
||||
MacmonMetrics,
|
||||
MacTBConnections,
|
||||
MacTBIdentifiers,
|
||||
MemoryUsage,
|
||||
MiscData,
|
||||
NodeConfig,
|
||||
NodeNetworkInterfaces,
|
||||
StaticNodeInformation,
|
||||
)
|
||||
|
||||
|
||||
def event_apply(event: Event, state: State) -> State:
|
||||
@@ -55,12 +47,16 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case NodeCreated():
|
||||
return apply_topology_node_created(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodePerformanceMeasured():
|
||||
return apply_node_performance_measured(event, state)
|
||||
case NodeDownloadProgress():
|
||||
return apply_node_download_progress(event, state)
|
||||
case NodeGatheredInfo():
|
||||
return apply_node_gathered_info(event, state)
|
||||
case NodeMemoryMeasured():
|
||||
return apply_node_memory_measured(event, state)
|
||||
case RunnerDeleted():
|
||||
return apply_runner_deleted(event, state)
|
||||
case RunnerStatusUpdated():
|
||||
@@ -192,7 +188,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
|
||||
|
||||
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology = copy.copy(state.topology)
|
||||
state.topology.remove_node(event.node_id)
|
||||
node_profiles = {
|
||||
key: value for key, value in state.node_profiles.items() if key != event.node_id
|
||||
@@ -200,12 +196,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
last_seen = {
|
||||
key: value for key, value in state.last_seen.items() if key != event.node_id
|
||||
}
|
||||
downloads = {
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
}
|
||||
return state.model_copy(
|
||||
update={
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"node_profiles": node_profiles,
|
||||
"last_seen": last_seen,
|
||||
@@ -213,69 +205,103 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
)
|
||||
|
||||
|
||||
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.add_node(event.node_id)
|
||||
info = event.info
|
||||
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
|
||||
# TODO: should be broken up into individual events instead of this monster
|
||||
match info:
|
||||
case MacmonMetrics():
|
||||
profile.system = info.system_profile
|
||||
profile.memory = info.memory
|
||||
case MemoryUsage():
|
||||
profile.memory = info
|
||||
case NodeConfig():
|
||||
pass
|
||||
case MiscData():
|
||||
profile.friendly_name = info.friendly_name
|
||||
case StaticNodeInformation():
|
||||
profile.model_id = info.model
|
||||
profile.chip_id = info.chip
|
||||
# TODO: makes me slightly sad
|
||||
case NodeNetworkInterfaces():
|
||||
profile.network_interfaces = info.ifaces
|
||||
case MacTBIdentifiers():
|
||||
profile.tb_interfaces = info.idents
|
||||
case MacTBConnections():
|
||||
conn_map = {
|
||||
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
|
||||
for nid in state.node_profiles
|
||||
for tb_ident in state.node_profiles[nid].tb_interfaces
|
||||
}
|
||||
as_rdma_conns = [
|
||||
(
|
||||
conn_map[tb_conn.sink_uuid][0],
|
||||
RDMAConnection(
|
||||
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
|
||||
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
|
||||
),
|
||||
)
|
||||
for tb_conn in info.conns
|
||||
if tb_conn.source_uuid in conn_map
|
||||
if tb_conn.sink_uuid in conn_map
|
||||
]
|
||||
topology.replace_all_out_tb_connections(event.node_id, as_rdma_conns)
|
||||
|
||||
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
|
||||
new_profiles = {**state.node_profiles, event.node_id: profile}
|
||||
def apply_node_performance_measured(
|
||||
event: NodePerformanceMeasured, state: State
|
||||
) -> State:
|
||||
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: event.node_profile,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
state = state.model_copy(update={"node_profiles": new_profiles})
|
||||
topology = copy.copy(state.topology)
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, event.node_profile)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": new_profiles,
|
||||
"last_seen": last_seen,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
|
||||
existing = state.node_profiles.get(event.node_id)
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
if existing is None:
|
||||
created = NodePerformanceProfile(
|
||||
model_id="unknown",
|
||||
chip_id="unknown",
|
||||
friendly_name="Unknown",
|
||||
memory=event.memory,
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(
|
||||
# TODO: flops_fp16=0.0,
|
||||
gpu_usage=0.0,
|
||||
temp=0.0,
|
||||
sys_power=0.0,
|
||||
pcpu_usage=0.0,
|
||||
ecpu_usage=0.0,
|
||||
ane_power=0.0,
|
||||
),
|
||||
)
|
||||
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: created,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
# TODO: NodeCreated
|
||||
topology.update_node_profile(event.node_id, created)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": created_profiles,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
updated = existing.model_copy(update={"memory": event.memory})
|
||||
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: updated,
|
||||
}
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, updated)
|
||||
return state.model_copy(
|
||||
update={"node_profiles": updated_profiles, "topology": topology}
|
||||
)
|
||||
|
||||
|
||||
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.add_connection(event.source, event.sink, event.edge)
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_connection(event.edge)
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.remove_connection(event.sink, event.source, event.edge)
|
||||
topology = copy.copy(state.topology)
|
||||
if not topology.contains_connection(event.edge):
|
||||
return state
|
||||
topology.remove_connection(event.edge)
|
||||
# TODO: Clean up removing the reverse connection
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
@@ -38,7 +38,6 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
|
||||
|
||||
# Identity (config)
|
||||
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
|
||||
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
|
||||
|
||||
# libp2p topics for event forwarding
|
||||
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
||||
|
||||
@@ -24,8 +24,6 @@ class _InterceptHandler(logging.Handler):
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
return
|
||||
|
||||
logger.opt(depth=3, exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
|
||||
@@ -14,32 +14,6 @@ class ModelCard(CamelCaseModel):
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
# "deepseek-v3-0324:4bit": ModelCard(
|
||||
# short_id="deepseek-v3-0324:4bit",
|
||||
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
# name="DeepSeek V3 0324 (4-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3-0324": ModelCard(
|
||||
# short_id="deepseek-v3-0324",
|
||||
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
|
||||
# name="DeepSeek V3 0324 (8-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (8-bit)",
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
"deepseek-v3.1-4bit": ModelCard(
|
||||
short_id="deepseek-v3.1-4bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
@@ -51,6 +25,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="DeepSeek V3.1 (4-bit)",
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
@@ -64,65 +40,10 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="DeepSeek V3.1 (8-bit)",
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "deepseek-v3.2": ModelCard(
|
||||
# short_id="deepseek-v3.2",
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
|
||||
# name="DeepSeek V3.2 (8-bit)",
|
||||
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
|
||||
# pretty_name="DeepSeek V3.2 (8-bit)",
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3.2-4bit": ModelCard(
|
||||
# short_id="deepseek-v3.2-4bit",
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
|
||||
# name="DeepSeek V3.2 (4-bit)",
|
||||
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
|
||||
# pretty_name="DeepSeek V3.2 (4-bit)",
|
||||
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# deepseek r1
|
||||
# "deepseek-r1-0528-4bit": ModelCard(
|
||||
# short_id="deepseek-r1-0528-4bit",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
|
||||
# name="DeepSeek-R1-0528 (4-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
|
||||
# pretty_name="DeepSeek R1 671B (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-r1-0528": ModelCard(
|
||||
# short_id="deepseek-r1-0528",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
|
||||
# name="DeepSeek-R1-0528 (8-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
|
||||
# pretty_name="DeepSeek R1 671B (8-bit)",
|
||||
# storage_size=Memory.from_bytes(754998771712),
|
||||
# n_layers=61,
|
||||
# . hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
short_id="kimi-k2-instruct-4bit",
|
||||
@@ -135,6 +56,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Kimi K2 Instruct (4-bit)",
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
@@ -148,6 +71,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Kimi K2 Thinking (4-bit)",
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.1
|
||||
@@ -162,6 +87,38 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.1 8B (4-bit)",
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
short_id="llama-3.1-8b-8bit",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
name="Llama 3.1 8B (8-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.1 8B (8-bit)",
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
short_id="llama-3.1-8b-bf16",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
name="Llama 3.1 8B (BF16)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
pretty_name="Llama 3.1 8B (BF16)",
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
@@ -175,6 +132,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.1 70B (4-bit)",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.2
|
||||
@@ -189,6 +148,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 1B (4-bit)",
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
@@ -202,6 +163,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 3B (4-bit)",
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
@@ -215,6 +178,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.2 3B (8-bit)",
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.3
|
||||
@@ -229,6 +194,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
@@ -242,6 +209,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B (8-bit)",
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
@@ -255,20 +224,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Llama 3.3 70B (FP16)",
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
),
|
||||
),
|
||||
# phi-3
|
||||
"phi-3-mini": ModelCard(
|
||||
short_id="phi-3-mini",
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
name="Phi 3 Mini 128k (4-bit)",
|
||||
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
pretty_name="Phi 3 Mini 128k (4-bit)",
|
||||
storage_size=Memory.from_mb(2099),
|
||||
n_layers=32,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# qwen3
|
||||
@@ -283,6 +240,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 0.6B (4-bit)",
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
@@ -296,6 +255,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 0.6B (8-bit)",
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
@@ -309,6 +270,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 30B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
@@ -322,6 +285,68 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 30B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
name="Qwen3 80B A3B (4-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
name="Qwen3 80B A3B (8-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
name="Qwen3 80B A3B Thinking (4-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
name="Qwen3 80B A3B Thinking (8-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
@@ -335,6 +360,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 235B A22B (4-bit)",
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
@@ -348,6 +375,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 235B A22B (8-bit)",
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
@@ -361,6 +390,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
@@ -374,78 +405,148 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# granite
|
||||
"granite-3.3-2b": ModelCard(
|
||||
short_id="granite-3.3-2b",
|
||||
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
|
||||
name="Granite 3.3 2B (FP16)",
|
||||
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
|
||||
pretty_name="Granite 3.3 2B (FP16)",
|
||||
storage_size=Memory.from_mb(4951),
|
||||
n_layers=40,
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"gpt-oss-20b-4bit": ModelCard(
|
||||
short_id="gpt-oss-20b-4bit",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# glm 4.5
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
# Needs to be quantized g32 or g16 to work with tensor parallel
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
description="""GLM 4.5 Air 8bit""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
pretty_name="GLM 4.5 Air 8bit",
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
short_id="glm-4.5-air-bf16",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
name="GLM 4.5 Air bf16",
|
||||
description="""GLM 4.5 Air bf16""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
pretty_name="GLM 4.5 Air bf16",
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
short_id="glm-4.7-4bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
name="GLM 4.7 4bit",
|
||||
description="GLM 4.7 4bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
pretty_name="GLM 4.7 4bit",
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
short_id="glm-4.7-6bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
name="GLM 4.7 6bit",
|
||||
description="GLM 4.7 6bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
pretty_name="GLM 4.7 6bit",
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
short_id="glm-4.7-8bit-gs32",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
name="GLM 4.7 8bit (gs32)",
|
||||
description="GLM 4.7 8bit (gs32)",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
pretty_name="GLM 4.7 8bit (gs32)",
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
short_id="minimax-m2.1-8bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
name="MiniMax M2.1 8bit",
|
||||
description="MiniMax M2.1 8bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
pretty_name="MiniMax M2.1 8bit",
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
short_id="minimax-m2.1-3bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
name="MiniMax M2.1 3bit",
|
||||
description="MiniMax M2.1 3bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
pretty_name="MiniMax M2.1 3bit",
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "granite-3.3-8b": ModelCard(
|
||||
# short_id="granite-3.3-8b",
|
||||
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
# name="Granite 3.3 8B",
|
||||
# description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
# pretty_name="Granite 3.3 8B",
|
||||
# storage_size=Memory.from_kb(15958720),
|
||||
# n_layers=40,
|
||||
# ),
|
||||
# ),
|
||||
# smol-lm
|
||||
# "smol-lm-135m": ModelCard(
|
||||
# short_id="smol-lm-135m",
|
||||
# model_id="mlx-community/SmolLM-135M-4bit",
|
||||
# name="Smol LM 135M",
|
||||
# description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
|
||||
# pretty_name="Smol LM 135M",
|
||||
# storage_size=Memory.from_kb(73940),
|
||||
# n_layers=30,
|
||||
# ),
|
||||
# ),
|
||||
# gpt-oss
|
||||
# "gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
# short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
# name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
# description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
# pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
# storage_size=Memory.from_kb(68_996_301),
|
||||
# n_layers=36,
|
||||
# hidden_size=2880,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
# "gpt-oss-20b-4bit": ModelCard(
|
||||
# short_id="gpt-oss-20b-4bit",
|
||||
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
# name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
# description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
# pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
# storage_size=Memory.from_kb(11_744_051),
|
||||
# n_layers=24,
|
||||
# hidden_size=2880,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.worker.download.download_utils import (
|
||||
@@ -25,6 +26,7 @@ class ConfigData(BaseModel):
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
@@ -106,10 +108,19 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
model_card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
|
||||
None,
|
||||
)
|
||||
|
||||
return ModelMetadata(
|
||||
model_id=ModelId(model_id),
|
||||
pretty_name=model_id,
|
||||
pretty_name=model_card.name if model_card is not None else model_id,
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=model_card.metadata.supports_tensor
|
||||
if model_card is not None
|
||||
else False,
|
||||
)
|
||||
|
||||
@@ -36,6 +36,8 @@ def get_pipeline_shard_metadata(
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
|
||||
@@ -2,6 +2,7 @@ from exo.shared.apply import apply_node_download_progress
|
||||
from exo.shared.tests.conftest import get_pipeline_shard_metadata
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import NodeDownloadProgress
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
|
||||
@@ -13,6 +14,7 @@ def test_apply_node_download_progress():
|
||||
event = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
|
||||
new_state = apply_node_download_progress(
|
||||
@@ -28,10 +30,12 @@ def test_apply_two_node_download_progress():
|
||||
event1 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
event2 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard2,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
state = State(downloads={NodeId("node-1"): [event1]})
|
||||
|
||||
@@ -39,4 +43,7 @@ def test_apply_two_node_download_progress():
|
||||
NodeDownloadProgress(download_progress=event2), state
|
||||
)
|
||||
|
||||
# TODO: This test is failing. We should support the following:
|
||||
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
|
||||
# 2. Downloading a model, it completes, then downloading a different model on the same node.
|
||||
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.topology import Connection
|
||||
|
||||
|
||||
def test_state_serialization_roundtrip() -> None:
|
||||
@@ -11,12 +11,14 @@ def test_state_serialization_roundtrip() -> None:
|
||||
node_a = NodeId("node-a")
|
||||
node_b = NodeId("node-b")
|
||||
|
||||
connection = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
connection = Connection(
|
||||
local_node_id=node_a,
|
||||
send_back_node_id=node_b,
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
)
|
||||
|
||||
state = State()
|
||||
state.topology.add_connection(node_a, node_b, connection)
|
||||
state.topology.add_connection(connection)
|
||||
|
||||
json_repr = state.model_dump_json()
|
||||
restored_state = State.model_validate_json(json_repr)
|
||||
|
||||
@@ -1,219 +1,203 @@
|
||||
import contextlib
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable
|
||||
|
||||
import rustworkx as rx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
|
||||
|
||||
class TopologySnapshot(BaseModel):
|
||||
nodes: Sequence[NodeId]
|
||||
connections: Mapping[
|
||||
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
|
||||
]
|
||||
nodes: list[NodeInfo]
|
||||
connections: list[Connection]
|
||||
|
||||
model_config = ConfigDict(frozen=True, extra="forbid")
|
||||
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Topology:
|
||||
# the _graph can be used as a int -> NodeId map.
|
||||
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
|
||||
init=False, default_factory=rx.PyDiGraph
|
||||
)
|
||||
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
|
||||
def __init__(self) -> None:
|
||||
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
||||
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
|
||||
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
|
||||
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
|
||||
|
||||
def to_snapshot(self) -> TopologySnapshot:
|
||||
return TopologySnapshot(
|
||||
nodes=list(self.list_nodes()), connections=self.map_connections()
|
||||
nodes=list(self.list_nodes()),
|
||||
connections=list(self.list_connections()),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
|
||||
topology = cls()
|
||||
|
||||
for node_id in snapshot.nodes:
|
||||
for node in snapshot.nodes:
|
||||
with contextlib.suppress(ValueError):
|
||||
topology.add_node(node_id)
|
||||
topology.add_node(node)
|
||||
|
||||
for source in snapshot.connections:
|
||||
for sink in snapshot.connections[source]:
|
||||
for conn in snapshot.connections[source][sink]:
|
||||
topology.add_connection(source, sink, conn)
|
||||
for connection in snapshot.connections:
|
||||
topology.add_connection(connection)
|
||||
|
||||
return topology
|
||||
|
||||
def add_node(self, node_id: NodeId) -> None:
|
||||
if node_id in self._vertex_indices:
|
||||
def add_node(self, node: NodeInfo) -> None:
|
||||
if node.node_id in self._node_id_to_rx_id_map:
|
||||
return
|
||||
rx_id = self._graph.add_node(node_id)
|
||||
self._vertex_indices[node_id] = rx_id
|
||||
rx_id = self._graph.add_node(node)
|
||||
self._node_id_to_rx_id_map[node.node_id] = rx_id
|
||||
self._rx_id_to_node_id_map[rx_id] = node.node_id
|
||||
|
||||
def node_is_leaf(self, node_id: NodeId) -> bool:
|
||||
return (
|
||||
node_id in self._vertex_indices
|
||||
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
|
||||
node_id in self._node_id_to_rx_id_map
|
||||
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
|
||||
)
|
||||
|
||||
def neighbours(self, node_id: NodeId) -> list[NodeId]:
|
||||
return [
|
||||
self._graph[rx_id]
|
||||
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
|
||||
self._rx_id_to_node_id_map[rx_id]
|
||||
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
|
||||
]
|
||||
|
||||
def out_edges(
|
||||
self, node_id: NodeId
|
||||
) -> Iterable[tuple[NodeId, SocketConnection | RDMAConnection]]:
|
||||
if node_id not in self._vertex_indices:
|
||||
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return []
|
||||
return (
|
||||
(self._graph[nid], conn)
|
||||
for _, nid, conn in self._graph.out_edges(self._vertex_indices[node_id])
|
||||
)
|
||||
return [
|
||||
(self._rx_id_to_node_id_map[nid], conn)
|
||||
for _, nid, conn in self._graph.out_edges(
|
||||
self._node_id_to_rx_id_map[node_id]
|
||||
)
|
||||
]
|
||||
|
||||
def contains_node(self, node_id: NodeId) -> bool:
|
||||
return node_id in self._vertex_indices
|
||||
return node_id in self._node_id_to_rx_id_map
|
||||
|
||||
def contains_connection(self, connection: Connection) -> bool:
|
||||
return connection in self._edge_id_to_rx_id_map
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
source: NodeId,
|
||||
sink: NodeId,
|
||||
connection: SocketConnection | RDMAConnection,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
if connection in self.get_all_connections_between(source, sink):
|
||||
if connection.local_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.local_node_id))
|
||||
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
|
||||
|
||||
if connection in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
if source not in self._vertex_indices:
|
||||
self.add_node(source)
|
||||
if sink not in self._vertex_indices:
|
||||
self.add_node(sink)
|
||||
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
|
||||
|
||||
src_id = self._vertex_indices[source]
|
||||
sink_id = self._vertex_indices[sink]
|
||||
rx_id = self._graph.add_edge(src_id, sink_id, connection)
|
||||
self._edge_id_to_rx_id_map[connection] = rx_id
|
||||
|
||||
_ = self._graph.add_edge(src_id, sink_id, connection)
|
||||
def list_nodes(self) -> Iterable[NodeInfo]:
|
||||
return (self._graph[i] for i in self._graph.node_indices())
|
||||
|
||||
def get_all_connections_between(
|
||||
self, source: NodeId, sink: NodeId
|
||||
) -> Iterable[SocketConnection | RDMAConnection]:
|
||||
if source not in self._vertex_indices:
|
||||
return []
|
||||
if sink not in self._vertex_indices:
|
||||
return []
|
||||
def list_connections(self) -> Iterable[Connection]:
|
||||
return (connection for _, _, connection in self._graph.weighted_edge_list())
|
||||
|
||||
src_id = self._vertex_indices[source]
|
||||
sink_id = self._vertex_indices[sink]
|
||||
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
|
||||
try:
|
||||
return self._graph.get_all_edge_data(src_id, sink_id)
|
||||
except rx.NoEdgeBetweenNodes:
|
||||
return []
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
return self._graph.get_node_data(rx_idx).node_profile
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def list_nodes(self) -> Iterable[NodeId]:
|
||||
return self._graph.nodes()
|
||||
def update_node_profile(
|
||||
self, node_id: NodeId, node_profile: NodePerformanceProfile
|
||||
) -> None:
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
self._graph[rx_idx].node_profile = node_profile
|
||||
|
||||
def map_connections(
|
||||
self,
|
||||
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
|
||||
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
|
||||
for src_id, sink_id, connection in self._graph.weighted_edge_list():
|
||||
source = self._graph[src_id]
|
||||
sink = self._graph[sink_id]
|
||||
if source not in base:
|
||||
base[source] = {}
|
||||
if sink not in base[source]:
|
||||
base[source][sink] = []
|
||||
base[source][sink].append(connection)
|
||||
return base
|
||||
def update_connection_profile(self, connection: Connection) -> None:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.update_edge_by_index(rx_idx, connection)
|
||||
|
||||
def list_connections(
|
||||
self,
|
||||
) -> Iterable[tuple[NodeId, NodeId, SocketConnection | RDMAConnection]]:
|
||||
return (
|
||||
(
|
||||
self._graph[src_id],
|
||||
self._graph[sink_id],
|
||||
connection,
|
||||
)
|
||||
for src_id, sink_id, connection in self._graph.weighted_edge_list()
|
||||
)
|
||||
def get_connection_profile(
|
||||
self, connection: Connection
|
||||
) -> ConnectionProfile | None:
|
||||
try:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def remove_node(self, node_id: NodeId) -> None:
|
||||
if node_id not in self._vertex_indices:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
rx_idx = self._vertex_indices[node_id]
|
||||
for connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_id
|
||||
or connection.send_back_node_id == node_id
|
||||
):
|
||||
self.remove_connection(connection)
|
||||
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
self._graph.remove_node(rx_idx)
|
||||
|
||||
del self._vertex_indices[node_id]
|
||||
del self._node_id_to_rx_id_map[node_id]
|
||||
del self._rx_id_to_node_id_map[rx_idx]
|
||||
|
||||
def replace_all_out_tb_connections(
|
||||
self, source: NodeId, new_connections: Sequence[tuple[NodeId, RDMAConnection]]
|
||||
) -> None:
|
||||
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
|
||||
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
for sink, conn in new_connections:
|
||||
self.add_connection(source, sink, conn)
|
||||
|
||||
def remove_connection(
|
||||
self, source: NodeId, sink: NodeId, edge: SocketConnection | RDMAConnection
|
||||
) -> None:
|
||||
if source not in self._vertex_indices or sink not in self._vertex_indices:
|
||||
def remove_connection(self, connection: Connection) -> None:
|
||||
if connection not in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
for conn_idx in self._graph.edge_indices_from_endpoints(
|
||||
self._vertex_indices[source], self._vertex_indices[sink]
|
||||
):
|
||||
if self._graph.get_edge_data_by_index(conn_idx) == edge:
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.remove_edge_from_index(rx_idx)
|
||||
del self._edge_id_to_rx_id_map[connection]
|
||||
|
||||
def get_cycles(self) -> list[list[NodeId]]:
|
||||
def get_cycles(self) -> list[list[NodeInfo]]:
|
||||
cycle_idxs = rx.simple_cycles(self._graph)
|
||||
cycles: list[list[NodeId]] = []
|
||||
cycles: list[list[NodeInfo]] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = [self._graph[idx] for idx in cycle_idx]
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
|
||||
def get_cycles_tb(self) -> list[list[NodeId]]:
|
||||
def get_cycles_tb(self) -> list[list[NodeInfo]]:
|
||||
tb_edges = [
|
||||
(u, v, conn)
|
||||
for u, v, conn in self._graph.weighted_edge_list()
|
||||
if conn.is_thunderbolt()
|
||||
]
|
||||
|
||||
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
|
||||
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
||||
tb_graph.add_nodes_from(self._graph.nodes())
|
||||
|
||||
for u, v, conn in tb_edges:
|
||||
if isinstance(conn, SocketConnection):
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
|
||||
cycle_idxs = rx.simple_cycles(tb_graph)
|
||||
cycles: list[list[NodeId]] = []
|
||||
cycles: list[list[NodeInfo]] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = [tb_graph[idx] for idx in cycle_idx]
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
|
||||
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
|
||||
rx_idxs = [self._vertex_indices[idx] for idx in node_ids]
|
||||
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
|
||||
node_idxs = [node.node_id for node in nodes]
|
||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
||||
topology = Topology()
|
||||
for rx_idx in rx_idxs:
|
||||
topology.add_node(self._graph[rx_idx])
|
||||
for source, sink, connection in self.list_connections():
|
||||
if source in node_ids and sink in node_ids:
|
||||
topology.add_connection(source, sink, connection)
|
||||
for connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id in node_idxs
|
||||
and connection.send_back_node_id in node_idxs
|
||||
):
|
||||
topology.add_connection(connection)
|
||||
return topology
|
||||
|
||||
def is_thunderbolt_cycle(self, cycle: list[NodeId]) -> bool:
|
||||
node_idxs = [node for node in cycle]
|
||||
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
|
||||
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
|
||||
node_idxs = [node.node_id for node in cycle]
|
||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
||||
for rid in rx_idxs:
|
||||
for neighbor_rid in self._graph.neighbors(rid):
|
||||
if neighbor_rid not in rx_idxs:
|
||||
|
||||
@@ -5,7 +5,8 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
@@ -51,6 +52,10 @@ class ChatCompletionMessage(BaseModel):
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class BenchChatCompletionMessage(ChatCompletionMessage):
|
||||
pass
|
||||
|
||||
|
||||
class TopLogprobItem(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
@@ -113,6 +118,18 @@ class ChatCompletionResponse(BaseModel):
|
||||
service_tier: str | None = None
|
||||
|
||||
|
||||
class GenerationStats(BaseModel):
|
||||
prompt_tps: float
|
||||
generation_tps: float
|
||||
prompt_tokens: int
|
||||
generation_tokens: int
|
||||
peak_memory_usage: Memory
|
||||
|
||||
|
||||
class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
@@ -135,6 +152,10 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
user: str | None = None
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
pass
|
||||
|
||||
|
||||
class PlaceInstanceParams(BaseModel):
|
||||
model_id: str
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
@@ -174,6 +195,7 @@ class DeleteInstanceTaskParams(BaseModel):
|
||||
class CreateInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
model_meta: ModelMetadata
|
||||
|
||||
|
||||
class DeleteInstanceResponse(BaseModel):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -20,6 +21,7 @@ class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
|
||||
@@ -2,14 +2,14 @@ from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import SocketConnection
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -76,15 +76,25 @@ class RunnerDeleted(BaseEvent):
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
# TODO
|
||||
class NodeCreated(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
class NodeTimedOut(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
# TODO: bikeshed this naem
|
||||
class NodeGatheredInfo(BaseEvent):
|
||||
class NodePerformanceMeasured(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
info: GatheredInfo # NB: this model is UNTAGGED!!! be warned for ser/de errors.
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
class NodeMemoryMeasured(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
memory: MemoryPerformanceProfile
|
||||
|
||||
|
||||
class NodeDownloadProgress(BaseEvent):
|
||||
@@ -97,15 +107,11 @@ class ChunkGenerated(BaseEvent):
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
source: NodeId
|
||||
sink: NodeId
|
||||
edge: SocketConnection
|
||||
edge: Connection
|
||||
|
||||
|
||||
class TopologyEdgeDeleted(BaseEvent):
|
||||
source: NodeId
|
||||
sink: NodeId
|
||||
edge: SocketConnection
|
||||
edge: Connection
|
||||
|
||||
|
||||
Event = (
|
||||
@@ -119,8 +125,10 @@ Event = (
|
||||
| InstanceDeleted
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeCreated
|
||||
| NodeTimedOut
|
||||
| NodeGatheredInfo
|
||||
| NodePerformanceMeasured
|
||||
| NodeMemoryMeasured
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| TopologyEdgeCreated
|
||||
|
||||
@@ -14,3 +14,5 @@ class ModelMetadata(CamelCaseModel):
|
||||
pretty_name: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import re
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
|
||||
from pydantic import BaseModel, computed_field, field_validator
|
||||
|
||||
|
||||
class Multiaddr(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
address: str
|
||||
|
||||
PATTERNS: ClassVar[list[str]] = [
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Self
|
||||
|
||||
import psutil
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.thunderbolt import TBIdentifier
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class MemoryUsage(CamelCaseModel):
|
||||
class MemoryPerformanceProfile(CamelCaseModel):
|
||||
ram_total: Memory
|
||||
ram_available: Memory
|
||||
swap_total: Memory
|
||||
@@ -46,6 +44,7 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
sys_power: float = 0.0
|
||||
pcpu_usage: float = 0.0
|
||||
ecpu_usage: float = 0.0
|
||||
ane_power: float = 0.0
|
||||
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
@@ -54,16 +53,15 @@ class NetworkInterfaceInfo(CamelCaseModel):
|
||||
|
||||
|
||||
class NodePerformanceProfile(CamelCaseModel):
|
||||
model_id: str = "Unknown"
|
||||
chip_id: str = "Unknown"
|
||||
friendly_name: str = "Unknown"
|
||||
memory: MemoryUsage = MemoryUsage.from_bytes(
|
||||
ram_total=0, ram_available=0, swap_total=0, swap_available=0
|
||||
)
|
||||
network_interfaces: Sequence[NetworkInterfaceInfo] = []
|
||||
tb_interfaces: Sequence[TBIdentifier] = []
|
||||
system: SystemPerformanceProfile = SystemPerformanceProfile()
|
||||
model_id: str
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
memory: MemoryPerformanceProfile
|
||||
network_interfaces: list[NetworkInterfaceInfo] = []
|
||||
system: SystemPerformanceProfile
|
||||
|
||||
|
||||
class ConnectionProfile(CamelCaseModel):
|
||||
pass
|
||||
throughput: float
|
||||
latency: float
|
||||
jitter: float
|
||||
|
||||
@@ -40,6 +40,10 @@ class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class ConnectToGroup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class StartWarmup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
@@ -57,5 +61,11 @@ class Shutdown(BaseTask): # emitted by Worker
|
||||
|
||||
|
||||
Task = (
|
||||
CreateRunner | DownloadModel | LoadModel | StartWarmup | ChatCompletion | Shutdown
|
||||
CreateRunner
|
||||
| DownloadModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| Shutdown
|
||||
)
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
import anyio
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class TBConnection(CamelCaseModel):
|
||||
source_uuid: str
|
||||
sink_uuid: str
|
||||
|
||||
|
||||
class TBIdentifier(CamelCaseModel):
|
||||
rdma_interface: str
|
||||
domain_uuid: str
|
||||
|
||||
|
||||
# Intentionally minimal, only collecting data we care about - there's a lot more
|
||||
|
||||
|
||||
class TBReceptacleTag(BaseModel, extra="ignore"):
|
||||
receptacle_id_key: str
|
||||
|
||||
|
||||
class TBConnectivityItem(BaseModel, extra="ignore"):
|
||||
domain_uuid_key: str | None
|
||||
|
||||
|
||||
class TBConnectivityData(BaseModel, extra="ignore"):
|
||||
domain_uuid_key: str | None
|
||||
device_name_key: str
|
||||
items: list[TBConnectivityItem] | None = Field(None, alias="_items")
|
||||
receptacle_1_tag: TBReceptacleTag
|
||||
|
||||
def ident(self, ifaces: dict[str, str]) -> TBIdentifier | None:
|
||||
if self.domain_uuid_key is None:
|
||||
return
|
||||
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
|
||||
iface = f"rdma_{ifaces[tag]}"
|
||||
return TBIdentifier(rdma_interface=iface, domain_uuid=self.domain_uuid_key)
|
||||
|
||||
def conn(self) -> TBConnection | None:
|
||||
if self.domain_uuid_key is None or self.items is None:
|
||||
return
|
||||
|
||||
sink_key = next(
|
||||
item.domain_uuid_key
|
||||
for item in self.items
|
||||
if item.domain_uuid_key is not None
|
||||
)
|
||||
return TBConnection(source_uuid=self.domain_uuid_key, sink_uuid=sink_key)
|
||||
|
||||
|
||||
class TBConnectivity(BaseModel):
|
||||
SPThunderboltDataType: list[TBConnectivityData]
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> list[TBConnectivityData] | None:
|
||||
proc = await anyio.run_process(
|
||||
["system_profiler", "SPThunderboltDataType", "-json"], check=False
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
return None
|
||||
# Saving you from PascalCase while avoiding too much pydantic
|
||||
return TBConnectivity.model_validate_json(proc.stdout).SPThunderboltDataType
|
||||
@@ -1,32 +1,37 @@
|
||||
from enum import Enum
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class RDMAConnection(FrozenModel):
|
||||
source_rdma_iface: str
|
||||
sink_rdma_iface: str
|
||||
class NodeInfo(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
node_profile: NodePerformanceProfile | None = None
|
||||
|
||||
|
||||
class Connection(CamelCaseModel):
|
||||
local_node_id: NodeId
|
||||
send_back_node_id: NodeId
|
||||
send_back_multiaddr: Multiaddr
|
||||
connection_profile: ConnectionProfile | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.send_back_multiaddr.address,
|
||||
)
|
||||
)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Connection):
|
||||
raise ValueError("Cannot compare Connection with non-Connection")
|
||||
return (
|
||||
self.local_node_id == other.local_node_id
|
||||
and self.send_back_node_id == other.send_back_node_id
|
||||
and self.send_back_multiaddr == other.send_back_multiaddr
|
||||
)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
logger.warning("duh")
|
||||
return True
|
||||
|
||||
|
||||
# TODO
|
||||
class LinkType(str, Enum):
|
||||
Thunderbolt = "Thunderbolt"
|
||||
Ethernet = "Ethernet"
|
||||
WiFi = "WiFi"
|
||||
|
||||
|
||||
class SocketConnection(FrozenModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.sink_multiaddr.ip_address)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
|
||||
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
@@ -28,7 +28,7 @@ class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
|
||||
class DownloadCompleted(BaseDownloadProgress):
|
||||
pass
|
||||
total_bytes: Memory
|
||||
|
||||
|
||||
class DownloadFailed(BaseDownloadProgress):
|
||||
|
||||
@@ -25,11 +25,12 @@ class BaseInstance(TaggedModel):
|
||||
|
||||
|
||||
class MlxRingInstance(BaseInstance):
|
||||
hosts: list[Host]
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
ephemeral_port: int
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
jaccl_devices: list[list[str | None]]
|
||||
ibv_devices: list[list[str | None]]
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user