Compare commits

..

239 Commits

Author SHA1 Message Date
ciaranbor
f60f0d8f4c Correctly handle num_sync_steps for diffusion steps override 2026-01-19 23:18:57 +00:00
ciaranbor
faab7c20c5 Simplify DistributedImageModel 2026-01-19 22:28:01 +00:00
ciaranbor
29f008395f Hide internal mflux image type 2026-01-19 21:58:48 +00:00
ciaranbor
35dd79de25 Only run two steps during warmup 2026-01-19 21:30:12 +00:00
ciaranbor
ba3a42cf91 Document image apis 2026-01-19 20:36:24 +00:00
ciaranbor
b9fa65d375 Clean up ImageModelConfig 2026-01-19 19:55:52 +00:00
ciaranbor
5a76641edd Add AdvancedImageParams to api 2026-01-19 19:28:56 +00:00
ciaranbor
6e1a6901fa Capture partial images earlier 2026-01-19 19:08:24 +00:00
ciaranbor
9ab66d5f5f Remove redundant tensor allocations for recv templates 2026-01-19 18:55:08 +00:00
ciaranbor
2d374c90b6 Remove recv evals, use async_eval for sends 2026-01-19 17:33:30 +00:00
ciaranbor
8c93c908e8 Remove ImageGenerator protocol 2026-01-19 17:08:16 +00:00
ciaranbor
d0ac6ee823 Batch CFG 2026-01-19 14:17:31 +00:00
ciaranbor
04cd02b1b5 Add image generation benchmarking endpoints 2026-01-18 21:04:27 +00:00
ciaranbor
4519548ef0 Consolidate patched and unpatched qkv computation logic 2026-01-18 10:33:36 +00:00
ciaranbor
94785a76b4 Use mixin for common block wrapper functionality 2026-01-18 10:18:00 +00:00
ciaranbor
1f088eaac9 Move last rank to first rank comms outside the CFG step 2026-01-18 09:47:14 +00:00
ciaranbor
a315e68f33 Revert 2026-01-17 22:35:44 +00:00
ciaranbor
2160ebd559 Run all positive then all negative 2026-01-17 21:36:34 +00:00
ciaranbor
d63096393c Rank 0 shouldn't receive on negative pass 2026-01-17 19:48:48 +00:00
ciaranbor
75d0ac291d Fix negative pass text_seq_len 2026-01-17 19:34:43 +00:00
ciaranbor
5ac44a03a4 Add distributed CFG support 2026-01-17 19:12:00 +00:00
ciaranbor
01148d57aa Enable CFG for Qwen-Image 2026-01-17 16:24:08 +00:00
ciaranbor
29191b0913 Use transformer block wrapper classes 2026-01-17 15:21:50 +00:00
ciaranbor
ad6f2bfdde Refactor 2026-01-16 19:26:04 +00:00
ciaranbor
7925cc826a Fix flux tokenizer 2026-01-15 18:39:56 +00:00
ciaranbor
b9c9f53a73 Reduce image generation and image edits code duplication 2026-01-15 18:28:24 +00:00
ciaranbor
7eea39de9d Update mflux to 0.14.2 2026-01-15 17:43:33 +00:00
ciaranbor
a736c3824c Add python-multipart dependency 2026-01-15 16:31:47 +00:00
ciaranbor
58970a7ba2 Register tasks for newly added models 2026-01-15 16:31:47 +00:00
ciaranbor
fe9bbcd3d0 Linting 2026-01-15 16:31:47 +00:00
ciaranbor
b164e62fd7 Start image editing time steps at 0 2026-01-15 16:31:47 +00:00
ciaranbor
6078e9a8a1 Ignore image_strength 2026-01-15 16:31:47 +00:00
ciaranbor
ff52182aff Handle conditioning latents in sync pipeline 2026-01-15 16:31:47 +00:00
ciaranbor
de4efff6ac Use dummy image for editing warmup 2026-01-15 16:31:47 +00:00
ciaranbor
012b87abbf Support streaming for image editing 2026-01-15 16:31:47 +00:00
ciaranbor
68a64b5671 Support image editing in runner 2026-01-15 16:31:47 +00:00
ciaranbor
da8579c3ab Add editing features to adapter 2026-01-15 16:31:47 +00:00
ciaranbor
db222fb69e Default partial images to 3 if streaming 2026-01-15 16:31:47 +00:00
ciaranbor
a4329997c8 Add Qwen-Image model adapter 2026-01-15 16:31:47 +00:00
ciaranbor
3d9f8a5161 Add Qwen-Image-Edit model config 2026-01-15 16:31:47 +00:00
ciaranbor
5095d52a3d Use image generation in streaming mode in UI 2026-01-15 16:31:47 +00:00
ciaranbor
59b92b7c56 Handle partial image streaming 2026-01-15 16:31:47 +00:00
ciaranbor
1be6caacb4 Add streaming params to ImageGenerationTaskParams 2026-01-15 16:31:47 +00:00
ciaranbor
b2cf0390da Add Qwen-Image-Edit-2509 2026-01-15 16:31:47 +00:00
ciaranbor
148ee6c347 Handle image editing time steps 2026-01-15 16:31:47 +00:00
ciaranbor
e40976290d Fix time steps 2026-01-15 16:31:47 +00:00
ciaranbor
69178664b3 Fix image_strength meaning 2026-01-15 16:31:47 +00:00
ciaranbor
b8ef6345ef Truncate image data logs 2026-01-15 16:31:47 +00:00
ciaranbor
fc534e3a0d Chunk image input 2026-01-15 16:31:47 +00:00
ciaranbor
c8cb0b04b4 Avoid logging image data 2026-01-15 16:31:47 +00:00
ciaranbor
77334d60c7 Support image editing 2026-01-15 16:31:47 +00:00
Sami Khan
71d684861c small UI change 2026-01-15 16:31:47 +00:00
Sami Khan
95e277acf7 image gen in dashboard 2026-01-15 16:31:46 +00:00
ciaranbor
63284fd5fe Better llm model type check 2026-01-15 16:31:46 +00:00
ciaranbor
982567db05 Prune blocks before model load 2026-01-15 16:31:46 +00:00
ciaranbor
d72a41e9bc Own TODOs 2026-01-15 16:31:46 +00:00
ciaranbor
8921b786ab Remove double RunnerReady event 2026-01-15 16:31:46 +00:00
ciaranbor
52113075a4 Fix hidden_size for image models 2026-01-15 16:31:46 +00:00
ciaranbor
9d8add1977 Fix image model cards 2026-01-15 16:31:46 +00:00
ciaranbor
daf7fe495b Skip decode on non-final ranks 2026-01-15 16:31:46 +00:00
ciaranbor
b494cc3c11 Final rank produces image 2026-01-15 16:31:46 +00:00
ciaranbor
e9e6f93945 Increase number of sync steps 2026-01-15 16:31:46 +00:00
ciaranbor
3410874ee9 Change Qwen-Image steps 2026-01-15 16:31:46 +00:00
ciaranbor
a6ba92bf6b Fix Qwen-Image latent shapes 2026-01-15 16:31:46 +00:00
ciaranbor
280364d872 Fix joint block patch recv shape for non-zero ranks 2026-01-15 16:31:46 +00:00
ciaranbor
832f687d85 Fix comms issue for models without single blocks 2026-01-15 16:31:46 +00:00
ciaranbor
769162b509 Support Qwen in DiffusionRunner pipefusion 2026-01-15 16:31:46 +00:00
ciaranbor
f80a9789a5 Implement Qwen pipefusion 2026-01-15 16:31:46 +00:00
ciaranbor
c159f2f7b9 Add guidance_scale parameter to image model config 2026-01-15 16:31:46 +00:00
ciaranbor
f4270a6056 Move orchestration to DiffusionRunner 2026-01-15 16:31:46 +00:00
ciaranbor
bd48be8b0e Add initial QwenModelAdapter 2026-01-15 16:31:46 +00:00
ciaranbor
2b556ac7fb Tweak embeddings interface 2026-01-15 16:31:46 +00:00
ciaranbor
3bccde49d0 Add Qwen ImageModelConfig 2026-01-15 16:31:46 +00:00
ciaranbor
1fa952adfc Use 10% sync steps 2026-01-15 16:31:46 +00:00
ciaranbor
f4909aa7c6 Update FluxModelAdaper for new interface 2026-01-15 16:31:46 +00:00
ciaranbor
89a2bd4d18 Register QwenModelAdapter 2026-01-15 16:31:46 +00:00
ciaranbor
623f623297 Support multiple forward passes in runner 2026-01-15 16:31:46 +00:00
ciaranbor
4022d0585b Extend block wrapper parameters 2026-01-15 16:31:46 +00:00
ciaranbor
e2155579f4 Relax adaptor typing 2026-01-15 16:31:46 +00:00
ciaranbor
5a1a124e65 Add Qwen-Image model card 2026-01-15 16:31:46 +00:00
ciaranbor
3121827263 Clean up dead code 2026-01-15 16:31:46 +00:00
ciaranbor
480b72b1b1 Add BaseModelAdaptor 2026-01-15 16:31:46 +00:00
ciaranbor
28986bb678 Refactor filestructure 2026-01-15 16:31:46 +00:00
ciaranbor
c53fc6a16f Treat unified blocks as single blocks (equivalent) 2026-01-15 16:31:46 +00:00
ciaranbor
2e86b0f5a9 Refactor to handle entire denoising process in Diffusion runner 2026-01-15 16:31:46 +00:00
ciaranbor
98f0a29085 Move transformer to adapter 2026-01-15 16:31:46 +00:00
ciaranbor
1c0f2daf3c Move some more logic to adaptor 2026-01-15 16:31:46 +00:00
ciaranbor
814a836db1 Add generic block wrapper 2026-01-15 16:31:46 +00:00
ciaranbor
ef03ef049c Access transformer blocks from adaptor 2026-01-15 16:31:46 +00:00
ciaranbor
ba8567418d Better typing 2026-01-15 16:31:46 +00:00
ciaranbor
801ecf4483 Create wrappers at init time 2026-01-15 16:31:46 +00:00
ciaranbor
87e25961f5 Combine model factory and adaptor 2026-01-15 16:31:46 +00:00
ciaranbor
1d1014eaef Implement model factory 2026-01-15 16:31:46 +00:00
ciaranbor
8bd077de52 Add adaptor registry 2026-01-15 16:31:46 +00:00
ciaranbor
8377da5e22 Remove mflux/generator/generate.py 2026-01-15 16:31:46 +00:00
ciaranbor
4aa3b75000 Switch to using DistributedImageModel 2026-01-15 16:31:46 +00:00
ciaranbor
3d24aab421 Add DistributedImageModel 2026-01-15 16:31:46 +00:00
ciaranbor
04bb688005 Use new generic wrappers, etc in denoising 2026-01-15 16:31:46 +00:00
ciaranbor
1c70cea40c Add generic transformer block wrappers 2026-01-15 16:31:46 +00:00
ciaranbor
5928f369c5 Add FluxAdaptor 2026-01-15 16:31:46 +00:00
ciaranbor
ce1b66e5e6 Add ModelAdaptor, derivations implement model specific logic 2026-01-15 16:31:46 +00:00
ciaranbor
5d503a1ffb Introduce image model config concept 2026-01-15 16:31:46 +00:00
ciaranbor
f94a5ec8df Consolidate kv cache patching 2026-01-15 16:31:46 +00:00
ciaranbor
d45f9d98c0 Support different configuration comms 2026-01-15 16:31:46 +00:00
ciaranbor
7d4faf04fb Add ImageGenerator protocol 2026-01-15 16:31:46 +00:00
ciaranbor
2d4ba878cb Force final patch receive order 2026-01-15 16:31:46 +00:00
ciaranbor
bca5a9ffe3 Remove logs 2026-01-15 16:31:46 +00:00
ciaranbor
6bbd134880 Update patch list 2026-01-15 16:31:46 +00:00
ciaranbor
1414da68ec Slight refactor 2026-01-15 16:31:46 +00:00
ciaranbor
c88156f5ab Don't need array for prev patches 2026-01-15 16:31:46 +00:00
ciaranbor
770982c830 Fix send/recv order 2026-01-15 16:31:46 +00:00
ciaranbor
8d99ed8133 Fix async single transformer block 2026-01-15 16:31:46 +00:00
ciaranbor
c3d8fbc5ed Use relative rank variables 2026-01-15 16:31:46 +00:00
ciaranbor
a72830c301 Fix writing patches 2026-01-15 16:31:46 +00:00
ciaranbor
1fc355a2b1 Collect final image 2026-01-15 16:31:46 +00:00
ciaranbor
19dc8380c6 Fix recv_template shape 2026-01-15 16:31:46 +00:00
ciaranbor
11148923ca Add logs 2026-01-15 16:31:46 +00:00
ciaranbor
329b7d5f36 Optimise async pipeline 2026-01-15 16:31:46 +00:00
ciaranbor
1752aaa44a Add next_rank and prev_rank members 2026-01-15 16:31:46 +00:00
ciaranbor
a3fa833ae4 Add _create_patches method 2026-01-15 16:31:46 +00:00
ciaranbor
4661013cbb Fix shapes 2026-01-15 16:31:46 +00:00
ciaranbor
54e80a314d Reorder comms 2026-01-15 16:31:46 +00:00
ciaranbor
7990d8b1ef Remove all_gather from sync pipeline, send from final rank to first rank 2026-01-15 16:31:46 +00:00
ciaranbor
289bbe3253 Simplify kv_cache initialization 2026-01-15 16:31:46 +00:00
ciaranbor
ea06742295 Fix kv cache 2026-01-15 16:31:46 +00:00
ciaranbor
5bf986db6b Clean up kv caches 2026-01-15 16:31:46 +00:00
ciaranbor
0a9c1f7212 Fix return 2026-01-15 16:31:46 +00:00
ciaranbor
4b84aa5f70 Fix hidden_states shapes 2026-01-15 16:31:46 +00:00
ciaranbor
148f6550ed Only perform projection and scheduler step on last rank 2026-01-15 16:31:46 +00:00
ciaranbor
ecf2f40b4c Only compute embeddings on rank 0 2026-01-15 16:31:46 +00:00
ciaranbor
eea030b8c2 Remove eval 2026-01-15 16:31:46 +00:00
ciaranbor
73c92dfe60 Remove eval 2026-01-15 16:31:46 +00:00
ciaranbor
a6f7c4b822 Only send encoder_hidden_states with the first patch (once per timestep) 2026-01-15 16:31:46 +00:00
ciaranbor
4b81f8a672 Remove redundant text kv cache computation 2026-01-15 16:31:46 +00:00
ciaranbor
6e57d817d1 Concatenate before all gather 2026-01-15 16:31:46 +00:00
ciaranbor
4905107ea2 Increase number of sync steps 2026-01-15 16:31:46 +00:00
ciaranbor
e56c970e74 Reinitialise kv_caches between generations 2026-01-15 16:31:46 +00:00
ciaranbor
5d3bc83a63 Eliminate double kv cache computation 2026-01-15 16:31:46 +00:00
ciaranbor
88356eb0a0 Add kv cache caching wrappers for sync pipeline transformer blocks 2026-01-15 16:31:46 +00:00
ciaranbor
cab296ada7 Persist kv caches 2026-01-15 16:31:46 +00:00
ciaranbor
bfc6650a13 Implement naive async pipeline implementation 2026-01-15 16:31:46 +00:00
ciaranbor
66c091ae88 Use wrapper classes for patched transformer logic 2026-01-15 16:31:45 +00:00
ciaranbor
1ca1a3e490 Add patch-aware joint and single attention wrappers 2026-01-15 16:31:45 +00:00
ciaranbor
b778213792 Fix group.size() 2026-01-15 16:31:45 +00:00
ciaranbor
14a3a5d41c Add classes to manage kv caches with patch support 2026-01-15 16:31:45 +00:00
ciaranbor
bef9589510 Use heuristic for number of sync steps 2026-01-15 16:31:45 +00:00
ciaranbor
d39fbf796d Generalise number of denoising steps 2026-01-15 16:31:45 +00:00
ciaranbor
19ef6ea748 Add flux1-dev 2026-01-15 16:31:45 +00:00
ciaranbor
431ddf947e Move scheduler step to inner pipeline 2026-01-15 16:31:45 +00:00
ciaranbor
b5485bf6ef Add barrier before all_gather 2026-01-15 16:31:45 +00:00
ciaranbor
8325d5b865 Fix transformer blocks pruning 2026-01-15 16:31:45 +00:00
ciaranbor
44de96c15c Fix image generation api 2026-01-15 16:31:45 +00:00
ciaranbor
00c88a1102 Create queue in try block 2026-01-15 16:31:45 +00:00
ciaranbor
594487caed Conform to rebase 2026-01-15 16:31:45 +00:00
ciaranbor
7d9df93b7a Refactor denoising 2026-01-15 16:31:45 +00:00
ciaranbor
692907d2de Move more logic to DistributedFlux 2026-01-15 16:31:45 +00:00
ciaranbor
4018f698a1 Move surrounding logic back to _sync_pipeline 2026-01-15 16:31:45 +00:00
ciaranbor
330c7bb9cf Add patching aware member variables 2026-01-15 16:31:45 +00:00
ciaranbor
c8d54af8b6 Implement sync/async switching logic 2026-01-15 16:31:45 +00:00
ciaranbor
ba798e6bd3 Move current transformer implementation to _sync_pipeline method 2026-01-15 16:31:45 +00:00
ciaranbor
bf25de116a Remove some logs 2026-01-15 16:31:45 +00:00
ciaranbor
64e0dd06a8 Remove old Flux1 implementation 2026-01-15 16:31:45 +00:00
ciaranbor
5dafb7aceb Prune unused transformer blocks 2026-01-15 16:31:45 +00:00
ciaranbor
bbe0b58642 Add mx.eval 2026-01-15 16:31:45 +00:00
ciaranbor
b3233e35f0 Test evals 2026-01-15 16:31:45 +00:00
ciaranbor
887441e666 Test only barriers 2026-01-15 16:31:45 +00:00
ciaranbor
e3231ae22b All perform final projection 2026-01-15 16:31:45 +00:00
ciaranbor
b2918f5e42 Another barrier 2026-01-15 16:31:45 +00:00
ciaranbor
4d9b893d7a More debug 2026-01-15 16:31:45 +00:00
ciaranbor
2494a05790 Add barriers 2026-01-15 16:31:45 +00:00
ciaranbor
9802f27545 Add log 2026-01-15 16:31:45 +00:00
ciaranbor
926b197ea5 Restore distributed logging 2026-01-15 16:31:45 +00:00
ciaranbor
580d1738fc Use bootstrap logger 2026-01-15 16:31:45 +00:00
ciaranbor
a94aacb72b Remove logs 2026-01-15 16:31:45 +00:00
ciaranbor
e6758829c7 fix single block receive shape 2026-01-15 16:31:45 +00:00
ciaranbor
c892352860 Add debug logs 2026-01-15 16:31:45 +00:00
ciaranbor
23048f0fbb Move communication logic to DistributedTransformer wrapper 2026-01-15 16:31:45 +00:00
ciaranbor
3a45e55dcf Move inference logic to DistribuedFlux1 2026-01-15 16:31:45 +00:00
ciaranbor
50ba4a38f1 Add DistributedFlux1 class 2026-01-15 16:31:45 +00:00
ciaranbor
d4f49b9a38 Rename pipeline to pipefusion 2026-01-15 16:31:45 +00:00
ciaranbor
57135bda07 Further refactor 2026-01-15 16:31:45 +00:00
ciaranbor
ab492c76e9 Refactor warmup 2026-01-15 16:31:45 +00:00
ciaranbor
e55b3d496f Manually handle flux1 inference 2026-01-15 16:31:45 +00:00
ciaranbor
c20ad0d5fe Refactor flux1 image generation 2026-01-15 16:31:45 +00:00
ciaranbor
b02fb39747 Use quality parameter to set number of inference steps 2026-01-15 16:31:45 +00:00
ciaranbor
d257abed82 Chunk image data transfer 2026-01-15 16:31:45 +00:00
ciaranbor
e84a14b650 Define EXO_MAX_CHUNK_SIZE 2026-01-15 16:31:45 +00:00
ciaranbor
04128b65a7 Add indexing info to ImageChunk 2026-01-15 16:31:45 +00:00
ciaranbor
3d6e675af8 Remove sharding logs 2026-01-15 16:31:45 +00:00
ciaranbor
7b3320cd0e Temp: reduce flux1.schnell storage size 2026-01-15 16:31:45 +00:00
ciaranbor
1b7208bc04 Fix mflux transformer all_gather 2026-01-15 16:31:45 +00:00
ciaranbor
eef91921f2 Add all_gather -> broadcast todo 2026-01-15 16:31:45 +00:00
ciaranbor
3e8ab46d69 Fix world size 2026-01-15 16:31:45 +00:00
ciaranbor
02f811dd7e Fix transition block? 2026-01-15 16:31:45 +00:00
ciaranbor
1b7eb4abb2 Implement image generation warmup 2026-01-15 16:31:45 +00:00
ciaranbor
c8f27976c9 Add logs 2026-01-15 16:31:45 +00:00
ciaranbor
56e6ae4984 Add spiece.model to default patterns 2026-01-15 16:31:45 +00:00
ciaranbor
bccb2977ec Just download all files for now 2026-01-15 16:31:45 +00:00
ciaranbor
3d38e1977e Fix get_allow_patterns to include non-indexed safetensors files 2026-01-15 16:31:45 +00:00
ciaranbor
beb6371caf Use half-open layer indexing in get_allow_patterns 2026-01-15 16:31:45 +00:00
ciaranbor
6f66b387a8 Enable distributed mflux 2026-01-15 16:31:45 +00:00
ciaranbor
b0b789d971 Implement mflux transformer sharding and communication pattern 2026-01-15 16:31:45 +00:00
ciaranbor
6921df88a1 Update get_allow_patterns to handle sharding components 2026-01-15 16:31:45 +00:00
ciaranbor
b4cd0517c9 Namespace both keys and values for component weight maps 2026-01-15 16:31:45 +00:00
ciaranbor
ece3f207ad Add components to Flux.1-schnell MODEL_CARD 2026-01-15 16:31:45 +00:00
ciaranbor
29575a1fea Add component concept for ModelMetadata 2026-01-15 16:31:45 +00:00
ciaranbor
8d0cdb2b52 Fix multiple components weight map key conflicts 2026-01-15 16:31:45 +00:00
ciaranbor
eeac072a6b get_weight_map: handle repos with multiple safetensors.index.json files 2026-01-15 16:31:45 +00:00
ciaranbor
7ed6b75b41 Add initial image edits spec 2026-01-15 16:31:45 +00:00
ciaranbor
6e00899385 Add image edits endpoint 2026-01-15 16:31:45 +00:00
ciaranbor
ea3bab243a Add ImageToImage task 2026-01-15 16:31:45 +00:00
ciaranbor
497a2c065d Allow ModelCards to have multiple tasks 2026-01-15 16:31:45 +00:00
ciaranbor
4dd1a7c1b6 Fix text generation 2026-01-15 16:31:45 +00:00
ciaranbor
670a0f0c4a Rename mlx_generate_image to mflux_generate 2026-01-15 16:31:45 +00:00
ciaranbor
01a0d6d141 Initialize mlx or mflux engine based on model task 2026-01-15 16:31:45 +00:00
ciaranbor
761d2d82a7 Restore warmup for text generation 2026-01-15 16:31:45 +00:00
ciaranbor
248bea1839 Add initialize_mflux function 2026-01-15 16:31:45 +00:00
ciaranbor
f41d0129e5 Move image generation to mflux engine 2026-01-15 16:31:45 +00:00
ciaranbor
65f9d666b5 Just use str for image generation size 2026-01-15 16:31:45 +00:00
ciaranbor
7b13b361d0 Use MFlux for image generation 2026-01-15 16:31:45 +00:00
ciaranbor
dad82a605c Add get_model_card function 2026-01-15 16:31:45 +00:00
ciaranbor
6573b47abf Add ModelTask enum 2026-01-15 16:31:45 +00:00
ciaranbor
4596a7ac24 ADd flux1-schnell model 2026-01-15 16:31:45 +00:00
ciaranbor
e580b45eb2 Add task field to ModelCard 2026-01-15 16:31:04 +00:00
ciaranbor
def080c7e3 Update mflux version 2026-01-15 16:31:04 +00:00
ciaranbor
806239f14b Enable recursive repo downloads 2026-01-15 16:31:04 +00:00
ciaranbor
154f3561e7 Add dummy generate_image implementation 2026-01-15 16:31:04 +00:00
ciaranbor
07f7601948 Use base64 encoded str for image data 2026-01-15 16:31:03 +00:00
ciaranbor
229bd05473 Handle ImageGeneration tasks in _pending_tasks 2026-01-15 16:31:03 +00:00
ciaranbor
ac0c187aed Add mflux dependency 2026-01-15 16:31:03 +00:00
ciaranbor
083de373db Handle ImageGeneration task in runner task processing 2026-01-15 16:31:03 +00:00
ciaranbor
ae95172e41 Handle ImageGeneration command in master command processing 2026-01-15 16:31:03 +00:00
ciaranbor
a688001446 Add image generation to API 2026-01-15 16:31:03 +00:00
ciaranbor
796f291d85 Add ImageGenerationResponse 2026-01-15 16:31:03 +00:00
ciaranbor
73a09cf98c Add ImageGeneration task 2026-01-15 16:31:03 +00:00
ciaranbor
6c6dfd9ec7 Add ImageGeneration command 2026-01-15 16:31:03 +00:00
ciaranbor
41dbcf0b37 Add image generation params and response types 2026-01-15 16:31:03 +00:00
ciaranbor
b291950c1a Add pillow dependency 2026-01-15 16:31:03 +00:00
ciaranbor
f48b3dd870 Fix mlx stream_generate import 2026-01-15 16:31:03 +00:00
117 changed files with 10803 additions and 5643 deletions

View File

@@ -1,16 +1,5 @@
name: Build EXO macOS DMG
# Release workflow:
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
# 3. This workflow builds, signs, and notarizes the DMG
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
# 5. DMG and appcast.xml are uploaded to S3
# 6. The draft GitHub Release is published with the DMG attached
#
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
# If no draft exists, a release is auto-created with generated notes.
on:
workflow_dispatch:
push:
@@ -22,10 +11,8 @@ on:
jobs:
build-macos-app:
runs-on: "macos-26"
permissions:
contents: write
env:
SPARKLE_VERSION: 2.9.0-beta.1
SPARKLE_VERSION: 2.8.1
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
@@ -100,52 +87,6 @@ jobs:
exit 1
fi
- name: Fetch and validate release notes
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Find draft release by name using gh release list (more reliable with default token)
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
if [[ -z "$DRAFT_EXISTS" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
echo "Please create a draft release with release notes before pushing the tag."
exit 1
fi
# Fetch full release details via API to get body and ID
echo "Found draft release, fetching details..."
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
# Extract release notes
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "Draft release has no notes (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: Draft release exists but has no release notes"
echo "Please add release notes to the draft release before pushing the tag."
exit 1
fi
# Save release ID for later publishing
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
echo "$NOTES" > /tmp/release_notes.md
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
# ============================================================
# Install dependencies
# ============================================================
@@ -363,28 +304,6 @@ jobs:
$CHANNEL_FLAG \
.
- name: Inject release notes into appcast
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
env:
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
run: |
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
export NOTES=$(cat "$RELEASE_NOTES_FILE")
# Insert description after the enclosure tag for this version
awk '
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
print
print " <description sparkle:format=\"markdown\"><![CDATA["
print ENVIRON["NOTES"]
print " ]]></description>"
next
}
{ print }
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
echo "Injected markdown release notes for version $RELEASE_VERSION"
# ============================================================
# Upload artifacts
# ============================================================
@@ -417,26 +336,3 @@ jobs:
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
fi
- name: Publish GitHub Release
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
# Update the draft release with the tag and upload DMG
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
-f tag_name="$GITHUB_REF_NAME" \
-F draft=false
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
echo "Published release $GITHUB_REF_NAME with DMG attached"
else
# Alpha without draft release - create one with auto-generated notes
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
--title "$GITHUB_REF_NAME" \
--generate-notes \
--prerelease
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
fi

View File

@@ -40,31 +40,6 @@ uv run ruff check
nix fmt
```
## Pre-Commit Checks (REQUIRED)
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
```bash
# 1. Type checking - MUST pass with 0 errors
uv run basedpyright
# 2. Linting - MUST pass
uv run ruff check
# 3. Formatting - MUST be applied
nix fmt
# 4. Tests - MUST pass
uv run pytest
```
Run all checks in sequence:
```bash
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
```
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
## Architecture
### Node Composition

19
Cargo.lock generated
View File

@@ -4340,6 +4340,25 @@ dependencies = [
"libc",
]
[[package]]
name = "system_custodian"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
]
[[package]]
name = "tagptr"
version = "0.2.0"

View File

@@ -3,6 +3,7 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/system_custodian",
"rust/util",
]
@@ -24,6 +25,7 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }
# Proc-macro authoring tools

View File

@@ -27,22 +27,13 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
## Dashboard
exo includes a built-in dashboard for managing your cluster and chatting with models.
<p align="center">
<img src="docs/imgs/dashboard-cluster-view.png" alt="exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded" width="80%" />
</p>
<p align="center"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>
## Benchmarks
<details>
<summary>Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-1-qwen3-235b.jpeg" alt="Benchmark - Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
@@ -50,7 +41,7 @@ exo includes a built-in dashboard for managing your cluster and chatting with mo
<summary>DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-2-deepseek-3.1-671b.jpeg" alt="Benchmark - DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
@@ -58,7 +49,7 @@ exo includes a built-in dashboard for managing your cluster and chatting with mo
<summary>Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-3-kimi-k2-thinking.jpeg" alt="Benchmark - Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
<p>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt 5</a>
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio RDMA over Thunderbolt5</a>
</p>
</details>
@@ -163,24 +154,6 @@ This starts the exo dashboard and API at http://localhost:52415/
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
**Configuration Options:**
- `--no-worker`: Run exo without the worker component. Useful for coordinator-only nodes that handle networking and orchestration but don't execute inference tasks. This is helpful for machines without sufficient GPU resources but with good network connectivity.
```bash
uv run exo --no-worker
```
**File Locations (Linux):**
exo follows the [XDG Base Directory Specification](https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html) on Linux:
- **Configuration files**: `~/.config/exo/` (or `$XDG_CONFIG_HOME/exo/`)
- **Data files**: `~/.local/share/exo/` (or `$XDG_DATA_HOME/exo/`)
- **Cache files**: `~/.cache/exo/` (or `$XDG_CACHE_HOME/exo/`)
You can override these locations by setting the corresponding XDG environment variables.
### macOS App
exo ships a macOS app that runs in the background on your Mac.
@@ -193,19 +166,6 @@ Download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-
The app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.
**Custom Namespace for Cluster Isolation:**
The macOS app includes a custom namespace feature that allows you to isolate your exo cluster from others on the same network. This is configured through the `EXO_LIBP2P_NAMESPACE` setting:
- **Use cases**:
- Running multiple separate exo clusters on the same network
- Isolating development/testing clusters from production clusters
- Preventing accidental cluster joining
- **Configuration**: Access this setting in the app's Advanced settings (or set the `EXO_LIBP2P_NAMESPACE` environment variable when running from source)
The namespace is logged on startup for debugging purposes.
#### Uninstalling the macOS App
The recommended way to uninstall is through the app itself: click the menu bar icon → Advanced → Uninstall. This cleanly removes all system components.
@@ -352,52 +312,6 @@ For further details, see:
---
## Benchmarking
The `exo-bench` tool measures model prefill and token generation speed across different placement configurations. This helps you optimize model performance and validate improvements.
**Prerequisites:**
- Nodes should be running with `uv run exo` before benchmarking
- The tool uses the `/bench/chat/completions` endpoint
**Basic usage:**
```bash
uv run bench/exo_bench.py \
--model llama-3.2-1b \
--pp 128,256,512 \
--tg 128,256
```
**Key parameters:**
- `--model`: Model to benchmark (short ID or HuggingFace ID)
- `--pp`: Prompt size hints (comma-separated integers)
- `--tg`: Generation lengths (comma-separated integers)
- `--max-nodes`: Limit placements to N nodes (default: 4)
- `--instance-meta`: Filter by `ring`, `jaccl`, or `both` (default: both)
- `--sharding`: Filter by `pipeline`, `tensor`, or `both` (default: both)
- `--repeat`: Number of repetitions per configuration (default: 1)
- `--warmup`: Warmup runs per placement (default: 0)
- `--json-out`: Output file for results (default: bench/results.json)
**Example with filters:**
```bash
uv run bench/exo_bench.py \
--model llama-3.2-1b \
--pp 128,512 \
--tg 128 \
--max-nodes 2 \
--sharding tensor \
--repeat 3 \
--json-out my-results.json
```
The tool outputs performance metrics including prompt tokens per second (prompt_tps), generation tokens per second (generation_tps), and peak memory usage for each configuration.
---
## Hardware Accelerator Support
On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working on extending hardware accelerator support. If you'd like support for a new hardware platform, please [search for an existing feature request](https://github.com/exo-explore/exo/issues) and add a thumbs up so we know what hardware is important to the community.
@@ -406,4 +320,4 @@ On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working
## Contributing
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.

View File

@@ -19,7 +19,6 @@
25. Rethink retry logic
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
27. Log cleanup - per-module log filters and default to DEBUG log levels
28. Validate RDMA connections with ibv_devinfo in the info gatherer
Potential refactors:

View File

@@ -585,7 +585,7 @@
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.9.0-beta.1;
minimumVersion = 2.8.1;
};
};
/* End XCRemoteSwiftPackageReference section */

View File

@@ -6,8 +6,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/sparkle-project/Sparkle.git",
"state" : {
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
"version" : "2.9.0-beta.1"
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
"version" : "2.8.1"
}
}
],

View File

@@ -56,11 +56,6 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
lastConnectionState = "connecting"
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
let result = await self.performCheck()
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
connection?.cancel()
connection = nil
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { state in
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
switch state {
case .ready:
resumeOnce(.working)
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: timeout)
try? await Task.sleep(nanoseconds: 3_000_000_000)
let state = conn.state
switch state {
case .ready:

View File

@@ -6,7 +6,7 @@ enum NetworkSetupHelper {
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
private static let daemonLabel = "io.exo.networksetup"
private static let scriptDestination =
"/Library/Application Support/EXO/disable_bridge.sh"
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1791
@@ -28,6 +28,35 @@ enum NetworkSetupHelper {
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
@@ -112,13 +141,6 @@ enum NetworkSetupHelper {
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
guard scriptExists, plistExists else { return false }
guard
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
else {
return false
}
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import argparse
import contextlib
import http.client
import json
import os
@@ -27,7 +26,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -105,46 +104,22 @@ def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
@@ -266,9 +241,6 @@ class PromptSizer:
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn
@@ -324,12 +296,6 @@ def main() -> int:
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
@@ -351,7 +317,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -430,7 +396,7 @@ def main() -> int:
):
continue
if args.min_nodes <= n <= args.max_nodes:
if 0 < n <= args.max_nodes:
selected.append(p)
if not selected:
@@ -472,13 +438,7 @@ def main() -> int:
)
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
wait_for_instance_ready(client, instance_id)
time.sleep(1)
@@ -496,9 +456,9 @@ def main() -> int:
and "tensor" in sharding.lower()
):
model_card = MODEL_CARDS[short_id]
if model_card.storage_size > Memory.from_gb(10):
if model_card.metadata.storage_size > Memory.from_gb(10):
logger.info(
f"Skipping tensor ring as this is too slow for model of size {model_card.storage_size} on {n_nodes=}"
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
)
continue
for tg in tg_list:

View File

@@ -863,7 +863,6 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -903,7 +902,6 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,7 +1518,6 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1530,7 +1527,6 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1943,7 +1939,6 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2651,7 +2646,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2839,7 +2833,6 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2984,7 +2977,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3006,7 +2998,6 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import { isLoading, sendMessage, generateImage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import ChatAttachments from './ChatAttachments.svelte';
import type { ChatUploadedFile } from '$lib/types/files';
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
@@ -10,6 +10,7 @@
showHelperText?: boolean;
autofocus?: boolean;
showModelSelector?: boolean;
modelTasks?: Record<string, string[]>;
}
let {
@@ -17,7 +18,8 @@
placeholder = 'Ask anything',
showHelperText = false,
autofocus = true,
showModelSelector = false
showModelSelector = false,
modelTasks = {}
}: Props = $props();
let message = $state('');
@@ -48,51 +50,40 @@
// Accept all supported file types
const acceptString = getAcceptString(['image', 'text', 'pdf']);
// Check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const tasks = modelTasks[modelId] || [];
return tasks.includes('TextToImage') || tasks.includes('ImageToImage');
}
// Check if the currently selected model supports image generation
const isImageModel = $derived(() => {
if (!currentModel) return false;
return modelSupportsImageGeneration(currentModel);
});
// Extract available models from running instances
const availableModels = $derived(() => {
const models: Array<{id: string, label: string}> = [];
const models: Array<{id: string, label: string, isImageModel: boolean}> = [];
for (const [, instance] of Object.entries(instanceData)) {
const modelId = getInstanceModelId(instance);
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
models.push({
id: modelId,
label: modelId.split('/').pop() || modelId,
isImageModel: modelSupportsImageGeneration(modelId)
});
}
}
return models;
});
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
// Auto-select the first available model if none is selected
$effect(() => {
const models = availableModels();
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {
@@ -187,7 +178,12 @@
uploadedFiles = [];
resetTextareaHeight();
sendMessage(content, files);
// Use image generation for image models
if (isImageModel() && content) {
generateImage(content);
} else {
sendMessage(content, files);
}
// Refocus the textarea after sending
setTimeout(() => textareaRef?.focus(), 10);
@@ -324,7 +320,14 @@
{:else}
<span class="w-3"></span>
{/if}
<span class="truncate">{model.label}</span>
{#if model.isImageModel}
<svg class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate flex-1">{model.label}</span>
</button>
{/each}
</div>
@@ -384,7 +387,7 @@
onkeydown={handleKeydown}
oninput={handleInput}
onpaste={handlePaste}
{placeholder}
placeholder={isImageModel() ? 'Describe the image you want to generate...' : placeholder}
disabled={loading}
rows={1}
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
@@ -398,14 +401,23 @@
{!canSend || loading
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label="Send message"
aria-label={isImageModel() ? "Generate image" : "Send message"}
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
<span class="hidden sm:inline">PROCESSING</span>
<span class="hidden sm:inline">{isImageModel() ? 'GENERATING' : 'PROCESSING'}</span>
<span class="sm:hidden">...</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}

View File

@@ -365,10 +365,58 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
</div>
{/if}
<!-- Generated Images -->
{#if message.attachments?.some(a => a.type === 'generated-image')}
<div class="mb-3">
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
<div class="relative group/img inline-block">
<img
src={attachment.preview}
alt=""
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
/>
<!-- Download button overlay -->
<button
type="button"
class="absolute top-2 right-2 p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow opacity-0 group-hover/img:opacity-100 transition-opacity hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
onclick={() => {
if (attachment.preview) {
const link = document.createElement('a');
link.href = attachment.preview;
link.download = `generated-image-${Date.now()}.png`;
link.click();
}
}}
title="Download image"
>
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
</svg>
</button>
</div>
{/each}
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{#if message.content === 'Generating image...'}
<div class="flex items-center gap-3 text-exo-yellow">
<div class="relative">
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
</div>
<span class="font-mono tracking-wider uppercase text-sm">Generating image...</span>
</div>
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
{/if}
</div>
</div>

View File

@@ -53,285 +53,62 @@
marked.use({ renderer });
/**
* Unescape HTML entities that marked may have escaped
*/
function unescapeHtmlEntities(text: string): string {
return text
.replace(/&lt;/g, '<')
.replace(/&gt;/g, '>')
.replace(/&amp;/g, '&')
.replace(/&quot;/g, '"')
.replace(/&#39;/g, "'");
}
// Storage for math expressions extracted before markdown processing
const mathExpressions: Map<string, { content: string; displayMode: boolean }> = new Map();
let mathCounter = 0;
// Storage for HTML snippets that need protection from markdown
const htmlSnippets: Map<string, string> = new Map();
let htmlCounter = 0;
// Use alphanumeric placeholders that won't be interpreted as HTML tags
const MATH_PLACEHOLDER_PREFIX = 'MATHPLACEHOLDER';
const CODE_PLACEHOLDER_PREFIX = 'CODEPLACEHOLDER';
const HTML_PLACEHOLDER_PREFIX = 'HTMLPLACEHOLDER';
/**
* Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
* Also protect code blocks from LaTeX processing
*/
function preprocessLaTeX(text: string): string {
// Reset storage
mathExpressions.clear();
mathCounter = 0;
htmlSnippets.clear();
htmlCounter = 0;
// Protect code blocks first
// Protect code blocks
const codeBlocks: string[] = [];
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
codeBlocks.push(match);
return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;
return `<<CODE_${codeBlocks.length - 1}>>`;
});
// Remove LaTeX document commands
processed = processed.replace(/\\documentclass(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\usepackage(\[[^\]]*\])?\{[^}]*\}/g, '');
processed = processed.replace(/\\begin\{document\}/g, '');
processed = processed.replace(/\\end\{document\}/g, '');
processed = processed.replace(/\\maketitle/g, '');
processed = processed.replace(/\\title\{[^}]*\}/g, '');
processed = processed.replace(/\\author\{[^}]*\}/g, '');
processed = processed.replace(/\\date\{[^}]*\}/g, '');
// Remove \require{...} commands (MathJax-specific, not supported by KaTeX)
processed = processed.replace(/\$\\require\{[^}]*\}\$/g, '');
processed = processed.replace(/\\require\{[^}]*\}/g, '');
// Remove unsupported LaTeX commands/environments (tikzpicture, figure, center, etc.)
processed = processed.replace(/\\begin\{tikzpicture\}[\s\S]*?\\end\{tikzpicture\}/g, () => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">📐</span><span class="latex-diagram-text">Diagram</span></div>');
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\begin\{figure\}[\s\S]*?\\end\{figure\}/g, () => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">🖼️</span><span class="latex-diagram-text">Figure</span></div>');
htmlCounter++;
return placeholder;
});
// Strip center environment (layout only, no content change)
processed = processed.replace(/\\begin\{center\}/g, '');
processed = processed.replace(/\\end\{center\}/g, '');
// Strip other layout environments
processed = processed.replace(/\\begin\{flushleft\}/g, '');
processed = processed.replace(/\\end\{flushleft\}/g, '');
processed = processed.replace(/\\begin\{flushright\}/g, '');
processed = processed.replace(/\\end\{flushright\}/g, '');
processed = processed.replace(/\\label\{[^}]*\}/g, '');
processed = processed.replace(/\\caption\{[^}]*\}/g, '');
// Protect escaped dollar signs (e.g., \$50 should become $50, not LaTeX)
processed = processed.replace(/\\\$/g, 'ESCAPEDDOLLARPLACEHOLDER');
// Convert LaTeX math environments to display math (both bare and wrapped in $...$)
const mathEnvs = ['align', 'align\\*', 'equation', 'equation\\*', 'gather', 'gather\\*', 'multline', 'multline\\*', 'eqnarray', 'eqnarray\\*', 'array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix', 'cases'];
for (const env of mathEnvs) {
// Handle $\begin{env}...\end{env}$ (with dollar signs, possibly multiline)
const wrappedRegex = new RegExp(`\\$\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}\\$`, 'g');
processed = processed.replace(wrappedRegex, (_, args, content) => {
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
// Handle bare \begin{env}...\end{env} (without dollar signs)
const bareRegex = new RegExp(`\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}`, 'g');
processed = processed.replace(bareRegex, (_, args, content) => {
const cleanEnv = env.replace('\\*', '*');
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
mathCounter++;
return placeholder;
});
}
// Convert LaTeX proof environments to styled blocks (use placeholders for HTML)
processed = processed.replace(
/\\begin\{proof\}([\s\S]*?)\\end\{proof\}/g,
(_, content) => {
const html = `<div class="latex-proof"><div class="latex-proof-header">Proof</div><div class="latex-proof-content">${content}</div></div>`;
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, html);
htmlCounter++;
return placeholder;
}
);
// Convert LaTeX theorem-like environments
const theoremEnvs = ['theorem', 'lemma', 'corollary', 'proposition', 'definition', 'remark', 'example'];
for (const env of theoremEnvs) {
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'gi');
const envName = env.charAt(0).toUpperCase() + env.slice(1);
processed = processed.replace(envRegex, (_, content) => {
const html = `<div class="latex-theorem"><div class="latex-theorem-header">${envName}</div><div class="latex-theorem-content">${content}</div></div>`;
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, html);
htmlCounter++;
return placeholder;
});
}
// Convert LaTeX text formatting commands (use placeholders to protect from markdown)
processed = processed.replace(/\\emph\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<em>${content}</em>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\textit\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<em>${content}</em>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\textbf\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<strong>${content}</strong>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\texttt\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<code class="inline-code">${content}</code>`);
htmlCounter++;
return placeholder;
});
processed = processed.replace(/\\underline\{([^}]*)\}/g, (_, content) => {
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
htmlSnippets.set(placeholder, `<u>${content}</u>`);
htmlCounter++;
return placeholder;
});
// Handle LaTeX line breaks and spacing
processed = processed.replace(/\\\\(?:\s*\n)?/g, '\n'); // \\ -> newline
processed = processed.replace(/\\newline/g, '\n');
processed = processed.replace(/\\par\b/g, '\n\n');
processed = processed.replace(/\\quad/g, ' ');
processed = processed.replace(/\\qquad/g, ' ');
processed = processed.replace(/~~/g, ' '); // non-breaking space
// Remove other common LaTeX commands that don't render
processed = processed.replace(/\\centering/g, '');
processed = processed.replace(/\\noindent/g, '');
processed = processed.replace(/\\hfill/g, '');
processed = processed.replace(/\\vspace\{[^}]*\}/g, '');
processed = processed.replace(/\\hspace\{[^}]*\}/g, ' ');
// Convert \(...\) to placeholder (display: false)
processed = processed.replace(/\\\(([\s\S]+?)\\\)/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: false });
mathCounter++;
return placeholder;
});
// Convert \[...\] to placeholder (display: true)
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content, displayMode: true });
mathCounter++;
return placeholder;
});
// Extract display math ($$...$$) BEFORE markdown processing
processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, content) => {
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: true });
mathCounter++;
return placeholder;
});
// Extract inline math ($...$) BEFORE markdown processing
// Allow single-line only, skip currency patterns like $5 or $50
processed = processed.replace(/\$([^\$\n]+?)\$/g, (match, content) => {
if (/^\d/.test(content.trim())) {
return match; // Keep as-is for currency
}
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
mathExpressions.set(placeholder, { content: content.trim(), displayMode: false });
mathCounter++;
return placeholder;
});
// Restore escaped dollar signs
processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, '$');
// Convert \(...\) to $...$
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
// Convert \[...\] to $$...$$
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
// Restore code blocks
processed = processed.replace(new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\d+)END`, 'g'), (_, index) => codeBlocks[parseInt(index)]);
// Clean up any remaining stray backslashes from unrecognized commands
processed = processed.replace(/\\(?=[a-zA-Z])/g, ''); // Remove \ before letters (unrecognized commands)
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
return processed;
}
/**
* Render math expressions with KaTeX and restore HTML placeholders
* Render math expressions with KaTeX after HTML is generated
*/
function renderMath(html: string): string {
// Replace all math placeholders with rendered KaTeX
for (const [placeholder, { content, displayMode }] of mathExpressions) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
// Render display math ($$...$$)
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
try {
return katex.renderToString(math.trim(), {
displayMode: true,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$$${math}$$</span>`;
}
});
html = html.replace(regex, () => {
try {
const rendered = katex.renderToString(content, {
displayMode,
throwOnError: false,
output: 'html'
});
if (displayMode) {
return `
<div class="math-display-wrapper">
<div class="math-display-header">
<span class="math-label">LaTeX</span>
<button type="button" class="copy-math-btn" data-math-source="${encodeURIComponent(content)}" title="Copy LaTeX source">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
</svg>
</button>
</div>
<div class="math-display-content">
${rendered}
</div>
</div>
`;
} else {
return `<span class="math-inline">${rendered}</span>`;
}
} catch {
const display = displayMode ? `$$${content}$$` : `$${content}$`;
return `<span class="math-error"><span class="math-error-icon">⚠</span> ${display}</span>`;
}
});
}
// Restore HTML placeholders (for \textbf, \emph, etc.)
for (const [placeholder, htmlContent] of htmlSnippets) {
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
const regex = new RegExp(escapedPlaceholder, 'g');
html = html.replace(regex, htmlContent);
}
// Render inline math ($...$) but avoid matching currency like $5
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
// Skip if it looks like currency ($ followed by number)
if (/^\d/.test(math.trim())) {
return match;
}
try {
return katex.renderToString(math.trim(), {
displayMode: false,
throwOnError: false,
output: 'html'
});
} catch {
return `<span class="math-error">$${math}$</span>`;
}
});
return html;
}
@@ -377,50 +154,16 @@
}
}
async function handleMathCopyClick(event: Event) {
const target = event.currentTarget as HTMLButtonElement;
const encodedSource = target.getAttribute('data-math-source');
if (!encodedSource) return;
const source = decodeURIComponent(encodedSource);
try {
await navigator.clipboard.writeText(source);
// Show copied feedback
const originalHtml = target.innerHTML;
target.innerHTML = `
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M20 6L9 17l-5-5"/>
</svg>
`;
target.classList.add('copied');
setTimeout(() => {
target.innerHTML = originalHtml;
target.classList.remove('copied');
}, 2000);
} catch (error) {
console.error('Failed to copy math:', error);
}
}
function setupCopyButtons() {
if (!containerRef || !browser) return;
const codeButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of codeButtons) {
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
for (const button of buttons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleCopyClick);
}
}
const mathButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-math-btn');
for (const button of mathButtons) {
if (button.dataset.listenerBound !== 'true') {
button.dataset.listenerBound = 'true';
button.addEventListener('click', handleMathCopyClick);
}
}
}
$effect(() => {
@@ -681,290 +424,28 @@
color: #60a5fa;
}
/* KaTeX math styling - Base */
/* KaTeX math styling */
.markdown-content :global(.katex) {
font-size: 1.1em;
color: oklch(0.9 0 0);
}
/* Display math container wrapper */
.markdown-content :global(.math-display-wrapper) {
.markdown-content :global(.katex-display) {
margin: 1rem 0;
border-radius: 0.5rem;
overflow: hidden;
border: 1px solid rgba(255, 215, 0, 0.15);
background: rgba(0, 0, 0, 0.3);
transition: border-color 0.2s ease, box-shadow 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover) {
border-color: rgba(255, 215, 0, 0.25);
box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);
}
/* Display math header - hidden by default, slides in on hover */
.markdown-content :global(.math-display-header) {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.375rem 0.75rem;
background: rgba(255, 215, 0, 0.03);
border-bottom: 1px solid rgba(255, 215, 0, 0.08);
opacity: 0;
max-height: 0;
padding-top: 0;
padding-bottom: 0;
overflow: hidden;
transition:
opacity 0.2s ease,
max-height 0.2s ease,
padding 0.2s ease;
}
.markdown-content :global(.math-display-wrapper:hover .math-display-header) {
opacity: 1;
max-height: 2.5rem;
padding: 0.375rem 0.75rem;
}
.markdown-content :global(.math-label) {
color: rgba(255, 215, 0, 0.7);
font-size: 0.65rem;
font-weight: 500;
text-transform: uppercase;
letter-spacing: 0.1em;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
}
.markdown-content :global(.copy-math-btn) {
display: flex;
align-items: center;
justify-content: center;
padding: 0.25rem;
background: transparent;
border: none;
color: var(--exo-light-gray, #9ca3af);
cursor: pointer;
transition: color 0.2s;
border-radius: 0.25rem;
opacity: 0;
transition:
color 0.2s,
opacity 0.15s ease;
}
.markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {
opacity: 1;
}
.markdown-content :global(.copy-math-btn:hover) {
color: var(--exo-yellow, #ffd700);
}
.markdown-content :global(.copy-math-btn.copied) {
color: #22c55e;
}
/* Display math content area */
.markdown-content :global(.math-display-content) {
padding: 1rem 1.25rem;
overflow-x: auto;
overflow-y: hidden;
padding: 0.5rem 0;
}
/* Custom scrollbar for math overflow */
.markdown-content :global(.math-display-content::-webkit-scrollbar) {
height: 6px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-track) {
background: rgba(255, 255, 255, 0.05);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {
background: rgba(255, 215, 0, 0.2);
border-radius: 3px;
}
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb:hover) {
background: rgba(255, 215, 0, 0.35);
}
.markdown-content :global(.math-display-content .katex-display) {
margin: 0;
padding: 0;
}
.markdown-content :global(.math-display-content .katex-display > .katex) {
.markdown-content :global(.katex-display > .katex) {
text-align: center;
}
/* Inline math wrapper */
.markdown-content :global(.math-inline) {
display: inline;
padding: 0 0.125rem;
border-radius: 0.25rem;
transition: background-color 0.15s ease;
}
.markdown-content :global(.math-inline:hover) {
background: rgba(255, 215, 0, 0.05);
}
/* Dark theme KaTeX overrides */
.markdown-content :global(.katex .mord),
.markdown-content :global(.katex .minner),
.markdown-content :global(.katex .mop),
.markdown-content :global(.katex .mbin),
.markdown-content :global(.katex .mrel),
.markdown-content :global(.katex .mpunct) {
color: oklch(0.9 0 0);
}
/* Fraction lines and rules */
.markdown-content :global(.katex .frac-line),
.markdown-content :global(.katex .overline-line),
.markdown-content :global(.katex .underline-line),
.markdown-content :global(.katex .hline),
.markdown-content :global(.katex .rule) {
border-color: oklch(0.85 0 0) !important;
background: oklch(0.85 0 0);
}
/* Square roots and SVG elements */
.markdown-content :global(.katex .sqrt-line) {
border-color: oklch(0.85 0 0) !important;
}
.markdown-content :global(.katex svg) {
fill: oklch(0.85 0 0);
stroke: oklch(0.85 0 0);
}
.markdown-content :global(.katex svg path) {
stroke: oklch(0.85 0 0);
}
/* Delimiters (parentheses, brackets, braces) */
.markdown-content :global(.katex .delimsizing),
.markdown-content :global(.katex .delim-size1),
.markdown-content :global(.katex .delim-size2),
.markdown-content :global(.katex .delim-size3),
.markdown-content :global(.katex .delim-size4),
.markdown-content :global(.katex .mopen),
.markdown-content :global(.katex .mclose) {
color: oklch(0.75 0 0);
}
/* Math error styling */
.markdown-content :global(.math-error) {
display: inline-flex;
align-items: center;
gap: 0.375rem;
color: #f87171;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.875em;
background: rgba(248, 113, 113, 0.1);
padding: 0.25rem 0.5rem;
padding: 0.125rem 0.25rem;
border-radius: 0.25rem;
border: 1px solid rgba(248, 113, 113, 0.2);
}
.markdown-content :global(.math-error-icon) {
font-size: 0.875em;
opacity: 0.9;
}
/* LaTeX proof environment */
.markdown-content :global(.latex-proof) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 255, 255, 0.02);
border-left: 3px solid rgba(255, 215, 0, 0.4);
border-radius: 0 0.375rem 0.375rem 0;
}
.markdown-content :global(.latex-proof-header) {
font-weight: 600;
font-style: italic;
color: oklch(0.85 0 0);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-proof-header::after) {
content: '.';
}
.markdown-content :global(.latex-proof-content) {
color: oklch(0.9 0 0);
}
.markdown-content :global(.latex-proof-content p:last-child) {
margin-bottom: 0;
}
/* QED symbol at end of proof */
.markdown-content :global(.latex-proof-content::after) {
content: '∎';
display: block;
text-align: right;
color: oklch(0.7 0 0);
margin-top: 0.5rem;
}
/* LaTeX theorem-like environments */
.markdown-content :global(.latex-theorem) {
margin: 1rem 0;
padding: 1rem 1.25rem;
background: rgba(255, 215, 0, 0.03);
border: 1px solid rgba(255, 215, 0, 0.15);
border-radius: 0.375rem;
}
.markdown-content :global(.latex-theorem-header) {
font-weight: 700;
color: var(--exo-yellow, #ffd700);
margin-bottom: 0.5rem;
}
.markdown-content :global(.latex-theorem-header::after) {
content: '.';
}
.markdown-content :global(.latex-theorem-content) {
color: oklch(0.9 0 0);
font-style: italic;
}
.markdown-content :global(.latex-theorem-content p:last-child) {
margin-bottom: 0;
}
/* LaTeX diagram/figure placeholder */
.markdown-content :global(.latex-diagram-placeholder) {
display: flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
margin: 1rem 0;
padding: 1.5rem 2rem;
background: rgba(255, 255, 255, 0.02);
border: 1px dashed rgba(255, 215, 0, 0.25);
border-radius: 0.5rem;
color: rgba(255, 215, 0, 0.6);
font-size: 0.875rem;
}
.markdown-content :global(.latex-diagram-icon) {
font-size: 1.25rem;
opacity: 0.8;
}
.markdown-content :global(.latex-diagram-text) {
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
font-size: 0.75rem;
text-transform: uppercase;
letter-spacing: 0.05em;
}
</style>

View File

@@ -197,7 +197,7 @@ function toggleNodeDetails(nodeId: string): void {
// Uses API preview data when available, falls back to local estimation
const placementPreview = $derived(() => {
const nodeArray = nodeList();
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, topoWidth: 260, topoHeight: 90, error: null };
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, error: null };
const numNodes = nodeArray.length;
const iconSize = numNodes === 1 ? 50 : 36;

View File

@@ -1,7 +1,7 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
interface Props {
class?: string;
@@ -24,14 +24,14 @@ function getNodeLabel(nodeId: string): string {
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
if (!ip) return { label: '?', missing: true };
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
// Helper to check a node's interfaces
function checkNode(node: NodeInfo | undefined): string | null {
function checkNode(node: typeof data.nodes[string]): string | null {
if (!node) return null;
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
);
@@ -39,19 +39,17 @@ function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missin
return matchFromInterfaces.name;
}
if (node.ip_to_interface) {
const mapped = node.ip_to_interface[cleanIp] || (ip ? node.ip_to_interface[ip] : undefined);
if (mapped && mapped.trim().length > 0) {
return mapped;
}
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
if (mapped && mapped.trim().length > 0) {
return mapped;
}
return null;
}
// Try specified node first
const result = checkNode(data?.nodes?.[nodeId]);
if (result) return { label: result, missing: false };
// Fallback: search all nodes for this IP
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
const otherResult = checkNode(otherNode);
@@ -257,24 +255,21 @@ function wrapLine(text: string, maxLen: number): string[] {
const arrowsGroup = svg.append('g').attr('class', 'arrows-group');
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
type ConnectionInfo = { from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean };
type PairEntry = { a: string; b: string; aToB: boolean; bToA: boolean; connections: ConnectionInfo[] };
type DebugEdgeLabelEntry = { connections: ConnectionInfo[]; isLeft: boolean; isTop: boolean; mx: number; my: number };
const pairMap = new Map<string, PairEntry>();
const debugEdgeLabels: DebugEdgeLabelEntry[] = [];
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
edges.forEach(edge => {
if (!edge.source || !edge.target || edge.source === edge.target) return;
if (!positionById[edge.source] || !positionById[edge.target]) return;
const a = edge.source < edge.target ? edge.source : edge.target;
const b = edge.source < edge.target ? edge.target : edge.source;
const key = `${a}|${b}`;
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false, connections: [] };
if (edge.source === a) entry.aToB = true;
else entry.bToA = true;
const ip = edge.sendBackIp || '?';
const ip = edge.sendBackIp || edge.sendBackMultiaddr?.ip_address || '?';
const ifaceInfo = getInterfaceLabel(edge.source, ip);
entry.connections.push({
from: edge.source,
@@ -343,8 +338,9 @@ function wrapLine(text: string, maxLen: number): string[] {
// Determine which side of viewport based on edge midpoint
const isLeft = mx < centerX;
const isTop = my < safeCenterY;
// Store for batch rendering after all edges processed
if (!debugEdgeLabels) debugEdgeLabels = [];
debugEdgeLabels.push({
connections: entry.connections,
isLeft,
@@ -385,32 +381,32 @@ function wrapLine(text: string, maxLen: number): string[] {
}
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
const quadrants: Record<string, DebugEdgeLabelEntry[]> = {
const quadrants: Record<string, typeof debugEdgeLabels> = {
topLeft: [],
topRight: [],
bottomLeft: [],
bottomRight: []
};
debugEdgeLabels.forEach(edge => {
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
quadrants[key].push(edge);
});
// Render each quadrant
Object.entries(quadrants).forEach(([quadrant, quadrantEdges]) => {
if (quadrantEdges.length === 0) return;
Object.entries(quadrants).forEach(([quadrant, edges]) => {
if (edges.length === 0) return;
const isLeft = quadrant.includes('Left');
const isTop = quadrant.includes('top');
let baseX = isLeft ? padding : width - padding;
let baseY = isTop ? padding : height - padding;
const textAnchor = isLeft ? 'start' : 'end';
let currentY = baseY;
quadrantEdges.forEach(edge => {
edges.forEach(edge => {
edge.connections.forEach(conn => {
const arrow = getArrow(conn.from, conn.to);
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;

View File

File diff suppressed because it is too large Load Diff

View File

@@ -47,7 +47,30 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
const modelTasks = $derived(() => {
const tasks: Record<string, string[]> = {};
for (const model of models) {
if (model.tasks && model.tasks.length > 0) {
// Map by short ID
tasks[model.id] = model.tasks;
// Also map by hugging_face_id from the API response
if (model.hugging_face_id) {
tasks[model.hugging_face_id] = model.tasks;
}
}
}
return tasks;
});
// Helper to check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
if (!model?.tasks) return false;
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
}
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -400,8 +423,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -434,8 +459,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
// Model meta is nested: shard.model_card.model_id
const modelMeta = shardData.model_card ?? shardData.modelCard;
// Model meta is nested: shard.model_meta.model_id
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -761,10 +786,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -773,24 +794,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);
@@ -915,7 +918,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const runnerEntries = Object.entries(runnerToShard).map(([runnerId, shardWrapped]) => {
const [tag, shard] = getTagged(shardWrapped);
const meta = (shard as { modelMeta?: { worldSize?: number; nLayers?: number; deviceRank?: number } } | undefined);
const deviceRank = meta?.modelMeta?.deviceRank ?? 0;
const deviceRank = (meta?.deviceRank as number | undefined) ?? 0;
return { runnerId, tag, deviceRank };
});
@@ -1270,6 +1273,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
placeholder="Ask anything"
showHelperText={false}
showModelSelector={true}
modelTasks={modelTasks()}
/>
</div>
</div>
@@ -1491,8 +1495,18 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{@const foundModel = models.find(m => m.id === selectedModelId)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="flex items-center gap-2 text-exo-light-gray truncate">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{foundModel.name || foundModel.id}</span>
</span>
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
</span>
{:else}
@@ -1537,6 +1551,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(model.id)}
<button
type="button"
onclick={() => {
@@ -1556,7 +1571,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
: 'text-white/30 cursor-default'
}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
@@ -1753,7 +1777,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
<div class="max-w-7xl mx-auto">
<ChatForm placeholder="Ask anything" showModelSelector={true} />
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
</div>
</div>
</div>

View File

@@ -98,7 +98,7 @@
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
const modelMeta = shardData.model_card ?? shardData.modelCard;
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -190,7 +190,7 @@
const shardKeys = Object.keys(shardObj);
if (shardKeys.length !== 1) return null;
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.prettyName as string) ?? null;

View File

@@ -1,6 +1,6 @@
# EXO API Technical Reference
This document describes the REST API exposed by the **EXO ** service, as implemented in:
This document describes the REST API exposed by the **EXO** service, as implemented in:
`src/exo/master/api.py`
@@ -183,7 +183,70 @@ Same schema as `/v1/chat/completions`.
**Response:**
Chat completion plus benchmarking metrics.
## 5. Complete Endpoint Summary
## 5. Image Generation & Editing
### Image Generation
**POST** `/v1/images/generations`
Executes an image generation request using an OpenAI-compatible schema with additional advanced_params.
**Request body (example):**
```json
{
"prompt": "a robot playing chess",
"model": "flux-dev",
"stream": false,
}
```
**Advanced Parameters (`advanced_params`):**
| Parameter | Type | Constraints | Description |
|-----------|------|-------------|-------------|
| `seed` | int | >= 0 | Random seed for reproducible generation |
| `num_inference_steps` | int | 1-100 | Number of denoising steps |
| `guidance` | float | 1.0-20.0 | Classifier-free guidance scale |
| `negative_prompt` | string | - | Text describing what to avoid in the image |
**Response:**
OpenAI-compatible image generation response.
### Benchmarked Image Generation
**POST** `/bench/images/generations`
Same as `/v1/images/generations`, but also returns generation statistics.
**Request body:**
Same schema as `/v1/images/generations`.
**Response:**
Image generation plus benchmarking metrics.
### Image Editing
**POST** `/v1/images/edits`
Executes an image editing request using an OpenAI-compatible schema with additional advanced_params (same as `/v1/images/generations`).
**Response:**
Same format as `/v1/images/generations`.
### Benchmarked Image Editing
**POST** `/bench/images/edits`
Same as `/v1/images/edits`, but also returns generation statistics.
**Request:**
Same schema as `/v1/images/edits`.
**Response:**
Same format as `/bench/images/generations`, including `generation_stats`.
## 6. Complete Endpoint Summary
```
GET /node_id
@@ -203,10 +266,16 @@ GET /v1/models
POST /v1/chat/completions
POST /bench/chat/completions
POST /v1/images/generations
POST /bench/images/generations
POST /v1/images/edits
POST /bench/images/edits
```
## 6. Notes
## 7. Notes
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
* The `/v1/chat/completions` endpoint is compatible with the OpenAI Chat API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
* The `/v1/images/generations` and `/v1/images/edits` endpoints are compatible with the OpenAI Images API format.
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 187 KiB

View File

@@ -1,5 +1,3 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

View File

@@ -23,7 +23,9 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux>=0.14.2",
"python-multipart>=0.0.21",
]
[project.scripts]
@@ -126,6 +128,3 @@ env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -81,6 +81,20 @@
config = {
packages = {
# The system_custodian binary
system_custodian = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "-p system_custodian";
meta = {
description = "System custodian daemon for exo";
mainProgram = "system_custodian";
};
}
);
# Python bindings wheel via maturin
exo_pyo3_bindings = craneLib.buildPackage (
commonArgs

View File

@@ -0,0 +1,47 @@
[package]
name = "system_custodian"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "system_custodian"
path = "src/lib.rs"
[[bin]]
path = "src/bin/main.rs"
name = "system_custodian"
doc = false
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }

View File

@@ -0,0 +1,4 @@
//! TODO: documentation
//!
fn main() {}

View File

@@ -0,0 +1,69 @@
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
//!
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
//! etc.) is in an appropriate state to facilitate the running of Exo application.
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
//! service which Exo application use to _control & query_ it.
//!
//! # Lifecycle
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
//! destructive to the user's pre-existing configurations.
//!
//! # Responsibilities
//! TODO: these are purely on MacOS, but change to be more broad
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
//! 1. duplicate the current network set
//! 2. modify existing services to turn on IPv6 if not there
//! 3. remove any bridge services & add any missing services that AREN'T bridge
//! TODO: In the future:
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
//! logic, this would be the place to spin that up.
//!
//! Then it will watch the SCDynamicStore for:
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
//! interface of any changes
//! 2. watch for any __undesirable__ changes to configuration and revert it
//!
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
//! over D-Bus
//! TODO:
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
//! then this would be the place to carry out discovery and propper handshakes with devices
//! on the other end of the link.
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {}

View File

@@ -205,14 +205,6 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger.info("EXO Shutdown complete")
@@ -226,7 +218,6 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
def parse(cls) -> Self:
@@ -268,20 +259,6 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",
action="store_true",
dest="fast_synch",
default=None,
help="Force MLX FAST_SYNCH on (for JACCL backend)",
)
fast_synch_group.add_argument(
"--no-fast-synch",
action="store_false",
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -1,39 +1,52 @@
import base64
import json
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
from typing import Literal, cast
import anyio
from anyio import BrokenResourceError, create_task_group
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException, Request
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
BenchImageGenerationResponse,
BenchImageGenerationTaskParams,
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
GenerationStats,
ImageData,
ImageEditsInternalParams,
ImageGenerationResponse,
ImageGenerationStats,
ImageGenerationTaskParams,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -41,24 +54,23 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
)
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -68,6 +80,8 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
@@ -86,12 +100,23 @@ def chunk_to_response(
)
async def resolve_model_card(model_id: str) -> ModelCard:
def get_model_card(model_id: str) -> ModelCard | None:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
else:
return await get_model_card(model_id)
for _, model_card in MODEL_CARDS.items():
if model_id == model_card.model_id:
return model_card
async def resolve_model_meta(model_id: str) -> ModelMetadata:
model_card = get_model_card(model_id)
if model_card is not None:
return model_card.metadata
return await get_model_meta(model_id)
class API:
@@ -122,7 +147,6 @@ class API:
self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI()
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
@@ -136,6 +160,7 @@ class API:
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -144,6 +169,7 @@ class API:
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._chat_completion_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
def unpause(self, result_clock: int):
@@ -153,21 +179,6 @@ class API:
self.paused_ev.set()
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
self.app.exception_handler(HTTPException)(self.http_exception_handler)
async def http_exception_handler(
self, _: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
CORSMiddleware,
@@ -191,12 +202,18 @@ class API:
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/v1/images/generations", response_model=None)(
self.image_generations
)
self.app.post("/bench/images/generations")(self.bench_image_generations)
self.app.post("/v1/images/edits", response_model=None)(self.image_edits)
self.app.post("/bench/images/edits")(self.bench_image_edits)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await resolve_model_card(payload.model_id),
model_meta=await resolve_model_meta(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -206,15 +223,15 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_card=command.model_card,
model_meta=command.model_meta,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
@@ -231,7 +248,7 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_card=model_card,
model_meta=model_meta,
)
async def get_placement(
@@ -241,17 +258,16 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await resolve_model_card(model_id)
model_meta = await resolve_model_meta(model_id)
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
model_meta=model_meta,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -278,7 +294,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
@@ -296,32 +312,32 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for model_card in cards:
for card in cards:
model_meta = card.metadata
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
model_meta=model_meta,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add((card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
@@ -332,17 +348,17 @@ class API:
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add((card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
@@ -351,7 +367,7 @@ class API:
memory_delta_by_node: dict[str, int] = {}
if node_ids:
total_bytes = model_card.storage_size.in_bytes
total_bytes = model_meta.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
@@ -359,14 +375,14 @@ class API:
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
card.model_id,
sharding,
instance_meta,
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -374,7 +390,7 @@ class API:
error=None,
)
)
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -397,8 +413,35 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
@@ -406,10 +449,16 @@ class API:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -425,23 +474,11 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -453,7 +490,7 @@ class API:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
@@ -461,13 +498,7 @@ class API:
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
@@ -496,7 +527,7 @@ class API:
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
@@ -504,13 +535,7 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
@@ -549,8 +574,10 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -567,17 +594,18 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
self._generate_chat_stream(command.command_id, parse_gpt_oss),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -593,15 +621,388 @@ class API:
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(command.command_id)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
return response
async def _validate_image_model(self, model: str) -> ModelId:
"""Validate model exists and return resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_meta = await resolve_model_meta(model)
resolved_model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved_model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {resolved_model}"
)
return resolved_model
async def image_generations(
self, payload: ImageGenerationTaskParams
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image generation requests.
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
command = ImageGeneration(
request_params=payload,
)
await self._send(command)
# Check if streaming is requested
if payload.stream and payload.partial_images and payload.partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
),
media_type="text/event-stream",
)
# Non-streaming: collect all image chunks
return await self._collect_image_generation(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
)
async def _generate_image_stream(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> AsyncGenerator[str, None]:
"""Generate SSE stream of partial and final images."""
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
image_total_chunks: dict[tuple[int, bool], int] = {}
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
with recv as chunks:
async for chunk in chunks:
key = (chunk.image_index, chunk.is_partial)
if key not in image_chunks:
image_chunks[key] = {}
image_total_chunks[key] = chunk.total_chunks
image_metadata[key] = (
chunk.partial_index,
chunk.total_partials,
)
image_chunks[key][chunk.chunk_index] = chunk.data
# Check if this image is complete
if len(image_chunks[key]) == image_total_chunks[key]:
full_data = "".join(
image_chunks[key][i] for i in range(len(image_chunks[key]))
)
partial_idx, total_partials = image_metadata[key]
if chunk.is_partial:
# Yield partial image event
event_data = {
"type": "partial",
"partial_index": partial_idx,
"total_partials": total_partials,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
else:
# Final image
event_data = {
"type": "final",
"image_index": chunk.image_index,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
images_complete += 1
if images_complete >= num_images:
yield "data: [DONE]\n\n"
break
# Clean up completed image chunks
del image_chunks[key]
del image_total_chunks[key]
del image_metadata[key]
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def _collect_image_chunks(
self,
command_id: CommandId,
num_images: int,
response_format: str,
capture_stats: bool = False,
) -> tuple[list[ImageData], ImageGenerationStats | None]:
"""Collect image chunks and optionally capture stats."""
# Track chunks per image: {image_index: {chunk_index: data}}
# Only track non-partial (final) images
image_chunks: dict[int, dict[int, str]] = {}
image_total_chunks: dict[int, int] = {}
images_complete = 0
stats: ImageGenerationStats | None = None
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
while images_complete < num_images:
with recv as chunks:
async for chunk in chunks:
if chunk.is_partial:
continue
if chunk.image_index not in image_chunks:
image_chunks[chunk.image_index] = {}
image_total_chunks[chunk.image_index] = chunk.total_chunks
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
if capture_stats and chunk.stats is not None:
stats = chunk.stats
if (
len(image_chunks[chunk.image_index])
== image_total_chunks[chunk.image_index]
):
images_complete += 1
if images_complete >= num_images:
break
images: list[ImageData] = []
for image_idx in range(num_images):
chunks_dict = image_chunks[image_idx]
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
images.append(
ImageData(
b64_json=full_data if response_format == "b64_json" else None,
url=None,
)
)
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def _collect_image_generation(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> ImageGenerationResponse:
"""Collect all image chunks (non-streaming) and return a single response."""
images, _ = await self._collect_image_chunks(
command_id, num_images, response_format, capture_stats=False
)
return ImageGenerationResponse(data=images)
async def _collect_image_generation_with_stats(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> BenchImageGenerationResponse:
images, stats = await self._collect_image_chunks(
command_id, num_images, response_format, capture_stats=True
)
return BenchImageGenerationResponse(data=images, generation_stats=stats)
async def bench_image_generations(
self, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(payload.model)
payload.stream = False
payload.partial_images = 0
command = ImageGeneration(
request_params=payload,
)
await self._send(command)
return await self._collect_image_generation_with_stats(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
)
async def _send_image_edits_command(
self,
image: UploadFile,
prompt: str,
model: str,
n: int,
size: str,
response_format: Literal["url", "b64_json"],
input_fidelity: Literal["low", "high"],
stream: bool,
partial_images: int,
bench: bool,
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
image_strength = 0.7 if input_fidelity == "high" else 0.3
data_chunks = [
image_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
command = ImageEdits(
request_params=ImageEditsInternalParams(
image_data="",
total_input_chunks=total_chunks,
prompt=prompt,
model=resolved_model,
n=n,
size=size,
response_format=response_format,
image_strength=image_strength,
stream=stream,
partial_images=partial_images,
bench=bench,
),
)
logger.info(
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(data_chunks):
await self._send(
SendInputChunk(
chunk=InputImageChunk(
idx=chunk_index,
model=resolved_model,
command_id=command.command_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
)
)
)
await self._send(command)
return command
async def image_edits(
self,
image: UploadFile = File(...),
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
stream: bool = Form(False),
partial_images: int = Form(0),
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image editing requests (img2img)."""
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
n=n,
size=size,
response_format=response_format,
input_fidelity=input_fidelity,
stream=stream,
partial_images=partial_images,
bench=False,
)
if stream and partial_images and partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=n,
response_format=response_format,
),
media_type="text/event-stream",
)
return await self._collect_image_generation(
command_id=command.command_id,
num_images=n,
response_format=response_format,
)
async def bench_image_edits(
self,
image: UploadFile = File(...),
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
) -> BenchImageGenerationResponse:
"""Handle benchmark image editing requests with generation stats."""
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
n=n,
size=size,
response_format=response_format,
input_fidelity=input_fidelity,
stream=False,
partial_images=0,
bench=True,
)
return await self._collect_image_generation_with_stats(
command_id=command.command_id,
num_images=n,
response_format=response_format,
)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for profile in self.state.node_profiles.values():
total_available += profile.memory.ram_available
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
return total_available
@@ -610,13 +1011,14 @@ class API:
return ModelList(
data=[
ModelListModel(
id=card.model_id,
id=card.short_id,
hugging_face_id=card.model_id,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
tasks=[task.value for task in card.tasks],
)
for card in MODEL_CARDS.values()
]
@@ -655,13 +1057,16 @@ class API:
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
elif event.command_id in self._image_generation_queues:
assert isinstance(event.chunk, ImageChunk)
await self._image_generation_queues[event.command_id].send(
event.chunk
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -16,8 +16,11 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskFinished,
TestCommand,
)
@@ -26,8 +29,8 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
TaskDeleted,
@@ -36,6 +39,12 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
from exo.shared.types.tasks import (
ImageGeneration as ImageGenerationTask,
)
from exo.shared.types.tasks import (
TaskId,
TaskStatus,
@@ -100,13 +109,14 @@ class Master:
async for forwarder_command in commands:
try:
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.command
instance_task_counts: dict[InstanceId, int] = {}
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
@@ -147,6 +157,90 @@ class Master:
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageGeneration():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageEdits():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
@@ -159,7 +253,6 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement
@@ -175,6 +268,13 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
generated_events.append(
InputChunkReceived(
command_id=chunk.command_id,
chunk=chunk,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -202,7 +302,9 @@ class Master:
async def _plan(self) -> None:
while True:
# kill broken instances
connected_node_ids = set(self.state.topology.list_nodes())
connected_node_ids = set(
[x.node_id for x in self.state.topology.list_nodes()]
)
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
@@ -237,8 +339,6 @@ class Master:
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
self._event_log.append(event)
await self._send_event(indexed)

View File

@@ -6,25 +6,23 @@ from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
Cycle,
filter_cycles_by_memory,
get_mlx_ibv_devices_matrix,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.models import ModelId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -54,32 +52,37 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
logger.info("finding cycles:")
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_profiles, command.model_card.storage_size
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
if len(cycles_with_sufficient_memory) == 0:
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
)
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_card.supports_tensor:
if not command.model_meta.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_card.hidden_size % len(cycle) == 0
if command.model_meta.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
@@ -89,38 +92,44 @@ def place_instance(
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
cycle
for cycle in smallest_cycles
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
]
if smallest_tb_cycles != []:
smallest_cycles = smallest_tb_cycles
cycles_with_leaf_nodes: list[Cycle] = [
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
cycle
for cycle in smallest_cycles
if any(topology.node_is_leaf(node_id) for node_id in cycle)
if any(topology.node_is_leaf(node.node_id) for node in cycle)
]
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle),
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
),
start=Memory(),
),
)
shard_assignments = get_shard_assignments(
command.model_card, selected_cycle, command.sharding, node_profiles
command.model_meta, selected_cycle, command.sharding
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
"You have likely selected ibv for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
@@ -128,20 +137,19 @@ def place_instance(
# TODO: Single node instances
match command.instance_meta:
case InstanceMeta.MlxJaccl:
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
[node_id for node_id in selected_cycle],
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
selected_cycle,
cycle_digraph,
)
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle.node_ids[0],
selected_cycle,
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
node_profiles=node_profiles,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
)
case InstanceMeta.MlxRing:
@@ -150,7 +158,6 @@ def place_instance(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
node_profiles=node_profiles,
)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,

View File

@@ -1,13 +1,15 @@
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import TypeGuard, cast
from loguru import logger
from pydantic import BaseModel
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -17,113 +19,67 @@ from exo.shared.types.worker.shards import (
)
class NodeWithProfile(BaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
def filter_cycles_by_memory(
cycles: list[Cycle],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
required_memory: Memory,
) -> list[Cycle]:
filtered_cycles: list[Cycle] = []
cycles: list[list[NodeInfo]], required_memory: Memory
) -> list[list[NodeInfo]]:
filtered_cycles: list[list[NodeInfo]] = []
for cycle in cycles:
if not all(node in node_profiles for node in cycle):
if not narrow_all_nodes(cycle):
continue
total_mem = sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
start=Memory(),
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
)
if total_mem >= required_memory:
filtered_cycles.append(cycle)
filtered_cycles.append(cast(list[NodeInfo], cycle))
return filtered_cycles
def get_smallest_cycles(
cycles: list[Cycle],
) -> list[Cycle]:
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(
total_layers: int,
memory_fractions: list[float],
) -> list[int]:
n = len(memory_fractions)
if n == 0:
raise ValueError("Cannot allocate layers to an empty node list")
if total_layers < n:
raise ValueError(
f"Cannot distribute {total_layers} layers across {n} nodes "
"(need at least 1 layer per node)"
)
# Largest remainder: floor each, then distribute remainder by fractional part
raw = [f * total_layers for f in memory_fractions]
result = [int(r) for r in raw]
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
for i in range(total_layers - sum(result)):
result[by_remainder[i]] += 1
# Ensure minimum 1 per node by taking from the largest
for i in range(n):
if result[i] == 0:
max_idx = max(range(n), key=lambda j: result[j])
assert result[max_idx] > 1
result[max_idx] -= 1
result[i] = 1
return result
def get_shard_assignments_for_pipeline_parallel(
model_card: ModelCard,
cycle: Cycle,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
if not cycle.node_ids:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_card.n_layers
world_size = len(cycle)
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node_profiles[node_id].memory.ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
],
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node_profiles[node_id].memory.ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
)
layers_assigned = 0
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(
total_layers
* (
node.node_profile.memory.ram_available.in_bytes
/ cycle_memory.in_bytes
)
)
node_layers = max(1, node_layers)
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
@@ -132,11 +88,11 @@ def get_shard_assignments_for_pipeline_parallel(
)
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
node_to_runner[node.node_id] = runner_id
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -145,17 +101,17 @@ def get_shard_assignments_for_pipeline_parallel(
def get_shard_assignments_for_tensor_parallel(
model_card: ModelCard,
cycle: Cycle,
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
total_layers = model_card.n_layers
world_size = len(cycle)
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for i, node_id in enumerate(cycle):
for i, node in enumerate(selected_cycle):
shard = TensorShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=0,
@@ -166,10 +122,10 @@ def get_shard_assignments_for_tensor_parallel(
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
node_to_runner[node.node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -178,22 +134,22 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_card: ModelCard,
cycle: Cycle,
model_meta: ModelMetadata,
selected_cycle: list[NodeInfo],
sharding: Sharding,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> ShardAssignments:
if not narrow_all_nodes(selected_cycle):
raise ValueError("All nodes must have profiles to create shard assignments")
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_card=model_card,
cycle=cycle,
node_profiles=node_profiles,
model_meta=model_meta,
selected_cycle=selected_cycle,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
model_card=model_card,
cycle=cycle,
model_meta=model_meta,
selected_cycle=selected_cycle,
)
@@ -208,40 +164,38 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
)
return []
cycle = cycles[0]
get_thunderbolt = False
if cycle_digraph.is_thunderbolt_cycle(cycle):
if cycle_digraph.is_thunderbolt_cycle(cycles[0]):
get_thunderbolt = True
logger.info(f"Using thunderbolt cycle: {get_thunderbolt}")
cycle = cycles[0]
hosts: list[Host] = []
for i in range(len(cycle)):
current_node = cycle.node_ids[i]
next_node = cycle.node_ids[(i + 1) % len(cycle)]
current_node = cycle[i]
next_node = cycle[(i + 1) % len(cycle)]
for connection in cycle_digraph.get_all_connections_between(
source=current_node, sink=next_node
):
if not isinstance(connection, SocketConnection):
continue
if get_thunderbolt and not connection.is_thunderbolt():
continue
host = Host(
ip=connection.sink_multiaddr.ip_address,
port=connection.sink_multiaddr.port,
)
hosts.append(host)
break
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == current_node.node_id
and connection.send_back_node_id == next_node.node_id
):
if get_thunderbolt and not connection.is_thunderbolt():
continue
assert connection.send_back_multiaddr is not None
host = Host(
ip=connection.send_back_multiaddr.ip_address,
port=connection.send_back_multiaddr.port,
)
hosts.append(host)
break
return hosts
def get_mlx_jaccl_devices_matrix(
selected_cycle: list[NodeId],
def get_mlx_ibv_devices_matrix(
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
) -> list[list[str | None]]:
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
@@ -260,37 +214,72 @@ def get_mlx_jaccl_devices_matrix(
if i == j:
continue
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(conn, RDMAConnection):
matrix[i][j] = conn.source_rdma_iface
# Find the IP J uses to talk to I
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
# This is a local IP on I, which is attached to an interface: find that interface
if interface_name := _find_rdma_interface_name_for_ip(
connection_ip, node_i
):
matrix[i][j] = interface_name
logger.info(
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
)
break
else:
logger.warning(
f"Failed to find interface name between {node_i} and {node_j}"
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
)
raise ValueError(
"Current jaccl backend requires all-to-all RDMA connections"
"Current ibv backend requires all-to-all rdma connections"
)
return matrix
def _find_connection_ip(
node_i: NodeId,
node_j: NodeId,
node_i: NodeInfo,
node_j: NodeInfo,
cycle_digraph: Topology,
) -> Generator[tuple[str, bool]]:
"""Find all IP addresses that connect node i to node j."""
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
if isinstance(connection, SocketConnection):
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
for connection in cycle_digraph.list_connections():
if (
connection.local_node_id == node_i.node_id
and connection.send_back_node_id == node_j.node_id
):
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
def _find_rdma_interface_name_for_ip(
ip_address: str,
node_info: NodeInfo,
) -> str | None:
if node_info.node_profile is None:
return None
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
for interface in node_info.node_profile.network_interfaces:
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
continue
logger.info(f" | {interface.name}: {interface.ip_address}")
if interface.ip_address != ip_address:
continue
logger.info("Found")
return f"rdma_{interface.name}"
return None
def _find_interface_name_for_ip(
ip_address: str, node_profile: NodePerformanceProfile
ip_address: str,
node_info: NodeInfo,
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
for interface in node_profile.network_interfaces:
if node_info.node_profile is None:
return None
for interface in node_info.node_profile.network_interfaces:
if interface.ip_address == ip_address:
return interface.name
@@ -298,10 +287,7 @@ def _find_interface_name_for_ip(
def _find_ip_prioritised(
node_id: NodeId,
other_node_id: NodeId,
cycle_digraph: Topology,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
@@ -312,12 +298,9 @@ def _find_ip_prioritised(
3. Non-Thunderbolt connections
4. Any other IP address
"""
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {
_find_interface_name_for_ip(ip, node_profiles[other_node_id]): ip
for ip, _ in ips
}
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
en0_ip = iface_map.get("en0")
if en0_ip:
@@ -341,10 +324,9 @@ def _find_ip_prioritised(
def get_mlx_ring_hosts_by_node(
selected_cycle: Cycle,
selected_cycle: list[NodeInfo],
cycle_digraph: Topology,
ephemeral_port: int,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[NodeId, list[Host]]:
"""Generate per-node host lists for MLX ring backend.
@@ -359,13 +341,14 @@ def get_mlx_ring_hosts_by_node(
hosts_by_node: dict[NodeId, list[Host]] = {}
for rank, node_id in enumerate(selected_cycle):
for rank, node in enumerate(selected_cycle):
node_id = node.node_id
left_rank = (rank - 1) % world_size
right_rank = (rank + 1) % world_size
hosts_for_node: list[Host] = []
for idx, other_node_id in enumerate(selected_cycle):
for idx, other_node in enumerate(selected_cycle):
if idx == rank:
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
continue
@@ -375,12 +358,10 @@ def get_mlx_ring_hosts_by_node(
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
continue
connection_ip = _find_ip_prioritised(
node_id, other_node_id, cycle_digraph, node_profiles
)
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
if connection_ip is None:
logger.warning(
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
)
raise ValueError(
"MLX ring backend requires connectivity between neighbouring nodes"
@@ -394,34 +375,31 @@ def get_mlx_ring_hosts_by_node(
def get_mlx_jaccl_coordinators(
coordinator: NodeId,
selected_cycle: list[NodeInfo],
coordinator_port: int,
cycle_digraph: Topology,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
Select an IP address that each node can reach for the rank 0 node. Returns
address in format "X.X.X.X:PORT" per node.
"""
logger.info(f"Selecting coordinator: {coordinator}")
rank_0_node = selected_cycle[0]
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeId) -> str:
if n == coordinator:
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_profiles)
if ip is not None:
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
return ip
logger.warning(
f"Failed to find directly connected ip between {n} and {coordinator}"
)
raise ValueError(
"Current jaccl backend requires all participating devices to be able to communicate"
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
)
raise ValueError("Current ibv backend requires all-to-all rdma connections")
return {
n: f"{get_ip_for_node(n)}:{coordinator_port}"
for n in cycle_digraph.list_nodes()
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
}

View File

@@ -1,39 +1,67 @@
from typing import Callable
import pytest
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
def create_node_profile(memory: int) -> NodePerformanceProfile:
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[
NetworkInterfaceInfo(name="en0", ip_address=f"169.254.0.{i}")
for i in range(10)
],
system=SystemPerformanceProfile(),
)
@pytest.fixture
def create_node():
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
if node_id is None:
node_id = NodeId()
return NodeInfo(
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
return _create_node
def create_socket_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
)
# TODO: this is a hack to get the port for the send_back_multiaddr
@pytest.fixture
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
port_counter = 1235
ip_counter = 1
def _create_connection(
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
) -> Connection:
nonlocal port_counter
nonlocal ip_counter
# assign unique ips
ip_counter += 1
if send_back_port is None:
send_back_port = port_counter
port_counter += 1
return Connection(
local_node_id=source_node_id,
send_back_node_id=sink_node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
def create_rdma_connection(iface: int) -> RDMAConnection:
return RDMAConnection(
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
)
return _create_connection

View File

@@ -1,107 +0,0 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
from exo.master.api import API
app = FastAPI()
# Setup exception handler
api = object.__new__(API)
api.app = app
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
# Add test routes that raise HTTPException
@app.get("/test-error")
async def _test_error() -> None:
raise HTTPException(status_code=500, detail="Test error message")
@app.get("/test-not-found")
async def _test_not_found() -> None:
raise HTTPException(status_code=404, detail="Resource not found")
client = TestClient(app)
# Test 500 error
response = client.get("/test-error")
assert response.status_code == 500
data: dict[str, Any] = response.json()
assert "error" in data
assert data["error"]["message"] == "Test error message"
assert data["error"]["type"] == "Internal Server Error"
assert data["error"]["code"] == 500
# Test 404 error
response = client.get("/test-not-found")
assert response.status_code == 404
data = response.json()
assert "error" in data
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason

View File

@@ -7,7 +7,6 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
@@ -20,12 +19,15 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodeGatheredInfo,
NodePerformanceMeasured,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
@@ -81,14 +83,21 @@ async def test_master():
origin=sender_node_id,
session=session_id,
event=(
NodeGatheredInfo(
NodePerformanceMeasured(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
),
)
),
@@ -109,8 +118,9 @@ async def test_master():
command=(
PlaceInstance(
command_id=CommandId(),
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
@@ -153,7 +163,7 @@ async def test_master():
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodeGatheredInfo)
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
created_instance = events[1].event.instance
assert isinstance(created_instance, MlxRingInstance)
@@ -166,8 +176,9 @@ async def test_master():
start_layer=0,
end_layer=16,
n_layers=16,
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,

View File

@@ -1,23 +1,20 @@
from typing import Callable
import pytest
from loguru import logger
from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.tests.conftest import (
create_node_profile,
create_rdma_connection,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -29,6 +26,11 @@ from exo.shared.types.worker.runners import ShardAssignments
from exo.shared.types.worker.shards import Sharding
@pytest.fixture
def topology() -> Topology:
return Topology()
@pytest.fixture
def instance() -> Instance:
return MlxRingInstance(
@@ -42,20 +44,21 @@ def instance() -> Instance:
@pytest.fixture
def model_card() -> ModelCard:
return ModelCard(
def model_meta() -> ModelMetadata:
return ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
supports_tensor=True,
)
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
return PlaceInstance(
command_id=CommandId(),
model_card=model_card,
model_meta=model_meta,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
@@ -67,70 +70,47 @@ def place_instance_command(model_card: ModelCard) -> PlaceInstance:
[
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 468, 1092), 12, (2, 3, 7)),
((312, 518, 1024), 12, (2, 3, 7)),
],
)
def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
model_card: ModelCard,
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_card)
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
# fully connected (directed) between the 3 nodes
conn_a_b = Connection(
source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)
)
conn_b_c = Connection(
source=node_id_b, sink=node_id_c, edge=create_socket_connection(2)
)
conn_c_a = Connection(
source=node_id_c, sink=node_id_a, edge=create_socket_connection(3)
)
conn_c_b = Connection(
source=node_id_c, sink=node_id_b, edge=create_socket_connection(4)
)
conn_a_c = Connection(
source=node_id_a, sink=node_id_c, edge=create_socket_connection(5)
)
conn_b_a = Connection(
source=node_id_b, sink=node_id_a, edge=create_socket_connection(6)
)
profiles = {
node_id_a: create_node_profile(available_memory[0]),
node_id_b: create_node_profile(available_memory[1]),
node_id_c: create_node_profile(available_memory[2]),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
topology.add_connection(conn_b_a)
topology.add_node(create_node(available_memory[0], node_id_a))
topology.add_node(create_node(available_memory[1], node_id_b))
topology.add_node(create_node(available_memory[2], node_id_c))
# Add bidirectional connections for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert instance.shard_assignments.model_id == model_card.model_id
assert instance.shard_assignments.model_id == model_meta.model_id
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
@@ -150,21 +130,23 @@ def test_get_instance_placements_create_instance(
assert shards_sorted[-1].end_layer == total_layers
def test_get_instance_placements_one_node_exact_fit() -> None:
def test_get_instance_placements_one_node_exact_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
ModelCard(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -175,21 +157,23 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
def test_get_instance_placements_one_node_fits_with_extra_memory(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1001 * 1024)}
topology.add_node(create_node(1001 * 1024, node_id))
cic = place_instance_command(
ModelCard(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -200,15 +184,17 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
assert len(instance.shard_assignments.runner_to_shard) == 1
def test_get_instance_placements_one_node_not_fit() -> None:
def test_get_instance_placements_one_node_not_fit(
create_node: Callable[[int, NodeId | None], NodeInfo],
) -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -216,7 +202,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {}, profiles)
place_instance(cic, topology, {})
def test_get_transition_events_no_change(instance: Instance):
@@ -261,169 +247,217 @@ def test_get_transition_events_delete_instance(instance: Instance):
assert events[0].instance_id == instance_id
def test_placement_selects_leaf_nodes(
model_card: ModelCard,
def test_placement_selects_cycle_with_most_memory(
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
topology = Topology()
# Arrange two 3-node cycles with different total memory.
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
# The algorithm should select the cycle with the most available memory.
# Model requires more than any single node but fits within a 3-node cycle
model_card.storage_size.in_bytes = 1500
model_card.n_layers = 12
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
# Create node ids
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_id_e = NodeId()
node_id_f = NodeId()
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
node_id_c: create_node_profile(600),
node_id_d: create_node_profile(500),
}
# A-B-C cycle total memory = 1600 (< D-E-F total)
topology.add_node(create_node(400, node_id_a))
topology.add_node(create_node(400, node_id_b))
topology.add_node(create_node(800, node_id_c))
topology.add_node(node_id_a)
topology.add_node(node_id_b)
topology.add_node(node_id_c)
topology.add_node(node_id_d)
# D-E-F cycle total memory = 1800 (> A-B-C total)
topology.add_node(create_node(600, node_id_d))
topology.add_node(create_node(600, node_id_e))
topology.add_node(create_node(600, node_id_f))
# Daisy chain topology (directed)
topology.add_connection(
Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_b, sink=node_id_a, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_b, sink=node_id_c, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_c, sink=node_id_b, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_c, sink=node_id_d, edge=create_socket_connection(1))
)
topology.add_connection(
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
# Build bidirectional cycles for ring topology
topology.add_connection(create_connection(node_id_a, node_id_b))
topology.add_connection(create_connection(node_id_b, node_id_a))
topology.add_connection(create_connection(node_id_b, node_id_c))
topology.add_connection(create_connection(node_id_c, node_id_b))
topology.add_connection(create_connection(node_id_c, node_id_a))
topology.add_connection(create_connection(node_id_a, node_id_c))
topology.add_connection(create_connection(node_id_d, node_id_e))
topology.add_connection(create_connection(node_id_e, node_id_d))
topology.add_connection(create_connection(node_id_e, node_id_f))
topology.add_connection(create_connection(node_id_f, node_id_e))
topology.add_connection(create_connection(node_id_f, node_id_d))
topology.add_connection(create_connection(node_id_d, node_id_f))
cic = place_instance_command(
model_meta=model_meta,
)
cic = place_instance_command(model_card=model_card)
# Act
placements = place_instance(cic, topology, {})
# act
placements = place_instance(cic, topology, {}, profiles)
# assert
# Assert: D-E-F cycle should be selected as it has more total memory
assert len(placements) == 1
instance = list(placements.values())[0]
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
(
node_id_c,
node_id_d,
)
)
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
assert more_memory_cycle_nodes.issubset(assigned_nodes)
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
def test_tensor_rdma_backend_connectivity_matrix(
model_card: ModelCard,
topology: Topology,
model_meta: ModelMetadata,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
topology = Topology()
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
profiles = {
node_a: create_node_profile(500),
node_b: create_node_profile(500),
node_c: create_node_profile(500),
}
node_a = create_node(500, node_id_a)
node_b = create_node(500, node_id_b)
node_c = create_node(500, node_id_c)
ethernet_interface = NetworkInterfaceInfo(
name="en0",
ip_address="10.0.0.1",
)
ethernet_conn = SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
ip_address="192.168.1.100",
)
profiles[node_a].network_interfaces = [ethernet_interface]
profiles[node_b].network_interfaces = [ethernet_interface]
profiles[node_c].network_interfaces = [ethernet_interface]
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
conn_a_b = create_connection(node_id_a, node_id_b)
conn_b_c = create_connection(node_id_b, node_id_c)
conn_c_a = create_connection(node_id_c, node_id_a)
conn_b_a = create_connection(node_id_b, node_id_a)
conn_c_b = create_connection(node_id_c, node_id_b)
conn_a_c = create_connection(node_id_a, node_id_c)
assert conn_a_b.send_back_multiaddr is not None
assert conn_b_c.send_back_multiaddr is not None
assert conn_c_a.send_back_multiaddr is not None
assert conn_b_a.send_back_multiaddr is not None
assert conn_c_b.send_back_multiaddr is not None
assert conn_a_c.send_back_multiaddr is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
ethernet_interface,
],
system=node_c.node_profile.system,
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
# RDMA connections (directed)
topology.add_connection(
Connection(source=node_a, sink=node_b, edge=create_rdma_connection(3))
)
topology.add_connection(
Connection(source=node_b, sink=node_a, edge=create_rdma_connection(3))
)
topology.add_connection(
Connection(source=node_b, sink=node_c, edge=create_rdma_connection(4))
)
topology.add_connection(
Connection(source=node_c, sink=node_b, edge=create_rdma_connection(4))
)
topology.add_connection(
Connection(source=node_a, sink=node_c, edge=create_rdma_connection(5))
)
topology.add_connection(
Connection(source=node_c, sink=node_a, edge=create_rdma_connection(5))
)
# Ethernet connections (directed)
topology.add_connection(Connection(source=node_a, sink=node_b, edge=ethernet_conn))
topology.add_connection(Connection(source=node_b, sink=node_c, edge=ethernet_conn))
topology.add_connection(Connection(source=node_c, sink=node_a, edge=ethernet_conn))
topology.add_connection(Connection(source=node_a, sink=node_c, edge=ethernet_conn))
topology.add_connection(Connection(source=node_b, sink=node_a, edge=ethernet_conn))
topology.add_connection(Connection(source=node_c, sink=node_b, edge=ethernet_conn))
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
topology.add_connection(conn_c_b)
topology.add_connection(conn_a_c)
cic = PlaceInstance(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_card=model_card,
model_meta=model_meta,
min_nodes=1,
)
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {})
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert isinstance(instance, MlxJacclInstance)
assert instance.jaccl_devices is not None
assert instance.ibv_devices is not None
assert instance.jaccl_coordinators is not None
matrix = instance.jaccl_devices
matrix = instance.ibv_devices
assert len(matrix) == 3
for i in range(3):
assert matrix[i][i] is None
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
idx_a = node_to_idx[node_a]
idx_b = node_to_idx[node_b]
idx_c = node_to_idx[node_c]
idx_a = node_to_idx[node_id_a]
idx_b = node_to_idx[node_id_b]
idx_c = node_to_idx[node_id_c]
assert matrix[idx_a][idx_b] == "rdma_en3"
assert matrix[idx_b][idx_c] == "rdma_en4"
assert matrix[idx_c][idx_a] == "rdma_en5"
logger.info(matrix)
assert matrix[idx_a][idx_b] == "rdma_en4"
assert matrix[idx_b][idx_c] == "rdma_en3"
assert matrix[idx_c][idx_a] == "rdma_en3"
# Verify coordinators are set for all nodes
assert len(instance.jaccl_coordinators) == 3
@@ -435,5 +469,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
if node_id == assigned_nodes[0]:
assert coordinator.startswith("0.0.0.0:")
else:
# Non-rank-0 nodes should have valid IP addresses (can be link-local)
ip_part = coordinator.split(":")[0]
# Just verify it's a valid IP format
assert len(ip_part.split(".")) == 4

View File

@@ -1,187 +1,162 @@
from copy import copy
from typing import Callable
import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import create_node_profile, create_socket_connection
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
def test_filter_cycles_by_memory():
@pytest.fixture
def topology() -> Topology:
topology = Topology()
return topology
def test_filter_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
connection1 = Connection(
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology = Topology()
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
cycles = [c for c in topology.get_cycles() if len(c) != 1]
cycles = topology.get_cycles()
assert len(cycles) == 1
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_bytes(1)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 2
assert set(n for n in filtered_cycles[0]) == {node1_id, node2_id}
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
def test_filter_cycles_by_insufficient_memory():
def test_filter_cycles_by_insufficient_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node1_id = NodeId()
node2_id = NodeId()
connection1 = Connection(
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1 = create_node(1000 * 1024, node1_id)
node2 = create_node(1000 * 1024, node2_id)
topology.add_node(node1)
topology.add_node(node2)
connection1 = create_connection(node1_id, node2_id)
connection2 = create_connection(node2_id, node1_id)
topology = Topology()
topology.add_node(node1_id)
topology.add_node(node2_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
topology.get_cycles(), Memory.from_kb(2001)
)
# assert
assert len(filtered_cycles) == 0
def test_filter_multiple_cycles_by_memory():
def test_filter_multiple_cycles_by_memory(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_kb(1500)
)
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
# assert
assert len(filtered_cycles) == 1
assert len(filtered_cycles[0]) == 3
assert set(n for n in filtered_cycles[0]) == {
assert set(n.node_id for n in filtered_cycles[0]) == {
node_a_id,
node_b_id,
node_c_id,
}
def test_get_smallest_cycles():
def test_get_smallest_cycles(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
# act
smallest_cycles = get_smallest_cycles(cycles)
smallest_cycles = get_smallest_cycles(topology.get_cycles())
# assert
assert len(smallest_cycles) == 1
assert len(smallest_cycles[0]) == 2
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
assert set(n.node_id for n in smallest_cycles[0]) == {node_a_id, node_b_id}
@pytest.mark.parametrize(
@@ -190,12 +165,12 @@ def test_get_smallest_cycles():
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
# Edge case: one node has ~90% of memory - should not over-allocate.
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
((900, 50, 50), 20, (18, 1, 1)),
],
)
def test_get_shard_assignments(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
@@ -205,61 +180,44 @@ def test_get_shard_assignments(
node_b_id = NodeId()
node_c_id = NodeId()
# create connections (A -> B -> C -> A forms a 3-cycle, plus B -> A also exists)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
connection4 = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(4)
)
node_a = create_node(available_memory[0] * 1024, node_a_id)
node_b = create_node(available_memory[1] * 1024, node_b_id)
node_c = create_node(available_memory[2] * 1024, node_c_id)
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(connection4)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
node_a = create_node_profile(available_memory[0] * 1024)
node_b = create_node_profile(available_memory[1] * 1024)
node_c = create_node_profile(available_memory[2] * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_card = ModelCard(
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
# pick the 3-node cycle deterministically (cycle ordering can vary)
selected_cycle = next(cycle for cycle in cycles if len(cycle) == 3)
selected_cycle = cycles[0]
# act
shard_assignments = get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_profiles=node_profiles
model_meta, selected_cycle, Sharding.Pipeline
)
# assert
runner_id_a = shard_assignments.node_to_runner[node_a_id]
runner_id_b = shard_assignments.node_to_runner[node_b_id]
runner_id_c = shard_assignments.node_to_runner[node_c_id]
assert (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
== expected_layers[2]
)
assert (
shard_assignments.runner_to_shard[runner_id_a].end_layer
- shard_assignments.runner_to_shard[runner_id_a].start_layer
@@ -270,37 +228,30 @@ def test_get_shard_assignments(
- shard_assignments.runner_to_shard[runner_id_b].start_layer
== expected_layers[1]
)
assert (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
== expected_layers[2]
)
def test_get_hosts_from_subgraph():
def test_get_hosts_from_subgraph(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
node_a = create_node(500, node_a_id)
node_b = create_node(500, node_b_id)
node_c = create_node(1000, node_c_id)
connection1 = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
connection2 = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
connection3 = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(connection1)
topology.add_connection(connection2)
topology.add_connection(connection3)
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
# act
hosts = get_hosts_from_subgraph(topology)
@@ -308,78 +259,95 @@ def test_get_hosts_from_subgraph():
# assert
assert len(hosts) == 3
expected_hosts = [
Host(ip="169.254.0.1", port=1234),
Host(ip="169.254.0.2", port=1234),
Host(ip="169.254.0.3", port=1234),
Host(ip=("169.254.0.2"), port=5001),
Host(ip=("169.254.0.3"), port=5002),
Host(ip=("169.254.0.4"), port=5003),
]
for expected_host in expected_hosts:
assert expected_host in hosts
def test_get_mlx_jaccl_coordinators():
def test_get_mlx_jaccl_coordinators(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
# fully connected (directed) between the 3 nodes
conn_a_b = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
conn_b_a = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
)
conn_b_c = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(3)
)
conn_c_b = Connection(
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
conn_c_a = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(5)
)
conn_a_c = Connection(
source=node_a_id, sink=node_c_id, edge=create_socket_connection(6)
)
node_a = create_node(500 * 1024, node_a_id)
node_b = create_node(500 * 1024, node_b_id)
node_c = create_node(1000 * 1024, node_c_id)
npp = NodePerformanceProfile(
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
# Update node profiles with network interfaces before adding to topology
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
node_a.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=0,
ram_available=0,
swap_total=0,
swap_available=0,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
memory=node_a.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_a_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_a_c.send_back_multiaddr.ip_address,
),
],
system=node_a.node_profile.system,
)
node_b.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_b.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_b_a.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_b_c.send_back_multiaddr.ip_address,
),
],
system=node_b.node_profile.system,
)
node_c.node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=node_c.node_profile.memory,
network_interfaces=[
NetworkInterfaceInfo(
name="en3",
ip_address=conn_c_b.send_back_multiaddr.ip_address,
),
NetworkInterfaceInfo(
name="en4",
ip_address=conn_c_a.send_back_multiaddr.ip_address,
),
],
system=node_c.node_profile.system,
)
npp_a = copy(npp)
npp_a.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.5"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.2"),
]
npp_b = copy(npp)
npp_b.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.1"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.4"),
]
npp_c = copy(npp)
npp_c.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.3"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.6"),
]
node_profiles = {
node_a_id: npp_a,
node_b_id: npp_b,
node_c_id: npp_c,
}
topology = Topology()
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_a)
@@ -388,12 +356,11 @@ def test_get_mlx_jaccl_coordinators():
topology.add_connection(conn_c_a)
topology.add_connection(conn_a_c)
cycle = [node_a, node_b, node_c]
# act
coordinators = get_mlx_jaccl_coordinators(
node_a_id,
coordinator_port=5000,
cycle_digraph=topology,
node_profiles=node_profiles,
cycle, coordinator_port=5000, cycle_digraph=topology
)
# assert
@@ -414,127 +381,19 @@ def test_get_mlx_jaccl_coordinators():
f"Coordinator for {node_id} should use port 5000"
)
# Rank 0 (node_a) treats this as the listen socket so should listen on all IPs
# Rank 0 (node_a) treats this as the listen socket so should listen on all
# IPs
assert coordinators[node_a_id].startswith("0.0.0.0:"), (
"Rank 0 node should use 0.0.0.0 as coordinator listen address"
"Rank 0 node should use localhost as coordinator"
)
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
# node_b uses the IP from conn_b_a (node_b -> node_a)
assert isinstance(conn_b_a.edge, SocketConnection)
assert (
coordinators[node_b_id] == f"{conn_b_a.edge.sink_multiaddr.ip_address}:5000"
assert coordinators[node_b_id] == (
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
), "node_b should use the IP from conn_b_a"
# node_c uses the IP from conn_c_a (node_c -> node_a)
assert isinstance(conn_c_a.edge, SocketConnection)
assert coordinators[node_c_id] == (
f"{conn_c_a.edge.sink_multiaddr.ip_address}:5000"
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
class TestAllocateLayersProportionally:
def test_empty_node_list_raises(self):
with pytest.raises(ValueError, match="empty node list"):
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
def test_zero_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
def test_negative_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
def test_fewer_layers_than_nodes_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
)
def test_equal_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
)
assert result == [3, 3, 3, 3]
assert sum(result) == 12
def test_proportional_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
)
assert result == [3, 3, 6]
assert sum(result) == 12
def test_extreme_imbalance_ensures_minimum(self):
result = allocate_layers_proportionally(
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
)
assert all(layers >= 1 for layers in result)
assert sum(result) == 20
# Small nodes get minimum 1 layer
assert result == [18, 1, 1]
def test_single_node_gets_all_layers(self):
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
assert result == [10]
def test_minimum_viable_allocation(self):
result = allocate_layers_proportionally(
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
)
assert result == [1, 1, 1]
assert sum(result) == 3
def test_get_shard_assignments_insufficient_memory_raises():
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
topology = Topology()
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
node_a = create_node_profile(900 * 1024)
node_b = create_node_profile(50 * 1024)
node_c = create_node_profile(10 * 1024) # Insufficient memory
topology.add_node(node_a_id)
topology.add_node(node_b_id)
topology.add_node(node_c_id)
conn_a_b = Connection(
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
)
conn_b_c = Connection(
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
)
conn_c_a = Connection(
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
)
conn_b_a = Connection(
source=node_b_id, sink=node_a_id, edge=create_socket_connection(3)
)
topology.add_connection(conn_a_b)
topology.add_connection(conn_b_c)
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
model_card = ModelCard(
model_id=ModelId("test-model"),
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(model_card, selected_cycle, Sharding.Pipeline, profiles)

View File

@@ -1,14 +1,13 @@
import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
@pytest.fixture
@@ -17,15 +16,20 @@ def topology() -> Topology:
@pytest.fixture
def socket_connection() -> SocketConnection:
return SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
def connection() -> Connection:
return Connection(
local_node_id=NodeId(),
send_back_node_id=NodeId(),
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
connection_profile=ConnectionProfile(
throughput=1000, latency=1000, jitter=1000
),
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryUsage.from_bytes(
memory_profile = MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
@@ -39,91 +43,162 @@ def node_profile() -> NodePerformanceProfile:
)
def test_add_node(topology: Topology):
@pytest.fixture
def connection_profile() -> ConnectionProfile:
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
# arrange
node_id = NodeId()
# act
topology.add_node(node_id)
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
# assert
assert topology.node_is_leaf(node_id)
data = topology.get_node_profile(node_id)
assert data == node_profile
def test_add_connection(topology: Topology, socket_connection: SocketConnection):
def test_add_connection(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
connection = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
data = list(topology.list_connections())
data = topology.get_connection_profile(connection)
# assert
assert data == [connection]
assert data == connection.connection_profile
assert topology.node_is_leaf(node_a)
assert topology.node_is_leaf(node_b)
def test_update_node_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_node_profile = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryPerformanceProfile.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
# act
topology.update_node_profile(
connection.local_node_id, node_profile=new_node_profile
)
# assert
data = topology.get_node_profile(connection.local_node_id)
assert data == new_node_profile
def test_update_connection_profile(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
new_connection_profile = ConnectionProfile(
throughput=2000, latency=2000, jitter=2000
)
connection = Connection(
local_node_id=connection.local_node_id,
send_back_node_id=connection.send_back_node_id,
send_back_multiaddr=connection.send_back_multiaddr,
connection_profile=new_connection_profile,
)
# act
topology.update_connection_profile(connection)
# assert
data = topology.get_connection_profile(connection)
assert data == new_connection_profile
def test_remove_connection_still_connected(
topology: Topology, socket_connection: SocketConnection
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_connection(conn)
topology.remove_connection(connection)
# assert
assert list(topology.get_all_connections_between(node_a, node_b)) == []
assert topology.get_connection_profile(connection) is None
def test_remove_node_still_connected(
topology: Topology, socket_connection: SocketConnection
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
assert list(topology.out_edges(node_a)) == [conn]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
topology.remove_node(node_b)
topology.remove_node(connection.local_node_id)
# assert
assert list(topology.out_edges(node_a)) == []
assert topology.get_node_profile(connection.local_node_id) is None
def test_list_nodes(topology: Topology, socket_connection: SocketConnection):
def test_list_nodes(
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
):
# arrange
node_a = NodeId()
node_b = NodeId()
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_connection(conn)
assert list(topology.out_edges(node_a)) == [conn]
topology.add_node(
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
)
topology.add_node(
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
)
topology.add_connection(connection)
# act
nodes = list(topology.list_nodes())
# assert
assert len(nodes) == 2
assert all(isinstance(node, NodeId) for node in nodes)
assert set(node for node in nodes) == set([node_a, node_b])
assert all(isinstance(node, NodeInfo) for node in nodes)
assert {node.node_id for node in nodes} == {
connection.local_node_id,
connection.send_back_node_id,
}

View File

@@ -9,10 +9,13 @@ from exo.shared.types.events import (
ChunkGenerated,
Event,
IndexedEvent,
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -25,42 +28,36 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import Connection, RDMAConnection
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import (
MacmonMetrics,
MacThunderboltConnections,
MacThunderboltIdentifiers,
MemoryUsage,
MiscData,
NodeConfig,
NodeNetworkInterfaces,
StaticNodeInformation,
)
def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged()
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
): # Pass-through events that don't modify state
return state
case InstanceCreated():
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeGatheredInfo():
return apply_node_gathered_info(event, state)
case NodeMemoryMeasured():
return apply_node_memory_measured(event, state)
case RunnerDeleted():
return apply_runner_deleted(event, state)
case RunnerStatusUpdated():
@@ -192,7 +189,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology = copy.copy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
@@ -200,12 +197,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
@@ -213,68 +206,103 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
)
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
match info:
case MacmonMetrics():
profile.system = info.system_profile
profile.memory = info.memory
case MemoryUsage():
profile.memory = info
case NodeConfig():
pass
case MiscData():
profile.friendly_name = info.friendly_name
case StaticNodeInformation():
profile.model_id = info.model
profile.chip_id = info.chip
case NodeNetworkInterfaces():
profile.network_interfaces = info.ifaces
case MacThunderboltIdentifiers():
profile.tb_interfaces = info.idents
case MacThunderboltConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_profiles
for tb_ident in state.node_profiles[nid].tb_interfaces
}
as_rdma_conns = [
Connection(
source=event.node_id,
sink=conn_map[tb_conn.sink_uuid][0],
edge=RDMAConnection(
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
),
)
for tb_conn in info.conns
if tb_conn.source_uuid in conn_map
if tb_conn.sink_uuid in conn_map
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
new_profiles = {**state.node_profiles, event.node_id: profile}
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
return state.model_copy(
update={
"node_profiles": new_profiles,
"last_seen": last_seen,
"topology": topology,
"last_seen": last_seen,
}
)
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
)
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
topology = copy.copy(state.topology)
topology.add_node(NodeInfo(node_id=event.node_id))
return state.model_copy(update={"topology": topology})
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_connection(event.conn)
topology = copy.copy(state.topology)
topology.add_connection(event.edge)
return state.model_copy(update={"topology": topology})
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.remove_connection(event.conn)
topology = copy.copy(state.topology)
if not topology.contains_connection(event.edge):
return state
topology.remove_connection(event.edge)
# TODO: Clean up removing the reverse connection
return state.model_copy(update={"topology": topology})

View File

@@ -38,10 +38,11 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
# Identity (config)
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
# libp2p topics for event forwarding
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024

View File

@@ -11,6 +11,9 @@ class InterceptLogger(HypercornLogger):
def __init__(self, config: Config):
super().__init__(config)
assert self.error_logger
# TODO: Decide if we want to provide access logs
# assert self.access_logger
# self.access_logger.handlers = [_InterceptHandler()]
self.error_logger.handlers = [_InterceptHandler()]
@@ -26,11 +29,6 @@ class _InterceptHandler(logging.Handler):
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

@@ -1,281 +1,772 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.shared.types.models import ComponentInfo, ModelId, ModelMetadata, ModelTask
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class ModelCard(CamelCaseModel):
short_id: str
model_id: ModelId
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
name: str
description: str
tasks: list[ModelTask]
tags: list[str]
metadata: ModelMetadata
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"deepseek-v3.1-8bit": ModelCard(
short_id="deepseek-v3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"kimi-k2-thinking": ModelCard(
short_id="kimi-k2-thinking",
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-70b": ModelCard(
short_id="llama-3.1-70b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# llama-3.2
"llama-3.2-1b": ModelCard(
short_id="llama-3.2-1b",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
),
"llama-3.2-3b": ModelCard(
short_id="llama-3.2-3b",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
"llama-3.2-3b-8bit": ModelCard(
short_id="llama-3.2-3b-8bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
# llama-3.3
"llama-3.3-70b": ModelCard(
short_id="llama-3.3-70b",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-8bit": ModelCard(
short_id="llama-3.3-70b-8bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-fp16": ModelCard(
short_id="llama-3.3-70b-fp16",
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-0.6b-8bit": ModelCard(
short_id="qwen3-0.6b-8bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-30b": ModelCard(
short_id="qwen3-30b",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-30b-8bit": ModelCard(
short_id="qwen3-30b-8bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-235b-a22b-4bit": ModelCard(
short_id="qwen3-235b-a22b-4bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
short_id="qwen3-coder-480b-a35b-4bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
short_id="qwen3-coder-480b-a35b-8bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
"flux1-schnell": ModelCard(
short_id="flux1-schnell",
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
name="FLUX.1 [schnell]",
description="""FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
pretty_name="FLUX.1 [schnell]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120), # + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"flux1-dev": ModelCard(
short_id="flux1-dev",
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
name="FLUX.1 [dev]",
description="""FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
pretty_name="FLUX.1 [dev]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"qwen-image": ModelCard(
short_id="qwen-image",
model_id=ModelId("Qwen/Qwen-Image"),
name="Qwen Image",
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("Qwen/Qwen-Image"),
pretty_name="Qwen Image",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"qwen-image-edit-2509": ModelCard(
short_id="qwen-image-edit-2509",
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
name="Qwen Image Edit 2509",
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
tasks=[ModelTask.ImageToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
pretty_name="Qwen Image Edit 2509",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
}

View File

@@ -6,8 +6,9 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
@@ -91,18 +92,18 @@ async def get_safetensors_size(model_id: str) -> Memory:
return Memory.from_bytes(info.safetensors.total)
_model_card_cache: dict[str, ModelCard] = {}
_model_meta_cache: dict[str, ModelMetadata] = {}
async def get_model_card(model_id: str) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def _get_model_card(model_id: str) -> ModelCard:
async def _get_model_meta(model_id: str) -> ModelMetadata:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
@@ -112,11 +113,14 @@ async def _get_model_card(model_id: str) -> ModelCard:
None,
)
return ModelCard(
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -7,8 +7,8 @@ import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@@ -31,8 +31,9 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,

View File

@@ -43,4 +43,7 @@ def test_apply_two_node_download_progress():
NodeDownloadProgress(download_progress=event2), state
)
# TODO: This test is failing. We should support the following:
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
# 2. Downloading a model, it completes, then downloading a different model on the same node.
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}

View File

@@ -1,7 +1,7 @@
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.topology import Connection
def test_state_serialization_roundtrip() -> None:
@@ -12,11 +12,9 @@ def test_state_serialization_roundtrip() -> None:
node_b = NodeId("node-b")
connection = Connection(
source=node_a,
sink=node_b,
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
),
local_node_id=node_a,
send_back_node_id=node_b,
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
)
state = State()
@@ -25,11 +23,5 @@ def test_state_serialization_roundtrip() -> None:
json_repr = state.model_dump_json()
restored_state = State.model_validate_json(json_repr)
assert (
state.topology.to_snapshot().nodes
== restored_state.topology.to_snapshot().nodes
)
assert set(state.topology.to_snapshot().connections) == set(
restored_state.topology.to_snapshot().connections
)
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
assert restored_state.model_dump_json() == json_repr

View File

@@ -1,227 +1,203 @@
import contextlib
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Iterable
import rustworkx as rx
from pydantic import BaseModel, ConfigDict
from exo.shared.types.common import NodeId
from exo.shared.types.topology import (
Connection,
Cycle,
RDMAConnection,
SocketConnection,
)
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
class TopologySnapshot(BaseModel):
nodes: Sequence[NodeId]
connections: Mapping[
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
]
nodes: list[NodeInfo]
connections: list[Connection]
model_config = ConfigDict(frozen=True, extra="forbid")
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
@dataclass
class Topology:
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
init=False, default_factory=rx.PyDiGraph
)
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
def __init__(self) -> None:
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
def to_snapshot(self) -> TopologySnapshot:
return TopologySnapshot(
nodes=list(self.list_nodes()), connections=self.map_connections()
nodes=list(self.list_nodes()),
connections=list(self.list_connections()),
)
@classmethod
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
topology = cls()
for node_id in snapshot.nodes:
for node in snapshot.nodes:
with contextlib.suppress(ValueError):
topology.add_node(node_id)
topology.add_node(node)
for source in snapshot.connections:
for sink in snapshot.connections[source]:
for edge in snapshot.connections[source][sink]:
topology.add_connection(
Connection(source=source, sink=sink, edge=edge)
)
for connection in snapshot.connections:
topology.add_connection(connection)
return topology
def add_node(self, node_id: NodeId) -> None:
if node_id in self._vertex_indices:
def add_node(self, node: NodeInfo) -> None:
if node.node_id in self._node_id_to_rx_id_map:
return
rx_id = self._graph.add_node(node_id)
self._vertex_indices[node_id] = rx_id
rx_id = self._graph.add_node(node)
self._node_id_to_rx_id_map[node.node_id] = rx_id
self._rx_id_to_node_id_map[rx_id] = node.node_id
def node_is_leaf(self, node_id: NodeId) -> bool:
return (
node_id in self._vertex_indices
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
node_id in self._node_id_to_rx_id_map
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
)
def neighbours(self, node_id: NodeId) -> list[NodeId]:
return [
self._graph[rx_id]
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
self._rx_id_to_node_id_map[rx_id]
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
]
def out_edges(self, node_id: NodeId) -> Iterable[Connection]:
if node_id not in self._vertex_indices:
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
if node_id not in self._node_id_to_rx_id_map:
return []
return (
Connection(source=self._graph[source], sink=self._graph[sink], edge=edge)
for source, sink, edge in self._graph.out_edges(
self._vertex_indices[node_id]
return [
(self._rx_id_to_node_id_map[nid], conn)
for _, nid, conn in self._graph.out_edges(
self._node_id_to_rx_id_map[node_id]
)
)
]
def contains_node(self, node_id: NodeId) -> bool:
return node_id in self._vertex_indices
return node_id in self._node_id_to_rx_id_map
def add_connection(self, conn: Connection) -> None:
source, sink, edge = conn.source, conn.sink, conn.edge
del conn
if edge in self.get_all_connections_between(source, sink):
def contains_connection(self, connection: Connection) -> bool:
return connection in self._edge_id_to_rx_id_map
def add_connection(
self,
connection: Connection,
) -> None:
if connection.local_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.local_node_id))
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
if connection in self._edge_id_to_rx_id_map:
return
if source not in self._vertex_indices:
self.add_node(source)
if sink not in self._vertex_indices:
self.add_node(sink)
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
rx_id = self._graph.add_edge(src_id, sink_id, connection)
self._edge_id_to_rx_id_map[connection] = rx_id
_ = self._graph.add_edge(src_id, sink_id, edge)
def list_nodes(self) -> Iterable[NodeInfo]:
return (self._graph[i] for i in self._graph.node_indices())
def get_all_connections_between(
self, source: NodeId, sink: NodeId
) -> Iterable[SocketConnection | RDMAConnection]:
if source not in self._vertex_indices:
return []
if sink not in self._vertex_indices:
return []
def list_connections(self) -> Iterable[Connection]:
return (connection for _, _, connection in self._graph.weighted_edge_list())
src_id = self._vertex_indices[source]
sink_id = self._vertex_indices[sink]
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
try:
return self._graph.get_all_edge_data(src_id, sink_id)
except rx.NoEdgeBetweenNodes:
return []
rx_idx = self._node_id_to_rx_id_map[node_id]
return self._graph.get_node_data(rx_idx).node_profile
except KeyError:
return None
def list_nodes(self) -> Iterable[NodeId]:
return self._graph.nodes()
def update_node_profile(
self, node_id: NodeId, node_profile: NodePerformanceProfile
) -> None:
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph[rx_idx].node_profile = node_profile
def map_connections(
self,
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
for src_id, sink_id, connection in self._graph.weighted_edge_list():
source = self._graph[src_id]
sink = self._graph[sink_id]
if source not in base:
base[source] = {}
if sink not in base[source]:
base[source][sink] = []
base[source][sink].append(connection)
return base
def update_connection_profile(self, connection: Connection) -> None:
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.update_edge_by_index(rx_idx, connection)
def list_connections(
self,
) -> Iterable[Connection]:
return (
(
Connection(
source=self._graph[src_id],
sink=self._graph[sink_id],
edge=connection,
)
)
for src_id, sink_id, connection in self._graph.weighted_edge_list()
)
def get_connection_profile(
self, connection: Connection
) -> ConnectionProfile | None:
try:
rx_idx = self._edge_id_to_rx_id_map[connection]
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
except KeyError:
return None
def remove_node(self, node_id: NodeId) -> None:
if node_id not in self._vertex_indices:
if node_id not in self._node_id_to_rx_id_map:
return
rx_idx = self._vertex_indices[node_id]
for connection in self.list_connections():
if (
connection.local_node_id == node_id
or connection.send_back_node_id == node_id
):
self.remove_connection(connection)
rx_idx = self._node_id_to_rx_id_map[node_id]
self._graph.remove_node(rx_idx)
del self._vertex_indices[node_id]
del self._node_id_to_rx_id_map[node_id]
del self._rx_id_to_node_id_map[rx_idx]
def replace_all_out_rdma_connections(
self, source: NodeId, new_connections: Sequence[Connection]
) -> None:
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
self._graph.remove_edge_from_index(conn_idx)
for conn in new_connections:
self.add_connection(conn)
def remove_connection(self, conn: Connection) -> None:
if (
conn.source not in self._vertex_indices
or conn.sink not in self._vertex_indices
):
def remove_connection(self, connection: Connection) -> None:
if connection not in self._edge_id_to_rx_id_map:
return
for conn_idx in self._graph.edge_indices_from_endpoints(
self._vertex_indices[conn.source], self._vertex_indices[conn.sink]
):
if self._graph.get_edge_data_by_index(conn_idx) == conn.edge:
self._graph.remove_edge_from_index(conn_idx)
def get_cycles(self) -> list[Cycle]:
"""Get simple cycles in the graph, including singleton cycles"""
rx_idx = self._edge_id_to_rx_id_map[connection]
self._graph.remove_edge_from_index(rx_idx)
del self._edge_id_to_rx_id_map[connection]
def get_cycles(self) -> list[list[NodeInfo]]:
cycle_idxs = rx.simple_cycles(self._graph)
cycles: list[Cycle] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = Cycle(node_ids=[self._graph[idx] for idx in cycle_idx])
cycle = [self._graph[idx] for idx in cycle_idx]
cycles.append(cycle)
for node_id in self.list_nodes():
cycles.append(Cycle(node_ids=[node_id]))
return cycles
def get_cycles_tb(self) -> list[Cycle]:
def get_cycles_tb(self) -> list[list[NodeInfo]]:
tb_edges = [
(u, v, conn)
for u, v, conn in self._graph.weighted_edge_list()
if conn.is_thunderbolt()
]
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
tb_graph.add_nodes_from(self._graph.nodes())
for u, v, conn in tb_edges:
if isinstance(conn, SocketConnection):
tb_graph.add_edge(u, v, conn)
tb_graph.add_edge(u, v, conn)
cycle_idxs = rx.simple_cycles(tb_graph)
cycles: list[Cycle] = []
cycles: list[list[NodeInfo]] = []
for cycle_idx in cycle_idxs:
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
cycle = [tb_graph[idx] for idx in cycle_idx]
cycles.append(cycle)
return cycles
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
node_idxs = [node.node_id for node in nodes]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
topology = Topology()
for node_id in node_ids:
topology.add_node(node_id)
for rx_idx in rx_idxs:
topology.add_node(self._graph[rx_idx])
for connection in self.list_connections():
if connection.source in node_ids and connection.sink in node_ids:
if (
connection.local_node_id in node_idxs
and connection.send_back_node_id in node_idxs
):
topology.add_connection(connection)
return topology
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
node_idxs = [node for node in cycle]
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
node_idxs = [node.node_id for node in cycle]
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
for rid in rx_idxs:
for neighbor_rid in self._graph.neighbors(rid):
if neighbor_rid not in rx_idxs:

View File

@@ -1,31 +1,22 @@
import time
from typing import Any, Literal
from collections.abc import Generator
from typing import Annotated, Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
"stop", "length", "tool_calls", "content_filter", "function_call"
]
class ErrorInfo(BaseModel):
message: str
type: str
param: str | None = None
code: int
class ErrorResponse(BaseModel):
error: ErrorInfo
class ModelListModel(BaseModel):
id: str
object: str = "model"
@@ -39,6 +30,7 @@ class ModelListModel(BaseModel):
tags: list[str] = Field(default=[])
storage_size_megabytes: int = Field(default=0)
supports_tensor: bool = Field(default=False)
tasks: list[str] = Field(default=[])
class ModelList(BaseModel):
@@ -137,6 +129,19 @@ class GenerationStats(BaseModel):
peak_memory_usage: Memory
class ImageGenerationStats(BaseModel):
seconds_per_step: float
total_generation_time: float
num_inference_steps: int
num_images: int
image_width: int
image_height: int
peak_memory_usage: Memory
class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
@@ -206,10 +211,110 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_card: ModelCard
model_meta: ModelMetadata
class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class AdvancedImageParams(BaseModel):
seed: Annotated[int, Field(ge=0)] | None = None
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None
negative_prompt: str | None = None
class ImageGenerationTaskParams(BaseModel):
prompt: str
background: str | None = None
model: str
moderation: str | None = None
n: int | None = 1
output_compression: int | None = None
output_format: Literal["png", "jpeg", "webp"] = "png"
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
stream: bool | None = False
style: str | None = "vivid"
user: str | None = None
advanced_params: AdvancedImageParams | None = None
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
bench: bool = True
class ImageEditsTaskParams(BaseModel):
image: UploadFile
prompt: str
background: str | None = None
input_fidelity: float | None = None
mask: UploadFile | None = None
model: str
n: int | None = 1
output_compression: int | None = None
output_format: Literal["png", "jpeg", "webp"] = "png"
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
stream: bool | None = False
user: str | None = None
advanced_params: AdvancedImageParams | None = None
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
class ImageEditsInternalParams(BaseModel):
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
total_input_chunks: int = 0
prompt: str
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
image_strength: float | None = 0.7
stream: bool = False
partial_images: int | None = 0
advanced_params: AdvancedImageParams | None = None
bench: bool = False
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} chars>"
elif name is not None:
yield name, value
class ImageData(BaseModel):
b64_json: str | None = None
url: str | None = None
revised_prompt: str | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "b64_json" and self.b64_json is not None:
yield name, f"<{len(self.b64_json)} chars>"
elif name is not None:
yield name, value
class ImageGenerationResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
data: list[ImageData]
class BenchImageGenerationResponse(ImageGenerationResponse):
generation_stats: ImageGenerationStats | None = None

View File

@@ -1,10 +1,13 @@
from collections.abc import Generator
from enum import Enum
from typing import Any
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats
from exo.shared.types.api import GenerationStats, ImageGenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .common import CommandId
from .models import ModelId
class ChunkType(str, Enum):
@@ -22,11 +25,38 @@ class TokenChunk(BaseChunk):
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):
data: bytes
data: str
chunk_index: int
total_chunks: int
image_index: int
is_partial: bool = False
partial_index: int | None = None
total_partials: int | None = None
stats: ImageGenerationStats | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data" and hasattr(value, "__len__"):
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
class InputImageChunk(BaseChunk):
command_id: CommandId
data: str
chunk_index: int
total_chunks: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data" and hasattr(value, "__len__"):
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
GenerationChunk = TokenChunk | ImageChunk

View File

@@ -1,8 +1,13 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -20,8 +25,16 @@ class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
class ImageGeneration(BaseCommand):
request_params: ImageGenerationTaskParams
class ImageEdits(BaseCommand):
request_params: ImageEditsInternalParams
class PlaceInstance(BaseCommand):
model_card: ModelCard
model_meta: ModelMetadata
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
@@ -39,6 +52,12 @@ class TaskFinished(BaseCommand):
finished_command_id: CommandId
class SendInputChunk(BaseCommand):
"""Command to send an input image chunk (converted to event by master)."""
chunk: InputImageChunk
class RequestEventLog(BaseCommand):
since_idx: int
@@ -47,10 +66,13 @@ Command = (
TestCommand
| RequestEventLog
| ChatCompletion
| ImageGeneration
| ImageEdits
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskFinished
| SendInputChunk
)

View File

@@ -16,9 +16,7 @@ class Id(str):
cls, _source: type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Just use a plain string schema
return core_schema.no_info_after_validator_function(
cls, core_schema.str_schema()
)
return core_schema.str_schema()
class NodeId(Id):

View File

@@ -2,14 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -76,15 +76,25 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
# TODO
class NodeCreated(BaseEvent):
node_id: NodeId
class NodeTimedOut(BaseEvent):
node_id: NodeId
# TODO: bikeshed this name
class NodeGatheredInfo(BaseEvent):
class NodePerformanceMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
info: GatheredInfo
node_profile: NodePerformanceProfile
class NodeMemoryMeasured(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
class NodeDownloadProgress(BaseEvent):
@@ -96,12 +106,17 @@ class ChunkGenerated(BaseEvent):
chunk: GenerationChunk
class InputChunkReceived(BaseEvent):
command_id: CommandId
chunk: InputImageChunk
class TopologyEdgeCreated(BaseEvent):
conn: Connection
edge: Connection
class TopologyEdgeDeleted(BaseEvent):
conn: Connection
edge: Connection
Event = (
@@ -115,10 +130,13 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodeGatheredInfo
| NodePerformanceMeasured
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| TopologyEdgeCreated
| TopologyEdgeDeleted
)

View File

@@ -0,0 +1,36 @@
from enum import Enum
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
pass
class ModelTask(str, Enum):
TextGeneration = "TextGeneration"
TextToImage = "TextToImage"
ImageToImage = "ImageToImage"
class ComponentInfo(CamelCaseModel):
component_name: str
component_path: str
storage_size: Memory
n_layers: PositiveInt | None
can_shard: bool
safetensors_index_filename: str | None
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
components: list[ComponentInfo] | None = None

View File

@@ -1,11 +1,10 @@
import re
from typing import ClassVar
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
from pydantic import BaseModel, computed_field, field_validator
class Multiaddr(BaseModel):
model_config = ConfigDict(frozen=True)
address: str
PATTERNS: ClassVar[list[str]] = [

View File

@@ -1,14 +1,12 @@
from collections.abc import Sequence
from typing import Self
import psutil
from exo.shared.types.memory import Memory
from exo.shared.types.thunderbolt import ThunderboltIdentifier
from exo.utils.pydantic_ext import CamelCaseModel
class MemoryUsage(CamelCaseModel):
class MemoryPerformanceProfile(CamelCaseModel):
ram_total: Memory
ram_available: Memory
swap_total: Memory
@@ -46,6 +44,7 @@ class SystemPerformanceProfile(CamelCaseModel):
sys_power: float = 0.0
pcpu_usage: float = 0.0
ecpu_usage: float = 0.0
ane_power: float = 0.0
class NetworkInterfaceInfo(CamelCaseModel):
@@ -54,12 +53,15 @@ class NetworkInterfaceInfo(CamelCaseModel):
class NodePerformanceProfile(CamelCaseModel):
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
memory: MemoryUsage = MemoryUsage.from_bytes(
ram_total=0, ram_available=0, swap_total=0, swap_available=0
)
network_interfaces: Sequence[NetworkInterfaceInfo] = []
tb_interfaces: Sequence[ThunderboltIdentifier] = []
system: SystemPerformanceProfile = SystemPerformanceProfile()
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
class ConnectionProfile(CamelCaseModel):
throughput: float
latency: float
jitter: float

View File

@@ -2,7 +2,11 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.common import CommandId, Id
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
@@ -56,6 +60,22 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class ImageEdits(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageEditsInternalParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
@@ -67,5 +87,7 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| ImageGeneration
| ImageEdits
| Shutdown
)

View File

@@ -1,81 +0,0 @@
import anyio
from pydantic import BaseModel, Field
from exo.utils.pydantic_ext import CamelCaseModel
class ThunderboltConnection(CamelCaseModel):
source_uuid: str
sink_uuid: str
class ThunderboltIdentifier(CamelCaseModel):
rdma_interface: str
domain_uuid: str
## Intentionally minimal, only collecting data we care about - there's a lot more
class _ReceptacleTag(BaseModel, extra="ignore"):
receptacle_id_key: str | None = None
class _ConnectivityItem(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
class ThunderboltConnectivityData(BaseModel, extra="ignore"):
domain_uuid_key: str | None = None
items: list[_ConnectivityItem] | None = Field(None, alias="_items")
receptacle_1_tag: _ReceptacleTag | None = None
def ident(self, ifaces: dict[str, str]) -> ThunderboltIdentifier | None:
if (
self.domain_uuid_key is None
or self.receptacle_1_tag is None
or self.receptacle_1_tag.receptacle_id_key is None
):
return
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
assert tag in ifaces # doesn't need to be an assertion but im confident
# if tag not in ifaces: return None
iface = f"rdma_{ifaces[tag]}"
return ThunderboltIdentifier(
rdma_interface=iface, domain_uuid=self.domain_uuid_key
)
def conn(self) -> ThunderboltConnection | None:
if self.domain_uuid_key is None or self.items is None:
return
sink_key = next(
(
item.domain_uuid_key
for item in self.items
if item.domain_uuid_key is not None
),
None,
)
if sink_key is None:
return None
return ThunderboltConnection(
source_uuid=self.domain_uuid_key, sink_uuid=sink_key
)
class ThunderboltConnectivity(BaseModel, extra="ignore"):
SPThunderboltDataType: list[ThunderboltConnectivityData] = []
@classmethod
async def gather(cls) -> list[ThunderboltConnectivityData] | None:
proc = await anyio.run_process(
["system_profiler", "SPThunderboltDataType", "-json"], check=False
)
if proc.returncode != 0:
return None
# Saving you from PascalCase while avoiding too much pydantic
return ThunderboltConnectivity.model_validate_json(
proc.stdout
).SPThunderboltDataType

View File

@@ -1,41 +1,37 @@
from collections.abc import Iterator
from dataclasses import dataclass
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.utils.pydantic_ext import FrozenModel
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
from exo.utils.pydantic_ext import CamelCaseModel
@dataclass(frozen=True)
class Cycle:
node_ids: list[NodeId]
def __len__(self) -> int:
return self.node_ids.__len__()
def __iter__(self) -> Iterator[NodeId]:
return self.node_ids.__iter__()
class NodeInfo(CamelCaseModel):
node_id: NodeId
node_profile: NodePerformanceProfile | None = None
class RDMAConnection(FrozenModel):
source_rdma_iface: str
sink_rdma_iface: str
class Connection(CamelCaseModel):
local_node_id: NodeId
send_back_node_id: NodeId
send_back_multiaddr: Multiaddr
connection_profile: ConnectionProfile | None = None
def __hash__(self) -> int:
return hash(
(
self.local_node_id,
self.send_back_node_id,
self.send_back_multiaddr.address,
)
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise ValueError("Cannot compare Connection with non-Connection")
return (
self.local_node_id == other.local_node_id
and self.send_back_node_id == other.send_back_node_id
and self.send_back_multiaddr == other.send_back_multiaddr
)
def is_thunderbolt(self) -> bool:
return True
class SocketConnection(FrozenModel):
sink_multiaddr: Multiaddr
def __hash__(self):
return hash(self.sink_multiaddr.ip_address)
def is_thunderbolt(self) -> bool:
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
class Connection(FrozenModel):
source: NodeId
sink: NodeId
edge: RDMAConnection | SocketConnection
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")

View File

@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
class MlxJacclInstance(BaseInstance):
jaccl_devices: list[list[str | None]]
ibv_devices: list[list[str | None]]
jaccl_coordinators: dict[NodeId, str]

View File

@@ -0,0 +1,43 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from typing import Callable
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
SystemPerformanceProfile,
)
class ResourceCollector(ABC):
@abstractmethod
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
class SystemResourceCollector(ResourceCollector):
async def collect(self) -> SystemPerformanceProfile: ...
class MemoryResourceCollector(ResourceCollector):
async def collect(self) -> MemoryPerformanceProfile: ...
class ResourceMonitor:
data_collectors: list[ResourceCollector]
effect_handlers: set[
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
]
async def _collect(
self,
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
tasks: list[
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
] = [collector.collect() for collector in self.data_collectors]
return await asyncio.gather(*tasks)
async def collect(self) -> None:
profiles = await self._collect()
for profile in profiles:
for effect_handler in self.effect_handlers:
effect_handler(profile)

View File

@@ -1,4 +1,7 @@
from exo.shared.types.api import FinishReason, GenerationStats
from collections.abc import Generator
from typing import Any, Literal
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
from exo.utils.pydantic_ext import TaggedModel
@@ -18,5 +21,32 @@ class GenerationResponse(BaseRunnerResponse):
stats: GenerationStats | None = None
class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
stats: ImageGenerationStats | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
class PartialImageResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
partial_index: int
total_partials: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -2,8 +2,8 @@ from collections.abc import Mapping
from pydantic import model_validator
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.models import ModelMetadata
from exo.utils.pydantic_ext import TaggedModel
@@ -17,7 +17,7 @@ class BaseShardMetadata(TaggedModel):
Replaces previous `Shard` object.
"""
model_card: ModelCard
model_meta: ModelMetadata
device_rank: int
world_size: int
@@ -41,7 +41,7 @@ class BaseShardMetadata(TaggedModel):
def __hash__(self) -> int:
return hash(
(
self.model_card.model_id,
self.model_meta.model_id,
self.start_layer,
self.end_layer,
self.n_layers,

View File

@@ -1,235 +0,0 @@
import os
import shutil
import sys
import tomllib
from collections.abc import Sequence
from dataclasses import dataclass, field
from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, open_process
from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
)
from exo.shared.types.thunderbolt import (
ThunderboltConnection,
ThunderboltConnectivity,
ThunderboltIdentifier,
)
from exo.utils.channels import Sender
from exo.utils.pydantic_ext import TaggedModel
from .macmon import MacmonMetrics
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
IS_DARWIN = sys.platform == "darwin"
class StaticNodeInformation(TaggedModel):
"""Node information that should NEVER change, to be gathered once at startup"""
model: str
chip: str
@classmethod
async def gather(cls) -> Self:
model, chip = await get_model_and_chip()
return cls(model=model, chip=chip)
class NodeNetworkInterfaces(TaggedModel):
ifaces: Sequence[NetworkInterfaceInfo]
class MacThunderboltIdentifiers(TaggedModel):
idents: Sequence[ThunderboltIdentifier]
class MacThunderboltConnections(TaggedModel):
conns: Sequence[ThunderboltConnection]
class NodeConfig(TaggedModel):
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)
await cfg_file.touch(exist_ok=True)
async with await cfg_file.open("rb") as f:
try:
contents = (await f.read()).decode("utf-8")
data = tomllib.loads(contents)
return cls.model_validate(data)
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
logger.warning("Invalid config file, skipping...")
return None
class MiscData(TaggedModel):
"""Node information that may slowly change that doesn't fall into the other categories"""
friendly_name: str
@classmethod
async def gather(cls) -> Self:
return cls(friendly_name=await get_friendly_name())
async def _gather_iface_map() -> dict[str, str] | None:
proc = await anyio.run_process(
["networksetup", "-listallhardwareports"], check=False
)
if proc.returncode != 0:
return None
ports: dict[str, str] = {}
port = ""
for line in proc.stdout.decode("utf-8").split("\n"):
if line.startswith("Hardware Port:"):
port = line.split(": ")[1]
elif line.startswith("Device:"):
ports[port] = line.split(": ")[1]
port = ""
if "" in ports:
del ports[""]
return ports
GatheredInfo = (
MacmonMetrics
| MemoryUsage
| NodeNetworkInterfaces
| MacThunderboltIdentifiers
| MacThunderboltConnections
| NodeConfig
| MiscData
| StaticNodeInformation
)
@dataclass
class InfoGatherer:
info_sender: Sender[GatheredInfo]
interface_watcher_interval: float | None = 10
misc_poll_interval: float | None = 60
system_profiler_interval: float | None = 5 if IS_DARWIN else None
memory_poll_rate: float | None = None if IS_DARWIN else 1
macmon_interval: float | None = 1 if IS_DARWIN else None
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
async def run(self):
async with self._tg as tg:
if IS_DARWIN:
if (macmon_path := shutil.which("macmon")) is not None:
tg.start_soon(self._monitor_macmon, macmon_path)
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
tg.start_soon(self._watch_system_info)
tg.start_soon(self._monitor_memory_usage)
tg.start_soon(self._monitor_misc)
nc = await NodeConfig.gather()
if nc is not None:
await self.info_sender.send(nc)
sni = await StaticNodeInformation.gather()
await self.info_sender.send(sni)
def shutdown(self):
self._tg.cancel_scope.cancel()
async def _monitor_misc(self):
if self.misc_poll_interval is None:
return
prev = await MiscData.gather()
await self.info_sender.send(prev)
while True:
curr = await MiscData.gather()
if prev != curr:
prev = curr
await self.info_sender.send(curr)
await anyio.sleep(self.misc_poll_interval)
async def _monitor_system_profiler_thunderbolt_data(self):
if self.system_profiler_interval is None:
return
iface_map = await _gather_iface_map()
if iface_map is None:
return
old_idents = []
while True:
data = await ThunderboltConnectivity.gather()
assert data is not None
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
if idents != old_idents:
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
old_idents = idents
conns = [it for i in data if (it := i.conn()) is not None]
await self.info_sender.send(MacThunderboltConnections(conns=conns))
await anyio.sleep(self.system_profiler_interval)
async def _monitor_memory_usage(self):
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
override_memory: int | None = (
Memory.from_mb(int(override_memory_env)).in_bytes
if override_memory_env
else None
)
if self.memory_poll_rate is None:
return
while True:
await self.info_sender.send(
MemoryUsage.from_psutil(override_memory=override_memory)
)
await anyio.sleep(self.memory_poll_rate)
async def _watch_system_info(self):
if self.interface_watcher_interval is None:
return
old_nics = []
while True:
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
await anyio.sleep(self.interface_watcher_interval)
async def _monitor_macmon(self, macmon_path: str):
if self.macmon_interval is None:
return
# macmon pipe --interval [interval in ms]
try:
async with await open_process(
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
) as p:
if not p.stdout:
logger.critical("MacMon closed stdout")
return
async for text in TextReceiveStream(
BufferedByteReceiveStream(p.stdout)
):
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
except CalledProcessError as e:
stderr_msg = "no stderr"
stderr_output = cast(bytes | str | None, e.stderr)
if stderr_output is not None:
stderr_msg = (
stderr_output.decode()
if isinstance(stderr_output, bytes)
else str(stderr_output)
)
logger.warning(
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
)

View File

@@ -1,70 +0,0 @@
from typing import Self
from pydantic import BaseModel
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
from exo.utils.pydantic_ext import TaggedModel
class _TempMetrics(BaseModel, extra="ignore"):
"""Temperature-related metrics returned by macmon."""
cpu_temp_avg: float
gpu_temp_avg: float
class _MemoryMetrics(BaseModel, extra="ignore"):
"""Memory-related metrics returned by macmon."""
ram_total: int
ram_usage: int
swap_total: int
swap_usage: int
class RawMacmonMetrics(BaseModel, extra="ignore"):
"""Complete set of metrics returned by macmon.
Unknown fields are ignored for forward-compatibility.
"""
timestamp: str # ignored
temp: _TempMetrics
memory: _MemoryMetrics
ecpu_usage: tuple[int, float] # freq mhz, usage %
pcpu_usage: tuple[int, float] # freq mhz, usage %
gpu_usage: tuple[int, float] # freq mhz, usage %
all_power: float
ane_power: float
cpu_power: float
gpu_power: float
gpu_ram_power: float
ram_power: float
sys_power: float
class MacmonMetrics(TaggedModel):
system_profile: SystemPerformanceProfile
memory: MemoryUsage
@classmethod
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
return cls(
system_profile=SystemPerformanceProfile(
gpu_usage=raw.gpu_usage[1],
temp=raw.temp.gpu_temp_avg,
sys_power=raw.sys_power,
pcpu_usage=raw.pcpu_usage[1],
ecpu_usage=raw.ecpu_usage[1],
),
memory=MemoryUsage.from_bytes(
ram_total=raw.memory.ram_total,
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
swap_total=raw.memory.swap_total,
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
),
)
@classmethod
def from_raw_json(cls, json: str) -> Self:
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))

View File

@@ -1,114 +0,0 @@
from collections.abc import Mapping
import anyio
import httpx
from anyio import create_task_group
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
REACHABILITY_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
out: dict[NodeId, set[str]],
client: httpx.AsyncClient,
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
url = f"http://[{target_ip}]:52415/node_id"
else:
url = f"http://{target_ip}:52415/node_id"
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
try:
r = await client.get(url)
if r.status_code != 200:
await anyio.sleep(1)
continue
body = r.text.strip().strip('"')
if not body:
await anyio.sleep(1)
continue
remote_node_id = NodeId(body)
break
# expected failure cases
except (
httpx.TimeoutException,
httpx.NetworkError,
):
await anyio.sleep(1)
# other failures should be logged on last attempt
except httpx.HTTPError as e:
last_error = e
await anyio.sleep(1)
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
if remote_node_id is None:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
f"ip={target_ip}, expected_node_id={expected_node_id}, "
f"remote_node_id={remote_node_id}"
)
return
if remote_node_id not in out:
out[remote_node_id] = set()
out[remote_node_id].add(target_ip)
async def check_reachable(
topology: Topology,
self_node_id: NodeId,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
# these are intentionally httpx's defaults so we can tune them later
timeout = httpx.Timeout(timeout=5.0)
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=5,
)
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
for node_id in topology.list_nodes():
if node_id not in node_profiles:
continue
if node_id == self_node_id:
continue
for iface in node_profiles[node_id].network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node_id,
reachable,
client,
)
return reachable

View File

@@ -1,24 +0,0 @@
import sys
import pytest
from exo.shared.types.thunderbolt import (
ThunderboltConnectivity,
)
from exo.utils.info_gatherer.info_gatherer import (
_gather_iface_map, # pyright: ignore[reportPrivateUsage]
)
@pytest.mark.anyio
@pytest.mark.skipif(
sys.platform != "darwin", reason="Thunderbolt info can only be gathered on macos"
)
async def test_tb_parsing():
data = await ThunderboltConnectivity.gather()
ifaces = await _gather_iface_map()
assert ifaces
assert data
for datum in data:
datum.ident(ifaces)
datum.conn()

View File

@@ -19,20 +19,11 @@ class CamelCaseModel(BaseModel):
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
strict=True,
)
class FrozenModel(BaseModel):
model_config = ConfigDict(
alias_generator=to_camel,
validate_by_name=True,
extra="forbid",
strict=True,
frozen=True,
)
class TaggedModel(CamelCaseModel):
@model_serializer(mode="wrap")
def _serialize(self, handler: SerializerFunctionWrapHandler):

View File

@@ -28,8 +28,9 @@ def bar(send: MpSender[str]):
send.close()
# not async, just want the fail_after
@pytest.mark.anyio
async def test_channel_ipc():
async def test_channel_setup():
with fail_after(0.5):
s, r = mp_channel[str]()
p1 = mp.Process(target=foo, args=(r,))

View File

@@ -5,11 +5,11 @@ import shutil
import ssl
import time
import traceback
from collections.abc import Awaitable
from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal
from urllib.parse import urljoin
from huggingface_hub._snapshot_download import snapshot_download
import aiofiles
import aiofiles.os as aios
@@ -246,15 +246,12 @@ def create_http_session(
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
)
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,
@@ -445,12 +442,31 @@ def calculate_repo_progress(
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
index_files_dir = snapshot_download(
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
return index_data.weight_map
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
weight_map: dict[str, str] = {}
for index_file in index_files:
relative_dir = index_file.parent.relative_to(index_files_dir)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
if relative_dir != Path("."):
prefixed_weight_map = {
f"{relative_dir}/{key}": str(relative_dir / value)
for key, value in index_data.weight_map.items()
}
weight_map = weight_map | prefixed_weight_map
else:
weight_map = weight_map | index_data.weight_map
return weight_map
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
@@ -460,10 +476,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_card.model_id))
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
logger.error(f"Error getting weight map for {shard.model_card.model_id=}")
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
logger.error(traceback.format_exc())
return ["*"]
@@ -526,24 +542,24 @@ async def download_progress_for_local_path(
async def download_shard(
shard: ShardMetadata,
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], None],
max_parallel_downloads: int = 8,
skip_download: bool = False,
allow_patterns: list[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=}")
logger.info(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_card.model_id)):
logger.info(f"Using local model path {shard.model_card.model_id}")
local_path = Path(str(shard.model_card.model_id))
if await aios.path.exists(str(shard.model_meta.model_id)):
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_card.model_id), shard, local_path
str(shard.model_meta.model_id), shard, local_path
)
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
target_dir = await ensure_models_dir() / str(shard.model_meta.model_id).replace(
"/", "--"
)
if not skip_download:
@@ -552,13 +568,11 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_card.model_id), revision, recursive=True
str(shard.model_meta.model_id), revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -567,9 +581,9 @@ async def download_shard(
)
file_progress: dict[str, RepoFileDownloadProgress] = {}
async def on_progress_wrapper(
def on_progress_wrapper(
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
) -> None:
):
start_time = (
file_progress[file.path].start_time
if file.path in file_progress
@@ -592,7 +606,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -605,11 +619,11 @@ async def download_shard(
else "in_progress",
start_time=start_time,
)
await on_progress(
on_progress(
shard,
calculate_repo_progress(
shard,
str(shard.model_card.model_id),
str(shard.model_meta.model_id),
revision,
file_progress,
all_start_time,
@@ -619,7 +633,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -633,21 +647,14 @@ async def download_shard(
semaphore = asyncio.Semaphore(max_parallel_downloads)
def schedule_progress(
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
) -> None:
asyncio.create_task(
on_progress_wrapper(file, curr_bytes, total_bytes, is_renamed)
)
async def download_with_semaphore(file: FileListEntry) -> None:
async def download_with_semaphore(file: FileListEntry):
async with semaphore:
await download_file_with_retry(
str(shard.model_card.model_id),
str(shard.model_meta.model_id),
revision,
file.path,
target_dir,
lambda curr_bytes, total_bytes, is_renamed: schedule_progress(
lambda curr_bytes, total_bytes, is_renamed: on_progress_wrapper(
file, curr_bytes, total_bytes, is_renamed
),
)
@@ -657,9 +664,9 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
return target_dir / gguf.path, final_repo_progress
else:

View File

@@ -100,26 +100,68 @@ def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list
"*.py",
"tokenizer.model",
"tiktoken.model",
"*/spiece.model",
"*.tiktoken",
"*.txt",
"*.jinja",
]
)
shard_specific_patterns: set[str] = set()
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num <= shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
if shard.model_meta.components is not None:
shardable_component = next(
(c for c in shard.model_meta.components if c.can_shard), None
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
if weight_map and shardable_component:
for tensor_name, filename in weight_map.items():
# Strip component prefix from tensor name (added by weight map namespacing)
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
if "/" in tensor_name:
_, tensor_name_no_prefix = tensor_name.split("/", 1)
else:
tensor_name_no_prefix = tensor_name
# Determine which component this file belongs to from filename
component_path = Path(filename).parts[0] if "/" in filename else None
if component_path == shardable_component.component_path.rstrip("/"):
layer_num = extract_layer_num(tensor_name_no_prefix)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
if shard.is_first_layer or shard.is_last_layer:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns = set(["*.safetensors"])
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
for component in shard.model_meta.components:
if not component.can_shard and component.safetensors_index_filename is None:
component_pattern = f"{component.component_path.rstrip('/')}/*"
shard_specific_patterns.add(component_pattern)
else:
shard_specific_patterns = set(["*.safetensors"])
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
else:
shard_specific_patterns = set(["*.safetensors"])
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

View File

@@ -1,10 +1,9 @@
import asyncio
from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -20,21 +19,21 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: str) -> ShardMetadata:
model_card = await get_model_card(model_id)
model_meta = await get_model_meta(model_id)
return PipelineShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=model_card.n_layers,
n_layers=model_card.n_layers,
end_layer=model_meta.n_layers,
n_layers=model_meta.n_layers,
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_card=base_shard.model_card,
model_meta=base_shard.model_meta,
device_rank=base_shard.device_rank,
world_size=base_shard.world_size,
start_layer=base_shard.start_layer,
@@ -49,8 +48,7 @@ class SingletonShardDownloader(ShardDownloader):
self.active_downloads: dict[ShardMetadata, asyncio.Task[Path]] = {}
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
) -> None:
self.shard_downloader.on_progress(callback)
@@ -85,19 +83,18 @@ class CachedShardDownloader(ShardDownloader):
self.cache: dict[tuple[str, ShardMetadata], Path] = {}
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
) -> None:
self.shard_downloader.on_progress(callback)
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
if (shard.model_card.model_id, shard) in self.cache:
return self.cache[(shard.model_card.model_id, shard)]
if (shard.model_meta.model_id, shard) in self.cache:
return self.cache[(shard.model_meta.model_id, shard)]
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
self.cache[(shard.model_card.model_id, shard)] = target_dir
self.cache[(shard.model_meta.model_id, shard)] = target_dir
return target_dir
async def get_shard_download_status(
@@ -116,18 +113,17 @@ class ResumableShardDownloader(ShardDownloader):
def __init__(self, max_parallel_downloads: int = 8):
self.max_parallel_downloads = max_parallel_downloads
self.on_progress_callbacks: list[
Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]]
Callable[[ShardMetadata, RepoDownloadProgress], None]
] = []
async def on_progress_wrapper(
def on_progress_wrapper(
self, shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
for callback in self.on_progress_callbacks:
await callback(shard, progress)
callback(shard, progress)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
) -> None:
self.on_progress_callbacks.append(callback)

View File

@@ -1,12 +1,11 @@
from abc import ABC, abstractmethod
from collections.abc import Awaitable
from copy import copy
from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -32,8 +31,7 @@ class ShardDownloader(ABC):
@abstractmethod
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
) -> None:
pass
@@ -61,8 +59,7 @@ class NoopShardDownloader(ShardDownloader):
return Path("/tmp/noop_shard")
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
) -> None:
pass
@@ -86,8 +83,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,

View File

@@ -0,0 +1,12 @@
from exo.worker.engines.image.distributed_model import (
DistributedImageModel,
initialize_image_model,
)
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
__all__ = [
"DistributedImageModel",
"generate_image",
"initialize_image_model",
"warmup_image_generator",
]

View File

@@ -0,0 +1,50 @@
from enum import Enum
from math import ceil
from pydantic import BaseModel
class BlockType(Enum):
JOINT = "joint" # Separate image/text streams
SINGLE = "single" # Concatenated streams
class TransformerBlockConfig(BaseModel):
model_config = {"frozen": True}
block_type: BlockType
count: int
has_separate_text_output: bool # True for joint blocks that output text separately
class ImageModelConfig(BaseModel):
model_family: str
block_configs: tuple[TransformerBlockConfig, ...]
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
num_sync_steps_factor: float # Fraction of steps for sync phase
guidance_scale: float | None = None # None or <= 1.0 disables CFG
@property
def total_blocks(self) -> int:
return sum(bc.count for bc in self.block_configs)
@property
def joint_block_count(self) -> int:
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
)
@property
def single_block_count(self) -> int:
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
)
def get_steps_for_quality(self, quality: str) -> int:
return self.default_steps[quality]
def get_num_sync_steps(self, steps: int) -> int:
return ceil(steps * self.num_sync_steps_factor)

View File

@@ -0,0 +1,166 @@
from collections.abc import Generator
from pathlib import Path
from typing import Literal, Optional
import mlx.core as mx
from mflux.models.common.config.config import Config
from PIL import Image
from exo.shared.types.api import AdvancedImageParams
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,
get_config_for_model,
)
from exo.worker.engines.image.models.base import ModelAdapter
from exo.worker.engines.image.pipeline import DiffusionRunner
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
from exo.worker.runner.bootstrap import logger
class DistributedImageModel:
_config: ImageModelConfig
_adapter: ModelAdapter
_runner: DiffusionRunner
def __init__(
self,
model_id: str,
local_path: Path,
shard_metadata: PipelineShardMetadata,
group: Optional[mx.distributed.Group] = None,
quantize: int | None = None,
):
config = get_config_for_model(model_id)
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
if group is not None:
adapter.slice_transformer_blocks(
start_layer=shard_metadata.start_layer,
end_layer=shard_metadata.end_layer,
)
runner = DiffusionRunner(
config=config,
adapter=adapter,
group=group,
shard_metadata=shard_metadata,
)
if group is not None:
logger.info("Initialized distributed diffusion runner")
mx.eval(adapter.model.parameters())
# TODO(ciaran): Do we need this?
mx.eval(adapter.model)
mx_barrier(group)
logger.info(f"Transformer sharded for rank {group.rank()}")
else:
logger.info("Single-node initialization")
self._config = config
self._adapter = adapter
self._runner = runner
@classmethod
def from_bound_instance(
cls, bound_instance: BoundInstance
) -> "DistributedImageModel":
model_id = bound_instance.bound_shard.model_meta.model_id
model_path = build_model_path(model_id)
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, PipelineShardMetadata):
raise ValueError("Expected PipelineShardMetadata for image generation")
is_distributed = (
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
)
if is_distributed:
logger.info("Starting distributed init for image model")
group = mlx_distributed_init(bound_instance)
else:
group = None
return cls(
model_id=model_id,
local_path=model_path,
shard_metadata=shard_metadata,
group=group,
)
def get_steps_for_quality(self, quality: Literal["low", "medium", "high"]) -> int:
"""Get the number of inference steps for a quality level."""
return self._config.get_steps_for_quality(quality)
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"] = "medium",
seed: int = 2,
image_path: Path | None = None,
partial_images: int = 0,
advanced_params: AdvancedImageParams | None = None,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
if (
advanced_params is not None
and advanced_params.num_inference_steps is not None
):
steps = advanced_params.num_inference_steps
else:
steps = self._config.get_steps_for_quality(quality)
guidance_override: float | None = None
if advanced_params is not None and advanced_params.guidance is not None:
guidance_override = advanced_params.guidance
negative_prompt: str | None = None
if advanced_params is not None and advanced_params.negative_prompt is not None:
negative_prompt = advanced_params.negative_prompt
# For edit mode: compute dimensions from input image
# This also stores image_paths in the adapter for encode_prompt()
if image_path is not None:
computed_dims = self._adapter.set_image_dimensions(image_path)
if computed_dims is not None:
# Override user-provided dimensions with computed ones
width, height = computed_dims
config = Config(
num_inference_steps=steps,
height=height,
width=width,
image_path=image_path,
model_config=self._adapter.model.model_config,
)
num_sync_steps = self._config.get_num_sync_steps(steps)
for result in self._runner.generate_image(
runtime_config=config,
prompt=prompt,
seed=seed,
partial_images=partial_images,
guidance_override=guidance_override,
negative_prompt=negative_prompt,
num_sync_steps=num_sync_steps,
):
if isinstance(result, tuple):
# Partial image: (GeneratedImage, partial_index, total_partials)
image, partial_idx, total_partials = result
yield (image, partial_idx, total_partials)
else:
logger.info("generated image")
yield result
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
return DistributedImageModel.from_bound_instance(bound_instance)

View File

@@ -0,0 +1,170 @@
import base64
import io
import random
import tempfile
import time
from pathlib import Path
from typing import Generator, Literal
import mlx.core as mx
from PIL import Image
from exo.shared.types.api import (
AdvancedImageParams,
ImageEditsInternalParams,
ImageGenerationStats,
ImageGenerationTaskParams,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.runner_response import (
ImageGenerationResponse,
PartialImageResponse,
)
from exo.worker.engines.image.distributed_model import DistributedImageModel
def parse_size(size_str: str | None) -> tuple[int, int]:
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
if not size_str or size_str == "auto":
size_str = "1024x1024"
try:
parts = size_str.split("x")
if len(parts) == 2:
width, height = int(parts[0]), int(parts[1])
return (width, height)
except (ValueError, AttributeError):
pass
# Default fallback
return (1024, 1024)
def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
"""Warmup the image generator with a small image."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a small dummy image for warmup (needed for edit models)
dummy_image = Image.new("RGB", (256, 256), color=(128, 128, 128))
dummy_path = Path(tmpdir) / "warmup.png"
dummy_image.save(dummy_path)
warmup_params = AdvancedImageParams(num_inference_steps=2)
for result in model.generate(
prompt="Warmup",
height=256,
width=256,
quality="low",
image_path=dummy_path,
advanced_params=warmup_params,
):
if not isinstance(result, tuple):
return result
return None
def generate_image(
model: DistributedImageModel,
task: ImageGenerationTaskParams | ImageEditsInternalParams,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
When partial_images > 0 or stream=True, yields PartialImageResponse for
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0)
ImageGenerationResponse for the final complete image
"""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
advanced_params = task.advanced_params
if advanced_params is not None and advanced_params.seed is not None:
seed = advanced_params.seed
else:
seed = random.randint(0, 2**32 - 1)
is_bench = getattr(task, "bench", False)
generation_start_time: float = 0.0
if is_bench:
mx.reset_peak_memory()
generation_start_time = time.perf_counter()
partial_images = task.partial_images or (3 if task.stream else 0)
image_path: Path | None = None
with tempfile.TemporaryDirectory() as tmpdir:
if isinstance(task, ImageEditsInternalParams):
# Decode base64 image data and save to temp file
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
# Iterate over generator results
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
)
else:
image = result
stats: ImageGenerationStats | None = None
if is_bench:
generation_end_time = time.perf_counter()
total_generation_time = generation_end_time - generation_start_time
num_inference_steps = model.get_steps_for_quality(quality)
seconds_per_step = (
total_generation_time / num_inference_steps
if num_inference_steps > 0
else 0.0
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=task.n or 1,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
)

View File

@@ -0,0 +1,84 @@
from pathlib import Path
from typing import Callable
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter
from exo.worker.engines.image.models.flux import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
FluxModelAdapter,
)
from exo.worker.engines.image.models.qwen import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
QwenEditModelAdapter,
QwenModelAdapter,
)
__all__: list[str] = []
# Type alias for adapter factory functions
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter]
# Registry maps model_family string to adapter factory
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
"flux": FluxModelAdapter,
"qwen-edit": QwenEditModelAdapter,
"qwen": QwenModelAdapter,
}
# Config registry: maps model ID patterns to configs
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
"qwen-image": QWEN_IMAGE_CONFIG,
}
def get_config_for_model(model_id: str) -> ImageModelConfig:
"""Get configuration for a model ID.
Args:
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
Returns:
The model configuration
Raises:
ValueError: If no configuration found for model ID
"""
model_id_lower = model_id.lower()
for pattern, config in _CONFIG_REGISTRY.items():
if pattern in model_id_lower:
return config
raise ValueError(f"No configuration found for model: {model_id}")
def create_adapter_for_model(
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
) -> ModelAdapter:
"""Create a model adapter for the given configuration.
Args:
config: The model configuration
model_id: The model identifier
local_path: Path to the model weights
quantize: Optional quantization bits
Returns:
A ModelAdapter instance
Raises:
ValueError: If no adapter found for model family
"""
factory = _ADAPTER_REGISTRY.get(config.model_family)
if factory is None:
raise ValueError(f"No adapter found for model family: {config.model_family}")
return factory(config, model_id, local_path, quantize)

View File

@@ -0,0 +1,376 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
from mflux.utils.image_util import ImageUtil
from PIL import Image
from exo.worker.engines.image.config import ImageModelConfig
if TYPE_CHECKING:
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class PromptData(ABC):
"""Abstract base class for encoded prompt data.
All adapters must return prompt data that inherits from this class.
Model-specific prompt data classes can add additional attributes
(e.g., attention masks for Qwen).
"""
@property
@abstractmethod
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
...
@property
@abstractmethod
def pooled_prompt_embeds(self) -> mx.array:
"""Pooled text embeddings (for Flux) or placeholder (for Qwen)."""
...
@property
@abstractmethod
def negative_prompt_embeds(self) -> mx.array | None:
"""Negative prompt embeddings for CFG (None if not using CFG)."""
...
@property
@abstractmethod
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Negative pooled embeddings for CFG (None if not using CFG)."""
...
@abstractmethod
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
"""Get encoder hidden states mask for attention.
Args:
positive: If True, return mask for positive prompt pass.
If False, return mask for negative prompt pass.
Returns:
Attention mask array (Qwen) or None (Flux).
"""
...
@property
@abstractmethod
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
"""Conditioning image grid dimensions for edit mode.
Returns:
Grid dimensions (Qwen edit) or None (standard generation).
"""
...
@property
@abstractmethod
def conditioning_latents(self) -> mx.array | None:
"""Conditioning latents for edit mode.
Returns:
Conditioning latents array for image editing, None for standard generation.
"""
...
@abstractmethod
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
"""Get embeddings for CFG with batch_size=2.
Combines positive and negative embeddings into batched tensors for
a single forward pass. Pads shorter sequences to max length. Attention
mask is used to mask padding.
Returns:
None if model doesn't support CFG, otherwise tuple of:
- batched_embeds: [2, max_seq, hidden] (positive then negative)
- batched_mask: [2, max_seq] attention mask
- batched_pooled: [2, hidden] pooled embeddings or None
- conditioning_latents: [2, latent_seq, latent_dim] or None
TODO(ciaran): type this
"""
...
class ModelAdapter(ABC):
"""Base class for model adapters with shared utilities."""
_config: ImageModelConfig
_model: Any
_transformer: Any
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> Any:
return self._model
@property
def transformer(self) -> Any:
return self._transformer
@property
@abstractmethod
def hidden_dim(self) -> int:
"""Return the size of hidden_dim."""
...
@property
@abstractmethod
def needs_cfg(self) -> bool:
"""Whether this model uses classifier-free guidance.
Returns:
True if model requires two forward passes with guidance (e.g., Qwen)
False if model uses a single forward pass (e.g., Flux)
"""
...
@abstractmethod
def _get_latent_creator(self) -> type:
"""Return the latent creator class for this model."""
...
@abstractmethod
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list["JointBlockWrapper"]:
"""Create wrapped joint transformer blocks with pipefusion support.
Args:
text_seq_len: Number of text tokens (constant for generation)
encoder_hidden_states_mask: Attention mask for text (Qwen only)
Returns:
List of wrapped joint blocks ready for pipefusion
"""
...
@abstractmethod
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list["SingleBlockWrapper"]:
"""Create wrapped single transformer blocks with pipefusion support.
Args:
text_seq_len: Number of text tokens (constant for generation)
Returns:
List of wrapped single blocks ready for pipefusion
"""
...
@abstractmethod
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
"""Remove transformer blocks outside the assigned range.
This should be called BEFORE mx.eval() to avoid loading unused weights
in distributed mode.
Args:
start_layer: First layer index (inclusive) assigned to this node
end_layer: Last layer index (exclusive) assigned to this node
"""
...
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
"""Default implementation: no dimension computation needed.
Override in edit adapters to compute dimensions from input image.
Returns:
None (use user-specified dimensions)
"""
return None
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial latents. Uses model-specific latent creator."""
return LatentCreator.create_for_txt2img_or_img2img(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
img2img=Img2Img(
vae=self.model.vae,
latent_creator=self._get_latent_creator(),
sigmas=runtime_config.scheduler.sigmas,
init_time_step=runtime_config.init_time_step,
image_path=runtime_config.image_path,
),
)
def decode_latents(
self,
latents: mx.array,
runtime_config: Config,
seed: int,
prompt: str,
) -> Image.Image:
"""Decode latents to image. Shared implementation."""
latents = self._get_latent_creator().unpack_latents(
latents=latents,
height=runtime_config.height,
width=runtime_config.width,
)
decoded = self.model.vae.decode(latents)
# TODO(ciaran):
# from mflux.models.common.vae.vae_util import VAEUtil
# VAEUtil.decode(vae=self.model.vae, latents=latents, tiling_config=self.tiling_config)
generated_image = ImageUtil.to_image(
decoded_latents=decoded,
config=runtime_config,
seed=seed,
prompt=prompt,
quantization=self.model.bits,
lora_paths=self.model.lora_paths,
lora_scales=self.model.lora_scales,
image_path=runtime_config.image_path,
image_strength=runtime_config.image_strength,
generation_time=0,
)
return generated_image.image
@abstractmethod
def encode_prompt(
self, prompt: str, negative_prompt: str | None = None
) -> "PromptData":
"""Encode prompt into model-specific prompt data.
Args:
prompt: Text prompt
negative_prompt: Negative prompt for CFG
Returns:
PromptData containing embeddings (and model-specific extras)
"""
...
@abstractmethod
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute x_embedder and context_embedder outputs.
Args:
hidden_states: Input latent states
prompt_embeds: Text embeddings from encoder
Returns:
Tuple of (embedded_hidden_states, embedded_encoder_states)
"""
...
@abstractmethod
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings for conditioning.
Args:
t: Current timestep
runtime_config: Runtime configuration
pooled_prompt_embeds: Pooled text embeddings (used by Flux)
hidden_states: Image hidden states
Returns:
Text embeddings tensor
"""
...
@abstractmethod
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> Any:
"""Compute rotary position embeddings.
Args:
prompt_embeds: Text embeddings
runtime_config: Runtime configuration
encoder_hidden_states_mask: Attention mask for text (Qwen)
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
kontext_image_ids: Kontext image position IDs (Flux)
Returns:
Flux: mx.array
Qwen: tuple[mx.array, mx.array]
"""
...
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
@abstractmethod
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
"""Apply classifier-free guidance to combine positive/negative predictions.
Only called when needs_cfg is True.
Args:
noise_positive: Noise prediction from positive prompt
noise_negative: Noise prediction from negative prompt
guidance_scale: Guidance strength
Returns:
Guided noise prediction
"""
...
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final norm and projection.
Args:
hidden_states: Hidden states (image only, text already removed)
text_embeddings: Conditioning embeddings
Returns:
Projected output
"""
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)

View File

@@ -0,0 +1,11 @@
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
from exo.worker.engines.image.models.flux.config import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
)
__all__ = [
"FluxModelAdapter",
"FLUX_DEV_CONFIG",
"FLUX_SCHNELL_CONFIG",
]

View File

@@ -0,0 +1,218 @@
from pathlib import Path
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.variants.txt2img.flux import Flux1
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter, PromptData
from exo.worker.engines.image.models.flux.wrappers import (
FluxJointBlockWrapper,
FluxSingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class FluxPromptData(PromptData):
"""Container for Flux prompt encoding results."""
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
self._prompt_embeds = prompt_embeds
self._pooled_prompt_embeds = pooled_prompt_embeds
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._pooled_prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
"""Flux does not use encoder hidden states mask."""
return None
@property
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
"""Flux does not use conditioning image grid."""
return None
@property
def conditioning_latents(self) -> mx.array | None:
"""Flux does not use conditioning latents."""
return None
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
"""Flux does not use CFG."""
return None
class FluxModelAdapter(ModelAdapter):
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = Flux1(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
model_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def hidden_dim(self) -> int:
return self._transformer.x_embedder.weight.shape[0]
@property
def needs_cfg(self) -> bool:
return False
def _get_latent_creator(self) -> type:
return FluxLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper]:
"""Create wrapped joint blocks for Flux."""
return [
FluxJointBlockWrapper(block, text_seq_len)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper]:
"""Create wrapped single blocks for Flux."""
return [
FluxSingleBlockWrapper(block, text_seq_len)
for block in self._transformer.single_transformer_blocks
]
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
all_joint = list(self._transformer.transformer_blocks)
all_single = list(self._transformer.single_transformer_blocks)
total_joint_blocks = len(all_joint)
if end_layer <= total_joint_blocks:
# All assigned are joint blocks
joint_start, joint_end = start_layer, end_layer
single_start, single_end = 0, 0
elif start_layer >= total_joint_blocks:
# All assigned are single blocks
joint_start, joint_end = 0, 0
single_start = start_layer - total_joint_blocks
single_end = end_layer - total_joint_blocks
else:
# Spans both joint and single
joint_start, joint_end = start_layer, total_joint_blocks
single_start = 0
single_end = end_layer - total_joint_blocks
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
self._transformer.single_transformer_blocks = all_single[
single_start:single_end
]
def encode_prompt(
self, prompt: str, negative_prompt: str | None = None
) -> FluxPromptData:
del negative_prompt
assert isinstance(self.model.prompt_cache, dict)
assert isinstance(self.model.tokenizers, dict)
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
prompt_cache=self.model.prompt_cache,
t5_tokenizer=self.model.tokenizers["t5"],
clip_tokenizer=self.model.tokenizers["clip"],
t5_text_encoder=self.model.t5_text_encoder,
clip_text_encoder=self.model.clip_text_encoder,
)
return FluxPromptData(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.x_embedder(hidden_states)
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None, # Ignored by Flux
) -> mx.array:
if pooled_prompt_embeds is None:
raise ValueError(
"pooled_prompt_embeds is required for Flux text embeddings"
)
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
return Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
return Transformer.compute_rotary_embeddings(
prompt_embeds,
self._transformer.pos_embed,
runtime_config,
kontext_image_ids,
)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
raise NotImplementedError("Flux does not use classifier-free guidance")

View File

@@ -0,0 +1,34 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
FLUX_SCHNELL_CONFIG = ImageModelConfig(
model_family="flux",
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
default_steps={"low": 1, "medium": 2, "high": 4},
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
)
FLUX_DEV_CONFIG = ImageModelConfig(
model_family="flux",
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
)

View File

@@ -0,0 +1,279 @@
import mlx.core as mx
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
AttentionUtils,
)
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
JointTransformerBlock,
)
from mflux.models.flux.model.flux_transformer.single_transformer_block import (
SingleTransformerBlock,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class FluxJointBlockWrapper(JointBlockWrapper):
"""Flux-specific joint block wrapper with pipefusion support."""
def __init__(self, block: JointTransformerBlock, text_seq_len: int):
super().__init__(block, text_seq_len)
# Cache attention parameters from block
self._num_heads = block.attn.num_heads
self._head_dim = block.attn.head_dimension
# Intermediate state stored between _compute_qkv and _apply_output
self._gate_msa: mx.array | None = None
self._shift_mlp: mx.array | None = None
self._scale_mlp: mx.array | None = None
self._gate_mlp: mx.array | None = None
self._c_gate_msa: mx.array | None = None
self._c_shift_mlp: mx.array | None = None
self._c_scale_mlp: mx.array | None = None
self._c_gate_mlp: mx.array | None = None
def _compute_qkv(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
patch_mode: bool = False,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V for sequence with Flux-specific logic.
Args:
hidden_states: Image hidden states [B, num_img_tokens, D] or patch [B, patch_len, D]
encoder_hidden_states: Text hidden states [B, text_seq_len, D]
text_embeddings: Conditioning embeddings [B, D]
rotary_embeddings: Rotary position embeddings
patch_mode: If True, slice RoPE for current patch range
"""
attn = self.block.attn
# 1. Compute norms (store gates for _apply_output)
(
norm_hidden,
self._gate_msa,
self._shift_mlp,
self._scale_mlp,
self._gate_mlp,
) = self.block.norm1(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
(
norm_encoder,
self._c_gate_msa,
self._c_shift_mlp,
self._c_scale_mlp,
self._c_gate_mlp,
) = self.block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V for image
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
# 3. Compute Q, K, V for text
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
# 4. Concatenate Q, K, V: [text, image/patch]
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
# 5. Apply RoPE (slice for patch mode)
if patch_mode:
text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:,
:,
self._text_seq_len + self._patch_start : self._text_seq_len
+ self._patch_end,
...,
]
rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
else:
rope = rotary_embeddings
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=rope)
return query, key, value
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
"""Compute scaled dot-product attention."""
batch_size = query.shape[0]
return AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
) -> tuple[mx.array, mx.array]:
"""Apply output projection, feed-forward, and residuals."""
attn = self.block.attn
# 1. Extract text and image attention outputs
context_attn_output = attn_out[:, : self._text_seq_len, :]
hidden_attn_output = attn_out[:, self._text_seq_len :, :]
# 2. Project outputs
hidden_attn_output = attn.to_out[0](hidden_attn_output)
context_attn_output = attn.to_add_out(context_attn_output)
# 3. Apply norm and feed forward (using stored gates)
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=hidden_states,
attn_output=hidden_attn_output,
gate_mlp=self._gate_mlp,
gate_msa=self._gate_msa,
scale_mlp=self._scale_mlp,
shift_mlp=self._shift_mlp,
norm_layer=self.block.norm2,
ff_layer=self.block.ff,
)
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=encoder_hidden_states,
attn_output=context_attn_output,
gate_mlp=self._c_gate_mlp,
gate_msa=self._c_gate_msa,
scale_mlp=self._c_scale_mlp,
shift_mlp=self._c_shift_mlp,
norm_layer=self.block.norm2_context,
ff_layer=self.block.ff_context,
)
return encoder_hidden_states, hidden_states
class FluxSingleBlockWrapper(SingleBlockWrapper):
"""Flux-specific single block wrapper with pipefusion support."""
def __init__(self, block: SingleTransformerBlock, text_seq_len: int):
super().__init__(block, text_seq_len)
# Cache attention parameters from block
self._num_heads = block.attn.num_heads
self._head_dim = block.attn.head_dimension
# Intermediate state stored between _compute_qkv and _apply_output
self._gate: mx.array | None = None
self._norm_hidden: mx.array | None = None
def _compute_qkv(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
patch_mode: bool = False,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V for [text, image] sequence.
Args:
hidden_states: Concatenated [text, image] hidden states
text_embeddings: Conditioning embeddings [B, D]
rotary_embeddings: Rotary position embeddings
patch_mode: If True, slice RoPE for current patch range
"""
attn = self.block.attn
# 1. Compute norm (store for _apply_output)
self._norm_hidden, self._gate = self.block.norm(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V
query, key, value = AttentionUtils.process_qkv(
hidden_states=self._norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
# 3. Apply RoPE (slice for patch mode)
if patch_mode:
text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:,
:,
self._text_seq_len + self._patch_start : self._text_seq_len
+ self._patch_end,
...,
]
rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
else:
rope = rotary_embeddings
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=rope)
return query, key, value
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
"""Compute scaled dot-product attention."""
batch_size = query.shape[0]
return AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply feed forward and projection with residual."""
# Residual from original hidden_states
residual = hidden_states
# Apply feed forward and projection (using stored norm and gate)
output = self.block._apply_feed_forward_and_projection(
norm_hidden_states=self._norm_hidden,
attn_output=attn_out,
gate=self._gate,
)
return residual + output

View File

@@ -0,0 +1,13 @@
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
from exo.worker.engines.image.models.qwen.config import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
)
from exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter
__all__ = [
"QwenModelAdapter",
"QwenEditModelAdapter",
"QWEN_IMAGE_CONFIG",
"QWEN_IMAGE_EDIT_CONFIG",
]

View File

@@ -0,0 +1,318 @@
from pathlib import Path
import mlx.core as mx
from mflux.models.common.config import ModelConfig
from mflux.models.common.config.config import Config
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
QwenPromptEncoder,
)
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter, PromptData
from exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class QwenPromptData(PromptData):
"""Container for Qwen prompt encoding results.
Implements PromptData protocol with additional Qwen-specific attributes.
"""
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
):
self._prompt_embeds = prompt_embeds
self._prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self._negative_prompt_mask = negative_prompt_mask
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
return self._prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array:
"""Negative prompt embeddings for CFG."""
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
"""Placeholder - Qwen doesn't use pooled embeds."""
return self._negative_prompt_embeds
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:
"""Return encoder_hidden_states_mask for the appropriate prompt."""
if positive:
return self._prompt_mask
else:
return self._negative_prompt_mask
@property
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
"""Standard Qwen does not use conditioning image grid."""
return None
@property
def conditioning_latents(self) -> mx.array | None:
"""Standard Qwen does not use conditioning latents."""
return None
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
"""Batch positive and negative embeddings for CFG with batch_size=2.
Pads shorter sequence to max length using zeros for embeddings
and zeros (masked) for attention mask.
Returns:
Tuple of (batched_embeds, batched_mask, None, conditioning_latents)
- batched_embeds: [2, max_seq, hidden]
- batched_mask: [2, max_seq]
- None for pooled (Qwen doesn't use it)
- conditioning_latents: [2, latent_seq, latent_dim] or None
"""
pos_embeds = self._prompt_embeds # [1, pos_seq, hidden]
neg_embeds = self._negative_prompt_embeds # [1, neg_seq, hidden]
pos_mask = self._prompt_mask # [1, pos_seq]
neg_mask = self._negative_prompt_mask # [1, neg_seq]
pos_seq_len = pos_embeds.shape[1]
neg_seq_len = neg_embeds.shape[1]
max_seq_len = max(pos_seq_len, neg_seq_len)
hidden_dim = pos_embeds.shape[2]
if pos_seq_len < max_seq_len:
pad_len = max_seq_len - pos_seq_len
pos_embeds = mx.concatenate(
[
pos_embeds,
mx.zeros((1, pad_len, hidden_dim), dtype=pos_embeds.dtype),
],
axis=1,
)
pos_mask = mx.concatenate(
[pos_mask, mx.zeros((1, pad_len), dtype=pos_mask.dtype)],
axis=1,
)
elif neg_seq_len < max_seq_len:
pad_len = max_seq_len - neg_seq_len
neg_embeds = mx.concatenate(
[
neg_embeds,
mx.zeros((1, pad_len, hidden_dim), dtype=neg_embeds.dtype),
],
axis=1,
)
neg_mask = mx.concatenate(
[neg_mask, mx.zeros((1, pad_len), dtype=neg_mask.dtype)],
axis=1,
)
batched_embeds = mx.concatenate([pos_embeds, neg_embeds], axis=0)
batched_mask = mx.concatenate([pos_mask, neg_mask], axis=0)
# TODO(ciaran): currently None but maybe we will deduplicate with edit
# adapter
cond_latents = self.conditioning_latents
if cond_latents is not None:
cond_latents = mx.concatenate([cond_latents, cond_latents], axis=0)
return batched_embeds, batched_mask, None, cond_latents
class QwenModelAdapter(ModelAdapter):
"""Adapter for Qwen-Image model.
Key differences from Flux:
- Single text encoder (vs dual T5+CLIP)
- 60 joint-style blocks, no single blocks
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
- Norm-preserving CFG with negative prompts
- Uses attention mask for variable-length text
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImage(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
model_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper]:
"""Create wrapped joint blocks for Qwen."""
return [
QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
self._transformer.transformer_blocks = self._transformer.transformer_blocks[
start_layer:end_layer
]
def encode_prompt(
self, prompt: str, negative_prompt: str | None = None
) -> QwenPromptData:
assert isinstance(self.model.prompt_cache, dict)
assert isinstance(self.model.tokenizers, dict)
if negative_prompt is None or negative_prompt == "":
negative_prompt = " "
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
QwenPromptEncoder.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_cache=self.model.prompt_cache,
qwen_tokenizer=self.model.tokenizers["qwen"],
qwen_text_encoder=self.model.text_encoder,
)
)
return QwenPromptData(
prompt_embeds=prompt_embeds,
prompt_mask=prompt_mask,
negative_prompt_embeds=neg_embeds,
negative_prompt_mask=neg_mask,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute image and text embeddings."""
embedded_hidden = self._transformer.img_in(hidden_states)
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings.
For Qwen, the time_text_embed only uses hidden_states for:
- batch_size (shape[0])
- dtype
This allows us to pass any tensor (latents, prompt_embeds) as a fallback
when embedded hidden_states are not yet available.
"""
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
# (which for Qwen is the same as prompt_embeds)
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
"""Compute 3D rotary embeddings for Qwen.
Qwen uses video-aware 3D RoPE with separate embeddings for image and text.
Returns:
tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]:
((img_cos, img_sin), (txt_cos, txt_sin))
"""
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
return self._model.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)

View File

@@ -0,0 +1,29 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
QWEN_IMAGE_CONFIG = ImageModelConfig(
model_family="qwen",
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
)
QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
model_family="qwen-edit",
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125,
guidance_scale=3.5,
)

View File

@@ -0,0 +1,459 @@
import math
from pathlib import Path
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.variants.edit.qwen_edit_util import QwenEditUtil
from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter, PromptData
from exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class QwenEditPromptData(PromptData):
"""Container for Qwen edit prompt encoding results.
Includes vision-language encoded embeddings and edit-specific conditioning.
"""
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
conditioning_latents: mx.array,
qwen_image_ids: mx.array,
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],
):
self._prompt_embeds = prompt_embeds
self._prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self._negative_prompt_mask = negative_prompt_mask
self._conditioning_latents = conditioning_latents
self._qwen_image_ids = qwen_image_ids
self._cond_image_grid = cond_image_grid
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from vision-language encoder."""
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
return self._prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array:
"""Negative prompt embeddings for CFG."""
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
"""Placeholder - Qwen doesn't use pooled embeds."""
return self._negative_prompt_embeds
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:
"""Return encoder_hidden_states_mask for the appropriate prompt."""
if positive:
return self._prompt_mask
else:
return self._negative_prompt_mask
@property
def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:
"""Conditioning image grid dimensions."""
return self._cond_image_grid
@property
def conditioning_latents(self) -> mx.array:
"""Static image conditioning latents to concatenate with generated latents."""
return self._conditioning_latents
@property
def qwen_image_ids(self) -> mx.array:
"""Spatial position IDs for conditioning images."""
return self._qwen_image_ids
@property
def is_edit_mode(self) -> bool:
"""Indicates this is edit mode with conditioning latents."""
return True
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
"""Batch positive and negative embeddings for CFG with batch_size=2.
Pads shorter sequence to max length using zeros for embeddings
and zeros (masked) for attention mask. Duplicates conditioning
latents for both positive and negative passes.
Returns:
Tuple of (batched_embeds, batched_mask, None, batched_cond_latents)
- batched_embeds: [2, max_seq, hidden]
- batched_mask: [2, max_seq]
- None for pooled (Qwen doesn't use it)
- batched_cond_latents: [2, latent_seq, latent_dim]
TODO(ciaran): type this
"""
pos_embeds = self._prompt_embeds # [1, pos_seq, hidden]
neg_embeds = self._negative_prompt_embeds # [1, neg_seq, hidden]
pos_mask = self._prompt_mask # [1, pos_seq]
neg_mask = self._negative_prompt_mask # [1, neg_seq]
pos_seq_len = pos_embeds.shape[1]
neg_seq_len = neg_embeds.shape[1]
max_seq_len = max(pos_seq_len, neg_seq_len)
hidden_dim = pos_embeds.shape[2]
if pos_seq_len < max_seq_len:
pad_len = max_seq_len - pos_seq_len
pos_embeds = mx.concatenate(
[
pos_embeds,
mx.zeros((1, pad_len, hidden_dim), dtype=pos_embeds.dtype),
],
axis=1,
)
pos_mask = mx.concatenate(
[pos_mask, mx.zeros((1, pad_len), dtype=pos_mask.dtype)],
axis=1,
)
if neg_seq_len < max_seq_len:
pad_len = max_seq_len - neg_seq_len
neg_embeds = mx.concatenate(
[
neg_embeds,
mx.zeros((1, pad_len, hidden_dim), dtype=neg_embeds.dtype),
],
axis=1,
)
neg_mask = mx.concatenate(
[neg_mask, mx.zeros((1, pad_len), dtype=neg_mask.dtype)],
axis=1,
)
batched_embeds = mx.concatenate([pos_embeds, neg_embeds], axis=0)
batched_mask = mx.concatenate([pos_mask, neg_mask], axis=0)
batched_cond_latents = mx.concatenate(
[self._conditioning_latents, self._conditioning_latents], axis=0
)
return batched_embeds, batched_mask, None, batched_cond_latents
class QwenEditModelAdapter(ModelAdapter):
"""Adapter for Qwen-Image-Edit model.
Key differences from standard QwenModelAdapter:
- Uses QwenImageEdit model with vision-language components
- Encodes prompts WITH input images via VL tokenizer/encoder
- Creates conditioning latents from input images
- Supports image editing with concatenated latents during diffusion
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImageEdit(
quantize=quantize,
model_path=str(local_path),
)
self._transformer = self._model.transformer
self._vl_width: int | None = None
self._vl_height: int | None = None
self._vae_width: int | None = None
self._vae_height: int | None = None
self._image_paths: list[str] | None = None
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> QwenImageEdit:
return self._model
@property
def transformer(self) -> QwenTransformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper]:
"""Create wrapped joint blocks for Qwen Edit."""
return [
QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
self._transformer.transformer_blocks = self._transformer.transformer_blocks[
start_layer:end_layer
]
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
"""Compute and store dimensions from input image.
Also stores image_paths for use in encode_prompt().
Returns:
(output_width, output_height) for runtime config
"""
vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(
image_path
)
self._vl_width = vl_w
self._vl_height = vl_h
self._vae_width = vae_w
self._vae_height = vae_h
self._image_paths = [str(image_path)]
return out_w, out_h
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial noise latents (pure noise for edit mode)."""
return QwenLatentCreator.create_noise(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
)
def encode_prompt(
self, prompt: str, negative_prompt: str | None = None
) -> QwenEditPromptData:
if (
self._image_paths is None
or self._vl_height is None
or self._vl_width is None
or self._vae_height is None
or self._vae_width is None
):
raise RuntimeError(
"set_image_dimensions() must be called before encode_prompt() "
"for QwenEditModelAdapter"
)
if negative_prompt is None or negative_prompt == "":
negative_prompt = " "
image_paths = self._image_paths
# TODO(ciaran): config is untyped and unused, unsure if Config or RuntimeConfig is intended
(
prompt_embeds,
prompt_mask,
negative_prompt_embeds,
negative_prompt_mask,
) = self._model._encode_prompts_with_images(
prompt,
negative_prompt,
image_paths,
self._config,
self._vl_width,
self._vl_height,
)
(
conditioning_latents,
qwen_image_ids,
cond_h_patches,
cond_w_patches,
num_images,
) = QwenEditUtil.create_image_conditioning_latents(
vae=self._model.vae,
height=self._vae_height,
width=self._vae_width,
image_paths=image_paths,
vl_width=self._vl_width,
vl_height=self._vl_height,
)
# Build cond_image_grid
if num_images > 1:
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [
(1, cond_h_patches, cond_w_patches) for _ in range(num_images)
]
else:
cond_image_grid = (1, cond_h_patches, cond_w_patches)
return QwenEditPromptData(
prompt_embeds=prompt_embeds,
prompt_mask=prompt_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_mask=negative_prompt_mask,
conditioning_latents=conditioning_latents,
qwen_image_ids=qwen_image_ids,
cond_image_grid=cond_image_grid,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute image and text embeddings."""
embedded_hidden = self._transformer.img_in(hidden_states)
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings."""
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
"""Compute 3D rotary embeddings for Qwen edit."""
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams."""
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
return QwenImage.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)
def _compute_dimensions_from_image(
self, image_path: Path
) -> tuple[int, int, int, int, int, int]:
"""Compute VL and VAE dimensions from input image.
Returns:
(vl_width, vl_height, vae_width, vae_height, output_width, output_height)
"""
from mflux.utils.image_util import ImageUtil
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
image_size = pil_image.size
# Vision-language dimensions (384x384 target area)
condition_image_size = 384 * 384
condition_ratio = image_size[0] / image_size[1]
vl_width = math.sqrt(condition_image_size * condition_ratio)
vl_height = vl_width / condition_ratio
vl_width = round(vl_width / 32) * 32
vl_height = round(vl_height / 32) * 32
# VAE dimensions (1024x1024 target area)
vae_image_size = 1024 * 1024
vae_ratio = image_size[0] / image_size[1]
vae_width = math.sqrt(vae_image_size * vae_ratio)
vae_height = vae_width / vae_ratio
vae_width = round(vae_width / 32) * 32
vae_height = round(vae_height / 32) * 32
# Output dimensions from input image aspect ratio
target_area = 1024 * 1024
ratio = image_size[0] / image_size[1]
output_width = math.sqrt(target_area * ratio)
output_height = output_width / ratio
output_width = round(output_width / 32) * 32
output_height = round(output_height / 32) * 32
# Ensure multiple of 16 for VAE
vae_scale_factor = 8
multiple_of = vae_scale_factor * 2
output_width = output_width // multiple_of * multiple_of
output_height = output_height // multiple_of * multiple_of
return (
int(vl_width),
int(vl_height),
int(vae_width),
int(vae_height),
int(output_width),
int(output_height),
)

View File

@@ -0,0 +1,220 @@
import mlx.core as mx
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
QwenTransformerBlock,
)
from exo.worker.engines.image.pipeline.block_wrapper import JointBlockWrapper
class QwenJointBlockWrapper(JointBlockWrapper):
"""Qwen-specific joint block wrapper with pipefusion support.
Qwen differs from Flux in several ways:
- Uses modulation parameters computed from text_embeddings
- Uses 3D RoPE with separate (cos, sin) for image and text
- Uses attention mask for variable-length text
"""
def __init__(
self,
block: QwenTransformerBlock,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
):
super().__init__(block, text_seq_len)
self._encoder_hidden_states_mask = encoder_hidden_states_mask
# Cache attention parameters from block
self._num_heads = block.attn.num_heads
self._head_dim = block.attn.head_dim
# Intermediate state stored between _compute_qkv and _apply_output
self._img_mod1: mx.array | None = None
self._img_mod2: mx.array | None = None
self._txt_mod1: mx.array | None = None
self._txt_mod2: mx.array | None = None
self._img_gate1: mx.array | None = None
self._txt_gate1: mx.array | None = None
def set_encoder_mask(self, mask: mx.array | None) -> None:
"""Set the encoder hidden states mask for attention."""
self._encoder_hidden_states_mask = mask
def _compute_qkv(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
patch_mode: bool = False,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V for sequence with Qwen-specific logic.
Args:
hidden_states: Image hidden states [B, num_img_tokens, D] or patch [B, patch_len, D]
encoder_hidden_states: Text hidden states [B, text_seq_len, D]
text_embeddings: Conditioning embeddings [B, D]
rotary_embeddings: Tuple of ((img_cos, img_sin), (txt_cos, txt_sin))
patch_mode: If True, slice RoPE for current patch range
"""
batch_size = hidden_states.shape[0]
img_seq_len = hidden_states.shape[1]
attn = self.block.attn
# 1. Compute modulation parameters
img_mod_params = self.block.img_mod_linear(
self.block.img_mod_silu(text_embeddings)
)
txt_mod_params = self.block.txt_mod_linear(
self.block.txt_mod_silu(text_embeddings)
)
self._img_mod1, self._img_mod2 = mx.split(img_mod_params, 2, axis=-1)
self._txt_mod1, self._txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
# 2. Apply normalization and modulation
img_normed = self.block.img_norm1(hidden_states)
img_modulated, self._img_gate1 = QwenTransformerBlock._modulate(
img_normed, self._img_mod1
)
txt_normed = self.block.txt_norm1(encoder_hidden_states)
txt_modulated, self._txt_gate1 = QwenTransformerBlock._modulate(
txt_normed, self._txt_mod1
)
# 3. Compute Q, K, V for image
img_query = attn.to_q(img_modulated)
img_key = attn.to_k(img_modulated)
img_value = attn.to_v(img_modulated)
# 4. Compute Q, K, V for text
txt_query = attn.add_q_proj(txt_modulated)
txt_key = attn.add_k_proj(txt_modulated)
txt_value = attn.add_v_proj(txt_modulated)
# 5. Reshape to [B, S, H, D]
img_query = mx.reshape(
img_query, (batch_size, img_seq_len, self._num_heads, self._head_dim)
)
img_key = mx.reshape(
img_key, (batch_size, img_seq_len, self._num_heads, self._head_dim)
)
img_value = mx.reshape(
img_value, (batch_size, img_seq_len, self._num_heads, self._head_dim)
)
txt_query = mx.reshape(
txt_query,
(batch_size, self._text_seq_len, self._num_heads, self._head_dim),
)
txt_key = mx.reshape(
txt_key, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
)
txt_value = mx.reshape(
txt_value, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
)
# 6. Apply RMSNorm to Q, K
img_query = attn.norm_q(img_query)
img_key = attn.norm_k(img_key)
txt_query = attn.norm_added_q(txt_query)
txt_key = attn.norm_added_k(txt_key)
# 7. Apply RoPE (Qwen uses 3D RoPE with separate embeddings)
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
if patch_mode:
# Slice image RoPE for patch, keep full text RoPE
img_cos = img_cos[self._patch_start : self._patch_end]
img_sin = img_sin[self._patch_start : self._patch_end]
img_query = QwenAttention._apply_rope_qwen(img_query, img_cos, img_sin)
img_key = QwenAttention._apply_rope_qwen(img_key, img_cos, img_sin)
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
# 8. Transpose to [B, H, S, D] for attention
img_query = mx.transpose(img_query, (0, 2, 1, 3))
img_key = mx.transpose(img_key, (0, 2, 1, 3))
img_value = mx.transpose(img_value, (0, 2, 1, 3))
txt_query = mx.transpose(txt_query, (0, 2, 1, 3))
txt_key = mx.transpose(txt_key, (0, 2, 1, 3))
txt_value = mx.transpose(txt_value, (0, 2, 1, 3))
# 9. Concatenate [text, image/patch]
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
return query, key, value
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
"""Compute scaled dot-product attention with Qwen-specific mask."""
attn = self.block.attn
# Build attention mask
mask = QwenAttention._convert_mask_for_qwen(
mask=self._encoder_hidden_states_mask,
joint_seq_len=key.shape[2],
txt_seq_len=self._text_seq_len,
)
# Transpose back to [B, S, H, D] for Qwen's attention
query_bshd = mx.transpose(query, (0, 2, 1, 3))
key_bshd = mx.transpose(key, (0, 2, 1, 3))
value_bshd = mx.transpose(value, (0, 2, 1, 3))
return attn._compute_attention_qwen(
query=query_bshd,
key=key_bshd,
value=value_bshd,
mask=mask,
block_idx=None,
)
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
) -> tuple[mx.array, mx.array]:
"""Apply output projection, feed-forward, and residuals."""
attn = self.block.attn
# 1. Extract text and image attention outputs
txt_attn_output = attn_out[:, : self._text_seq_len, :]
img_attn_output = attn_out[:, self._text_seq_len :, :]
# 2. Project outputs
img_attn_output = attn.attn_to_out[0](img_attn_output)
txt_attn_output = attn.to_add_out(txt_attn_output)
# 3. Apply residual + gate for attention
hidden_states = hidden_states + self._img_gate1 * img_attn_output
encoder_hidden_states = (
encoder_hidden_states + self._txt_gate1 * txt_attn_output
)
# 4. Apply feed-forward for image
img_normed2 = self.block.img_norm2(hidden_states)
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
img_normed2, self._img_mod2
)
img_mlp_output = self.block.img_ff(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
# 5. Apply feed-forward for text
txt_normed2 = self.block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
txt_normed2, self._txt_mod2
)
txt_mlp_output = self.block.txt_ff(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
return encoder_hidden_states, hidden_states

View File

@@ -0,0 +1,15 @@
from exo.worker.engines.image.pipeline.block_wrapper import (
BlockWrapperMode,
JointBlockWrapper,
SingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
__all__ = [
"BlockWrapperMode",
"DiffusionRunner",
"ImagePatchKVCache",
"JointBlockWrapper",
"SingleBlockWrapper",
]

View File

@@ -0,0 +1,392 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Self
import mlx.core as mx
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class BlockWrapperMode(Enum):
CACHING = "caching" # Sync mode: compute full attention, populate cache
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
class BlockWrapperMixin:
"""Common cache management logic for block wrappers.
Including:
- KV cache creation and management
- Mode
- Patch range setting
"""
_text_seq_len: int
_kv_cache: ImagePatchKVCache | None
_mode: BlockWrapperMode
_patch_start: int
_patch_end: int
def _init_cache_state(self, text_seq_len: int) -> None:
self._text_seq_len = text_seq_len
self._kv_cache = None
self._mode = BlockWrapperMode.CACHING
self._patch_start = 0
self._patch_end = 0
def set_patch(
self,
mode: BlockWrapperMode,
patch_start: int = 0,
patch_end: int = 0,
) -> Self:
"""Set mode and patch range.
Args:
mode: CACHING (full attention) or PATCHED (use cached KV)
patch_start: Start token index within image (for PATCHED mode)
patch_end: End token index within image (for PATCHED mode)
Returns:
Self for method chaining
"""
self._mode = mode
self._patch_start = patch_start
self._patch_end = patch_end
return self
def set_text_seq_len(self, text_seq_len: int) -> None:
self._text_seq_len = text_seq_len
def _get_active_cache(self) -> ImagePatchKVCache | None:
return self._kv_cache
def _ensure_cache(self, img_key: mx.array) -> None:
if self._kv_cache is None:
batch, num_heads, img_seq_len, head_dim = img_key.shape
self._kv_cache = ImagePatchKVCache(
batch_size=batch,
num_heads=num_heads,
image_seq_len=img_seq_len,
head_dim=head_dim,
)
def _cache_full_image_kv(self, img_key: mx.array, img_value: mx.array) -> None:
self._ensure_cache(img_key)
cache = self._get_active_cache()
assert cache is not None
cache.update_image_patch(0, img_key.shape[2], img_key, img_value)
def _cache_patch_kv(self, img_key: mx.array, img_value: mx.array) -> None:
cache = self._get_active_cache()
assert cache is not None
cache.update_image_patch(self._patch_start, self._patch_end, img_key, img_value)
def _get_full_kv(
self, text_key: mx.array, text_value: mx.array
) -> tuple[mx.array, mx.array]:
cache = self._get_active_cache()
assert cache is not None
return cache.get_full_kv(text_key, text_value)
def reset_cache(self) -> None:
self._kv_cache = None
class JointBlockWrapper(BlockWrapperMixin, ABC):
"""Base class for joint transformer block wrappers with pipefusion support.
Subclass this to add pipefusion support to any model's joint blocks.
The wrapper:
- Owns its KV cache (created lazily on first CACHING forward)
- Controls the forward pass flow (CACHING vs PATCHED mode)
- Handles patch slicing and cache operations
Model subclass provides:
- _compute_qkv: Compute Q, K, V tensors (norms, projections, RoPE)
- _compute_attention: Run scaled dot-product attention
- _apply_output: Apply output projection, feed-forward, residuals
"""
def __init__(self, block: Any, text_seq_len: int):
"""Initialize the joint block wrapper.
Args:
block: The joint transformer block to wrap
text_seq_len: Number of text tokens (constant for entire generation)
"""
self.block = block
self._init_cache_state(text_seq_len)
def set_encoder_mask(self, mask: mx.array | None) -> None: # noqa: B027
"""Set the encoder hidden states mask for attention.
Override in subclasses that use attention masks (e.g., Qwen).
Default is a no-op for models that don't use masks (e.g., Flux).
"""
del mask # Unused in base class
def __call__(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> tuple[mx.array, mx.array]:
"""Apply the joint block.
Args:
hidden_states: Image hidden states [B, num_img_tokens, D]
encoder_hidden_states: Text hidden states [B, text_seq_len, D]
text_embeddings: Conditioning embeddings [B, D]
rotary_embeddings: Rotary position embeddings (model-specific format)
Returns:
Tuple of (encoder_hidden_states, hidden_states) - text and image outputs
"""
if self._mode == BlockWrapperMode.CACHING:
return self._forward_caching(
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
)
return self._forward_patched(
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
)
def _forward_caching(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> tuple[mx.array, mx.array]:
"""CACHING mode: Full attention, store image K/V in cache."""
# Model computes Q/K/V for full sequence
query, key, value = self._compute_qkv(
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
)
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_full_image_kv(img_key, img_value)
attn_out = self._compute_attention(query, key, value)
return self._apply_output(
attn_out, hidden_states, encoder_hidden_states, text_embeddings
)
def _forward_patched(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> tuple[mx.array, mx.array]:
"""PATCHED mode: Compute patch Q/K/V, use cached image K/V for attention."""
# hidden_states is already the patch (provided by runner)
patch_hidden = hidden_states
query, key, value = self._compute_qkv(
patch_hidden,
encoder_hidden_states,
text_embeddings,
rotary_embeddings,
patch_mode=True,
)
text_key = key[:, :, : self._text_seq_len, :]
text_value = value[:, :, : self._text_seq_len, :]
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_patch_kv(img_key, img_value)
full_key, full_value = self._get_full_kv(text_key, text_value)
attn_out = self._compute_attention(query, full_key, full_value)
return self._apply_output(
attn_out, patch_hidden, encoder_hidden_states, text_embeddings
)
@abstractmethod
def _compute_qkv(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
patch_mode: bool = False,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V tensors for sequence.
Includes normalization, projections, concatenation, and RoPE.
Args:
hidden_states: Image hidden states [B, num_img_tokens, D] or patch [B, patch_len, D]
encoder_hidden_states: Text hidden states [B, text_seq_len, D]
text_embeddings: Conditioning embeddings [B, D]
rotary_embeddings: Rotary position embeddings
patch_mode: If True, slice RoPE for current patch range
Returns:
Tuple of (query, key, value) with shape [B, H, text+img/patch, head_dim]
"""
...
@abstractmethod
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
"""Compute scaled dot-product attention.
Args:
query: Query tensor [B, H, Q_len, head_dim]
key: Key tensor [B, H, KV_len, head_dim]
value: Value tensor [B, H, KV_len, head_dim]
Returns:
Attention output [B, Q_len, D]
"""
...
@abstractmethod
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
) -> tuple[mx.array, mx.array]:
"""Apply output projection, feed-forward, and residuals.
Args:
attn_out: Attention output [B, text+img, D]
hidden_states: Original image hidden states (for residual)
encoder_hidden_states: Original text hidden states (for residual)
text_embeddings: Conditioning embeddings
Returns:
Tuple of (encoder_hidden_states, hidden_states) - updated text and image
"""
...
class SingleBlockWrapper(BlockWrapperMixin, ABC):
"""Base class for single-stream transformer block wrappers.
Similar to JointBlockWrapper but for blocks that operate on a single
concatenated [text, image] stream rather than separate streams.
"""
def __init__(self, block: Any, text_seq_len: int):
"""Initialize the single block wrapper.
Args:
block: The single transformer block to wrap
text_seq_len: Number of text tokens (constant for entire generation)
"""
self.block = block
self._init_cache_state(text_seq_len)
def __call__(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> mx.array:
"""Apply the single block.
Args:
hidden_states: Concatenated [text, image] hidden states
text_embeddings: Conditioning embeddings [B, D]
rotary_embeddings: Rotary position embeddings
Returns:
Updated hidden states [B, text+img, D]
"""
if self._mode == BlockWrapperMode.CACHING:
return self._forward_caching(
hidden_states, text_embeddings, rotary_embeddings
)
return self._forward_patched(hidden_states, text_embeddings, rotary_embeddings)
def _forward_caching(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> mx.array:
"""CACHING mode: Full attention, store image K/V in cache."""
query, key, value = self._compute_qkv(
hidden_states, text_embeddings, rotary_embeddings
)
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_full_image_kv(img_key, img_value)
attn_out = self._compute_attention(query, key, value)
return self._apply_output(attn_out, hidden_states, text_embeddings)
def _forward_patched(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> mx.array:
"""PATCHED mode: Compute patch Q/K/V, use cached image K/V for attention."""
# hidden_states is already [text, patch]
query, key, value = self._compute_qkv(
hidden_states, text_embeddings, rotary_embeddings, patch_mode=True
)
text_key = key[:, :, : self._text_seq_len, :]
text_value = value[:, :, : self._text_seq_len, :]
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_patch_kv(img_key, img_value)
full_key, full_value = self._get_full_kv(text_key, text_value)
attn_out = self._compute_attention(query, full_key, full_value)
return self._apply_output(attn_out, hidden_states, text_embeddings)
@abstractmethod
def _compute_qkv(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
patch_mode: bool = False,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V tensors for sequence.
Args:
hidden_states: Concatenated [text, image] hidden states
text_embeddings: Conditioning embeddings [B, D]
rotary_embeddings: Rotary position embeddings
patch_mode: If True, slice RoPE for current patch range
Returns:
Tuple of (query, key, value) with shape [B, H, seq_len, head_dim]
"""
...
@abstractmethod
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
"""Compute scaled dot-product attention."""
...
@abstractmethod
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply output projection, feed-forward, and residuals."""
...

View File

@@ -0,0 +1,72 @@
import mlx.core as mx
class ImagePatchKVCache:
"""KV cache that stores only IMAGE K/V with patch-level updates.
Only caches image K/V since:
- Text K/V is always computed fresh (same for all patches)
- Only image portion needs stale/fresh cache management across patches
"""
def __init__(
self,
batch_size: int,
num_heads: int,
image_seq_len: int,
head_dim: int,
dtype: mx.Dtype = mx.float32,
):
self.batch_size = batch_size
self.num_heads = num_heads
self.image_seq_len = image_seq_len
self.head_dim = head_dim
self._dtype = dtype
self.key_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
self.value_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
def update_image_patch(
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
) -> None:
"""Update cache with fresh K/V for an image patch slice.
Args:
patch_start: Start token index within image portion (0-indexed)
patch_end: End token index within image portion
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
"""
self.key_cache[:, :, patch_start:patch_end, :] = key
self.value_cache[:, :, patch_start:patch_end, :] = value
def get_full_kv(
self, text_key: mx.array, text_value: mx.array
) -> tuple[mx.array, mx.array]:
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
Args:
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
Returns:
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
"""
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
return full_key, full_value
def reset(self) -> None:
"""Reset cache to zeros."""
self.key_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)
self.value_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)

View File

@@ -0,0 +1,979 @@
from math import ceil
from typing import Optional
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.utils.exceptions import StopImageGenerationException
from tqdm import tqdm
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter, PromptData
from exo.worker.engines.image.pipeline.block_wrapper import (
BlockWrapperMode,
JointBlockWrapper,
SingleBlockWrapper,
)
def calculate_patch_heights(latent_height: int, num_patches: int):
patch_height = ceil(latent_height / num_patches)
actual_num_patches = ceil(latent_height / patch_height)
patch_heights = [patch_height] * (actual_num_patches - 1)
last_height = latent_height - patch_height * (actual_num_patches - 1)
patch_heights.append(last_height)
return patch_heights, actual_num_patches
def calculate_token_indices(patch_heights: list[int], latent_width: int):
tokens_per_row = latent_width
token_ranges = []
cumulative_height = 0
for h in patch_heights:
start_token = tokens_per_row * cumulative_height
end_token = tokens_per_row * (cumulative_height + h)
token_ranges.append((start_token, end_token))
cumulative_height += h
return token_ranges
class DiffusionRunner:
"""Orchestrates the diffusion loop for image generation.
This class owns the entire diffusion process, handling both single-node
and distributed (PipeFusion) modes.
In distributed mode, it implements PipeFusion with:
- Sync pipeline for initial timesteps (full image, all devices in lockstep)
- Async pipeline for later timesteps (patches processed independently)
"""
def __init__(
self,
config: ImageModelConfig,
adapter: ModelAdapter,
group: Optional[mx.distributed.Group],
shard_metadata: PipelineShardMetadata,
num_patches: Optional[int] = None,
):
"""Initialize the diffusion runner.
Args:
config: Model configuration (architecture, block counts, etc.)
adapter: Model adapter for model-specific operations
group: MLX distributed group (None for single-node mode)
shard_metadata: Pipeline shard metadata with layer assignments
num_patches: Number of patches for async mode (defaults to world_size)
"""
self.config = config
self.adapter = adapter
self.group = group
if group is None:
self.rank = 0
self.world_size = 1
self.next_rank = 0
self.prev_rank = 0
self.start_layer = 0
self.end_layer = config.total_blocks
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.next_rank = (self.rank + 1) % self.world_size
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.num_patches = num_patches if num_patches else max(1, self.world_size)
self.total_joint = config.joint_block_count
self.total_single = config.single_block_count
self.total_layers = config.total_blocks
self._guidance_override: float | None = None
self._compute_assigned_blocks()
def _compute_assigned_blocks(self) -> None:
"""Determine which joint/single blocks this stage owns."""
start = self.start_layer
end = self.end_layer
if end <= self.total_joint:
self.joint_start = start
self.joint_end = end
self.single_start = 0
self.single_end = 0
elif start >= self.total_joint:
self.joint_start = 0
self.joint_end = 0
self.single_start = start - self.total_joint
self.single_end = end - self.total_joint
else:
self.joint_start = start
self.joint_end = self.total_joint
self.single_start = 0
self.single_end = end - self.total_joint
self.has_joint_blocks = self.joint_end > self.joint_start
self.has_single_blocks = self.single_end > self.single_start
self.owns_concat_stage = self.has_joint_blocks and (
self.has_single_blocks or self.end_layer == self.total_joint
)
# Wrappers created lazily on first forward (need text_seq_len)
self.joint_block_wrappers: list[JointBlockWrapper] | None = None
self.single_block_wrappers: list[SingleBlockWrapper] | None = None
self._wrappers_initialized = False
self._current_text_seq_len: int | None = None
@property
def is_first_stage(self) -> bool:
return self.rank == 0
@property
def is_last_stage(self) -> bool:
return self.rank == self.world_size - 1
@property
def is_distributed(self) -> bool:
return self.group is not None
def _get_effective_guidance_scale(self) -> float | None:
if self._guidance_override is not None:
return self._guidance_override
return self.config.guidance_scale
def _ensure_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> None:
"""Lazily create block wrappers on first forward pass.
Wrappers need text_seq_len which is only known after prompt encoding.
Re-initializes if text_seq_len changes (e.g., warmup vs real generation).
"""
if self._wrappers_initialized and self._current_text_seq_len == text_seq_len:
return
self.joint_block_wrappers = self.adapter.get_joint_block_wrappers(
text_seq_len=text_seq_len,
encoder_hidden_states_mask=encoder_hidden_states_mask,
)
self.single_block_wrappers = self.adapter.get_single_block_wrappers(
text_seq_len=text_seq_len,
)
self._wrappers_initialized = True
self._current_text_seq_len = text_seq_len
def _reset_all_caches(self) -> None:
"""Reset KV caches on all wrappers for a new generation."""
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.reset_cache()
if self.single_block_wrappers:
for wrapper in self.single_block_wrappers:
wrapper.reset_cache()
def _set_text_seq_len(self, text_seq_len: int) -> None:
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_text_seq_len(text_seq_len)
if self.single_block_wrappers:
for wrapper in self.single_block_wrappers:
wrapper.set_text_seq_len(text_seq_len)
def _calculate_capture_steps(
self,
partial_images: int,
init_time_step: int,
num_inference_steps: int,
) -> set[int]:
"""Calculate which timesteps should produce partial images.
Places the first partial after step 1 for fast initial feedback,
then evenly spaces remaining partials with equal gaps between them
and from the last partial to the final image.
Args:
partial_images: Number of partial images to capture
init_time_step: Starting timestep (for img2img this may not be 0)
num_inference_steps: Total inference steps
Returns:
Set of timestep indices to capture
"""
if partial_images <= 0:
return set()
total_steps = num_inference_steps - init_time_step
if total_steps <= 1:
return set()
if partial_images >= total_steps - 1:
return set(range(init_time_step, num_inference_steps - 1))
capture_steps: set[int] = set()
first_capture = init_time_step + 1
capture_steps.add(first_capture)
if partial_images == 1:
return capture_steps
final_step = num_inference_steps - 1
remaining_range = final_step - first_capture
for i in range(1, partial_images):
step_idx = first_capture + int(i * remaining_range / partial_images)
capture_steps.add(step_idx)
return capture_steps
def generate_image(
self,
runtime_config: Config,
prompt: str,
seed: int,
partial_images: int = 0,
guidance_override: float | None = None,
negative_prompt: str | None = None,
num_sync_steps: int = 1,
):
"""Primary entry point for image generation.
Orchestrates the full generation flow:
1. Create runtime config
2. Create initial latents
3. Encode prompt
4. Run diffusion loop (yielding partials if requested)
5. Decode to image
When partial_images > 0, yields (GeneratedImage, partial_index, total_partials)
tuples for intermediate images, then yields the final GeneratedImage.
Args:
settings: Generation config (steps, height, width)
prompt: Text prompt
seed: Random seed
partial_images: Number of intermediate images to yield (0 for none)
guidance_override: Optional override for guidance scale (CFG)
Yields:
Partial images as (GeneratedImage, partial_index, total_partials) tuples
Final GeneratedImage
"""
self._guidance_override = guidance_override
latents = self.adapter.create_latents(seed, runtime_config)
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
capture_steps = self._calculate_capture_steps(
partial_images=partial_images,
init_time_step=runtime_config.init_time_step,
num_inference_steps=runtime_config.num_inference_steps,
)
diffusion_gen = self._run_diffusion_loop(
latents=latents,
prompt_data=prompt_data,
runtime_config=runtime_config,
seed=seed,
prompt=prompt,
capture_steps=capture_steps,
num_sync_steps=num_sync_steps,
)
partial_index = 0
total_partials = len(capture_steps)
if capture_steps:
try:
while True:
partial_latents, _step = next(diffusion_gen)
if self.is_last_stage:
partial_image = self.adapter.decode_latents(
partial_latents, runtime_config, seed, prompt
)
yield (partial_image, partial_index, total_partials)
partial_index += 1
except StopIteration as e:
latents = e.value
else:
try:
while True:
next(diffusion_gen)
except StopIteration as e:
latents = e.value
if self.is_last_stage:
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt)
def _run_diffusion_loop(
self,
latents: mx.array,
prompt_data: PromptData,
runtime_config: Config,
seed: int,
prompt: str,
num_sync_steps: int,
capture_steps: set[int] | None = None,
):
if capture_steps is None:
capture_steps = set()
self._reset_all_caches()
time_steps = tqdm(range(runtime_config.num_inference_steps))
ctx = self.adapter.model.callbacks.start(
seed=seed, prompt=prompt, config=runtime_config
)
ctx.before_loop(
latents=latents,
)
for t in time_steps:
try:
latents = self._diffusion_step(
t=t,
config=runtime_config,
latents=latents,
prompt_data=prompt_data,
num_sync_steps=num_sync_steps,
)
ctx.in_loop(
t=t,
latents=latents,
)
mx.eval(latents)
# Yield partial latents at capture steps (only on last stage)
if t in capture_steps and self.is_last_stage:
yield (latents, t)
except KeyboardInterrupt: # noqa: PERF203
ctx.interruption(t=t, latents=latents)
raise StopImageGenerationException(
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
) from None
ctx.after_loop(latents=latents)
return latents
def _forward_pass(
self,
latents: mx.array,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
t: int,
config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
conditioning_latents: mx.array | None = None,
) -> mx.array:
"""Run a single forward pass through the transformer.
This is the internal method called by adapters via compute_step_noise.
Returns noise prediction without applying scheduler step.
For edit mode, concatenates conditioning latents with generated latents
before the transformer, and extracts only the generated portion after.
Args:
latents: Input latents (already scaled by caller)
prompt_embeds: Text embeddings
pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)
t: Current timestep
config: Runtime configuration
encoder_hidden_states_mask: Attention mask for text (Qwen)
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
conditioning_latents: Conditioning latents for edit mode
Returns:
Noise prediction tensor
"""
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_hidden_states_mask)
if self.joint_block_wrappers and encoder_hidden_states_mask is not None:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_hidden_states_mask)
scaled_latents = config.scheduler.scale_model_input(latents, t)
# For edit mode: concatenate with conditioning latents
original_latent_tokens = scaled_latents.shape[1]
if conditioning_latents is not None:
scaled_latents = mx.concatenate(
[scaled_latents, conditioning_latents], axis=1
)
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
scaled_latents, prompt_embeds
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds, hidden_states=hidden_states
)
rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
encoder_hidden_states_mask=encoder_hidden_states_mask,
cond_image_grid=cond_image_grid,
)
assert self.joint_block_wrappers is not None
for wrapper in self.joint_block_wrappers:
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
)
if self.joint_block_wrappers:
hidden_states = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
)
# Extract image portion
hidden_states = hidden_states[:, text_seq_len:, ...]
# For edit mode: extract only the generated portion (exclude conditioning latents)
if conditioning_latents is not None:
hidden_states = hidden_states[:, :original_latent_tokens, ...]
return self.adapter.final_projection(hidden_states, text_embeddings)
def _diffusion_step(
self,
t: int,
config: Config,
latents: mx.array,
prompt_data: PromptData,
num_sync_steps: int,
) -> mx.array:
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif t < config.init_time_step + num_sync_steps:
return self._sync_pipeline_step(
t,
config,
latents,
prompt_data,
)
else:
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
is_first_async_step=t == config.init_time_step + num_sync_steps,
)
def _single_node_step(
self,
t: int,
config: Config,
latents: mx.array,
prompt_data: PromptData,
) -> mx.array:
cond_image_grid = prompt_data.cond_image_grid
needs_cfg = self.adapter.needs_cfg
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
batched_pooled if batched_pooled is not None else prompt_embeds
)
step_latents = mx.concatenate([latents, latents], axis=0)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = latents
noise = self._forward_pass(
step_latents,
prompt_embeds,
pooled_embeds,
t=t,
config=config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
conditioning_latents=cond_latents,
)
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale=guidance_scale
)
return config.scheduler.step(noise=noise, timestep=t, latents=latents)
def _create_patches(
self,
latents: mx.array,
config: Config,
) -> tuple[list[mx.array], list[tuple[int, int]]]:
latent_height = config.height // 16
latent_width = config.width // 16
patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)
token_indices = calculate_token_indices(patch_heights, latent_width)
patch_latents = [latents[:, start:end, :] for start, end in token_indices]
return patch_latents, token_indices
def _run_sync_pass(
self,
t: int,
config: Config,
scaled_hidden_states: mx.array,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
encoder_hidden_states_mask: mx.array | None,
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] | None,
kontext_image_ids: mx.array | None,
num_img_tokens: int,
original_latent_tokens: int,
conditioning_latents: mx.array | None,
) -> mx.array | None:
hidden_states = scaled_hidden_states
batch_size = hidden_states.shape[0]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
dtype = scaled_hidden_states.dtype
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_hidden_states_mask)
if self.is_first_stage:
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
hidden_states, prompt_embeds
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
encoder_hidden_states_mask=encoder_hidden_states_mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
if self.has_joint_blocks:
if not self.is_first_stage:
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
assert self.joint_block_wrappers is not None
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if self.owns_concat_stage:
concatenated = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
hidden_states = (
mx.distributed.send(hidden_states, self.next_rank, group=self.group),
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if not self.is_last_stage:
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
hidden_states = hidden_states[:, text_seq_len:, ...]
if conditioning_latents is not None:
hidden_states = hidden_states[:, :original_latent_tokens, ...]
if self.is_last_stage:
return self.adapter.final_projection(hidden_states, text_embeddings)
return None
def _sync_pipeline_step(
self,
t: int,
config: Config,
hidden_states: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t)
original_latent_tokens = scaled_hidden_states.shape[1]
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
batched_pooled if batched_pooled is not None else prompt_embeds
)
step_latents = mx.concatenate(
[scaled_hidden_states, scaled_hidden_states], axis=0
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = scaled_hidden_states
if cond_latents is not None:
num_img_tokens = original_latent_tokens + cond_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
if self.is_first_stage and cond_latents is not None:
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
noise = self._run_sync_pass(
t,
config,
step_latents,
prompt_embeds,
pooled_embeds,
encoder_mask,
cond_image_grid,
kontext_image_ids,
num_img_tokens,
original_latent_tokens,
cond_latents,
)
if self.is_last_stage:
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
hidden_states = config.scheduler.step(
noise=noise, timestep=t, latents=prev_latents
)
if not self.is_first_stage:
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
else:
hidden_states = prev_latents
return hidden_states
def _async_pipeline_step(
self,
t: int,
config: Config,
latents: mx.array,
prompt_data: PromptData,
is_first_async_step: bool,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
"""Execute async pipeline step with batched CFG."""
patch_latents, token_indices = self._create_patches(latents, config)
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
pooled_embeds = (
batched_pooled if batched_pooled is not None else prompt_embeds
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_mask)
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
prev_patch_latents = [p for p in patch_latents]
encoder_hidden_states: mx.array | None = None
for patch_idx in range(len(patch_latents)):
patch = patch_latents[patch_idx]
if (
self.is_first_stage
and not self.is_last_stage
and not is_first_async_step
):
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
noise, encoder_hidden_states = self._run_single_patch_pass(
patch=step_patch,
patch_idx=patch_idx,
token_indices=token_indices[patch_idx],
prompt_embeds=prompt_embeds,
text_embeddings=text_embeddings,
image_rotary_embeddings=image_rotary_embeddings,
encoder_hidden_states=encoder_hidden_states,
)
if self.is_last_stage:
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
patch_latents[patch_idx] = config.scheduler.step(
noise=noise,
timestep=t,
latents=prev_patch_latents[patch_idx],
)
if not self.is_first_stage and t != config.num_inference_steps - 1:
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx],
self.next_rank,
group=self.group,
)
mx.async_eval(patch_latents[patch_idx])
return mx.concatenate(patch_latents, axis=1)
def _run_single_patch_pass(
self,
patch: mx.array,
patch_idx: int,
token_indices: tuple[int, int],
prompt_embeds: mx.array,
text_embeddings: mx.array,
image_rotary_embeddings: mx.array,
encoder_hidden_states: mx.array | None,
) -> tuple[mx.array | None, mx.array | None]:
"""Process a single patch through the forward pipeline.
Handles stage-to-stage communication (stage i -> stage i+1).
Ring communication (last stage -> first stage) is handled by the caller.
Args:
patch: The patch latents to process
patch_idx: Index of this patch (0-indexed)
token_indices: (start_token, end_token) for this patch
prompt_embeds: Text embeddings (for compute_embeddings on first stage)
text_embeddings: Precomputed text embeddings
image_rotary_embeddings: Precomputed rotary embeddings
encoder_hidden_states: Encoder hidden states (passed between patches)
Returns:
(noise_prediction, encoder_hidden_states) - noise is None for non-last stages
"""
start_token, end_token = token_indices
batch_size = patch.shape[0]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
if self.has_joint_blocks:
if not self.is_first_stage:
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
if patch_idx == 0:
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
patch, prompt_embeds
)
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if self.owns_concat_stage:
assert encoder_hidden_states is not None
patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if not self.is_last_stage:
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None
if self.is_last_stage:
patch_img_only = patch[:, text_seq_len:, :]
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
return noise, encoder_hidden_states

View File

@@ -46,11 +46,9 @@ class CustomMlxLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable):
super().__init__()
# Set twice to avoid __setattr__ recursion
object.__setattr__(self, "_original_layer", original_layer)
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
self.original_layer: _LayerCallable = original_layer
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -60,7 +58,7 @@ class CustomMlxLayer(nn.Module):
return super().__getattr__(name)
except AttributeError:
original_layer = object.__getattribute__(self, "_original_layer")
return getattr(original_layer, name)
return object.__getattribute__(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
@@ -108,6 +106,7 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
# TODO(ciaran): This is overkill
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
return output
@@ -170,21 +169,11 @@ def pipeline_auto_parallel(
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
start_layer:end_layer
]
# We can assume the model has at least one layer thanks to placement.
# If a layer type doesn't exist, we can set it to 0.
inner_model_instance.swa_idx = (
0
if "sliding_attention" not in inner_model_instance.layer_types # type: ignore
else inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
inner_model_instance.ga_idx = (
0
if "full_attention" not in inner_model_instance.layer_types # type: ignore
else inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
_set_layers(model, layers)

View File

@@ -1,7 +1,7 @@
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper

View File

@@ -2,9 +2,7 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -22,7 +20,6 @@ except ImportError:
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -75,7 +72,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
/ model_shard_meta.n_layers
* model_shard_meta.model_card.storage_size.in_kb
* model_shard_meta.model_meta.storage_size.in_kb
/ (
1
if isinstance(model_shard_meta, PipelineShardMetadata)
@@ -84,45 +81,6 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
)
class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -186,26 +144,20 @@ def mlx_distributed_init(
group = mx.distributed.init(backend="ring", strict=True)
case MlxJacclInstance(
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
):
assert all(
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
)
# Use RDMA connectivity matrix
coordination_file = (
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
)
jaccl_devices_json = json.dumps(jaccl_devices)
ibv_devices_json = json.dumps(ibv_devices)
with open(coordination_file, "w") as f:
_ = f.write(jaccl_devices_json)
_ = f.write(ibv_devices_json)
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
@@ -235,13 +187,11 @@ def initialize_mlx(
def load_mlx_items(
bound_instance: BoundInstance,
group: Group | None,
on_timeout: TimeoutCallback | None = None,
bound_instance: BoundInstance, group: Group | None
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
@@ -251,9 +201,7 @@ def load_mlx_items(
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
@@ -267,9 +215,8 @@ def load_mlx_items(
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_card.model_id)
model_path = build_model_path(shard_metadata.model_meta.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
@@ -304,15 +251,7 @@ def shard_and_load(
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb / 5
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
mx.eval(model.parameters())
# TODO: Do we need this?
mx.eval(model)
@@ -328,7 +267,7 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
@@ -426,8 +365,6 @@ def apply_chat_template(
tools=chat_task_data.tools,
)
logger.info(prompt)
return prompt
@@ -459,11 +396,6 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")

View File

@@ -8,31 +8,36 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
EventId,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
NodeDownloadProgress,
NodeGatheredInfo,
NodeMemoryMeasured,
NodePerformanceMeasured,
TaskCreated,
TaskStatusUpdated,
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
DownloadModel,
ImageEdits,
Shutdown,
Task,
TaskStatus,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.topology import Connection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
@@ -43,14 +48,14 @@ from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
from exo.worker.utils.net_profile import check_reachable
class Worker:
@@ -84,7 +89,7 @@ class Worker:
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
self._tg: TaskGroup | None = None
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
@@ -93,16 +98,44 @@ class Worker:
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
async def run(self):
logger.info("Starting Worker")
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
# TODO: CLEANUP HEADER
async def resource_monitor_callback(
node_performance_profile: NodePerformanceProfile,
) -> None:
await self.event_sender.send(
NodePerformanceMeasured(
node_id=self.node_id,
node_profile=node_performance_profile,
when=str(datetime.now(tz=timezone.utc)),
),
)
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
await self.event_sender.send(
NodeMemoryMeasured(
node_id=self.node_id,
memory=memory_profile,
when=str(datetime.now(tz=timezone.utc)),
)
)
# END CLEANUP
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self.plan_step)
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
@@ -116,17 +149,6 @@ class Worker:
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
async for info in info_stream:
await self.event_sender.send(
NodeGatheredInfo(
node_id=self.node_id,
when=str(datetime.now(tz=timezone.utc)),
info=info,
)
)
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -146,6 +168,7 @@ class Worker:
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
assert self._tg
# Request the next index.
self._tg.start_soon(
self._nack_request, self.state.last_event_applied_idx + 1
@@ -157,6 +180,17 @@ class Worker:
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
)
async def plan_step(self):
while True:
await anyio.sleep(0.1)
@@ -169,6 +203,8 @@ class Worker:
self.state.instances,
self.state.runners,
self.state.tasks,
self.input_chunk_buffer,
self.input_chunk_counts,
)
if task is None:
continue
@@ -186,11 +222,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_card.model_id not in self.download_status:
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -205,7 +241,7 @@ class Worker:
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -232,11 +268,48 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
chunks = self.input_chunk_buffer.get(cmd_id, {})
assembled = "".join(chunks[i] for i in range(len(chunks)))
logger.info(
f"Assembled input image from {len(chunks)} chunks, "
f"total size: {len(assembled)} bytes"
)
# Create modified task with assembled image data
modified_task = ImageEdits(
task_id=task.task_id,
command_id=task.command_id,
instance_id=task.instance_id,
task_status=task.task_status,
task_params=ImageEditsInternalParams(
image_data=assembled,
total_input_chunks=task.task_params.total_input_chunks,
prompt=task.task_params.prompt,
model=task.task_params.model,
n=task.task_params.n,
quality=task.task_params.quality,
output_format=task.task_params.output_format,
response_format=task.task_params.response_format,
size=task.task_params.size,
image_strength=task.task_params.image_strength,
),
)
# Cleanup buffers
if cmd_id in self.input_chunk_buffer:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
if self._tg:
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
@@ -253,28 +326,24 @@ class Worker:
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
)
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
edge=Connection(
local_node_id=self.node_id,
send_back_node_id=msg.node_id,
send_back_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
)
)
async def _nack_request(self, since_idx: int) -> None:
@@ -323,6 +392,7 @@ class Worker:
event_sender=self.event_sender.clone(),
)
self.runners[task.bound_instance.bound_runner_id] = runner
assert self._tg
self._tg.start_soon(runner.run)
return runner
@@ -339,13 +409,14 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata.model_card.model_id] = status
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
# TODO: i hate callbacks
def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
@@ -356,11 +427,12 @@ class Worker:
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
self.download_status[shard.model_meta.model_id] = status
# Footgun!
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
await self.event_sender.send(
self.event_sender.send_nowait(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
@@ -376,13 +448,14 @@ class Worker:
progress
),
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
self.download_status[shard.model_meta.model_id] = status
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
assert self._tg
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
@@ -403,14 +476,9 @@ class Worker:
async def _poll_connection_updates(self):
while True:
edges = set(
conn.edge for conn in self.state.topology.out_edges(self.node_id)
)
conns = await check_reachable(
self.state.topology,
self.node_id,
self.state.node_profiles,
)
# TODO: EdgeDeleted
edges = set(self.state.topology.list_connections())
conns = await check_reachable(self.state.topology, self.node_id)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
@@ -418,33 +486,26 @@ class Worker:
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = SocketConnection(
edge = Connection(
local_node_id=self.node_id,
send_back_node_id=nid,
# nonsense multiaddr
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
if "." in ip
# nonsense multiaddr
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
)
if edge not in edges:
logger.debug(f"ping discovered {edge=}")
await self.event_sender.send(
TopologyEdgeCreated(
conn=Connection(
source=self.node_id, sink=nid, edge=edge
)
)
)
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
for conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn.edge, SocketConnection):
continue
for nid, conn in self.state.topology.out_edges(self.node_id):
if (
conn.sink not in conns
or conn.edge.sink_multiaddr.ip_address
not in conns.get(conn.sink, set())
nid not in conns
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
):
logger.debug(f"ping failed to discover {conn=}")
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)
@@ -478,7 +539,7 @@ class Worker:
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
self.download_status[progress.shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)

View File

@@ -2,13 +2,15 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadModel,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -49,6 +51,8 @@ def plan(
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_counts: Mapping[CommandId, int] | None = None,
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
@@ -58,7 +62,7 @@ def plan(
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
)
@@ -114,7 +118,7 @@ def _model_needs_download(
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_card.model_id
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
@@ -191,7 +195,7 @@ def _load_model(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner
@@ -262,14 +266,24 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
if not isinstance(task, ChatCompletion):
# TODO(ciaran): do this better!
if not isinstance(task, (ChatCompletion, ImageGeneration, ImageEdits)):
continue
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
continue
# For ImageEdits tasks, verify all input chunks have been received
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue

View File

@@ -17,23 +17,15 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
and (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.jaccl_devices) >= 2
)
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
global logger
logger = _logger
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main

Some files were not shown because too many files have changed in this diff Show More