Compare commits

..

206 Commits

Author SHA1 Message Date
ciaranbor
9067033f20 Support image editing in runner 2026-01-09 16:15:51 +00:00
ciaranbor
17ef7d8838 Add editing features to adapter 2026-01-09 16:15:51 +00:00
ciaranbor
23962ab4e2 Default partial images to 3 if streaming 2026-01-09 16:15:51 +00:00
ciaranbor
59c5de8256 Add Qwen-Image model adapter 2026-01-09 16:15:51 +00:00
ciaranbor
a43a14da1f Add Qwen-Image-Edit model config 2026-01-09 16:15:51 +00:00
ciaranbor
bb46c12878 Use image generation in streaming mode in UI 2026-01-09 16:15:51 +00:00
ciaranbor
79b5316efe Handle partial image streaming 2026-01-09 16:15:51 +00:00
ciaranbor
40d49c2720 Add streaming params to ImageGenerationTaskParams 2026-01-09 16:15:51 +00:00
ciaranbor
f87d06b1f1 Add Qwen-Image-Edit-2509 2026-01-09 16:15:51 +00:00
ciaranbor
617c6ffdcb Handle image editing time steps 2026-01-09 16:15:51 +00:00
ciaranbor
acd246f49e Fix time steps 2026-01-09 16:15:51 +00:00
ciaranbor
d5536c1a2b Remove duplicate RunnerReady event 2026-01-09 16:15:51 +00:00
ciaranbor
2aba3fc9a9 Fix image_strength meaning 2026-01-09 16:15:51 +00:00
ciaranbor
c7cb22d546 Truncate image data logs 2026-01-09 16:15:51 +00:00
ciaranbor
5a429f4ab6 Chunk image input 2026-01-09 16:15:51 +00:00
ciaranbor
42c427a5bb Avoid logging image data 2026-01-09 16:15:51 +00:00
ciaranbor
4b8976be51 Support image editing 2026-01-09 16:15:51 +00:00
Sami Khan
85bf4aba1c small UI change 2026-01-09 16:15:51 +00:00
Sami Khan
de12873b1a image gen in dashboard 2026-01-09 16:15:51 +00:00
ciaranbor
a86bb97d65 Better llm model type check 2026-01-07 10:17:42 +00:00
ciaranbor
a5c6db7145 Prune blocks before model load 2026-01-06 16:47:58 +00:00
ciaranbor
e74345bb09 Own TODOs 2026-01-06 13:30:17 +00:00
ciaranbor
58d1f159b7 Remove double RunnerReady event 2026-01-06 13:25:01 +00:00
ciaranbor
8183225714 Fix hidden_size for image models 2026-01-06 12:33:28 +00:00
ciaranbor
46f957ee5b Fix uv.lock 2026-01-06 12:25:23 +00:00
ciaranbor
a2f52e04e3 Fix image model cards 2026-01-06 11:00:01 +00:00
ciaranbor
859960608b Skip decode on non-final ranks 2026-01-06 10:51:21 +00:00
ciaranbor
80ad016004 Final rank produces image 2026-01-06 10:51:21 +00:00
ciaranbor
d8938e6e72 Increase number of sync steps 2026-01-06 10:51:21 +00:00
ciaranbor
bd6a6cc6d3 Change Qwen-Image steps 2026-01-06 10:51:21 +00:00
ciaranbor
a88d588de4 Fix Qwen-Image latent shapes 2026-01-06 10:51:21 +00:00
ciaranbor
08e8a30fb7 Fix joint block patch recv shape for non-zero ranks 2026-01-06 10:51:21 +00:00
ciaranbor
d926df8f95 Fix comms issue for models without single blocks 2026-01-06 10:51:21 +00:00
ciaranbor
4ff550106d Support Qwen in DiffusionRunner pipefusion 2026-01-06 10:51:21 +00:00
ciaranbor
1d168dfe61 Implement Qwen pipefusion 2026-01-06 10:51:21 +00:00
ciaranbor
f813b9f5e1 Add guidance_scale parameter to image model config 2026-01-06 10:51:21 +00:00
ciaranbor
0f96083d48 Move orchestration to DiffusionRunner 2026-01-06 10:51:21 +00:00
ciaranbor
1a1b394f6d Add initial QwenModelAdapter 2026-01-06 10:51:21 +00:00
ciaranbor
6ab5a9d3d4 Tweak embeddings interface 2026-01-06 10:51:21 +00:00
ciaranbor
90bf4608df Add Qwen ImageModelConfig 2026-01-06 10:51:21 +00:00
ciaranbor
f2a0fdf25c Use 10% sync steps 2026-01-06 10:51:21 +00:00
ciaranbor
f574b3f57e Update FluxModelAdaper for new interface 2026-01-06 10:51:21 +00:00
ciaranbor
5cca9d8493 Register QwenModelAdapter 2026-01-06 10:51:21 +00:00
ciaranbor
3cd421079b Support multiple forward passes in runner 2026-01-06 10:51:21 +00:00
ciaranbor
d9eb4637ee Extend block wrapper parameters 2026-01-06 10:51:21 +00:00
ciaranbor
19f52e80fd Relax adaptor typing 2026-01-06 10:51:21 +00:00
ciaranbor
3f4162b732 Add Qwen-Image model card 2026-01-06 10:51:21 +00:00
ciaranbor
cad86ee76e Clean up dead code 2026-01-06 10:51:21 +00:00
ciaranbor
d7be6a09b0 Add BaseModelAdaptor 2026-01-06 10:51:21 +00:00
ciaranbor
79603e73ed Refactor filestructure 2026-01-06 10:51:21 +00:00
ciaranbor
78901cfe23 Treat unified blocks as single blocks (equivalent) 2026-01-06 10:51:21 +00:00
ciaranbor
c0ac199ab8 Refactor to handle entire denoising process in Diffusion runner 2026-01-06 10:51:21 +00:00
ciaranbor
b70d6abfa2 Move transformer to adapter 2026-01-06 10:51:21 +00:00
ciaranbor
16bfab9bab Move some more logic to adaptor 2026-01-06 10:51:21 +00:00
ciaranbor
28ee6f6370 Add generic block wrapper 2026-01-06 10:51:21 +00:00
ciaranbor
6b299bab8f Access transformer blocks from adaptor 2026-01-06 10:51:21 +00:00
ciaranbor
a3754a60b6 Better typing 2026-01-06 10:51:21 +00:00
ciaranbor
06039f93f5 Create wrappers at init time 2026-01-06 10:51:21 +00:00
ciaranbor
fcfecc9cd8 Combine model factory and adaptor 2026-01-06 10:51:21 +00:00
ciaranbor
ba798ae4f9 Implement model factory 2026-01-06 10:51:21 +00:00
ciaranbor
9a0e1e93a9 Add adaptor registry 2026-01-06 10:51:21 +00:00
ciaranbor
196f504c82 Remove mflux/generator/generate.py 2026-01-06 10:51:21 +00:00
ciaranbor
e3d89b8d63 Switch to using DistributedImageModel 2026-01-06 10:51:21 +00:00
ciaranbor
cb8079525c Add DistributedImageModel 2026-01-06 10:51:21 +00:00
ciaranbor
cb03c62c4a Use new generic wrappers, etc in denoising 2026-01-06 10:51:21 +00:00
ciaranbor
0653668048 Add generic transformer block wrappers 2026-01-06 10:51:21 +00:00
ciaranbor
0054bc4c14 Add FluxAdaptor 2026-01-06 10:51:21 +00:00
ciaranbor
b7b682b7bb Add ModelAdaptor, derivations implement model specific logic 2026-01-06 10:51:21 +00:00
ciaranbor
f7a651c1c1 Introduce image model config concept 2026-01-06 10:51:21 +00:00
ciaranbor
98e8d74cea Consolidate kv cache patching 2026-01-06 10:51:21 +00:00
ciaranbor
27567f8a4e Support different configuration comms 2026-01-06 10:51:21 +00:00
ciaranbor
28227bb45a Add ImageGenerator protocol 2026-01-06 10:51:21 +00:00
ciaranbor
7683d4a21f Force final patch receive order 2026-01-06 10:51:21 +00:00
ciaranbor
0a3cb77a29 Remove logs 2026-01-06 10:51:21 +00:00
ciaranbor
3f5810c1fe Update patch list 2026-01-06 10:51:21 +00:00
ciaranbor
fc62ae1b9b Slight refactor 2026-01-06 10:51:21 +00:00
ciaranbor
ec5bad4254 Don't need array for prev patches 2026-01-06 10:51:21 +00:00
ciaranbor
f9f54be32b Fix send/recv order 2026-01-06 10:51:21 +00:00
ciaranbor
36daf9183f Fix async single transformer block 2026-01-06 10:51:21 +00:00
ciaranbor
5d38ffc77e Use relative rank variables 2026-01-06 10:51:21 +00:00
ciaranbor
1b4851765a Fix writing patches 2026-01-06 10:51:21 +00:00
ciaranbor
8787eaf3df Collect final image 2026-01-06 10:51:21 +00:00
ciaranbor
e1e3aa7a5e Fix recv_template shape 2026-01-06 10:51:21 +00:00
ciaranbor
0fe5239273 Add logs 2026-01-06 10:51:21 +00:00
ciaranbor
7eddf7404b Optimise async pipeline 2026-01-06 10:51:21 +00:00
ciaranbor
5f3bc30f17 Add next_rank and prev_rank members 2026-01-06 10:51:21 +00:00
ciaranbor
90a7e6601d Add _create_patches method 2026-01-06 10:51:21 +00:00
ciaranbor
ce2691c8d3 Fix shapes 2026-01-06 10:51:21 +00:00
ciaranbor
076d2901e8 Reorder comms 2026-01-06 10:51:20 +00:00
ciaranbor
7a733b584c Remove all_gather from sync pipeline, send from final rank to first rank 2026-01-06 10:51:20 +00:00
ciaranbor
94fee6f2d2 Simplify kv_cache initialization 2026-01-06 10:51:20 +00:00
ciaranbor
ef4fe09424 Fix kv cache 2026-01-06 10:51:20 +00:00
ciaranbor
2919bcf21d Clean up kv caches 2026-01-06 10:51:20 +00:00
ciaranbor
dd84cc9ca2 Fix return 2026-01-06 10:51:20 +00:00
ciaranbor
5a74d76d41 Fix hidden_states shapes 2026-01-06 10:51:20 +00:00
ciaranbor
e115814c74 Only perform projection and scheduler step on last rank 2026-01-06 10:51:20 +00:00
ciaranbor
d85432d4f0 Only compute embeddings on rank 0 2026-01-06 10:51:20 +00:00
ciaranbor
da823a2b02 Remove eval 2026-01-06 10:51:20 +00:00
ciaranbor
8576f4252b Remove eval 2026-01-06 10:51:20 +00:00
ciaranbor
7ca0bc5b55 Only send encoder_hidden_states with the first patch (once per timestep) 2026-01-06 10:51:20 +00:00
ciaranbor
db24f052d7 Remove redundant text kv cache computation 2026-01-06 10:51:20 +00:00
ciaranbor
7b8382be10 Concatenate before all gather 2026-01-06 10:51:20 +00:00
ciaranbor
d3685b0eb5 Increase number of sync steps 2026-01-06 10:51:20 +00:00
ciaranbor
93f4bdc5f9 Reinitialise kv_caches between generations 2026-01-06 10:51:20 +00:00
ciaranbor
8eea0327b8 Eliminate double kv cache computation 2026-01-06 10:51:20 +00:00
ciaranbor
085358e5e0 Add kv cache caching wrappers for sync pipeline transformer blocks 2026-01-06 10:51:20 +00:00
ciaranbor
546efe4dd2 Persist kv caches 2026-01-06 10:51:20 +00:00
ciaranbor
4ddfb6e254 Implement naive async pipeline implementation 2026-01-06 10:51:20 +00:00
ciaranbor
12f20fd94e Use wrapper classes for patched transformer logic 2026-01-06 10:51:20 +00:00
ciaranbor
f7ba70d5ae Add patch-aware joint and single attention wrappers 2026-01-06 10:51:20 +00:00
ciaranbor
4ecad10a66 Fix group.size() 2026-01-06 10:51:20 +00:00
ciaranbor
552ae776fe Add classes to manage kv caches with patch support 2026-01-06 10:51:20 +00:00
ciaranbor
6e0a6e8956 Use heuristic for number of sync steps 2026-01-06 10:51:20 +00:00
ciaranbor
e8b0a2124c Generalise number of denoising steps 2026-01-06 10:51:20 +00:00
ciaranbor
129df1ec89 Add flux1-dev 2026-01-06 10:51:20 +00:00
ciaranbor
a87fe26973 Move scheduler step to inner pipeline 2026-01-06 10:51:20 +00:00
ciaranbor
a9ea223dc7 Add barrier before all_gather 2026-01-06 10:51:20 +00:00
ciaranbor
0af3349f2f Fix transformer blocks pruning 2026-01-06 10:51:20 +00:00
ciaranbor
20e3319a3e Fix image generation api 2026-01-06 10:51:20 +00:00
ciaranbor
4c88fac266 Create queue in try block 2026-01-06 10:51:20 +00:00
ciaranbor
e1d916f743 Conform to rebase 2026-01-06 10:51:20 +00:00
ciaranbor
09c9b2e29f Refactor denoising 2026-01-06 10:51:20 +00:00
ciaranbor
b6359a7199 Move more logic to DistributedFlux 2026-01-06 10:51:20 +00:00
ciaranbor
b5a043f676 Move surrounding logic back to _sync_pipeline 2026-01-06 10:51:20 +00:00
ciaranbor
55e690fd49 Add patching aware member variables 2026-01-06 10:51:20 +00:00
ciaranbor
9e4ffb11ec Implement sync/async switching logic 2026-01-06 10:51:20 +00:00
ciaranbor
d665a8d05a Move current transformer implementation to _sync_pipeline method 2026-01-06 10:51:20 +00:00
ciaranbor
cac77816be Remove some logs 2026-01-06 10:51:20 +00:00
ciaranbor
25b9c3369e Remove old Flux1 implementation 2026-01-06 10:51:20 +00:00
ciaranbor
c19c5b4080 Prune unused transformer blocks 2026-01-06 10:51:20 +00:00
ciaranbor
9592f8b6b0 Add mx.eval 2026-01-06 10:51:20 +00:00
ciaranbor
7d7c16ebc1 Test evals 2026-01-06 10:51:20 +00:00
ciaranbor
450d0ba923 Test only barriers 2026-01-06 10:51:20 +00:00
ciaranbor
ea64062362 All perform final projection 2026-01-06 10:51:20 +00:00
ciaranbor
206b12e912 Another barrier 2026-01-06 10:51:20 +00:00
ciaranbor
eecc1da596 More debug 2026-01-06 10:51:20 +00:00
ciaranbor
44e68e4498 Add barriers 2026-01-06 10:51:20 +00:00
ciaranbor
f1548452fa Add log 2026-01-06 10:51:20 +00:00
ciaranbor
97769c82a9 Restore distributed logging 2026-01-06 10:51:20 +00:00
ciaranbor
26e5b03285 Use bootstrap logger 2026-01-06 10:51:20 +00:00
ciaranbor
8f93a1ff78 Remove logs 2026-01-06 10:51:20 +00:00
ciaranbor
e07dcc43b9 fix single block receive shape 2026-01-06 10:51:20 +00:00
ciaranbor
f91d0797fb Add debug logs 2026-01-06 10:51:20 +00:00
ciaranbor
aaeebaf79e Move communication logic to DistributedTransformer wrapper 2026-01-06 10:51:20 +00:00
ciaranbor
c3075a003e Move inference logic to DistribuedFlux1 2026-01-06 10:51:20 +00:00
ciaranbor
be796e55ac Add DistributedFlux1 class 2026-01-06 10:51:20 +00:00
ciaranbor
6e0c611f37 Rename pipeline to pipefusion 2026-01-06 10:51:20 +00:00
ciaranbor
88996eddcb Further refactor 2026-01-06 10:51:20 +00:00
ciaranbor
fb4fae51fa Refactor warmup 2026-01-06 10:51:20 +00:00
ciaranbor
dbefc209f5 Manually handle flux1 inference 2026-01-06 10:51:20 +00:00
ciaranbor
e6dd95524c Refactor flux1 image generation 2026-01-06 10:51:20 +00:00
ciaranbor
c2a9e5e53b Use quality parameter to set number of inference steps 2026-01-06 10:51:20 +00:00
ciaranbor
21587898bc Chunk image data transfer 2026-01-06 10:51:20 +00:00
ciaranbor
b6f23d0b01 Define EXO_MAX_CHUNK_SIZE 2026-01-06 10:51:20 +00:00
ciaranbor
f00ba03f4b Add indexing info to ImageChunk 2026-01-06 10:50:56 +00:00
ciaranbor
73e3713296 Remove sharding logs 2026-01-06 10:50:56 +00:00
ciaranbor
ecca6b4d20 Temp: reduce flux1.schnell storage size 2026-01-06 10:50:56 +00:00
ciaranbor
8bac08a236 Fix mflux transformer all_gather 2026-01-06 10:50:34 +00:00
ciaranbor
e7cca752fd Add all_gather -> broadcast todo 2026-01-06 10:50:34 +00:00
ciaranbor
540fe8b278 Fix world size 2026-01-06 10:50:34 +00:00
ciaranbor
2972f4620c Fix transition block? 2026-01-06 10:50:34 +00:00
ciaranbor
0ed81d8afa Implement image generation warmup 2026-01-06 10:50:34 +00:00
ciaranbor
66a24d59b9 Add logs 2026-01-06 10:50:11 +00:00
ciaranbor
5dcc359dba Add spiece.model to default patterns 2026-01-06 10:50:11 +00:00
ciaranbor
c2a4d61865 Just download all files for now 2026-01-06 10:49:43 +00:00
ciaranbor
ba12ee4897 Fix get_allow_patterns to include non-indexed safetensors files 2026-01-06 10:49:43 +00:00
ciaranbor
bcd69a3b01 Use half-open layer indexing in get_allow_patterns 2026-01-06 10:49:43 +00:00
ciaranbor
f5eb5d0338 Enable distributed mflux 2026-01-06 10:49:43 +00:00
ciaranbor
058aff5145 Implement mflux transformer sharding and communication pattern 2026-01-06 10:49:43 +00:00
ciaranbor
5cb0bc6a63 Update get_allow_patterns to handle sharding components 2026-01-06 10:49:43 +00:00
ciaranbor
c3aab450c6 Namespace both keys and values for component weight maps 2026-01-06 10:49:43 +00:00
ciaranbor
cf27673e20 Add components to Flux.1-schnell MODEL_CARD 2026-01-06 10:49:43 +00:00
ciaranbor
96c165e297 Add component concept for ModelMetadata 2026-01-06 10:48:42 +00:00
ciaranbor
2a589177cd Fix multiple components weight map key conflicts 2026-01-06 10:48:26 +00:00
ciaranbor
f782b619b6 get_weight_map: handle repos with multiple safetensors.index.json files 2026-01-06 10:48:26 +00:00
ciaranbor
dc661e4b5e Add initial image edits spec 2026-01-06 10:48:26 +00:00
ciaranbor
8b7d8ef394 Add image edits endpoint 2026-01-06 10:47:44 +00:00
ciaranbor
7dd2b328c8 Add ImageToImage task 2026-01-06 10:45:26 +00:00
ciaranbor
73a165702d Allow ModelCards to have multiple tasks 2026-01-06 10:44:53 +00:00
ciaranbor
0c76978b35 Fix text generation 2026-01-06 10:41:38 +00:00
ciaranbor
25188c845e Rename mlx_generate_image to mflux_generate 2026-01-06 10:41:38 +00:00
ciaranbor
df94169aba Initialize mlx or mflux engine based on model task 2026-01-06 10:39:27 +00:00
ciaranbor
a2d4c0de2a Restore warmup for text generation 2026-01-06 10:17:21 +00:00
ciaranbor
2edbc7e026 Add initialize_mflux function 2026-01-06 10:17:21 +00:00
ciaranbor
8f6e360d21 Move image generation to mflux engine 2026-01-06 10:17:21 +00:00
ciaranbor
085b966a5f Just use str for image generation size 2026-01-06 10:17:21 +00:00
ciaranbor
c64a55bfed Use MFlux for image generation 2026-01-06 10:17:21 +00:00
ciaranbor
fee716faab Add get_model_card function 2026-01-06 10:17:21 +00:00
ciaranbor
b88c89ee9c Add ModelTask enum 2026-01-06 10:17:21 +00:00
ciaranbor
9ef7b913e2 ADd flux1-schnell model 2026-01-06 10:08:11 +00:00
ciaranbor
0daa4b36db Add task field to ModelCard 2026-01-06 10:08:11 +00:00
ciaranbor
3c2da43792 Update mflux version 2026-01-06 10:04:51 +00:00
ciaranbor
8c4c53b50a Enable recursive repo downloads 2026-01-06 10:03:18 +00:00
ciaranbor
b2beb4c9cd Add dummy generate_image implementation 2026-01-06 10:03:18 +00:00
ciaranbor
098a11b262 Use base64 encoded str for image data 2026-01-06 10:03:18 +00:00
ciaranbor
bedb9045a0 Handle ImageGeneration tasks in _pending_tasks 2026-01-06 10:03:18 +00:00
ciaranbor
8e23841b4e Add mflux dependency 2026-01-06 10:03:18 +00:00
ciaranbor
4420eac10d Handle ImageGeneration task in runner task processing 2026-01-06 10:02:04 +00:00
ciaranbor
d0772e9e0f Handle ImageGeneration command in master command processing 2026-01-06 10:00:32 +00:00
ciaranbor
8d861168f1 Add image generation to API 2026-01-06 10:00:07 +00:00
ciaranbor
242648dff4 Add ImageGenerationResponse 2026-01-06 09:59:13 +00:00
ciaranbor
9b06b754cb Add ImageGeneration task 2026-01-06 09:59:13 +00:00
ciaranbor
1603984f45 Add ImageGeneration command 2026-01-06 09:58:46 +00:00
ciaranbor
f9418843f8 Add image generation params and response types 2026-01-05 21:51:10 +00:00
ciaranbor
877e7196c3 Add pillow dependency 2026-01-05 21:51:10 +00:00
ciaranbor
db7c4670b9 Fix mlx stream_generate import 2026-01-05 21:18:23 +00:00
163 changed files with 17014 additions and 11277 deletions

159
.github/benchmark-dashboard/README.md vendored Normal file
View File

@@ -0,0 +1,159 @@
# EXO Benchmark Dashboard
A fully self-contained, browser-based dashboard for tracking EXO benchmark performance over time.
## Features
- 📊 **Success Rate Tracking**: Monitor cluster reliability across commits
-**Response Time Analysis**: Track average request completion times
- 🎯 **Throughput Metrics**: Tokens per second visualization
- 📈 **Request Distribution**: Success/failure breakdown over time
- 🔄 **Auto-Refresh**: Updates every 60 seconds
- 📺 **TV-Ready**: Large, clear visualizations perfect for display
- 🔐 **Secure**: Credentials stored in browser localStorage only
- 🌐 **No Backend**: Directly accesses S3 from the browser
## Quick Start
### Option 1: Direct File Access (Simplest)
Just open the HTML file directly in your browser:
```bash
open .github/benchmark-dashboard/index.html
```
Then click "Configure AWS Credentials" and enter your keys.
### Option 2: URL Parameters (For Quick Setup)
```bash
# Serve with credentials in URL (they'll be moved to localStorage)
open ".github/benchmark-dashboard/index.html?accessKey=YOUR_KEY&secretKey=YOUR_SECRET&region=us-east-1"
```
The credentials will be saved to localStorage and removed from the URL immediately.
### Option 3: Simple HTTP Server
```bash
# From repo root
python3 -m http.server 8080
# Then open: http://localhost:8080/.github/benchmark-dashboard/
```
## AWS Credentials
The dashboard needs read-only access to the `exo-benchmark-results` S3 bucket.
### Required IAM Permissions
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::exo-benchmark-results",
"arn:aws:s3:::exo-benchmark-results/*"
]
}
]
}
```
### Security Notes
- ✅ Credentials stored in browser `localStorage` only
- ✅ Never sent to any server (except AWS)
- ✅ All S3 access happens client-side
- ✅ Use read-only IAM credentials
- ⚠️ Don't commit credentials to git
- ⚠️ Use a dedicated read-only IAM user
## TV/Kiosk Mode
For permanent display on a TV:
### macOS
```bash
open -a "Google Chrome" --args --kiosk ".github/benchmark-dashboard/index.html"
```
### Linux
```bash
chromium-browser --kiosk --app="file://$(pwd)/.github/benchmark-dashboard/index.html"
```
### Auto-start on Boot
Create a simple startup script:
```bash
#!/bin/bash
# /usr/local/bin/start-benchmark-dashboard.sh
cd /path/to/exo
python3 -m http.server 8080 &
sleep 2
chromium-browser --kiosk http://localhost:8080/.github/benchmark-dashboard/
```
## Data Displayed
### Summary Cards
- **Latest Success Rate**: Most recent benchmark success percentage with trend
- **Avg Response Time**: Latest average response time in ms with trend
- **Total Benchmarks**: Count of all benchmarks run
- **Active Configurations**: Number of unique benchmark configs
### Charts
1. **Success Rate Over Time**: Line chart showing reliability trends
2. **Average Response Time**: Performance over time (lower is better)
3. **Throughput**: Tokens/second metric (higher is better)
4. **Request Distribution**: Stacked bar chart of successes/failures
## How It Works
1. **Loads AWS SDK**: Uses AWS SDK for JavaScript (browser version)
2. **Lists S3 Objects**: Fetches all files from `s3://exo-benchmark-results/bench/`
3. **Downloads Results**: Fetches each JSON result file
4. **Parses & Visualizes**: Uses Chart.js to create interactive charts
5. **Auto-Refreshes**: Polls S3 every 60 seconds for new results
## Customization
To modify the dashboard:
1. Edit `index.html`
2. Adjust `REFRESH_INTERVAL` for different polling frequency
3. Modify chart colors/styles in the Chart.js configuration
4. Add new metrics by extending the results parsing
## Troubleshooting
**"AWS credentials not configured"**
- Click "Configure AWS Credentials" and enter your keys
**"Error loading benchmark data"**
- Check AWS credentials are correct
- Verify S3 bucket name is `exo-benchmark-results`
- Ensure IAM user has read permissions
- Check browser console for detailed errors
**"No benchmark results found"**
- Wait for benchmark workflows to run
- Verify results are being uploaded to S3
- Check S3 bucket has files in `bench/` prefix
**Charts not updating**
- Check browser console for errors
- Verify network connectivity to S3
- Try refreshing the page manually

1641
.github/benchmark-dashboard/index.html vendored Normal file
View File

File diff suppressed because it is too large Load Diff

186
.github/configs/README.md vendored Normal file
View File

@@ -0,0 +1,186 @@
# EXO Benchmark Configurations
This directory contains configuration files for the EXO staged benchmark system.
## Overview
The staged benchmark system allows you to run complex, multi-stage load tests against EXO clusters. Each stage can have different characteristics:
- **Prompt Length**: Number of tokens in the input prompt
- **Generation Length**: Maximum tokens to generate in the response
- **Time Between Requests**: Delay (in seconds) between firing consecutive requests
- **Iterations**: Number of requests to send in this stage
Requests are **fire-and-forget** - they don't wait for the previous request to complete. This allows you to test overlapping request handling and measure success rates under load.
## Configuration Files
### `bench_simple.yaml`
A minimal configuration that replicates the behavior of the original `bench.py` script:
- Single stage with 1 iteration
- Short prompt (~20 tokens)
- Generates up to 100 tokens
This is useful for quick smoke tests.
### `bench_config.yaml`
A comprehensive multi-stage benchmark with:
1. **Warmup** (10 requests): Light load with short prompts
2. **Medium Load** (20 requests): Moderate load with medium prompts
3. **Stress Test** (30 requests): Heavy overlapping requests with long prompts
4. **Cooldown** (5 requests): Light load to wind down
This tests the cluster's behavior under varying load patterns.
## Configuration Schema
```yaml
# Hardware configuration - maps runner labels to instance counts
hardware_plan:
M3ULTRA_GPU80_512GB: 4
# Environment variables to set on each node (optional)
environment:
OVERRIDE_MEMORY_MB: 512
# Timeout for instance and runner readiness (seconds)
timeout_seconds: 600
# Model instances to run concurrently
model_ids:
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
# Benchmark stages
stages:
- name: "stage_name" # Human-readable name for this stage
prompt_length: 100 # Target prompt length in tokens
generation_length: 200 # Max tokens to generate
time_between_requests: 2.0 # Seconds between firing requests
iterations: 10 # Number of requests in this stage
```
## Running Benchmarks
### Via GitHub Actions
**Automatic (every commit):**
- The **`bench`** workflow runs automatically on every push
- Uses `bench_simple.yaml` as the default configuration
- All settings (hardware plan, timeout, environment variables, models, stages) are defined in the config file
**Manual (on-demand):**
1. Go to **Actions****bench** workflow
2. Click **Run workflow**
3. Configure:
- **Config File**: Path to your YAML config (default: `.github/configs/bench_simple.yaml`)
- `.github/configs/bench_simple.yaml` for quick tests
- `.github/configs/bench_config.yaml` for complex multi-stage tests
All other settings (hardware plan, timeout, environment variables, models, stages) are read from the specified config file.
### Via Command Line
```bash
# Start EXO on localhost:8000
uv run exo --api-port 8000
# Run simple benchmark (1 stage, 1 iteration)
python3 .github/scripts/bench.py \
--api-port 8000 \
--config .github/configs/bench_simple.yaml \
--expected-nodes 1 \
--is-primary true \
--timeout-seconds 600
# Run complex staged benchmark (4 stages, multiple iterations)
python3 .github/scripts/bench.py \
--api-port 8000 \
--config .github/configs/bench_config.yaml \
--expected-nodes 1 \
--is-primary true \
--timeout-seconds 600
```
## Output Metrics
For each stage, the benchmark reports:
- **Total Requests**: Number of requests fired
- **Successful Requests**: Requests that completed successfully
- **Failed Requests**: Requests that encountered errors
- **Success Rate**: Percentage of successful requests
- **Total Tokens**: Sum of all tokens generated across successful requests
- **Avg Tokens/Request**: Average tokens per successful request
- **Avg Time/Request**: Average completion time per successful request
A JSON summary is also printed for easy parsing and storage.
## Creating Custom Benchmarks
To create a custom benchmark:
1. Copy an existing config file (e.g., `bench_config.yaml`)
2. Modify the stages to match your test scenario
3. Save it in this directory with a descriptive name
4. Run it using the workflow or command line
### Example: Sustained Load Test
```yaml
hardware_plan:
M3ULTRA_GPU80_512GB: 2
environment:
OVERRIDE_MEMORY_MB: 1024
timeout_seconds: 600
model_ids:
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
stages:
- name: "sustained_load"
prompt_length: 200
generation_length: 150
time_between_requests: 0.5 # Very fast - 2 requests/second
iterations: 100 # Run for ~50 seconds
```
### Example: Varying Prompt Sizes
```yaml
hardware_plan:
M4PRO_GPU16_24GB: 3
timeout_seconds: 900
model_ids:
- "mlx-community/Llama-3.2-1B-Instruct-4bit"
stages:
- name: "tiny_prompts"
prompt_length: 10
generation_length: 100
time_between_requests: 1.0
iterations: 10
- name: "medium_prompts"
prompt_length: 200
generation_length: 100
time_between_requests: 1.0
iterations: 10
- name: "large_prompts"
prompt_length: 1000
generation_length: 100
time_between_requests: 1.0
iterations: 10
```
## Tips
- **Overlapping Requests**: Set `time_between_requests` < expected completion time to test concurrent request handling
- **Sequential Requests**: Set `time_between_requests` > expected completion time to ensure requests don't overlap
- **Realistic Load**: Model real usage patterns by varying prompt/generation lengths across stages
- **Success Rate**: A 100% success rate indicates the cluster handled the load well; lower rates suggest capacity limits

49
.github/configs/bench_config.yaml vendored Normal file
View File

@@ -0,0 +1,49 @@
# EXO Staged Benchmark Configuration
# This configuration defines a multi-stage load test for EXO clusters
# Hardware configuration - maps runner labels to instance counts
hardware_plan:
M3ULTRA_GPU80_512GB: 4
# Environment variables to set on each node (optional)
environment:
OVERRIDE_MEMORY_MB: 512
# Timeout for instance and runner readiness (seconds)
timeout_seconds: 600
# Multiple instances run concurrently on the cluster
model_ids:
- "mlx-community/Qwen3-0.6B-4bit"
- "mlx-community/Qwen3-0.6B-4bit"
# Stages run sequentially, each with its own characteristics
stages:
# Stage 1: Light load with short prompts
- name: "warmup"
prompt_length: 50 # Number of tokens in prompt
generation_length: 100 # Max tokens to generate
time_between_requests: 5.0 # Seconds between firing requests
iterations: 10 # Number of requests to send in this stage
# Stage 2: Medium load with medium prompts
- name: "medium_load"
prompt_length: 200
generation_length: 150
time_between_requests: 3.0
iterations: 20
# Stage 3: Heavy load with long prompts - requests will overlap
- name: "stress_test"
prompt_length: 500
generation_length: 200
time_between_requests: 1.0 # Fast firing - will definitely overlap
iterations: 30
# Stage 4: Cool down with simple prompts
- name: "cooldown"
prompt_length: 50
generation_length: 50
time_between_requests: 10.0
iterations: 5

125
.github/configs/bench_simple.yaml vendored Normal file
View File

@@ -0,0 +1,125 @@
# Simple single-shot benchmark
# Tests 2 instances concurrently on 2 nodes
# Hardware configuration - maps runner labels to instance counts
hardware_plan:
puffin4: 1
puffin8: 1
# Environment variables to set on each node
environment:
PLACEHOLDER: "placeholder"
# OVERRIDE_MEMORY_MB: 50000
MLX_METAL_FAST_SYNCH: 1
# Timeout for instance and runner readiness (seconds)
timeout_seconds: 1800
# Model instances to run concurrently
model_ids:
# - "mlx-community/DeepSeek-V3.1-8bit"
# - "mlx-community/Kimi-K2-Instruct-4bit"
- "mlx-community/Kimi-K2-Thinking"
# - "mlx-community/Qwen3-235B-A22B-4bit"
# - "mlx-community/Llama-3.3-70B-Instruct-4bit"
# - "mlx-community/Llama-3.3-70B-Instruct-8bit"
# - "mlx-community/Llama-3.2-1B-Instruct-4bit"
# Sharding strategy: "Pipeline" or "Tensor"
sharding: "Tensor"
# Instance type: "MlxRing" or "MlxIbv"
instance_meta: "MlxIbv"
# If true, run requests sequentially (no overlap); if false, fire-and-forget (default: false)
no_overlap: true
# Benchmark stages
# pp: 64, 256, 1024, 2048, 4096, 8192, 16384
# g: 64, 512
stages:
# - name: "simple"
# prompt_length: 512
# generation_length: 10
# time_between_requests: 2.0
# iterations: 5
# - name: "pp64_g64"
# prompt_length: 64
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp64_g64"
# prompt_length: 64
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp64_g512"
# prompt_length: 64
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp256_g64"
# prompt_length: 256
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
- name: "pp256_g64"
prompt_length: 256
generation_length: 64
time_between_requests: 2.0
iterations: 5
# - name: "pp256_g512"
# prompt_length: 256
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp1024_g64"
# prompt_length: 1024
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp1024_g512"
# prompt_length: 1024
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp2048_g64"
# prompt_length: 2048
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp2048_g512"
# prompt_length: 2048
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp4096_g64"
# prompt_length: 4096
# generation_length: 64
# time_between_requests: 2.0
# iterations: 4
# - name: "pp4096_g512"
# prompt_length: 4096
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10
# - name: "pp8192_g64"
# prompt_length: 8192
# generation_length: 64
# time_between_requests: 2.0
# iterations: 5
# - name: "pp8192_g512"
# prompt_length: 8192
# generation_length: 512
# time_between_requests: 2.0
# iterations: 5
# - name: "pp16384_g64"
# prompt_length: 16384
# generation_length: 64
# time_between_requests: 2.0
# iterations: 10
# - name: "pp16384_g512"
# prompt_length: 16384
# generation_length: 512
# time_between_requests: 2.0
# iterations: 10

1399
.github/scripts/bench.py vendored Normal file
View File

File diff suppressed because it is too large Load Diff

70
.github/scripts/build_matrix.py vendored Normal file
View File

@@ -0,0 +1,70 @@
#!/usr/bin/env python3
import json
import os
from typing import NotRequired, TypedDict, cast
import yaml
class MatrixEntry(TypedDict):
label: str
index: int
class MatrixInclude(TypedDict):
label: str
index: int
is_primary: bool
expected_nodes: int
class Config(TypedDict):
hardware_plan: dict[str, int]
timeout_seconds: NotRequired[int]
environment: NotRequired[dict[str, str]]
# Read the config file
config_file: str = os.environ["CONFIG_FILE"]
with open(config_file, "r") as f:
config: Config = cast(Config, yaml.safe_load(f))
# Extract hardware plan from config
plan: dict[str, int] = config["hardware_plan"]
if not plan:
raise ValueError(f"No hardware_plan found in {config_file}")
# Build matrix entries
entries: list[MatrixEntry] = []
for label, count in plan.items():
for idx in range(count):
entries.append({"label": label, "index": idx})
total_nodes: int = len(entries)
matrix: dict[str, list[MatrixInclude]] = {
"include": [
{
"label": e["label"],
"index": e["index"],
"is_primary": (i == 0),
"expected_nodes": total_nodes,
}
for i, e in enumerate(entries)
]
}
# Extract other config values
timeout_seconds: int = config.get("timeout_seconds", 600)
environment: dict[str, str] = config.get("environment", {})
# Output to GitHub Actions
with open(os.environ["GITHUB_OUTPUT"], "a") as f:
f.write(f"matrix={json.dumps(matrix)}\n")
f.write(f"config_file={config_file}\n")
f.write(f"timeout_seconds={timeout_seconds}\n")
f.write(f"environment={json.dumps(environment)}\n")
print(f"Matrix: {json.dumps(matrix)}")
print(f"Config file: {config_file}")
print(f"Timeout: {timeout_seconds}")
print(f"Environment: {json.dumps(environment)}")

156
.github/workflows/BENCH_USAGE.md vendored Normal file
View File

@@ -0,0 +1,156 @@
# Benchmark Workflow Usage
## Overview
The `bench_matrix.yml` workflow enables distributed benchmarking of models across multiple self-hosted macOS runners with different hardware configurations.
## Workflow Inputs
| Input | Description | Default | Required |
|-------|-------------|---------|----------|
| `model_id` | Model ID to benchmark | `mlx-community/Llama-3.2-1B-Instruct-4bit` | Yes |
| `hardware_plan` | JSON mapping of runner labels to counts | `{"M4PRO_GPU16_24GB": 1}` | Yes |
| `prompt` | Benchmark prompt text | `What is the capital of France?` | No |
| `timeout_seconds` | Timeout for instance/runner readiness | `600` | No |
## Hardware Plan Format
The `hardware_plan` input is a JSON object mapping runner labels to the number of machines:
```json
{
"M4PRO_GPU16_24GB": 2,
"M3ULTRA_GPU80_512GB": 1
}
```
This example would:
- Start 2 runners with the `M4PRO_GPU16_24GB` label
- Start 1 runner with the `M3ULTRA_GPU80_512GB` label
- Total of 3 runners coordinating on a single distributed inference instance
## How It Works
1. **Planning Job** (`plan`)
- Runs on `ubuntu-latest`
- Parses the `hardware_plan` JSON
- Generates a dynamic matrix with one entry per runner
- Only the first runner (index 0) is marked as `is_primary`
2. **Benchmark Worker Jobs** (`bench_worker`)
- Each job runs on a self-hosted macOS runner with the specified label
- All runners start EXO in parallel
- The primary runner creates the model instance
- All runners wait for their assigned runner to be ready (Loaded/Running status)
- The primary runner executes the benchmark and prints results
- The primary runner deletes the instance
## Example Usage
### Single Machine Benchmark
```yaml
model_id: mlx-community/Llama-3.2-1B-Instruct-4bit
hardware_plan: '{"M4PRO_GPU16_24GB": 1}'
prompt: What is the capital of France?
timeout_seconds: 600
```
### Multi-Machine Distributed Benchmark
```yaml
model_id: mlx-community/Llama-3.2-3B-Instruct-4bit
hardware_plan: '{"M4PRO_GPU16_24GB": 2, "M3ULTRA_GPU80_512GB": 1}'
prompt: Explain quantum computing in simple terms.
timeout_seconds: 900
```
## Benchmark Output
The primary runner outputs a JSON object with benchmark results:
```json
{
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"instance_id": "abc-123-def",
"tokens": 42,
"elapsed_s": 2.451,
"tps": 17.136
}
```
Where:
- `tokens`: Number of chunks/tokens generated
- `elapsed_s`: Total elapsed time in seconds
- `tps`: Tokens per second (tokens / elapsed_s)
## Runner Requirements
Each self-hosted runner must:
- Be labeled with appropriate hardware tags (e.g., `M4PRO_GPU16_24GB`)
- Have the `self-hosted` and `macOS` labels
- Have Nix installed with flakes enabled
- Have network connectivity to other runners in the same job
## Architecture
```
┌─────────────────────────────────────────────────────────────┐
│ GitHub Actions Workflow (bench_matrix.yml) │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌────────────────┐ │
│ │ Plan Job │ │
│ │ (ubuntu) │──┬─► Matrix: [{label, index, primary}] │
│ └────────────────┘ │ │
│ │ │
│ ┌───────────────────▼──────────────────────────────────┐ │
│ │ Bench Worker Jobs (Matrix) │ │
│ ├──────────────────────────────────────────────────────┤ │
│ │ │ │
│ │ Runner 0 (Primary) Runner 1 Runner 2 │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ │
│ │ │ Start EXO │ │ Start EXO │ │ Start EXO│ │ │
│ │ │ Create Inst │ │ Wait... │ │ Wait... │ │ │
│ │ │ Wait Ready │ │ Wait Ready │ │ Wait... │ │ │
│ │ │ Run Bench │ │ (idle) │ │ (idle) │ │ │
│ │ │ Print TPS │ │ │ │ │ │ │
│ │ │ Delete Inst │ │ │ │ │ │ │
│ │ └─────────────┘ └─────────────┘ └──────────┘ │ │
│ └───────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
```
## Implementation Details
### `scripts/bench.py`
A standalone Python script that:
- Creates instance (primary only)
- Polls `/state` endpoint until instance and all runners are ready
- Executes chat completion with timing (primary only)
- Parses SSE stream and counts tokens
- Computes TPS metrics
- Cleans up instance (primary only)
### Key Functions
- `wait_for_instance()`: Polls until instance with model_id appears
- `wait_for_runners_ready()`: Polls until expected number of runners reach Loaded/Running status
- `run_benchmark()`: Executes chat completion, measures time, counts tokens
## Troubleshooting
### Instance never becomes ready
- Check EXO logs in the workflow output
- Verify model_id is valid and accessible
- Increase `timeout_seconds`
### Runner mismatch
- Ensure hardware_plan counts match available labeled runners
- Check runner labels match exactly (case-sensitive)
### Network issues
- Verify runners can communicate on the network
- Check firewall rules between runner hosts

305
.github/workflows/bench.yml vendored Normal file
View File

@@ -0,0 +1,305 @@
name: bench
on: [push]
jobs:
plan:
if: contains(github.event.head_commit.message, '/bench')
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.build.outputs.matrix }}
config_file: ${{ steps.build.outputs.config_file }}
timeout_seconds: ${{ steps.build.outputs.timeout_seconds }}
environment: ${{ steps.build.outputs.environment }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Build matrix from config file
id: build
shell: bash
run: |
set -euo pipefail
CONFIG_FILE='.github/configs/bench_simple.yaml'
export CONFIG_FILE
echo "Config file: $CONFIG_FILE"
python3 .github/scripts/build_matrix.py
bench_worker:
needs: plan
strategy:
fail-fast: false
matrix: ${{ fromJSON(needs.plan.outputs.matrix) }}
name: "bench on ${{ matrix.label }} [${{ matrix.index }}]"
runs-on: [self-hosted, macOS, "${{ matrix.label }}"]
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
# TODO: this is mega hacky and I'd like a simpler solution.
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix is already available
if command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
# Try sourcing profile scripts to set up environment properly
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Sourcing multi-user nix-daemon profile script"
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
elif [ -f "$HOME/.nix-profile/etc/profile.d/nix.sh" ]; then
echo "Sourcing single-user nix profile script"
source "$HOME/.nix-profile/etc/profile.d/nix.sh"
elif [ -f /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh ]; then
echo "Sourcing per-user nix profile script"
source /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh
elif [ -f /etc/profile.d/nix.sh ]; then
echo "Sourcing system-wide nix profile script"
source /etc/profile.d/nix.sh
# Fallback: manually add nix to PATH if binary exists
elif [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary, manually adding to PATH"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
elif [ -f "$HOME/.nix-profile/bin/nix" ]; then
echo "Found nix binary in user profile, manually adding to PATH"
export PATH="$HOME/.nix-profile/bin:$PATH"
else
echo "Nix not found. Debugging info:"
echo "USER: $USER"
echo "HOME: $HOME"
echo "Current PATH: $PATH"
echo ""
echo "Checking common Nix locations:"
echo " /nix/var/nix/profiles/default/bin/nix:"
ls -la /nix/var/nix/profiles/default/bin/nix 2>/dev/null || echo " Not found"
echo " /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh:"
ls -la /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh 2>/dev/null || echo " Not found"
echo " ~/.nix-profile/etc/profile.d/nix.sh:"
ls -la "$HOME/.nix-profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
echo " /nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh:"
ls -la "/nix/var/nix/profiles/per-user/$USER/profile/etc/profile.d/nix.sh" 2>/dev/null || echo " Not found"
echo ""
echo "/nix directory structure:"
ls -la /nix 2>/dev/null || echo " /nix directory not found"
echo ""
echo "/nix/var:"
ls -la /nix/var 2>/dev/null || echo " /nix/var not found"
echo ""
echo "/nix/store:"
ls -la /nix/store 2>/dev/null | head -20 || echo " /nix/store not found"
echo ""
echo "GitHub Actions runner is running as user '$USER'."
echo "If Nix is installed for a different user, either:"
echo " 1. Install Nix for user '$USER' (multi-user install recommended)"
echo " 2. Configure the runner service to run as the user with Nix installed"
echo " 3. Ensure Nix is installed system-wide with proper daemon setup"
exit 1
fi
# Verify nix is available and persist to GITHUB_ENV
if command -v nix >/dev/null 2>&1; then
echo "✓ Nix is available"
nix --version
echo "PATH=$PATH" >> $GITHUB_ENV
if [ -n "$NIX_PATH" ]; then
echo "NIX_PATH=$NIX_PATH" >> $GITHUB_ENV
fi
else
echo "ERROR: Failed to set up Nix"
echo "PATH after setup attempt: $PATH"
exit 1
fi
shell: bash
- name: Setup EXO_HOME and API_PORT
run: |
EXO_HOME=$(mktemp -d -t exo-e2e-XXXXXXXX)
API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
EXO_MODELS_DIR="$HOME/.exo/models"
EXO_LIBP2P_NAMESPACE="bench-${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}"
echo "EXO_HOME=$EXO_HOME" >> "$GITHUB_ENV"
echo "API_PORT=$API_PORT" >> "$GITHUB_ENV"
echo "EXO_MODELS_DIR=$EXO_MODELS_DIR" >> "$GITHUB_ENV"
echo "EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE" >> "$GITHUB_ENV"
echo "Created EXO_HOME: $EXO_HOME"
echo "Generated API_PORT: $API_PORT"
echo "Using models from: $EXO_MODELS_DIR"
echo "Using libp2p namespace: $EXO_LIBP2P_NAMESPACE"
shell: bash
- name: Configure local MLX if available
run: |
echo "=== DEBUG: Checking for local MLX configuration ==="
MODIFIED=false
echo "Checking for /Users/Shared/mlx directory..."
if [ -d "/Users/Shared/mlx" ]; then
echo "✓ Found /Users/Shared/mlx"
ls -la /Users/Shared/mlx | head -5
echo "Enabling local mlx path in pyproject.toml"
sed -i.bak 's|^# mlx = { path = "/Users/Shared/mlx", editable=true }$|mlx = { path = "/Users/Shared/mlx", editable=true }|' pyproject.toml
MODIFIED=true
else
echo "✗ /Users/Shared/mlx not found, will use PyPI version"
fi
echo "Checking for /Users/Shared/mlx-lm directory..."
if [ -d "/Users/Shared/mlx-lm" ]; then
echo "✓ Found /Users/Shared/mlx-lm"
ls -la /Users/Shared/mlx-lm | head -5
echo "Enabling local mlx-lm path in pyproject.toml"
sed -i.bak 's|^# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }$|mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }|' pyproject.toml
MODIFIED=true
else
echo "✗ /Users/Shared/mlx-lm not found, will use PyPI version"
fi
if [ "$MODIFIED" = true ]; then
echo "=== Modified pyproject.toml [tool.uv.sources] section: ==="
sed -n '/\[tool\.uv\.sources\]/,/^\[/{/^\[tool\.uv\.sources\]/p; /^\[/!p;}' pyproject.toml
echo "=== Regenerating uv.lock with local MLX paths... ==="
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command uv lock --upgrade-package mlx --upgrade-package mlx-lm
echo "✓ Lock file regenerated"
else
echo "⚠ No local MLX directories found, using PyPI packages"
fi
echo "=== DEBUG: Local MLX configuration complete ==="
shell: bash
- name: Sync dependencies
run: |
if [ -d "/Users/Shared/test" ]; then
pushd /Users/Shared/test
uv sync --reinstall
popd
fi
echo "Running just sync to ensure clean dependencies..."
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command just sync
shell: bash
- name: Start EXO and run bench script
shell: bash
env:
IS_PRIMARY: ${{ matrix.is_primary }}
EXPECTED_NODES: ${{ matrix.expected_nodes }}
HARDWARE_LABEL: ${{ matrix.label }}
CONFIG_FILE: ${{ needs.plan.outputs.config_file }}
TIMEOUT_SECONDS: ${{ needs.plan.outputs.timeout_seconds }}
ENVIRONMENT_JSON: ${{ needs.plan.outputs.environment }}
run: |
set -euo pipefail
# Parse environment variables from config
ENV_VARS=""
if [ -n "$ENVIRONMENT_JSON" ] && [ "$ENVIRONMENT_JSON" != "{}" ]; then
ENV_VARS=$(echo "$ENVIRONMENT_JSON" | python3 -c "import sys, json; env = json.load(sys.stdin); print(' '.join([f'{k}={v}' for k, v in env.items()]))")
fi
echo "Starting EXO with API_PORT=${API_PORT} EXO_HOME=${EXO_HOME} EXO_LIBP2P_NAMESPACE=${EXO_LIBP2P_NAMESPACE}"
echo "Environment variables from config: $ENV_VARS"
LOG_FILE=/tmp/exo.log
: > "$LOG_FILE"
MASTER_FLAG=""
if [ "$IS_PRIMARY" = "true" ]; then
MASTER_FLAG="-m"
fi
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
"EXO_HOME=$EXO_HOME EXO_MODELS_DIR=$EXO_MODELS_DIR EXO_LIBP2P_NAMESPACE=$EXO_LIBP2P_NAMESPACE $ENV_VARS PYTHONUNBUFFERED=1 PYTHONDEBUG=1 PYTHONPATH=. uv run exo $MASTER_FLAG --api-port $API_PORT" \
>> "$LOG_FILE" 2>&1 &
EXO_PID=$!
echo "Started EXO in background with PID: $EXO_PID"
echo "Log file: $LOG_FILE"
cleanup() {
echo '=== EXO log (tail) ==='
tail -n 300 "$LOG_FILE" || true
if ps -p "$EXO_PID" >/dev/null 2>&1; then
echo "Killing EXO (PID $EXO_PID)"
kill "$EXO_PID" || true
fi
}
trap cleanup EXIT
for i in $(seq 1 60); do
if curl -s "http://localhost:${API_PORT}/state" >/dev/null 2>&1; then
echo "EXO API ready"
break
fi
if ! ps -p "$EXO_PID" >/dev/null 2>&1; then
echo "EXO terminated early"; sed -n '1,200p' "$LOG_FILE" || true; exit 1
fi
sleep 1
done
RESULTS_FILE="/tmp/bench_results_${GITHUB_RUN_ID}_${GITHUB_RUN_ATTEMPT}_$(date +%s).json"
echo "Results will be saved to: $RESULTS_FILE"
echo "RESULTS_FILE=$RESULTS_FILE" >> "$GITHUB_ENV"
echo "Running bench script with config: $CONFIG_FILE, timeout: $TIMEOUT_SECONDS"
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop --command bash -c \
"PYTHONUNBUFFERED=1 uv run --no-project --with pyyaml --with pydantic python .github/scripts/bench.py \
--api-port $API_PORT \
--config $CONFIG_FILE \
--expected-nodes ${EXPECTED_NODES} \
--is-primary ${IS_PRIMARY} \
--timeout-seconds ${TIMEOUT_SECONDS} \
--output $RESULTS_FILE \
--git-commit ${GITHUB_SHA} \
--hardware-labels ${HARDWARE_LABEL}"
- name: Install AWS CLI
if: always() && env.RESULTS_FILE && matrix.is_primary
run: |
if ! command -v aws &> /dev/null; then
echo "AWS CLI not found, installing..."
brew install awscli
else
echo "AWS CLI already installed"
fi
shell: bash
- name: Upload results to S3
if: always() && env.RESULTS_FILE && matrix.is_primary
env:
AWS_ACCESS_KEY_ID: ${{ secrets.S3_BENCHMARKS_AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: us-east-1
run: |
echo "Checking for results file: $RESULTS_FILE"
echo "Is primary: ${{ matrix.is_primary }}"
if [ -f "$RESULTS_FILE" ]; then
TIMESTAMP=$(date -u +%Y/%m/%d/%H%M%S)
S3_KEY="bench/${TIMESTAMP}_${GITHUB_SHA:0:8}_${GITHUB_RUN_ID}.json"
echo "Uploading results to s3://exo-benchmark-results/$S3_KEY"
aws s3 cp "$RESULTS_FILE" "s3://exo-benchmark-results/$S3_KEY" \
--content-type application/json \
--metadata "commit=${GITHUB_SHA},run_id=${GITHUB_RUN_ID},branch=${GITHUB_REF_NAME}"
echo "Results uploaded successfully"
echo "View at: https://exo-benchmark-results.s3.amazonaws.com/$S3_KEY"
else
echo "Results file not found at: $RESULTS_FILE"
echo "Skipping upload"
fi
shell: bash
- name: Cleanup EXO_HOME
run: |
echo "Cleaning up EXO_HOME: $EXO_HOME"
rm -rf "$EXO_HOME"
shell: bash
if: always()

View File

@@ -1,18 +1,6 @@
name: Build EXO macOS DMG
# Release workflow:
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
# 3. This workflow builds, signs, and notarizes the DMG
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
# 5. DMG and appcast.xml are uploaded to S3
# 6. The draft GitHub Release is published with the DMG attached
#
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
# If no draft exists, a release is auto-created with generated notes.
on:
workflow_dispatch:
push:
tags:
- "v*"
@@ -22,17 +10,14 @@ on:
jobs:
build-macos-app:
runs-on: "macos-26"
permissions:
contents: write
env:
SPARKLE_VERSION: 2.9.0-beta.1
SPARKLE_VERSION: 2.8.1
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
SPARKLE_ED25519_PRIVATE: ${{ secrets.SPARKLE_ED25519_PRIVATE }}
SPARKLE_S3_BUCKET: ${{ secrets.SPARKLE_S3_BUCKET }}
SPARKLE_S3_PREFIX: ${{ secrets.SPARKLE_S3_PREFIX }}
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT: ${{ secrets.EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT }}
AWS_REGION: ${{ secrets.AWS_REGION }}
EXO_BUILD_NUMBER: ${{ github.run_number }}
EXO_LIBP2P_NAMESPACE: ${{ github.ref_name }}
@@ -49,7 +34,7 @@ jobs:
- name: Derive release version from tag
run: |
if [[ "$GITHUB_REF_NAME" == "test-app" || "${{ github.event_name }}" == "workflow_dispatch" ]]; then
if [[ "$GITHUB_REF_NAME" == "test-app" ]]; then
VERSION="0.0.0-alpha.0"
echo "IS_ALPHA=true" >> $GITHUB_ENV
else
@@ -62,32 +47,6 @@ jobs:
fi
echo "RELEASE_VERSION=$VERSION" >> $GITHUB_ENV
- name: Compute build version from semver
run: |
VERSION="$RELEASE_VERSION"
# Extract major.minor.patch (strip prerelease suffix)
BASE_VERSION="${VERSION%%-*}"
MAJOR=$(echo "$BASE_VERSION" | cut -d. -f1)
MINOR=$(echo "$BASE_VERSION" | cut -d. -f2)
PATCH=$(echo "$BASE_VERSION" | cut -d. -f3)
# Extract prerelease number (e.g., "alpha.2" -> 2, or 999 for releases)
if [[ "$VERSION" == *-* ]]; then
PRERELEASE_PART="${VERSION#*-}"
PRERELEASE_NUM="${PRERELEASE_PART##*.}"
# Default to 0 if not a number
if ! [[ "$PRERELEASE_NUM" =~ ^[0-9]+$ ]]; then
PRERELEASE_NUM=0
fi
else
PRERELEASE_NUM=999
fi
# Compute: PRERELEASE + (1000 * PATCH) + (1_000_000 * MINOR) + (1_000_000_000 * MAJOR)
BUILD_VERSION=$((PRERELEASE_NUM + 1000 * PATCH + 1000000 * MINOR + 1000000000 * MAJOR))
echo "EXO_BUILD_VERSION=$BUILD_VERSION" >> $GITHUB_ENV
echo "Computed build version: $BUILD_VERSION from $VERSION"
- name: Ensure tag commit is on main
if: github.ref_type == 'tag'
run: |
@@ -100,52 +59,6 @@ jobs:
exit 1
fi
- name: Fetch and validate release notes
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Find draft release by name using gh release list (more reliable with default token)
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
if [[ -z "$DRAFT_EXISTS" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
echo "Please create a draft release with release notes before pushing the tag."
exit 1
fi
# Fetch full release details via API to get body and ID
echo "Found draft release, fetching details..."
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
# Extract release notes
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "Draft release has no notes (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: Draft release exists but has no release notes"
echo "Please add release notes to the draft release before pushing the tag."
exit 1
fi
# Save release ID for later publishing
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
echo "$NOTES" > /tmp/release_notes.md
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
# ============================================================
# Install dependencies
# ============================================================
@@ -172,22 +85,11 @@ jobs:
uv python install
uv sync --locked
- name: Install Nix
uses: cachix/install-nix-action@v31
with:
nix_path: nixpkgs=channel:nixos-unstable
- name: Configure Cachix
uses: cachix/cachix-action@v14
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build dashboard
run: |
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
mkdir -p dashboard/build
cp -r "$DASHBOARD_OUT"/* dashboard/build/
cd dashboard
npm ci
npm run build
- name: Install Sparkle CLI
run: |
@@ -260,12 +162,11 @@ jobs:
-configuration Release \
-derivedDataPath build \
MARKETING_VERSION="$RELEASE_VERSION" \
CURRENT_PROJECT_VERSION="$EXO_BUILD_VERSION" \
CURRENT_PROJECT_VERSION="$EXO_BUILD_NUMBER" \
EXO_BUILD_TAG="$RELEASE_VERSION" \
EXO_BUILD_COMMIT="$GITHUB_SHA" \
SPARKLE_FEED_URL="$SPARKLE_FEED_URL" \
SPARKLE_ED25519_PUBLIC="$SPARKLE_ED25519_PUBLIC" \
EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT="$EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT" \
CODE_SIGNING_IDENTITY="$SIGNING_IDENTITY" \
CODE_SIGN_INJECT_BASE_ENTITLEMENTS=YES
mkdir -p ../../output
@@ -363,28 +264,6 @@ jobs:
$CHANNEL_FLAG \
.
- name: Inject release notes into appcast
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
env:
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
run: |
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
export NOTES=$(cat "$RELEASE_NOTES_FILE")
# Insert description after the enclosure tag for this version
awk '
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
print
print " <description sparkle:format=\"markdown\"><![CDATA["
print ENVIRON["NOTES"]
print " ]]></description>"
next
}
{ print }
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
echo "Injected markdown release notes for version $RELEASE_VERSION"
# ============================================================
# Upload artifacts
# ============================================================
@@ -415,28 +294,5 @@ jobs:
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}${DMG_NAME}"
if [[ "$IS_ALPHA" != "true" ]]; then
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
fi
- name: Publish GitHub Release
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
# Update the draft release with the tag and upload DMG
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
-f tag_name="$GITHUB_REF_NAME" \
-F draft=false
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
echo "Published release $GITHUB_REF_NAME with DMG attached"
else
# Alpha without draft release - create one with auto-generated notes
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
--title "$GITHUB_REF_NAME" \
--generate-notes \
--prerelease
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
fi
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache

View File

@@ -20,12 +20,6 @@ jobs:
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
@@ -94,19 +88,9 @@ jobs:
- uses: ./.github/actions/typecheck
nix:
name: Build and check (${{ matrix.system }})
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
include:
- runner: macos-26
system: aarch64-darwin
- runner: ubuntu-latest
system: x86_64-linux
- runner: ubuntu-24.04-arm
system: aarch64-linux
nix-flake-check:
name: Check Nix flake
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -117,20 +101,83 @@ jobs:
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build all Nix outputs
run: |
nix flake show --json | jq -r '
[
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
] | .[]
' | xargs nix build
- name: Run nix flake check
run: nix flake check
run: |
nix flake check
shell: bash
# ci:
# needs: typecheck
# runs-on: ubuntu-latest
# permissions:
# contents: read
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
# token: ${{ secrets.GITHUB_TOKEN }}
# lfs: true
#
# - name: Configure git user
# run: |
# git config --local user.email "github-actions@users.noreply.github.com"
# git config --local user.name "github-actions bot"
# shell: bash
#
# - name: Pull LFS files
# run: |
# echo "Pulling Git LFS files..."
# git lfs pull
# shell: bash
#
# - name: Setup EXO_HOME and API_PORT
# run: |
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
# # Generate random port (macOS compatible method)
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
# echo "Created EXO_HOME: $EXO_HOME"
# echo "Generated API_PORT: $API_PORT"
# shell: bash
#
# - name: Setup Nix Environment
# run: |
# echo "Checking for nix installation..."
#
# # Check if nix binary exists directly
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
# echo "PATH=$PATH" >> $GITHUB_ENV
# nix --version
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
# echo "Found nix profile script, sourcing..."
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
# nix --version
# elif command -v nix >/dev/null 2>&1; then
# echo "Nix already in PATH"
# nix --version
# else
# echo "Nix not found. Debugging info:"
# echo "Contents of /nix/var/nix/profiles/default/:"
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
# exit 1
# fi
# shell: bash
#
# - uses: ./.github/actions/lint-check
#
# - uses: ./.github/actions/unit-test
#
# - name: Cleanup EXO_HOME
# run: |
# echo "Cleaning up EXO_HOME: $EXO_HOME"
# rm -rf "$EXO_HOME"
# shell: bash
# if: always()

1
.gitignore vendored
View File

@@ -16,7 +16,6 @@ digest.txt
*.xcuserdatad/
**/.DS_Store
app/EXO/build/
dist/
# rust

View File

@@ -1,156 +0,0 @@
"""Type stubs for mlx_lm.models.deepseek_v3"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: Optional[int]
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
attention_bias: bool
class DeepseekV3Attention(nn.Module):
config: ModelArgs
hidden_size: int
num_heads: int
max_position_embeddings: int
rope_theta: float
q_lora_rank: Optional[int]
qk_rope_head_dim: int
kv_lora_rank: int
v_head_dim: int
qk_nope_head_dim: int
q_head_dim: int
scale: float
q_proj: nn.Linear
q_a_proj: nn.Linear
q_a_layernorm: nn.RMSNorm
q_b_proj: nn.Linear
kv_a_proj_with_mqa: nn.Linear
kv_a_layernorm: nn.RMSNorm
kv_b_proj: nn.Linear
o_proj: nn.Linear
rope: Any
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3MLP(nn.Module):
config: ModelArgs
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class MoEGate(nn.Module):
config: ModelArgs
top_k: int
norm_topk_prob: bool
n_routed_experts: Optional[int]
routed_scaling_factor: float
n_group: int
topk_group: int
weight: mx.array
e_score_correction_bias: mx.array
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class DeepseekV3MoE(nn.Module):
config: ModelArgs
num_experts_per_tok: int
switch_mlp: SwitchGLU
gate: MoEGate
shared_experts: DeepseekV3MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class DeepseekV3DecoderLayer(nn.Module):
self_attn: DeepseekV3Attention
mlp: DeepseekV3MLP | DeepseekV3MoE
input_layernorm: nn.RMSNorm
post_attention_layernorm: nn.RMSNorm
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3Model(nn.Module):
vocab_size: int
embed_tokens: nn.Embedding
layers: list[DeepseekV3DecoderLayer]
norm: nn.RMSNorm
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
class Model(nn.Module):
model_type: str
model: DeepseekV3Model
lm_head: nn.Linear
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
@property
def layers(self) -> list[DeepseekV3DecoderLayer]: ...

View File

@@ -57,11 +57,6 @@ class SwiGLU(nn.Module):
def __call__(self, x, gate): ...
class SwitchGLU(nn.Module):
gate_proj: SwitchLinear
up_proj: SwitchLinear
down_proj: SwitchLinear
activation: SwiGLU
def __init__(
self,
input_dims: int,

View File

@@ -4,7 +4,6 @@ This type stub file was generated by pyright.
from functools import partial
from pathlib import Path
from typing import Any
from transformers import PreTrainedTokenizerFast
@@ -104,55 +103,37 @@ class TokenizerWrapper:
Accessing any attribute other than the ``detokenizer`` is forwarded to the
huggingface tokenizer.
"""
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
def add_eos_token(self, token: str): # -> None:
...
@property
def has_thinking(self): # -> bool:
...
@property
def think_start(self): # -> str | None:
...
@property
def think_end(self): # -> str | None:
...
@property
def has_tool_calling(self): # -> bool:
...
@property
def tool_call_start(self): # -> str | None:
...
@property
def tool_call_end(self): # -> str | None:
...
@property
def detokenizer(self): # -> NaiveStreamingDetokenizer:
"""
Get a stateful streaming detokenizer.
"""
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int
all_special_tokens: list[str]
def __init__(
self,
tokenizer: Any,
detokenizer_class: Any = ...,
eos_token_ids: list[int] | None = ...,
chat_template: Any = ...,
tool_parser: Any = ...,
tool_call_start: str | None = ...,
tool_call_end: str | None = ...,
) -> None: ...
def encode(self, text: str, **kwargs: Any) -> list[int]: ...
def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...
def apply_chat_template(
self,
messages: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = False,
tools: Any = None,
**kwargs: Any,
) -> str: ...
def get_vocab(self) -> dict[str, int]: ...
def add_eos_token(self, token: str) -> None: ...
@property
def has_thinking(self) -> bool: ...
@property
def think_start(self) -> str | None: ...
@property
def think_end(self) -> str | None: ...
@property
def has_tool_calling(self) -> bool: ...
@property
def tool_call_start(self) -> str | None: ...
@property
def tool_call_end(self) -> str | None: ...
@property
def detokenizer(self) -> NaiveStreamingDetokenizer:
"""Get a stateful streaming detokenizer."""
def __getattr__(self, attr: str) -> Any: ...
def __setattr__(self, attr: str, value: Any) -> None: ...
def __getattr__(self, attr): # -> set[Any] | Any:
...
def __setattr__(self, attr, value): # -> None:
...
class NewlineTokenizer(PreTrainedTokenizerFast):
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
@@ -165,11 +146,18 @@ class NewlineTokenizer(PreTrainedTokenizerFast):
def batch_decode(self, *args, **kwargs): # -> list[str]:
...
def load(
def load_tokenizer(
model_path: Path,
tokenizer_config_extra: dict[str, Any] | None = None,
eos_token_ids: list[int] | int | None = None,
) -> TokenizerWrapper:
tokenizer_config_extra=...,
return_tokenizer=...,
eos_token_ids=...,
) -> (
TokenizerWrapper
| type[SPMStreamingDetokenizer]
| partial[SPMStreamingDetokenizer]
| type[BPEStreamingDetokenizer]
| type[NaiveStreamingDetokenizer]
):
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
@@ -177,7 +165,4 @@ def load(
a Hugging Face repo ID.
"""
# Alias for backward compatibility
load_tokenizer = load
def no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...

View File

@@ -1,3 +0,0 @@
{
"useTabs": true
}

View File

@@ -1,6 +0,0 @@
{
"version": 1,
"indentation": {
"spaces": 4
}
}

121
AGENTS.md
View File

@@ -1,121 +0,0 @@
# AGENTS.md
This file provides guidance to AI coding agents when working with code in this repository.
## Project Overview
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
## Build & Run Commands
```bash
# Build the dashboard (required before running exo)
cd dashboard && npm install && npm run build && cd ..
# Run exo (starts both master and worker with API at http://localhost:52415)
uv run exo
# Run with verbose logging
uv run exo -v # or -vv for more verbose
# Run tests (excludes slow tests by default)
uv run pytest
# Run all tests including slow tests
uv run pytest -m ""
# Run a specific test file
uv run pytest src/exo/shared/tests/test_election.py
# Run a specific test function
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
# Type checking (strict mode)
uv run basedpyright
# Linting
uv run ruff check
# Format code (using nix)
nix fmt
```
## Pre-Commit Checks (REQUIRED)
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
```bash
# 1. Type checking - MUST pass with 0 errors
uv run basedpyright
# 2. Linting - MUST pass
uv run ruff check
# 3. Formatting - MUST be applied
nix fmt
# 4. Tests - MUST pass
uv run pytest
```
Run all checks in sequence:
```bash
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
```
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
## Architecture
### Node Composition
A single exo `Node` (src/exo/main.py) runs multiple components:
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
- **Worker**: Handles inference tasks, downloads models, manages runner processes
- **Master**: Coordinates cluster state, places model instances across nodes
- **Election**: Bully algorithm for master election
- **API**: FastAPI server for OpenAI-compatible chat completions
### Message Flow
Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers
- `LOCAL_EVENTS`: Workers send events to master for indexing
- `COMMANDS`: Workers/API send commands to master
- `ELECTION_MESSAGES`: Election protocol messages
- `CONNECTION_MESSAGES`: libp2p connection updates
### Event Sourcing
The system uses event sourcing for state management:
- `State` (src/exo/shared/types/state.py): Immutable state object
- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state
- Master indexes events and broadcasts; workers apply indexed events
### Key Type Hierarchy
- `src/exo/shared/types/`: Pydantic models for all shared types
- `events.py`: Event types (discriminated union)
- `commands.py`: Command types
- `tasks.py`: Task types for worker execution
- `state.py`: Cluster state model
### Rust Components
Rust code in `rust/` provides:
- `networking`: libp2p networking (gossipsub, peer discovery)
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
- `system_custodian`: System-level operations
### Dashboard
Svelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.
## Code Style Requirements
From .cursorrules:
- Strict, exhaustive typing - never bypass the type-checker
- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives
- Pydantic models with `frozen=True` and `strict=True`
- Pure functions with injectable effect handlers for side-effects
- Descriptive names - no abbreviations or 3-letter acronyms
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

View File

@@ -1 +0,0 @@
AGENTS.md

19
Cargo.lock generated
View File

@@ -4340,6 +4340,25 @@ dependencies = [
"libc",
]
[[package]]
name = "system_custodian"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
]
[[package]]
name = "tagptr"
version = "0.2.0"

View File

@@ -3,6 +3,7 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/system_custodian",
"rust/util",
]
@@ -24,6 +25,7 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }
# Proc-macro authoring tools

View File

@@ -1,41 +0,0 @@
# Missed things
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[X] We put cache limit back in utils_mlx.py.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).

119
README.md
View File

@@ -8,7 +8,7 @@
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
<a href="https://discord.gg/72NsF6ux" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
<a href="https://x.com/exolabs" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/twitter/follow/exolabs?style=social" alt="X"></a>
<a href="https://www.apache.org/licenses/LICENSE-2.0.html" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/License-Apache2.0-blue.svg" alt="License: Apache-2.0"></a>
</p>
@@ -27,22 +27,13 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
## Dashboard
exo includes a built-in dashboard for managing your cluster and chatting with models.
<p align="center">
<img src="docs/imgs/dashboard-cluster-view.png" alt="exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded" width="80%" />
</p>
<p align="center"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>
## Benchmarks
<details>
<summary>Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-1-qwen3-235b.jpeg" alt="Benchmark - Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
@@ -50,7 +41,7 @@ exo includes a built-in dashboard for managing your cluster and chatting with mo
<summary>DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-2-deepseek-3.1-671b.jpeg" alt="Benchmark - DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
@@ -58,7 +49,7 @@ exo includes a built-in dashboard for managing your cluster and chatting with mo
<summary>Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-3-kimi-k2-thinking.jpeg" alt="Benchmark - Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
@@ -163,24 +154,6 @@ This starts the exo dashboard and API at http://localhost:52415/
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
**Configuration Options:**
- `--no-worker`: Run exo without the worker component. Useful for coordinator-only nodes that handle networking and orchestration but don't execute inference tasks. This is helpful for machines without sufficient GPU resources but with good network connectivity.
```bash
uv run exo --no-worker
```
**File Locations (Linux):**
exo follows the [XDG Base Directory Specification](https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html) on Linux:
- **Configuration files**: `~/.config/exo/` (or `$XDG_CONFIG_HOME/exo/`)
- **Data files**: `~/.local/share/exo/` (or `$XDG_DATA_HOME/exo/`)
- **Cache files**: `~/.cache/exo/` (or `$XDG_CACHE_HOME/exo/`)
You can override these locations by setting the corresponding XDG environment variables.
### macOS App
exo ships a macOS app that runs in the background on your Mac.
@@ -193,37 +166,6 @@ Download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-
The app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.
**Custom Namespace for Cluster Isolation:**
The macOS app includes a custom namespace feature that allows you to isolate your exo cluster from others on the same network. This is configured through the `EXO_LIBP2P_NAMESPACE` setting:
- **Use cases**:
- Running multiple separate exo clusters on the same network
- Isolating development/testing clusters from production clusters
- Preventing accidental cluster joining
- **Configuration**: Access this setting in the app's Advanced settings (or set the `EXO_LIBP2P_NAMESPACE` environment variable when running from source)
The namespace is logged on startup for debugging purposes.
#### Uninstalling the macOS App
The recommended way to uninstall is through the app itself: click the menu bar icon → Advanced → Uninstall. This cleanly removes all system components.
If you've already deleted the app, you can run the standalone uninstaller script:
```bash
sudo ./app/EXO/uninstall-exo.sh
```
This removes:
- Network setup LaunchDaemon
- Network configuration script
- Log files
- The "exo" network location
**Note:** You'll need to manually remove EXO from Login Items in System Settings → General → Login Items.
---
### Enabling RDMA on macOS
@@ -345,56 +287,7 @@ curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
- List all models: `curl http://localhost:52415/models`
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
For further details, see:
- API basic documentation in [docs/api.md](docs/api.md).
- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
---
## Benchmarking
The `exo-bench` tool measures model prefill and token generation speed across different placement configurations. This helps you optimize model performance and validate improvements.
**Prerequisites:**
- Nodes should be running with `uv run exo` before benchmarking
- The tool uses the `/bench/chat/completions` endpoint
**Basic usage:**
```bash
uv run bench/exo_bench.py \
--model llama-3.2-1b \
--pp 128,256,512 \
--tg 128,256
```
**Key parameters:**
- `--model`: Model to benchmark (short ID or HuggingFace ID)
- `--pp`: Prompt size hints (comma-separated integers)
- `--tg`: Generation lengths (comma-separated integers)
- `--max-nodes`: Limit placements to N nodes (default: 4)
- `--instance-meta`: Filter by `ring`, `jaccl`, or `both` (default: both)
- `--sharding`: Filter by `pipeline`, `tensor`, or `both` (default: both)
- `--repeat`: Number of repetitions per configuration (default: 1)
- `--warmup`: Warmup runs per placement (default: 0)
- `--json-out`: Output file for results (default: bench/results.json)
**Example with filters:**
```bash
uv run bench/exo_bench.py \
--model llama-3.2-1b \
--pp 128,512 \
--tg 128 \
--max-nodes 2 \
--sharding tensor \
--repeat 3 \
--json-out my-results.json
```
The tool outputs performance metrics including prompt tokens per second (prompt_tps), generation tokens per second (generation_tps), and peak memory usage for each configuration.
For further details, see API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
---
@@ -406,4 +299,4 @@ On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working
## Contributing
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.

View File

@@ -19,7 +19,6 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Validate RDMA connections with ibv_devinfo in the info gatherer
Potential refactors:

View File

@@ -585,7 +585,7 @@
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.9.0-beta.1;
minimumVersion = 2.8.1;
};
};
/* End XCRemoteSwiftPackageReference section */

View File

@@ -6,8 +6,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/sparkle-project/Sparkle.git",
"state" : {
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
"version" : "2.9.0-beta.1"
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
"version" : "2.8.1"
}
}
],

View File

@@ -12,25 +12,20 @@ struct ContentView: View {
@EnvironmentObject private var controller: ExoProcessController
@EnvironmentObject private var stateService: ClusterStateService
@EnvironmentObject private var networkStatusService: NetworkStatusService
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
@EnvironmentObject private var updater: SparkleUpdater
@State private var focusedNode: NodeViewModel?
@State private var deletingInstanceIDs: Set<String> = []
@State private var showAllNodes = false
@State private var showAllInstances = false
@State private var showAdvanced = false
@State private var showDebugInfo = false
@State private var bugReportInFlight = false
@State private var bugReportMessage: String?
@State private var uninstallInProgress = false
@State private var showAdvancedOptions = false
@State private var pendingNamespace: String = ""
var body: some View {
VStack(alignment: .leading, spacing: 12) {
statusSection
if shouldShowLocalNetworkWarning {
localNetworkWarningBanner
}
if shouldShowClusterDetails {
Divider()
overviewSection
@@ -45,7 +40,6 @@ struct ContentView: View {
}
.animation(.easeInOut(duration: 0.3), value: shouldShowClusterDetails)
.animation(.easeInOut(duration: 0.3), value: shouldShowInstances)
.animation(.easeInOut(duration: 0.3), value: shouldShowLocalNetworkWarning)
.padding()
.frame(width: 340)
.onAppear {
@@ -55,67 +49,9 @@ struct ContentView: View {
}
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}
return false
}
private var localNetworkWarningBanner: some View {
VStack(alignment: .leading, spacing: 6) {
HStack(spacing: 6) {
Image(systemName: "exclamationmark.triangle.fill")
.foregroundColor(.orange)
Text("Local Network Access Issue")
.font(.caption)
.fontWeight(.semibold)
}
Text(
"Device discovery won't work. To fix:\n1. Quit EXO\n2. Open System Settings → Privacy & Security → Local Network\n3. Toggle EXO off, then back on\n4. Relaunch EXO"
)
.font(.caption2)
.foregroundColor(.secondary)
.fixedSize(horizontal: false, vertical: true)
Button {
openLocalNetworkSettings()
} label: {
Text("Open Settings")
.font(.caption2)
}
.buttonStyle(.bordered)
.controlSize(.small)
}
.padding(8)
.background(
RoundedRectangle(cornerRadius: 8)
.fill(Color.orange.opacity(0.1))
)
.overlay(
RoundedRectangle(cornerRadius: 8)
.stroke(Color.orange.opacity(0.3), lineWidth: 1)
)
}
private func openLocalNetworkSettings() {
// Open Privacy & Security settings - Local Network section
if let url = URL(
string: "x-apple.systempreferences:com.apple.preference.security?Privacy_LocalNetwork")
{
NSWorkspace.shared.open(url)
}
}
private var topologySection: some View {
Group {
if let topology = stateService.latestSnapshot?.topologyViewModel(
localNodeId: stateService.localNodeId), !topology.nodes.isEmpty
{
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
TopologyMiniView(topology: topology)
}
}
@@ -149,10 +85,8 @@ struct ContentView: View {
VStack(alignment: .leading, spacing: 4) {
HStack {
VStack(alignment: .leading) {
Text(
"\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB"
)
.font(.headline)
Text("\(overview.usedRam, specifier: "%.0f") / \(overview.totalRam, specifier: "%.0f") GB")
.font(.headline)
Text("Memory")
.font(.caption)
.foregroundColor(.secondary)
@@ -261,7 +195,13 @@ struct ContentView: View {
Divider()
.padding(.vertical, 4)
}
advancedSection
controlButton(title: "Check for Updates") {
updater.checkForUpdates()
}
.padding(.bottom, 8)
advancedOptionsSection
.padding(.bottom, 8)
debugSection
.padding(.bottom, 8)
controlButton(title: "Quit", tint: .secondary) {
controller.stop()
@@ -270,57 +210,7 @@ struct ContentView: View {
}
}
private var advancedSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Advanced")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showAdvanced)
}
.animation(nil, value: showAdvanced)
if showAdvanced {
VStack(alignment: .leading, spacing: 8) {
VStack(alignment: .leading, spacing: 4) {
Text("Cluster Namespace")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
TextField("optional", text: $pendingNamespace)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingNamespace = controller.customNamespace
}
Button("Save & Restart") {
controller.customNamespace = pendingNamespace
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingNamespace == controller.customNamespace)
}
}
HoverButton(title: "Check for Updates", small: true) {
updater.checkForUpdates()
}
debugSection
HoverButton(title: "Uninstall", tint: .red, small: true) {
showUninstallConfirmationAlert()
}
.disabled(uninstallInProgress)
}
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showAdvanced)
}
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void)
-> some View
{
private func controlButton(title: String, tint: Color = .primary, action: @escaping () -> Void) -> some View {
HoverButton(title: title, tint: tint, trailingSystemImage: nil, action: action)
}
@@ -351,12 +241,9 @@ struct ContentView: View {
Button {
isExpanded.wrappedValue.toggle()
} label: {
Label(
isExpanded.wrappedValue ? "Hide" : "Show All",
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
)
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
Label(isExpanded.wrappedValue ? "Hide" : "Show All", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
}
.buttonStyle(.plain)
.font(.caption2)
@@ -444,16 +331,57 @@ struct ContentView: View {
}
}
private var debugSection: some View {
VStack(alignment: .leading, spacing: 4) {
HoverButton(
title: "Debug Info",
tint: .primary,
trailingSystemImage: showDebugInfo ? "chevron.up" : "chevron.down",
small: true
) {
showDebugInfo.toggle()
private var advancedOptionsSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Advanced Options")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showAdvancedOptions)
}
.animation(nil, value: showAdvancedOptions)
if showAdvancedOptions {
VStack(alignment: .leading, spacing: 8) {
VStack(alignment: .leading, spacing: 4) {
Text("Cluster Namespace")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
TextField("optional", text: $pendingNamespace)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingNamespace = controller.customNamespace
}
Button("Save & Restart") {
controller.customNamespace = pendingNamespace
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingNamespace == controller.customNamespace)
}
}
}
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showAdvancedOptions)
}
private var debugSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Debug Info")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showDebugInfo)
}
.animation(nil, value: showDebugInfo)
if showDebugInfo {
VStack(alignment: .leading, spacing: 4) {
Text("Version: \(buildTag)")
@@ -466,63 +394,15 @@ struct ContentView: View {
.font(.caption2)
.foregroundColor(thunderboltStatusColor)
interfaceIpList
rdmaStatusView
sendBugReportButton
.padding(.top, 6)
}
.padding(.leading, 8)
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showDebugInfo)
}
private var rdmaStatusView: some View {
let rdma = networkStatusService.status.rdmaStatus
return VStack(alignment: .leading, spacing: 1) {
Text("RDMA: \(rdmaStatusText(rdma))")
.font(.caption2)
.foregroundColor(rdmaStatusColor(rdma))
if !rdma.devices.isEmpty {
Text(" Devices: \(rdma.devices.joined(separator: ", "))")
.font(.caption2)
.foregroundColor(.secondary)
}
if !rdma.activePorts.isEmpty {
Text(" Active Ports:")
.font(.caption2)
.foregroundColor(.secondary)
ForEach(rdma.activePorts, id: \.device) { port in
Text(" \(port.device) port \(port.port): \(port.state)")
.font(.caption2)
.foregroundColor(.green)
}
}
}
}
private func rdmaStatusText(_ rdma: RDMAStatus) -> String {
switch rdma.rdmaCtlEnabled {
case .some(true):
return "Enabled"
case .some(false):
return "Disabled"
case nil:
return rdma.devices.isEmpty ? "Not Available" : "Available"
}
}
private func rdmaStatusColor(_ rdma: RDMAStatus) -> Color {
switch rdma.rdmaCtlEnabled {
case .some(true):
return .green
case .some(false):
return .orange
case nil:
return rdma.devices.isEmpty ? .secondary : .green
}
}
private var sendBugReportButton: some View {
VStack(alignment: .leading, spacing: 4) {
Button {
@@ -612,88 +492,6 @@ struct ContentView: View {
bugReportInFlight = false
}
private func showUninstallConfirmationAlert() {
let alert = NSAlert()
alert.messageText = "Uninstall EXO"
alert.informativeText = """
This will remove EXO and all its system components:
• Network configuration daemon
• Launch at login registration
• EXO network location
The app will be moved to Trash.
"""
alert.alertStyle = .warning
alert.addButton(withTitle: "Uninstall")
alert.addButton(withTitle: "Cancel")
// Style the Uninstall button as destructive
if let uninstallButton = alert.buttons.first {
uninstallButton.hasDestructiveAction = true
}
let response = alert.runModal()
if response == .alertFirstButtonReturn {
performUninstall()
}
}
private func performUninstall() {
uninstallInProgress = true
// Stop EXO process first
controller.cancelPendingLaunch()
controller.stop()
stateService.stopPolling()
// Run the privileged uninstall on a background thread
// Using .utility QoS to avoid priority inversion with NSAppleScript's subprocess
DispatchQueue.global(qos: .utility).async {
do {
// Remove network setup daemon and components (requires admin privileges)
try NetworkSetupHelper.uninstall()
DispatchQueue.main.async {
// Unregister from launch at login
LaunchAtLoginHelper.disable()
// Move app to trash
self.moveAppToTrash()
// Quit the app
DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {
NSApplication.shared.terminate(nil)
}
}
} catch {
DispatchQueue.main.async {
self.showErrorAlert(message: error.localizedDescription)
self.uninstallInProgress = false
}
}
}
}
private func showErrorAlert(message: String) {
let alert = NSAlert()
alert.messageText = "Uninstall Failed"
alert.informativeText = message
alert.alertStyle = .critical
alert.addButton(withTitle: "OK")
alert.runModal()
}
private func moveAppToTrash() {
guard let appURL = Bundle.main.bundleURL as URL? else { return }
do {
try FileManager.default.trashItem(at: appURL, resultingItemURL: nil)
} catch {
// If we can't trash the app, that's OK - user can do it manually
// The important system components have already been cleaned up
}
}
private var buildTag: String {
Bundle.main.infoDictionary?["EXOBuildTag"] as? String ?? "unknown"
}
@@ -707,27 +505,14 @@ private struct HoverButton: View {
let title: String
let tint: Color
let trailingSystemImage: String?
let small: Bool
let action: () -> Void
init(
title: String, tint: Color = .primary, trailingSystemImage: String? = nil,
small: Bool = false, action: @escaping () -> Void
) {
self.title = title
self.tint = tint
self.trailingSystemImage = trailingSystemImage
self.small = small
self.action = action
}
@State private var isHovering = false
var body: some View {
Button(action: action) {
HStack {
Text(title)
.font(small ? .caption : nil)
Spacer()
if let systemName = trailingSystemImage {
Image(systemName: systemName)
@@ -735,8 +520,8 @@ private struct HoverButton: View {
}
}
.frame(maxWidth: .infinity, alignment: .leading)
.padding(.vertical, small ? 4 : 6)
.padding(.horizontal, small ? 6 : 8)
.padding(.vertical, 6)
.padding(.horizontal, 8)
.background(
RoundedRectangle(cornerRadius: 6)
.fill(
@@ -751,3 +536,4 @@ private struct HoverButton: View {
.onHover { isHovering = $0 }
}
}

View File

@@ -8,9 +8,9 @@
import AppKit
import CoreImage
import CoreImage.CIFilterBuiltins
import ServiceManagement
import Sparkle
import SwiftUI
import ServiceManagement
import UserNotifications
import os.log
@@ -19,7 +19,6 @@ struct EXOApp: App {
@StateObject private var controller: ExoProcessController
@StateObject private var stateService: ClusterStateService
@StateObject private var networkStatusService: NetworkStatusService
@StateObject private var localNetworkChecker: LocalNetworkChecker
@StateObject private var updater: SparkleUpdater
private let terminationObserver: TerminationObserver
private let ciContext = CIContext(options: nil)
@@ -38,13 +37,9 @@ struct EXOApp: App {
_stateService = StateObject(wrappedValue: service)
let networkStatus = NetworkStatusService()
_networkStatusService = StateObject(wrappedValue: networkStatus)
let localNetwork = LocalNetworkChecker()
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
_updater = StateObject(wrappedValue: updater)
enableLaunchAtLoginIfNeeded()
NetworkSetupHelper.ensureLaunchDaemonInstalled()
// Check local network access BEFORE launching exo
localNetwork.check()
controller.scheduleLaunch(after: 15)
service.startPolling()
networkStatus.startPolling()
@@ -56,7 +51,6 @@ struct EXOApp: App {
.environmentObject(controller)
.environmentObject(stateService)
.environmentObject(networkStatusService)
.environmentObject(localNetworkChecker)
.environmentObject(updater)
} label: {
menuBarIcon
@@ -113,7 +107,7 @@ struct EXOApp: App {
filter.contrast = 0.9
guard let output = filter.outputImage,
let rendered = ciContext.createCGImage(output, from: output.extent)
let rendered = ciContext.createCGImage(output, from: output.extent)
else {
return nil
}
@@ -126,26 +120,7 @@ struct EXOApp: App {
do {
try SMAppService.mainApp.register()
} catch {
Logger().error(
"Failed to register EXO for launch at login: \(error.localizedDescription)")
}
}
}
/// Helper for managing EXO's launch-at-login registration
enum LaunchAtLoginHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LaunchAtLogin")
/// Unregisters EXO from launching at login
static func disable() {
guard SMAppService.mainApp.status == .enabled else { return }
do {
try SMAppService.mainApp.unregister()
logger.info("Unregistered EXO from launch at login")
} catch {
logger.error(
"Failed to unregister EXO from launch at login: \(error.localizedDescription, privacy: .public)"
)
Logger().error("Failed to register EXO for launch at login: \(error.localizedDescription)")
}
}
}
@@ -170,7 +145,7 @@ final class SparkleUpdater: NSObject, ObservableObject {
center.requestAuthorization(options: [.alert, .sound]) { _, _ in }
controller.updater.automaticallyChecksForUpdates = true
controller.updater.automaticallyDownloadsUpdates = false
controller.updater.updateCheckInterval = 900 // 15 minutes
controller.updater.updateCheckInterval = 900 // 15 minutes
DispatchQueue.main.asyncAfter(deadline: .now() + 5) { [weak controller] in
controller?.updater.checkForUpdatesInBackground()
}
@@ -237,8 +212,7 @@ private final class ExoNotificationDelegate: NSObject, UNUserNotificationCenterD
func userNotificationCenter(
_ center: UNUserNotificationCenter,
willPresent notification: UNNotification,
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) ->
Void
withCompletionHandler completionHandler: @escaping (UNNotificationPresentationOptions) -> Void
) {
completionHandler([.banner, .list, .sound])
}

View File

@@ -31,8 +31,7 @@ final class ExoProcessController: ObservableObject {
@Published private(set) var launchCountdownSeconds: Int?
@Published var customNamespace: String = {
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
}()
{
}() {
didSet {
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
}
@@ -222,9 +221,7 @@ final class ExoProcessController: ObservableObject {
if let tag = Bundle.main.infoDictionary?["EXOBuildTag"] as? String, !tag.isEmpty {
return tag
}
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String,
!short.isEmpty
{
if let short = Bundle.main.infoDictionary?["CFBundleShortVersionString"] as? String, !short.isEmpty {
return short
}
return "dev"

View File

@@ -8,15 +8,5 @@
<string>$(EXO_BUILD_TAG)</string>
<key>EXOBuildCommit</key>
<string>$(EXO_BUILD_COMMIT)</string>
<key>EXOBugReportPresignedUrlEndpoint</key>
<string>$(EXO_BUG_REPORT_PRESIGNED_URL_ENDPOINT)</string>
<key>NSLocalNetworkUsageDescription</key>
<string>EXO needs local network access to discover and connect to other devices in your cluster for distributed AI inference.</string>
<key>NSBonjourServices</key>
<array>
<string>_p2p._tcp</string>
<string>_p2p._udp</string>
<string>_libp2p._udp</string>
</array>
</dict>
</plist>

View File

@@ -16,13 +16,10 @@ struct ClusterState: Decodable {
self.instances = rawInstances.mapValues(\.instance)
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
let rawTasks =
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
let rawTasks = try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
self.tasks = rawTasks.compactMapValues(\.task)
self.topology = try container.decodeIfPresent(Topology.self, forKey: .topology)
let rawDownloads =
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
?? [:]
let rawDownloads = try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads) ?? [:]
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
}
@@ -44,8 +41,7 @@ private struct TaggedInstance: Decodable {
let payloads = try container.decode([String: ClusterInstancePayload].self)
guard let entry = payloads.first else {
throw DecodingError.dataCorrupted(
DecodingError.Context(
codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty instance payload")
)
}
self.instance = ClusterInstance(
@@ -81,8 +77,7 @@ struct RunnerStatusSummary: Decodable {
let payloads = try container.decode([String: RunnerStatusDetail].self)
guard let entry = payloads.first else {
throw DecodingError.dataCorrupted(
DecodingError.Context(
codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
DecodingError.Context(codingPath: decoder.codingPath, debugDescription: "Empty runner status payload")
)
}
self.status = entry.key
@@ -262,9 +257,7 @@ struct ChatCompletionTaskParameters: Decodable, Equatable {
func promptPreview() -> String? {
guard let messages else { return nil }
if let userMessage = messages.last(where: {
$0.role?.lowercased() == "user" && ($0.content?.isEmpty == false)
}) {
if let userMessage = messages.last(where: { $0.role?.lowercased() == "user" && ($0.content?.isEmpty == false) }) {
return userMessage.content
}
return messages.last?.content
@@ -372,3 +365,5 @@ extension ClusterState {
func availableModels() -> [ModelOption] { [] }
}

View File

@@ -1,3 +1,4 @@
import CryptoKit
import Foundation
struct BugReportOutcome: Equatable {
@@ -6,17 +7,17 @@ struct BugReportOutcome: Equatable {
}
enum BugReportError: LocalizedError {
case missingCredentials
case invalidEndpoint
case presignedUrlFailed(String)
case uploadFailed(String)
case collectFailed(String)
var errorDescription: String? {
switch self {
case .missingCredentials:
return "Bug report upload credentials are not set."
case .invalidEndpoint:
return "Bug report endpoint is invalid."
case .presignedUrlFailed(let message):
return "Failed to get presigned URLs: \(message)"
case .uploadFailed(let message):
return "Bug report upload failed: \(message)"
case .collectFailed(let message):
@@ -26,13 +27,11 @@ enum BugReportError: LocalizedError {
}
struct BugReportService {
private struct PresignedUrlsRequest: Codable {
let keys: [String]
}
private struct PresignedUrlsResponse: Codable {
let urls: [String: String]
let expiresIn: Int?
struct AWSConfig {
let accessKey: String
let secretKey: String
let region: String
let bucket: String
}
func sendReport(
@@ -40,9 +39,9 @@ struct BugReportService {
now: Date = Date(),
isManual: Bool = false
) async throws -> BugReportOutcome {
let timestamp = Self.runTimestampString(now)
let dayPrefix = Self.dayPrefixString(now)
let prefix = "reports/\(dayPrefix)/\(timestamp)/"
let credentials = try loadCredentials()
let timestamp = ISO8601DateFormatter().string(from: now)
let prefix = "reports/\(timestamp)/"
let logData = readLog()
let ifconfigText = try await captureIfconfig()
@@ -67,82 +66,28 @@ struct BugReportService {
("\(prefix)exo.log", logData),
("\(prefix)state.json", stateData),
("\(prefix)events.json", eventsData),
("\(prefix)report.json", reportJSON),
("\(prefix)report.json", reportJSON)
]
let uploadItems: [(key: String, body: Data)] = uploads.compactMap { item in
guard let body = item.data else { return nil }
return (key: item.path, body: body)
let uploader = try S3Uploader(config: credentials)
for item in uploads {
guard let data = item.data else { continue }
try await uploader.upload(
objectPath: item.path,
body: data
)
}
guard !uploadItems.isEmpty else {
return BugReportOutcome(success: false, message: "No data to upload")
}
let presignedUrls = try await fetchPresignedUploadUrls(keys: uploadItems.map(\.key))
for item in uploadItems {
guard let urlString = presignedUrls[item.key], let url = URL(string: urlString) else {
throw BugReportError.uploadFailed("Missing presigned URL for \(item.key)")
}
try await uploadToPresignedUrl(url: url, body: item.body)
}
return BugReportOutcome(
success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
return BugReportOutcome(success: true, message: "Bug Report sent. Thank you for helping to improve EXO 1.0.")
}
private static func dayPrefixString(_ date: Date) -> String {
var calendar = Calendar(identifier: .gregorian)
calendar.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
let components = calendar.dateComponents([.year, .month, .day], from: date)
let year = components.year ?? 0
let month = components.month ?? 0
let day = components.day ?? 0
return String(format: "%04d/%02d/%02d", year, month, day)
}
private static func runTimestampString(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.locale = Locale(identifier: "en_US_POSIX")
formatter.timeZone = TimeZone(secondsFromGMT: 0) ?? .current
formatter.dateFormat = "yyyy-MM-dd'T'HHmmss.SSS'Z'"
return formatter.string(from: date)
}
private func fetchPresignedUploadUrls(keys: [String], bundle: Bundle = .main) async throws
-> [String: String]
{
guard
let endpointString = bundle.infoDictionary?["EXOBugReportPresignedUrlEndpoint"]
as? String
else {
throw BugReportError.invalidEndpoint
}
let trimmedEndpointString = endpointString.trimmingCharacters(in: .whitespacesAndNewlines)
guard !trimmedEndpointString.isEmpty, let endpoint = URL(string: trimmedEndpointString)
else {
throw BugReportError.invalidEndpoint
}
var request = URLRequest(url: endpoint)
request.httpMethod = "POST"
request.timeoutInterval = 10
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
let encoder = JSONEncoder()
request.httpBody = try encoder.encode(PresignedUrlsRequest(keys: keys))
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse else {
throw BugReportError.presignedUrlFailed("Non-HTTP response")
}
guard (200..<300).contains(http.statusCode) else {
throw BugReportError.presignedUrlFailed("HTTP status \(http.statusCode)")
}
let decoder = JSONDecoder()
let decoded = try decoder.decode(PresignedUrlsResponse.self, from: data)
return decoded.urls
private func loadCredentials() throws -> AWSConfig {
return AWSConfig(
accessKey: "AKIAYEKP5EMXTOBYDGHX",
secretKey: "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",
region: "us-east-1",
bucket: "exo-bug-reports"
)
}
private func readLog() -> Data? {
@@ -155,8 +100,7 @@ struct BugReportService {
private func captureIfconfig() async throws -> String {
let result = runCommand(["/sbin/ifconfig"])
guard result.exitCode == 0 else {
throw BugReportError.collectFailed(
result.error.isEmpty ? "ifconfig failed" : result.error)
throw BugReportError.collectFailed(result.error.isEmpty ? "ifconfig failed" : result.error)
}
return result.output
}
@@ -164,23 +108,12 @@ struct BugReportService {
private func readDebugInfo() -> DebugInfo {
DebugInfo(
thunderboltBridgeDisabled: readThunderboltBridgeDisabled(),
interfaces: readInterfaces(),
rdma: readRDMADebugInfo()
)
}
private func readRDMADebugInfo() -> DebugInfo.RDMADebugInfo {
DebugInfo.RDMADebugInfo(
rdmaCtlStatus: safeRunCommand(["/usr/bin/rdma_ctl", "status"]),
ibvDevices: safeRunCommand(["/usr/bin/ibv_devices"]),
ibvDevinfo: safeRunCommand(["/usr/bin/ibv_devinfo"])
interfaces: readInterfaces()
)
}
private func readThunderboltBridgeDisabled() -> Bool? {
let result = runCommand([
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
])
let result = runCommand(["/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
guard result.exitCode == 0 else { return nil }
let output = result.output.lowercased()
if output.contains("enabled") {
@@ -223,8 +156,7 @@ struct BugReportService {
request.timeoutInterval = 5
do {
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode)
else {
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
return nil
}
return data
@@ -233,36 +165,6 @@ struct BugReportService {
}
}
private func uploadToPresignedUrl(url: URL, body: Data) async throws {
let maxAttempts = 2
var lastError: Error?
for attempt in 1...maxAttempts {
do {
var request = URLRequest(url: url)
request.httpMethod = "PUT"
request.httpBody = body
request.timeoutInterval = 30
let (_, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse else {
throw BugReportError.uploadFailed("Non-HTTP response")
}
guard (200..<300).contains(http.statusCode) else {
throw BugReportError.uploadFailed("HTTP status \(http.statusCode)")
}
return
} catch {
lastError = error
if attempt < maxAttempts {
try await Task.sleep(nanoseconds: 400_000_000)
}
}
}
throw BugReportError.uploadFailed(lastError?.localizedDescription ?? "Unknown error")
}
private func makeReportJson(
timestamp: String,
hostName: String,
@@ -280,7 +182,7 @@ struct BugReportService {
"system": system,
"exo_version": exo.version as Any,
"exo_commit": exo.commit as Any,
"report_type": isManual ? "manual" : "automated",
"report_type": isManual ? "manual" : "automated"
]
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
}
@@ -311,13 +213,10 @@ struct BugReportService {
let user = safeRunCommand(["/usr/bin/whoami"])
let consoleUser = safeRunCommand(["/usr/bin/stat", "-f%Su", "/dev/console"])
let uptime = safeRunCommand(["/usr/bin/uptime"])
let diskRoot = safeRunCommand([
"/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'",
])
let diskRoot = safeRunCommand(["/bin/sh", "-c", "/bin/df -h / | awk 'NR==2 {print $1, $2, $3, $4, $5}'"])
let interfacesList = safeRunCommand(["/usr/sbin/ipconfig", "getiflist"])
let interfacesAndIPs =
interfacesList?
let interfacesAndIPs = interfacesList?
.split(whereSeparator: { $0 == " " || $0 == "\n" })
.compactMap { iface -> [String: Any]? in
let name = String(iface)
@@ -328,8 +227,7 @@ struct BugReportService {
} ?? []
let wifiSSID: String?
let airportPath =
"/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
let airportPath = "/System/Library/PrivateFrameworks/Apple80211.framework/Versions/Current/Resources/airport"
if FileManager.default.isExecutableFile(atPath: airportPath) {
wifiSSID = safeRunCommand([airportPath, "-I"]).flatMap(parseWifiSSID)
} else {
@@ -357,7 +255,7 @@ struct BugReportService {
"disk_root": diskRoot as Any,
"interfaces_and_ips": interfacesAndIPs,
"ipconfig_getiflist": interfacesList as Any,
"wifi_ssid": wifiSSID as Any,
"wifi_ssid": wifiSSID as Any
]
}
@@ -415,8 +313,7 @@ struct BugReportService {
for line in airportOutput.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("SSID:") {
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(
in: .whitespaces)
return trimmed.replacingOccurrences(of: "SSID:", with: "").trimmingCharacters(in: .whitespaces)
}
}
return nil
@@ -453,7 +350,6 @@ struct BugReportService {
private struct DebugInfo {
let thunderboltBridgeDisabled: Bool?
let interfaces: [InterfaceStatus]
let rdma: RDMADebugInfo
struct InterfaceStatus {
let name: String
@@ -462,21 +358,7 @@ private struct DebugInfo {
func toDictionary() -> [String: Any] {
[
"name": name,
"ip": ip as Any,
]
}
}
struct RDMADebugInfo {
let rdmaCtlStatus: String?
let ibvDevices: String?
let ibvDevinfo: String?
func toDictionary() -> [String: Any] {
[
"rdma_ctl_status": rdmaCtlStatus as Any,
"ibv_devices": ibvDevices as Any,
"ibv_devinfo": ibvDevinfo as Any,
"ip": ip as Any
]
}
}
@@ -484,8 +366,7 @@ private struct DebugInfo {
func toDictionary() -> [String: Any] {
[
"thunderbolt_bridge_disabled": thunderboltBridgeDisabled as Any,
"interfaces": interfaces.map { $0.toDictionary() },
"rdma": rdma.toDictionary(),
"interfaces": interfaces.map { $0.toDictionary() }
]
}
}
@@ -495,3 +376,163 @@ private struct CommandResult {
let output: String
let error: String
}
private struct S3Uploader {
let config: BugReportService.AWSConfig
init(config: BugReportService.AWSConfig) throws {
self.config = config
}
func upload(objectPath: String, body: Data) async throws {
let host = "\(config.bucket).s3.amazonaws.com"
guard let url = URL(string: "https://\(host)/\(objectPath)") else {
throw BugReportError.invalidEndpoint
}
let now = Date()
let amzDate = awsTimestamp(now)
let dateStamp = dateStamp(now)
let payloadHash = sha256Hex(body)
let headers = [
"host": host,
"x-amz-content-sha256": payloadHash,
"x-amz-date": amzDate
]
let canonicalRequest = buildCanonicalRequest(
method: "PUT",
url: url,
headers: headers,
payloadHash: payloadHash
)
let stringToSign = buildStringToSign(
amzDate: amzDate,
dateStamp: dateStamp,
canonicalRequestHash: sha256Hex(canonicalRequest.data(using: .utf8) ?? Data())
)
let signingKey = deriveKey(secret: config.secretKey, dateStamp: dateStamp, region: config.region, service: "s3")
let signature = hmacHex(key: signingKey, data: Data(stringToSign.utf8))
let signedHeaders = "host;x-amz-content-sha256;x-amz-date"
let authorization = """
AWS4-HMAC-SHA256 Credential=\(config.accessKey)/\(dateStamp)/\(config.region)/s3/aws4_request, SignedHeaders=\(signedHeaders), Signature=\(signature)
"""
var request = URLRequest(url: url)
request.httpMethod = "PUT"
request.httpBody = body
request.setValue(headers["x-amz-content-sha256"], forHTTPHeaderField: "x-amz-content-sha256")
request.setValue(headers["x-amz-date"], forHTTPHeaderField: "x-amz-date")
request.setValue(host, forHTTPHeaderField: "Host")
request.setValue(authorization, forHTTPHeaderField: "Authorization")
let (data, response) = try await URLSession.shared.data(for: request)
guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else {
let statusText = (response as? HTTPURLResponse)?.statusCode ?? -1
_ = data // ignore response body for UX
throw BugReportError.uploadFailed("HTTP status \(statusText)")
}
}
private func buildCanonicalRequest(
method: String,
url: URL,
headers: [String: String],
payloadHash: String
) -> String {
let canonicalURI = encodePath(url.path)
let canonicalQuery = url.query ?? ""
let sortedHeaders = headers.sorted { $0.key < $1.key }
let canonicalHeaders = sortedHeaders
.map { "\($0.key.lowercased()):\($0.value)\n" }
.joined()
let signedHeaders = sortedHeaders.map { $0.key.lowercased() }.joined(separator: ";")
return [
method,
canonicalURI,
canonicalQuery,
canonicalHeaders,
signedHeaders,
payloadHash
].joined(separator: "\n")
}
private func encodePath(_ path: String) -> String {
return path
.split(separator: "/")
.map { segment in
segment.addingPercentEncoding(withAllowedCharacters: Self.rfc3986) ?? String(segment)
}
.joined(separator: "/")
.prependSlashIfNeeded()
}
private func buildStringToSign(
amzDate: String,
dateStamp: String,
canonicalRequestHash: String
) -> String {
"""
AWS4-HMAC-SHA256
\(amzDate)
\(dateStamp)/\(config.region)/s3/aws4_request
\(canonicalRequestHash)
"""
}
private func deriveKey(secret: String, dateStamp: String, region: String, service: String) -> Data {
let kDate = hmac(key: Data(("AWS4" + secret).utf8), data: Data(dateStamp.utf8))
let kRegion = hmac(key: kDate, data: Data(region.utf8))
let kService = hmac(key: kRegion, data: Data(service.utf8))
return hmac(key: kService, data: Data("aws4_request".utf8))
}
private func hmac(key: Data, data: Data) -> Data {
let keySym = SymmetricKey(data: key)
let mac = HMAC<SHA256>.authenticationCode(for: data, using: keySym)
return Data(mac)
}
private func hmacHex(key: Data, data: Data) -> String {
hmac(key: key, data: data).map { String(format: "%02x", $0) }.joined()
}
private func sha256Hex(_ data: Data) -> String {
let digest = SHA256.hash(data: data)
return digest.compactMap { String(format: "%02x", $0) }.joined()
}
private func awsTimestamp(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
formatter.timeZone = TimeZone(abbreviation: "UTC")
return formatter.string(from: date)
}
private func dateStamp(_ date: Date) -> String {
let formatter = DateFormatter()
formatter.dateFormat = "yyyyMMdd"
formatter.timeZone = TimeZone(abbreviation: "UTC")
return formatter.string(from: date)
}
private static let rfc3986: CharacterSet = {
var set = CharacterSet.alphanumerics
set.insert(charactersIn: "-._~")
return set
}()
}
private extension String {
func prependSlashIfNeeded() -> String {
if hasPrefix("/") {
return self
}
return "/" + self
}
}

View File

@@ -57,9 +57,7 @@ final class ClusterStateService: ObservableObject {
var request = URLRequest(url: url)
request.cachePolicy = .reloadIgnoringLocalCacheData
let (data, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
return
}
if let nodeId = try? decoder.decode(String.self, from: data) {
@@ -115,9 +113,7 @@ final class ClusterStateService: ObservableObject {
}
}
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int)
async
{
func launchInstance(modelId: String, sharding: String, instanceMeta: String, minNodes: Int) async {
do {
var request = URLRequest(url: baseURL.appendingPathComponent("instance"))
request.httpMethod = "POST"
@@ -126,7 +122,7 @@ final class ClusterStateService: ObservableObject {
"model_id": modelId,
"sharding": sharding,
"instance_meta": instanceMeta,
"min_nodes": minNodes,
"min_nodes": minNodes
]
request.httpBody = try JSONSerialization.data(withJSONObject: payload, options: [])
let (_, response) = try await session.data(for: request)
@@ -147,9 +143,7 @@ final class ClusterStateService: ObservableObject {
do {
let url = baseURL.appendingPathComponent("models")
let (data, response) = try await session.data(from: url)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
throw URLError(.badServerResponse)
}
let list = try decoder.decode(ModelListResponse.self, from: data)

View File

@@ -1,149 +0,0 @@
import Foundation
import Network
import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
case unknown
case checking
case working
case notWorking(reason: String)
var isHealthy: Bool {
if case .working = self { return true }
return false
}
var displayText: String {
switch self {
case .unknown:
return "Unknown"
case .checking:
return "Checking..."
case .working:
return "Working"
case .notWorking(let reason):
return reason
}
}
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
connection?.cancel()
connection = nil
// mDNS multicast address - same as libp2p uses for peer discovery
let host = NWEndpoint.Host("224.0.0.251")
let port = NWEndpoint.Port(integerLiteral: 5353)
let params = NWParameters.udp
params.allowLocalEndpointReuse = true
let conn = NWConnection(host: host, port: port, using: params)
connection = conn
return await withCheckedContinuation { continuation in
var hasResumed = false
let lock = NSLock()
let resumeOnce: (Status) -> Void = { status in
lock.lock()
defer { lock.unlock() }
guard !hasResumed else { return }
hasResumed = true
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { state in
switch state {
case .ready:
resumeOnce(.working)
case .waiting(let error):
let errorStr = "\(error)"
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|| errorStr.contains("permission") || errorStr.contains("denied")
{
resumeOnce(.notWorking(reason: "Permission denied"))
} else {
resumeOnce(.notWorking(reason: "Failed: \(error.localizedDescription)"))
}
case .cancelled, .setup, .preparing:
break
@unknown default:
break
}
}
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: timeout)
let state = conn.state
switch state {
case .ready:
resumeOnce(.working)
case .waiting, .preparing, .setup:
resumeOnce(.notWorking(reason: "Timeout (may be blocked)"))
default:
resumeOnce(.notWorking(reason: "Timeout"))
}
}
}
}
func stop() {
checkTask?.cancel()
checkTask = nil
connection?.cancel()
connection = nil
}
}

View File

@@ -5,37 +5,64 @@ import os.log
enum NetworkSetupHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge.sh"
private static let scriptDestination = "/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
private static let setupScript = """
#!/usr/bin/env bash
#!/usr/bin/env bash
set -euo pipefail
set -euo pipefail
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
static func ensureLaunchDaemonInstalled() {
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
Task.detached(priority: .utility) {
Task.detached {
do {
if daemonAlreadyInstalled() {
return
@@ -43,86 +70,19 @@ enum NetworkSetupHelper {
try await installLaunchDaemon()
logger.info("Network setup launch daemon installed and started")
} catch {
logger.error(
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
)
logger.error("Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)")
}
}
}
/// Removes all EXO network setup components from the system.
/// This includes the LaunchDaemon, scripts, logs, and network location.
/// Requires admin privileges.
static func uninstall() throws {
let uninstallScript = makeUninstallScript()
try runShellAsAdmin(uninstallScript)
logger.info("EXO network setup components removed successfully")
}
/// Checks if there are any EXO network components installed that need cleanup
static func hasInstalledComponents() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
return scriptExists || plistExists
}
private static func makeUninstallScript() -> String {
"""
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
LOG_OUT="/var/log/\(daemonLabel).log"
LOG_ERR="/var/log/\(daemonLabel).err.log"
# Unload the LaunchDaemon if running
launchctl bootout system/"$LABEL" 2>/dev/null || true
# Remove LaunchDaemon plist
rm -f "$PLIST_DEST"
# Remove the script and parent directory if empty
rm -f "$SCRIPT_DEST"
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
# Remove log files
rm -f "$LOG_OUT" "$LOG_ERR"
# Switch back to Automatic network location
networksetup -switchtolocation Automatic 2>/dev/null || true
# Delete the exo network location if it exists
networksetup -listlocations | grep -q '^exo$' && {
networksetup -deletelocation exo 2>/dev/null || true
} || true
# Re-enable Thunderbolt Bridge if it exists
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
} || true
echo "EXO network components removed successfully"
"""
}
private static func daemonAlreadyInstalled() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
guard scriptExists, plistExists else { return false }
guard
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
else {
return false
}
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(
from: data, options: [], format: nil) as? [String: Any]
let plist = try? PropertyListSerialization.propertyList(from: data, options: [], format: nil) as? [String: Any]
else {
return false
}
@@ -132,9 +92,7 @@ enum NetworkSetupHelper {
else {
return false
}
if let programArgs = plist["ProgramArguments"] as? [String],
programArgs.contains(scriptDestination) == false
{
if let programArgs = plist["ProgramArguments"] as? [String], programArgs.contains(scriptDestination) == false {
return false
}
return true
@@ -147,59 +105,58 @@ enum NetworkSetupHelper {
private static func makeInstallerScript() -> String {
"""
set -euo pipefail
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
PLIST_DEST="\(plistDestination)"
mkdir -p "$(dirname "$SCRIPT_DEST")"
mkdir -p "$(dirname "$SCRIPT_DEST")"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
}
private static func runShellAsAdmin(_ script: String) throws {
let escapedScript =
script
let escapedScript = script
.replacingOccurrences(of: "\\", with: "\\\\")
.replacingOccurrences(of: "\"", with: "\\\"")
let appleScriptSource = """
do shell script "\(escapedScript)" with administrator privileges
"""
do shell script "\(escapedScript)" with administrator privileges
"""
guard let appleScript = NSAppleScript(source: appleScriptSource) else {
throw NetworkSetupError.scriptCreationFailed

View File

@@ -35,34 +35,14 @@ struct NetworkStatus: Equatable {
let thunderboltBridgeState: ThunderboltState?
let bridgeInactive: Bool?
let interfaceStatuses: [InterfaceIpStatus]
let rdmaStatus: RDMAStatus
static let empty = NetworkStatus(
thunderboltBridgeState: nil,
bridgeInactive: nil,
interfaceStatuses: [],
rdmaStatus: .empty
interfaceStatuses: []
)
}
struct RDMAStatus: Equatable {
let rdmaCtlEnabled: Bool?
let devices: [String]
let activePorts: [RDMAPort]
var isAvailable: Bool {
rdmaCtlEnabled == true || !devices.isEmpty
}
static let empty = RDMAStatus(rdmaCtlEnabled: nil, devices: [], activePorts: [])
}
struct RDMAPort: Equatable {
let device: String
let port: String
let state: String
}
struct InterfaceIpStatus: Equatable {
let interfaceName: String
let ipAddress: String?
@@ -79,79 +59,10 @@ private struct NetworkStatusFetcher {
NetworkStatus(
thunderboltBridgeState: readThunderboltBridgeState(),
bridgeInactive: readBridgeInactive(),
interfaceStatuses: readInterfaceStatuses(),
rdmaStatus: readRDMAStatus()
interfaceStatuses: readInterfaceStatuses()
)
}
private func readRDMAStatus() -> RDMAStatus {
let rdmaCtlEnabled = readRDMACtlEnabled()
let devices = readRDMADevices()
let activePorts = readRDMAActivePorts()
return RDMAStatus(
rdmaCtlEnabled: rdmaCtlEnabled, devices: devices, activePorts: activePorts)
}
private func readRDMACtlEnabled() -> Bool? {
let result = runCommand(["rdma_ctl", "status"])
guard result.exitCode == 0 else { return nil }
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
if output.contains("enabled") {
return true
}
if output.contains("disabled") {
return false
}
return nil
}
private func readRDMADevices() -> [String] {
let result = runCommand(["ibv_devices"])
guard result.exitCode == 0 else { return [] }
var devices: [String] = []
for line in result.output.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("---") || trimmed.lowercased().hasPrefix("device")
|| trimmed.isEmpty
{
continue
}
let parts = trimmed.split(separator: " ", maxSplits: 1)
if let deviceName = parts.first {
devices.append(String(deviceName))
}
}
return devices
}
private func readRDMAActivePorts() -> [RDMAPort] {
let result = runCommand(["ibv_devinfo"])
guard result.exitCode == 0 else { return [] }
var ports: [RDMAPort] = []
var currentDevice: String?
var currentPort: String?
for line in result.output.split(separator: "\n") {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasPrefix("hca_id:") {
currentDevice = trimmed.replacingOccurrences(of: "hca_id:", with: "")
.trimmingCharacters(in: .whitespaces)
} else if trimmed.hasPrefix("port:") {
currentPort = trimmed.replacingOccurrences(of: "port:", with: "")
.trimmingCharacters(in: .whitespaces)
} else if trimmed.hasPrefix("state:") {
let state = trimmed.replacingOccurrences(of: "state:", with: "").trimmingCharacters(
in: .whitespaces)
if let device = currentDevice, let port = currentPort {
if state.lowercased().contains("active") {
ports.append(RDMAPort(device: device, port: port, state: state))
}
}
}
}
return ports
}
private func readThunderboltBridgeState() -> ThunderboltState? {
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
guard result.exitCode == 0 else {
@@ -174,11 +85,10 @@ private struct NetworkStatusFetcher {
private func readBridgeInactive() -> Bool? {
let result = runCommand(["ifconfig", "bridge0"])
guard result.exitCode == 0 else { return nil }
guard
let statusLine = result.output
.components(separatedBy: .newlines)
.first(where: { $0.contains("status:") })?
.lowercased()
guard let statusLine = result.output
.components(separatedBy: .newlines)
.first(where: { $0.contains("status:") })?
.lowercased()
else {
return nil
}
@@ -261,3 +171,4 @@ private struct NetworkStatusFetcher {
)
}
}

View File

@@ -57,7 +57,7 @@ struct InstanceViewModel: Identifiable, Equatable {
case waiting
case failed
case idle
case preparing
case unknown
var label: String {
switch self {
@@ -68,7 +68,7 @@ struct InstanceViewModel: Identifiable, Equatable {
case .waiting: return "Waiting"
case .failed: return "Failed"
case .idle: return "Idle"
case .preparing: return "Preparing"
case .unknown: return "Unknown"
}
}
}
@@ -107,13 +107,10 @@ extension ClusterState {
let nodeToRunner = instance.shardAssignments.nodeToRunner
let nodeIds = Array(nodeToRunner.keys)
let runnerIds = Array(nodeToRunner.values)
let nodeNames = nodeIds.compactMap {
nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0
}
let nodeNames = nodeIds.compactMap { nodeProfiles[$0]?.friendlyName ?? nodeProfiles[$0]?.modelId ?? $0 }
let statuses = runnerIds.compactMap { runners[$0]?.status.lowercased() }
let downloadProgress = aggregateDownloadProgress(for: nodeIds)
let state = InstanceViewModel.State(
statuses: statuses, hasActiveDownload: downloadProgress != nil)
let state = InstanceViewModel.State(statuses: statuses, hasActiveDownload: downloadProgress != nil)
let chatTasks = (chatTasksByInstance[entry.key] ?? [])
.sorted(by: { $0.sortPriority < $1.sortPriority })
.map { InstanceTaskViewModel(task: $0) }
@@ -168,8 +165,8 @@ extension ClusterState {
}
}
extension InstanceViewModel.State {
fileprivate init(statuses: [String], hasActiveDownload: Bool = false) {
private extension InstanceViewModel.State {
init(statuses: [String], hasActiveDownload: Bool = false) {
if statuses.contains(where: { $0.contains("failed") }) {
self = .failed
} else if hasActiveDownload || statuses.contains(where: { $0.contains("downloading") }) {
@@ -185,7 +182,7 @@ extension InstanceViewModel.State {
} else if statuses.isEmpty {
self = .idle
} else {
self = .preparing
self = .unknown
}
}
}
@@ -246,3 +243,4 @@ extension InstanceTaskViewModel {
self.parameters = task.parameters
}
}

View File

@@ -87,9 +87,7 @@ struct TopologyViewModel {
extension ClusterState {
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
let allNodes = nodeViewModels().filter {
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
}
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
guard !allNodes.isEmpty else { return nil }
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
@@ -108,24 +106,18 @@ extension ClusterState {
}
// Rotate so the local node (from /node_id API) is first
if let localId = localNodeId,
let index = orderedNodes.firstIndex(where: { $0.id == localId })
{
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
}
let nodeIds = Set(orderedNodes.map(\.id))
let edgesArray: [TopologyEdgeViewModel] =
topology?.connections?.compactMap { connection in
guard nodeIds.contains(connection.localNodeId),
nodeIds.contains(connection.sendBackNodeId)
else { return nil }
return TopologyEdgeViewModel(
sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
} ?? []
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
return TopologyEdgeViewModel(sourceId: connection.localNodeId, targetId: connection.sendBackNodeId)
} ?? []
let edges = Set(edgesArray)
return TopologyViewModel(
nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
}
}

View File

@@ -20,8 +20,8 @@ struct InstanceRowView: View {
if let progress = instance.downloadProgress {
downloadStatusView(progress: progress)
} else {
statusChip(label: instance.state.label.uppercased(), color: statusColor)
}
statusChip(label: instance.state.label.uppercased(), color: statusColor)
}
}
if let progress = instance.downloadProgress {
GeometryReader { geometry in
@@ -83,7 +83,7 @@ struct InstanceRowView: View {
case .ready: return .teal
case .waiting, .idle: return .gray
case .failed: return .red
case .preparing: return .secondary
case .unknown: return .secondary
}
}
@@ -97,8 +97,7 @@ struct InstanceRowView: View {
.font(.caption)
.fontWeight(.semibold)
if let subtitle = task.subtitle,
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame
{
subtitle.caseInsensitiveCompare(parentModelName) != .orderedSame {
Text(subtitle)
.font(.caption2)
.foregroundColor(.secondary)
@@ -235,12 +234,9 @@ struct InstanceRowView: View {
Button {
isExpanded.wrappedValue.toggle()
} label: {
Label(
isExpanded.wrappedValue ? "Hide" : "Show",
systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down"
)
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
Label(isExpanded.wrappedValue ? "Hide" : "Show", systemImage: isExpanded.wrappedValue ? "chevron.up" : "chevron.down")
.labelStyle(.titleAndIcon)
.contentTransition(.symbolEffect(.replace))
}
.buttonStyle(.plain)
.font(.caption2)
@@ -315,9 +311,7 @@ struct InstanceRowView: View {
}
@ViewBuilder
private func detailRow(
icon: String? = nil, title: String, value: String, tint: Color = .secondary
) -> some View {
private func detailRow(icon: String? = nil, title: String, value: String, tint: Color = .secondary) -> some View {
HStack(alignment: .firstTextBaseline, spacing: 6) {
if let icon {
Image(systemName: icon)
@@ -335,3 +329,4 @@ struct InstanceRowView: View {
}
}
}

View File

@@ -32,3 +32,4 @@ struct NodeDetailView: View {
}
}
}

View File

@@ -28,3 +28,4 @@ struct NodeRowView: View {
.padding(.vertical, 4)
}
}

View File

@@ -76,33 +76,30 @@ struct TopologyMiniView: View {
private func connectionLines(in size: CGSize) -> some View {
let positions = positionedNodes(in: size)
let positionById = Dictionary(
uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
let positionById = Dictionary(uniqueKeysWithValues: positions.map { ($0.node.id, $0.point) })
return Canvas { context, _ in
guard !topology.edges.isEmpty else { return }
let nodeRadius: CGFloat = 32
let arrowLength: CGFloat = 10
let arrowSpread: CGFloat = .pi / 7
for edge in topology.edges {
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId]
else { continue }
guard let start = positionById[edge.sourceId], let end = positionById[edge.targetId] else { continue }
let dx = end.x - start.x
let dy = end.y - start.y
let distance = max(CGFloat(hypot(dx, dy)), 1)
let ux = dx / distance
let uy = dy / distance
let adjustedStart = CGPoint(
x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
let adjustedStart = CGPoint(x: start.x + ux * nodeRadius, y: start.y + uy * nodeRadius)
let adjustedEnd = CGPoint(x: end.x - ux * nodeRadius, y: end.y - uy * nodeRadius)
var linePath = Path()
linePath.move(to: adjustedStart)
linePath.addLine(to: adjustedEnd)
context.stroke(
context.stroke(
linePath,
with: .color(.secondary.opacity(0.3)),
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
)
style: StrokeStyle(lineWidth: 1, dash: [4, 4])
)
let angle = atan2(uy, ux)
let tip = adjustedEnd
@@ -171,3 +168,5 @@ private struct NodeGlyphView: View {
.frame(width: 95)
}
}

View File

@@ -6,7 +6,6 @@
//
import Testing
@testable import EXO
struct EXOTests {

View File

@@ -1,154 +0,0 @@
#!/usr/bin/env bash
#
# EXO Uninstaller Script
#
# This script removes all EXO system components that persist after deleting the app.
# Run with: sudo ./uninstall-exo.sh
#
# Components removed:
# - LaunchDaemon: /Library/LaunchDaemons/io.exo.networksetup.plist
# - Network script: /Library/Application Support/EXO/
# - Log files: /var/log/io.exo.networksetup.*
# - Network location: "exo"
# - Launch at login registration
#
set -euo pipefail
LABEL="io.exo.networksetup"
SCRIPT_DEST="/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
PLIST_DEST="/Library/LaunchDaemons/io.exo.networksetup.plist"
LOG_OUT="/var/log/${LABEL}.log"
LOG_ERR="/var/log/${LABEL}.err.log"
APP_BUNDLE_ID="io.exo.EXO"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo_info() {
echo -e "${GREEN}[INFO]${NC} $1"
}
echo_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
}
echo_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
# Check if running as root
if [[ $EUID -ne 0 ]]; then
echo_error "This script must be run as root (use sudo)"
exit 1
fi
echo ""
echo "========================================"
echo " EXO Uninstaller"
echo "========================================"
echo ""
# Unload the LaunchDaemon if running
echo_info "Stopping network setup daemon..."
if launchctl list | grep -q "$LABEL"; then
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
else
echo_warn "Daemon was not running"
fi
# Remove LaunchDaemon plist
if [[ -f "$PLIST_DEST" ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
else
echo_warn "LaunchDaemon plist not found (already removed?)"
fi
# Remove the script and parent directory
if [[ -f "$SCRIPT_DEST" ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
else
echo_warn "Network setup script not found (already removed?)"
fi
# Remove EXO directory if empty
if [[ -d "/Library/Application Support/EXO" ]]; then
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
echo_info "Removed EXO support directory" || \
echo_warn "EXO support directory not empty, leaving in place"
fi
# Remove log files
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
else
echo_warn "Log files not found (already removed?)"
fi
# Switch back to Automatic network location
echo_info "Restoring network configuration..."
if networksetup -listlocations | grep -q "^Automatic$"; then
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
else
echo_warn "Automatic network location not found"
fi
# Delete the exo network location if it exists
if networksetup -listlocations | grep -q "^exo$"; then
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
else
echo_warn "'exo' network location not found (already removed?)"
fi
# Re-enable Thunderbolt Bridge if it exists
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
fi
# Note about launch at login registration
# SMAppService-based login items cannot be removed from a shell script.
# They can only be unregistered from within the app itself or manually via System Settings.
echo_warn "Launch at login must be removed manually:"
echo_warn " System Settings → General → Login Items → Remove EXO"
# Check if EXO.app exists in common locations
APP_FOUND=false
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
if [[ -d "$app_path" ]]; then
if [[ "$APP_FOUND" == false ]]; then
echo ""
APP_FOUND=true
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
done
echo ""
echo "========================================"
echo_info "EXO uninstall complete!"
echo "========================================"
echo ""
echo "The following have been removed:"
echo " • Network setup LaunchDaemon"
echo " • Network configuration script"
echo " • Log files"
echo " • 'exo' network location"
echo ""
echo "Your network has been restored to use the 'Automatic' location."
echo "Thunderbolt Bridge has been re-enabled (if present)."
echo ""
echo "Manual step required:"
echo " Remove EXO from Login Items in System Settings → General → Login Items"
echo ""

View File

@@ -1,563 +0,0 @@
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
from __future__ import annotations
import argparse
import contextlib
import http.client
import json
import os
import time
from collections.abc import Callable
from statistics import mean
from typing import Any
from urllib.parse import urlencode
from loguru import logger
from transformers import AutoTokenizer
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
super().__init__(f"HTTP {status} {reason}: {body_preview}")
self.status = status
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
def request_json(
self,
method: str,
path: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
) -> Any:
if not path.startswith("/"):
path = "/" + path
if params:
path = path + "?" + urlencode(params)
conn = http.client.HTTPConnection(self.host, self.port, timeout=self.timeout_s)
try:
payload: bytes | None = None
hdrs: dict[str, str] = {"Accept": "application/json"}
if body is not None:
payload = json.dumps(body).encode("utf-8")
hdrs["Content-Type"] = "application/json"
if headers:
hdrs.update(headers)
conn.request(method.upper(), path, body=payload, headers=hdrs)
resp = conn.getresponse()
raw = resp.read()
text = raw.decode("utf-8", errors="replace") if raw else ""
if resp.status >= 400:
raise ExoHttpError(resp.status, resp.reason, text[:300])
if not text:
return None
return json.loads(text)
finally:
conn.close()
def post_bench_chat_completions(self, payload: dict[str, Any]) -> dict[str, Any]:
return self.request_json("POST", "/bench/chat/completions", body=payload)
def unwrap_instance(instance: dict[str, Any]) -> dict[str, Any]:
if len(instance) != 1:
raise KeyError(f"Expected 1 key, got keys={list(instance.keys())}")
tag = next(iter(instance))
inner = instance[tag]
if not isinstance(inner, dict):
raise TypeError(f"payload for {tag} must be dict, got {type(inner)}")
return inner
def instance_id_from_instance(instance: dict[str, Any]) -> str:
inner = unwrap_instance(instance)
return str(inner["instanceId"])
def nodes_used_in_instance(instance: dict[str, Any]) -> int:
inner = unwrap_instance(instance)
return len(inner["shardAssignments"]["nodeToRunner"])
def runner_ids_from_instance(instance: dict[str, Any]) -> list[str]:
inner = unwrap_instance(instance)
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
return list(runner_to_shard.keys())
def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
time.sleep(0.1)
raise TimeoutError(f"Instance {instance_id} did not become ready within {timeout=}")
def wait_for_instance_gone(
client: ExoClient, instance_id: str, timeout: float = 3.0
) -> None:
start_time = time.time()
while time.time() - start_time < timeout:
try:
client.request_json("GET", f"/instance/{instance_id}")
time.sleep(0.4)
except ExoHttpError as e:
if e.status == 404:
return
raise TimeoutError(f"Instance {instance_id} did not get deleted within {timeout=}")
def format_peak_memory(b: float) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if b < 1024.0:
return f"{b:.2f}{unit}"
b /= 1024.0
raise ValueError("You're using petabytes of memory. Something went wrong...")
def parse_int_list(values: list[str]) -> list[int]:
items: list[int] = []
for v in values:
for part in v.split(","):
part = part.strip()
if part:
items.append(int(part))
seen: set[int] = set()
out: list[int] = []
for x in items:
if x not in seen:
out.append(x)
seen.add(x)
return out
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
models = client.request_json("GET", "/models") or {}
data = models.get("data") or []
for m in data:
if m.get("id") == model_arg:
short_id = str(m["id"])
full_id = str(m.get("hugging_face_id") or m["id"])
return short_id, full_id
for m in data:
if m.get("hugging_face_id") == model_arg:
short_id = str(m["id"])
full_id = str(m["hugging_face_id"])
return short_id, full_id
raise ValueError(f"Model not found in /models: {model_arg}")
def placement_filter(instance_meta: str, wanted: str) -> bool:
s = (instance_meta or "").lower()
if wanted == "both":
return ("ring" in s) or ("jaccl" in s)
return wanted in s
def sharding_filter(sharding: str, wanted: str) -> bool:
s = (sharding or "").lower()
if wanted == "both":
return ("pipeline" in s) or ("tensor" in s)
return wanted in s
def run_one_completion(
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
) -> tuple[dict[str, Any], int]:
content, pp_tokens = prompt_sizer.build(pp_hint)
payload: dict[str, Any] = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": tg,
}
t0 = time.perf_counter()
out = client.post_bench_chat_completions(payload)
elapsed = time.perf_counter() - t0
stats = out.get("generation_stats")
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
return {
"elapsed_s": elapsed,
"output_text_preview": preview,
"stats": stats,
}, pp_tokens
class PromptSizer:
def __init__(self, tokenizer: Any, atom: str = "a "):
self.tokenizer = tokenizer
self.atom = atom
self.count_fn = PromptSizer._make_counter(tokenizer)
self.base_tokens = self.count_fn("")
@staticmethod
def _make_counter(tokenizer: Any) -> Callable[[str], int]:
def count_fn(user_content: str) -> int:
messages = [{"role": "user", "content": user_content}]
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn
def build(self, target_prompt_tokens: int) -> tuple[str, int]:
target = int(target_prompt_tokens)
if target < self.base_tokens:
raise RuntimeError(
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
)
content = ""
tok = self.count_fn(content)
while tok < target:
content += self.atom
tok = self.count_fn(content)
if tok != target:
raise RuntimeError(
f"Overshot: got {tok} tokens (target {target}). "
f"Pick a different atom (try ' a' or '\\n' or '0 ')."
)
return content, tok
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-bench",
description="Benchmark exo model throughput across placement previews.",
)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="Model short id or huggingface id")
ap.add_argument(
"--pp",
nargs="+",
required=True,
help="Prompt-size hints (ints). Accepts commas.",
)
ap.add_argument(
"--tg",
nargs="+",
required=True,
help="Generation lengths (ints). Accepts commas.",
)
ap.add_argument(
"--max-nodes",
type=int,
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
ap.add_argument(
"--sharding", choices=["pipeline", "tensor", "both"], default="both"
)
ap.add_argument(
"--skip-pipeline-jaccl",
action="store_true",
help="Pipeline jaccl is often pointless, skip by default",
)
ap.add_argument(
"--repeat", type=int, default=1, help="Repetitions per (pp,tg) pair."
)
ap.add_argument(
"--warmup",
type=int,
default=0,
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
default="bench/results.json",
help="Write raw per-run results JSON to this path.",
)
ap.add_argument(
"--dry-run", action="store_true", help="List selected placements and exit."
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
tg_list = parse_int_list(args.tg)
if not pp_list or not tg_list:
logger.error("pp and tg lists must be non-empty")
return 2
if args.repeat <= 0:
logger.error("--repeat must be >= 1")
return 2
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
short_id, full_model_id = resolve_model_short_id(client, args.model)
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": short_id}
)
previews = previews_resp.get("previews") or []
tokenizer = AutoTokenizer.from_pretrained(
full_model_id,
trust_remote_code=True,
)
if tokenizer is None:
raise RuntimeError("[exo-bench] tokenizer load failed")
try:
prompt_sizer = PromptSizer(tokenizer)
logger.debug(f"[exo-bench] loaded tokenizer: {full_model_id} for prompt sizer")
except Exception:
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
raise
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), args.instance_meta):
continue
if not sharding_filter(str(p.get("sharding", "")), args.sharding):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
# Skip tensor ring single node as it is pointless when pipeline ring
if n == 1 and (
(args.sharding == "both" and "tensor" in p.get("sharding", "").lower())
or (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
):
continue
if (
args.skip_pipeline_jaccl
and (
args.instance_meta == "both"
and "jaccl" in p.get("instance_meta", "").lower()
)
and (
args.sharding == "both" and "pipeline" in p.get("sharding", "").lower()
)
):
continue
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
if not selected:
logger.error("No valid placements matched your filters.")
return 1
selected.sort(
key=lambda p: (
str(p.get("instance_meta", "")),
str(p.get("sharding", "")),
-nodes_used_in_instance(p["instance"]),
),
reverse=True,
)
logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}")
logger.info(f"placements: {len(selected)}")
for p in selected:
logger.info(
f" - {p['sharding']} / {p['instance_meta']} / nodes={nodes_used_in_instance(p['instance'])}"
)
if args.dry_run:
return 0
all_rows: list[dict[str, Any]] = []
for preview in selected:
instance = preview["instance"]
instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"])
instance_meta = str(preview["instance_meta"])
n_nodes = nodes_used_in_instance(instance)
logger.info("=" * 80)
logger.info(
f"PLACEMENT: {sharding} / {instance_meta} / nodes={n_nodes} / instance_id={instance_id}"
)
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
time.sleep(1)
try:
for i in range(args.warmup):
run_one_completion(
client, full_model_id, pp_list[0], tg_list[0], prompt_sizer
)
logger.debug(f" warmup {i + 1}/{args.warmup} done")
for pp in pp_list:
# if (
# pp * n_nodes > 2048
# and "ring" in instance_meta.lower()
# and "tensor" in sharding.lower()
# ):
# model_card = MODEL_CARDS[short_id]
# if model_card.metadata.storage_size > Memory.from_gb(10):
# logger.info(
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
# )
# continue
for tg in tg_list:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):
time.sleep(3)
try:
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
)
except Exception as e:
logger.error(e)
continue
row.update(
{
"model_short_id": short_id,
"model_id": full_model_id,
"placement_sharding": sharding,
"placement_instance_meta": instance_meta,
"placement_nodes": n_nodes,
"instance_id": instance_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
runs.append(row)
all_rows.append(row)
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
)
logger.info(
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)}\n"
)
time.sleep(2)
finally:
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
wait_for_instance_gone(client, instance_id)
logger.debug(f"Deleted instance {instance_id}")
time.sleep(5)
if args.json_out:
with open(args.json_out, "w", encoding="utf-8") as f:
json.dump(all_rows, f, indent=2, ensure_ascii=False)
logger.debug(f"\nWrote results JSON: {args.json_out}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -1,60 +0,0 @@
{ lib
, config
, dream2nix
, ...
}:
let
# Read and parse the lock file
rawLockFile = builtins.fromJSON (builtins.readFile "${config.deps.dashboardSrc}/package-lock.json");
# For packages with bundleDependencies, filter out deps that are bundled
# (bundled deps are inside the tarball, not separate lockfile entries)
fixedPackages = lib.mapAttrs
(path: entry:
if entry ? bundleDependencies && entry.bundleDependencies != [ ]
then entry // {
dependencies = lib.filterAttrs
(name: _: !(lib.elem name entry.bundleDependencies))
(entry.dependencies or { });
}
else entry
)
(rawLockFile.packages or { });
fixedLockFile = rawLockFile // { packages = fixedPackages; };
in
{
imports = [
dream2nix.modules.dream2nix.nodejs-package-lock-v3
dream2nix.modules.dream2nix.nodejs-granular-v3
];
name = "exo-dashboard";
version = "1.0.0";
mkDerivation = {
src = config.deps.dashboardSrc;
buildPhase = ''
runHook preBuild
npm run build
runHook postBuild
'';
installPhase = ''
runHook preInstall
cp -r build $out/build
runHook postInstall
'';
};
deps = { nixpkgs, ... }: {
inherit (nixpkgs) stdenv;
dashboardSrc = null; # Injected by parts.nix
};
nodejs-package-lock-v3 = {
# Don't use packageLockFile - provide the fixed lock content directly
packageLock = fixedLockFile;
};
}

View File

@@ -1,44 +0,0 @@
{ inputs, ... }:
{
perSystem =
{ pkgs, lib, ... }:
let
# Filter source to only include dashboard directory
src = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
inDashboardDir =
(lib.hasInfix "/dashboard/" path)
|| (lib.hasSuffix "/dashboard" (builtins.dirOf path))
|| (baseName == "dashboard" && type == "directory");
in
inDashboardDir;
};
# Build the dashboard with dream2nix (includes node_modules in output)
dashboardFull = inputs.dream2nix.lib.evalModules {
packageSets.nixpkgs = pkgs;
modules = [
./dashboard.nix
{
paths.projectRoot = inputs.self;
paths.projectRootFile = "flake.nix";
paths.package = inputs.self + "/dashboard";
}
# Inject the filtered source
{
deps.dashboardSrc = lib.mkForce "${src}/dashboard";
}
];
};
in
{
# Extract just the static site from the full build
packages.dashboard = pkgs.runCommand "exo-dashboard" { } ''
cp -r ${dashboardFull}/build $out
'';
};
}

View File

@@ -11,3 +11,4 @@ declare global {
}
export {};

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import { isLoading, sendMessage, generateImage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import ChatAttachments from './ChatAttachments.svelte';
import type { ChatUploadedFile } from '$lib/types/files';
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
@@ -10,6 +10,7 @@
showHelperText?: boolean;
autofocus?: boolean;
showModelSelector?: boolean;
modelTasks?: Record<string, string[]>;
}
let {
@@ -17,7 +18,8 @@
placeholder = 'Ask anything',
showHelperText = false,
autofocus = true,
showModelSelector = false
showModelSelector = false,
modelTasks = {}
}: Props = $props();
let message = $state('');
@@ -48,51 +50,40 @@
// Accept all supported file types
const acceptString = getAcceptString(['image', 'text', 'pdf']);
// Check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const tasks = modelTasks[modelId] || [];
return tasks.includes('TextToImage') || tasks.includes('ImageToImage');
}
// Check if the currently selected model supports image generation
const isImageModel = $derived(() => {
if (!currentModel) return false;
return modelSupportsImageGeneration(currentModel);
});
// Extract available models from running instances
const availableModels = $derived(() => {
const models: Array<{id: string, label: string}> = [];
const models: Array<{id: string, label: string, isImageModel: boolean}> = [];
for (const [, instance] of Object.entries(instanceData)) {
const modelId = getInstanceModelId(instance);
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
models.push({
id: modelId,
label: modelId.split('/').pop() || modelId,
isImageModel: modelSupportsImageGeneration(modelId)
});
}
}
return models;
});
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
// Auto-select the first available model if none is selected
$effect(() => {
const models = availableModels();
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {
@@ -187,7 +178,12 @@
uploadedFiles = [];
resetTextareaHeight();
sendMessage(content, files);
// Use image generation for image models
if (isImageModel() && content) {
generateImage(content);
} else {
sendMessage(content, files);
}
// Refocus the textarea after sending
setTimeout(() => textareaRef?.focus(), 10);
@@ -324,7 +320,14 @@
{:else}
<span class="w-3"></span>
{/if}
<span class="truncate">{model.label}</span>
{#if model.isImageModel}
<svg class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate flex-1">{model.label}</span>
</button>
{/each}
</div>
@@ -384,7 +387,7 @@
onkeydown={handleKeydown}
oninput={handleInput}
onpaste={handlePaste}
{placeholder}
placeholder={isImageModel() ? 'Describe the image you want to generate...' : placeholder}
disabled={loading}
rows={1}
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
@@ -398,14 +401,23 @@
{!canSend || loading
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label="Send message"
aria-label={isImageModel() ? "Generate image" : "Send message"}
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
<span class="hidden sm:inline">PROCESSING</span>
<span class="hidden sm:inline">{isImageModel() ? 'GENERATING' : 'PROCESSING'}</span>
<span class="sm:hidden">...</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}

View File

@@ -365,10 +365,58 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
</div>
{/if}
<!-- Generated Images -->
{#if message.attachments?.some(a => a.type === 'generated-image')}
<div class="mb-3">
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
<div class="relative group/img inline-block">
<img
src={attachment.preview}
alt=""
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
/>
<!-- Download button overlay -->
<button
type="button"
class="absolute top-2 right-2 p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow opacity-0 group-hover/img:opacity-100 transition-opacity hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
onclick={() => {
if (attachment.preview) {
const link = document.createElement('a');
link.href = attachment.preview;
link.download = `generated-image-${Date.now()}.png`;
link.click();
}
}}
title="Download image"
>
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
</svg>
</button>
</div>
{/each}
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{#if message.content === 'Generating image...'}
<div class="flex items-center gap-3 text-exo-yellow">
<div class="relative">
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
</div>
<span class="font-mono tracking-wider uppercase text-sm">Generating image...</span>
</div>
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
{/if}
</div>
</div>

View File

@@ -53,285 +53,62 @@
marked.use({ renderer });
/**
* Unescape HTML entities that marked may have escaped
*/
function unescapeHtmlEntities(text: string): string {
return text
.replace(/&lt;/g, '<')
.replace(/&gt;/g, '>')
.replace(/&amp;/g, '&')
.replace(/&quot;/g, '"')
.replace(/&#39;/g, "'");
}
// Storage for math expressions extracted before markdown processing
const mathExpressions: Map<string, { content: string; displayMode: boolean }> = new Map();
let mathCounter = 0;
// Storage for HTML snippets that need protection from markdown
const htmlSnippets: Map<string, string> = new Map();
let htmlCounter = 0;
// Use alphanumeric placeholders that won't be interpreted as HTML tags
const MATH_PLACEHOLDER_PREFIX = 'MATHPLACEHOLDER';
const CODE_PLACEHOLDER_PREFIX = 'CODEPLACEHOLDER';
const HTML_PLACEHOLDER_PREFIX = 'HTMLPLACEHOLDER';
/**
* Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
* Also protect code blocks from LaTeX processing
*/
function preprocessLaTeX(text: string): string {
// Reset storage
mathExpressions.clear();
mathCounter = 0;
htmlSnippets.clear();
htmlCounter = 0;
// Protect code blocks first
// Protect code blocks
const codeBlocks: string[] = [];
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
codeBlocks.push(match);
return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;
return `<<CODE_${codeBlocks.length - 1}>>`;
});
// Remove LaTeX document commands
processed = processed.replace(/\\documentclass(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\usepackage(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\begin\{document\}/g, '');
processed = processed.replace(/\\end\{document\}/g, '');
processed = processed.replace(/\\maketitle/g, '');
processed = processed.replace(/\\title\{[^}]*\}/g, '');
processed = processed.replace(/\\author\{[^}]*\}/g, '');
processed = processed.replace(/\\date\{[^}]*\}/g, '');
// Remove \require{...} commands (MathJax-specific, not supported by KaTeX)
processed = processed.replace(/\$\\require\{[^}]*\}\$/g, '');
processed = processed.replace(/\\require\{[^}]*\}/g, '');
// Remove unsupported LaTeX commands/environments (tikzpicture, figure, center, etc.)
processed = processed.replace(/\\begin\{tikzpicture\}[\s\S]*?\\end\{tikzpicture\}/g, () => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">📐</span><span class="latex-diagram-text">Diagram</span></div>');
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\begin\{figure\}[\s\S]*?\\end\{figure\}/g, () => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">🖼️</span><span class="latex-diagram-text">Figure</span></div>');
htmlCounter++;
return placeholder;
});
// Strip center environment (layout only, no content change)
processed = processed.replace(/\\begin\{center\}/g, '');
processed = processed.replace(/\\end\{center\}/g, '');
// Strip other layout environments
processed = processed.replace(/\\begin\{flushleft\}/g, '');
processed = processed.replace(/\\end\{flushleft\}/g, '');
processed = processed.replace(/\\begin\{flushright\}/g, '');
processed = processed.replace(/\\end\{flushright\}/g, '');
processed = processed.replace(/\\label\{[^}]*\}/g, '');
processed = processed.replace(/\\caption\{[^}]*\}/g, '');
// Protect escaped dollar signs (e.g., \$50 should become $50, not LaTeX)
processed = processed.replace(/\\\$/g, 'ESCAPEDDOLLARPLACEHOLDER');
// Convert LaTeX math environments to display math (both bare and wrapped in $...$)
const mathEnvs = ['align', 'align\\*', 'equation', 'equation\\*', 'gather', 'gather\\*', 'multline', 'multline\\*', 'eqnarray', 'eqnarray\\*', 'array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix', 'cases'];
for (const env of mathEnvs) {
// Handle $\begin{env}...\end{env}$ (with dollar signs, possibly multiline)
const wrappedRegex = new RegExp(`\\$\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}\\$`, 'g');
processed = processed.replace(wrappedRegex, (_, args, content) => {
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
// Handle bare \begin{env}...\end{env} (without dollar signs)
const bareRegex = new RegExp(`\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}`, 'g');
processed = processed.replace(bareRegex, (_, args, content) => {
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
}
// Convert LaTeX proof environments to styled blocks (use placeholders for HTML)
processed = processed.replace(
/\\begin\{proof\}([\s\S]*?)\\end\{proof\}/g,
(_, content) => {
const html = `<div class="latex-proof"><div class="latex-proof-header">Proof</div><div class="latex-proof-content">${content}</div></div>`;
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, html);
htmlCounter++;
return placeholder;
}
);
// Convert LaTeX theorem-like environments
const theoremEnvs = ['theorem', 'lemma', 'corollary', 'proposition', 'definition', 'remark', 'example'];
for (const env of theoremEnvs) {
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'gi');
const envName = env.charAt(0).toUpperCase() + env.slice(1);
processed = processed.replace(envRegex, (_, content) => {
const html = `<div class="latex-theorem"><div class="latex-theorem-header">${envName}</div><div class="latex-theorem-content">${content}</div></div>`;
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, html);
htmlCounter++;
return placeholder;
});
}
// Convert LaTeX text formatting commands (use placeholders to protect from markdown)
processed = processed.replace(/\\emph\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<em>${content}</em>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\textit\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<em>${content}</em>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\textbf\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<strong>${content}</strong>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\texttt\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<code class="inline-code">${content}</code>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\underline\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<u>${content}</u>`);
htmlCounter++;
return placeholder;
});
// Handle LaTeX line breaks and spacing
processed = processed.replace(/\\\\(?:\s*\n)?/g, '\n'); // \\ -> newline
processed = processed.replace(/\\newline/g, '\n');
processed = processed.replace(/\\par\b/g, '\n\n');
processed = processed.replace(/\\quad/g, ' ');
processed = processed.replace(/\\qquad/g, ' ');
processed = processed.replace(/~~/g, ' '); // non-breaking space
// Remove other common LaTeX commands that don't render
processed = processed.replace(/\\centering/g, '');
processed = processed.replace(/\\noindent/g, '');
processed = processed.replace(/\\hfill/g, '');
processed = processed.replace(/\\vspace\{[^}]*\}/g, '');
processed = processed.replace(/\\hspace\{[^}]*\}/g, ' ');
// Convert \(...\) to placeholder (display: false)
processed = processed.replace(/\\\(([\s\S]+?)\\\)/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: false });
mathCounter++;
return placeholder;
});
// Convert \[...\] to placeholder (display: true)
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: true });
mathCounter++;
return placeholder;
});
// Extract display math ($$...$$) BEFORE markdown processing
processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: true });
mathCounter++;
return placeholder;
});
// Extract inline math ($...$) BEFORE markdown processing
// Allow single-line only, skip currency patterns like $5 or $50
processed = processed.replace(/\$([^\$\n]+?)\$/g, (match, content) => {
if (/^\d/.test(content.trim())) {
return match; // Keep as-is for currency
}
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: false });
mathCounter++;
return placeholder;
});
// Restore escaped dollar signs
processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, '$');
// Convert \(...\) to $...$
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
// Convert \[...\] to $$...$$
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
// Restore code blocks
processed = processed.replace(new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\d+)END`, 'g'), (_, index) => codeBlocks[parseInt(index)]);
// Clean up any remaining stray backslashes from unrecognized commands
processed = processed.replace(/\\(?=[a-zA-Z])/g, ''); // Remove \ before letters (unrecognized commands)
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
return processed;
}
/**
* Render math expressions with KaTeX and restore HTML placeholders
* Render math expressions with KaTeX after HTML is generated
*/
function renderMath(html: string): string {
// Replace all math placeholders with rendered KaTeX
for (const [placeholder, { content, displayMode }] of mathExpressions) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
// Render display math ($$...$$)
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
try {
return katex.renderToString(math.trim(), {
displayMode: true,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$$${math}$$</span>`;
}
});
html = html.replace(regex, () => {
try {
const rendered = katex.renderToString(content, {
displayMode,
throwOnError: false,
output: 'html'
});
if (displayMode) {
return `
<div class="math-display-wrapper">
<div class="math-display-header">
<span class="math-label">LaTeX</span>
<button type="button" class="copy-math-btn" data-math-source="${encodeURIComponent(content)}" title="Copy LaTeX source">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
</svg>
</button>
</div>
<div class="math-display-content">
${rendered}
</div>
</div>
`;
} else {
return `<span class="math-inline">${rendered}</span>`;
}
} catch {
const display = displayMode ? `$$${content}$$` : `$${content}$`;
return `<span class="math-error"><span class="math-error-icon">⚠</span> ${display}</span>`;
}
});
}
// Restore HTML placeholders (for \textbf, \emph, etc.)
for (const [placeholder, htmlContent] of htmlSnippets) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
html = html.replace(regex, htmlContent);
}
// Render inline math ($...$) but avoid matching currency like $5
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
// Skip if it looks like currency ($ followed by number)
if (/^\d/.test(math.trim())) {
return match;
}
try {
return katex.renderToString(math.trim(), {
displayMode: false,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$${math}$</span>`;
}
});
return html;
}
@@ -377,50 +154,16 @@
}
}
async function handleMathCopyClick(event: Event) {
const target = event.currentTarget as HTMLButtonElement;
const encodedSource = target.getAttribute('data-math-source');
if (!encodedSource) return;
const source = decodeURIComponent(encodedSource);
try {
await navigator.clipboard.writeText(source);
// Show copied feedback
const originalHtml = target.innerHTML;
target.innerHTML = `
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M20 6L9 17l-5-5"/>
</svg>
`;
target.classList.add('copied');
setTimeout(() => {
target.innerHTML = originalHtml;
target.classList.remove('copied');
}, 2000);
} catch (error) {
console.error('Failed to copy math:', error);
}
}
function setupCopyButtons() {
if (!containerRef || !browser) return;
const codeButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of codeButtons) {
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of buttons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleCopyClick);
}
}
const mathButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-math-btn');
for (const button of mathButtons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleMathCopyClick);
}
}
}
$effect(() => {
@@ -681,290 +424,28 @@
color: #60a5fa;
}
/* KaTeX math styling - Base */
/* KaTeX math styling */
.markdown-content :global(.katex) {
font-size: 1.1em;
color: oklch(0.9 0 0);
}
/* Display math container wrapper */
.markdown-content :global(.math-display-wrapper) {
.markdown-content :global(.katex-display) {
margin: 1rem 0;
border-radius: 0.5rem;
overflow: hidden;
border: 1px solid rgba(255, 215, 0, 0.15);
background: rgba(0, 0, 0, 0.3);
transition: border-color 0.2s ease, box-shadow 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover) {
border-color: rgba(255, 215, 0, 0.25);
box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);
}
/* Display math header - hidden by default, slides in on hover */
.markdown-content :global(.math-display-header) {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.375rem 0.75rem;
background: rgba(255, 215, 0, 0.03);
border-bottom: 1px solid rgba(255, 215, 0, 0.08);
opacity: 0;
max-height: 0;
padding-top: 0;
padding-bottom: 0;
overflow: hidden;
transition:
opacity 0.2s ease,
max-height 0.2s ease,
padding 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover .math-display-header) {
opacity: 1;
max-height: 2.5rem;
padding: 0.375rem 0.75rem;
}
.markdown-content :global(.math-label) {
color: rgba(255, 215, 0, 0.7);
font-size: 0.65rem;
font-weight: 500;
text-transform: uppercase;
letter-spacing: 0.1em;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
}
.markdown-content :global(.copy-math-btn) {
display: flex;
align-items: center;
justify-content: center;
padding: 0.25rem;
background: transparent;
border: none;
color: var(--exo-light-gray, #9ca3af);
cursor: pointer;
transition: color 0.2s;
border-radius: 0.25rem;
opacity: 0;
transition:
color 0.2s,
opacity 0.15s ease;
}
.markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {
opacity: 1;
}
.markdown-content :global(.copy-math-btn:hover) {
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(.copy-math-btn.copied) {
color: #22c55e;
}
/* Display math content area */
.markdown-content :global(.math-display-content) {
padding: 1rem 1.25rem;
overflow-x: auto;
overflow-y: hidden;
padding: 0.5rem 0;
}
/* Custom scrollbar for math overflow */
.markdown-content :global(.math-display-content::-webkit-scrollbar) {
height: 6px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-track) {
background: rgba(255, 255, 255, 0.05);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {
background: rgba(255, 215, 0, 0.2);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb:hover) {
background: rgba(255, 215, 0, 0.35);
}
.markdown-content :global(.math-display-content .katex-display) {
margin: 0;
padding: 0;
}
.markdown-content :global(.math-display-content .katex-display > .katex) {
.markdown-content :global(.katex-display > .katex) {
text-align: center;
}
/* Inline math wrapper */
.markdown-content :global(.math-inline) {
display: inline;
padding: 0 0.125rem;
border-radius: 0.25rem;
transition: background-color 0.15s ease;
}
.markdown-content :global(.math-inline:hover) {
background: rgba(255, 215, 0, 0.05);
}
/* Dark theme KaTeX overrides */
.markdown-content :global(.katex .mord),
.markdown-content :global(.katex .minner),
.markdown-content :global(.katex .mop),
.markdown-content :global(.katex .mbin),
.markdown-content :global(.katex .mrel),
.markdown-content :global(.katex .mpunct) {
color: oklch(0.9 0 0);
}
/* Fraction lines and rules */
.markdown-content :global(.katex .frac-line),
.markdown-content :global(.katex .overline-line),
.markdown-content :global(.katex .underline-line),
.markdown-content :global(.katex .hline),
.markdown-content :global(.katex .rule) {
border-color: oklch(0.85 0 0) !important;
background: oklch(0.85 0 0);
}
/* Square roots and SVG elements */
.markdown-content :global(.katex .sqrt-line) {
border-color: oklch(0.85 0 0) !important;
}
.markdown-content :global(.katex svg) {
fill: oklch(0.85 0 0);
stroke: oklch(0.85 0 0);
}
.markdown-content :global(.katex svg path) {
stroke: oklch(0.85 0 0);
}
/* Delimiters (parentheses, brackets, braces) */
.markdown-content :global(.katex .delimsizing),
.markdown-content :global(.katex .delim-size1),
.markdown-content :global(.katex .delim-size2),
.markdown-content :global(.katex .delim-size3),
.markdown-content :global(.katex .delim-size4),
.markdown-content :global(.katex .mopen),
.markdown-content :global(.katex .mclose) {
color: oklch(0.75 0 0);
}
/* Math error styling */
.markdown-content :global(.math-error) {
display: inline-flex;
align-items: center;
gap: 0.375rem;
color: #f87171;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.875em;
background: rgba(248, 113, 113, 0.1);
padding: 0.25rem 0.5rem;
padding: 0.125rem 0.25rem;
border-radius: 0.25rem;
border: 1px solid rgba(248, 113, 113, 0.2);
}
.markdown-content :global(.math-error-icon) {
font-size: 0.875em;
opacity: 0.9;
}
/* LaTeX proof environment */
.markdown-content :global(.latex-proof) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 255, 255, 0.02);
border-left: 3px solid rgba(255, 215, 0, 0.4);
border-radius: 0 0.375rem 0.375rem 0;
}
.markdown-content :global(.latex-proof-header) {
font-weight: 600;
font-style: italic;
color: oklch(0.85 0 0);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-proof-header::after) {
content: '.';
}
.markdown-content :global(.latex-proof-content) {
color: oklch(0.9 0 0);
}
.markdown-content :global(.latex-proof-content p:last-child) {
margin-bottom: 0;
}
/* QED symbol at end of proof */
.markdown-content :global(.latex-proof-content::after) {
content: '∎';
display: block;
text-align: right;
color: oklch(0.7 0 0);
margin-top: 0.5rem;
}
/* LaTeX theorem-like environments */
.markdown-content :global(.latex-theorem) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 215, 0, 0.03);
border: 1px solid rgba(255, 215, 0, 0.15);
border-radius: 0.375rem;
}
.markdown-content :global(.latex-theorem-header) {
font-weight: 700;
color: var(--exo-yellow, #ffd700);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-theorem-header::after) {
content: '.';
}
.markdown-content :global(.latex-theorem-content) {
color: oklch(0.9 0 0);
font-style: italic;
}
.markdown-content :global(.latex-theorem-content p:last-child) {
margin-bottom: 0;
}
/* LaTeX diagram/figure placeholder */
.markdown-content :global(.latex-diagram-placeholder) {
display: flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
margin: 1rem 0;
padding: 1.5rem 2rem;
background: rgba(255, 255, 255, 0.02);
border: 1px dashed rgba(255, 215, 0, 0.25);
border-radius: 0.5rem;
color: rgba(255, 215, 0, 0.6);
font-size: 0.875rem;
}
.markdown-content :global(.latex-diagram-icon) {
font-size: 1.25rem;
opacity: 0.8;
}
.markdown-content :global(.latex-diagram-text) {
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.75rem;
text-transform: uppercase;
letter-spacing: 0.05em;
}
</style>

View File

@@ -197,7 +197,7 @@ function toggleNodeDetails(nodeId: string): void {
// Uses API preview data when available, falls back to local estimation
const placementPreview = $derived(() => {
const nodeArray = nodeList();
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, topoWidth: 260, topoHeight: 90, error: null };
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, error: null };
const numNodes = nodeArray.length;
const iconSize = numNodes === 1 ? 50 : 36;

View File

@@ -1,7 +1,7 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
interface Props {
class?: string;
@@ -24,14 +24,14 @@ function getNodeLabel(nodeId: string): string {
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
if (!ip) return { label: '?', missing: true };
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
// Helper to check a node's interfaces
function checkNode(node: NodeInfo | undefined): string | null {
function checkNode(node: typeof data.nodes[string]): string | null {
if (!node) return null;
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
@@ -39,19 +39,17 @@ function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missin
return matchFromInterfaces.name;
}
if (node.ip_to_interface) {
const mapped = node.ip_to_interface[cleanIp] || (ip ? node.ip_to_interface[ip] : undefined);
if (mapped && mapped.trim().length > 0) {
return mapped;
}
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
if (mapped && mapped.trim().length > 0) {
return mapped;
}
return null;
}
// Try specified node first
const result = checkNode(data?.nodes?.[nodeId]);
if (result) return { label: result, missing: false };
// Fallback: search all nodes for this IP
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
const otherResult = checkNode(otherNode);
@@ -257,24 +255,21 @@ function wrapLine(text: string, maxLen: number): string[] {
const arrowsGroup = svg.append('g').attr('class', 'arrows-group');
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
type ConnectionInfo = { from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean };
type PairEntry = { a: string; b: string; aToB: boolean; bToA: boolean; connections: ConnectionInfo[] };
type DebugEdgeLabelEntry = { connections: ConnectionInfo[]; isLeft: boolean; isTop: boolean; mx: number; my: number };
const pairMap = new Map<string, PairEntry>();
const debugEdgeLabels: DebugEdgeLabelEntry[] = [];
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
edges.forEach(edge => {
if (!edge.source || !edge.target || edge.source === edge.target) return;
if (!positionById[edge.source] || !positionById[edge.target]) return;
const a = edge.source < edge.target ? edge.source : edge.target;
const b = edge.source < edge.target ? edge.target : edge.source;
const key = `${a}|${b}`;
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false, connections: [] };
if (edge.source === a) entry.aToB = true;
else entry.bToA = true;
const ip = edge.sendBackIp || '?';
const ip = edge.sendBackIp || edge.sendBackMultiaddr?.ip_address || '?';
const ifaceInfo = getInterfaceLabel(edge.source, ip);
entry.connections.push({
from: edge.source,
@@ -343,8 +338,9 @@ function wrapLine(text: string, maxLen: number): string[] {
// Determine which side of viewport based on edge midpoint
const isLeft = mx < centerX;
const isTop = my < safeCenterY;
// Store for batch rendering after all edges processed
if (!debugEdgeLabels) debugEdgeLabels = [];
debugEdgeLabels.push({
connections: entry.connections,
isLeft,
@@ -385,32 +381,32 @@ function wrapLine(text: string, maxLen: number): string[] {
}
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
const quadrants: Record<string, DebugEdgeLabelEntry[]> = {
const quadrants: Record<string, typeof debugEdgeLabels> = {
topLeft: [],
topRight: [],
bottomLeft: [],
bottomRight: []
};
debugEdgeLabels.forEach(edge => {
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
quadrants[key].push(edge);
});
// Render each quadrant
Object.entries(quadrants).forEach(([quadrant, quadrantEdges]) => {
if (quadrantEdges.length === 0) return;
Object.entries(quadrants).forEach(([quadrant, edges]) => {
if (edges.length === 0) return;
const isLeft = quadrant.includes('Left');
const isTop = quadrant.includes('top');
let baseX = isLeft ? padding : width - padding;
let baseY = isTop ? padding : height - padding;
const textAnchor = isLeft ? 'start' : 'end';
let currentY = baseY;
quadrantEdges.forEach(edge => {
edges.forEach(edge => {
edge.connections.forEach(conn => {
const arrow = getArrow(conn.from, conn.to);
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;

View File

@@ -1,7 +1,8 @@
export { default as TopologyGraph } from "./TopologyGraph.svelte";
export { default as ChatForm } from "./ChatForm.svelte";
export { default as ChatMessages } from "./ChatMessages.svelte";
export { default as ChatAttachments } from "./ChatAttachments.svelte";
export { default as ChatSidebar } from "./ChatSidebar.svelte";
export { default as ModelCard } from "./ModelCard.svelte";
export { default as MarkdownContent } from "./MarkdownContent.svelte";
export { default as TopologyGraph } from './TopologyGraph.svelte';
export { default as ChatForm } from './ChatForm.svelte';
export { default as ChatMessages } from './ChatMessages.svelte';
export { default as ChatAttachments } from './ChatAttachments.svelte';
export { default as ChatSidebar } from './ChatSidebar.svelte';
export { default as ModelCard } from './ModelCard.svelte';
export { default as MarkdownContent } from './MarkdownContent.svelte';

View File

File diff suppressed because it is too large Load Diff

View File

@@ -13,124 +13,55 @@ export interface ChatUploadedFile {
}
export interface ChatAttachment {
type: "image" | "text" | "pdf" | "audio";
type: 'image' | 'text' | 'pdf' | 'audio';
name: string;
content?: string;
base64Url?: string;
mimeType?: string;
}
export type FileCategory = "image" | "text" | "pdf" | "audio" | "unknown";
export type FileCategory = 'image' | 'text' | 'pdf' | 'audio' | 'unknown';
export const IMAGE_EXTENSIONS = [
".jpg",
".jpeg",
".png",
".gif",
".webp",
".svg",
];
export const IMAGE_MIME_TYPES = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
"image/svg+xml",
];
export const IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.svg'];
export const IMAGE_MIME_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/svg+xml'];
export const TEXT_EXTENSIONS = [
".txt",
".md",
".json",
".xml",
".yaml",
".yml",
".csv",
".log",
".js",
".ts",
".jsx",
".tsx",
".py",
".java",
".cpp",
".c",
".h",
".css",
".html",
".htm",
".sql",
".sh",
".bat",
".rs",
".go",
".rb",
".php",
".swift",
".kt",
".scala",
".r",
".dart",
".vue",
".svelte",
'.txt', '.md', '.json', '.xml', '.yaml', '.yml', '.csv', '.log',
'.js', '.ts', '.jsx', '.tsx', '.py', '.java', '.cpp', '.c', '.h',
'.css', '.html', '.htm', '.sql', '.sh', '.bat', '.rs', '.go',
'.rb', '.php', '.swift', '.kt', '.scala', '.r', '.dart', '.vue', '.svelte'
];
export const TEXT_MIME_TYPES = [
"text/plain",
"text/markdown",
"text/csv",
"text/html",
"text/css",
"application/json",
"application/xml",
"text/xml",
"application/javascript",
"text/javascript",
"application/typescript",
'text/plain', 'text/markdown', 'text/csv', 'text/html', 'text/css',
'application/json', 'application/xml', 'text/xml', 'application/javascript',
'text/javascript', 'application/typescript'
];
export const PDF_EXTENSIONS = [".pdf"];
export const PDF_MIME_TYPES = ["application/pdf"];
export const PDF_EXTENSIONS = ['.pdf'];
export const PDF_MIME_TYPES = ['application/pdf'];
export const AUDIO_EXTENSIONS = [".mp3", ".wav", ".ogg", ".m4a"];
export const AUDIO_MIME_TYPES = [
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/mp4",
];
export const AUDIO_EXTENSIONS = ['.mp3', '.wav', '.ogg', '.m4a'];
export const AUDIO_MIME_TYPES = ['audio/mpeg', 'audio/wav', 'audio/ogg', 'audio/mp4'];
/**
* Get file category based on MIME type and extension
*/
export function getFileCategory(
mimeType: string,
fileName: string,
): FileCategory {
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf("."));
if (
IMAGE_MIME_TYPES.includes(mimeType) ||
IMAGE_EXTENSIONS.includes(extension)
) {
return "image";
export function getFileCategory(mimeType: string, fileName: string): FileCategory {
const extension = fileName.toLowerCase().slice(fileName.lastIndexOf('.'));
if (IMAGE_MIME_TYPES.includes(mimeType) || IMAGE_EXTENSIONS.includes(extension)) {
return 'image';
}
if (PDF_MIME_TYPES.includes(mimeType) || PDF_EXTENSIONS.includes(extension)) {
return "pdf";
return 'pdf';
}
if (
AUDIO_MIME_TYPES.includes(mimeType) ||
AUDIO_EXTENSIONS.includes(extension)
) {
return "audio";
if (AUDIO_MIME_TYPES.includes(mimeType) || AUDIO_EXTENSIONS.includes(extension)) {
return 'audio';
}
if (
TEXT_MIME_TYPES.includes(mimeType) ||
TEXT_EXTENSIONS.includes(extension) ||
mimeType.startsWith("text/")
) {
return "text";
if (TEXT_MIME_TYPES.includes(mimeType) || TEXT_EXTENSIONS.includes(extension) || mimeType.startsWith('text/')) {
return 'text';
}
return "unknown";
return 'unknown';
}
/**
@@ -138,36 +69,36 @@ export function getFileCategory(
*/
export function getAcceptString(categories: FileCategory[]): string {
const accepts: string[] = [];
for (const category of categories) {
switch (category) {
case "image":
case 'image':
accepts.push(...IMAGE_EXTENSIONS, ...IMAGE_MIME_TYPES);
break;
case "text":
case 'text':
accepts.push(...TEXT_EXTENSIONS, ...TEXT_MIME_TYPES);
break;
case "pdf":
case 'pdf':
accepts.push(...PDF_EXTENSIONS, ...PDF_MIME_TYPES);
break;
case "audio":
case 'audio':
accepts.push(...AUDIO_EXTENSIONS, ...AUDIO_MIME_TYPES);
break;
}
}
return accepts.join(",");
return accepts.join(',');
}
/**
* Format file size for display
*/
export function formatFileSize(bytes: number): string {
if (bytes === 0) return "0 B";
if (bytes === 0) return '0 B';
const k = 1024;
const sizes = ["B", "KB", "MB", "GB"];
const sizes = ['B', 'KB', 'MB', 'GB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + " " + sizes[i];
return parseFloat((bytes / Math.pow(k, i)).toFixed(1)) + ' ' + sizes[i];
}
/**
@@ -197,44 +128,42 @@ export function readFileAsText(file: File): Promise<string> {
/**
* Process uploaded files into ChatUploadedFile format
*/
export async function processUploadedFiles(
files: File[],
): Promise<ChatUploadedFile[]> {
export async function processUploadedFiles(files: File[]): Promise<ChatUploadedFile[]> {
const results: ChatUploadedFile[] = [];
for (const file of files) {
const id =
Date.now().toString() + Math.random().toString(36).substring(2, 9);
const id = Date.now().toString() + Math.random().toString(36).substring(2, 9);
const category = getFileCategory(file.type, file.name);
const base: ChatUploadedFile = {
id,
name: file.name,
size: file.size,
type: file.type,
file,
file
};
try {
if (category === "image") {
if (category === 'image') {
const preview = await readFileAsDataURL(file);
results.push({ ...base, preview });
} else if (category === "text" || category === "unknown") {
} else if (category === 'text' || category === 'unknown') {
const textContent = await readFileAsText(file);
results.push({ ...base, textContent });
} else if (category === "pdf") {
} else if (category === 'pdf') {
results.push(base);
} else if (category === "audio") {
} else if (category === 'audio') {
const preview = await readFileAsDataURL(file);
results.push({ ...base, preview });
} else {
results.push(base);
}
} catch (error) {
console.error("Error processing file:", file.name, error);
console.error('Error processing file:', file.name, error);
results.push(base);
}
}
return results;
}

View File

@@ -47,7 +47,30 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
const modelTasks = $derived(() => {
const tasks: Record<string, string[]> = {};
for (const model of models) {
if (model.tasks && model.tasks.length > 0) {
// Map by short ID
tasks[model.id] = model.tasks;
// Also map by hugging_face_id from the API response
if (model.hugging_face_id) {
tasks[model.hugging_face_id] = model.tasks;
}
}
}
return tasks;
});
// Helper to check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
if (!model?.tasks) return false;
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
}
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -400,8 +423,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -434,8 +459,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
// Model meta is nested: shard.model_card.model_id
const modelMeta = shardData.model_card ?? shardData.modelCard;
// Model meta is nested: shard.model_meta.model_id
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -591,7 +616,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
// Unwrap the instance
const [instanceTag, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { isDownloading: false, progress: null, statusText: 'PREPARING', perNode: [] };
return { isDownloading: false, progress: null, statusText: 'UNKNOWN', perNode: [] };
}
const inst = instance as { shardAssignments?: { nodeToRunner?: Record<string, string>; runnerToShard?: Record<string, unknown>; modelId?: string } };
@@ -704,7 +729,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function deriveInstanceStatus(instanceWrapped: unknown): { statusText: string; statusClass: string } {
const [, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { statusText: 'PREPARING', statusClass: 'inactive' };
return { statusText: 'UNKNOWN', statusClass: 'inactive' };
}
const inst = instance as { shardAssignments?: { runnerToShard?: Record<string, unknown> } };
@@ -733,7 +758,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const has = (s: string) => statuses.includes(s);
if (statuses.length === 0) return { statusText: 'PREPARING', statusClass: 'inactive' };
if (statuses.length === 0) return { statusText: 'UNKNOWN', statusClass: 'inactive' };
if (has('Failed')) return { statusText: 'FAILED', statusClass: 'failed' };
if (has('Shutdown')) return { statusText: 'SHUTDOWN', statusClass: 'inactive' };
if (has('Loading')) return { statusText: 'LOADING', statusClass: 'starting' };
@@ -761,10 +786,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -773,24 +794,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);
@@ -915,7 +918,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const runnerEntries = Object.entries(runnerToShard).map(([runnerId, shardWrapped]) => {
const [tag, shard] = getTagged(shardWrapped);
const meta = (shard as { modelMeta?: { worldSize?: number; nLayers?: number; deviceRank?: number } } | undefined);
const deviceRank = meta?.modelMeta?.deviceRank ?? 0;
const deviceRank = (meta?.deviceRank as number | undefined) ?? 0;
return { runnerId, tag, deviceRank };
});
@@ -1270,6 +1273,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
placeholder="Ask anything"
showHelperText={false}
showModelSelector={true}
modelTasks={modelTasks()}
/>
</div>
</div>
@@ -1287,9 +1291,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
</div>
<div
<div
bind:this={instancesContainerRef}
class="max-h-72 xl:max-h-96 space-y-3 overflow-y-auto overflow-x-hidden py-px"
class="max-h-72 space-y-3 overflow-y-auto"
>
{#each Object.entries(instanceData) as [id, instance]}
{@const downloadInfo = getInstanceDownloadStatus(id, instance)}
@@ -1491,8 +1495,18 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{@const foundModel = models.find(m => m.id === selectedModelId)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="flex items-center gap-2 text-exo-light-gray truncate">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{foundModel.name || foundModel.id}</span>
</span>
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
</span>
{:else}
@@ -1537,6 +1551,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(model.id)}
<button
type="button"
onclick={() => {
@@ -1556,7 +1571,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
: 'text-white/30 cursor-default'
}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
@@ -1753,7 +1777,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
<div class="max-w-7xl mx-auto">
<ChatForm placeholder="Ask anything" showModelSelector={true} />
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
</div>
</div>
</div>
@@ -1793,7 +1817,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<h3 class="text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
</div>
<div class="space-y-3 max-h-72 xl:max-h-96 overflow-y-auto overflow-x-hidden py-px pr-1">
<div class="space-y-3 max-h-72 overflow-y-auto pr-1">
{#each Object.entries(instanceData) as [id, instance]}
{@const downloadInfo = getInstanceDownloadStatus(id, instance)}
{@const statusText = downloadInfo.statusText}

View File

@@ -98,7 +98,7 @@
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
const modelMeta = shardData.model_card ?? shardData.modelCard;
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -190,7 +190,7 @@
const shardKeys = Object.keys(shardObj);
if (shardKeys.length !== 1) return null;
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.prettyName as string) ?? null;
@@ -199,13 +199,7 @@
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
?? (downloadPayload as Record<string, unknown>).downloadProgress
?? {};
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
const totalBytes = getBytes(
(downloadPayload as Record<string, unknown>).total_bytes
?? (downloadPayload as Record<string, unknown>).totalBytes
?? (rawProgress as Record<string, unknown>).total_bytes
?? (rawProgress as Record<string, unknown>).totalBytes
);
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
@@ -338,13 +332,8 @@
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
</div>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
<div>
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
</div>
<div class="text-exo-light-gray normal-case tracking-normal">
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
</div>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
</div>
</div>
@@ -396,7 +385,7 @@
</div>
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
{#if model.status !== 'completed'}
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
{/if}

View File

@@ -1,15 +1,16 @@
import tailwindcss from "@tailwindcss/vite";
import { sveltekit } from "@sveltejs/kit/vite";
import { defineConfig } from "vite";
import tailwindcss from '@tailwindcss/vite';
import { sveltekit } from '@sveltejs/kit/vite';
import { defineConfig } from 'vite';
export default defineConfig({
plugins: [tailwindcss(), sveltekit()],
server: {
proxy: {
"/v1": "http://localhost:52415",
"/state": "http://localhost:52415",
"/models": "http://localhost:52415",
"/instance": "http://localhost:52415",
},
},
'/v1': 'http://localhost:52415',
'/state': 'http://localhost:52415',
'/models': 'http://localhost:52415',
'/instance': 'http://localhost:52415'
}
}
});

View File

@@ -1,212 +0,0 @@
# EXO API Technical Reference
This document describes the REST API exposed by the **EXO ** service, as implemented in:
`src/exo/master/api.py`
The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using an OpenAI-compatible interface.
Base URL example:
```
http://localhost:52415
```
## 1. General / Meta Endpoints
### Get Master Node ID
**GET** `/node_id`
Returns the identifier of the current master node.
**Response (example):**
```json
{
"node_id": "node-1234"
}
```
### Get Cluster State
**GET** `/state`
Returns the current state of the cluster, including nodes and active instances.
**Response:**
JSON object describing topology, nodes, and instances.
### Get Events
**GET** `/events`
Returns the list of internal events recorded by the master (mainly for debugging and observability).
**Response:**
Array of event objects.
## 2. Model Instance Management
### Create Instance
**POST** `/instance`
Creates a new model instance in the cluster.
**Request body (example):**
```json
{
"instance": {
"model_id": "llama-3.2-1b",
"placement": { }
}
}
```
**Response:**
JSON description of the created instance.
### Delete Instance
**DELETE** `/instance/{instance_id}`
Deletes an existing instance by ID.
**Path parameters:**
* `instance_id`: string, ID of the instance to delete
**Response:**
Status / confirmation JSON.
### Get Instance
**GET** `/instance/{instance_id}`
Returns details of a specific instance.
**Path parameters:**
* `instance_id`: string
**Response:**
JSON description of the instance.
### Preview Placements
**GET** `/instance/previews?model_id=...`
Returns possible placement previews for a given model.
**Query parameters:**
* `model_id`: string, required
**Response:**
Array of placement preview objects.
### Compute Placement
**GET** `/instance/placement`
Computes a placement for a potential instance without creating it.
**Query parameters (typical):**
* `model_id`: string
* `sharding`: string or config
* `instance_meta`: JSON-encoded metadata
* `min_nodes`: integer
**Response:**
JSON object describing the proposed placement / instance configuration.
### Place Instance (Dry Operation)
**POST** `/place_instance`
Performs a placement operation for an instance (planning step), without necessarily creating it.
**Request body:**
JSON describing the instance to be placed.
**Response:**
Placement result.
## 3. Models
### List Models
**GET** `/models`
**GET** `/v1/models` (alias)
Returns the list of available models and their metadata.
**Response:**
Array of model descriptors.
## 4. Inference / Chat Completions
### OpenAI-Compatible Chat Completions
**POST** `/v1/chat/completions`
Executes a chat completion request using an OpenAI-compatible schema. Supports streaming and non-streaming modes.
**Request body (example):**
```json
{
"model": "llama-3.2-1b",
"messages": [
{ "role": "system", "content": "You are a helpful assistant." },
{ "role": "user", "content": "Hello" }
],
"stream": false
}
```
**Response:**
OpenAI-compatible chat completion response.
### Benchmarked Chat Completions
**POST** `/bench/chat/completions`
Same as `/v1/chat/completions`, but also returns performance and generation statistics.
**Request body:**
Same schema as `/v1/chat/completions`.
**Response:**
Chat completion plus benchmarking metrics.
## 5. Complete Endpoint Summary
```
GET /node_id
GET /state
GET /events
POST /instance
GET /instance/{instance_id}
DELETE /instance/{instance_id}
GET /instance/previews
GET /instance/placement
POST /place_instance
GET /models
GET /v1/models
POST /v1/chat/completions
POST /bench/chat/completions
```
## 6. Notes
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 187 KiB

181
flake.lock generated
View File

@@ -1,42 +1,5 @@
{
"nodes": {
"crane": {
"locked": {
"lastModified": 1767744144,
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
"owner": "ipetkov",
"repo": "crane",
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
"type": "github"
},
"original": {
"owner": "ipetkov",
"repo": "crane",
"type": "github"
}
},
"dream2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": "pyproject-nix"
},
"locked": {
"lastModified": 1765953015,
"narHash": "sha256-5FBZbbWR1Csp3Y2icfRkxMJw/a/5FGg8hCXej2//bbI=",
"owner": "nix-community",
"repo": "dream2nix",
"rev": "69eb01fa0995e1e90add49d8ca5bcba213b0416f",
"type": "github"
},
"original": {
"owner": "nix-community",
"repo": "dream2nix",
"type": "github"
}
},
"fenix": {
"inputs": {
"nixpkgs": [
@@ -45,11 +8,11 @@
"rust-analyzer-src": "rust-analyzer-src"
},
"locked": {
"lastModified": 1768287139,
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
"lastModified": 1761893049,
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
"owner": "nix-community",
"repo": "fenix",
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
"type": "github"
},
"original": {
@@ -58,59 +21,25 @@
"type": "github"
}
},
"flake-compat": {
"flake": false,
"locked": {
"lastModified": 1696426674,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
"owner": "edolstra",
"repo": "flake-compat",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
"type": "github"
},
"original": {
"owner": "edolstra",
"repo": "flake-compat",
"type": "github"
}
},
"flake-parts": {
"flake-utils": {
"inputs": {
"nixpkgs-lib": [
"nixpkgs"
]
"systems": "systems"
},
"locked": {
"lastModified": 1768135262,
"narHash": "sha256-PVvu7OqHBGWN16zSi6tEmPwwHQ4rLPU9Plvs8/1TUBY=",
"owner": "hercules-ci",
"repo": "flake-parts",
"rev": "80daad04eddbbf5a4d883996a73f3f542fa437ac",
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "hercules-ci",
"repo": "flake-parts",
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1768127708,
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"nixpkgs-swift": {
"locked": {
"lastModified": 1761672384,
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
@@ -121,74 +50,27 @@
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"type": "github"
}
},
"purescript-overlay": {
"inputs": {
"flake-compat": "flake-compat",
"nixpkgs": [
"dream2nix",
"nixpkgs"
],
"slimlock": "slimlock"
},
"locked": {
"lastModified": 1728546539,
"narHash": "sha256-Sws7w0tlnjD+Bjck1nv29NjC5DbL6nH5auL9Ex9Iz2A=",
"owner": "thomashoneyman",
"repo": "purescript-overlay",
"rev": "4ad4c15d07bd899d7346b331f377606631eb0ee4",
"type": "github"
},
"original": {
"owner": "thomashoneyman",
"repo": "purescript-overlay",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"type": "github"
}
},
"root": {
"inputs": {
"crane": "crane",
"dream2nix": "dream2nix",
"fenix": "fenix",
"flake-parts": "flake-parts",
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"treefmt-nix": "treefmt-nix"
}
},
"rust-analyzer-src": {
"flake": false,
"locked": {
"lastModified": 1768224240,
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
"lastModified": 1761849405,
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
"owner": "rust-lang",
"repo": "rust-analyzer",
"rev": "725349602e525df37f377701e001fe8aab807878",
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
"type": "github"
},
"original": {
@@ -198,25 +80,18 @@
"type": "github"
}
},
"slimlock": {
"inputs": {
"nixpkgs": [
"dream2nix",
"purescript-overlay",
"nixpkgs"
]
},
"systems": {
"locked": {
"lastModified": 1688756706,
"narHash": "sha256-xzkkMv3neJJJ89zo3o2ojp7nFeaZc2G0fYwNXNJRFlo=",
"owner": "thomashoneyman",
"repo": "slimlock",
"rev": "cf72723f59e2340d24881fd7bf61cb113b4c407c",
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "thomashoneyman",
"repo": "slimlock",
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
@@ -227,11 +102,11 @@
]
},
"locked": {
"lastModified": 1768158989,
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
"lastModified": 1762938485,
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
"owner": "numtide",
"repo": "treefmt-nix",
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
"type": "github"
},
"original": {

196
flake.nix
View File

@@ -3,134 +3,118 @@
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-parts = {
url = "github:hercules-ci/flake-parts";
inputs.nixpkgs-lib.follows = "nixpkgs";
};
crane.url = "github:ipetkov/crane";
flake-utils.url = "github:numtide/flake-utils";
# Provides Rust dev-env integration:
fenix = {
url = "github:nix-community/fenix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Provides formatting infrastructure:
treefmt-nix = {
url = "github:numtide/treefmt-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
};
nixConfig = {
extra-trusted-public-keys = "exo.cachix.org-1:okq7hl624TBeAR3kV+g39dUFSiaZgLRkLsFBCuJ2NZI=";
extra-substituters = "https://exo.cachix.org";
};
# TODO: figure out caching story
# nixConfig = {
# # nix community cachix
# extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=";
# extra-substituters = "https://nix-community.cachix.org";
# };
outputs =
inputs:
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
let
systems = [
"x86_64-linux"
"aarch64-darwin"
"aarch64-linux"
];
fenixToolchain = system: inputs.fenix.packages.${system}.complete;
in
inputs.flake-utils.lib.eachSystem systems (
system:
let
pkgs = import inputs.nixpkgs {
inherit system;
overlays = [ inputs.fenix.overlays.default ];
};
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
projectRootFile = "flake.nix";
programs.ruff-format.enable = true;
programs.ruff-format.excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
programs.rustfmt.enable = true;
programs.rustfmt.package = (fenixToolchain system).rustfmt;
programs.nixpkgs-fmt.enable = true;
};
in
{
formatter = treefmtEval.config.build.wrapper;
checks.formatting = treefmtEval.config.build.check inputs.self;
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
imports = [
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
];
devShells.default = pkgs.mkShell {
packages =
with pkgs;
[
# PYTHON
python313
uv
ruff
basedpyright
perSystem =
{ config, self', inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
treefmt = {
projectRootFile = "flake.nix";
programs = {
nixpkgs-fmt.enable = true;
ruff-format = {
enable = true;
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
};
rustfmt = {
enable = true;
package = config.rust.toolchain;
};
prettier = {
enable = true;
includes = [ "*.ts" ];
};
swift-format = {
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
};
};
# RUST
((fenixToolchain system).withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
])
rustup # Just here to make RustRover happy
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
# IFCONFIG
unixtools.ifconfig
# Build dependencies for Linux
pkg-config
openssl
])
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
macmon
]);
shellHook = ''
# PYTHON
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
# Build environment for Linux
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
''}
echo
echo "🍎🍎 Run 'just <recipe>' to get started"
just --list
'';
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];
packages =
[
# FORMATTING
config.treefmt.build.wrapper
# PYTHON
python313
uv
ruff
basedpyright
# RUST
config.rust.toolchain
maturin
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ lib.optionals stdenv.isLinux [
unixtools.ifconfig
]
++ lib.optionals stdenv.isDarwin [
macmon
];
OPENSSL_NO_VENDOR = "1";
shellHook = ''
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
${lib.optionalString stdenv.isLinux ''
export LD_LIBRARY_PATH="${openssl.out}/lib:$LD_LIBRARY_PATH"
''}
'';
};
};
};
}
);
}

View File

@@ -1,5 +1,3 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

View File

@@ -8,23 +8,35 @@ dependencies = [
"aiofiles>=24.1.0",
"aiohttp>=3.12.14",
"types-aiofiles>=24.1.0.20250708",
"typeguard>=4.4.4",
"pydantic>=2.11.7",
"base58>=2.1.1",
"cryptography>=45.0.5",
"fastapi>=0.116.1",
"filelock>=3.18.0",
"aiosqlite>=0.21.0",
"networkx>=3.5",
"protobuf>=6.32.0",
"rich>=14.1.0",
"rustworkx>=0.17.1",
"sqlmodel>=0.0.24",
"sqlalchemy[asyncio]>=2.0.43",
"greenlet>=3.2.4",
"huggingface-hub>=0.33.4",
"psutil>=7.0.0",
"loguru>=0.7.3",
"textual>=5.3.0",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"bidict>=0.23.1",
"mlx>=0.30.1; sys_platform == 'darwin'",
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
"mlx-lm>=0.28.3",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"tomlkit>=0.14.0",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux>=0.12.1",
]
[project.scripts]
@@ -35,7 +47,6 @@ exo = "exo.main:main"
# dependencies only required for development
[dependency-groups]
dev = [
"basedpyright>=1.29.0",
"pyinstaller>=6.17.0",
"pytest>=8.4.0",
"pytest-asyncio>=1.0.0",
@@ -73,7 +84,7 @@ build-backend = "uv_build"
###
[tool.basedpyright]
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src", "bench"]
include = [".venv/lib/mlx", ".venv/lib/mlx_lm", "src"]
typeCheckingMode = "strict"
failOnWarnings = true
@@ -101,7 +112,6 @@ root = "src"
# supported platforms for this project
[tool.uv]
prerelease = "allow"
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",
@@ -127,6 +137,3 @@ env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -1,145 +0,0 @@
{ inputs, ... }:
{
perSystem =
{ config, self', inputs', pkgs, lib, ... }:
let
# Fenix nightly toolchain with all components
fenixPkgs = inputs'.fenix.packages;
rustToolchain = fenixPkgs.complete.withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
"rust-analyzer"
];
# Crane with fenix toolchain
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain rustToolchain;
# Source filtering - only include rust/ directory and root Cargo files
# This ensures changes to Python/docs/etc don't trigger Rust rebuilds
src = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
parentDir = builtins.dirOf path;
inRustDir =
(lib.hasInfix "/rust/" path)
|| (lib.hasSuffix "/rust" parentDir)
|| (baseName == "rust" && type == "directory");
isRootCargoFile =
(baseName == "Cargo.toml" || baseName == "Cargo.lock")
&& (builtins.dirOf path == toString inputs.self);
in
isRootCargoFile
|| (inRustDir && (craneLib.filterCargoSources path type || lib.hasSuffix ".toml" path || lib.hasSuffix ".md" path));
};
# Common arguments for all Rust builds
commonArgs = {
inherit src;
pname = "exo-rust";
version = "0.0.1";
strictDeps = true;
nativeBuildInputs = [
pkgs.pkg-config
pkgs.python313 # Required for pyo3-build-config
];
buildInputs = [
pkgs.openssl
pkgs.python313 # Required for pyo3 tests
];
OPENSSL_NO_VENDOR = "1";
# Required for pyo3 tests to find libpython
LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.python313 ];
};
# Build dependencies once for caching
cargoArtifacts = craneLib.buildDepsOnly (
commonArgs
// {
cargoExtraArgs = "--workspace";
}
);
in
{
# Export toolchain for use in treefmt and devShell
options.rust = {
toolchain = lib.mkOption {
type = lib.types.package;
default = rustToolchain;
description = "The Rust toolchain to use";
};
};
config = {
packages = {
# Python bindings wheel via maturin
exo_pyo3_bindings = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
pname = "exo_pyo3_bindings";
nativeBuildInputs = commonArgs.nativeBuildInputs ++ [
pkgs.maturin
];
buildPhaseCargoCommand = ''
maturin build \
--release \
--manylinux off \
--manifest-path rust/exo_pyo3_bindings/Cargo.toml \
--features "pyo3/extension-module,pyo3/experimental-async" \
--interpreter ${pkgs.python313}/bin/python \
--out dist
'';
# Don't use crane's default install behavior
doNotPostBuildInstallCargoBinaries = true;
installPhaseCommand = ''
mkdir -p $out
cp dist/*.whl $out/
'';
}
);
};
checks = {
# Full workspace build (all crates)
cargo-build = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
# Run tests with nextest
cargo-nextest = craneLib.cargoNextest (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
# Build documentation
cargo-doc = craneLib.cargoDoc (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
};
};
};
}

View File

@@ -0,0 +1,47 @@
[package]
name = "system_custodian"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "system_custodian"
path = "src/lib.rs"
[[bin]]
path = "src/bin/main.rs"
name = "system_custodian"
doc = false
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }

View File

@@ -0,0 +1,4 @@
//! TODO: documentation
//!
fn main() {}

View File

@@ -0,0 +1,69 @@
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
//!
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
//! etc.) is in an appropriate state to facilitate the running of Exo application.
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
//! service which Exo application use to _control & query_ it.
//!
//! # Lifecycle
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
//! destructive to the user's pre-existing configurations.
//!
//! # Responsibilities
//! TODO: these are purely on MacOS, but change to be more broad
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
//! 1. duplicate the current network set
//! 2. modify existing services to turn on IPv6 if not there
//! 3. remove any bridge services & add any missing services that AREN'T bridge
//! TODO: In the future:
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
//! logic, this would be the place to spin that up.
//!
//! Then it will watch the SCDynamicStore for:
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
//! interface of any changes
//! 2. watch for any __undesirable__ changes to configuration and revert it
//!
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
//! over D-Bus
//! TODO:
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
//! then this would be the place to carry out discovery and propper handshakes with devices
//! on the other end of the link.
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {}

View File

@@ -1,7 +1,6 @@
import argparse
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -196,8 +195,6 @@ class Node:
def main():
args = Args.parse()
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
@@ -205,14 +202,6 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger.info("EXO Shutdown complete")
@@ -226,7 +215,6 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
def parse(cls) -> Self:
@@ -268,20 +256,6 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",
action="store_true",
dest="fast_synch",
default=None,
help="Force MLX FAST_SYNCH on (for JACCL backend)",
)
fast_synch_group.add_argument(
"--no-fast-synch",
action="store_false",
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -1,42 +1,46 @@
import base64
import json
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
from typing import Literal, cast
import anyio
from anyio import BrokenResourceError, create_task_group
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException, Request
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
GenerationStats,
ImageData,
ImageEditsInternalParams,
ImageGenerationResponse,
ImageGenerationTaskParams,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -44,24 +48,23 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
)
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -71,9 +74,11 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -85,25 +90,27 @@ def chunk_to_response(
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
finish_reason=chunk.finish_reason,
)
if isinstance(chunk, TokenChunk)
else StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=[tool.model_dump() for tool in chunk.tool_calls],
),
finish_reason="tool_calls",
)
],
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
def get_model_card(model_id: str) -> ModelCard | None:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
else:
return await ModelCard.from_hf(model_id)
for _, model_card in MODEL_CARDS.items():
if model_id == model_card.model_id:
return model_card
async def resolve_model_meta(model_id: str) -> ModelMetadata:
model_card = get_model_card(model_id)
if model_card is not None:
return model_card.metadata
return await get_model_meta(model_id)
class API:
@@ -134,7 +141,6 @@ class API:
self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI()
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
@@ -147,9 +153,8 @@ class API:
name="dashboard",
)
self._chat_completion_queues: dict[
CommandId, Sender[TokenChunk | ToolCallChunk]
] = {}
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -158,6 +163,7 @@ class API:
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._chat_completion_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
def unpause(self, result_clock: int):
@@ -167,21 +173,6 @@ class API:
self.paused_ev.set()
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
self.app.exception_handler(HTTPException)(self.http_exception_handler)
async def http_exception_handler(
self, _: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
CORSMiddleware,
@@ -204,13 +195,16 @@ class API:
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/v1/images/generations", response_model=None)(
self.image_generations
)
self.app.post("/v1/images/edits")(self.image_edits)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await resolve_model_card(payload.model_id),
model_meta=await resolve_model_meta(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -220,15 +214,15 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_card=command.model_card,
model_meta=command.model_meta,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
@@ -245,28 +239,26 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_card=model_card,
model_meta=model_meta,
)
async def get_placement(
self,
model_id: ModelId,
model_id: str,
sharding: Sharding = Sharding.Pipeline,
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await resolve_model_card(model_id)
model_meta = await resolve_model_meta(model_id)
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
model_meta=model_meta,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -293,7 +285,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
@@ -311,33 +303,32 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for model_card in cards:
for card in cards:
model_meta = card.metadata
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
model_meta=model_meta,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add((card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
@@ -348,17 +339,17 @@ class API:
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add((card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
@@ -367,7 +358,7 @@ class API:
memory_delta_by_node: dict[str, int] = {}
if node_ids:
total_bytes = model_card.storage_size.in_bytes
total_bytes = model_meta.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
@@ -375,14 +366,14 @@ class API:
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
card.model_id,
sharding,
instance_meta,
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -390,7 +381,7 @@ class API:
error=None,
)
)
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -413,24 +404,52 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk | ToolCallChunk, None]:
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
try:
self._chat_completion_queues[command_id], recv = channel[
TokenChunk | ToolCallChunk
]()
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if (
isinstance(chunk, TokenChunk)
and chunk.finish_reason is not None
):
break
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -446,23 +465,11 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, TokenChunk) and chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -470,11 +477,11 @@ class API:
yield f"data: {chunk_response.model_dump_json()}\n\n"
if isinstance(chunk, ToolCallChunk) or chunk.finish_reason is not None:
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
@@ -482,31 +489,7 @@ class API:
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ToolCallChunk):
finish_reason = "tool_calls"
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model or chunk.model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
tool_calls=[
tool.model_dump() for tool in chunk.tool_calls
],
),
)
],
)
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
@@ -534,56 +517,6 @@ class API:
],
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ToolCallChunk):
raise HTTPException(
status_code=500,
detail="Tool call in bench",
)
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
text_parts.append(chunk.text)
stats = chunk.stats or stats
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
assert model is not None
resp = BenchChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content=combined_text
),
finish_reason=finish_reason,
)
],
generation_stats=stats,
)
return resp
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
@@ -593,8 +526,10 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -611,17 +546,22 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
self._generate_chat_stream(command.command_id, parse_gpt_oss),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
async def image_generations(
self, payload: ImageGenerationTaskParams
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image generation requests.
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -632,20 +572,297 @@ class API:
status_code=404, detail=f"No instance found for model {payload.model}"
)
payload.stream = False
command = ChatCompletion(request_params=payload)
command = ImageGeneration(
request_params=payload,
)
await self._send(command)
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
# Check if streaming is requested
if payload.stream and payload.partial_images and payload.partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
),
media_type="text/event-stream",
)
# Non-streaming: collect all image chunks
return await self._collect_image_generation(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
)
async def _generate_image_stream(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> AsyncGenerator[str, None]:
"""Generate SSE stream of partial and final images."""
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
image_total_chunks: dict[tuple[int, bool], int] = {}
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
with recv as chunks:
async for chunk in chunks:
key = (chunk.image_index, chunk.is_partial)
if key not in image_chunks:
image_chunks[key] = {}
image_total_chunks[key] = chunk.total_chunks
image_metadata[key] = (
chunk.partial_index,
chunk.total_partials,
)
image_chunks[key][chunk.chunk_index] = chunk.data
# Check if this image is complete
if len(image_chunks[key]) == image_total_chunks[key]:
full_data = "".join(
image_chunks[key][i] for i in range(len(image_chunks[key]))
)
partial_idx, total_partials = image_metadata[key]
if chunk.is_partial:
# Yield partial image event
event_data = {
"type": "partial",
"partial_index": partial_idx,
"total_partials": total_partials,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
else:
# Final image
event_data = {
"type": "final",
"image_index": chunk.image_index,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
images_complete += 1
if images_complete >= num_images:
yield "data: [DONE]\n\n"
break
# Clean up completed image chunks
del image_chunks[key]
del image_total_chunks[key]
del image_metadata[key]
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def _collect_image_generation(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> ImageGenerationResponse:
"""Collect all image chunks (non-streaming) and return a single response."""
# Track chunks per image: {image_index: {chunk_index: data}}
# Only track non-partial (final) images
image_chunks: dict[int, dict[int, str]] = {}
image_total_chunks: dict[int, int] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
while images_complete < num_images:
with recv as chunks:
async for chunk in chunks:
# Skip partial images in non-streaming mode
if chunk.is_partial:
continue
if chunk.image_index not in image_chunks:
image_chunks[chunk.image_index] = {}
image_total_chunks[chunk.image_index] = chunk.total_chunks
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
# Check if this image is complete
if (
len(image_chunks[chunk.image_index])
== image_total_chunks[chunk.image_index]
):
images_complete += 1
if images_complete >= num_images:
break
# Reassemble images in order
images: list[ImageData] = []
for image_idx in range(num_images):
chunks_dict = image_chunks[image_idx]
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
images.append(
ImageData(
b64_json=full_data if response_format == "b64_json" else None,
url=None, # URL format not implemented yet
)
)
return ImageGenerationResponse(data=images)
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def image_edits(
self,
image: UploadFile = File(...),
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
) -> ImageGenerationResponse:
"""Handle image editing requests (img2img)."""
model_meta = await resolve_model_meta(model)
resolved_model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved_model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {resolved_model}"
)
# Read and base64 encode the uploaded image
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
# Map input_fidelity to image_strength
image_strength = 0.7 if input_fidelity == "high" else 0.3
# Split image into chunks to stay under gossipsub message size limit
data_chunks = [
image_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
# Create command first to get command_id
command = ImageEdits(
request_params=ImageEditsInternalParams(
image_data="", # Empty - will be assembled at worker from chunks
total_input_chunks=total_chunks,
prompt=prompt,
model=resolved_model,
n=n,
size=size,
response_format=response_format,
image_strength=image_strength,
),
)
# Send input chunks BEFORE the command
logger.info(
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(data_chunks):
await self._send(
SendInputChunk(
chunk=InputImageChunk(
idx=chunk_index,
model=resolved_model,
command_id=command.command_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
)
)
)
# Now send the main command
await self._send(command)
num_images = n
# Track chunks per image: {image_index: {chunk_index: data}}
image_chunks: dict[int, dict[int, str]] = {}
image_total_chunks: dict[int, int] = {}
images_complete = 0
try:
self._image_generation_queues[command.command_id], recv = channel[
ImageChunk
]()
while images_complete < num_images:
with recv as chunks:
async for chunk in chunks:
if chunk.image_index not in image_chunks:
image_chunks[chunk.image_index] = {}
image_total_chunks[chunk.image_index] = chunk.total_chunks
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
if (
len(image_chunks[chunk.image_index])
== image_total_chunks[chunk.image_index]
):
images_complete += 1
if images_complete >= num_images:
break
images: list[ImageData] = []
for image_idx in range(num_images):
chunks_dict = image_chunks[image_idx]
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
images.append(
ImageData(
b64_json=full_data if response_format == "b64_json" else None,
url=None, # URL format not implemented yet
)
)
return ImageGenerationResponse(data=images)
except anyio.get_cancelled_exc_class():
raise
finally:
# Send TaskFinished command
await self._send(TaskFinished(finished_command_id=command.command_id))
del self._image_generation_queues[command.command_id]
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for memory in self.state.node_memory.values():
total_available += memory.ram_available
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
return total_available
@@ -654,13 +871,14 @@ class API:
return ModelList(
data=[
ModelListModel(
id=card.model_id,
id=card.short_id,
hugging_face_id=card.model_id,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
tasks=[task.value for task in card.tasks],
)
for card in MODEL_CARDS.values()
]
@@ -699,13 +917,16 @@ class API:
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, (TokenChunk, ToolCallChunk))
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
elif event.command_id in self._image_generation_queues:
assert isinstance(event.chunk, ImageChunk)
await self._image_generation_queues[event.command_id].send(
event.chunk
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -2,6 +2,7 @@ from datetime import datetime, timedelta, timezone
import anyio
from anyio.abc import TaskGroup
from fastapi.routing import request_response
from loguru import logger
from exo.master.placement import (
@@ -11,13 +12,17 @@ from exo.master.placement import (
place_instance,
)
from exo.shared.apply import apply
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.commands import (
ChatCompletion,
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskFinished,
TestCommand,
)
@@ -26,8 +31,8 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
TaskDeleted,
@@ -36,6 +41,12 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
from exo.shared.types.tasks import (
ImageGeneration as ImageGenerationTask,
)
from exo.shared.types.tasks import (
TaskId,
TaskStatus,
@@ -100,6 +111,7 @@ class Master:
async for forwarder_command in commands:
try:
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.command
match command:
@@ -147,6 +159,92 @@ class Master:
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageGeneration():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageEdits():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
@@ -159,8 +257,6 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
@@ -176,6 +272,13 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
generated_events.append(
InputChunkReceived(
command_id=chunk.command_id,
chunk=chunk,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -203,7 +306,9 @@ class Master:
async def _plan(self) -> None:
while True:
# kill broken instances
connected_node_ids = set(self.state.topology.list_nodes())
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
@@ -238,8 +343,6 @@ class Master:
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
self._event_log.append(event)
await self._send_event(indexed)

View File

@@ -6,25 +6,23 @@ from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
Cycle,
filter_cycles_by_memory,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.models import ModelId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -54,33 +52,37 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
logger.info("finding cycles:")
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_memory, command.model_card.storage_size
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
if len(cycles_with_sufficient_memory) == 0:
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
)
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_card.supports_tensor:
if not command.model_meta.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_card.hidden_size % len(cycle) == 0
if command.model_meta.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
@@ -90,38 +92,44 @@ def place_instance(
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[Cycle] = [
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node_id) for node_id in cycle)
if any(topology.node_is_leaf(node.node_id) for node in cycle)
]
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node_memory[node_id].ram_available for node_id in cycle),
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
start=Memory(),
),
)
shard_assignments = get_shard_assignments(
command.model_card, selected_cycle, command.sharding, node_memory
command.model_meta, selected_cycle, command.sharding
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
"You have likely selected ibv for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -129,20 +137,19 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
[node_id for node_id in selected_cycle],
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
cycle_digraph,
)
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle.node_ids[0],
selected_cycle,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
node_network=node_network,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
@@ -151,7 +158,6 @@ def place_instance(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
node_network=node_network,
)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,

View File

@@ -1,13 +1,15 @@
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import TypeGuard, cast
from loguru import logger
from pydantic import BaseModel
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -17,113 +19,67 @@ from exo.shared.types.worker.shards import (
)
class NodeWithProfile(BaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[Cycle],
node_memory: Mapping[NodeId, MemoryUsage],
required_memory: Memory,
) -> list[Cycle]:
filtered_cycles: list[Cycle] = []
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
for cycle in cycles:
if not all(node in node_memory for node in cycle):
if not narrow_all_nodes(cycle):
continue
total_mem = sum(
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
start=Memory(),
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(cycle)
filtered_cycles.append(cast(list[NodeInfo], cycle))
return filtered_cycles
def get_smallest_cycles(
cycles: list[Cycle],
) -> list[Cycle]:
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(
total_layers: int,
memory_fractions: list[float],
) -> list[int]:
n = len(memory_fractions)
if n == 0:
raise ValueError("Cannot allocate layers to an empty node list")
if total_layers < n:
raise ValueError(
f"Cannot distribute {total_layers} layers across {n} nodes "
"(need at least 1 layer per node)"
)
# Largest remainder: floor each, then distribute remainder by fractional part
raw = [f * total_layers for f in memory_fractions]
result = [int(r) for r in raw]
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
for i in range(total_layers - sum(result)):
result[by_remainder[i]] += 1
# Ensure minimum 1 per node by taking from the largest
for i in range(n):
if result[i] == 0:
max_idx = max(range(n), key=lambda j: result[j])
assert result[max_idx] > 1
result[max_idx] -= 1
result[i] = 1
return result
def get_shard_assignments_for_pipeline_parallel(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
if not cycle.node_ids:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_card.n_layers
world_size = len(cycle)
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
],
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node_memory[node_id].ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
)
layers_assigned = 0
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(
total_layers
* (
node.node_profile.memory.ram_available.in_bytes
/ cycle_memory.in_bytes
)
)
node_layers = max(1, node_layers)
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
@@ -132,11 +88,11 @@ def get_shard_assignments_for_pipeline_parallel(
)
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
node_to_runner[node.node_id] = runner_id
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -145,17 +101,17 @@ def get_shard_assignments_for_pipeline_parallel(
def get_shard_assignments_for_tensor_parallel(
model_card: ModelCard,
cycle: Cycle,
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
total_layers = model_card.n_layers
world_size = len(cycle)
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for i, node_id in enumerate(cycle):
for i, node in enumerate(selected_cycle):
shard = TensorShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=0,
@@ -166,10 +122,10 @@ def get_shard_assignments_for_tensor_parallel(
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
node_to_runner[node.node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -178,22 +134,22 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_card: ModelCard,
cycle: Cycle,
model_meta: ModelMetadata,
selected_cycle: list[NodeInfo],
sharding: Sharding,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_card=model_card,
cycle=cycle,
node_memory=node_memory,
model_meta=model_meta,
selected_cycle=selected_cycle,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
model_card=model_card,
cycle=cycle,
model_meta=model_meta,
selected_cycle=selected_cycle,
)
@@ -208,40 +164,38 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
)
return []
cycle = cycles[0]
get_thunderbolt = False
if cycle_digraph.is_thunderbolt_cycle(cycle):
if cycle_digraph.is_thunderbolt_cycle(cycles[0]):
get_thunderbolt = True
logger.info(f"Using thunderbolt cycle: {get_thunderbolt}")
cycle = cycles[0]
hosts: list[Host] = []
for i in range(len(cycle)):
current_node = cycle.node_ids[i]
next_node = cycle.node_ids[(i + 1) % len(cycle)]
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
for connection in cycle_digraph.get_all_connections_between(
source=current_node, sink=next_node
):
if not isinstance(connection, SocketConnection):
continue
if get_thunderbolt and not connection.is_thunderbolt():
continue
host = Host(
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
)
hosts.append(host)
break
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
):
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
)
hosts.append(host)
break
return hosts
def get_mlx_jaccl_devices_matrix(
selected_cycle: list[NodeId],
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -260,37 +214,72 @@ def get_mlx_jaccl_devices_matrix(
if i == j:
continue
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
# Find the IP J uses to talk to I
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_rdma_interface_name_for_ip(
connection_ip, node_i
):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
break
else:
logger.warning(
f"Failed to find interface name between {node_i} and {node_j}"
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
)
raise ValueError(
"Current jaccl backend requires all-to-all RDMA connections"
"Current ibv backend requires all-to-all rdma connections"
)
return matrix
def _find_connection_ip(
node_i: NodeId,
node_j: NodeId,
node_i: NodeInfo,
node_j: NodeInfo,
cycle_digraph: Topology,
) -> Generator[tuple[str, bool]]:
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
def _find_rdma_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None
def _find_interface_name_for_ip(
ip_address: str, node_network: NodeNetworkInfo
ip_address: str,
node_info: NodeInfo,
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
for interface in node_network.interfaces:
if node_info.node_profile is None:
return None
for interface in node_info.node_profile.network_interfaces:
if interface.ip_address == ip_address:
return interface.name
@@ -298,10 +287,7 @@ def _find_interface_name_for_ip(
def _find_ip_prioritised(
node_id: NodeId,
other_node_id: NodeId,
cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo],
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
@@ -312,14 +298,9 @@ def _find_ip_prioritised(
3. Non-Thunderbolt connections
4. Any other IP address
"""
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {
_find_interface_name_for_ip(
ip, node_network.get(other_node_id, NodeNetworkInfo())
): ip
for ip, _ in ips
}
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
en0_ip = iface_map.get("en0")
if en0_ip:
@@ -343,10 +324,9 @@ def _find_ip_prioritised(
def get_mlx_ring_hosts_by_node(
selected_cycle: Cycle,
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
ephemeral_port: int,
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> dict[NodeId, list[Host]]:
"""Generate per-node host lists for MLX ring backend.
@@ -361,13 +341,14 @@ def get_mlx_ring_hosts_by_node(
hosts_by_node: dict[NodeId, list[Host]] = {}
for rank, node_id in enumerate(selected_cycle):
for rank, node in enumerate(selected_cycle):
node_id = node.node_id
left_rank = (rank - 1) % world_size
right_rank = (rank + 1) % world_size
hosts_for_node: list[Host] = []
for idx, other_node_id in enumerate(selected_cycle):
for idx, other_node in enumerate(selected_cycle):
if idx == rank:
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
continue
@@ -377,12 +358,10 @@ def get_mlx_ring_hosts_by_node(
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
continue
connection_ip = _find_ip_prioritised(
node_id, other_node_id, cycle_digraph, node_network
)
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
@@ -396,34 +375,31 @@ def get_mlx_ring_hosts_by_node(
def get_mlx_jaccl_coordinators(
coordinator: NodeId,
selected_cycle: list[NodeInfo],
coordinator_port: int,
cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
logger.info(f"Selecting coordinator: {coordinator}")
rank_0_node = selected_cycle[0]
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeId) -> str:
if n == coordinator:
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network)
if ip is not None:
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
return ip
logger.warning(
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {
n: f"{get_ip_for_node(n)}:{coordinator_port}"
for n in cycle_digraph.list_nodes()
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
}

View File

@@ -1,37 +1,67 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
NodeNetworkInfo,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
def create_node_memory(memory: int) -> MemoryUsage:
return MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
)
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
return _create_node
def create_node_network() -> NodeNetworkInfo:
return NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address=f"169.254.0.{i}")
for i in range(10)
]
)
# TODO: this is a hack to get the port for the send_back_multiaddr
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
def create_socket_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
)
def create_rdma_connection(iface: int) -> RDMAConnection:
return RDMAConnection(
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
)
return _create_connection

View File

@@ -1,107 +0,0 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
from exo.master.api import API
app = FastAPI()
# Setup exception handler
api = object.__new__(API)
api.app = app
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
# Add test routes that raise HTTPException
@app.get("/test-error")
async def _test_error() -> None:
raise HTTPException(status_code=500, detail="Test error message")
@app.get("/test-not-found")
async def _test_not_found() -> None:
raise HTTPException(status_code=404, detail="Resource not found")
client = TestClient(app)
# Test 500 error
response = client.get("/test-error")
assert response.status_code == 500
data: dict[str, Any] = response.json()
assert "error" in data
assert data["error"]["message"] == "Test error message"
assert data["error"]["type"] == "Internal Server Error"
assert data["error"]["code"] == 500
# Test 404 error
response = client.get("/test-not-found")
assert response.status_code == 404
data = response.json()
assert "error" in data
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason

View File

@@ -7,7 +7,6 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
@@ -20,12 +19,15 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodeGatheredInfo,
NodePerformanceMeasured,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
@@ -73,22 +75,29 @@ async def test_master():
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
# inject a NodePerformanceProfile event
logger.info("inject a NodePerformanceProfile event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
origin=sender_node_id,
session=session_id,
event=(
NodeGatheredInfo(
NodePerformanceMeasured(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
),
@@ -99,7 +108,7 @@ async def test_master():
logger.info("wait for initial topology event")
while len(list(master.state.topology.list_nodes())) == 0:
await anyio.sleep(0.001)
while len(master.state.node_memory) == 0:
while len(master.state.node_profiles) == 0:
await anyio.sleep(0.001)
logger.info("inject a CreateInstance Command")
@@ -109,8 +118,9 @@ async def test_master():
command=(
PlaceInstance(
command_id=CommandId(),
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
@@ -153,7 +163,7 @@ async def test_master():
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodeGatheredInfo)
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)
@@ -166,8 +176,9 @@ async def test_master():
start_layer=0,
end_layer=16,
n_layers=16,
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,

View File

@@ -1,24 +1,20 @@
from typing import Callable
import pytest
from loguru import logger
from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.tests.conftest import (
create_node_memory,
create_node_network,
create_rdma_connection,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -30,6 +26,11 @@ from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance() -> Instance:
return MlxRingInstance(
@@ -43,20 +44,21 @@ def instance() -> Instance:
@pytest.fixture
def model_card() -> ModelCard:
return ModelCard(
def model_meta() -> ModelMetadata:
return ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
supports_tensor=True,
)
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
return PlaceInstance(
command_id=CommandId(),
model_card=model_card,
model_meta=model_meta,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
@@ -68,75 +70,47 @@ def place_instance_command(model_card: ModelCard) -> PlaceInstance:
[
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 468, 1092), 12, (2, 3, 7)),
((312, 518, 1024), 12, (2, 3, 7)),
],
)
def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
model_card: ModelCard,
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_card)
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
# fully connected (directed) between the 3 nodes
conn_a_b = Connection(
source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)
)
conn_b_c = Connection(
source=node_id_b, sink=node_id_c, edge=create_socket_connection(2)
)
conn_c_a = Connection(
source=node_id_c, sink=node_id_a, edge=create_socket_connection(3)
)
conn_c_b = Connection(
source=node_id_c, sink=node_id_b, edge=create_socket_connection(4)
)
conn_a_c = Connection(
source=node_id_a, sink=node_id_c, edge=create_socket_connection(5)
)
conn_b_a = Connection(
source=node_id_b, sink=node_id_a, edge=create_socket_connection(6)
)
node_memory = {
node_id_a: create_node_memory(available_memory[0]),
node_id_b: create_node_memory(available_memory[1]),
node_id_c: create_node_memory(available_memory[2]),
}
node_network = {
node_id_a: create_node_network(),
node_id_b: create_node_network(),
node_id_c: create_node_network(),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
topology.add_connection(conn_b_a)
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
# Add bidirectional connections for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
placements = place_instance(cic, topology, {})
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert instance.shard_assignments.model_id == model_card.model_id
assert instance.shard_assignments.model_id == model_meta.model_id
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
@@ -156,22 +130,23 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit() -> None:
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
ModelCard(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, node_memory, node_network)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -182,22 +157,23 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
node_memory = {node_id: create_node_memory(1001 * 1024)}
node_network = {node_id: create_node_network()}
topology.add_node(create_node(1001 * 1024, node_id))
cic = place_instance_command(
ModelCard(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, node_memory, node_network)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -208,16 +184,17 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit() -> None:
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -225,7 +202,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {}, node_memory, node_network)
place_instance(cic, topology, {})
def test_get_transition_events_no_change(instance: Instance):
@@ -270,175 +247,217 @@ def test_get_transition_events_delete_instance(instance: Instance):
assert events[0].instance_id == instance_id
def test_placement_selects_leaf_nodes(
model_card: ModelCard,
def test_placement_selects_cycle_with_most_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
topology = Topology()
# Arrange two 3-node cycles with different total memory.
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
# The algorithm should select the cycle with the most available memory.
model_card.storage_size = Memory.from_bytes(1000)
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
# Create node ids
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_id_e = NodeId()
node_id_f = NodeId()
node_memory = {
node_id_a: create_node_memory(500),
node_id_b: create_node_memory(600),
node_id_c: create_node_memory(600),
node_id_d: create_node_memory(500),
}
node_network = {
node_id_a: create_node_network(),
node_id_b: create_node_network(),
node_id_c: create_node_network(),
node_id_d: create_node_network(),
}
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
topology.add_node(create_node(800, node_id_c))
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_node(node_id_d)
# D-E-F cycle total memory = 1800 (> A-B-C total)
topology.add_node(create_node(600, node_id_d))
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
# Daisy chain topology (directed)
topology.add_connection(
Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_b, sink=node_id_a, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_b, sink=node_id_c, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_c, sink=node_id_b, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_c, sink=node_id_d, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
# Build bidirectional cycles for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_d))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_e))
topology.add_connection(create_connection(node_id_f, node_id_d))
topology.add_connection(create_connection(node_id_d, node_id_f))
cic = place_instance_command(
model_meta=model_meta,
)
cic = place_instance_command(model_card=model_card)
# Act
placements = place_instance(cic, topology, {})
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
# assert
# Assert: D-E-F cycle should be selected as it has more total memory
assert len(placements) == 1
instance = list(placements.values())[0]
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(
node_id_c,
node_id_d,
)
)
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert more_memory_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
def test_tensor_rdma_backend_connectivity_matrix(
model_card: ModelCard,
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
topology = Topology()
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_memory = {
node_a: create_node_memory(500),
node_b: create_node_memory(500),
node_c: create_node_memory(500),
}
node_a = create_node(500, node_id_a)
node_b = create_node(500, node_id_b)
node_c = create_node(500, node_id_c)
ethernet_interface = NetworkInterfaceInfo(
name="en0",
ip_address="10.0.0.1",
)
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
ip_address="192.168.1.100",
)
node_network = {
node_a: NodeNetworkInfo(interfaces=[ethernet_interface]),
node_b: NodeNetworkInfo(interfaces=[ethernet_interface]),
node_c: NodeNetworkInfo(interfaces=[ethernet_interface]),
}
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(node_id_a, node_id_b)
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
# RDMA connections (directed)
topology.add_connection(
Connection(source=node_a, sink=node_b, edge=create_rdma_connection(3))
)
topology.add_connection(
Connection(source=node_b, sink=node_a, edge=create_rdma_connection(3))
)
topology.add_connection(
Connection(source=node_b, sink=node_c, edge=create_rdma_connection(4))
)
topology.add_connection(
Connection(source=node_c, sink=node_b, edge=create_rdma_connection(4))
)
topology.add_connection(
Connection(source=node_a, sink=node_c, edge=create_rdma_connection(5))
)
topology.add_connection(
Connection(source=node_c, sink=node_a, edge=create_rdma_connection(5))
)
# Ethernet connections (directed)
topology.add_connection(Connection(source=node_a, sink=node_b, edge=ethernet_conn))
topology.add_connection(Connection(source=node_b, sink=node_c, edge=ethernet_conn))
topology.add_connection(Connection(source=node_c, sink=node_a, edge=ethernet_conn))
topology.add_connection(Connection(source=node_a, sink=node_c, edge=ethernet_conn))
topology.add_connection(Connection(source=node_b, sink=node_a, edge=ethernet_conn))
topology.add_connection(Connection(source=node_c, sink=node_b, edge=ethernet_conn))
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
cic = PlaceInstance(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_card=model_card,
model_meta=model_meta,
min_nodes=1,
)
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
placements = place_instance(cic, topology, {})
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert isinstance(instance, MlxJacclInstance)
assert instance.jaccl_devices is not None
assert instance.ibv_devices is not None
assert instance.jaccl_coordinators is not None
matrix = instance.jaccl_devices
matrix = instance.ibv_devices
assert len(matrix) == 3
for i in range(3):
assert matrix[i][i] is None
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
idx_a = node_to_idx[node_a]
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]
idx_a = node_to_idx[node_id_a]
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
# Verify coordinators are set for all nodes
assert len(instance.jaccl_coordinators) == 3
@@ -450,5 +469,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
if node_id == assigned_nodes[0]:
assert coordinator.startswith("0.0.0.0:")
else:
# Non-rank-0 nodes should have valid IP addresses (can be link-local)
ip_part = coordinator.split(":")[0]
# Just verify it's a valid IP format
assert len(ip_part.split(".")) == 4

View File

@@ -1,182 +1,162 @@
from typing import Callable
import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import (
create_node_memory,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
NetworkInterfaceInfo,
NodeNetworkInfo,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
def test_filter_cycles_by_memory():
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
connection1 = Connection(
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1_mem = create_node_memory(1000 * 1024)
node2_mem = create_node_memory(1000 * 1024)
node_memory = {node1_id: node1_mem, node2_id: node2_mem}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology = Topology()
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
cycles = [c for c in topology.get_cycles() if len(c) != 1]
cycles = topology.get_cycles()
assert len(cycles) == 1
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(cycles, node_memory, Memory.from_bytes(1))
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 2
assert set(n for n in filtered_cycles[0]) == {node1_id, node2_id}
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory():
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
connection1 = Connection(
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1_mem = create_node_memory(1000 * 1024)
node2_mem = create_node_memory(1000 * 1024)
node_memory = {node1_id: node1_mem, node2_id: node2_mem}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology = Topology()
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), node_memory, Memory.from_kb(2001)
topology.get_cycles(), Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory():
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
node_a_mem = create_node_memory(500 * 1024)
node_b_mem = create_node_memory(500 * 1024)
node_c_mem = create_node_memory(1000 * 1024)
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
}
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(cycles, node_memory, Memory.from_kb(1500))
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 3
assert set(n for n in filtered_cycles[0]) == {
assert set(n.node_id for n in filtered_cycles[0]) == {
node_a_id,
node_b_id,
node_c_id,
}
def test_get_smallest_cycles():
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
# act
smallest_cycles = get_smallest_cycles(cycles)
smallest_cycles = get_smallest_cycles(topology.get_cycles())
# assert
assert len(smallest_cycles) == 1
assert len(smallest_cycles[0]) == 2
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
assert set(n.node_id for n in smallest_cycles[0]) == {node_a_id, node_b_id}
@pytest.mark.parametrize(
@@ -185,12 +165,12 @@ def test_get_smallest_cycles():
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
# Edge case: one node has ~90% of memory - should not over-allocate.
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
((900, 50, 50), 20, (18, 1, 1)),
],
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
@@ -200,61 +180,44 @@ def test_get_shard_assignments(
node_b_id = NodeId()
node_c_id = NodeId()
# create connections (A -> B -> C -> A forms a 3-cycle, plus B -> A also exists)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(4)
)
node_a = create_node(available_memory[0] * 1024, node_a_id)
node_b = create_node(available_memory[1] * 1024, node_b_id)
node_c = create_node(available_memory[2] * 1024, node_c_id)
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
node_a_mem = create_node_memory(available_memory[0] * 1024)
node_b_mem = create_node_memory(available_memory[1] * 1024)
node_c_mem = create_node_memory(available_memory[2] * 1024)
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
}
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_card = ModelCard(
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
# pick the 3-node cycle deterministically (cycle ordering can vary)
selected_cycle = next(cycle for cycle in cycles if len(cycle) == 3)
selected_cycle = cycles[0]
# act
shard_assignments = get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory=node_memory
model_meta, selected_cycle, Sharding.Pipeline
)
# assert
runner_id_a = shard_assignments.node_to_runner[node_a_id]
runner_id_b = shard_assignments.node_to_runner[node_b_id]
runner_id_c = shard_assignments.node_to_runner[node_c_id]
assert (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
== expected_layers[2]
)
assert (
shard_assignments.runner_to_shard[runner_id_a].end_layer
- shard_assignments.runner_to_shard[runner_id_a].start_layer
@@ -265,37 +228,30 @@ def test_get_shard_assignments(
- shard_assignments.runner_to_shard[runner_id_b].start_layer
== expected_layers[1]
)
assert (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
== expected_layers[2]
)
def test_get_hosts_from_subgraph():
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
# act
hosts = get_hosts_from_subgraph(topology)
@@ -303,68 +259,95 @@ def test_get_hosts_from_subgraph():
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip="169.254.0.1", port=1234),
Host(ip="169.254.0.2", port=1234),
Host(ip="169.254.0.3", port=1234),
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_jaccl_coordinators():
def test_get_mlx_jaccl_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
# fully connected (directed) between the 3 nodes
conn_a_b = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
)
conn_b_a = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
)
conn_b_c = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(3)
)
conn_c_b = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
conn_c_a = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(5)
)
conn_a_c = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(6)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
)
network_a = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.5"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.2"),
]
)
network_b = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.1"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.4"),
]
)
network_c = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.3"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.6"),
]
)
node_network = {
node_a_id: network_a,
node_b_id: network_b,
node_c_id: network_c,
}
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
@@ -373,12 +356,11 @@ def test_get_mlx_jaccl_coordinators():
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
# act
coordinators = get_mlx_jaccl_coordinators(
node_a_id,
coordinator_port=5000,
cycle_digraph=topology,
node_network=node_network,
cycle, coordinator_port=5000, cycle_digraph=topology
)
# assert
@@ -399,129 +381,19 @@ def test_get_mlx_jaccl_coordinators():
f"Coordinator for {node_id} should use port 5000"
)
# Rank 0 (node_a) treats this as the listen socket so should listen on all IPs
# Rank 0 (node_a) treats this as the listen socket so should listen on all
# IPs
assert coordinators[node_a_id].startswith("0.0.0.0:"), (
"Rank 0 node should use 0.0.0.0 as coordinator listen address"
"Rank 0 node should use localhost as coordinator"
)
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert isinstance(conn_b_a.edge, SocketConnection)
assert (
coordinators[node_b_id] == f"{conn_b_a.edge.sink_multiaddr.ip_address}:5000"
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert isinstance(conn_c_a.edge, SocketConnection)
assert coordinators[node_c_id] == (
f"{conn_c_a.edge.sink_multiaddr.ip_address}:5000"
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
class TestAllocateLayersProportionally:
def test_empty_node_list_raises(self):
with pytest.raises(ValueError, match="empty node list"):
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
def test_zero_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
def test_negative_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
def test_fewer_layers_than_nodes_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
)
def test_equal_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
)
assert result == [3, 3, 3, 3]
assert sum(result) == 12
def test_proportional_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
)
assert result == [3, 3, 6]
assert sum(result) == 12
def test_extreme_imbalance_ensures_minimum(self):
result = allocate_layers_proportionally(
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
)
assert all(layers >= 1 for layers in result)
assert sum(result) == 20
# Small nodes get minimum 1 layer
assert result == [18, 1, 1]
def test_single_node_gets_all_layers(self):
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
assert result == [10]
def test_minimum_viable_allocation(self):
result = allocate_layers_proportionally(
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
)
assert result == [1, 1, 1]
assert sum(result) == 3
def test_get_shard_assignments_insufficient_memory_raises():
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
node_a_mem = create_node_memory(900 * 1024)
node_b_mem = create_node_memory(50 * 1024)
node_c_mem = create_node_memory(10 * 1024) # Insufficient memory
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
conn_a_b = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
conn_b_c = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
conn_c_a = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
conn_b_a = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(3)
)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
}
model_card = ModelCard(
model_id=ModelId("test-model"),
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory
)

View File

@@ -1,9 +1,13 @@
import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
@pytest.fixture
@@ -12,97 +16,189 @@ def topology() -> Topology:
@pytest.fixture
def socket_connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
def test_add_node(topology: Topology):
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=memory_profile,
network_interfaces=[],
system=system_profile,
)
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
# arrange
node_id = NodeId()
# act
topology.add_node(node_id)
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
# assert
assert topology.node_is_leaf(node_id)
data = topology.get_node_profile(node_id)
assert data == node_profile
def test_add_connection(topology: Topology, socket_connection: SocketConnection):
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
connection = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
data = list(topology.list_connections())
data = topology.get_connection_profile(connection)
# assert
assert data == [connection]
assert data == connection.connection_profile
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
def test_remove_connection_still_connected(
topology: Topology, socket_connection: SocketConnection
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_connection(conn)
topology.remove_connection(connection)
# assert
assert list(topology.get_all_connections_between(node_a, node_b)) == []
assert topology.get_connection_profile(connection) is None
def test_remove_node_still_connected(
topology: Topology, socket_connection: SocketConnection
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
assert list(topology.out_edges(node_a)) == [conn]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_node(node_b)
topology.remove_node(connection.local_node_id)
# assert
assert list(topology.out_edges(node_a)) == []
assert topology.get_node_profile(connection.local_node_id) is None
def test_list_nodes(topology: Topology, socket_connection: SocketConnection):
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
assert list(topology.out_edges(node_a)) == [conn]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeId) for node in nodes)
assert set(node for node in nodes) == set([node_a, node_b])
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
}

View File

@@ -9,10 +9,13 @@ from exo.shared.types.events import (
ChunkGenerated,
Event,
IndexedEvent,
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -25,46 +28,36 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
)
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import Connection, RDMAConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import (
MacmonMetrics,
MacThunderboltConnections,
MacThunderboltIdentifiers,
MemoryUsage,
MiscData,
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
)
def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged()
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
): # Pass-through events that don't modify state
return state
case InstanceCreated():
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeGatheredInfo():
return apply_node_gathered_info(event, state)
case NodeMemoryMeasured():
return apply_node_memory_measured(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
@@ -196,133 +189,120 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.remove_node(event.node_id)
topology = copy.copy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
}
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
# Clean up all granular node mappings
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
node_system = {
key: value for key, value in state.node_system.items() if key != event.node_id
}
node_network = {
key: value for key, value in state.node_network.items() if key != event.node_id
}
node_thunderbolt = {
key: value
for key, value in state.node_thunderbolt.items()
if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_system": node_system,
"node_network": node_network,
"node_thunderbolt": node_thunderbolt,
}
)
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(
update={
"node_profiles": new_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
# Build update dict with only the mappings that change
update: dict[str, object] = {
"last_seen": {
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
},
"topology": topology,
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
)
match info:
case MacmonMetrics():
update["node_system"] = {
**state.node_system,
event.node_id: info.system_profile,
}
update["node_memory"] = {**state.node_memory, event.node_id: info.memory}
case MemoryUsage():
update["node_memory"] = {**state.node_memory, event.node_id: info}
case NodeConfig():
pass
case MiscData():
current_identity = state.node_identities.get(event.node_id, NodeIdentity())
new_identity = current_identity.model_copy(
update={"friendly_name": info.friendly_name}
)
update["node_identities"] = {
**state.node_identities,
event.node_id: new_identity,
}
case StaticNodeInformation():
current_identity = state.node_identities.get(event.node_id, NodeIdentity())
new_identity = current_identity.model_copy(
update={"model_id": info.model, "chip_id": info.chip}
)
update["node_identities"] = {
**state.node_identities,
event.node_id: new_identity,
}
case NodeNetworkInterfaces():
update["node_network"] = {
**state.node_network,
event.node_id: NodeNetworkInfo(interfaces=info.ifaces),
}
case MacThunderboltIdentifiers():
update["node_thunderbolt"] = {
**state.node_thunderbolt,
event.node_id: NodeThunderboltInfo(interfaces=info.idents),
}
case MacThunderboltConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_thunderbolt
for tb_ident in state.node_thunderbolt[nid].interfaces
}
as_rdma_conns = [
Connection(
source=event.node_id,
sink=conn_map[tb_conn.sink_uuid][0],
edge=RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info.conns
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
return state.model_copy(update=update)
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_connection(event.conn)
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.remove_connection(event.conn)
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -38,10 +38,11 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
# libp2p topics for event forwarding
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024

View File

@@ -11,6 +11,9 @@ class InterceptLogger(HypercornLogger):
def __init__(self, config: Config):
super().__init__(config)
assert self.error_logger
# TODO: Decide if we want to provide access logs
# assert self.access_logger
# self.access_logger.handlers = [_InterceptHandler()]
self.error_logger.handlers = [_InterceptHandler()]
@@ -26,11 +29,6 @@ class _InterceptHandler(logging.Handler):
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,126 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)
_model_meta_cache: dict[str, ModelMetadata] = {}
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def _get_model_meta(model_id: str) -> ModelMetadata:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -7,8 +7,8 @@ import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@@ -31,8 +31,9 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,

View File

@@ -2,7 +2,6 @@ from exo.shared.apply import apply_node_download_progress
from exo.shared.tests.conftest import get_pipeline_shard_metadata
from exo.shared.types.common import NodeId
from exo.shared.types.events import NodeDownloadProgress
from exo.shared.types.memory import Memory
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
@@ -14,7 +13,6 @@ def test_apply_node_download_progress():
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
new_state = apply_node_download_progress(
@@ -30,12 +28,10 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
total_bytes=Memory(),
)
state = State(downloads={NodeId("node-1"): [event1]})
@@ -43,4 +39,7 @@ def test_apply_two_node_download_progress():
NodeDownloadProgress(download_progress=event2), state
)
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.topology import Connection
def test_state_serialization_roundtrip() -> None:
@@ -12,11 +12,9 @@ def test_state_serialization_roundtrip() -> None:
node_b = NodeId("node-b")
connection = Connection(
source=node_a,
sink=node_b,
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
),
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
)
state = State()
@@ -25,11 +23,5 @@ def test_state_serialization_roundtrip() -> None:
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)
assert (
state.topology.to_snapshot().nodes
== restored_state.topology.to_snapshot().nodes
)
assert set(state.topology.to_snapshot().connections) == set(
restored_state.topology.to_snapshot().connections
)
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
assert restored_state.model_dump_json() == json_repr

View File

@@ -1,227 +1,203 @@
import contextlib
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.topology import (
Connection,
Cycle,
RDMAConnection,
SocketConnection,
)
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
class TopologySnapshot(BaseModel):
nodes: Sequence[NodeId]
connections: Mapping[
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
]
nodes: list[NodeInfo]
connections: list[Connection]
model_config = ConfigDict(frozen=True, extra="forbid")
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
@dataclass
class Topology:
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
init=False, default_factory=rx.PyDiGraph
)
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
nodes=list(self.list_nodes()), connections=self.map_connections()
nodes=list(self.list_nodes()),
connections=list(self.list_connections()),
)
@classmethod
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
topology = cls()
for node_id in snapshot.nodes:
for node in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node_id)
topology.add_node(node)
for source in snapshot.connections:
for sink in snapshot.connections[source]:
for edge in snapshot.connections[source][sink]:
topology.add_connection(
Connection(source=source, sink=sink, edge=edge)
)
for connection in snapshot.connections:
topology.add_connection(connection)
return topology
def add_node(self, node_id: NodeId) -> None:
if node_id in self._vertex_indices:
def add_node(self, node: NodeInfo) -> None:
if node.node_id in self._node_id_to_rx_id_map:
return
rx_id = self._graph.add_node(node_id)
self._vertex_indices[node_id] = rx_id
rx_id = self._graph.add_node(node)
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._vertex_indices
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
return [
self._graph[rx_id]
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
self._rx_id_to_node_id_map[rx_id]
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
]
def out_edges(self, node_id: NodeId) -> Iterable[Connection]:
if node_id not in self._vertex_indices:
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
if node_id not in self._node_id_to_rx_id_map:
return []
return (
Connection(source=self._graph[source], sink=self._graph[sink], edge=edge)
for source, sink, edge in self._graph.out_edges(
self._vertex_indices[node_id]
return [
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
)
)
]
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._vertex_indices
return node_id in self._node_id_to_rx_id_map
def add_connection(self, conn: Connection) -> None:
source, sink, edge = conn.source, conn.sink, conn.edge
del conn
if edge in self.get_all_connections_between(source, sink):
def contains_connection(self, connection: Connection) -> bool:
return connection in self._edge_id_to_rx_id_map
def add_connection(
self,
connection: Connection,
) -> None:
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
return
if source not in self._vertex_indices:
self.add_node(source)
if sink not in self._vertex_indices:
self.add_node(sink)
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
_ = self._graph.add_edge(src_id, sink_id, edge)
def list_nodes(self) -> Iterable[NodeInfo]:
return (self._graph[i] for i in self._graph.node_indices())
def get_all_connections_between(
self, source: NodeId, sink: NodeId
) -> Iterable[SocketConnection | RDMAConnection]:
if source not in self._vertex_indices:
return []
if sink not in self._vertex_indices:
return []
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
try:
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
def list_nodes(self) -> Iterable[NodeId]:
return self._graph.nodes()
def update_node_profile(
self, node_id: NodeId, node_profile: NodePerformanceProfile
) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def map_connections(
self,
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
for src_id, sink_id, connection in self._graph.weighted_edge_list():
source = self._graph[src_id]
sink = self._graph[sink_id]
if source not in base:
base[source] = {}
if sink not in base[source]:
base[source][sink] = []
base[source][sink].append(connection)
return base
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def list_connections(
self,
) -> Iterable[Connection]:
return (
(
Connection(
source=self._graph[src_id],
sink=self._graph[sink_id],
edge=connection,
)
)
for src_id, sink_id, connection in self._graph.weighted_edge_list()
)
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._vertex_indices:
if node_id not in self._node_id_to_rx_id_map:
return
rx_idx = self._vertex_indices[node_id]
for connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph.remove_node(rx_idx)
del self._vertex_indices[node_id]
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
def replace_all_out_rdma_connections(
self, source: NodeId, new_connections: Sequence[Connection]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for conn in new_connections:
self.add_connection(conn)
def remove_connection(self, conn: Connection) -> None:
if (
conn.source not in self._vertex_indices
or conn.sink not in self._vertex_indices
):
def remove_connection(self, connection: Connection) -> None:
if connection not in self._edge_id_to_rx_id_map:
return
for conn_idx in self._graph.edge_indices_from_endpoints(
self._vertex_indices[conn.source], self._vertex_indices[conn.sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == conn.edge:
self._graph.remove_edge_from_index(conn_idx)
def get_cycles(self) -> list[Cycle]:
"""Get simple cycles in the graph, including singleton cycles"""
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
def get_cycles(self) -> list[list[NodeInfo]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[Cycle] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = Cycle(node_ids=[self._graph[idx] for idx in cycle_idx])
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
for node_id in self.list_nodes():
cycles.append(Cycle(node_ids=[node_id]))
return cycles
def get_cycles_tb(self) -> list[Cycle]:
def get_cycles_tb(self) -> list[list[NodeInfo]]:
tb_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
]
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[Cycle] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
cycle = [tb_graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
node_idxs = [node.node_id for node in nodes]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
topology = Topology()
for node_id in node_ids:
topology.add_node(node_id)
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for connection in self.list_connections():
if connection.source in node_ids and connection.sink in node_ids:
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
topology.add_connection(connection)
return topology
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:

View File

@@ -1,31 +1,21 @@
import time
from collections.abc import Generator
from typing import Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
"stop", "length", "tool_calls", "content_filter", "function_call"
]
class ErrorInfo(BaseModel):
message: str
type: str
param: str | None = None
code: int
class ErrorResponse(BaseModel):
error: ErrorInfo
class ModelListModel(BaseModel):
id: str
object: str = "model"
@@ -39,6 +29,7 @@ class ModelListModel(BaseModel):
tags: list[str] = Field(default=[])
storage_size_megabytes: int = Field(default=0)
supports_tensor: bool = Field(default=False)
tasks: list[str] = Field(default=[])
class ModelList(BaseModel):
@@ -63,10 +54,6 @@ class ChatCompletionMessage(BaseModel):
function_call: dict[str, Any] | None = None
class BenchChatCompletionMessage(ChatCompletionMessage):
pass
class TopLogprobItem(BaseModel):
token: str
logprob: float
@@ -129,18 +116,6 @@ class ChatCompletionResponse(BaseModel):
service_tier: str | None = None
class GenerationStats(BaseModel):
prompt_tps: float
generation_tps: float
prompt_tokens: int
generation_tokens: int
peak_memory_usage: Memory
class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
class ChatCompletionTaskParams(BaseModel):
model: str
frequency_penalty: float | None = None
@@ -163,12 +138,8 @@ class ChatCompletionTaskParams(BaseModel):
user: str | None = None
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
pass
class PlaceInstanceParams(BaseModel):
model_id: ModelId
model_id: str
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
@@ -206,10 +177,80 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_card: ModelCard
model_meta: ModelMetadata
class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class ImageGenerationTaskParams(BaseModel):
prompt: str
# background: str | None = None
model: str
# moderation: str | None = None
n: int | None = 1
# output_compression: int | None = None
output_format: Literal["png", "jpeg", "webp"] = "png"
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
stream: bool | None = False
# style: str | None = "vivid"
# user: str | None = None
class ImageEditsTaskParams(BaseModel):
image: UploadFile
prompt: str
input_fidelity: float = 0.7
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
# user: str | None = None
class ImageEditsInternalParams(BaseModel):
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
total_input_chunks: int = 0
prompt: str
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
image_strength: float = 0.7
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} chars>"
elif name is not None:
yield name, value
class ImageData(BaseModel):
b64_json: str | None = None
url: str | None = None
revised_prompt: str | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "b64_json" and value is not None:
yield name, f"<{len(value)} chars>"
elif name is not None:
yield name, value
class ImageGenerationResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
data: list[ImageData]

View File

@@ -1,11 +1,12 @@
from collections.abc import Generator
from enum import Enum
from typing import Any
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .worker.runner_response import ToolCallItem
from .common import CommandId
from .models import ModelId
class ChunkType(str, Enum):
@@ -22,16 +23,37 @@ class TokenChunk(BaseChunk):
text: str
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
class ImageChunk(BaseChunk):
data: bytes
data: str
chunk_index: int
total_chunks: int
image_index: int
is_partial: bool = False
partial_index: int | None = None
total_partials: int | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data":
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk
class InputImageChunk(BaseChunk):
command_id: CommandId
data: str
chunk_index: int
total_chunks: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data":
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
GenerationChunk = TokenChunk | ImageChunk

View File

@@ -1,8 +1,13 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -20,8 +25,16 @@ class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
class ImageGeneration(BaseCommand):
request_params: ImageGenerationTaskParams
class ImageEdits(BaseCommand):
request_params: ImageEditsInternalParams
class PlaceInstance(BaseCommand):
model_card: ModelCard
model_meta: ModelMetadata
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
@@ -39,6 +52,12 @@ class TaskFinished(BaseCommand):
finished_command_id: CommandId
class SendInputChunk(BaseCommand):
"""Command to send an input image chunk (converted to event by master)."""
chunk: InputImageChunk
class RequestEventLog(BaseCommand):
since_idx: int
@@ -47,10 +66,13 @@ Command = (
TestCommand
| RequestEventLog
| ChatCompletion
| ImageGeneration
| ImageEdits
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskFinished
| SendInputChunk
)

View File

@@ -16,23 +16,13 @@ class Id(str):
cls, _source: type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Just use a plain string schema
return core_schema.no_info_after_validator_function(
cls, core_schema.str_schema()
)
return core_schema.str_schema()
class NodeId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class SessionId(CamelCaseModel):
master_node_id: NodeId
election_clock: int

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -76,15 +76,25 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodeTimedOut(BaseEvent):
node_id: NodeId
# TODO: bikeshed this name
class NodeGatheredInfo(BaseEvent):
class NodePerformanceMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
info: GatheredInfo
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
class NodeDownloadProgress(BaseEvent):
@@ -96,12 +106,17 @@ class ChunkGenerated(BaseEvent):
chunk: GenerationChunk
class InputChunkReceived(BaseEvent):
command_id: CommandId
chunk: InputImageChunk
class TopologyEdgeCreated(BaseEvent):
conn: Connection
edge: Connection
class TopologyEdgeDeleted(BaseEvent):
conn: Connection
edge: Connection
Event = (
@@ -115,10 +130,13 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodeGatheredInfo
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| TopologyEdgeCreated
| TopologyEdgeDeleted
)

View File

@@ -0,0 +1,36 @@
from enum import Enum
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
pass
class ModelTask(str, Enum):
TextGeneration = "TextGeneration"
TextToImage = "TextToImage"
ImageToImage = "ImageToImage"
class ComponentInfo(CamelCaseModel):
component_name: str
component_path: str
storage_size: Memory
n_layers: PositiveInt | None
can_shard: bool
safetensors_index_filename: str | None
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
components: list[ComponentInfo] | None = None

View File

@@ -1,11 +1,10 @@
import re
from typing import ClassVar
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
from pydantic import BaseModel, computed_field, field_validator
class Multiaddr(BaseModel):
model_config = ConfigDict(frozen=True)
address: str
PATTERNS: ClassVar[list[str]] = [

View File

@@ -1,14 +1,12 @@
from collections.abc import Sequence
from typing import Self
import psutil
from exo.shared.types.memory import Memory
from exo.shared.types.thunderbolt import ThunderboltIdentifier
from exo.utils.pydantic_ext import CamelCaseModel
class MemoryUsage(CamelCaseModel):
class MemoryPerformanceProfile(CamelCaseModel):
ram_total: Memory
ram_available: Memory
swap_total: Memory
@@ -46,6 +44,7 @@ class SystemPerformanceProfile(CamelCaseModel):
sys_power: float = 0.0
pcpu_usage: float = 0.0
ecpu_usage: float = 0.0
ane_power: float = 0.0
class NetworkInterfaceInfo(CamelCaseModel):
@@ -53,21 +52,16 @@ class NetworkInterfaceInfo(CamelCaseModel):
ip_address: str
class NodeIdentity(CamelCaseModel):
"""Static and slow-changing node identification data."""
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
class NodePerformanceProfile(CamelCaseModel):
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
class NodeNetworkInfo(CamelCaseModel):
"""Network interface information for a node."""
interfaces: Sequence[NetworkInterfaceInfo] = []
class NodeThunderboltInfo(CamelCaseModel):
"""Thunderbolt interface identifiers for a node."""
interfaces: Sequence[ThunderboltIdentifier] = []
class ConnectionProfile(CamelCaseModel):
throughput: float
latency: float
jitter: float

View File

@@ -7,13 +7,7 @@ from pydantic.alias_generators import to_camel
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import (
MemoryUsage,
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
SystemPerformanceProfile,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -41,17 +35,11 @@ class State(CamelCaseModel):
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
last_seen: Mapping[NodeId, datetime] = {}
topology: Topology = Field(default_factory=Topology)
last_event_applied_idx: int = Field(default=-1, ge=-1)
# Granular node state mappings (update independently at different frequencies)
node_identities: Mapping[NodeId, NodeIdentity] = {}
node_memory: Mapping[NodeId, MemoryUsage] = {}
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
@field_serializer("topology", mode="plain")
def _encode_topology(self, value: Topology) -> TopologySnapshot:
return value.to_snapshot()

View File

@@ -2,7 +2,11 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.common import CommandId, Id
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
@@ -56,6 +60,22 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class ImageEdits(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageEditsInternalParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
@@ -67,5 +87,7 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| ImageGeneration
| ImageEdits
| Shutdown
)

View File

@@ -1,81 +0,0 @@
import anyio
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
class ThunderboltConnection(CamelCaseModel):
source_uuid: str
sink_uuid: str
class ThunderboltIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
## Intentionally minimal, only collecting data we care about - there's a lot more
class _ReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str | None = None
class _ConnectivityItem(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
class ThunderboltConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
items: list[_ConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: _ReceptacleTag | None = None
def ident(self, ifaces: dict[str, str]) -> ThunderboltIdentifier | None:
if (
self.domain_uuid_key is None
or self.receptacle_1_tag is None
or self.receptacle_1_tag.receptacle_id_key is None
):
return
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
assert tag in ifaces # doesn't need to be an assertion but im confident
# if tag not in ifaces: return None
iface = f"rdma_{ifaces[tag]}"
return ThunderboltIdentifier(
rdma_interface=iface, domain_uuid=self.domain_uuid_key
)
def conn(self) -> ThunderboltConnection | None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
(
item.domain_uuid_key
for item in self.items
if item.domain_uuid_key is not None
),
None,
)
if sink_key is None:
return None
return ThunderboltConnection(
source_uuid=self.domain_uuid_key, sink_uuid=sink_key
)
class ThunderboltConnectivity(BaseModel, extra="ignore"):
SPThunderboltDataType: list[ThunderboltConnectivityData] = []
@classmethod
async def gather(cls) -> list[ThunderboltConnectivityData] | None:
proc = await anyio.run_process(
["system_profiler", "SPThunderboltDataType", "-json"], check=False
)
if proc.returncode != 0:
return None
# Saving you from PascalCase while avoiding too much pydantic
return ThunderboltConnectivity.model_validate_json(
proc.stdout
).SPThunderboltDataType

Some files were not shown because too many files have changed in this diff Show More