Compare commits

...

23 Commits

Author SHA1 Message Date
rltakashige
077b1bc732 exo-bench (Benchmark model pp & tg speed) (#1099)
## Motivation

This PR implements benchmarking in the style of llama-bench. The main
difficulty here is the fact that exo is not a library - it exposes an
endpoint. This means that benchmarking numbers will be inaccurate if the
API is measured.

The solution assumes nodes are set up with uv run exo (or via the app),
and then hits the new endpoint /bench/chat/completions to retrieve
generation statistics directly from mlx_lm.
<!-- Why is this change needed? What problem does it solve? -->

This will allow us to release benchmarks for models and perform
regression tests.

TODO: Performance benchmarking.
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->
- Adds /bench/chat/completions endpoint
- Adds BenchChatCompletion/Response
- Adds a logits processor to prevent response from ending early
- Adds a "Prompt Sizer" which downloads the tokenizer and dynamically
adjusts the prompt of "a" to fit the desired prompt size.
- Reduce prefill step size to 2048 for now (in future, dynamically
adjust this value)

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->
Benchmarked Llama, Qwen, DeepSeek and Kimi models. Will require several
fixes to run consistently on all configurations (to be done in the
future).
Manually tested the normal API to verify chat requests complete as
expected.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
Not really possible. Type checker passes.
2026-01-06 17:39:09 +00:00
Alex Cheema
4963c33162 Fix Discord link in README.md. Fixes #1096 (#1097)
## Motivation

Discord link expired.

## Changes

Replace discord invite link with permanent link.

## Why It Works

It's permanent now.

## Test Plan

Clicked the link. It works.
2026-01-06 14:05:09 +00:00
madanlalit
4f6fcd9e93 feat(macos-app): add custom namespace UI for cluster isolation
Add Advanced Options section with custom namespace field that allows
users to override EXO_LIBP2P_NAMESPACE environment variable. This
enables splitting machines that can see each other into separate
clusters.

- Added customNamespace property with UserDefaults persistence
- Added Advanced Options collapsible section with text field
- Added Save & Restart button that auto-restarts exo process
- Namespace replaces buildTag when custom value is set
- Falls back to buildTag (version) when namespace is empty
2026-01-05 15:25:00 +01:00
Evan Quiney
839b67f318 [feat] Add an option to disable the worker (#1091)
## Motivation

Workerless machines can be used for networking without running any gpu
jobs - add a cli flag that adds this basic functionality.

## Changes

Adds the --no-worker cli flag

## Test Plan

### Manual Testing

Exo starts as expected

### Automated Testing

None
2026-01-05 12:05:03 +00:00
Drifter4242
47b8e0ce12 feat: remember last launch settings (model, sharding, instance type) (#1028)
## Motivation

Saves the last launch settings, so that the next time you run exo it
will default to the same launch settings.
This is just a small quality of life improvement.

## Changes

When you launch it saves the settings to the web browser local storage.
When it fills out the model list, it reads the settings and sets the
default.

I reviewed, tested and edited the code, but some of the code was written
by Claude Opus. I hope that's ok.

## Why It Works

See above

## Test Plan

### Manual Testing

I have two Macbook Studio M3 Ultras, each with 512Gb ram, connected with
Thunderbolt 5. I ran Kimi K2 Thinking with MLX Ring and Tensor Split.
I ran exo multiple times to confirm that the default works.

### Automated Testing

No changes to automated testing.
2026-01-05 11:27:14 +00:00
Evan Quiney
17f9b583a4 Task Deduplication (#1062) 2026-01-03 20:01:49 +00:00
RickyChen / 陳昭儒
844bcc7ce6 fix: prevent form submission during IME composition (#1069)
## Problem
When typing in Chinese (or other IME-based languages like
Japanese/Korean), pressing Enter to select a character from the IME
candidate list would incorrectly submit the message instead of
confirming the character selection.

## Solution
Added IME composition state detection in the `handleKeydown` function in
`ChatForm.svelte`:
- Check `event.isComposing` to detect active IME composition
- Fallback to `event.keyCode === 229` for broader browser compatibility
- Return early when IME is active, allowing normal character selection

## Changes
- Modified `dashboard/src/lib/components/ChatForm.svelte` 
- Added IME composition check before Enter key handling

Co-authored-by: Ricky Chen <rickychen@Rickys-MacBook-Pro.local>
2025-12-31 17:11:04 +00:00
Evan Quiney
c1be5184b2 Fix tests broken by 283c (#1063)
Some tests were broken by #1058 and #1046 - this fixes them.
2025-12-31 01:53:55 +00:00
Alex Cheema
1ec550dff1 Emit download progress on start, and change downloads to be keyed by model_id (#1044)
## Motivation

We added a download page to the dashboard which shows the currently
download status of each model on each node. Users have reported this to
be extremely useful.

However, we don't currently fetch the download progress on start, so it
doesn't show any model's download status.

## Changes

Fetch and emit model download status on start of worker, and
periodically every 5 mins.
Also to support this, I changed download_status to be keyed by model_id
instead of shard, since we want download_status of each model, not each
shard.

## Why It Works

The dashboard already implements the correct functionality, we just
weren't populating the download status in the state. Now it gets
populated and shows correctly.

## Test Plan

### Manual Testing
On a cluster of 2 x 512GB M3 Ultra Mac Studio, I launched an instance
onto one node that hadn't been downloaded. I checked the download page
and it showed the in progress download. I downloaded it to completion,
restarted exo on both nodes, and then opened the download page and it
showed the model as 100% downloaded and other models as 0% that hadn't
been downloaded.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2025-12-31 01:18:10 +00:00
Alex Cheema
283c0e39e4 Placement filters for tensor parallel supports_tensor, tensor dimension and pipeline parallel deepseek v3.1 (#1058)
## Motivation

Certain placements are not valid. Added filters to exclude these placements. There were invalid placement previews being shown in the dashboard which would then fail when the user actually tries to launch an instance with that placement.


## Changes

Three filters added:

1. Certain models do not support tensor parallel at all. Checks `supports_tensor` on the model_meta.
2. For models that do support tensor parallelism, certain tensor parallel sizes are not valid. This check is actually not correct right now but it works fine for now. The actual correct check is more involved.
3. For unknown reasons, deepseek v3.1 (8-bit) does not work with tensor parallelism.

## Why It Works

`place_instance` now raises an `Exception` for invalid placements.

## Test Plan

### Manual Testing
Since `/instance/previews` enumerates all possible placements and runs `place_instance`, I checked the dashboard to see if invalid placements are still shown.
2025-12-31 00:33:40 +00:00
Alex Cheema
35be4c55c3 prioritise mlx jaccl coordinator ip (en0 -> en1 -> non-TB5 -> other) 2025-12-31 00:10:19 +00:00
Alex Cheema
31d4cd8409 set KV_CACHE_BITS to None to disable quantized kv cache 2025-12-31 00:03:30 +00:00
Alex Cheema
8a6da58404 remove mx.set_cache_limit 2025-12-30 23:58:15 +00:00
Alex Cheema
16e2bfd3b3 log EXO_LIBP2P_NAMESPACE on start 2025-12-30 04:08:47 +00:00
Alex Cheema
ade3ee7ec5 fix warmup order. should be rank!=0 then rank=0 2025-12-30 03:29:34 +00:00
Evan Quiney
fea42473dd Place local node at the top of the dashboard. (#1033)
@samiamjidkhan and @AlexCheema's work moving the topology to place the
local node at the top of the topology in the app dashboard.
2025-12-28 21:12:47 +00:00
Alex Cheema
ca7adcc2a8 Update README.md with instructions to enable RDMA. (#1031)
## Motivation

We didn't have instructions for enabling RDMA on macOS.

## Changes

I added instructions for enabling RDMA on macOS.

## Why It Works

Tried it on my M4 Max MacBook Pro and works.

## Test Plan

### Manual Testing
Tried it on my M4 Max MacBook Pro and works.

### Automated Testing
In the future, we could automate this from fresh macOS builds using KVM
over IP. See #1030
2025-12-28 20:56:26 +00:00
Evan Quiney
9d9e24f969 some dashboard updates (#1017)
Mostly @samiamjidkhan and @AlexCheema's work in progress.

---------

Co-authored-by: Sami Khan <smsak99@gmail.com>
Co-authored-by: Alex Cheema
2025-12-28 20:50:23 +00:00
Jake Hillion
b5d424b658 placement: generate per-node host lists for MLX ring backend
Pipeline + MLX Ring worked with 2 nodes but failed to initialize with
3 or more nodes. The MLX ring backend requires each node to know its
specific left and right neighbors in the ring, but the previous
implementation provided a single flat host list shared by all nodes.

With 2 nodes, a flat list [host0, host1] accidentally worked because
each node could find its only neighbor. With 3+ nodes, each node needs
a customized view:
- Rank 0: [self, right_neighbor, placeholder]
- Rank 1: [left_neighbor, self, right_neighbor]
- Rank 2: [placeholder, left_neighbor, self]

Changed MlxRingInstance from `hosts: list[Host]` to
`hosts_by_node: dict[NodeId, list[Host]]` with `ephemeral_port: int`.

Added `get_mlx_ring_hosts_by_node()` which generates per-node host
lists where:
- Self position uses 0.0.0.0 for local binding
- Left/right neighbors use actual connection IPs
- Non-neighbors use 198.51.100.1 (RFC 5737 TEST-NET-2 placeholder)

Also added IP prioritization (en0 > en1 > non-Thunderbolt > any) to
prefer stable network interfaces.

Fixed topology discovery recording loopback addresses (127.0.0.1) as
valid connections to remote nodes. The reachability check now verifies
node identity via HTTP GET /node_id rather than just checking if the
port is open.

Test plan:

- Built a DMG [0]
- Installed on all Macs and started cluster.
- Requested a 3 node Pipeline + MLX Ring Llama 3.3 70B (FP16).
- It started and I was able to send a few chat messages.

Eventually my instance seemed to get into a broken state and chat
stopped working, but this commit is a clear step forward.

[0] https://github.com/exo-explore/exo/actions/runs/20473983471/job/58834969418
2025-12-28 20:38:20 +00:00
Drifter4242
b465134012 Fix Kimi K2 Thinking download by adding tiktoken.model to download patterns (#1024)
Kimi-K2 Thinking uses tiktoken.model for its tokenizer, which wasn't
being downloaded. This adds it to the default_patterns alongside
tokenizer.model.
I'm a bit confused why this isn't a problem for other people - I know
that others have used Kimi K2 (I wonder if they manually fixed the
download).

## Motivation

I downloaded Kimi K2 Thinking and it didn't work because it didn't
download tiktoken.model file.

## Changes

Added tiktoken.model to the default patterns.

## Why It Works

Now downloads the file.

## Test Plan

### Manual Testing

I have two Macbook Studio M3 Ultras, each with 512Gb ram, connected with
Thunderbolt 5. I ran Kimi K2 Thinking with MLX Ring and Tensor Split. It
ran successfully.

### Automated Testing
No automated test changes. I don't think they are needed.
2025-12-28 19:30:31 +00:00
Matiwos Kebede
eabdcab978 Fix linux docs (#1022)
This PR updates the "Run from Source (Mac & Linux)" section in README.md
to clarify Linux instructions.

Changes include:
- Split the section into macOS and Linux subsections.
- Added native Linux package manager commands (apt, dnf, pacman) for
dependencies: uv, node, npm.
- Clarified that macmon is macOS-only.
- Noted that Homebrew on Linux is optional, with native package managers
preferred.

These changes improve clarity for Linux users and fix confusion from the
previous macOS-centric instructions.
2025-12-27 19:56:44 +00:00
Evan Quiney
8e9332d6a7 Separate out the Runner's behaviour into a "connect" phase and a "load" phase (#1006)
## Motivation

We should ensure all runners are connected before loading the model -
this gives us finer grained control in the future for the workers
planning mechanism over the runners state.

## Changes

- Introduced task ConnectToGroup, preceeding LoadModel
- Introduced runner statuses Idle, Connecting, Connected
- Separated out initialize_mlx from shard_and_load
- Single instances never go through the connecting phase

## Test Plan

# Automated Testing
Added a test for checking event ordering in a standard workflow.

# Manual testing
Tested Llama 3.2 1b and Kimi K2 Thinking loads and shuts down repeatedly
on multiple configurations.
Not exhaustive, however.

---------

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2025-12-27 16:28:42 +00:00
Heath Dutton🕴️
4b65d5f896 Fix race condition in mlx_distributed_init with concurrent instances (#1012)
## Motivation

Fixes #1005

When multiple instances initialize concurrently with the same rank, they
overwrite each other's coordination files (hosts_{rank}.json), causing
"[jaccl] Malformed device file" errors and initialization failures.

## Changes

- Changed coordination filename from `./hosts_{rank}.json` to
`./hosts_{instance_id}_{rank}.json` to make it unique per instance
- Added cleanup in a finally block to remove coordination files after
initialization completes
- Applied fix to both MlxRingInstance and MlxJacclInstance cases

## Why It Works

Each instance now gets a unique coordination file based on its
instance_id, preventing concurrent instances from overwriting each
other's files. The cleanup logic ensures files are removed after use,
preventing accumulation and handling both success and failure cases.

## Test Plan

### Manual Testing
Code review and logic verification. The fix prevents the race condition
by ensuring filename uniqueness per instance.

### Automated Testing
No new tests added. Existing tests continue to pass.

---------

Co-authored-by: Ryuichi Leo Takashige <rl.takashige@gmail.com>
2025-12-27 16:13:26 +00:00
70 changed files with 3665 additions and 5039 deletions

View File

@@ -1,159 +0,0 @@
# EXO Benchmark Dashboard
A fully self-contained, browser-based dashboard for tracking EXO benchmark performance over time.
## Features
- 📊 **Success Rate Tracking**: Monitor cluster reliability across commits
-**Response Time Analysis**: Track average request completion times
- 🎯 **Throughput Metrics**: Tokens per second visualization
- 📈 **Request Distribution**: Success/failure breakdown over time
- 🔄 **Auto-Refresh**: Updates every 60 seconds
- 📺 **TV-Ready**: Large, clear visualizations perfect for display
- 🔐 **Secure**: Credentials stored in browser localStorage only
- 🌐 **No Backend**: Directly accesses S3 from the browser
## Quick Start
### Option 1: Direct File Access (Simplest)
Just open the HTML file directly in your browser:
```bash
open .github/benchmark-dashboard/index.html
```
Then click "Configure AWS Credentials" and enter your keys.
### Option 2: URL Parameters (For Quick Setup)
```bash
# Serve with credentials in URL (they'll be moved to localStorage)
open ".github/benchmark-dashboard/index.html?accessKey=YOUR_KEY&secretKey=YOUR_SECRET&region=us-east-1"
```
The credentials will be saved to localStorage and removed from the URL immediately.
### Option 3: Simple HTTP Server
```bash
# From repo root
python3 -m http.server 8080
# Then open: http://localhost:8080/.github/benchmark-dashboard/
```
## AWS Credentials
The dashboard needs read-only access to the `exo-benchmark-results` S3 bucket.
### Required IAM Permissions
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::exo-benchmark-results",
"arn:aws:s3:::exo-benchmark-results/*"
]
}
]
}
```
### Security Notes
- ✅ Credentials stored in browser `localStorage` only
- ✅ Never sent to any server (except AWS)
- ✅ All S3 access happens client-side
- ✅ Use read-only IAM credentials
- ⚠️ Don't commit credentials to git
- ⚠️ Use a dedicated read-only IAM user
## TV/Kiosk Mode
For permanent display on a TV:
### macOS
```bash
open -a "Google Chrome" --args --kiosk ".github/benchmark-dashboard/index.html"
```
### Linux
```bash
chromium-browser --kiosk --app="file://$(pwd)/.github/benchmark-dashboard/index.html"
```
### Auto-start on Boot
Create a simple startup script:
```bash
#!/bin/bash
# /usr/local/bin/start-benchmark-dashboard.sh
cd /path/to/exo
python3 -m http.server 8080 &
sleep 2
chromium-browser --kiosk http://localhost:8080/.github/benchmark-dashboard/
```
## Data Displayed
### Summary Cards
- **Latest Success Rate**: Most recent benchmark success percentage with trend
- **Avg Response Time**: Latest average response time in ms with trend
- **Total Benchmarks**: Count of all benchmarks run
- **Active Configurations**: Number of unique benchmark configs
### Charts
1. **Success Rate Over Time**: Line chart showing reliability trends
2. **Average Response Time**: Performance over time (lower is better)
3. **Throughput**: Tokens/second metric (higher is better)
4. **Request Distribution**: Stacked bar chart of successes/failures
## How It Works
1. **Loads AWS SDK**: Uses AWS SDK for JavaScript (browser version)
2. **Lists S3 Objects**: Fetches all files from `s3://exo-benchmark-results/bench/`
3. **Downloads Results**: Fetches each JSON result file
4. **Parses & Visualizes**: Uses Chart.js to create interactive charts
5. **Auto-Refreshes**: Polls S3 every 60 seconds for new results
## Customization
To modify the dashboard:
1. Edit `index.html`
2. Adjust `REFRESH_INTERVAL` for different polling frequency
3. Modify chart colors/styles in the Chart.js configuration
4. Add new metrics by extending the results parsing
## Troubleshooting
**"AWS credentials not configured"**
- Click "Configure AWS Credentials" and enter your keys
**"Error loading benchmark data"**
- Check AWS credentials are correct
- Verify S3 bucket name is `exo-benchmark-results`
- Ensure IAM user has read permissions
- Check browser console for detailed errors
**"No benchmark results found"**
- Wait for benchmark workflows to run
- Verify results are being uploaded to S3
- Check S3 bucket has files in `bench/` prefix
**Charts not updating**
- Check browser console for errors
- Verify network connectivity to S3
- Try refreshing the page manually

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,186 +0,0 @@
# EXO Benchmark Configurations
This directory contains configuration files for the EXO staged benchmark system.
## Overview
The staged benchmark system allows you to run complex, multi-stage load tests against EXO clusters. Each stage can have different characteristics:
- **Prompt Length**: Number of tokens in the input prompt
- **Generation Length**: Maximum tokens to generate in the response
- **Time Between Requests**: Delay (in seconds) between firing consecutive requests
- **Iterations**: Number of requests to send in this stage
Requests are **fire-and-forget** - they don't wait for the previous request to complete. This allows you to test overlapping request handling and measure success rates under load.
## Configuration Files
### `bench_simple.yaml`
A minimal configuration that replicates the behavior of the original `bench.py` script:
- Single stage with 1 iteration
- Short prompt (~20 tokens)
- Generates up to 100 tokens
This is useful for quick smoke tests.
### `bench_config.yaml`
A comprehensive multi-stage benchmark with:
1. **Warmup** (10 requests): Light load with short prompts
2. **Medium Load** (20 requests): Moderate load with medium prompts
3. **Stress Test** (30 requests): Heavy overlapping requests with long prompts
4. **Cooldown** (5 requests): Light load to wind down
This tests the cluster's behavior under varying load patterns.
## Configuration Schema
```yaml
# Hardware configuration - maps runner labels to instance counts
hardware_plan:
M3ULTRA_GPU80_512GB: 4
# Environment variables to set on each node (optional)
environment:
OVERRIDE_MEMORY_MB: 512
# Timeout for instance and runner readiness (seconds)
timeout_seconds: 600
# Model instances to run concurrently
model_ids:
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
# Benchmark stages
stages:
- name: "stage_name" # Human-readable name for this stage
prompt_length: 100 # Target prompt length in tokens
generation_length: 200 # Max tokens to generate
time_between_requests: 2.0 # Seconds between firing requests
iterations: 10 # Number of requests in this stage
```
## Running Benchmarks
### Via GitHub Actions
**Automatic (every commit):**
- The **`bench`** workflow runs automatically on every push
- Uses `bench_simple.yaml` as the default configuration
- All settings (hardware plan, timeout, environment variables, models, stages) are defined in the config file
**Manual (on-demand):**
1. Go to **Actions****bench** workflow
2. Click **Run workflow**
3. Configure:
- **Config File**: Path to your YAML config (default: `.github/configs/bench_simple.yaml`)
- `.github/configs/bench_simple.yaml` for quick tests
- `.github/configs/bench_config.yaml` for complex multi-stage tests
All other settings (hardware plan, timeout, environment variables, models, stages) are read from the specified config file.
### Via Command Line
```bash
# Start EXO on localhost:8000
uv run exo --api-port 8000
# Run simple benchmark (1 stage, 1 iteration)
python3 .github/scripts/bench.py \
--api-port 8000 \
--config .github/configs/bench_simple.yaml \
--expected-nodes 1 \
--is-primary true \
--timeout-seconds 600
# Run complex staged benchmark (4 stages, multiple iterations)
python3 .github/scripts/bench.py \
--api-port 8000 \
--config .github/configs/bench_config.yaml \
--expected-nodes 1 \
--is-primary true \
--timeout-seconds 600
```
## Output Metrics
For each stage, the benchmark reports:
- **Total Requests**: Number of requests fired
- **Successful Requests**: Requests that completed successfully
- **Failed Requests**: Requests that encountered errors
- **Success Rate**: Percentage of successful requests
- **Total Tokens**: Sum of all tokens generated across successful requests
- **Avg Tokens/Request**: Average tokens per successful request
- **Avg Time/Request**: Average completion time per successful request
A JSON summary is also printed for easy parsing and storage.
## Creating Custom Benchmarks
To create a custom benchmark:
1. Copy an existing config file (e.g., `bench_config.yaml`)
2. Modify the stages to match your test scenario
3. Save it in this directory with a descriptive name
4. Run it using the workflow or command line
### Example: Sustained Load Test
```yaml
hardware_plan:
M3ULTRA_GPU80_512GB: 2
environment:
OVERRIDE_MEMORY_MB: 1024
timeout_seconds: 600
model_ids:
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
stages:
- name: "sustained_load"
prompt_length: 200
generation_length: 150
time_between_requests: 0.5 # Very fast - 2 requests/second
iterations: 100 # Run for ~50 seconds
```
### Example: Varying Prompt Sizes
```yaml
hardware_plan:
M4PRO_GPU16_24GB: 3
timeout_seconds: 900
model_ids:
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
stages:
- name: "tiny_prompts"
prompt_length: 10
generation_length: 100
time_between_requests: 1.0
iterations: 10
- name: "medium_prompts"
prompt_length: 200
generation_length: 100
time_between_requests: 1.0
iterations: 10
- name: "large_prompts"
prompt_length: 1000
generation_length: 100
time_between_requests: 1.0
iterations: 10
```
## Tips
- **Overlapping Requests**: Set `time_between_requests` < expected completion time to test concurrent request handling
- **Sequential Requests**: Set `time_between_requests` > expected completion time to ensure requests don't overlap
- **Realistic Load**: Model real usage patterns by varying prompt/generation lengths across stages
- **Success Rate**: A 100% success rate indicates the cluster handled the load well; lower rates suggest capacity limits

View File

@@ -1,49 +0,0 @@
# EXO Staged Benchmark Configuration
# This configuration defines a multi-stage load test for EXO clusters
# Hardware configuration - maps runner labels to instance counts
hardware_plan:
M3ULTRA_GPU80_512GB: 4
# Environment variables to set on each node (optional)
environment:
OVERRIDE_MEMORY_MB: 512
# Timeout for instance and runner readiness (seconds)
timeout_seconds: 600
# Multiple instances run concurrently on the cluster
model_ids:
- "mlx-community/Qwen3-0.6B-4bit"
- "mlx-community/Qwen3-0.6B-4bit"
# Stages run sequentially, each with its own characteristics
stages:
# Stage 1: Light load with short prompts
- name: "warmup"
prompt_length: 50 # Number of tokens in prompt
generation_length: 100 # Max tokens to generate
time_between_requests: 5.0 # Seconds between firing requests
iterations: 10 # Number of requests to send in this stage
# Stage 2: Medium load with medium prompts
- name: "medium_load"
prompt_length: 200
generation_length: 150
time_between_requests: 3.0
iterations: 20
# Stage 3: Heavy load with long prompts - requests will overlap
- name: "stress_test"
prompt_length: 500
generation_length: 200
time_between_requests: 1.0 # Fast firing - will definitely overlap
iterations: 30
# Stage 4: Cool down with simple prompts
- name: "cooldown"
prompt_length: 50
generation_length: 50
time_between_requests: 10.0
iterations: 5

View File

@@ -1,125 +0,0 @@
# Simple single-shot benchmark
# Tests 2 instances concurrently on 2 nodes
# Hardware configuration - maps runner labels to instance counts
hardware_plan:
puffin4: 1
puffin8: 1
# Environment variables to set on each node
environment:
PLACEHOLDER: "placeholder"
# OVERRIDE_MEMORY_MB: 50000
MLX_METAL_FAST_SYNCH: 1
# Timeout for instance and runner readiness (seconds)
timeout_seconds: 1800
# Model instances to run concurrently
model_ids:
# - "mlx-community/DeepSeek-V3.1-8bit"
# - "mlx-community/Kimi-K2-Instruct-4bit"
- "mlx-community/Kimi-K2-Thinking"
# - "mlx-community/Qwen3-235B-A22B-4bit"
# - "mlx-community/Llama-3.3-70B-Instruct-4bit"
# - "mlx-community/Llama-3.3-70B-Instruct-8bit"
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
# Sharding strategy: "Pipeline" or "Tensor"
sharding: "Tensor"
# Instance type: "MlxRing" or "MlxIbv"
instance_meta: "MlxIbv"
# If true, run requests sequentially (no overlap); if false, fire-and-forget (default: false)
no_overlap: true
# Benchmark stages
# pp: 64, 256, 1024, 2048, 4096, 8192, 16384
# g: 64, 512
stages:
# - name: "simple"
# prompt_length: 512
# generation_length: 10
# time_between_requests: 2.0
# iterations: 5
# - name: "pp64_g64"
# prompt_length: 64
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp64_g64"
# prompt_length: 64
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp64_g512"
# prompt_length: 64
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp256_g64"
# prompt_length: 256
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
- name: "pp256_g64"
prompt_length: 256
generation_length: 64
time_between_requests: 2.0
iterations: 5
# - name: "pp256_g512"
# prompt_length: 256
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp1024_g64"
# prompt_length: 1024
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp1024_g512"
# prompt_length: 1024
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp2048_g64"
# prompt_length: 2048
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp2048_g512"
# prompt_length: 2048
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp4096_g64"
# prompt_length: 4096
# generation_length: 64
# time_between_requests: 2.0
# iterations: 4
# - name: "pp4096_g512"
# prompt_length: 4096
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp8192_g64"
# prompt_length: 8192
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp8192_g512"
# prompt_length: 8192
# generation_length: 512
# time_between_requests: 2.0
# iterations: 5
# - name: "pp16384_g64"
# prompt_length: 16384
# generation_length: 64
# time_between_requests: 2.0
# iterations: 10
# - name: "pp16384_g512"
# prompt_length: 16384
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10

1399
.github/scripts/bench.py vendored
View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,70 +0,0 @@
#!/usr/bin/env python3
import json
import os
from typing import NotRequired, TypedDict, cast
import yaml
class MatrixEntry(TypedDict):
label: str
index: int
class MatrixInclude(TypedDict):
label: str
index: int
is_primary: bool
expected_nodes: int
class Config(TypedDict):
hardware_plan: dict[str, int]
timeout_seconds: NotRequired[int]
environment: NotRequired[dict[str, str]]
# Read the config file
config_file: str = os.environ["CONFIG_FILE"]
with open(config_file, "r") as f:
config: Config = cast(Config, yaml.safe_load(f))
# Extract hardware plan from config
plan: dict[str, int] = config["hardware_plan"]
if not plan:
raise ValueError(f"No hardware_plan found in {config_file}")
# Build matrix entries
entries: list[MatrixEntry] = []
for label, count in plan.items():
for idx in range(count):
entries.append({"label": label, "index": idx})
total_nodes: int = len(entries)
matrix: dict[str, list[MatrixInclude]] = {
"include": [
{
"label": e["label"],
"index": e["index"],
"is_primary": (i == 0),
"expected_nodes": total_nodes,
}
for i, e in enumerate(entries)
]
}
# Extract other config values
timeout_seconds: int = config.get("timeout_seconds", 600)
environment: dict[str, str] = config.get("environment", {})
# Output to GitHub Actions
with open(os.environ["GITHUB_OUTPUT"], "a") as f:
f.write(f"matrix={json.dumps(matrix)}\n")
f.write(f"config_file={config_file}\n")
f.write(f"timeout_seconds={timeout_seconds}\n")
f.write(f"environment={json.dumps(environment)}\n")
print(f"Matrix: {json.dumps(matrix)}")
print(f"Config file: {config_file}")
print(f"Timeout: {timeout_seconds}")
print(f"Environment: {json.dumps(environment)}")

View File

@@ -1,156 +0,0 @@
# Benchmark Workflow Usage
## Overview
The `bench_matrix.yml` workflow enables distributed benchmarking of models across multiple self-hosted macOS runners with different hardware configurations.
## Workflow Inputs
| Input | Description | Default | Required |
|-------|-------------|---------|----------|
| `model_id` | Model ID to benchmark | `mlx-community/Llama-3.2-1B-Instruct-4bit` | Yes |
| `hardware_plan` | JSON mapping of runner labels to counts | `{"M4PRO_GPU16_24GB": 1}` | Yes |
| `prompt` | Benchmark prompt text | `What is the capital of France?` | No |
| `timeout_seconds` | Timeout for instance/runner readiness | `600` | No |
## Hardware Plan Format
The `hardware_plan` input is a JSON object mapping runner labels to the number of machines:
```json
{
"M4PRO_GPU16_24GB": 2,
"M3ULTRA_GPU80_512GB": 1
}
```
This example would:
- Start 2 runners with the `M4PRO_GPU16_24GB` label
- Start 1 runner with the `M3ULTRA_GPU80_512GB` label
- Total of 3 runners coordinating on a single distributed inference instance
## How It Works
1. **Planning Job** (`plan`)
- Runs on `ubuntu-latest`
- Parses the `hardware_plan` JSON
- Generates a dynamic matrix with one entry per runner
- Only the first runner (index 0) is marked as `is_primary`
2. **Benchmark Worker Jobs** (`bench_worker`)
- Each job runs on a self-hosted macOS runner with the specified label
- All runners start EXO in parallel
- The primary runner creates the model instance
- All runners wait for their assigned runner to be ready (Loaded/Running status)
- The primary runner executes the benchmark and prints results
- The primary runner deletes the instance
## Example Usage
### Single Machine Benchmark
```yaml
model_id: mlx-community/Llama-3.2-1B-Instruct-4bit
hardware_plan: '{"M4PRO_GPU16_24GB": 1}'
prompt: What is the capital of France?
timeout_seconds: 600
```
### Multi-Machine Distributed Benchmark
```yaml
model_id: mlx-community/Llama-3.2-3B-Instruct-4bit
hardware_plan: '{"M4PRO_GPU16_24GB": 2, "M3ULTRA_GPU80_512GB": 1}'
prompt: Explain quantum computing in simple terms.
timeout_seconds: 900
```
## Benchmark Output
The primary runner outputs a JSON object with benchmark results:
```json
{
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"instance_id": "abc-123-def",
"tokens": 42,
"elapsed_s": 2.451,
"tps": 17.136
}
```
Where:
- `tokens`: Number of chunks/tokens generated
- `elapsed_s`: Total elapsed time in seconds
- `tps`: Tokens per second (tokens / elapsed_s)
## Runner Requirements
Each self-hosted runner must:
- Be labeled with appropriate hardware tags (e.g., `M4PRO_GPU16_24GB`)
- Have the `self-hosted` and `macOS` labels
- Have Nix installed with flakes enabled
- Have network connectivity to other runners in the same job
## Architecture
```
┌─────────────────────────────────────────────────────────────┐
│ GitHub Actions Workflow (bench_matrix.yml) │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌────────────────┐ │
│ │ Plan Job │ │
│ │ (ubuntu) │──┬─► Matrix: [{label, index, primary}] │
│ └────────────────┘ │ │
│ │ │
│ ┌───────────────────▼──────────────────────────────────┐ │
│ │ Bench Worker Jobs (Matrix) │ │
│ ├──────────────────────────────────────────────────────┤ │
│ │ │ │
│ │ Runner 0 (Primary) Runner 1 Runner 2 │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ │
│ │ │ Start EXO │ │ Start EXO │ │ Start EXO│ │ │
│ │ │ Create Inst │ │ Wait... │ │ Wait... │ │ │
│ │ │ Wait Ready │ │ Wait Ready │ │ Wait... │ │ │
│ │ │ Run Bench │ │ (idle) │ │ (idle) │ │ │
│ │ │ Print TPS │ │ │ │ │ │ │
│ │ │ Delete Inst │ │ │ │ │ │ │
│ │ └─────────────┘ └─────────────┘ └──────────┘ │ │
│ └───────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
```
## Implementation Details
### `scripts/bench.py`
A standalone Python script that:
- Creates instance (primary only)
- Polls `/state` endpoint until instance and all runners are ready
- Executes chat completion with timing (primary only)
- Parses SSE stream and counts tokens
- Computes TPS metrics
- Cleans up instance (primary only)
### Key Functions
- `wait_for_instance()`: Polls until instance with model_id appears
- `wait_for_runners_ready()`: Polls until expected number of runners reach Loaded/Running status
- `run_benchmark()`: Executes chat completion, measures time, counts tokens
## Troubleshooting
### Instance never becomes ready
- Check EXO logs in the workflow output
- Verify model_id is valid and accessible
- Increase `timeout_seconds`
### Runner mismatch
- Ensure hardware_plan counts match available labeled runners
- Check runner labels match exactly (case-sensitive)
### Network issues
- Verify runners can communicate on the network
- Check firewall rules between runner hosts

View File

@@ -1,305 +0,0 @@
name: bench
on: [push]
jobs:
plan:
if: contains(github.event.head_commit.message, '/bench')
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.build.outputs.matrix }}
config_file: ${{ steps.build.outputs.config_file }}
timeout_seconds: ${{ steps.build.outputs.timeout_seconds }}
environment: ${{ steps.build.outputs.environment }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Build matrix from config file
id: build
shell: bash
run: |
set -euo pipefail
CONFIG_FILE='.github/configs/bench_simple.yaml'
export CONFIG_FILE
echo "Config file: $CONFIG_FILE"
python3 .github/scripts/build_matrix.py
bench_worker:
needs: plan
strategy:
fail-fast: false
matrix: ${{ fromJSON(needs.plan.outputs.matrix) }}
name: "bench on ${{ matrix.label }} [${{ matrix.index }}]"
runs-on: [self-hosted, macOS, "${{ matrix.label }}"]
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
# TODO: this is mega hacky and I'd like a simpler solution.
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix is already available
if command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
# Try sourcing profile scripts to set up environment properly
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Sourcing multi-user nix-daemon profile script"
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
elif [ -f "$HOME/.nix-profile/etc/profile.d/nix.sh" ]; then
echo "Sourcing single-user nix profile script"
source "$HOME/.nix-profile/etc/profile.d/nix.sh"
elif [ -f /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh ]; then
echo "Sourcing per-user nix profile script"
source /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh
elif [ -f /etc/profile.d/nix.sh ]; then
echo "Sourcing system-wide nix profile script"
source /etc/profile.d/nix.sh
# Fallback: manually add nix to PATH if binary exists
elif [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary, manually adding to PATH"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
elif [ -f "$HOME/.nix-profile/bin/nix" ]; then
echo "Found nix binary in user profile, manually adding to PATH"
export PATH="$HOME/.nix-profile/bin:$PATH"
else
echo "Nix not found. Debugging info:"
echo "USER: $USER"
echo "HOME: $HOME"
echo "Current PATH: $PATH"
echo ""
echo "Checking common Nix locations:"
echo " /nix/var/nix/profiles/default/bin/nix:"
ls -la /nix/var/nix/profiles/default/bin/nix 2>/dev/null || echo " Not found"
echo " /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh:"
ls -la /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh 2>/dev/null || echo " Not found"
echo " ~/.nix-profile/etc/profile.d/nix.sh:"
ls -la "$HOME/.nix-profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
echo " /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh:"
ls -la "/nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
echo ""
echo "/nix directory structure:"
ls -la /nix 2>/dev/null || echo " /nix directory not found"
echo ""
echo "/nix/var:"
ls -la /nix/var 2>/dev/null || echo " /nix/var not found"
echo ""
echo "/nix/store:"
ls -la /nix/store 2>/dev/null | head -20 || echo " /nix/store not found"
echo ""
echo "GitHub Actions runner is running as user '$USER'."
echo "If Nix is installed for a different user, either:"
echo " 1. Install Nix for user '$USER' (multi-user install recommended)"
echo " 2. Configure the runner service to run as the user with Nix installed"
echo " 3. Ensure Nix is installed system-wide with proper daemon setup"
exit 1
fi
# Verify nix is available and persist to GITHUB_ENV
if command -v nix >/dev/null 2>&1; then
echo "✓ Nix is available"
nix --version
echo "PATH=$PATH" >> $GITHUB_ENV
if [ -n "$NIX_PATH" ]; then
echo "NIX_PATH=$NIX_PATH" >> $GITHUB_ENV
fi
else
echo "ERROR: Failed to set up Nix"
echo "PATH after setup attempt: $PATH"
exit 1
fi
shell: bash
- name: Setup EXO_HOME and API_PORT
run: |
EXO_HOME=$(mktemp -d -t exo-e2e-XXXXXXXX)
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
EXO_MODELS_DIR="$HOME/.exo/models"
EXO_LIBP2P_NAMESPACE="bench-${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
echo "EXO_HOME=$EXO_HOME" >> "$GITHUB_ENV"
echo "API_PORT=$API_PORT" >> "$GITHUB_ENV"
echo "EXO_MODELS_DIR=$EXO_MODELS_DIR" >> "$GITHUB_ENV"
echo "EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE" >> "$GITHUB_ENV"
echo "Created EXO_HOME: $EXO_HOME"
echo "Generated API_PORT: $API_PORT"
echo "Using models from: $EXO_MODELS_DIR"
echo "Using libp2p namespace: $EXO_LIBP2P_NAMESPACE"
shell: bash
- name: Configure local MLX if available
run: |
echo "=== DEBUG: Checking for local MLX configuration ==="
MODIFIED=false
echo "Checking for /Users/Shared/mlx directory..."
if [ -d "/Users/Shared/mlx" ]; then
echo "✓ Found /Users/Shared/mlx"
ls -la /Users/Shared/mlx | head -5
echo "Enabling local mlx path in pyproject.toml"
sed -i.bak 's|^# mlx = { path = "/Users/Shared/mlx", editable=true }$|mlx = { path = "/Users/Shared/mlx", editable=true }|' pyproject.toml
MODIFIED=true
else
echo "✗ /Users/Shared/mlx not found, will use PyPI version"
fi
echo "Checking for /Users/Shared/mlx-lm directory..."
if [ -d "/Users/Shared/mlx-lm" ]; then
echo "✓ Found /Users/Shared/mlx-lm"
ls -la /Users/Shared/mlx-lm | head -5
echo "Enabling local mlx-lm path in pyproject.toml"
sed -i.bak 's|^# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }$|mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }|' pyproject.toml
MODIFIED=true
else
echo "✗ /Users/Shared/mlx-lm not found, will use PyPI version"
fi
if [ "$MODIFIED" = true ]; then
echo "=== Modified pyproject.toml [tool.uv.sources] section: ==="
sed -n '/\[tool\.uv\.sources\]/,/^\[/{/^\[tool\.uv\.sources\]/p; /^\[/!p;}' pyproject.toml
echo "=== Regenerating uv.lock with local MLX paths... ==="
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command uv lock --upgrade-package mlx --upgrade-package mlx-lm
echo "✓ Lock file regenerated"
else
echo "⚠ No local MLX directories found, using PyPI packages"
fi
echo "=== DEBUG: Local MLX configuration complete ==="
shell: bash
- name: Sync dependencies
run: |
if [ -d "/Users/Shared/test" ]; then
pushd /Users/Shared/test
uv sync --reinstall
popd
fi
echo "Running just sync to ensure clean dependencies..."
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync
shell: bash
- name: Start EXO and run bench script
shell: bash
env:
IS_PRIMARY: ${{ matrix.is_primary }}
EXPECTED_NODES: ${{ matrix.expected_nodes }}
HARDWARE_LABEL: ${{ matrix.label }}
CONFIG_FILE: ${{ needs.plan.outputs.config_file }}
TIMEOUT_SECONDS: ${{ needs.plan.outputs.timeout_seconds }}
ENVIRONMENT_JSON: ${{ needs.plan.outputs.environment }}
run: |
set -euo pipefail
# Parse environment variables from config
ENV_VARS=""
if [ -n "$ENVIRONMENT_JSON" ] && [ "$ENVIRONMENT_JSON" != "{}" ]; then
ENV_VARS=$(echo "$ENVIRONMENT_JSON" | python3 -c "import sys, json; env = json.load(sys.stdin); print(' '.join([f'{k}={v}' for k, v in env.items()]))")
fi
echo "Starting EXO with API_PORT=${API_PORT} EXO_HOME=${EXO_HOME} EXO_LIBP2P_NAMESPACE=${EXO_LIBP2P_NAMESPACE}"
echo "Environment variables from config: $ENV_VARS"
LOG_FILE=/tmp/exo.log
: > "$LOG_FILE"
MASTER_FLAG=""
if [ "$IS_PRIMARY" = "true" ]; then
MASTER_FLAG="-m"
fi
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
"EXO_HOME=$EXO_HOME EXO_MODELS_DIR=$EXO_MODELS_DIR EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE $ENV_VARS PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run exo $MASTER_FLAG --api-port $API_PORT" \
>> "$LOG_FILE" 2>&1 &
EXO_PID=$!
echo "Started EXO in background with PID: $EXO_PID"
echo "Log file: $LOG_FILE"
cleanup() {
echo '=== EXO log (tail) ==='
tail -n 300 "$LOG_FILE" || true
if ps -p "$EXO_PID" >/dev/null 2>&1; then
echo "Killing EXO (PID $EXO_PID)"
kill "$EXO_PID" || true
fi
}
trap cleanup EXIT
for i in $(seq 1 60); do
if curl -s "http://localhost:${API_PORT}/state" >/dev/null 2>&1; then
echo "EXO API ready"
break
fi
if ! ps -p "$EXO_PID" >/dev/null 2>&1; then
echo "EXO terminated early"; sed -n '1,200p' "$LOG_FILE" || true; exit 1
fi
sleep 1
done
RESULTS_FILE="/tmp/bench_results_${GITHUB_RUN_ID}_${GITHUB_RUN_ATTEMPT}_$(date +%s).json"
echo "Results will be saved to: $RESULTS_FILE"
echo "RESULTS_FILE=$RESULTS_FILE" >> "$GITHUB_ENV"
echo "Running bench script with config: $CONFIG_FILE, timeout: $TIMEOUT_SECONDS"
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
"PYTHONUNBUFFERED=1 uv run --no-project --with pyyaml --with pydantic python .github/scripts/bench.py \
--api-port $API_PORT \
--config $CONFIG_FILE \
--expected-nodes ${EXPECTED_NODES} \
--is-primary ${IS_PRIMARY} \
--timeout-seconds ${TIMEOUT_SECONDS} \
--output $RESULTS_FILE \
--git-commit ${GITHUB_SHA} \
--hardware-labels ${HARDWARE_LABEL}"
- name: Install AWS CLI
if: always() && env.RESULTS_FILE && matrix.is_primary
run: |
if ! command -v aws &> /dev/null; then
echo "AWS CLI not found, installing..."
brew install awscli
else
echo "AWS CLI already installed"
fi
shell: bash
- name: Upload results to S3
if: always() && env.RESULTS_FILE && matrix.is_primary
env:
AWS_ACCESS_KEY_ID: ${{ secrets.S3_BENCHMARKS_AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: us-east-1
run: |
echo "Checking for results file: $RESULTS_FILE"
echo "Is primary: ${{ matrix.is_primary }}"
if [ -f "$RESULTS_FILE" ]; then
TIMESTAMP=$(date -u +%Y/%m/%d/%H%M%S)
S3_KEY="bench/${TIMESTAMP}_${GITHUB_SHA:0:8}_${GITHUB_RUN_ID}.json"
echo "Uploading results to s3://exo-benchmark-results/$S3_KEY"
aws s3 cp "$RESULTS_FILE" "s3://exo-benchmark-results/$S3_KEY" \
--content-type application/json \
--metadata "commit=${GITHUB_SHA},run_id=${GITHUB_RUN_ID},branch=${GITHUB_REF_NAME}"
echo "Results uploaded successfully"
echo "View at: https://exo-benchmark-results.s3.amazonaws.com/$S3_KEY"
else
echo "Results file not found at: $RESULTS_FILE"
echo "Skipping upload"
fi
shell: bash
- name: Cleanup EXO_HOME
run: |
echo "Cleaning up EXO_HOME: $EXO_HOME"
rm -rf "$EXO_HOME"
shell: bash
if: always()

2
.gitignore vendored
View File

@@ -7,6 +7,8 @@ digest.txt
# nix
.direnv/
# IDEA (PyCharm)
.idea
# xcode / macos
*.xcuserstate

View File

@@ -8,7 +8,7 @@
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/72NsF6ux" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
<a href="https://x.com/exolabs" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/twitter/follow/exolabs?style=social" alt="X"></a>
<a href="https://www.apache.org/licenses/LICENSE-2.0.html" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/License-Apache2.0-blue.svg" alt="License: Apache-2.0"></a>
</p>
@@ -61,10 +61,10 @@ Devices running exo automatically discover each other, without needing any manua
There are two ways to run exo:
### Run from Source (Mac & Linux)
### Run from Source (macOS)
**Prerequisites:**
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
@@ -98,6 +98,62 @@ uv run exo
This starts the exo dashboard and API at http://localhost:52415/
### Run from Source (Linux)
**Prerequisites:**
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [node](https://github.com/nodejs/node) (for building the dashboard) - version 18 or higher
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
**Installation methods:**
**Option 1: Using system package manager (Ubuntu/Debian example):**
```bash
# Install Node.js and npm
sudo apt update
sudo apt install nodejs npm
# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
# Install Rust (using rustup)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
rustup toolchain install nightly
```
**Option 2: Using Homebrew on Linux (if preferred):**
```bash
# Install Homebrew on Linux
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
# Install dependencies
brew install uv node
# Install Rust (using rustup)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
rustup toolchain install nightly
```
**Note:** The `macmon` package is macOS-only and not required for Linux.
Clone the repo, build the dashboard, and run exo:
```bash
# Clone exo
git clone https://github.com/exo-explore/exo
# Build dashboard
cd exo/dashboard && npm install && npm run build && cd ..
# Run exo
uv run exo
```
This starts the exo dashboard and API at http://localhost:52415/
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
### macOS App
exo ships a macOS app that runs in the background on your Mac.
@@ -112,6 +168,29 @@ The app will ask for permission to modify system settings and install a new Netw
---
### Enabling RDMA on macOS
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
To enable RDMA on macOS, follow these steps:
1. Shut down your Mac.
2. Hold down the power button for 10 seconds until the boot menu appears.
3. Select "Options" to enter Recovery mode.
4. When the Recovery UI appears, open the Terminal from the Utilities menu.
5. In the Terminal, type:
```
rdma_ctl enable
```
and press Enter.
6. Reboot your Mac.
After that, RDMA will be enabled in macOS and exo will take care of the rest.
---
### Using the API
If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.

View File

@@ -20,6 +20,8 @@ struct ContentView: View {
@State private var showDebugInfo = false
@State private var bugReportInFlight = false
@State private var bugReportMessage: String?
@State private var showAdvancedOptions = false
@State private var pendingNamespace: String = ""
var body: some View {
VStack(alignment: .leading, spacing: 12) {
@@ -49,7 +51,7 @@ struct ContentView: View {
private var topologySection: some View {
Group {
if let topology = stateService.latestSnapshot?.topologyViewModel(), !topology.nodes.isEmpty {
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
TopologyMiniView(topology: topology)
}
}
@@ -197,6 +199,8 @@ struct ContentView: View {
updater.checkForUpdates()
}
.padding(.bottom, 8)
advancedOptionsSection
.padding(.bottom, 8)
debugSection
.padding(.bottom, 8)
controlButton(title: "Quit", tint: .secondary) {
@@ -327,6 +331,47 @@ struct ContentView: View {
}
}
private var advancedOptionsSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Advanced Options")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showAdvancedOptions)
}
.animation(nil, value: showAdvancedOptions)
if showAdvancedOptions {
VStack(alignment: .leading, spacing: 8) {
VStack(alignment: .leading, spacing: 4) {
Text("Cluster Namespace")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
TextField("optional", text: $pendingNamespace)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingNamespace = controller.customNamespace
}
Button("Save & Restart") {
controller.customNamespace = pendingNamespace
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingNamespace == controller.customNamespace)
}
}
}
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showAdvancedOptions)
}
private var debugSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {

View File

@@ -2,6 +2,8 @@ import AppKit
import Combine
import Foundation
private let customNamespaceKey = "EXOCustomNamespace"
@MainActor
final class ExoProcessController: ObservableObject {
enum Status: Equatable {
@@ -27,6 +29,13 @@ final class ExoProcessController: ObservableObject {
@Published private(set) var status: Status = .stopped
@Published private(set) var lastError: String?
@Published private(set) var launchCountdownSeconds: Int?
@Published var customNamespace: String = {
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
}() {
didSet {
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
}
}
private var process: Process?
private var runtimeDirectoryURL: URL?
@@ -180,7 +189,7 @@ final class ExoProcessController: ObservableObject {
private func makeEnvironment(for runtimeURL: URL) -> [String: String] {
var environment = ProcessInfo.processInfo.environment
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
environment["EXO_LIBP2P_NAMESPACE"] = buildTag()
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
var paths: [String] = []
if let existing = environment["PATH"], !existing.isEmpty {
@@ -217,6 +226,12 @@ final class ExoProcessController: ObservableObject {
}
return "dev"
}
private func computeNamespace() -> String {
let base = buildTag()
let custom = customNamespace.trimmingCharacters(in: .whitespaces)
return custom.isEmpty ? base : custom
}
}
struct RuntimeError: LocalizedError {

View File

@@ -82,7 +82,6 @@ struct BugReportService {
}
private func loadCredentials() throws -> AWSConfig {
// These credentials are write-only and necessary to receive bug reports from users
return AWSConfig(
accessKey: "AKIAYEKP5EMXTOBYDGHX",
secretKey: "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",

View File

@@ -7,6 +7,7 @@ final class ClusterStateService: ObservableObject {
@Published private(set) var lastError: String?
@Published private(set) var lastActionMessage: String?
@Published private(set) var modelOptions: [ModelOption] = []
@Published private(set) var localNodeId: String?
private var timer: Timer?
private let decoder: JSONDecoder
@@ -29,6 +30,7 @@ final class ClusterStateService: ObservableObject {
func startPolling(interval: TimeInterval = 0.5) {
stopPolling()
Task {
await fetchLocalNodeId()
await fetchModels()
await fetchSnapshot()
}
@@ -46,9 +48,31 @@ final class ClusterStateService: ObservableObject {
latestSnapshot = nil
lastError = nil
lastActionMessage = nil
localNodeId = nil
}
private func fetchLocalNodeId() async {
do {
let url = baseURL.appendingPathComponent("node_id")
var request = URLRequest(url: url)
request.cachePolicy = .reloadIgnoringLocalCacheData
let (data, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
return
}
if let nodeId = try? decoder.decode(String.self, from: data) {
localNodeId = nodeId
}
} catch {
// Silently ignore - localNodeId will remain nil and retry on next poll
}
}
private func fetchSnapshot() async {
// Retry fetching local node ID if not yet set
if localNodeId == nil {
await fetchLocalNodeId()
}
do {
var request = URLRequest(url: endpoint)
request.cachePolicy = .reloadIgnoringLocalCacheData

View File

@@ -85,7 +85,7 @@ struct TopologyViewModel {
}
extension ClusterState {
func topologyViewModel() -> TopologyViewModel? {
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
guard !allNodes.isEmpty else { return nil }
@@ -105,6 +105,11 @@ extension ClusterState {
orderedNodes = allNodes
}
// Rotate so the local node (from /node_id API) is first
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
}
let nodeIds = Set(orderedNodes.map(\.id))
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
@@ -112,10 +117,7 @@ extension ClusterState {
} ?? []
let edges = Set(edgesArray)
let topologyRootId = topology?.nodes.first?.nodeId
let currentId = orderedNodes.first(where: { $0.id == topologyRootId })?.id ?? orderedNodes.first?.id
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: currentId)
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
}
}

526
bench/exo_bench.py Normal file
View File

@@ -0,0 +1,526 @@
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
from __future__ import annotations
import argparse
import http.client
import json
import os
import time
from collections.abc import Callable
from statistics import mean
from typing import Any
from urllib.parse import urlencode
from loguru import logger
from transformers import AutoTokenizer
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
time.sleep(0.1)
continue
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def format_peak_memory(b: float) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if b < 1024.0:
return f"{b:.2f}{unit}"
b /= 1024.0
raise ValueError("You're using petabytes of memory. Something went wrong...")
def parse_int_list(values: list[str]) -> list[int]:
items: list[int] = []
for v in values:
for part in v.split(","):
part = part.strip()
if part:
items.append(int(part))
seen: set[int] = set()
out: list[int] = []
for x in items:
if x not in seen:
out.append(x)
seen.add(x)
return out
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if m.get("id") == model_arg:
short_id = str(m["id"])
full_id = str(m.get("hugging_face_id") or m["id"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["id"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def run_one_completion(
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
) -> tuple[dict[str, Any], int]:
content, pp_tokens = prompt_sizer.build(pp_hint)
payload: dict[str, Any] = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": tg,
}
t0 = time.perf_counter()
out = client.post_bench_chat_completions(payload)
elapsed = time.perf_counter() - t0
stats = out.get("generation_stats")
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
return {
"elapsed_s": elapsed,
"output_text_preview": preview,
"stats": stats,
}, pp_tokens
class PromptSizer:
def __init__(self, tokenizer: Any, atom: str = "a "):
self.tokenizer = tokenizer
self.atom = atom
self.count_fn = PromptSizer._make_counter(tokenizer)
self.base_tokens = self.count_fn("")
@staticmethod
def _make_counter(tokenizer: Any) -> Callable[[str], int]:
def count_fn(user_content: str) -> int:
messages = [{"role": "user", "content": user_content}]
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
return int(len(ids))
return count_fn
def build(self, target_prompt_tokens: int) -> tuple[str, int]:
target = int(target_prompt_tokens)
if target < self.base_tokens:
raise RuntimeError(
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
)
content = ""
tok = self.count_fn(content)
while tok < target:
content += self.atom
tok = self.count_fn(content)
if tok != target:
raise RuntimeError(
f"Overshot: got {tok} tokens (target {target}). "
f"Pick a different atom (try ' a' or '\\n' or '0 ')."
)
return content, tok
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
description="Benchmark exo model throughput across placement previews.",
)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
ap.add_argument(
"--pp",
nargs="+",
required=True,
help="Prompt-size hints (ints). Accepts commas.",
)
ap.add_argument(
"--tg",
nargs="+",
required=True,
help="Generation lengths (ints). Accepts commas.",
)
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Pipeline jaccl is often pointless, skip by default",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
)
ap.add_argument(
"--warmup",
type=int,
default=0,
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
default="bench/results.json",
help="Write raw per-run results JSON to this path.",
)
ap.add_argument(
"--dry-run", action="store_true", help="List selected placements and exit."
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
tg_list = parse_int_list(args.tg)
if not pp_list or not tg_list:
logger.error("pp and tg lists must be non-empty")
return 2
if args.repeat <= 0:
logger.error("--repeat must be >= 1")
return 2
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
short_id, full_model_id = resolve_model_short_id(client, args.model)
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": short_id}
)
previews = previews_resp.get("previews") or []
tokenizer = AutoTokenizer.from_pretrained(
full_model_id,
trust_remote_code=True,
)
if tokenizer is None:
raise RuntimeError("[exo-bench] tokenizer load failed")
try:
prompt_sizer = PromptSizer(tokenizer)
logger.debug(f"[exo-bench] loaded tokenizer: {full_model_id} for prompt sizer")
except Exception:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if 0 < n <= args.max_nodes:
selected.append(p)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
selected.sort(
key=lambda p: (
str(p.get("instance_meta", "")),
str(p.get("sharding", "")),
-nodes_used_in_instance(p["instance"]),
),
reverse=True,
)
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
logger.info(f"placements: {len(selected)}")
for p in selected:
logger.info(
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
)
if args.dry_run:
return 0
all_rows: list[dict[str, Any]] = []
for preview in selected:
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"])
instance_meta = str(preview["instance_meta"])
n_nodes = nodes_used_in_instance(instance)
logger.info("=" * 80)
logger.info(
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
)
client.request_json("POST", "/instance", body={"instance": instance})
wait_for_instance_ready(client, instance_id)
time.sleep(1)
try:
for i in range(args.warmup):
run_one_completion(
client, full_model_id, pp_list[0], tg_list[0], prompt_sizer
)
logger.debug(f" warmup {i + 1}/{args.warmup} done")
for pp in pp_list:
if (
pp * n_nodes > 2048
and "ring" in instance_meta.lower()
and "tensor" in sharding.lower()
):
model_card = MODEL_CARDS[short_id]
if model_card.metadata.storage_size > Memory.from_gb(10):
logger.info(
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
)
continue
for tg in tg_list:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):
time.sleep(3)
try:
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
)
except Exception as e:
logger.error(e)
continue
row.update(
{
"model_short_id": short_id,
"model_id": full_model_id,
"placement_sharding": sharding,
"placement_instance_meta": instance_meta,
"placement_nodes": n_nodes,
"instance_id": instance_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
runs.append(row)
all_rows.append(row)
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
)
logger.info(
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)}\n"
)
time.sleep(2)
finally:
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
wait_for_instance_gone(client, instance_id)
logger.debug(f"Deleted instance {instance_id}")
time.sleep(5)
if args.json_out:
with open(args.json_out, "w", encoding="utf-8") as f:
json.dump(all_rows, f, indent=2, ensure_ascii=False)
logger.debug(f"\nWrote results JSON: {args.json_out}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -9,6 +9,8 @@
"version": "1.0.0",
"dependencies": {
"highlight.js": "^11.11.1",
"katex": "^0.16.27",
"marked": "^17.0.1",
"mode-watcher": "^1.1.0"
},
"devDependencies": {
@@ -861,7 +863,6 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -901,7 +902,6 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1518,7 +1518,6 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1528,7 +1527,6 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1941,7 +1939,6 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2254,6 +2251,31 @@
"jiti": "lib/jiti-cli.mjs"
}
},
"node_modules/katex": {
"version": "0.16.27",
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.27.tgz",
"integrity": "sha512-aeQoDkuRWSqQN6nSvVCEFvfXdqo1OQiCmmW1kc9xSdjutPv7BGO7pqY9sQRJpMOGrEdfDgF2TfRXe5eUAD2Waw==",
"funding": [
"https://opencollective.com/katex",
"https://github.com/sponsors/katex"
],
"license": "MIT",
"dependencies": {
"commander": "^8.3.0"
},
"bin": {
"katex": "cli.js"
}
},
"node_modules/katex/node_modules/commander": {
"version": "8.3.0",
"resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz",
"integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==",
"license": "MIT",
"engines": {
"node": ">= 12"
}
},
"node_modules/kleur": {
"version": "4.1.5",
"resolved": "https://registry.npmjs.org/kleur/-/kleur-4.1.5.tgz",
@@ -2540,6 +2562,18 @@
"@jridgewell/sourcemap-codec": "^1.5.5"
}
},
"node_modules/marked": {
"version": "17.0.1",
"resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz",
"integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==",
"license": "MIT",
"bin": {
"marked": "bin/marked.js"
},
"engines": {
"node": ">= 20"
}
},
"node_modules/mode-watcher": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/mode-watcher/-/mode-watcher-1.1.0.tgz",
@@ -2612,7 +2646,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2800,7 +2833,6 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2945,7 +2977,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -2967,7 +2998,6 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -27,7 +27,8 @@
},
"dependencies": {
"highlight.js": "^11.11.1",
"katex": "^0.16.27",
"marked": "^17.0.1",
"mode-watcher": "^1.1.0"
}
}

View File

@@ -139,6 +139,11 @@
}
function handleKeydown(event: KeyboardEvent) {
// Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)
if (event.isComposing || event.keyCode === 229) {
return;
}
if (event.key === 'Enter' && !event.shiftKey) {
event.preventDefault();
handleSubmit();

View File

@@ -8,89 +8,80 @@
regenerateLastResponse
} from '$lib/stores/app.svelte';
import type { MessageAttachment } from '$lib/stores/app.svelte';
import { tick, onDestroy } from 'svelte';
import MarkdownContent from './MarkdownContent.svelte';
interface Props {
class?: string;
scrollParent?: HTMLElement | null;
}
interface Props {
class?: string;
scrollParent?: HTMLElement | null;
}
let { class: className = '', scrollParent = null }: Props = $props();
let { class: className = '', scrollParent = null }: Props = $props();
const messageList = $derived(messages());
const response = $derived(currentResponse());
const loading = $derived(isLoading());
// Ref for scroll anchor at bottom
let scrollAnchorRef: HTMLDivElement | undefined = $state();
// Scroll management - user controls scroll, show button when not at bottom
const SCROLL_THRESHOLD = 100;
let showScrollButton = $state(false);
let lastMessageCount = 0;
let containerRef: HTMLDivElement | undefined = $state();
// Scroll management
const SCROLL_BOTTOM_THRESHOLD = 120;
let autoScrollEnabled = true;
let currentScrollEl: HTMLElement | null = null;
function resolveScrollElement(): HTMLElement | null {
if (scrollParent) return scrollParent;
let node: HTMLElement | null = scrollAnchorRef?.parentElement as HTMLElement | null;
while (node) {
const isScrollable = node.scrollHeight > node.clientHeight + 1;
if (isScrollable) return node;
node = node.parentElement;
function getScrollContainer(): HTMLElement | null {
if (scrollParent) return scrollParent;
return containerRef?.parentElement ?? null;
}
return null;
}
function handleScroll() {
if (!currentScrollEl) return;
const distanceFromBottom = currentScrollEl.scrollHeight - currentScrollEl.scrollTop - currentScrollEl.clientHeight;
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
autoScrollEnabled = isNearBottom;
}
function attachScrollListener() {
const nextEl = resolveScrollElement();
if (currentScrollEl === nextEl) return;
if (currentScrollEl) {
currentScrollEl.removeEventListener('scroll', handleScroll);
function isNearBottom(el: HTMLElement): boolean {
return el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
}
currentScrollEl = nextEl;
if (currentScrollEl) {
currentScrollEl.addEventListener('scroll', handleScroll);
// Initialize state based on current position
handleScroll();
}
}
onDestroy(() => {
if (currentScrollEl) {
currentScrollEl.removeEventListener('scroll', handleScroll);
}
});
$effect(() => {
// Re-evaluate scroll container if prop changes or after mount
scrollParent;
attachScrollListener();
});
// Auto-scroll to bottom when messages change or response updates, but only if user is near bottom
$effect(() => {
// Track these values to trigger effect
const _ = messageList.length;
const __ = response;
const ___ = loading;
tick().then(() => {
const el = currentScrollEl ?? resolveScrollElement();
if (!el || !scrollAnchorRef) return;
const distanceFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight;
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
if (autoScrollEnabled || isNearBottom) {
scrollAnchorRef.scrollIntoView({ behavior: 'smooth', block: 'end' });
autoScrollEnabled = true;
function scrollToBottom() {
const el = getScrollContainer();
if (el) {
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
}
}
function updateScrollButtonVisibility() {
const el = getScrollContainer();
if (!el) return;
showScrollButton = !isNearBottom(el);
}
// Attach scroll listener
$effect(() => {
const el = scrollParent ?? containerRef?.parentElement;
if (!el) return;
el.addEventListener('scroll', updateScrollButtonVisibility, { passive: true });
// Initial check
updateScrollButtonVisibility();
return () => el.removeEventListener('scroll', updateScrollButtonVisibility);
});
// Auto-scroll when user sends a new message
$effect(() => {
const count = messageList.length;
if (count > lastMessageCount) {
const el = getScrollContainer();
if (el) {
requestAnimationFrame(() => {
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
});
}
}
lastMessageCount = count;
});
// Update scroll button visibility when content changes
$effect(() => {
// Track response to trigger re-check during streaming
const _ = response;
// Small delay to let DOM update
requestAnimationFrame(() => updateScrollButtonVisibility());
});
});
// Edit state
let editingMessageId = $state<string | null>(null);
@@ -231,7 +222,7 @@ function isThinkingExpanded(messageId: string): boolean {
<div class="flex flex-col gap-4 sm:gap-6 {className}">
{#each messageList as message (message.id)}
<div class="group flex {message.role === 'user' ? 'justify-end' : 'justify-start'}">
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'max-w-[95%] sm:max-w-[85%]'}">
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'w-full max-w-[98%] sm:max-w-[95%]'}">
{#if message.role === 'assistant'}
<!-- Assistant message header -->
<div class="flex items-center gap-1.5 sm:gap-2 mb-1.5 sm:mb-2">
@@ -305,7 +296,7 @@ function isThinkingExpanded(messageId: string): boolean {
{:else}
<div class="{message.role === 'user'
? 'command-panel rounded-lg rounded-tr-sm inline-block'
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 inline-block'}">
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 block w-full'}">
{#if message.role === 'user'}
<!-- User message styling -->
@@ -331,7 +322,7 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
{#if message.content}
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
<div class="text-xs text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
{message.content}
</div>
{/if}
@@ -360,7 +351,7 @@ function isThinkingExpanded(messageId: string): boolean {
</svg>
<span>Thinking...</span>
</span>
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60">
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60 ml-4">
{isThinkingExpanded(message.id) ? 'HIDE' : 'SHOW'}
</span>
</button>
@@ -374,8 +365,8 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
</div>
{/if}
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
{message.content || (loading ? response : '')}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
@@ -457,6 +448,20 @@ function isThinkingExpanded(messageId: string): boolean {
</div>
{/if}
<!-- Scroll anchor for auto-scroll -->
<div bind:this={scrollAnchorRef}></div>
<!-- Invisible element for container reference -->
<div bind:this={containerRef}></div>
<!-- Scroll to bottom button -->
{#if showScrollButton}
<button
type="button"
onclick={scrollToBottom}
class="sticky bottom-4 left-1/2 -translate-x-1/2 w-10 h-10 rounded-full bg-exo-dark-gray/90 border border-exo-medium-gray/50 flex items-center justify-center text-exo-light-gray hover:text-exo-yellow hover:border-exo-yellow/50 transition-all shadow-lg cursor-pointer z-10"
title="Scroll to bottom"
>
<svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
</svg>
</button>
{/if}
</div>

View File

@@ -10,7 +10,9 @@ import {
clearChat,
instances,
debugMode,
toggleDebugMode
toggleDebugMode,
topologyOnlyMode,
toggleTopologyOnlyMode
} from '$lib/stores/app.svelte';
interface Props {
@@ -23,6 +25,7 @@ import {
const activeId = $derived(activeConversationId());
const instanceData = $derived(instances());
const debugEnabled = $derived(debugMode());
const topologyOnlyEnabled = $derived(topologyOnlyMode());
let searchQuery = $state('');
let editingId = $state<string | null>(null);
@@ -424,6 +427,19 @@ const debugEnabled = $derived(debugMode());
<div class="text-xs text-white/60 font-mono tracking-wider text-center">
{conversationList.length} CONVERSATION{conversationList.length !== 1 ? 'S' : ''}
</div>
<button
type="button"
onclick={toggleTopologyOnlyMode}
class="p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
title="Toggle topology only mode"
>
<svg class="w-4 h-4 {topologyOnlyEnabled ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<circle cx="12" cy="5" r="2" fill="currentColor" />
<circle cx="5" cy="19" r="2" fill="currentColor" />
<circle cx="19" cy="19" r="2" fill="currentColor" />
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
</svg>
</button>
</div>
</div>
</aside>

View File

@@ -3,6 +3,9 @@
export let showHome = true;
export let onHome: (() => void) | null = null;
export let showSidebarToggle = false;
export let sidebarVisible = true;
export let onToggleSidebar: (() => void) | null = null;
function handleHome(): void {
if (onHome) {
@@ -14,13 +17,38 @@
window.location.hash = '/';
}
}
function handleToggleSidebar(): void {
if (onToggleSidebar) {
onToggleSidebar();
}
}
</script>
<header class="relative z-20 flex items-center justify-center px-6 pt-8 pb-4 bg-exo-dark-gray">
<!-- Left: Sidebar Toggle -->
{#if showSidebarToggle}
<div class="absolute left-6 top-1/2 -translate-y-1/2">
<button
onclick={handleToggleSidebar}
class="p-2 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
title={sidebarVisible ? 'Hide sidebar' : 'Show sidebar'}
>
<svg class="w-5 h-5 {sidebarVisible ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
{#if sidebarVisible}
<path stroke-linecap="round" stroke-linejoin="round" d="M11 19l-7-7 7-7m8 14l-7-7 7-7" />
{:else}
<path stroke-linecap="round" stroke-linejoin="round" d="M13 5l7 7-7 7M5 5l7 7-7 7" />
{/if}
</svg>
</button>
</div>
{/if}
<!-- Center: Logo (clickable to go home) -->
<button
onclick={handleHome}
class="hover:opacity-80 transition-opacity {showHome ? 'cursor-pointer' : 'cursor-default'}"
class="bg-transparent border-none outline-none focus:outline-none transition-opacity duration-200 hover:opacity-90 {showHome ? 'cursor-pointer' : 'cursor-default'}"
title={showHome ? 'Go to home' : ''}
disabled={!showHome}
>

View File

@@ -0,0 +1,451 @@
<script lang="ts">
import { marked } from 'marked';
import hljs from 'highlight.js';
import katex from 'katex';
import 'katex/dist/katex.min.css';
import { browser } from '$app/environment';
interface Props {
content: string;
class?: string;
}
let { content, class: className = '' }: Props = $props();
let containerRef = $state<HTMLDivElement>();
let processedHtml = $state('');
// Configure marked with syntax highlighting
marked.setOptions({
gfm: true,
breaks: true
});
// Custom renderer for code blocks
const renderer = new marked.Renderer();
renderer.code = function ({ text, lang }: { text: string; lang?: string }) {
const language = lang && hljs.getLanguage(lang) ? lang : 'plaintext';
const highlighted = hljs.highlight(text, { language }).value;
const codeId = `code-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`;
return `
<div class="code-block-wrapper">
<div class="code-block-header">
<span class="code-language">${language}</span>
<button type="button" class="copy-code-btn" data-code="${encodeURIComponent(text)}" title="Copy code">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
</svg>
</button>
</div>
<pre><code class="hljs language-${language}" data-code-id="${codeId}">${highlighted}</code></pre>
</div>
`;
};
// Inline code
renderer.codespan = function ({ text }: { text: string }) {
return `<code class="inline-code">${text}</code>`;
};
marked.use({ renderer });
/**
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
* Also protect code blocks from LaTeX processing
*/
function preprocessLaTeX(text: string): string {
// Protect code blocks
const codeBlocks: string[] = [];
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
codeBlocks.push(match);
return `<<CODE_${codeBlocks.length - 1}>>`;
});
// Convert \(...\) to $...$
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
// Convert \[...\] to $$...$$
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
// Restore code blocks
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
return processed;
}
/**
* Render math expressions with KaTeX after HTML is generated
*/
function renderMath(html: string): string {
// Render display math ($$...$$)
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
try {
return katex.renderToString(math.trim(), {
displayMode: true,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$$${math}$$</span>`;
}
});
// Render inline math ($...$) but avoid matching currency like $5
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
// Skip if it looks like currency ($ followed by number)
if (/^\d/.test(math.trim())) {
return match;
}
try {
return katex.renderToString(math.trim(), {
displayMode: false,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$${math}$</span>`;
}
});
return html;
}
function processMarkdown(text: string): string {
try {
// Preprocess LaTeX notation
const preprocessed = preprocessLaTeX(text);
// Parse markdown
let html = marked.parse(preprocessed) as string;
// Render math expressions
html = renderMath(html);
return html;
} catch (error) {
console.error('Markdown processing error:', error);
return text.replace(/\n/g, '<br>');
}
}
async function handleCopyClick(event: Event) {
const target = event.currentTarget as HTMLButtonElement;
const encodedCode = target.getAttribute('data-code');
if (!encodedCode) return;
const code = decodeURIComponent(encodedCode);
try {
await navigator.clipboard.writeText(code);
// Show copied feedback
const originalHtml = target.innerHTML;
target.innerHTML = `
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M20 6L9 17l-5-5"/>
</svg>
`;
target.classList.add('copied');
setTimeout(() => {
target.innerHTML = originalHtml;
target.classList.remove('copied');
}, 2000);
} catch (error) {
console.error('Failed to copy:', error);
}
}
function setupCopyButtons() {
if (!containerRef || !browser) return;
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of buttons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleCopyClick);
}
}
}
$effect(() => {
if (content) {
processedHtml = processMarkdown(content);
} else {
processedHtml = '';
}
});
$effect(() => {
if (containerRef && processedHtml) {
setupCopyButtons();
}
});
</script>
<div bind:this={containerRef} class="markdown-content {className}">
{@html processedHtml}
</div>
<style>
.markdown-content {
line-height: 1.6;
}
/* Paragraphs */
.markdown-content :global(p) {
margin-bottom: 1rem;
}
.markdown-content :global(p:last-child) {
margin-bottom: 0;
}
/* Headers */
.markdown-content :global(h1) {
font-size: 1.5rem;
font-weight: 700;
margin: 1.5rem 0 0.75rem 0;
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(h2) {
font-size: 1.25rem;
font-weight: 600;
margin: 1.25rem 0 0.5rem 0;
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(h3) {
font-size: 1.125rem;
font-weight: 600;
margin: 1rem 0 0.5rem 0;
}
.markdown-content :global(h4),
.markdown-content :global(h5),
.markdown-content :global(h6) {
font-size: 1rem;
font-weight: 600;
margin: 0.75rem 0 0.25rem 0;
}
/* Bold and italic */
.markdown-content :global(strong) {
font-weight: 600;
}
.markdown-content :global(em) {
font-style: italic;
}
/* Inline code */
.markdown-content :global(.inline-code) {
background: rgba(255, 215, 0, 0.1);
color: var(--exo-yellow, #ffd700);
padding: 0.125rem 0.375rem;
border-radius: 0.25rem;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.875em;
}
/* Links */
.markdown-content :global(a) {
color: var(--exo-yellow, #ffd700);
text-decoration: underline;
text-underline-offset: 2px;
}
.markdown-content :global(a:hover) {
opacity: 0.8;
}
/* Lists */
.markdown-content :global(ul) {
list-style-type: disc;
margin-left: 1.5rem;
margin-bottom: 1rem;
}
.markdown-content :global(ol) {
list-style-type: decimal;
margin-left: 1.5rem;
margin-bottom: 1rem;
}
.markdown-content :global(li) {
margin-bottom: 0.25rem;
}
.markdown-content :global(li::marker) {
color: var(--exo-light-gray, #9ca3af);
}
/* Blockquotes */
.markdown-content :global(blockquote) {
border-left: 3px solid var(--exo-yellow, #ffd700);
padding: 0.5rem 1rem;
margin: 1rem 0;
background: rgba(255, 215, 0, 0.05);
border-radius: 0 0.25rem 0.25rem 0;
}
/* Tables */
.markdown-content :global(table) {
width: 100%;
margin: 1rem 0;
border-collapse: collapse;
font-size: 0.875rem;
}
.markdown-content :global(th) {
background: rgba(255, 215, 0, 0.1);
border: 1px solid rgba(255, 215, 0, 0.2);
padding: 0.5rem;
text-align: left;
font-weight: 600;
}
.markdown-content :global(td) {
border: 1px solid rgba(255, 255, 255, 0.1);
padding: 0.5rem;
}
/* Horizontal rule */
.markdown-content :global(hr) {
border: none;
border-top: 1px solid rgba(255, 255, 255, 0.1);
margin: 1.5rem 0;
}
/* Code block wrapper */
.markdown-content :global(.code-block-wrapper) {
margin: 1rem 0;
border-radius: 0.5rem;
overflow: hidden;
border: 1px solid rgba(255, 215, 0, 0.2);
background: rgba(0, 0, 0, 0.4);
}
.markdown-content :global(.code-block-header) {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.5rem 0.75rem;
background: rgba(255, 215, 0, 0.05);
border-bottom: 1px solid rgba(255, 215, 0, 0.1);
}
.markdown-content :global(.code-language) {
color: var(--exo-yellow, #ffd700);
font-size: 0.7rem;
font-weight: 500;
text-transform: uppercase;
letter-spacing: 0.1em;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
}
.markdown-content :global(.copy-code-btn) {
display: flex;
align-items: center;
justify-content: center;
padding: 0.25rem;
background: transparent;
border: none;
color: var(--exo-light-gray, #9ca3af);
cursor: pointer;
transition: color 0.2s;
border-radius: 0.25rem;
}
.markdown-content :global(.copy-code-btn:hover) {
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(.copy-code-btn.copied) {
color: #22c55e;
}
.markdown-content :global(.code-block-wrapper pre) {
margin: 0;
padding: 1rem;
overflow-x: auto;
background: transparent;
}
.markdown-content :global(.code-block-wrapper code) {
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.8125rem;
line-height: 1.5;
background: transparent;
}
/* Syntax highlighting - dark theme matching EXO style */
.markdown-content :global(.hljs) {
color: #e5e7eb;
}
.markdown-content :global(.hljs-keyword),
.markdown-content :global(.hljs-selector-tag),
.markdown-content :global(.hljs-literal),
.markdown-content :global(.hljs-section),
.markdown-content :global(.hljs-link) {
color: #c084fc;
}
.markdown-content :global(.hljs-string),
.markdown-content :global(.hljs-title),
.markdown-content :global(.hljs-name),
.markdown-content :global(.hljs-type),
.markdown-content :global(.hljs-attribute),
.markdown-content :global(.hljs-symbol),
.markdown-content :global(.hljs-bullet),
.markdown-content :global(.hljs-addition),
.markdown-content :global(.hljs-variable),
.markdown-content :global(.hljs-template-tag),
.markdown-content :global(.hljs-template-variable) {
color: #fbbf24;
}
.markdown-content :global(.hljs-comment),
.markdown-content :global(.hljs-quote),
.markdown-content :global(.hljs-deletion),
.markdown-content :global(.hljs-meta) {
color: #6b7280;
}
.markdown-content :global(.hljs-number),
.markdown-content :global(.hljs-regexp),
.markdown-content :global(.hljs-literal),
.markdown-content :global(.hljs-built_in) {
color: #34d399;
}
.markdown-content :global(.hljs-function),
.markdown-content :global(.hljs-class .hljs-title) {
color: #60a5fa;
}
/* KaTeX math styling */
.markdown-content :global(.katex) {
font-size: 1.1em;
}
.markdown-content :global(.katex-display) {
margin: 1rem 0;
overflow-x: auto;
overflow-y: hidden;
padding: 0.5rem 0;
}
.markdown-content :global(.katex-display > .katex) {
text-align: center;
}
.markdown-content :global(.math-error) {
color: #f87171;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.875em;
background: rgba(248, 113, 113, 0.1);
padding: 0.125rem 0.25rem;
border-radius: 0.25rem;
}
</style>

View File

@@ -1,5 +1,6 @@
<script lang="ts">
import type { DownloadProgress, NodeInfo, PlacementPreview } from '$lib/stores/app.svelte';
import type { DownloadProgress, NodeInfo, PlacementPreview, TopologyEdge } from '$lib/stores/app.svelte';
import { debugMode, topologyData } from '$lib/stores/app.svelte';
interface Props {
model: { id: string; name?: string; storage_size_megabytes?: number };
@@ -206,12 +207,8 @@ function toggleNodeDetails(nodeId: string): void {
const centerY = topoHeight / 2;
const radius = numNodes === 1 ? 0 : numNodes === 2 ? 45 : Math.min(topoWidth, topoHeight) * 0.32;
// Use API preview data if available
// Only use API preview data - no local estimation
const hasApiPreview = apiPreview !== null && apiPreview.error === null && apiPreview.memory_delta_by_node !== null;
const canFit = hasApiPreview ? true : (() => {
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
return totalAvailable >= estimatedMemory;
})();
const error = apiPreview?.error ?? null;
let placementNodes: Array<{
@@ -232,135 +229,140 @@ function toggleNodeDetails(nodeId: string): void {
modelFillHeight: number;
}> = [];
if (hasApiPreview && apiPreview.memory_delta_by_node) {
// Use API placement data
const memoryDelta = apiPreview.memory_delta_by_node;
placementNodes = nodeArray.map((n, i) => {
const deltaBytes = memoryDelta[n.id] ?? 0;
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
const isUsed = deltaBytes > 0;
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
const safeTotal = Math.max(n.totalGB, 0.001);
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
const screenHeight = iconSize * 0.58;
return {
id: n.id,
deviceName: n.deviceName,
deviceType: n.deviceType,
totalGB: n.totalGB,
currentUsedGB: n.usedGB,
modelUsageGB,
currentPercent,
newPercent,
isUsed,
x: centerX + Math.cos(angle) * radius,
y: centerY + Math.sin(angle) * radius,
iconSize,
screenHeight,
currentFillHeight: screenHeight * (currentPercent / 100),
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
};
});
} else if (apiPreview?.error) {
// API returned an error - model can't fit, show all nodes as unused
placementNodes = nodeArray.map((n, i) => {
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
const safeTotal = Math.max(n.totalGB, 0.001);
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
const screenHeight = iconSize * 0.58;
return {
id: n.id,
deviceName: n.deviceName,
deviceType: n.deviceType,
totalGB: n.totalGB,
currentUsedGB: n.usedGB,
modelUsageGB: 0,
currentPercent,
newPercent: currentPercent,
isUsed: false,
x: centerX + Math.cos(angle) * radius,
y: centerY + Math.sin(angle) * radius,
iconSize,
screenHeight,
currentFillHeight: screenHeight * (currentPercent / 100),
modelFillHeight: 0
};
});
} else {
// Fallback: local estimation based on sharding strategy
const memoryNeeded = estimatedMemory;
// Use API placement data directly
const memoryDelta = apiPreview?.memory_delta_by_node ?? {};
placementNodes = nodeArray.map((n, i) => {
const deltaBytes = memoryDelta[n.id] ?? 0;
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
const isUsed = deltaBytes > 0;
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
const safeTotal = Math.max(n.totalGB, 0.001);
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
const screenHeight = iconSize * 0.58;
if (sharding === 'Pipeline') {
const memoryPerNode = memoryNeeded / numNodes;
placementNodes = nodeArray.map((n, i) => {
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
const safeTotal = Math.max(n.totalGB, 0.001);
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
const newPercent = clampPercent(((n.usedGB + memoryPerNode) / safeTotal) * 100);
const screenHeight = iconSize * 0.58;
return {
id: n.id,
deviceName: n.deviceName,
deviceType: n.deviceType,
totalGB: n.totalGB,
currentUsedGB: n.usedGB,
modelUsageGB: memoryPerNode,
currentPercent,
newPercent,
isUsed: true,
x: centerX + Math.cos(angle) * radius,
y: centerY + Math.sin(angle) * radius,
iconSize,
screenHeight,
currentFillHeight: screenHeight * (currentPercent / 100),
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
};
});
} else {
let remaining = memoryNeeded;
placementNodes = nodeArray.map((n, i) => {
const allocated = Math.min(remaining, n.availableGB);
remaining -= allocated;
const isUsed = allocated > 0;
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
const safeTotal = Math.max(n.totalGB, 0.001);
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
const newPercent = clampPercent(((n.usedGB + allocated) / safeTotal) * 100);
const screenHeight = iconSize * 0.58;
return {
id: n.id,
deviceName: n.deviceName,
deviceType: n.deviceType,
totalGB: n.totalGB,
currentUsedGB: n.usedGB,
modelUsageGB: allocated,
currentPercent,
newPercent,
isUsed,
x: centerX + Math.cos(angle) * radius,
y: centerY + Math.sin(angle) * radius,
iconSize,
screenHeight,
currentFillHeight: screenHeight * (currentPercent / 100),
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
};
});
}
}
return {
id: n.id,
deviceName: n.deviceName,
deviceType: n.deviceType,
totalGB: n.totalGB,
currentUsedGB: n.usedGB,
modelUsageGB,
currentPercent,
newPercent,
isUsed,
x: centerX + Math.cos(angle) * radius,
y: centerY + Math.sin(angle) * radius,
iconSize,
screenHeight,
currentFillHeight: screenHeight * (currentPercent / 100),
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
};
});
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
return { nodes: placementNodes, canFit: hasApiPreview || canFit, totalAvailable, topoWidth, topoHeight, error };
return { nodes: placementNodes, canFit: hasApiPreview, totalAvailable, topoWidth, topoHeight, error };
});
const canFit = $derived(apiPreview ? apiPreview.error === null : placementPreview().canFit);
const placementError = $derived(apiPreview?.error ?? null);
const nodeCount = $derived(nodeList().length);
const filterId = $derived(model.id.replace(/[^a-zA-Z0-9]/g, ''));
// Debug mode state
const isDebugMode = $derived(debugMode());
const topology = $derived(topologyData());
const isRdma = $derived(runtime === 'MlxIbv' || runtime === 'MlxJaccl');
// Get interface name for an IP from node data
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
if (!ip || !topology?.nodes) return null;
// Strip port if present
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
// Check specified node first
const node = topology.nodes[nodeId];
if (node) {
const match = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
if (match?.name) return match.name;
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
if (mapped) return mapped;
}
// Fallback: check all nodes
for (const [, otherNode] of Object.entries(topology.nodes)) {
if (!otherNode) continue;
const match = otherNode.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
if (match?.name) return match.name;
const mapped = otherNode.ip_to_interface?.[cleanIp] || otherNode.ip_to_interface?.[ip];
if (mapped) return mapped;
}
return null;
}
// Get directional arrow based on node positions
function getArrow(fromNode: { x: number; y: number }, toNode: { x: number; y: number }): string {
const dx = toNode.x - fromNode.x;
const dy = toNode.y - fromNode.y;
const absX = Math.abs(dx);
const absY = Math.abs(dy);
if (absX > absY * 2) {
return dx > 0 ? '→' : '←';
} else if (absY > absX * 2) {
return dy > 0 ? '↓' : '↑';
} else {
if (dx > 0 && dy > 0) return '↘';
if (dx > 0 && dy < 0) return '↗';
if (dx < 0 && dy > 0) return '↙';
return '↖';
}
}
// Get connection info for edges between two nodes
// Returns exactly one connection per direction (A→B and B→A), preferring non-loopback
function getConnectionInfo(nodeId1: string, nodeId2: string): Array<{ ip: string; iface: string | null; from: string; to: string }> {
if (!topology?.edges) return [];
// Collect candidates for each direction
const aToBCandidates: Array<{ ip: string; iface: string | null }> = [];
const bToACandidates: Array<{ ip: string; iface: string | null }> = [];
for (const edge of topology.edges) {
const ip = edge.sendBackIp || '?';
const iface = edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
if (edge.source === nodeId1 && edge.target === nodeId2) {
aToBCandidates.push({ ip, iface });
} else if (edge.source === nodeId2 && edge.target === nodeId1) {
bToACandidates.push({ ip, iface });
}
}
// Pick best (prefer non-loopback)
const pickBest = (candidates: Array<{ ip: string; iface: string | null }>) => {
if (candidates.length === 0) return null;
return candidates.find(c => !c.ip.startsWith('127.')) || candidates[0];
};
const result: Array<{ ip: string; iface: string | null; from: string; to: string }> = [];
const bestAtoB = pickBest(aToBCandidates);
if (bestAtoB) result.push({ ...bestAtoB, from: nodeId1, to: nodeId2 });
const bestBtoA = pickBest(bToACandidates);
if (bestBtoA) result.push({ ...bestBtoA, from: nodeId2, to: nodeId1 });
return result;
}
</script>
<div class="relative group">
@@ -453,6 +455,26 @@ function toggleNodeDetails(nodeId: string): void {
<!-- Connection lines between nodes (if multiple) -->
{#if preview.nodes.length > 1}
{@const usedNodes = preview.nodes.filter(n => n.isUsed)}
{@const nodePositions = Object.fromEntries(preview.nodes.map(n => [n.id, { x: n.x, y: n.y }]))}
{@const allConnections = isDebugMode && usedNodes.length > 1 ? (() => {
const conns: Array<{ ip: string; iface: string | null; from: string; to: string; midX: number; midY: number; arrow: string }> = [];
for (let i = 0; i < usedNodes.length; i++) {
for (let j = i + 1; j < usedNodes.length; j++) {
const n1 = usedNodes[i];
const n2 = usedNodes[j];
const midX = (n1.x + n2.x) / 2;
const midY = (n1.y + n2.y) / 2;
for (const c of getConnectionInfo(n1.id, n2.id)) {
const fromPos = nodePositions[c.from];
const toPos = nodePositions[c.to];
const arrow = fromPos && toPos ? getArrow(fromPos, toPos) : '→';
conns.push({ ...c, midX, midY, arrow });
}
}
}
return conns;
})() : []}
{#each preview.nodes as node, i}
{#each preview.nodes.slice(i + 1) as node2}
<line
@@ -464,6 +486,43 @@ function toggleNodeDetails(nodeId: string): void {
/>
{/each}
{/each}
<!-- Debug: Show connection IPs/interfaces in corners -->
{#if isDebugMode && allConnections.length > 0}
{@const centerX = preview.topoWidth / 2}
{@const centerY = preview.topoHeight / 2}
{@const quadrants = {
topLeft: allConnections.filter(c => c.midX < centerX && c.midY < centerY),
topRight: allConnections.filter(c => c.midX >= centerX && c.midY < centerY),
bottomLeft: allConnections.filter(c => c.midX < centerX && c.midY >= centerY),
bottomRight: allConnections.filter(c => c.midX >= centerX && c.midY >= centerY)
}}
{@const padding = 4}
{@const lineHeight = 8}
<!-- Top Left -->
{#each quadrants.topLeft as conn, idx}
<text x={padding} y={padding + idx * lineHeight} text-anchor="start" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
</text>
{/each}
<!-- Top Right -->
{#each quadrants.topRight as conn, idx}
<text x={preview.topoWidth - padding} y={padding + idx * lineHeight} text-anchor="end" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
</text>
{/each}
<!-- Bottom Left -->
{#each quadrants.bottomLeft as conn, idx}
<text x={padding} y={preview.topoHeight - padding - (quadrants.bottomLeft.length - 1 - idx) * lineHeight} text-anchor="start" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
</text>
{/each}
<!-- Bottom Right -->
{#each quadrants.bottomRight as conn, idx}
<text x={preview.topoWidth - padding} y={preview.topoHeight - padding - (quadrants.bottomRight.length - 1 - idx) * lineHeight} text-anchor="end" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
</text>
{/each}
{/if}
{/if}
{#each preview.nodes as node}

View File

@@ -24,19 +24,36 @@ function getNodeLabel(nodeId: string): string {
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
if (!ip) return { label: '?', missing: true };
const node = data?.nodes?.[nodeId];
if (!node) return { label: '?', missing: true };
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
// Helper to check a node's interfaces
function checkNode(node: typeof data.nodes[string]): string | null {
if (!node) return null;
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
if (matchFromInterfaces?.name) {
return matchFromInterfaces.name;
}
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === ip)
);
if (matchFromInterfaces?.name) {
return { label: matchFromInterfaces.name, missing: false };
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
if (mapped && mapped.trim().length > 0) {
return mapped;
}
return null;
}
const mapped = node.ip_to_interface?.[ip];
if (mapped && mapped.trim().length > 0) {
return { label: mapped, missing: false };
// Try specified node first
const result = checkNode(data?.nodes?.[nodeId]);
if (result) return { label: result, missing: false };
// Fallback: search all nodes for this IP
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
const otherResult = checkNode(otherNode);
if (otherResult) return { label: otherResult, missing: false };
}
return { label: '?', missing: true };
@@ -67,6 +84,7 @@ function wrapLine(text: string, maxLen: number): string[] {
return lines;
}
// Apple logo path for MacBook Pro screen
const APPLE_LOGO_PATH = "M788.1 340.9c-5.8 4.5-108.2 62.2-108.2 190.5 0 148.4 130.3 200.9 134.2 202.2-.6 3.2-20.7 71.9-68.7 141.9-42.8 61.6-87.5 123.1-155.5 123.1s-85.5-39.5-164-39.5c-76.5 0-103.7 40.8-165.9 40.8s-105.6-57-155.5-127C46.7 790.7 0 663 0 541.8c0-194.4 126.4-297.5 250.8-297.5 66.1 0 121.2 43.4 162.7 43.4 39.5 0 101.1-46 176.3-46 28.5 0 130.9 2.6 198.3 99.2zm-234-181.5c31.1-36.9 53.1-88.1 53.1-139.3 0-7.1-.6-14.3-1.9-20.1-50.6 1.9-110.8 33.7-147.1 75.8-28.5 32.4-55.1 83.6-55.1 135.5 0 7.8 1.3 15.6 1.9 18.1 3.2.6 8.4 1.3 13.6 1.3 45.4 0 102.5-30.4 135.5-71.3z";
const LOGO_NATIVE_WIDTH = 814;
@@ -238,6 +256,7 @@ function wrapLine(text: string, maxLen: number): string[] {
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
edges.forEach(edge => {
if (!edge.source || !edge.target || edge.source === edge.target) return;
if (!positionById[edge.source] || !positionById[edge.target]) return;
@@ -314,110 +333,98 @@ function wrapLine(text: string, maxLen: number): string[] {
.attr('marker-end', 'url(#arrowhead)');
}
// Collect debug labels for later positioning at edges
if (debugEnabled && entry.connections.length > 0) {
const maxBoxes = 6;
const fontSize = isMinimized ? 8 : 9;
const lineGap = 2;
const labelOffsetOut = Math.max(140, minDimension * 0.38);
const labelOffsetSide = isMinimized ? 16 : 20;
const boxWidth = 170;
const maxLineLen = 26;
const connections = entry.connections.slice(0, maxBoxes);
if (entry.connections.length > maxBoxes) {
const remaining = entry.connections.length - maxBoxes;
connections.push({
from: '',
to: '',
ip: `(+${remaining} more)`,
ifaceLabel: '',
missingIface: false
});
}
let dirX = mx - centerX;
let dirY = my - centerY;
const dirLen = Math.hypot(dirX, dirY);
if (dirLen < 1) {
dirX = -uy;
dirY = ux;
} else {
dirX /= dirLen;
dirY /= dirLen;
}
const nx = -dirY;
const ny = dirX;
const labelXRaw = mx + dirX * labelOffsetOut + nx * labelOffsetSide;
const labelYRaw = my + dirY * labelOffsetOut + ny * labelOffsetSide;
const clampPad = Math.min(120, minDimension * 0.12);
const labelX = Math.max(clampPad, Math.min(width - clampPad, labelXRaw));
const labelY = Math.max(clampPad, Math.min(height - clampPad, labelYRaw));
const labelGroup = debugLabelsGroup.append('g')
.attr('transform', `translate(${labelX}, ${labelY})`);
const textGroup = labelGroup.append('g');
connections.forEach((conn, idx) => {
const rawLines = conn.from && conn.to
? [
`${getNodeLabel(conn.from)}${getNodeLabel(conn.to)}`,
`${conn.ip}`,
`${conn.ifaceLabel}`
]
: [conn.ip];
const wrapped = rawLines.flatMap(line => wrapLine(line, maxLineLen));
wrapped.forEach((line, lineIdx) => {
textGroup.append('text')
.attr('x', 0)
.attr('y', (idx * (wrapped.length * (fontSize + lineGap))) + lineIdx * (fontSize + lineGap))
.attr('text-anchor', 'middle')
.attr('dominant-baseline', 'hanging')
.attr('font-size', fontSize)
.attr('font-family', 'SF Mono, monospace')
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.9)')
.text(line);
});
// Determine which side of viewport based on edge midpoint
const isLeft = mx < centerX;
const isTop = my < safeCenterY;
// Store for batch rendering after all edges processed
if (!debugEdgeLabels) debugEdgeLabels = [];
debugEdgeLabels.push({
connections: entry.connections,
isLeft,
isTop,
mx,
my
});
const bbox = textGroup.node()?.getBBox();
if (bbox) {
const paddedWidth = Math.max(boxWidth, bbox.width + 14);
const boxHeight = bbox.height + 8;
const boxMinX = labelX - paddedWidth / 2;
const boxMaxX = labelX + paddedWidth / 2;
const boxMinY = labelY + bbox.y - 4;
const boxMaxY = boxMinY + boxHeight;
const clampPadDynamic = Math.min(140, minDimension * 0.18);
let shiftX = 0;
let shiftY = 0;
if (boxMinX < clampPadDynamic) shiftX = clampPadDynamic - boxMinX;
if (boxMaxX > width - clampPadDynamic) shiftX = (width - clampPadDynamic) - boxMaxX;
if (boxMinY < clampPadDynamic) shiftY = clampPadDynamic - boxMinY;
if (boxMaxY > height - clampPadDynamic) shiftY = (height - clampPadDynamic) - boxMaxY;
const finalX = labelX + shiftX;
const finalY = labelY + shiftY;
labelGroup.attr('transform', `translate(${finalX}, ${finalY})`);
labelGroup.insert('rect', 'g')
.attr('x', -paddedWidth / 2)
.attr('y', bbox.y - 4)
.attr('width', paddedWidth)
.attr('height', boxHeight)
.attr('rx', 4)
.attr('fill', 'rgba(0,0,0,0.75)')
.attr('stroke', 'rgba(255,255,255,0.12)')
.attr('stroke-width', 0.6);
}
}
});
// Render debug labels at viewport edges/corners
if (debugEdgeLabels && debugEdgeLabels.length > 0) {
const fontSize = isMinimized ? 10 : 12;
const lineHeight = fontSize + 4;
const padding = 10;
// Helper to get arrow based on direction vector
function getArrow(fromId: string, toId: string): string {
const fromPos = positionById[fromId];
const toPos = positionById[toId];
if (!fromPos || !toPos) return '→';
const dirX = toPos.x - fromPos.x;
const dirY = toPos.y - fromPos.y;
const absX = Math.abs(dirX);
const absY = Math.abs(dirY);
if (absX > absY * 2) {
return dirX > 0 ? '→' : '←';
} else if (absY > absX * 2) {
return dirY > 0 ? '↓' : '↑';
} else {
if (dirX > 0 && dirY > 0) return '↘';
if (dirX > 0 && dirY < 0) return '↗';
if (dirX < 0 && dirY > 0) return '↙';
return '↖';
}
}
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
const quadrants: Record<string, typeof debugEdgeLabels> = {
topLeft: [],
topRight: [],
bottomLeft: [],
bottomRight: []
};
debugEdgeLabels.forEach(edge => {
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
quadrants[key].push(edge);
});
// Render each quadrant
Object.entries(quadrants).forEach(([quadrant, edges]) => {
if (edges.length === 0) return;
const isLeft = quadrant.includes('Left');
const isTop = quadrant.includes('top');
let baseX = isLeft ? padding : width - padding;
let baseY = isTop ? padding : height - padding;
const textAnchor = isLeft ? 'start' : 'end';
let currentY = baseY;
edges.forEach(edge => {
edge.connections.forEach(conn => {
const arrow = getArrow(conn.from, conn.to);
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;
debugLabelsGroup.append('text')
.attr('x', baseX)
.attr('y', currentY)
.attr('text-anchor', textAnchor)
.attr('dominant-baseline', isTop ? 'hanging' : 'auto')
.attr('font-size', fontSize)
.attr('font-family', 'SF Mono, monospace')
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.85)')
.text(label);
currentY += isTop ? lineHeight : -lineHeight;
});
});
});
}
// Draw nodes
const nodesGroup = svg.append('g').attr('class', 'nodes-group');
@@ -968,4 +975,5 @@ function wrapLine(text: string, maxLen: number): string[] {
from { stroke-dashoffset: 0; }
to { stroke-dashoffset: -10; }
}
</style>

View File

@@ -4,4 +4,5 @@ export { default as ChatMessages } from './ChatMessages.svelte';
export { default as ChatAttachments } from './ChatAttachments.svelte';
export { default as ChatSidebar } from './ChatSidebar.svelte';
export { default as ModelCard } from './ModelCard.svelte';
export { default as MarkdownContent } from './MarkdownContent.svelte';

View File

@@ -327,6 +327,8 @@ class AppStore {
isTopologyMinimized = $state(false);
isSidebarOpen = $state(false); // Hidden by default, shown when in chat mode
debugMode = $state(false);
topologyOnlyMode = $state(false);
chatSidebarVisible = $state(true); // Shown by default
private fetchInterval: ReturnType<typeof setInterval> | null = null;
private previewsInterval: ReturnType<typeof setInterval> | null = null;
@@ -337,6 +339,8 @@ class AppStore {
this.startPolling();
this.loadConversationsFromStorage();
this.loadDebugModeFromStorage();
this.loadTopologyOnlyModeFromStorage();
this.loadChatSidebarVisibleFromStorage();
}
}
@@ -394,6 +398,44 @@ class AppStore {
}
}
private loadTopologyOnlyModeFromStorage() {
try {
const stored = localStorage.getItem('exo-topology-only-mode');
if (stored !== null) {
this.topologyOnlyMode = stored === 'true';
}
} catch (error) {
console.error('Failed to load topology only mode:', error);
}
}
private saveTopologyOnlyModeToStorage() {
try {
localStorage.setItem('exo-topology-only-mode', this.topologyOnlyMode ? 'true' : 'false');
} catch (error) {
console.error('Failed to save topology only mode:', error);
}
}
private loadChatSidebarVisibleFromStorage() {
try {
const stored = localStorage.getItem('exo-chat-sidebar-visible');
if (stored !== null) {
this.chatSidebarVisible = stored === 'true';
}
} catch (error) {
console.error('Failed to load chat sidebar visibility:', error);
}
}
private saveChatSidebarVisibleToStorage() {
try {
localStorage.setItem('exo-chat-sidebar-visible', this.chatSidebarVisible ? 'true' : 'false');
} catch (error) {
console.error('Failed to save chat sidebar visibility:', error);
}
}
/**
* Create a new conversation
*/
@@ -698,6 +740,34 @@ class AppStore {
this.saveDebugModeToStorage();
}
getTopologyOnlyMode(): boolean {
return this.topologyOnlyMode;
}
setTopologyOnlyMode(enabled: boolean) {
this.topologyOnlyMode = enabled;
this.saveTopologyOnlyModeToStorage();
}
toggleTopologyOnlyMode() {
this.topologyOnlyMode = !this.topologyOnlyMode;
this.saveTopologyOnlyModeToStorage();
}
getChatSidebarVisible(): boolean {
return this.chatSidebarVisible;
}
setChatSidebarVisible(visible: boolean) {
this.chatSidebarVisible = visible;
this.saveChatSidebarVisibleToStorage();
}
toggleChatSidebarVisible() {
this.chatSidebarVisible = !this.chatSidebarVisible;
this.saveChatSidebarVisibleToStorage();
}
startPolling() {
this.fetchState();
this.fetchInterval = setInterval(() => this.fetchState(), 1000);
@@ -888,8 +958,6 @@ class AppStore {
if (lastUserIndex === -1) return;
const lastUserMessage = this.messages[lastUserIndex];
// Remove any messages after the user message
this.messages = this.messages.slice(0, lastUserIndex + 1);
@@ -930,7 +998,10 @@ class AppStore {
}
if (!modelToUse) {
assistantMessage.content = 'Error: No model available. Please launch an instance first.';
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = 'Error: No model available. Please launch an instance first.';
}
this.isLoading = false;
this.updateActiveConversation();
return;
@@ -948,7 +1019,10 @@ class AppStore {
if (!response.ok) {
const errorText = await response.text();
assistantMessage.content = `Error: ${response.status} - ${errorText}`;
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = `Error: ${response.status} - ${errorText}`;
}
this.isLoading = false;
this.updateActiveConversation();
return;
@@ -956,7 +1030,10 @@ class AppStore {
const reader = response.body?.getReader();
if (!reader) {
assistantMessage.content = 'Error: No response stream available';
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = 'Error: No response stream available';
}
this.isLoading = false;
this.updateActiveConversation();
return;
@@ -984,9 +1061,16 @@ class AppStore {
const delta = json.choices?.[0]?.delta?.content;
if (delta) {
fullContent += delta;
const { displayContent } = this.stripThinkingTags(fullContent);
const { displayContent, thinkingContent } = this.stripThinkingTags(fullContent);
this.currentResponse = displayContent;
assistantMessage.content = displayContent;
// Update the assistant message in place (triggers Svelte reactivity)
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = displayContent;
this.messages[idx].thinking = thinkingContent || undefined;
}
this.persistActiveConversation();
}
} catch {
// Skip malformed JSON
@@ -995,16 +1079,25 @@ class AppStore {
}
}
const { displayContent } = this.stripThinkingTags(fullContent);
assistantMessage.content = displayContent;
this.currentResponse = '';
this.updateActiveConversation();
// Final cleanup of the message
const { displayContent, thinkingContent } = this.stripThinkingTags(fullContent);
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = displayContent;
this.messages[idx].thinking = thinkingContent || undefined;
}
this.persistActiveConversation();
} catch (error) {
assistantMessage.content = `Error: ${error instanceof Error ? error.message : 'Unknown error'}`;
this.updateActiveConversation();
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = `Error: ${error instanceof Error ? error.message : 'Unknown error'}`;
}
this.persistActiveConversation();
} finally {
this.isLoading = false;
this.currentResponse = '';
this.updateActiveConversation();
}
}
@@ -1364,6 +1457,8 @@ export const lastUpdate = () => appStore.lastUpdate;
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
export const selectedChatModel = () => appStore.selectedChatModel;
export const debugMode = () => appStore.getDebugMode();
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
// Actions
export const startChat = () => appStore.startChat();
@@ -1391,5 +1486,9 @@ export const isSidebarOpen = () => appStore.isSidebarOpen;
export const toggleSidebar = () => appStore.toggleSidebar();
export const toggleDebugMode = () => appStore.toggleDebugMode();
export const setDebugMode = (enabled: boolean) => appStore.setDebugMode(enabled);
export const toggleTopologyOnlyMode = () => appStore.toggleTopologyOnlyMode();
export const setTopologyOnlyMode = (enabled: boolean) => appStore.setTopologyOnlyMode(enabled);
export const toggleChatSidebarVisible = () => appStore.toggleChatSidebarVisible();
export const setChatSidebarVisible = (visible: boolean) => appStore.setChatSidebarVisible(visible);
export const refreshState = () => appStore.fetchState();

View File

@@ -18,6 +18,10 @@
selectedChatModel,
debugMode,
toggleDebugMode,
topologyOnlyMode,
toggleTopologyOnlyMode,
chatSidebarVisible,
toggleChatSidebarVisible,
type DownloadProgress,
type PlacementPreview
} from '$lib/stores/app.svelte';
@@ -37,6 +41,8 @@
const selectedModelId = $derived(selectedPreviewModelId());
const loadingPreviews = $derived(isLoadingPreviews());
const debugEnabled = $derived(debugMode());
const topologyOnlyEnabled = $derived(topologyOnlyMode());
const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
@@ -45,6 +51,59 @@ const debugEnabled = $derived(debugMode());
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
// Launch defaults persistence
const LAUNCH_DEFAULTS_KEY = 'exo-launch-defaults';
interface LaunchDefaults {
modelId: string | null;
sharding: 'Pipeline' | 'Tensor';
instanceType: InstanceMeta;
minNodes: number;
}
function saveLaunchDefaults(): void {
const defaults: LaunchDefaults = {
modelId: selectedPreviewModelId(),
sharding: selectedSharding,
instanceType: selectedInstanceType,
minNodes: selectedMinNodes,
};
try {
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
} catch (e) {
console.warn('Failed to save launch defaults:', e);
}
}
function loadLaunchDefaults(): LaunchDefaults | null {
try {
const stored = localStorage.getItem(LAUNCH_DEFAULTS_KEY);
if (!stored) return null;
return JSON.parse(stored) as LaunchDefaults;
} catch (e) {
console.warn('Failed to load launch defaults:', e);
return null;
}
}
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
const defaults = loadLaunchDefaults();
if (!defaults) return;
// Apply sharding and instance type unconditionally
selectedSharding = defaults.sharding;
selectedInstanceType = defaults.instanceType;
// Apply minNodes if valid (between 1 and maxNodes)
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
selectPreviewModel(defaults.modelId);
}
}
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
let selectedMinNodes = $state<number>(1);
let minNodesInitialized = $state(false);
@@ -292,6 +351,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const data = await response.json();
// API returns { data: [{ id, name }] } format
models = data.data || [];
// Restore last launch defaults if available
const currentNodeCount = topologyData() ? Object.keys(topologyData()!.nodes).length : 1;
applyLaunchDefaults(models, currentNodeCount);
}
} catch (error) {
console.error('Failed to fetch models:', error);
@@ -472,6 +534,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const progress = parseDownloadProgress(downloadPayload);
if (progress) {
// Sum all values across nodes - each node downloads independently
totalBytes += progress.totalBytes;
downloadedBytes += progress.downloadedBytes;
totalSpeed += progress.speed;
@@ -489,13 +552,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
return { isDownloading: false, progress: null, perNode: [] };
}
// ETA = total remaining bytes / total speed across all nodes
const remainingBytes = totalBytes - downloadedBytes;
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
return {
isDownloading: true,
progress: {
totalBytes,
downloadedBytes,
speed: totalSpeed,
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
etaMs,
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
completedFiles,
totalFiles,
@@ -576,6 +643,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const progress = parseDownloadProgress(downloadPayload);
if (progress) {
// Sum all values across nodes - each node downloads independently
totalBytes += progress.totalBytes;
downloadedBytes += progress.downloadedBytes;
totalSpeed += progress.speed;
@@ -596,13 +664,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
return { isDownloading: false, progress: null, statusText: statusInfo.statusText, perNode: [] };
}
// ETA = total remaining bytes / total speed across all nodes
const remainingBytes = totalBytes - downloadedBytes;
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
return {
isDownloading: true,
progress: {
totalBytes,
downloadedBytes,
speed: totalSpeed,
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
etaMs,
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
completedFiles,
totalFiles,
@@ -618,10 +690,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function getStatusColor(statusText: string): string {
switch (statusText) {
case 'FAILED': return 'text-red-400';
case 'SHUTDOWN': return 'text-gray-400';
case 'DOWNLOADING': return 'text-blue-400';
case 'LOADING':
case 'WARMING UP':
case 'WAITING': return 'text-yellow-400';
case 'WAITING':
case 'INITIALIZING': return 'text-yellow-400';
case 'RUNNING': return 'text-teal-400';
case 'READY':
case 'LOADED': return 'text-green-400';
@@ -644,12 +718,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!r) return null;
const [kind] = getTagged(r);
const statusMap: Record<string, string> = {
RunnerWaitingForInitialization: 'WaitingForInitialization',
RunnerInitializingBackend: 'InitializingBackend',
RunnerWaitingForModel: 'WaitingForModel',
RunnerLoading: 'Loading',
RunnerLoaded: 'Loaded',
RunnerWarmingUp: 'WarmingUp',
RunnerReady: 'Ready',
RunnerRunning: 'Running',
RunnerShutdown: 'Shutdown',
RunnerFailed: 'Failed',
};
return kind ? statusMap[kind] || null : null;
@@ -660,12 +737,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (statuses.length === 0) return { statusText: 'UNKNOWN', statusClass: 'inactive' };
if (has('Failed')) return { statusText: 'FAILED', statusClass: 'failed' };
if (has('Shutdown')) return { statusText: 'SHUTDOWN', statusClass: 'inactive' };
if (has('Loading')) return { statusText: 'LOADING', statusClass: 'starting' };
if (has('WarmingUp')) return { statusText: 'WARMING UP', statusClass: 'starting' };
if (has('Running')) return { statusText: 'RUNNING', statusClass: 'running' };
if (has('Ready')) return { statusText: 'READY', statusClass: 'loaded' };
if (has('Loaded')) return { statusText: 'LOADED', statusClass: 'loaded' };
if (has('WaitingForModel')) return { statusText: 'WAITING', statusClass: 'starting' };
if (has('InitializingBackend')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
if (has('WaitingForInitialization')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
return { statusText: 'RUNNING', statusClass: 'active' };
}
@@ -964,6 +1044,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function handleSliderMouseUp() {
isDraggingSlider = false;
saveLaunchDefaults();
}
// Handle touch events for mobile
@@ -983,6 +1064,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function handleSliderTouchEnd() {
isDraggingSlider = false;
saveLaunchDefaults();
}
const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0);
@@ -1107,16 +1189,47 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="shooting-star" style="top: 50%; left: 40%; --duration: 45s; --delay: 30s;"></div>
</div>
<HeaderNav showHome={chatStarted} onHome={handleGoHome} />
{#if !topologyOnlyEnabled}
<HeaderNav
showHome={chatStarted}
onHome={handleGoHome}
showSidebarToggle={true}
sidebarVisible={sidebarVisible}
onToggleSidebar={toggleChatSidebarVisible}
/>
{/if}
<!-- Main Content -->
<main class="flex-1 flex overflow-hidden relative">
<!-- Left: Conversation History Sidebar (always visible) -->
<!-- Left: Conversation History Sidebar (hidden in topology-only mode or when toggled off) -->
{#if !topologyOnlyEnabled && sidebarVisible}
<div class="w-80 flex-shrink-0 border-r border-exo-yellow/10">
<ChatSidebar class="h-full" />
</div>
{/if}
{#if !chatStarted}
{#if topologyOnlyEnabled}
<!-- TOPOLOGY ONLY MODE: Full-screen topology -->
<div class="flex-1 flex flex-col min-h-0 min-w-0 p-4" in:fade={{ duration: 300 }}>
<div class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden">
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
<!-- Exit topology-only mode button -->
<button
type="button"
onclick={toggleTopologyOnlyMode}
class="absolute bottom-4 right-4 p-2 rounded border border-exo-yellow/30 bg-exo-dark-gray/80 hover:border-exo-yellow/50 hover:bg-exo-dark-gray transition-colors cursor-pointer backdrop-blur-sm"
title="Exit topology only mode"
>
<svg class="w-5 h-5 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<circle cx="12" cy="5" r="2" fill="currentColor" />
<circle cx="5" cy="19" r="2" fill="currentColor" />
<circle cx="19" cy="19" r="2" fill="currentColor" />
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
</svg>
</button>
</div>
</div>
{:else if !chatStarted}
<!-- WELCOME STATE: Topology + Instance Controls (no left sidebar for cleaner look) -->
<div class="flex-1 flex overflow-visible relative" in:fade={{ duration: 300 }} out:fade={{ duration: 200 }}>
@@ -1300,14 +1413,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{:else}
{#each nodeProg.progress.files as f}
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
{@const isFileComplete = filePercent >= 100}
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
<span class="truncate pr-2">{f.name}</span>
<span class="text-white/80">{filePercent.toFixed(1)}%</span>
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
</div>
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
style="width: {filePercent.toFixed(1)}%"
></div>
</div>
@@ -1408,6 +1522,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
onclick={() => {
if (modelCanFit) {
selectPreviewModel(model.id);
saveLaunchDefaults();
isModelDropdownOpen = false;
modelDropdownSearch = '';
}
@@ -1441,7 +1556,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="text-xs text-white/70 font-mono mb-2">Sharding:</div>
<div class="flex gap-2">
<button
onclick={() => selectedSharding = 'Pipeline'}
onclick={() => { selectedSharding = 'Pipeline'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Pipeline' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Pipeline' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1452,7 +1567,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
Pipeline
</button>
<button
onclick={() => selectedSharding = 'Tensor'}
onclick={() => { selectedSharding = 'Tensor'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Tensor' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Tensor' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1470,7 +1585,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="text-xs text-white/70 font-mono mb-2">Instance Type:</div>
<div class="flex gap-2">
<button
onclick={() => selectedInstanceType = 'MlxRing'}
onclick={() => { selectedInstanceType = 'MlxRing'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxRing' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxRing' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1481,7 +1596,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
MLX Ring
</button>
<button
onclick={() => selectedInstanceType = 'MlxIbv'}
onclick={() => { selectedInstanceType = 'MlxIbv'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxIbv' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxIbv' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1611,13 +1726,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
in:fade={{ duration: 300, delay: 100 }}
>
<div class="flex-1 overflow-y-auto px-8 py-6" bind:this={chatScrollRef}>
<div class="max-w-3xl mx-auto">
<div class="max-w-7xl mx-auto">
<ChatMessages scrollParent={chatScrollRef} />
</div>
</div>
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
<div class="max-w-3xl mx-auto">
<div class="max-w-7xl mx-auto">
<ChatForm placeholder="Ask anything" showModelSelector={true} />
</div>
</div>
@@ -1655,7 +1770,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<!-- Panel Header -->
<div class="flex items-center gap-2 mb-4">
<div class="w-2 h-2 bg-exo-yellow rounded-full shadow-[0_0_8px_rgba(255,215,0,0.6)] animate-pulse"></div>
<h3 class="text-sm text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
<h3 class="text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
</div>
<div class="space-y-3 max-h-72 overflow-y-auto pr-1">
@@ -1701,28 +1816,28 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex justify-between items-start mb-2 pl-2">
<div class="flex items-center gap-2">
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
<span class="text-exo-light-gray font-mono text-xs tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
</div>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400/80 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<div class="pl-2">
<div class="text-exo-yellow text-sm font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[10px] text-white/60 hover:text-exo-yellow transition-colors mt-0.5"
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
href={`https://huggingface.co/${instanceModelId}`}
target="_blank"
rel="noreferrer noopener"
aria-label="View model on Hugging Face"
>
<span>Hugging Face</span>
<svg class="w-3 h-3" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<svg class="w-3.5 h-3.5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M14 3h7v7"/>
<path d="M10 14l11-11"/>
<path d="M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6"/>
@@ -1733,68 +1848,84 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="text-white/60 text-xs font-mono">{instanceInfo.nodeNames.join(', ')}</div>
{/if}
{#if debugEnabled && instanceConnections.length > 0}
<div class="mt-1 space-y-0.5">
{#each instanceConnections as conn}
<div class="text-[10px] leading-snug font-mono text-white/70">
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
</div>
{/each}
<div class="mt-2 space-y-1">
{#each instanceConnections as conn}
<div class="text-[11px] leading-snug font-mono text-white/70">
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
</div>
{/each}
</div>
{/if}
<!-- Download Progress -->
{#if downloadInfo.isDownloading && downloadInfo.progress}
<div class="mt-2 space-y-1">
<div class="flex justify-between text-xs font-mono">
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
</div>
{/if}
<!-- Download Progress -->
{#if downloadInfo.isDownloading && downloadInfo.progress}
<div class="mt-2 space-y-1">
<div class="flex justify-between text-sm font-mono">
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
</div>
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden">
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
style="width: {downloadInfo.progress.percentage}%"
></div>
</div>
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
</div>
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
style="width: {downloadInfo.progress.percentage}%"
></div>
</div>
{#if downloadInfo.perNode.length > 0}
<div class="mt-2 space-y-1.5 max-h-48 overflow-y-auto pr-1">
{#each downloadInfo.perNode as nodeProg}
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
</div>
</div>
{#if downloadInfo.perNode.length > 0}
<div class="mt-2 space-y-2 max-h-48 overflow-y-auto pr-1">
{#each downloadInfo.perNode as nodeProg}
{@const nodePercent = Math.min(100, Math.max(0, nodeProg.progress.percentage))}
{@const isExpanded = instanceDownloadExpandedNodes.has(nodeProg.nodeId)}
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
<button
type="button"
class="w-full text-left space-y-1.5"
onclick={() => toggleInstanceDownloadDetails(nodeProg.nodeId)}
>
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
<span class="text-white/80 truncate pr-2">{nodeProg.nodeName}</span>
<span class="text-blue-300">{Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%</span>
<span class="flex items-center gap-1 text-blue-300">
{nodePercent.toFixed(1)}%
<svg class="w-3 h-3 text-exo-light-gray" viewBox="0 0 20 20" fill="none" stroke="currentColor" stroke-width="2">
<path d="M6 8l4 4 4-4" class={isExpanded ? 'transform rotate-180 origin-center transition-transform duration-150' : 'transition-transform duration-150'}></path>
</svg>
</span>
</div>
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mb-1.5">
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
<div
class="absolute inset-y-0 left-0 bg-blue-500/80 transition-all duration-300"
style="width: {Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%"
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
style="width: {nodePercent.toFixed(1)}%"
></div>
</div>
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
<span>{formatBytes(nodeProg.progress.downloadedBytes)} / {formatBytes(nodeProg.progress.totalBytes)}</span>
<span>{formatSpeed(nodeProg.progress.speed)} • ETA {formatEta(nodeProg.progress.etaMs)}</span>
</div>
{#if nodeProg.progress.files.length > 0}
{@const inProgressFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) < 100)}
{@const completedFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) >= 100)}
{#if inProgressFiles.length > 0}
<div class="space-y-1">
{#each inProgressFiles as f}
<div class="text-[10px] font-mono text-exo-light-gray/80">
<div class="flex items-center justify-between">
</button>
{#if isExpanded}
<div class="mt-2 space-y-1.5">
{#if nodeProg.progress.files.length === 0}
<div class="text-[11px] font-mono text-exo-light-gray/70">No file details reported.</div>
{:else}
{#each nodeProg.progress.files as f}
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
{@const isFileComplete = filePercent >= 100}
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
<span class="truncate pr-2">{f.name}</span>
<span class="text-white/70">{Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%</span>
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
</div>
<div class="relative h-1 bg-exo-black/50 rounded-sm overflow-hidden mt-0.5">
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70"
style="width: {Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%"
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
style="width: {filePercent.toFixed(1)}%"
></div>
</div>
<div class="flex items-center justify-between text-[10px] text-exo-light-gray/70 mt-0.5">
@@ -1803,27 +1934,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
</div>
</div>
{/each}
</div>
{/if}
{#if completedFiles.length > 0}
<div class="pt-1 space-y-0.5">
{#each completedFiles as f}
<div class="text-[10px] font-mono text-exo-light-gray/70 flex items-center justify-between">
<span class="truncate pr-2">{f.name}</span>
<span class="text-white/60">100%</span>
</div>
{/each}
</div>
{/if}
{/if}
</div>
{/if}
</div>
{/each}
</div>
{/if}
<div class="text-sm text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
{:else}
<div class="text-sm {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
</div>
{/each}
</div>
{/if}
<div class="text-xs text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
{:else}
<div class="text-xs {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
{/if}
</div>
</div>
</div>

View File

@@ -345,13 +345,19 @@
<div class="rounded border border-exo-medium-gray/30 bg-exo-dark-gray/60 p-3 space-y-2">
<div class="flex items-center justify-between gap-3">
<div class="min-w-0 space-y-0.5">
<div class="text-sm font-mono text-white truncate">{model.prettyName ?? model.modelId}</div>
<div class="text-[11px] text-exo-light-gray font-mono truncate">
{model.modelId}
</div>
<div class="text-[11px] text-exo-light-gray font-mono">
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
</div>
<div
class="text-xs font-mono text-white truncate"
title={model.prettyName ?? model.modelId}
>{model.prettyName ?? model.modelId}</div>
<div
class="text-[10px] text-exo-light-gray font-mono truncate"
title={model.modelId}
>{model.modelId}</div>
{#if model.status !== 'completed'}
<div class="text-[11px] text-exo-light-gray font-mono">
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
</div>
{/if}
</div>
<div class="flex items-center gap-2">
<span class="text-xs font-mono {pct >= 100 ? 'text-green-400' : pct <= 0 ? 'text-red-400' : 'text-exo-yellow'}">
@@ -426,14 +432,14 @@
<style>
.downloads-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
grid-template-columns: repeat(auto-fill, minmax(320px, 1fr));
}
@media (min-width: 1024px) {
.downloads-grid {
grid-template-columns: repeat(3, minmax(0, 1fr));
}
}
@media (min-width: 1440px) {
@media (min-width: 1600px) {
.downloads-grid {
grid-template-columns: repeat(4, minmax(0, 1fr));
}

View File

@@ -29,10 +29,12 @@ dependencies = [
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"bidict>=0.23.1",
"mlx>=0.30.1",
"mlx>=0.30.1; sys_platform == 'darwin'",
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
"mlx-lm>=0.28.3",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
]
[project.scripts]
@@ -80,7 +82,7 @@ build-backend = "uv_build"
###
[tool.basedpyright]
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"]
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src", "bench"]
typeCheckingMode = "strict"
failOnWarnings = true

View File

@@ -1,5 +1,6 @@
import argparse
import multiprocessing as mp
import os
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -27,7 +28,7 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
worker: Worker
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
master: Master | None
@@ -61,15 +62,19 @@ class Node:
else:
api = None
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
)
if not args.no_worker:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
)
else:
worker = None
# We start every node with a master
master = Master(
node_id,
@@ -99,8 +104,9 @@ class Node:
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.worker.run)
tg.start_soon(self.election.run)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
tg.start_soon(self.master.run)
if self.api:
@@ -194,6 +200,7 @@ def main():
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
node = anyio.run(Node.create, args)
anyio.run(node.run)
@@ -207,6 +214,7 @@ class Args(CamelCaseModel):
spawn_api: bool = False
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
@classmethod
def parse(cls) -> Self:
@@ -244,6 +252,10 @@ class Args(CamelCaseModel):
dest="api_port",
default=52415,
)
parser.add_argument(
"--no-worker",
action="store_true",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -13,6 +13,12 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -21,11 +27,16 @@ from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
FinishReason,
GenerationStats,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -56,7 +67,7 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
HIDE_THINKING = False
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
@@ -161,7 +172,10 @@ class API:
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/v1/chat/completions")(self.chat_completions)
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -177,17 +191,32 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=command.model_meta,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
command = CreateInstance(instance=payload.instance)
instance = payload.instance
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
raise HTTPException(
status_code=400,
detail=f"Insufficient memory to create instance. Required: {required_memory.in_gb:.1f}GB, Available: {available_memory.in_gb:.1f}GB",
)
command = CreateInstance(
instance=instance,
)
await self._send(command)
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=model_meta,
)
async def get_placement(
@@ -352,32 +381,52 @@ class API:
instance_id=instance_id,
)
async def _generate_chat_stream(
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
is_thinking = False
with recv as token_chunks:
async for chunk in token_chunks:
if HIDE_THINKING:
if chunk.text == "<think>":
is_thinking = True
if chunk.text == "</think>":
is_thinking = False
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
if not (is_thinking and HIDE_THINKING):
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
break
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -392,6 +441,98 @@ class API:
await self._send(command)
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId, parse_gpt_oss: bool
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
),
finish_reason=finish_reason,
)
],
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId, parse_gpt_oss: bool
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
text_parts.append(chunk.text)
stats = chunk.stats or stats
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
assert model is not None
resp = BenchChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content=combined_text
),
finish_reason=finish_reason,
)
],
generation_stats=stats,
)
return resp
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
@@ -399,10 +540,12 @@ class API:
async def chat_completions(
self, payload: ChatCompletionTaskParams
) -> StreamingResponse:
"""Handle chat completions with proper streaming response."""
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -417,10 +560,40 @@ class API:
request_params=payload,
)
await self._send(command)
return StreamingResponse(
self._generate_chat_stream(command.command_id),
media_type="text/event-stream",
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id, parse_gpt_oss),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
)
payload.stream = False
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
return response
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
@@ -442,6 +615,8 @@ class API:
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
)
for card in MODEL_CARDS.values()
]
@@ -458,7 +633,7 @@ class API:
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._applystate)
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
print_startup_banner(self.port)
await serve(
@@ -470,7 +645,7 @@ class API:
self.command_sender.close()
self.global_event_receiver.close()
async def _applystate(self):
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.origin != self.session_id.master_node_id:

View File

@@ -7,9 +7,9 @@ from loguru import logger
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
@@ -19,9 +19,9 @@ from exo.shared.types.commands import (
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import Host
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
@@ -30,6 +30,7 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.shards import Sharding
def random_ephemeral_port() -> int:
@@ -66,6 +67,28 @@ def place_instance(
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_meta.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_meta.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
)
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [
@@ -130,17 +153,17 @@ def place_instance(
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
ephemeral_port = random_ephemeral_port()
hosts_by_node = get_mlx_ring_hosts_by_node(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts=[
Host(
ip=host.ip,
port=random_ephemeral_port(),
)
for host in hosts
],
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
)
return target_instances

View File

@@ -215,9 +215,11 @@ def get_mlx_ibv_devices_matrix(
continue
# Find the IP J uses to talk to I
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
if interface_name := _find_rdma_interface_name_for_ip(
connection_ip, node_i
):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
@@ -238,17 +240,17 @@ def _find_connection_ip(
node_i: NodeInfo,
node_j: NodeInfo,
cycle_digraph: Topology,
) -> Generator[str]:
"""Find all IP addresses that connect node i to node j."""
) -> Generator[tuple[str, bool]]:
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
def _find_interface_name_for_ip(
def _find_rdma_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
@@ -269,6 +271,109 @@ def _find_interface_name_for_ip(
return None
def _find_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
if node_info.node_profile is None:
return None
for interface in node_info.node_profile.network_interfaces:
if interface.ip_address == ip_address:
return interface.name
return None
def _find_ip_prioritised(
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
Priority order:
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
3. Non-Thunderbolt connections
4. Any other IP address
"""
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
en0_ip = iface_map.get("en0")
if en0_ip:
return en0_ip
en1_ip = iface_map.get("en1")
if en1_ip:
return en1_ip
non_thunderbolt_ip = next(
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
)
if non_thunderbolt_ip:
return non_thunderbolt_ip
if ips:
return ips[0][0]
return None
def get_mlx_ring_hosts_by_node(
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
ephemeral_port: int,
) -> dict[NodeId, list[Host]]:
"""Generate per-node host lists for MLX ring backend.
Each node gets a list where:
- Self position: Host(ip="0.0.0.0", port=ephemeral_port)
- Left/right neighbors: actual connection IPs
- Non-neighbors: Host(ip="198.51.100.1", port=0) placeholder (RFC 5737 TEST-NET-2)
"""
world_size = len(selected_cycle)
if world_size == 0:
return {}
hosts_by_node: dict[NodeId, list[Host]] = {}
for rank, node in enumerate(selected_cycle):
node_id = node.node_id
left_rank = (rank - 1) % world_size
right_rank = (rank + 1) % world_size
hosts_for_node: list[Host] = []
for idx, other_node in enumerate(selected_cycle):
if idx == rank:
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
continue
if idx not in {left_rank, right_rank}:
# Placeholder IP from RFC 5737 TEST-NET-2
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
continue
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
)
hosts_for_node.append(Host(ip=connection_ip, port=ephemeral_port))
hosts_by_node[node_id] = hosts_for_node
return hosts_by_node
def get_mlx_jaccl_coordinators(
selected_cycle: list[NodeInfo],
coordinator_port: int,
@@ -280,13 +385,14 @@ def get_mlx_jaccl_coordinators(
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
return ip
logger.warning(

View File

@@ -123,6 +123,8 @@ async def test_master():
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
supports_tensor=True,
),
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
@@ -163,32 +165,38 @@ async def test_master():
assert events[2].idx == 2
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
runner_id = list(
events[1].event.instance.shard_assignments.runner_to_shard.keys()
)[0]
assert events[1].event.instance == MlxRingInstance(
instance_id=events[1].event.instance.instance_id,
shard_assignments=ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
start_layer=0,
end_layer=16,
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
# Validate the shard assignments
expected_shard_assignments = ShardAssignments(
model_id=ModelId("llama-3.2-1b"),
runner_to_shard={
(runner_id): PipelineShardMetadata(
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
),
hosts=[],
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
supports_tensor=True,
),
device_rank=0,
world_size=1,
)
},
node_to_runner={node_id: runner_id},
)
assert created_instance.shard_assignments == expected_shard_assignments
# For single-node, hosts_by_node should have one entry with self-binding
assert len(created_instance.hosts_by_node) == 1
assert node_id in created_instance.hosts_by_node
assert len(created_instance.hosts_by_node[node_id]) == 1
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
assert created_instance.ephemeral_port > 0
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)

View File

@@ -38,7 +38,8 @@ def instance() -> Instance:
shard_assignments=ShardAssignments(
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
),
hosts=[],
hosts_by_node={},
ephemeral_port=50000,
)
@@ -49,6 +50,8 @@ def model_meta() -> ModelMetadata:
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
supports_tensor=True,
)
@@ -92,9 +95,13 @@ def test_get_instance_placements_create_instance(
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
# Add bidirectional connections for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
# act
placements = place_instance(cic, topology, {})
@@ -135,6 +142,8 @@ def test_get_instance_placements_one_node_exact_fit(
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {})
@@ -160,6 +169,8 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {})
@@ -185,6 +196,8 @@ def test_get_instance_placements_one_node_not_fit(
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
@@ -234,17 +247,15 @@ def test_get_transition_events_delete_instance(instance: Instance):
assert events[0].instance_id == instance_id
def test_placement_prioritizes_leaf_cycle_with_less_memory(
def test_placement_selects_cycle_with_most_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
# cycle having LESS total memory than D-E-F. The algorithm should still choose
# the cycle that contains a leaf node.
# Arrange two 3-node cycles with different total memory.
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
# The algorithm should select the cycle with the most available memory.
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
@@ -258,11 +269,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
node_id_e = NodeId()
node_id_f = NodeId()
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
node_id_x = NodeId()
node_id_y = NodeId()
node_id_z = NodeId()
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
@@ -273,24 +279,20 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
# Extra nodes with tiny memory so they can't form singleton placements
topology.add_node(create_node(10, node_id_x))
topology.add_node(create_node(10, node_id_y))
topology.add_node(create_node(10, node_id_z))
# Build directed cycles
# Build bidirectional cycles for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_d))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_e))
topology.add_connection(create_connection(node_id_f, node_id_d))
# Add extra outgoing edges from D/E/F so none of them are leaves
topology.add_connection(create_connection(node_id_d, node_id_x))
topology.add_connection(create_connection(node_id_e, node_id_y))
topology.add_connection(create_connection(node_id_f, node_id_z))
topology.add_connection(create_connection(node_id_d, node_id_f))
cic = place_instance_command(
model_meta=model_meta,
@@ -299,18 +301,17 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
# Act
placements = place_instance(cic, topology, {})
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
# D-E-F has more total memory.
# Assert: D-E-F cycle should be selected as it has more total memory
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
assert more_memory_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
def test_tensor_rdma_backend_connectivity_matrix(

View File

@@ -198,6 +198,8 @@ def test_get_shard_assignments(
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]

View File

@@ -51,6 +51,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"deepseek-v3.1-8bit": ModelCard(
@@ -64,6 +66,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# "deepseek-v3.2": ModelCard(
@@ -135,6 +139,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"kimi-k2-thinking": ModelCard(
@@ -148,6 +154,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# llama-3.1
@@ -162,6 +170,38 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-70b": ModelCard(
@@ -175,6 +215,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# llama-3.2
@@ -189,6 +231,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
),
"llama-3.2-3b": ModelCard(
@@ -202,6 +246,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
"llama-3.2-3b-8bit": ModelCard(
@@ -215,6 +261,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
# llama-3.3
@@ -229,6 +277,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-8bit": ModelCard(
@@ -242,6 +292,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-fp16": ModelCard(
@@ -255,20 +307,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
),
),
# phi-3
"phi-3-mini": ModelCard(
short_id="phi-3-mini",
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
name="Phi 3 Mini 128k (4-bit)",
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
pretty_name="Phi 3 Mini 128k (4-bit)",
storage_size=Memory.from_mb(2099),
n_layers=32,
hidden_size=8192,
supports_tensor=True,
),
),
# qwen3
@@ -283,6 +323,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-0.6b-8bit": ModelCard(
@@ -296,6 +338,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-30b": ModelCard(
@@ -309,6 +353,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-30b-8bit": ModelCard(
@@ -322,6 +368,68 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-235b-a22b-4bit": ModelCard(
@@ -335,6 +443,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-235b-a22b-8bit": ModelCard(
@@ -348,6 +458,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
@@ -361,6 +473,8 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
@@ -374,77 +488,84 @@ MODEL_CARDS: dict[str, ModelCard] = {
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
# granite
"granite-3.3-2b": ModelCard(
short_id="granite-3.3-2b",
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
name="Granite 3.3 2B (FP16)",
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
pretty_name="Granite 3.3 2B (FP16)",
storage_size=Memory.from_mb(4951),
n_layers=40,
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
),
# "granite-3.3-8b": ModelCard(
# short_id="granite-3.3-8b",
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
# name="Granite 3.3 8B",
# description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# Needs to be quantized g32 or g16.
"glm-4.5-air-8bit": ModelCard(
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
),
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
# short_id="devstral-2-123b-instruct-2512-8bit",
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
# pretty_name="Granite 3.3 8B",
# storage_size=Memory.from_kb(15958720),
# n_layers=40,
# ),
# ),
# smol-lm
# "smol-lm-135m": ModelCard(
# short_id="smol-lm-135m",
# model_id="mlx-community/SmolLM-135M-4bit",
# name="Smol LM 135M",
# description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
# pretty_name="Smol LM 135M",
# storage_size=Memory.from_kb(73940),
# n_layers=30,
# ),
# ),
# gpt-oss
# "gpt-oss-120b-MXFP4-Q8": ModelCard(
# short_id="gpt-oss-120b-MXFP4-Q8",
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
# name="GPT-OSS 120B (MXFP4-Q8, MLX)",
# description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
# pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
# storage_size=Memory.from_kb(68_996_301),
# n_layers=36,
# hidden_size=2880,
# supports_tensor=True,
# ),
# ),
# "gpt-oss-20b-4bit": ModelCard(
# short_id="gpt-oss-20b-4bit",
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
# name="GPT-OSS 20B (MXFP4-Q4, MLX)",
# description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
# pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
# storage_size=Memory.from_kb(11_744_051),
# n_layers=24,
# hidden_size=2880,
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# storage_size=Memory.from_kb(133_000_000),
# n_layers=88,
# hidden_size=12288,
# supports_tensor=True,
# ),
# ),

View File

@@ -6,6 +6,7 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
@@ -25,6 +26,7 @@ class ConfigData(BaseModel):
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
@@ -106,10 +108,19 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_id,
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -36,6 +36,8 @@ def get_pipeline_shard_metadata(
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,
supports_tensor=True,
),
device_rank=device_rank,
world_size=world_size,

View File

@@ -19,7 +19,7 @@ def test_apply_node_download_progress():
NodeDownloadProgress(download_progress=event), state
)
assert new_state == State(downloads={NodeId("node-1"): [event]})
assert new_state.downloads == {NodeId("node-1"): [event]}
def test_apply_two_node_download_progress():
@@ -42,4 +42,4 @@ def test_apply_two_node_download_progress():
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -5,7 +5,8 @@ from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.types.common import CommandId
from exo.shared.types.models import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -51,6 +52,10 @@ class ChatCompletionMessage(BaseModel):
function_call: dict[str, Any] | None = None
class BenchChatCompletionMessage(ChatCompletionMessage):
pass
class TopLogprobItem(BaseModel):
token: str
logprob: float
@@ -113,6 +118,18 @@ class ChatCompletionResponse(BaseModel):
service_tier: str | None = None
class GenerationStats(BaseModel):
prompt_tps: float
generation_tps: float
prompt_tokens: int
generation_tokens: int
peak_memory_usage: Memory
class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
class ChatCompletionTaskParams(BaseModel):
model: str
frequency_penalty: float | None = None
@@ -135,6 +152,10 @@ class ChatCompletionTaskParams(BaseModel):
user: str | None = None
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
pass
class PlaceInstanceParams(BaseModel):
model_id: str
sharding: Sharding = Sharding.Pipeline
@@ -174,6 +195,7 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_meta: ModelMetadata
class DeleteInstanceResponse(BaseModel):

View File

@@ -1,5 +1,6 @@
from enum import Enum
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -20,6 +21,7 @@ class TokenChunk(BaseChunk):
text: str
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
class ImageChunk(BaseChunk):

View File

@@ -14,3 +14,5 @@ class ModelMetadata(CamelCaseModel):
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool

View File

@@ -40,6 +40,10 @@ class LoadModel(BaseTask): # emitted by Worker
pass
class ConnectToGroup(BaseTask): # emitted by Worker
pass
class StartWarmup(BaseTask): # emitted by Worker
pass
@@ -57,5 +61,11 @@ class Shutdown(BaseTask): # emitted by Worker
Task = (
CreateRunner | DownloadModel | LoadModel | StartWarmup | ChatCompletion | Shutdown
CreateRunner
| DownloadModel
| ConnectToGroup
| LoadModel
| StartWarmup
| ChatCompletion
| Shutdown
)

View File

@@ -25,7 +25,8 @@ class BaseInstance(TaggedModel):
class MlxRingInstance(BaseInstance):
hosts: list[Host]
hosts_by_node: dict[NodeId, list[Host]]
ephemeral_port: int
class MlxJacclInstance(BaseInstance):

View File

@@ -1,4 +1,4 @@
from exo.shared.types.api import FinishReason
from exo.shared.types.api import FinishReason, GenerationStats
from exo.utils.pydantic_ext import TaggedModel
@@ -15,6 +15,7 @@ class GenerationResponse(BaseRunnerResponse):
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -21,7 +21,15 @@ class BaseRunnerStatus(TaggedModel):
return isinstance(self, RunnerRunning)
class RunnerWaitingForModel(BaseRunnerStatus):
class RunnerIdle(BaseRunnerStatus):
pass
class RunnerConnecting(BaseRunnerStatus):
pass
class RunnerConnected(BaseRunnerStatus):
pass
@@ -45,6 +53,10 @@ class RunnerRunning(BaseRunnerStatus):
pass
class RunnerShuttingDown(BaseRunnerStatus):
pass
class RunnerShutdown(BaseRunnerStatus):
pass
@@ -54,12 +66,15 @@ class RunnerFailed(BaseRunnerStatus):
RunnerStatus = (
RunnerWaitingForModel
RunnerIdle
| RunnerConnecting
| RunnerConnected
| RunnerLoading
| RunnerLoaded
| RunnerWarmingUp
| RunnerReady
| RunnerRunning
| RunnerShuttingDown
| RunnerShutdown
| RunnerFailed
)

View File

@@ -450,6 +450,11 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# TODO: 'Smart' downloads are disabled because:
# (i) We don't handle all kinds of files;
# (ii) We don't have sticky sessions.
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)

View File

@@ -95,7 +95,15 @@ def extract_layer_num(tensor_name: str) -> int | None:
def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list[str]:
default_patterns = set(
["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt", "*.jinja"]
[
"*.json",
"*.py",
"tokenizer.model",
"tiktoken.model",
"*.tiktoken",
"*.txt",
"*.jinja",
]
)
shard_specific_patterns: set[str] = set()
if weight_map:

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from copy import copy
from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
@@ -12,7 +13,7 @@ from exo.shared.types.worker.shards import (
from exo.worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Shoudl this be a classmethod?
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
class ShardDownloader(ABC):
@abstractmethod
async def ensure_shard(
@@ -43,34 +44,7 @@ class ShardDownloader(ABC):
Yields:
tuple[Path, RepoDownloadProgress]: The path and progress of a shard download.
"""
yield (
Path("/tmp/noop_shard"),
RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
),
)
yield (Path("/tmp/noop_shard"), NOOP_DOWNLOAD_PROGRESS)
@abstractmethod
async def get_shard_download_status_for_shard(
@@ -94,46 +68,41 @@ class NoopShardDownloader(ShardDownloader):
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
yield (
Path("/tmp/noop_shard"),
RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
),
NOOP_DOWNLOAD_PROGRESS,
)
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata
) -> RepoDownloadProgress:
return RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=shard,
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
)
dp = copy(NOOP_DOWNLOAD_PROGRESS)
dp.shard = shard
return dp
NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,
supports_tensor=False,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
),
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
)

View File

@@ -9,8 +9,7 @@ MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"
CACHE_GROUP_SIZE: int = 64
KV_CACHE_BITS: int | None = 8
TEMPERATURE: float = 1.0
KV_CACHE_BITS: int | None = None
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True

View File

@@ -6,7 +6,13 @@ from mlx_lm.models.cache import KVCache
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import ChatCompletionMessage, FinishReason
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
FinishReason,
GenerationStats,
)
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
@@ -72,7 +78,7 @@ def warmup_inference(
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=65536,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
@@ -80,17 +86,42 @@ def warmup_inference(
tokens_generated += 1
logger.info("Generated ALL warmup tokens")
# TODO: Do we want an mx_barrier?
# At least this version is actively incorrect, as it should use mx_barrier(group)
mx_barrier()
return tokens_generated
def ban_token_ids(token_ids: list[int]) -> Callable[[mx.array, mx.array], mx.array]:
token_ids = [int(t) for t in token_ids]
def proc(_history: mx.array, logits: mx.array) -> mx.array:
for tid in token_ids:
logits[..., tid] = -1e9
return logits
return proc
def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
eos: list[int] | None = getattr(tokenizer, "eos_token_ids", None)
if eos is None:
return []
return eos
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
task: ChatCompletionTaskParams,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
# Currently we support chat-completion tasks only.
logger.info(f"task_params: {task}")
@@ -101,6 +132,12 @@ def mlx_generate(
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
eos_ids = eos_ids_from_tokenizer(tokenizer)
logits_processors = [ban_token_ids(eos_ids)]
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
@@ -108,26 +145,40 @@ def mlx_generate(
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
prefill_step_size=65536,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
logger.info(out.text)
if out.finish_reason is not None and out.finish_reason not in get_args(
FinishReason
):
# We don't throw here as this failure case is really not all that bad
# Just log the error and move on
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
# We don't throw here as this failure case is really not all that bad
# Just log the error and move on
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
yield GenerationResponse(
text=out.text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
)
if out.finish_reason is not None:
break
# TODO: Do we want an mx_barrier?

View File

@@ -13,7 +13,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
TEMPERATURE,
TRUST_REMOTE_CODE,
)
@@ -21,6 +20,8 @@ try:
from mlx_lm.tokenizer_utils import load_tokenizer
except ImportError:
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
import contextlib
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.utils import load_model
@@ -48,6 +49,7 @@ from exo.worker.engines.mlx.auto_parallel import (
)
from exo.worker.runner.bootstrap import logger
Group = mx.distributed.Group
# Needed for 8 bit model
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
@@ -67,7 +69,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
)
def mx_barrier(group: mx.distributed.Group | None = None):
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
@@ -77,7 +79,7 @@ def mx_barrier(group: mx.distributed.Group | None = None):
)
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
@@ -99,91 +101,97 @@ class HostList(RootModel[list[str]]):
def mlx_distributed_init(
bound_instance: BoundInstance,
) -> mx.distributed.Group:
) -> Group:
"""
Initialize the MLX distributed (runs in thread pool).
Either hosts or mlx_ibv_devices must be provided:
- hosts: traditional host-based connectivity using MLX_HOSTFILE
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
- strict: if True, raise an error if the distributed backend is not available
Initialize MLX distributed.
"""
rank = bound_instance.bound_shard.device_rank
logger.info(f"Starting initialization for rank {rank}")
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts=hosts):
hostfile = f"./hosts_{rank}.json"
hosts_json = HostList.from_hosts(hosts).model_dump_json()
coordination_file = None
try:
# TODO: singleton instances
match bound_instance.instance:
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
with open(hostfile, "w") as f:
_ = f.write(hosts_json)
with open(coordination_file, "w") as f:
_ = f.write(hosts_json)
logger.info(f"rank {rank} hostfile: {hostfile} hosts: {hosts_json}")
logger.info(
f"rank {rank} hostfile: {coordination_file} hosts: {hosts_json}"
)
os.environ["MLX_HOSTFILE"] = hostfile
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_RING_VERBOSE"] = "1"
group = mx.distributed.init(backend="ring", strict=True)
os.environ["MLX_HOSTFILE"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_RING_VERBOSE"] = "1"
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
):
# Use RDMA connectivity matrix
devices_file = f"./hosts_{rank}.json"
ibv_devices_json = json.dumps(ibv_devices)
case MlxJacclInstance(
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
):
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
ibv_devices_json = json.dumps(ibv_devices)
with open(devices_file, "w") as f:
_ = f.write(ibv_devices_json)
with open(coordination_file, "w") as f:
_ = f.write(ibv_devices_json)
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = devices_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"Rank {rank} mlx distributed initialization complete")
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
return group
finally:
with contextlib.suppress(FileNotFoundError):
if coordination_file:
os.remove(coordination_file)
def initialize_mlx(
bound_instance: BoundInstance,
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
"""
Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread.
"""
) -> Group:
# should we unseed it?
# TODO: pass in seed from params
mx.random.seed(42)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
assert len(bound_instance.instance.shard_assignments.node_to_runner) > 1, (
"Tried to initialize mlx for a single node instance"
)
return mlx_distributed_init(bound_instance)
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
def load_mlx_items(
bound_instance: BoundInstance, group: Group | None
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
# TODO: pass temperature
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
logger.info("Created a sampler")
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass
# model, config = quantize_model(
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
# )
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
group = mlx_distributed_init(bound_instance)
start_time = time.perf_counter()
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
end_time = time.perf_counter()
@@ -193,14 +201,12 @@ def initialize_mlx(
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
logger.debug(model)
return cast(Model, model), tokenizer, sampler
def shard_and_load(
shard_metadata: ShardMetadata,
group: mx.distributed.Group,
group: Group,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
@@ -389,11 +395,15 @@ def set_wired_limit_for_model(model_size: Memory):
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
)
kv_bytes = int(0.02 * model_bytes)
target_cache = int(1.10 * (model_bytes + kv_bytes))
target_cache = min(target_cache, max_rec_size)
mx.set_cache_limit(target_cache)
mx.set_wired_limit(max_rec_size)
logger.info(
f"Wired limit set to {max_rec_size}. Cache limit set to {target_cache}."
)
logger.info(f"Wired limit set to {max_rec_size}.")
def mlx_cleanup(
model: Model | None, tokenizer: TokenizerWrapper | None, group: Group | None
) -> None:
del model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()

View File

@@ -23,6 +23,7 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
@@ -83,7 +84,7 @@ class Worker:
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup | None = None
@@ -128,6 +129,7 @@ class Worker:
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -200,11 +202,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard not in self.download_status:
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -217,7 +219,7 @@ class Worker:
progress = DownloadCompleted(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -228,7 +230,7 @@ class Worker:
)
)
else:
self.event_sender.send_nowait(
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
@@ -349,7 +351,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata] = status
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -363,7 +365,7 @@ class Worker:
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
self.download_status[shard] = status
self.download_status[shard.model_meta.model_id] = status
# Footgun!
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
@@ -384,7 +386,7 @@ class Worker:
progress
),
)
self.download_status[shard] = status
self.download_status[shard.model_meta.model_id] = status
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
@@ -414,9 +416,14 @@ class Worker:
while True:
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(self.state.topology)
conns = await check_reachable(self.state.topology, self.node_id)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
logger.warning(
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = Connection(
local_node_id=self.node_id,
send_back_node_id=nid,
@@ -439,3 +446,40 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id, shard_metadata=progress.shard
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -3,8 +3,10 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadModel,
LoadModel,
@@ -14,20 +16,25 @@ from exo.shared.types.tasks import (
TaskId,
TaskStatus,
)
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
RunnerFailed,
RunnerId,
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerStatus,
RunnerWaitingForModel,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -36,7 +43,7 @@ def plan(
# Runners is expected to be FRESH and so should not come from state
runners: Mapping[RunnerId, RunnerSupervisor],
# DL_status is expected to be FRESH and so should not come from state
download_status: Mapping[ShardMetadata, DownloadProgress],
download_status: Mapping[ModelId, DownloadProgress],
# gdls is not expected to be fresh
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
instances: Mapping[InstanceId, Instance],
@@ -48,6 +55,7 @@ def plan(
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(runners, download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners)
@@ -103,12 +111,15 @@ def _create_runner(
def _model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ShardMetadata, DownloadProgress],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
if (
isinstance(runner.status, RunnerWaitingForModel)
and runner.bound_instance.bound_shard not in download_status
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
return DownloadModel(
@@ -117,14 +128,54 @@ def _model_needs_download(
)
""" --- TODO!
def _init_backend(
def _init_distributed_backend(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
) -> LoadModel | None:
for runner in runner.values()
pass
"""
):
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
if is_single_node_instance:
continue
runner_is_idle = isinstance(runner.status, RunnerIdle)
all_runners_connecting = all(
isinstance(
all_runners.get(global_runner_id),
(RunnerConnecting, RunnerIdle),
)
for global_runner_id in shard_assignments.runner_to_shard
)
if not (runner_is_idle and all_runners_connecting):
continue
runner_id = runner.bound_instance.bound_runner_id
shard = runner.bound_instance.bound_shard
device_rank = shard.device_rank
world_size = shard.world_size
assert device_rank < world_size
assert device_rank >= 0
accepting_ranks = device_rank < world_size - 1
# Rank = n-1
connecting_rank_ready = device_rank == world_size - 1 and all(
isinstance(all_runners.get(global_runner_id, None), RunnerConnecting)
for global_runner_id in shard_assignments.runner_to_shard
if global_runner_id != runner_id
)
if not (accepting_ranks or connecting_rank_ready):
continue
return ConnectToGroup(instance_id=instance.instance_id)
return None
def _load_model(
@@ -136,31 +187,33 @@ def _load_model(
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
all_downloads_complete_local = all(
all_local_downloads_complete = all(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata == shard_assignments.runner_to_shard[rid]
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid, rid in shard_assignments.node_to_runner.items()
for nid in shard_assignments.node_to_runner
)
if not all_local_downloads_complete:
continue
runner_is_waiting = isinstance(runner.status, RunnerWaitingForModel)
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
all_runners_expecting_model = all(
is_runner_waiting = isinstance(runner.status, RunnerConnected)
all_ready_for_model = all(
isinstance(
all_runners.get(global_runner_id),
(RunnerWaitingForModel, RunnerLoading, RunnerLoaded),
all_runners.get(global_runner_id, None),
(RunnerConnected, RunnerLoading, RunnerLoaded),
)
for global_runner_id in shard_assignments.runner_to_shard
)
if (
all_downloads_complete_local
and runner_is_waiting
and all_runners_expecting_model
):
if is_runner_waiting and all_ready_for_model:
return LoadModel(instance_id=instance.instance_id)
return None
@@ -183,8 +236,8 @@ def _ready_to_warmup(
assert device_rank < world_size
assert device_rank >= 0
# Rank != n-1
accepting_ranks_ready = device_rank != world_size - 1 and all(
# Rank != 0
accepting_ranks_ready = device_rank > 0 and all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerLoaded, RunnerWarmingUp),
@@ -192,8 +245,8 @@ def _ready_to_warmup(
for global_runner_id in shard_assignments.runner_to_shard
)
# Rank = n-1
connecting_rank_ready = device_rank == world_size - 1 and all(
# Rank = 0
connecting_rank_ready = device_rank == 0 and all(
isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
for global_runner_id in shard_assignments.runner_to_shard
if global_runner_id != runner_id
@@ -221,6 +274,14 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard

View File

@@ -2,16 +2,13 @@ import os
import loguru
from exo.shared.types.events import Event
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import MpReceiver, MpSender
logger: "loguru.Logger"
if os.getenv("EXO_TESTS") == "1":
logger = loguru.logger
logger: "loguru.Logger" = loguru.logger
def entrypoint(
@@ -30,6 +27,23 @@ def entrypoint(
logger = _logger
# Import main after setting global logger - this lets us just import logger from this module
from exo.worker.runner.runner import main
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver)
except Exception as e:
logger.opt(exception=e).warning(
f"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}"
)
event_sender.send(
RunnerStatusUpdated(
runner_id=bound_instance.bound_runner_id,
runner_status=RunnerFailed(error_message=str(e)),
)
)
finally:
event_sender.close()
task_receiver.close()
event_sender.join()
task_receiver.join()
logger.info("bye from the runner")

View File

@@ -11,6 +11,7 @@ from exo.shared.types.events import (
)
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
@@ -22,20 +23,25 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
RunnerFailed,
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
RunnerWaitingForModel,
RunnerWarmingUp,
)
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_cleanup,
mlx_force_oom,
)
from exo.worker.runner.bootstrap import logger
@@ -63,9 +69,10 @@ def main(
model = None
tokenizer = None
sampler = None
group = None
current_status: RunnerStatus = RunnerWaitingForModel()
logger.info("runner waiting for model")
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
@@ -78,9 +85,26 @@ def main(
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case LoadModel() if isinstance(
current_status, (RunnerWaitingForModel, RunnerFailed)
case ConnectToGroup() if isinstance(
current_status, (RunnerIdle, RunnerFailed)
):
logger.info("runner connecting")
current_status = RunnerConnecting()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
group = initialize_mlx(bound_instance)
logger.info("runner connected")
current_status = RunnerConnected()
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
case LoadModel() if (
isinstance(current_status, RunnerConnected)
and group is not None
) or (isinstance(current_status, RunnerIdle) and group is None):
current_status = RunnerLoading()
logger.info("runner loading")
event_sender.send(
@@ -89,15 +113,12 @@ def main(
)
)
model, tokenizer, sampler = initialize_mlx(bound_instance)
model, tokenizer, sampler = load_mlx_items(
bound_instance, group
)
current_status = RunnerLoaded()
logger.info("runner loaded")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
@@ -123,11 +144,6 @@ def main(
)
current_status = RunnerReady()
logger.info("runner ready")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=RunnerReady()
)
)
case ChatCompletion(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
@@ -164,6 +180,7 @@ def main(
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
@@ -172,29 +189,32 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
mlx_cleanup(model, tokenizer, group)
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=RunnerReady()
runner_id=runner_id, runner_status=current_status
)
)
case Shutdown():
logger.info("runner shutting down")
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
break
current_status = RunnerShutdown()
case _:
raise ValueError("Received task outside of state machine")
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
)
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
)
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
if isinstance(current_status, RunnerShutdown):
break
except ClosedResourceError:
logger.warning("runner communication closed unexpectedly")
except Exception as e:

View File

@@ -14,13 +14,23 @@ from anyio import (
from anyio.abc import TaskGroup
from loguru import logger
from exo.shared.types.events import Event, RunnerStatusUpdated, TaskAcknowledged
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerConnecting,
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerRunning,
RunnerShuttingDown,
RunnerStatus,
RunnerWaitingForModel,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
@@ -39,10 +49,10 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
# err_path: str
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerWaitingForModel, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -77,7 +87,6 @@ class RunnerSupervisor:
_ev_recv=ev_recv,
_task_sender=task_sender,
_event_sender=event_sender,
# err_path=err_path,
)
return self
@@ -118,6 +127,10 @@ class RunnerSupervisor:
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.completed:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -138,6 +151,22 @@ class RunnerSupervisor:
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
continue
if (
isinstance(event, TaskStatusUpdated)
and event.task_status == TaskStatus.Complete
):
# If a task has just been completed, we should be working on it.
assert isinstance(
self.status,
(
RunnerRunning,
RunnerWarmingUp,
RunnerLoading,
RunnerConnecting,
RunnerShuttingDown,
),
)
self.completed.add(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)

View File

@@ -9,9 +9,11 @@ MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
NODE_C: Final[NodeId] = NodeId("cccccccc-cccc-4ccc-8ccc-cccccccccccc")
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
RUNNER_3_ID: Final[RunnerId] = RunnerId("Runner3")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
@@ -24,3 +26,9 @@ TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777")
COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888")
SHUTDOWN_TASK_ID = TaskId("shutdown")
CHAT_COMPLETION_TASK_ID = TaskId("chat-completion")
INITIALIZATION_TASK_ID = TaskId("initialisation")
LOAD_TASK_ID = TaskId("load")
WARMUP_TASK_ID = TaskId("warmup")

View File

@@ -1,9 +1,9 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask
from exo.shared.types.tasks import BaseTask, TaskId
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
@@ -14,10 +14,12 @@ from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignm
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
# Runner supervisor without multiprocessing logic.
@dataclass(frozen=True)
class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
class OtherTask(BaseTask):
@@ -35,6 +37,8 @@ def get_pipeline_shard_metadata(
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=2048,
supports_tensor=False,
),
device_rank=device_rank,
world_size=world_size,
@@ -67,5 +71,27 @@ def get_mlx_ring_instance(
shard_assignments=get_shard_assignments(
model_id, node_to_runner, runner_to_shard
),
hosts=[],
hosts_by_node={},
ephemeral_port=50000,
)
def get_bound_mlx_ring_instance(
instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId
) -> BoundInstance:
shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=2)
other_shard = get_pipeline_shard_metadata(
model_id=model_id, device_rank=1, world_size=2
)
instance = get_mlx_ring_instance(
instance_id=instance_id,
model_id=model_id,
node_to_runner={
node_id: runner_id,
NodeId("other_node"): RunnerId("other_runner"),
},
runner_to_shard={runner_id: shard, RunnerId("other_runner"): other_shard},
)
return BoundInstance(
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
)

View File

@@ -1,12 +1,13 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerWaitingForModel,
RunnerConnected,
RunnerIdle,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
@@ -38,16 +39,14 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ShardMetadata, DownloadProgress] = {}
download_status: dict[ModelId, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
@@ -82,20 +81,20 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
bound_instance=bound_instance, status=RunnerConnected()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerWaitingForModel(),
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
}
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
}
# Global view has completed downloads for both nodes
@@ -133,17 +132,15 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# Local status claims the shard is downloaded already
local_download_status = {
shard: DownloadCompleted(shard_metadata=shard, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
@@ -183,19 +180,19 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerWaitingForModel()
bound_instance=bound_instance, status=RunnerConnected()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerWaitingForModel(),
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
}
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
}
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
@@ -213,3 +210,22 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
)
assert result is None
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_B: [
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
], # NODE_B has no downloads completed yet
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is not None

View File

@@ -5,9 +5,9 @@ from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import (
RunnerIdle,
RunnerReady,
RunnerRunning,
RunnerWaitingForModel,
)
from exo.worker.tests.constants import (
COMMAND_1_ID,
@@ -99,7 +99,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerReady(),
RUNNER_2_ID: RunnerWaitingForModel(),
RUNNER_2_ID: RunnerIdle(),
}
task = ChatCompletion(

View File

@@ -2,8 +2,9 @@ import exo.worker.plan as plan_mod
from exo.shared.types.tasks import StartWarmup
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerIdle,
RunnerLoaded,
RunnerWaitingForModel,
RunnerLoading,
RunnerWarmingUp,
)
from exo.worker.tests.constants import (
@@ -11,8 +12,10 @@ from exo.worker.tests.constants import (
MODEL_A_ID,
NODE_A,
NODE_B,
NODE_C,
RUNNER_1_ID,
RUNNER_2_ID,
RUNNER_3_ID,
)
from exo.worker.tests.unittests.conftest import (
FakeRunnerSupervisor,
@@ -21,18 +24,19 @@ from exo.worker.tests.unittests.conftest import (
)
def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
"""
For non-zero device_rank shards, StartWarmup should be emitted when all
shards in the instance are Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=3)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=3)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=2, world_size=3)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID, NODE_C: RUNNER_3_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1, RUNNER_3_ID: shard2},
)
bound_instance = BoundInstance(
@@ -47,6 +51,7 @@ def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
RUNNER_3_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
@@ -128,7 +133,7 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWaitingForModel(),
RUNNER_1_ID: RunnerIdle(),
RUNNER_2_ID: RunnerLoaded(),
}
@@ -149,6 +154,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
"""
Rank-zero shard should not start warmup until all non-zero ranks are
already WarmingUp.
For accepting ranks (device_rank != 0), StartWarmup should be
emitted when all shards in the instance are Loaded/WarmingUp.
In a 2-node setup, rank 1 is the accepting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -159,6 +167,153 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 1 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
"""
For connecting rank (device_rank == world_size - 1), StartWarmup should
only be emitted once all the other runners are already warming up.
In a 2-node setup, rank 1 is the connecting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerWarmingUp(),
RUNNER_2_ID: RunnerLoaded(),
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warming():
"""
Accepting rank should not start warmup while any shard is not Loaded/WarmingUp.
In a 2-node setup, rank 0 is the accepting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 0 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoading(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
"""
Connecting rank (device_rank == 0) should not start warmup
until all other ranks are already WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)

View File

@@ -0,0 +1,212 @@
# Check tasks are complete before runner is ever ready.
from collections.abc import Iterable
from typing import Callable
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TaskStatus,
)
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
RunnerIdle,
RunnerLoaded,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
from ...constants import (
CHAT_COMPLETION_TASK_ID,
COMMAND_1_ID,
INITIALIZATION_TASK_ID,
INSTANCE_1_ID,
LOAD_TASK_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
SHUTDOWN_TASK_ID,
WARMUP_TASK_ID,
)
from ..conftest import get_bound_mlx_ring_instance
def make_nothin[T, U, V](res: T) -> Callable[[], T]:
def nothin(*_1: U, **_2: V) -> T:
return res
return nothin
nothin = make_nothin(None)
INIT_TASK = ConnectToGroup(
task_id=INITIALIZATION_TASK_ID,
instance_id=INSTANCE_1_ID,
)
LOAD_TASK = LoadModel(
task_id=LOAD_TASK_ID,
instance_id=INSTANCE_1_ID,
)
WARMUP_TASK = StartWarmup(
task_id=WARMUP_TASK_ID,
instance_id=INSTANCE_1_ID,
)
SHUTDOWN_TASK = Shutdown(
task_id=SHUTDOWN_TASK_ID,
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
)
CHAT_PARAMS = ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
stream=True,
max_tokens=4,
temperature=0.0,
)
CHAT_TASK = ChatCompletion(
task_id=CHAT_COMPLETION_TASK_ID,
command_id=COMMAND_1_ID,
task_params=CHAT_PARAMS,
instance_id=INSTANCE_1_ID,
)
def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Event]):
for test_event, true_event in zip(test_events, true_events, strict=True):
test_event.event_id = true_event.event_id
assert test_event == true_event, f"{test_event} != {true_event}"
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1, 1)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NODE_A,
)
task_sender, task_receiver = mp_channel[Task]()
event_sender, event_receiver = mp_channel[Event]()
with task_sender, event_receiver:
for t in tasks:
task_sender.send(t)
# worst monkeypatch known to man
# this is some c++ nonsense
event_sender.close = nothin
event_sender.join = nothin
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver)
return event_receiver.collect()
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
expected_chunk = ChunkGenerated(
command_id=COMMAND_1_ID,
chunk=TokenChunk(
idx=0,
model=MODEL_A_ID,
text="hi",
token_id=0,
finish_reason="stop",
),
)
assert_events_equal(
events,
[
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerIdle()),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=LOAD_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=WARMUP_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
),
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
),
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
],
)

View File

@@ -0,0 +1 @@
# TODO:

View File

@@ -1,33 +1,64 @@
import socket
import http.client
from anyio import create_task_group, to_thread
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
# TODO: ref. api port
async def check_reachability(
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
) -> None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1) # 1 second timeout
try:
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
except socket.gaierror:
# seems to throw on ipv6 loopback. oh well
# logger.warning(f"invalid {target_ip=}")
"""Check if a node is reachable at the given IP and verify its identity."""
def _fetch_remote_node_id() -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
body = response.read().decode("utf-8").strip()
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
return NodeId(body) or None
except OSError:
return None
finally:
connection.close()
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id is None:
return
finally:
sock.close()
if result == 0:
if target_node_id not in out:
out[target_node_id] = set()
out[target_node_id].add(target_ip)
if remote_node_id == self_node_id:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
f"ip={target_ip}, expected_node_id={expected_node_id}, "
f"remote_node_id={remote_node_id}"
)
return
if remote_node_id not in out:
out[remote_node_id] = set()
out[remote_node_id].add(target_ip)
async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
async with create_task_group() as tg:
for node in topology.list_nodes():
@@ -35,7 +66,11 @@ async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability, iface.ip_address, node.node_id, reachable
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
)
return reachable

View File

@@ -1,24 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
networksetup -listallnetworkservices | grep -q '^Thunderbolt Bridge$' \
&& echo "Disabling bridge in networksetup" \
&& networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
networksetup -listallnetworkservices | grep -q '^\*Thunderbolt Bridge$' \
&& echo "Bridge disabled in networksetup"
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && echo "Removing bridge members in ifconfig" && {
ifconfig bridge0 | \
awk '/member/ {print $2}' | \
xargs -n1 sudo ifconfig bridge0 deletem
}
ifconfig bridge0 | grep -q 'status: active' && sudo ifconfig bridge0 down
ifconfig bridge0 | grep -q 'status: inactive' && echo "Bridge disabled in ifconfig"
}
for iface in $(seq 2 7); do
sudo ipconfig set "en$iface" dhcp && echo "enabled dhcp on en$iface" || echo "failed to enable dhcp on en$iface"
done

41
uv.lock generated
View File

@@ -334,8 +334,10 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "networkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -374,9 +376,11 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mlx", specifier = ">=0.30.1" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = ">=0.30.1" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = ">=0.30.1" },
{ name = "mlx-lm", specifier = ">=0.28.3" },
{ name = "networkx", specifier = ">=3.5" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "protobuf", specifier = ">=6.32.0" },
{ name = "psutil", specifier = ">=7.0.0" },
{ name = "pydantic", specifier = ">=2.11.7" },
@@ -801,6 +805,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/ff/1e1968f107b4221a98dc26832586b1f646b27ddf3e55c95051c09d751f0a/mlx-0.30.1-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:d18012d5cf0f013bc4a405cfd1e9d2d28e798f4d2dc4f15aa0fbffff73c02ba2", size = 687114, upload-time = "2025-12-18T01:55:56.506Z" },
]
[package.optional-dependencies]
cpu = [
{ name = "mlx-cpu", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx-cpu"
version = "0.30.1"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/64/51/32903727a68a61e972383e28a775c1f5e5f0628552c85cbc6103d68c0dc4/mlx_cpu-0.30.1-py3-none-manylinux_2_35_aarch64.whl", hash = "sha256:3f5dc2e4d0849181f8253508bb6a0854250483fc63d43ac79ec614b19824b172", size = 8992394, upload-time = "2025-12-18T00:16:13.696Z" },
{ url = "https://files.pythonhosted.org/packages/0c/74/69c21bb907f3c4064881ab0653029c939ae15fc4e63a5301ef8643cb1d68/mlx_cpu-0.30.1-py3-none-manylinux_2_35_x86_64.whl", hash = "sha256:c9ea6992d8c001e1123dfd3b4d4405ff576c787eec52656ad405e3d033a8be60", size = 10553055, upload-time = "2025-12-18T00:16:16.104Z" },
]
[[package]]
name = "mlx-lm"
version = "0.28.3"
@@ -946,6 +964,27 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e2/c1/6dba12fdf68b02a21ac411c9df19afa66bed2540f467150ca64d246b463d/numpy-2.3.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e1708fac43ef8b419c975926ce1eaf793b0c13b7356cfab6ab0dc34c0a02ac0f", size = 18652691, upload-time = "2025-10-15T16:17:46.247Z" },
]
[[package]]
name = "openai-harmony"
version = "0.0.8"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3e/92/2d038d096f29179c7c9571b431f9e739f87a487121901725e23fe338dd9d/openai_harmony-0.0.8.tar.gz", hash = "sha256:6e43f98e6c242fa2de6f8ea12eab24af63fa2ed3e89c06341fb9d92632c5cbdf", size = 284777, upload-time = "2025-11-05T19:07:06.727Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/45/c6/2502f416d46be3ec08bb66d696cccffb57781a499e3ff2e4d7c174af4e8f/openai_harmony-0.0.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:029ec25ca74abe48fdb58eb9fdd2a8c1618581fc33ce8e5653f8a1ffbfbd9326", size = 2627806, upload-time = "2025-11-05T19:06:57.063Z" },
{ url = "https://files.pythonhosted.org/packages/d3/d2/ce6953ca87db9cae3e775024184da7d1c5cb88cead19a2d75b42f00a959c/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4f709815924ec325b9a890e6ab2bbb0ceec8e319a4e257328eb752cf36b2efc", size = 2948463, upload-time = "2025-11-05T19:06:48.17Z" },
{ url = "https://files.pythonhosted.org/packages/fa/4c/b553c9651662d6ce102ca7f3629d268b23df1abe5841e24bed81e8a8e949/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5cfcfd963b50a41fc656c84d3440ca6eecdccd6c552158ce790b8f2e33dfb5a9", size = 2704083, upload-time = "2025-11-05T19:06:50.205Z" },
{ url = "https://files.pythonhosted.org/packages/9b/af/4eec8f9ab9c27bcdb444460c72cf43011d176fc44c79d6e113094ca1e152/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a3a16972aa1cee38ea958470cd04ac9a2d5ac38fdcf77ab686611246220c158", size = 2959765, upload-time = "2025-11-05T19:06:53.62Z" },
{ url = "https://files.pythonhosted.org/packages/11/3c/33f3374e4624e0e776f6b13b73c45a7ead7f9c4529f8369ed5bfcaa30cac/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4d5cfa168e74d08f8ba6d58a7e49bc7daef4d58951ec69b66b0d56f4927a68d", size = 3427031, upload-time = "2025-11-05T19:06:51.829Z" },
{ url = "https://files.pythonhosted.org/packages/25/3f/1a192b93bb47c6b44cd98ba8cc1d3d2a9308f1bb700c3017e6352da11bda/openai_harmony-0.0.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c007d277218a50db8839e599ed78e0fffe5130f614c3f6d93ae257f282071a29", size = 2953260, upload-time = "2025-11-05T19:06:55.406Z" },
{ url = "https://files.pythonhosted.org/packages/5b/f8/93b582cad3531797c3db7c2db5400fd841538ccddfd9f5e3df61be99a630/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:8565d4f5a0638da1bffde29832ed63c9e695c558611053add3b2dc0b56c92dbc", size = 3127044, upload-time = "2025-11-05T19:06:59.553Z" },
{ url = "https://files.pythonhosted.org/packages/1d/10/4327dbf87f75ae813405fd9a9b4a5cde63d506ffed0a096a440a4cabd89c/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:cbaa3bda75ef0d8836e1f8cc84af62f971b1d756d740efc95c38c3e04c0bfde2", size = 2932931, upload-time = "2025-11-05T19:07:01.437Z" },
{ url = "https://files.pythonhosted.org/packages/8a/c8/1774eec4f6f360ef57618fb8f52e3d3af245b2491bd0297513aa09eec04b/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:772922a9bd24e133950fad71eb1550836f415a88e8c77870e12d0c3bd688ddc2", size = 2996140, upload-time = "2025-11-05T19:07:03.438Z" },
{ url = "https://files.pythonhosted.org/packages/60/c3/3d1e01e2dba517a91760e4a03e4f20ffc75039a6fe584d0e6f9b5c78fd15/openai_harmony-0.0.8-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:007b0476a1f331f8130783f901f1da6f5a7057af1a4891f1b6a31dec364189b5", size = 3205080, upload-time = "2025-11-05T19:07:05.078Z" },
]
[[package]]
name = "packaging"
version = "25.0"