Compare commits

..

252 Commits

Author SHA1 Message Date
ciaranbor
09a4454427 Prevent running image editing model without an input image 2026-01-20 16:52:02 +00:00
ciaranbor
12d2da7e47 Type coercion for ModelTask 2026-01-20 16:47:49 +00:00
ciaranbor
e98affdb4d Fallback for resolving model card 2026-01-20 16:19:11 +00:00
ciaranbor
c33f79a133 Reflect ModelCard simplification 2026-01-20 16:07:57 +00:00
ciaranbor
4d826b465e Fix text model runtime check 2026-01-20 16:07:57 +00:00
ciaranbor
3fbe0b09bb Fix image streaming for editing 2026-01-20 16:07:57 +00:00
ciaranbor
2c9bfc0d8a Propagate additional image edit api params 2026-01-20 16:07:57 +00:00
ciaranbor
f4d7606ae3 Support image editing in UI 2026-01-20 16:07:57 +00:00
ciaranbor
fb81b5c0db Allow dropdowns to fit values 2026-01-20 16:07:57 +00:00
ciaranbor
39a26c6745 Better param dropdowns 2026-01-20 16:07:57 +00:00
ciaranbor
35e62c8087 Better typing 2026-01-20 16:07:57 +00:00
ciaranbor
f107ae0b43 Restore recv evals 2026-01-20 16:07:57 +00:00
ciaranbor
ba2faf3a71 Remove outputFormat param 2026-01-20 16:07:57 +00:00
ciaranbor
305f4b79d3 Expose image generation settings to UI 2026-01-20 16:07:57 +00:00
ciaranbor
40624bf494 DiffusionRunner type errors 2026-01-20 16:07:57 +00:00
ciaranbor
1277265ae2 Correctly handle num_sync_steps for diffusion steps override 2026-01-20 16:07:57 +00:00
ciaranbor
b5d1e78aba Simplify DistributedImageModel 2026-01-20 16:07:57 +00:00
ciaranbor
c712726778 Hide internal mflux image type 2026-01-20 16:07:57 +00:00
ciaranbor
c4a1f2cde8 Only run two steps during warmup 2026-01-20 16:07:57 +00:00
ciaranbor
4e287ab471 Document image apis 2026-01-20 16:07:57 +00:00
ciaranbor
beb39442ab Clean up ImageModelConfig 2026-01-20 16:07:57 +00:00
ciaranbor
1684907879 Add AdvancedImageParams to api 2026-01-20 16:07:57 +00:00
ciaranbor
842569655c Capture partial images earlier 2026-01-20 16:07:57 +00:00
ciaranbor
c437154755 Remove redundant tensor allocations for recv templates 2026-01-20 16:07:57 +00:00
ciaranbor
841f89cb0a Remove recv evals, use async_eval for sends 2026-01-20 16:07:57 +00:00
ciaranbor
8bc12db8d3 Remove ImageGenerator protocol 2026-01-20 16:07:57 +00:00
ciaranbor
0fd4ff1151 Batch CFG 2026-01-20 16:07:57 +00:00
ciaranbor
9cba51aee6 Add image generation benchmarking endpoints 2026-01-20 16:07:57 +00:00
ciaranbor
874ad0e4c0 Consolidate patched and unpatched qkv computation logic 2026-01-20 16:07:57 +00:00
ciaranbor
102681334d Use mixin for common block wrapper functionality 2026-01-20 16:07:57 +00:00
ciaranbor
bda4c903a4 Move last rank to first rank comms outside the CFG step 2026-01-20 16:07:57 +00:00
ciaranbor
b13c75671c Revert 2026-01-20 16:07:57 +00:00
ciaranbor
b3b3ae6ae6 Run all positive then all negative 2026-01-20 16:07:57 +00:00
ciaranbor
0fcd26c9d6 Rank 0 shouldn't receive on negative pass 2026-01-20 16:07:57 +00:00
ciaranbor
d8c0d3e1c5 Fix negative pass text_seq_len 2026-01-20 16:07:57 +00:00
ciaranbor
acfacbf34c Add distributed CFG support 2026-01-20 16:07:57 +00:00
ciaranbor
8b39485250 Enable CFG for Qwen-Image 2026-01-20 16:07:57 +00:00
ciaranbor
5d5676103e Use transformer block wrapper classes 2026-01-20 16:07:57 +00:00
ciaranbor
3af221bf75 Refactor 2026-01-20 16:07:57 +00:00
ciaranbor
e6993ad537 Fix flux tokenizer 2026-01-20 16:07:57 +00:00
ciaranbor
0fc68617fb Reduce image generation and image edits code duplication 2026-01-20 16:07:57 +00:00
ciaranbor
983420619e Update mflux to 0.14.2 2026-01-20 16:07:57 +00:00
ciaranbor
9381edfaa6 Add python-multipart dependency 2026-01-20 16:07:57 +00:00
ciaranbor
f3a840011d Linting 2026-01-20 16:07:57 +00:00
ciaranbor
971b8367ab Start image editing time steps at 0 2026-01-20 16:07:57 +00:00
ciaranbor
2d0335b7cc Ignore image_strength 2026-01-20 16:07:57 +00:00
ciaranbor
89b2471183 Handle conditioning latents in sync pipeline 2026-01-20 16:07:57 +00:00
ciaranbor
e39aec8189 Use dummy image for editing warmup 2026-01-20 16:07:57 +00:00
ciaranbor
00a828494b Support streaming for image editing 2026-01-20 16:07:57 +00:00
ciaranbor
0c2b760179 Support image editing in runner 2026-01-20 16:07:57 +00:00
ciaranbor
430ea689b1 Add editing features to adapter 2026-01-20 16:07:57 +00:00
ciaranbor
9e5c64eb07 Default partial images to 3 if streaming 2026-01-20 16:07:57 +00:00
ciaranbor
d6b44f4ba1 Add Qwen-Image model adapter 2026-01-20 16:07:57 +00:00
ciaranbor
ebcefc1fb3 Add Qwen-Image-Edit model config 2026-01-20 16:07:57 +00:00
ciaranbor
e7d9a66c15 Use image generation in streaming mode in UI 2026-01-20 16:07:57 +00:00
ciaranbor
5280b4d4ae Handle partial image streaming 2026-01-20 16:07:57 +00:00
ciaranbor
847ca9c21f Add streaming params to ImageGenerationTaskParams 2026-01-20 16:07:57 +00:00
ciaranbor
69f6d37c66 Add Qwen-Image-Edit-2509 2026-01-20 16:07:56 +00:00
ciaranbor
c74d59c810 Handle image editing time steps 2026-01-20 16:07:56 +00:00
ciaranbor
99e8223c0b Fix time steps 2026-01-20 16:07:56 +00:00
ciaranbor
fd321398f2 Fix image_strength meaning 2026-01-20 16:07:56 +00:00
ciaranbor
c1e8bcb54c Truncate image data logs 2026-01-20 16:07:56 +00:00
ciaranbor
82b0edf277 Chunk image input 2026-01-20 16:07:56 +00:00
ciaranbor
37a353448f Avoid logging image data 2026-01-20 16:07:56 +00:00
ciaranbor
699b63d1da Support image editing 2026-01-20 16:07:56 +00:00
Sami Khan
aaa7dcc0b3 small UI change 2026-01-20 16:07:56 +00:00
Sami Khan
9d1963a891 image gen in dashboard 2026-01-20 16:07:56 +00:00
ciaranbor
35f1425490 Better llm model type check 2026-01-20 16:07:56 +00:00
ciaranbor
0cb008328f Prune blocks before model load 2026-01-20 16:07:56 +00:00
ciaranbor
796e9ccd9a Own TODOs 2026-01-20 16:07:56 +00:00
ciaranbor
7a1c6b2a18 Remove double RunnerReady event 2026-01-20 16:07:56 +00:00
ciaranbor
f992389cb3 Fix hidden_size for image models 2026-01-20 16:07:56 +00:00
ciaranbor
12cf753614 Fix image model cards 2026-01-20 16:07:56 +00:00
ciaranbor
8673168b21 Skip decode on non-final ranks 2026-01-20 16:07:56 +00:00
ciaranbor
91bf4b5b5d Final rank produces image 2026-01-20 16:07:56 +00:00
ciaranbor
bf4e114ec9 Increase number of sync steps 2026-01-20 16:07:56 +00:00
ciaranbor
f39792b5b9 Change Qwen-Image steps 2026-01-20 16:07:56 +00:00
ciaranbor
2bc2cd2108 Fix Qwen-Image latent shapes 2026-01-20 16:07:56 +00:00
ciaranbor
88289b617b Fix joint block patch recv shape for non-zero ranks 2026-01-20 16:07:56 +00:00
ciaranbor
5295873bb2 Fix comms issue for models without single blocks 2026-01-20 16:07:56 +00:00
ciaranbor
1de5fb6716 Support Qwen in DiffusionRunner pipefusion 2026-01-20 16:07:56 +00:00
ciaranbor
4a47448fbf Implement Qwen pipefusion 2026-01-20 16:07:56 +00:00
ciaranbor
b381d4b518 Add guidance_scale parameter to image model config 2026-01-20 16:07:56 +00:00
ciaranbor
8be71107fd Move orchestration to DiffusionRunner 2026-01-20 16:07:56 +00:00
ciaranbor
c67356f4dd Add initial QwenModelAdapter 2026-01-20 16:07:56 +00:00
ciaranbor
65e2a24d05 Tweak embeddings interface 2026-01-20 16:07:56 +00:00
ciaranbor
4093078730 Add Qwen ImageModelConfig 2026-01-20 16:07:56 +00:00
ciaranbor
76cbcabcab Use 10% sync steps 2026-01-20 16:07:56 +00:00
ciaranbor
36944d40ae Update FluxModelAdaper for new interface 2026-01-20 16:07:56 +00:00
ciaranbor
85f9bb4b27 Register QwenModelAdapter 2026-01-20 16:07:56 +00:00
ciaranbor
2e590c18dd Support multiple forward passes in runner 2026-01-20 16:07:56 +00:00
ciaranbor
4e881a6c21 Extend block wrapper parameters 2026-01-20 16:07:56 +00:00
ciaranbor
af8d373411 Relax adaptor typing 2026-01-20 16:07:56 +00:00
ciaranbor
55eb7620bf Add Qwen-Image model card 2026-01-20 16:07:56 +00:00
ciaranbor
0fb299db17 Clean up dead code 2026-01-20 16:07:56 +00:00
ciaranbor
a7870e4904 Add BaseModelAdaptor 2026-01-20 16:07:56 +00:00
ciaranbor
95c7b26178 Refactor filestructure 2026-01-20 16:07:56 +00:00
ciaranbor
8747a74b32 Treat unified blocks as single blocks (equivalent) 2026-01-20 16:07:56 +00:00
ciaranbor
6c9e582ffa Refactor to handle entire denoising process in Diffusion runner 2026-01-20 16:07:56 +00:00
ciaranbor
ae07a76d99 Move transformer to adapter 2026-01-20 16:07:56 +00:00
ciaranbor
c41f7adfdc Move some more logic to adaptor 2026-01-20 16:07:56 +00:00
ciaranbor
5d38221731 Add generic block wrapper 2026-01-20 16:07:56 +00:00
ciaranbor
a1d7eb61b6 Access transformer blocks from adaptor 2026-01-20 16:07:56 +00:00
ciaranbor
dc51132bb7 Better typing 2026-01-20 16:07:56 +00:00
ciaranbor
5c9227ce42 Create wrappers at init time 2026-01-20 16:07:56 +00:00
ciaranbor
e0cd04c5f3 Combine model factory and adaptor 2026-01-20 16:07:56 +00:00
ciaranbor
c2aab343c4 Implement model factory 2026-01-20 16:07:56 +00:00
ciaranbor
3bcdd46bb1 Add adaptor registry 2026-01-20 16:07:56 +00:00
ciaranbor
46181a35ae Remove mflux/generator/generate.py 2026-01-20 16:07:56 +00:00
ciaranbor
e29d0b4a0e Switch to using DistributedImageModel 2026-01-20 16:07:56 +00:00
ciaranbor
633147cb02 Add DistributedImageModel 2026-01-20 16:07:56 +00:00
ciaranbor
3f5b4a43db Use new generic wrappers, etc in denoising 2026-01-20 16:07:56 +00:00
ciaranbor
9ee7a3e92b Add generic transformer block wrappers 2026-01-20 16:07:56 +00:00
ciaranbor
3e40b2beb5 Add FluxAdaptor 2026-01-20 16:07:56 +00:00
ciaranbor
4d6f339e6f Add ModelAdaptor, derivations implement model specific logic 2026-01-20 16:07:56 +00:00
ciaranbor
282b63effb Introduce image model config concept 2026-01-20 16:07:56 +00:00
ciaranbor
cb537b0110 Consolidate kv cache patching 2026-01-20 16:07:56 +00:00
ciaranbor
afc1643eb2 Support different configuration comms 2026-01-20 16:07:56 +00:00
ciaranbor
02ca7d5a4b Add ImageGenerator protocol 2026-01-20 16:07:56 +00:00
ciaranbor
73abda4f17 Force final patch receive order 2026-01-20 16:07:56 +00:00
ciaranbor
f492fd4be8 Remove logs 2026-01-20 16:07:56 +00:00
ciaranbor
e59b51788b Update patch list 2026-01-20 16:07:56 +00:00
ciaranbor
f1b08bdf68 Slight refactor 2026-01-20 16:07:56 +00:00
ciaranbor
1e74c5ec4f Don't need array for prev patches 2026-01-20 16:07:56 +00:00
ciaranbor
839e845876 Fix send/recv order 2026-01-20 16:07:56 +00:00
ciaranbor
8ab312de44 Fix async single transformer block 2026-01-20 16:07:56 +00:00
ciaranbor
9e0d646505 Use relative rank variables 2026-01-20 16:07:56 +00:00
ciaranbor
6a21dca2e0 Fix writing patches 2026-01-20 16:07:56 +00:00
ciaranbor
b491607a8f Collect final image 2026-01-20 16:07:56 +00:00
ciaranbor
258754a5e8 Fix recv_template shape 2026-01-20 16:07:56 +00:00
ciaranbor
09ec079be8 Add logs 2026-01-20 16:07:56 +00:00
ciaranbor
3da22204db Optimise async pipeline 2026-01-20 16:07:56 +00:00
ciaranbor
60d7ea6265 Add next_rank and prev_rank members 2026-01-20 16:07:56 +00:00
ciaranbor
f37751b31f Add _create_patches method 2026-01-20 16:07:56 +00:00
ciaranbor
0f357e1f9b Fix shapes 2026-01-20 16:07:56 +00:00
ciaranbor
9b0a621987 Reorder comms 2026-01-20 16:07:56 +00:00
ciaranbor
5574eb57e5 Remove all_gather from sync pipeline, send from final rank to first rank 2026-01-20 16:07:56 +00:00
ciaranbor
ddc67f09cc Simplify kv_cache initialization 2026-01-20 16:07:56 +00:00
ciaranbor
27b343316b Fix kv cache 2026-01-20 16:07:56 +00:00
ciaranbor
99705175ee Clean up kv caches 2026-01-20 16:07:56 +00:00
ciaranbor
8aad72b4d7 Fix return 2026-01-20 16:07:56 +00:00
ciaranbor
2ce3833b17 Fix hidden_states shapes 2026-01-20 16:07:56 +00:00
ciaranbor
ac535d5725 Only perform projection and scheduler step on last rank 2026-01-20 16:07:56 +00:00
ciaranbor
9dae2eafd4 Only compute embeddings on rank 0 2026-01-20 16:07:56 +00:00
ciaranbor
d082db113d Remove eval 2026-01-20 16:07:56 +00:00
ciaranbor
48298104a8 Remove eval 2026-01-20 16:07:56 +00:00
ciaranbor
49823aa6d5 Only send encoder_hidden_states with the first patch (once per timestep) 2026-01-20 16:07:56 +00:00
ciaranbor
76eb5171f4 Remove redundant text kv cache computation 2026-01-20 16:07:56 +00:00
ciaranbor
7fa8ebacad Concatenate before all gather 2026-01-20 16:07:56 +00:00
ciaranbor
cebd3de003 Increase number of sync steps 2026-01-20 16:07:56 +00:00
ciaranbor
e6cd1291b9 Reinitialise kv_caches between generations 2026-01-20 16:07:56 +00:00
ciaranbor
52d9e77bed Eliminate double kv cache computation 2026-01-20 16:07:56 +00:00
ciaranbor
3d41745b51 Add kv cache caching wrappers for sync pipeline transformer blocks 2026-01-20 16:07:56 +00:00
ciaranbor
c5c3f43c6c Persist kv caches 2026-01-20 16:07:56 +00:00
ciaranbor
8ac0686125 Implement naive async pipeline implementation 2026-01-20 16:07:56 +00:00
ciaranbor
4d8a759eb2 Use wrapper classes for patched transformer logic 2026-01-20 16:07:56 +00:00
ciaranbor
5d7c0847c1 Add patch-aware joint and single attention wrappers 2026-01-20 16:07:56 +00:00
ciaranbor
c4e916fc02 Fix group.size() 2026-01-20 16:07:56 +00:00
ciaranbor
06ce520f4d Add classes to manage kv caches with patch support 2026-01-20 16:07:56 +00:00
ciaranbor
1083808072 Use heuristic for number of sync steps 2026-01-20 16:07:56 +00:00
ciaranbor
686fa6e04c Generalise number of denoising steps 2026-01-20 16:07:56 +00:00
ciaranbor
fb0e64f4c0 Add flux1-dev 2026-01-20 16:07:56 +00:00
ciaranbor
11f2a5bda5 Move scheduler step to inner pipeline 2026-01-20 16:07:56 +00:00
ciaranbor
2b547a18f8 Add barrier before all_gather 2026-01-20 16:07:56 +00:00
ciaranbor
c4d558d550 Fix transformer blocks pruning 2026-01-20 16:07:56 +00:00
ciaranbor
5d5a1aa561 Fix image generation api 2026-01-20 16:07:56 +00:00
ciaranbor
33d34e376b Create queue in try block 2026-01-20 16:07:56 +00:00
ciaranbor
fa1d6de18c Conform to rebase 2026-01-20 16:07:56 +00:00
ciaranbor
bdd3863d2d Refactor denoising 2026-01-20 16:07:56 +00:00
ciaranbor
48b687af45 Move more logic to DistributedFlux 2026-01-20 16:07:56 +00:00
ciaranbor
a91e0f55ca Move surrounding logic back to _sync_pipeline 2026-01-20 16:07:56 +00:00
ciaranbor
35ef128f04 Add patching aware member variables 2026-01-20 16:07:56 +00:00
ciaranbor
34d3b84405 Implement sync/async switching logic 2026-01-20 16:07:56 +00:00
ciaranbor
7a4b9e7884 Move current transformer implementation to _sync_pipeline method 2026-01-20 16:07:56 +00:00
ciaranbor
da251c0c47 Remove some logs 2026-01-20 16:07:56 +00:00
ciaranbor
87b6940580 Remove old Flux1 implementation 2026-01-20 16:07:56 +00:00
ciaranbor
630af5b0f1 Prune unused transformer blocks 2026-01-20 16:07:56 +00:00
ciaranbor
2ac27234fb Add mx.eval 2026-01-20 16:07:56 +00:00
ciaranbor
1ad090b4f4 Test evals 2026-01-20 16:07:56 +00:00
ciaranbor
4f746c1575 Test only barriers 2026-01-20 16:07:56 +00:00
ciaranbor
e1921b24a0 All perform final projection 2026-01-20 16:07:56 +00:00
ciaranbor
c76da7c220 Another barrier 2026-01-20 16:07:56 +00:00
ciaranbor
af27feedc9 More debug 2026-01-20 16:07:56 +00:00
ciaranbor
f088bbad92 Add barriers 2026-01-20 16:07:56 +00:00
ciaranbor
3f1bfc6ee1 Add log 2026-01-20 16:07:56 +00:00
ciaranbor
eaa450e2db Restore distributed logging 2026-01-20 16:07:56 +00:00
ciaranbor
8964137f2b Use bootstrap logger 2026-01-20 16:07:56 +00:00
ciaranbor
984084fe8e Remove logs 2026-01-20 16:07:56 +00:00
ciaranbor
dd2d25951d fix single block receive shape 2026-01-20 16:07:56 +00:00
ciaranbor
9f5f763993 Add debug logs 2026-01-20 16:07:56 +00:00
ciaranbor
244f3a1bb4 Move communication logic to DistributedTransformer wrapper 2026-01-20 16:07:56 +00:00
ciaranbor
efa419b36b Move inference logic to DistribuedFlux1 2026-01-20 16:07:56 +00:00
ciaranbor
ec035fab4e Add DistributedFlux1 class 2026-01-20 16:07:56 +00:00
ciaranbor
48cef402f1 Rename pipeline to pipefusion 2026-01-20 16:07:56 +00:00
ciaranbor
ec68ad19a5 Further refactor 2026-01-20 16:07:56 +00:00
ciaranbor
a2f0da26b6 Refactor warmup 2026-01-20 16:07:56 +00:00
ciaranbor
49ac2e1aeb Manually handle flux1 inference 2026-01-20 16:07:56 +00:00
ciaranbor
78d9c5264d Refactor flux1 image generation 2026-01-20 16:07:56 +00:00
ciaranbor
f5f19415da Use quality parameter to set number of inference steps 2026-01-20 16:07:56 +00:00
ciaranbor
e8c4293f1e Chunk image data transfer 2026-01-20 16:07:56 +00:00
ciaranbor
f278464ff9 Define EXO_MAX_CHUNK_SIZE 2026-01-20 16:07:56 +00:00
ciaranbor
ce06bbb95e Add indexing info to ImageChunk 2026-01-20 16:07:56 +00:00
ciaranbor
86c4d24e23 Remove sharding logs 2026-01-20 16:07:56 +00:00
ciaranbor
551daa2c06 Temp: reduce flux1.schnell storage size 2026-01-20 16:07:55 +00:00
ciaranbor
6ca3042f8e Fix mflux transformer all_gather 2026-01-20 16:07:55 +00:00
ciaranbor
62b653299f Fix world size 2026-01-20 16:07:55 +00:00
ciaranbor
2b4e931873 Fix transition block? 2026-01-20 16:07:55 +00:00
ciaranbor
c67b684552 Implement image generation warmup 2026-01-20 16:07:55 +00:00
ciaranbor
926a157476 Add logs 2026-01-20 16:07:55 +00:00
ciaranbor
d9586028d1 Add spiece.model to default patterns 2026-01-20 16:07:55 +00:00
ciaranbor
e971744960 Just download all files for now 2026-01-20 16:07:55 +00:00
ciaranbor
b2abece9aa Fix get_allow_patterns to include non-indexed safetensors files 2026-01-20 16:07:55 +00:00
ciaranbor
1d5bd3e447 Use half-open layer indexing in get_allow_patterns 2026-01-20 16:07:55 +00:00
ciaranbor
24668c54fc Enable distributed mflux 2026-01-20 16:07:55 +00:00
ciaranbor
aff75d9992 Implement mflux transformer sharding and communication pattern 2026-01-20 16:07:55 +00:00
ciaranbor
49cb8422a1 Update get_allow_patterns to handle sharding components 2026-01-20 16:07:55 +00:00
ciaranbor
19566deffd Namespace both keys and values for component weight maps 2026-01-20 16:07:55 +00:00
ciaranbor
d046475ec4 Add components to Flux.1-schnell MODEL_CARD 2026-01-20 16:07:55 +00:00
ciaranbor
780e0ea425 Add component concept for ModelMetadata 2026-01-20 16:07:55 +00:00
ciaranbor
d301694f9c Fix multiple components weight map key conflicts 2026-01-20 16:07:55 +00:00
ciaranbor
407c0b2418 get_weight_map: handle repos with multiple safetensors.index.json files 2026-01-20 16:07:55 +00:00
ciaranbor
3559e7f43b Add initial image edits spec 2026-01-20 16:07:55 +00:00
ciaranbor
938cf08a5f Add image edits endpoint 2026-01-20 16:07:55 +00:00
ciaranbor
3bc016cec4 Add ImageToImage task 2026-01-20 16:07:55 +00:00
ciaranbor
21a1dc8c55 Allow ModelCards to have multiple tasks 2026-01-20 16:07:55 +00:00
ciaranbor
1ea271788f Fix text generation 2026-01-20 16:07:55 +00:00
ciaranbor
3a71354b1c Rename mlx_generate_image to mflux_generate 2026-01-20 16:07:55 +00:00
ciaranbor
d7b423db0e Initialize mlx or mflux engine based on model task 2026-01-20 16:07:55 +00:00
ciaranbor
b940578e57 Restore warmup for text generation 2026-01-20 16:07:55 +00:00
ciaranbor
10dda7936d Add initialize_mflux function 2026-01-20 16:07:55 +00:00
ciaranbor
addc3ffe15 Move image generation to mflux engine 2026-01-20 16:07:55 +00:00
ciaranbor
27c4f60a91 Just use str for image generation size 2026-01-20 16:07:55 +00:00
ciaranbor
b9c34cdcc4 Use MFlux for image generation 2026-01-20 16:07:55 +00:00
ciaranbor
fd5835271f Add get_model_card function 2026-01-20 16:07:55 +00:00
ciaranbor
77d5c9b38f Add ModelTask enum 2026-01-20 16:07:55 +00:00
ciaranbor
051577d122 ADd flux1-schnell model 2026-01-20 16:07:55 +00:00
ciaranbor
013f66956a Add task field to ModelCard 2026-01-20 16:07:55 +00:00
ciaranbor
eead745cd0 Update mflux version 2026-01-20 16:07:55 +00:00
ciaranbor
6895ea935b Enable recursive repo downloads 2026-01-20 16:07:55 +00:00
ciaranbor
8451cb7d16 Add dummy generate_image implementation 2026-01-20 16:07:55 +00:00
ciaranbor
3c8f5f3464 Use base64 encoded str for image data 2026-01-20 16:07:55 +00:00
ciaranbor
1bf4b42830 Handle ImageGeneration tasks in _pending_tasks 2026-01-20 16:07:55 +00:00
ciaranbor
5ba044d0f8 Add mflux dependency 2026-01-20 16:07:55 +00:00
ciaranbor
c33fd5c9ae Handle ImageGeneration task in runner task processing 2026-01-20 16:07:55 +00:00
ciaranbor
0e789435d4 Handle ImageGeneration command in master command processing 2026-01-20 16:07:55 +00:00
ciaranbor
0ca2e449b9 Add image generation to API 2026-01-20 16:07:55 +00:00
ciaranbor
c99cd5fc87 Add ImageGenerationResponse 2026-01-20 16:07:55 +00:00
ciaranbor
e86caf8f48 Add ImageGeneration task 2026-01-20 16:07:55 +00:00
ciaranbor
2d483a5a77 Add ImageGeneration command 2026-01-20 16:07:55 +00:00
ciaranbor
01db2adc80 Add image generation params and response types 2026-01-20 16:07:55 +00:00
ciaranbor
c10f87b006 Add pillow dependency 2026-01-20 16:07:55 +00:00
ciaranbor
0c6d04e085 Fix mlx stream_generate import 2026-01-20 16:07:55 +00:00
55 changed files with 9894 additions and 2456 deletions

View File

@@ -863,6 +863,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -902,6 +903,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1518,6 +1520,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1527,6 +1530,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1939,6 +1943,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2646,6 +2651,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2833,6 +2839,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2977,6 +2984,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -2998,6 +3006,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -1,8 +1,22 @@
<script lang="ts">
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import ChatAttachments from './ChatAttachments.svelte';
import type { ChatUploadedFile } from '$lib/types/files';
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
import {
isLoading,
sendMessage,
generateImage,
editImage,
editingImage,
clearEditingImage,
selectedChatModel,
setSelectedChatModel,
instances,
ttftMs,
tps,
totalTokens,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
import type { ChatUploadedFile } from "$lib/types/files";
import { processUploadedFiles, getAcceptString } from "$lib/types/files";
interface Props {
class?: string;
@@ -10,17 +24,19 @@
showHelperText?: boolean;
autofocus?: boolean;
showModelSelector?: boolean;
modelTasks?: Record<string, string[]>;
}
let {
class: className = '',
placeholder = 'Ask anything',
let {
class: className = "",
placeholder = "Ask anything",
showHelperText = false,
autofocus = true,
showModelSelector = false
showModelSelector = false,
modelTasks = {},
}: Props = $props();
let message = $state('');
let message = $state("");
let textareaRef: HTMLTextAreaElement | undefined = $state();
let fileInputRef: HTMLInputElement | undefined = $state();
let uploadedFiles = $state<ChatUploadedFile[]>([]);
@@ -31,30 +47,82 @@
const currentTtft = $derived(ttftMs());
const currentTps = $derived(tps());
const currentTokens = $derived(totalTokens());
const currentEditingImage = $derived(editingImage());
const isEditMode = $derived(currentEditingImage !== null);
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let dropdownButtonRef: HTMLButtonElement | undefined = $state();
let dropdownPosition = $derived(() => {
if (!dropdownButtonRef || !isModelDropdownOpen) return { top: 0, left: 0, width: 0 };
if (!dropdownButtonRef || !isModelDropdownOpen)
return { top: 0, left: 0, width: 0 };
const rect = dropdownButtonRef.getBoundingClientRect();
return {
top: rect.top,
left: rect.left,
width: rect.width
width: rect.width,
};
});
// Accept all supported file types
const acceptString = getAcceptString(['image', 'text', 'pdf']);
const acceptString = getAcceptString(["image", "text", "pdf"]);
function modelSupportsImageGeneration(modelId: string): boolean {
const tasks = modelTasks[modelId] || [];
return tasks.includes("TextToImage") || tasks.includes("ImageToImage");
}
function modelSupportsTextToImage(modelId: string): boolean {
const tasks = modelTasks[modelId] || [];
return tasks.includes("TextToImage");
}
function modelSupportsOnlyImageEditing(modelId: string): boolean {
const tasks = modelTasks[modelId] || [];
return tasks.includes("ImageToImage") && !tasks.includes("TextToImage");
}
function modelSupportsImageEditing(modelId: string): boolean {
const tasks = modelTasks[modelId] || [];
return tasks.includes("ImageToImage");
}
const isImageModel = $derived(() => {
if (!currentModel) return false;
return modelSupportsTextToImage(currentModel);
});
const isEditOnlyWithoutImage = $derived(
currentModel !== null &&
modelSupportsOnlyImageEditing(currentModel) &&
!isEditMode &&
uploadedFiles.length === 0,
);
// Show edit mode when: explicit edit mode OR (model supports ImageToImage AND files attached)
const shouldShowEditMode = $derived(
isEditMode ||
(currentModel &&
modelSupportsImageEditing(currentModel) &&
uploadedFiles.length > 0),
);
// Extract available models from running instances
const availableModels = $derived(() => {
const models: Array<{id: string, label: string}> = [];
const models: Array<{ id: string; label: string; isImageModel: boolean }> =
[];
for (const [, instance] of Object.entries(instanceData)) {
const modelId = getInstanceModelId(instance);
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
if (
modelId &&
modelId !== "Unknown" &&
!models.some((m) => m.id === modelId)
) {
models.push({
id: modelId,
label: modelId.split("/").pop() || modelId,
isImageModel: modelSupportsImageGeneration(modelId),
});
}
}
return models;
@@ -66,18 +134,18 @@
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
$effect(() => {
const models = availableModels();
const currentModelIds = new Set(models.map(m => m.id));
const currentModelIds = new Set(models.map((m) => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
const newModels = models.filter((m) => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
else if (!models.some((m) => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
@@ -87,7 +155,7 @@
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
setSelectedChatModel("");
}
}
@@ -96,13 +164,15 @@
});
function getInstanceModelId(instanceWrapped: unknown): string {
if (!instanceWrapped || typeof instanceWrapped !== 'object') return '';
if (!instanceWrapped || typeof instanceWrapped !== "object") return "";
const keys = Object.keys(instanceWrapped as Record<string, unknown>);
if (keys.length === 1) {
const instance = (instanceWrapped as Record<string, unknown>)[keys[0]] as { shardAssignments?: { modelId?: string } };
return instance?.shardAssignments?.modelId || '';
const instance = (instanceWrapped as Record<string, unknown>)[
keys[0]
] as { shardAssignments?: { modelId?: string } };
return instance?.shardAssignments?.modelId || "";
}
return '';
return "";
}
async function handleFiles(files: File[]) {
@@ -115,33 +185,35 @@
const input = event.target as HTMLInputElement;
if (input.files && input.files.length > 0) {
handleFiles(Array.from(input.files));
input.value = ''; // Reset for next selection
input.value = ""; // Reset for next selection
}
}
function handleFileRemove(fileId: string) {
uploadedFiles = uploadedFiles.filter(f => f.id !== fileId);
uploadedFiles = uploadedFiles.filter((f) => f.id !== fileId);
}
function handlePaste(event: ClipboardEvent) {
if (!event.clipboardData) return;
const files = Array.from(event.clipboardData.items)
.filter(item => item.kind === 'file')
.map(item => item.getAsFile())
.filter((item) => item.kind === "file")
.map((item) => item.getAsFile())
.filter((file): file is File => file !== null);
if (files.length > 0) {
event.preventDefault();
handleFiles(files);
return;
}
// Handle long text paste as file
const text = event.clipboardData.getData('text/plain');
const text = event.clipboardData.getData("text/plain");
if (text.length > 2500) {
event.preventDefault();
const textFile = new File([text], 'pasted-text.txt', { type: 'text/plain' });
const textFile = new File([text], "pasted-text.txt", {
type: "text/plain",
});
handleFiles([textFile]);
}
}
@@ -159,7 +231,7 @@
function handleDrop(event: DragEvent) {
event.preventDefault();
isDragOver = false;
if (event.dataTransfer?.files) {
handleFiles(Array.from(event.dataTransfer.files));
}
@@ -170,8 +242,8 @@
if (event.isComposing || event.keyCode === 229) {
return;
}
if (event.key === 'Enter' && !event.shiftKey) {
if (event.key === "Enter" && !event.shiftKey) {
event.preventDefault();
handleSubmit();
}
@@ -179,29 +251,50 @@
function handleSubmit() {
if ((!message.trim() && uploadedFiles.length === 0) || loading) return;
const content = message.trim();
const files = [...uploadedFiles];
message = '';
message = "";
uploadedFiles = [];
resetTextareaHeight();
sendMessage(content, files);
// Use image editing if in edit mode
if (isEditMode && currentEditingImage && content) {
editImage(content, currentEditingImage.imageDataUrl);
}
// If user attached an image with an ImageToImage model, use edit endpoint
else if (
currentModel &&
modelSupportsImageEditing(currentModel) &&
files.length > 0 &&
content
) {
// Use the first attached image for editing
const imageFile = files[0];
if (imageFile.preview) {
editImage(content, imageFile.preview);
}
} else if (isImageModel() && content) {
// Use image generation for text-to-image models
generateImage(content);
} else {
sendMessage(content, files);
}
// Refocus the textarea after sending
setTimeout(() => textareaRef?.focus(), 10);
}
function handleInput() {
if (!textareaRef) return;
textareaRef.style.height = 'auto';
textareaRef.style.height = Math.min(textareaRef.scrollHeight, 150) + 'px';
textareaRef.style.height = "auto";
textareaRef.style.height = Math.min(textareaRef.scrollHeight, 150) + "px";
}
function resetTextareaHeight() {
if (textareaRef) {
textareaRef.style.height = 'auto';
textareaRef.style.height = "auto";
}
}
@@ -211,13 +304,13 @@
// Track previous loading state to detect when loading completes
let wasLoading = $state(false);
$effect(() => {
if (autofocus && textareaRef) {
setTimeout(() => textareaRef?.focus(), 10);
}
});
// Refocus after loading completes (AI response finished)
$effect(() => {
if (wasLoading && !loading && textareaRef) {
@@ -226,7 +319,9 @@
wasLoading = loading;
});
const canSend = $derived(message.trim().length > 0 || uploadedFiles.length > 0);
const canSend = $derived(
message.trim().length > 0 || uploadedFiles.length > 0,
);
</script>
<!-- Hidden file input -->
@@ -239,69 +334,132 @@
onchange={handleFileInputChange}
/>
<form
onsubmit={(e) => { e.preventDefault(); handleSubmit(); }}
<form
onsubmit={(e) => {
e.preventDefault();
handleSubmit();
}}
class="w-full {className}"
ondragover={handleDragOver}
ondragleave={handleDragLeave}
ondrop={handleDrop}
>
<div
class="relative command-panel rounded overflow-hidden transition-all duration-200 {isDragOver ? 'ring-2 ring-exo-yellow ring-opacity-50' : ''}"
<div
class="relative command-panel rounded overflow-hidden transition-all duration-200 {isDragOver
? 'ring-2 ring-exo-yellow ring-opacity-50'
: ''}"
>
<!-- Top accent line -->
<div class="absolute top-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/50 to-transparent"></div>
<div
class="absolute top-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/50 to-transparent"
></div>
<!-- Drag overlay -->
{#if isDragOver}
<div class="absolute inset-0 bg-exo-dark-gray/80 z-10 flex items-center justify-center">
<div
class="absolute inset-0 bg-exo-dark-gray/80 z-10 flex items-center justify-center"
>
<div class="text-exo-yellow text-sm font-mono tracking-wider uppercase">
DROP FILES HERE
</div>
</div>
{/if}
<!-- Edit mode banner -->
{#if isEditMode && currentEditingImage}
<div
class="flex items-center gap-3 px-3 py-2 bg-exo-yellow/10 border-b border-exo-yellow/30"
>
<img
src={currentEditingImage.imageDataUrl}
alt="Source for editing"
class="w-10 h-10 object-cover rounded border border-exo-yellow/30"
/>
<div class="flex-1">
<span
class="text-xs font-mono tracking-wider uppercase text-exo-yellow"
>EDITING IMAGE</span
>
</div>
<button
type="button"
onclick={() => clearEditingImage()}
class="px-2 py-1 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50 rounded hover:bg-exo-medium-gray/50 hover:text-exo-yellow transition-colors cursor-pointer"
>
CANCEL
</button>
</div>
{/if}
<!-- Model selector (when enabled) -->
{#if showModelSelector && availableModels().length > 0}
<div class="flex items-center justify-between gap-2 px-3 py-2 border-b border-exo-medium-gray/30">
<div
class="flex items-center justify-between gap-2 px-3 py-2 border-b border-exo-medium-gray/30"
>
<div class="flex items-center gap-2 flex-1">
<span class="text-xs text-exo-light-gray uppercase tracking-wider flex-shrink-0">MODEL:</span>
<span
class="text-xs text-exo-light-gray uppercase tracking-wider flex-shrink-0"
>MODEL:</span
>
<!-- Custom dropdown -->
<div class="relative flex-1 max-w-xs">
<button
bind:this={dropdownButtonRef}
type="button"
onclick={() => isModelDropdownOpen = !isModelDropdownOpen}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-1.5 text-xs font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen ? 'border-exo-yellow/70' : ''}"
onclick={() => (isModelDropdownOpen = !isModelDropdownOpen)}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-1.5 text-xs font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
{#if availableModels().find(m => m.id === currentModel)}
<span class="text-exo-yellow truncate">{availableModels().find(m => m.id === currentModel)?.label}</span>
{#if availableModels().find((m) => m.id === currentModel)}
<span class="text-exo-yellow truncate"
>{availableModels().find((m) => m.id === currentModel)
?.label}</span
>
{:else if availableModels().length > 0}
<span class="text-exo-yellow truncate">{availableModels()[0].label}</span>
<span class="text-exo-yellow truncate"
>{availableModels()[0].label}</span
>
{:else}
<span class="text-exo-light-gray/50">— SELECT MODEL —</span>
{/if}
</button>
<div class="absolute right-2 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isModelDropdownOpen ? 'rotate-180' : ''}">
<svg class="w-3 h-3 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
<div
class="absolute right-2 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isModelDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isModelDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => isModelDropdownOpen = false}
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isModelDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto"
style="bottom: calc(100vh - {dropdownPosition().top}px + 4px); left: {dropdownPosition().left}px; width: {dropdownPosition().width}px;"
style="bottom: calc(100vh - {dropdownPosition()
.top}px + 4px); left: {dropdownPosition()
.left}px; width: {dropdownPosition().width}px;"
>
<div class="py-1">
{#each availableModels() as model}
@@ -311,20 +469,48 @@
setSelectedChatModel(model.id);
isModelDropdownOpen = false;
}}
class="w-full px-3 py-2 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {
currentModel === model.id
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'
}"
class="w-full px-3 py-2 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {currentModel ===
model.id
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if currentModel === model.id}
<svg class="w-3 h-3 flex-shrink-0" fill="currentColor" viewBox="0 0 20 20">
<path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd" />
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span class="truncate">{model.label}</span>
{#if model.isImageModel}
<svg
class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image generation model"
>
<rect
x="3"
y="3"
width="18"
height="18"
rx="2"
ry="2"
/>
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
<span class="truncate flex-1">{model.label}</span>
</button>
{/each}
</div>
@@ -336,30 +522,37 @@
<div class="flex items-center gap-4 text-xs font-mono flex-shrink-0">
{#if currentTtft !== null}
<span class="text-exo-light-gray">
<span class="text-white/70">TTFT</span> <span class="text-exo-yellow">{currentTtft.toFixed(1)}ms</span>
<span class="text-white/70">TTFT</span>
<span class="text-exo-yellow">{currentTtft.toFixed(1)}ms</span>
</span>
{/if}
{#if currentTps !== null}
<span class="text-exo-light-gray">
<span class="text-white/70">TPS</span> <span class="text-exo-yellow">{currentTps.toFixed(1)}</span> <span class="text-white/60">tok/s</span>
<span class="text-white/50">({(1000 / currentTps).toFixed(1)} ms/tok)</span>
<span class="text-white/70">TPS</span>
<span class="text-exo-yellow">{currentTps.toFixed(1)}</span>
<span class="text-white/60">tok/s</span>
<span class="text-white/50"
>({(1000 / currentTps).toFixed(1)} ms/tok)</span
>
</span>
{/if}
</div>
{/if}
</div>
{/if}
<!-- Image params panel (shown for image models or edit mode) -->
{#if showModelSelector && (isImageModel() || isEditMode)}
<ImageParamsPanel {isEditMode} />
{/if}
<!-- Attached files preview -->
{#if uploadedFiles.length > 0}
<div class="px-3 pt-3">
<ChatAttachments
files={uploadedFiles}
onRemove={handleFileRemove}
/>
<ChatAttachments files={uploadedFiles} onRemove={handleFileRemove} />
</div>
{/if}
<!-- Input area -->
<div class="flex items-start gap-2 sm:gap-3 py-3 px-3 sm:px-4">
<!-- Attach file button -->
@@ -370,58 +563,130 @@
class="flex items-center justify-center w-7 h-7 rounded text-exo-light-gray hover:text-exo-yellow transition-all disabled:opacity-50 disabled:cursor-not-allowed flex-shrink-0 cursor-pointer"
title="Attach file"
>
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M15.172 7l-6.586 6.586a2 2 0 102.828 2.828l6.414-6.586a4 4 0 00-5.656-5.656l-6.415 6.585a6 6 0 108.486 8.486L20.5 13" />
<svg
class="w-4 h-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M15.172 7l-6.586 6.586a2 2 0 102.828 2.828l6.414-6.586a4 4 0 00-5.656-5.656l-6.415 6.585a6 6 0 108.486 8.486L20.5 13"
/>
</svg>
</button>
<!-- Terminal prompt -->
<span class="text-exo-yellow text-sm font-bold flex-shrink-0 leading-7"></span>
<span class="text-exo-yellow text-sm font-bold flex-shrink-0 leading-7"
></span
>
<textarea
bind:this={textareaRef}
bind:value={message}
onkeydown={handleKeydown}
oninput={handleInput}
onpaste={handlePaste}
{placeholder}
placeholder={isEditOnlyWithoutImage
? "Attach an image to edit..."
: isEditMode
? "Describe how to edit this image..."
: isImageModel()
? "Describe the image you want to generate..."
: placeholder}
disabled={loading}
rows={1}
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
style="min-height: 28px; max-height: 150px;"
></textarea>
<button
type="submit"
disabled={!canSend || loading}
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label="Send message"
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
<span class="hidden sm:inline">PROCESSING</span>
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<span class="sm:hidden">...</span>
</span>
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</div>
<!-- Bottom accent line -->
<div class="absolute bottom-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/30 to-transparent"></div>
<div
class="absolute bottom-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/30 to-transparent"
></div>
</div>
{#if showHelperText}
<p class="mt-2 sm:mt-3 text-center text-xs sm:text-xs text-exo-light-gray tracking-[0.1em] sm:tracking-[0.15em] uppercase">
<kbd class="px-1 sm:px-1.5 py-0.5 rounded bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50">ENTER</kbd>
<p
class="mt-2 sm:mt-3 text-center text-xs sm:text-xs text-exo-light-gray tracking-[0.1em] sm:tracking-[0.15em] uppercase"
>
<kbd
class="px-1 sm:px-1.5 py-0.5 rounded bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50"
>ENTER</kbd
>
<span class="mx-0.5 sm:mx-1">TO SEND</span>
<span class="text-exo-medium-gray mx-1 sm:mx-2">|</span>
<kbd class="px-1 sm:px-1.5 py-0.5 rounded bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50">SHIFT+ENTER</kbd>
<kbd
class="px-1 sm:px-1.5 py-0.5 rounded bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50"
>SHIFT+ENTER</kbd
>
<span class="mx-0.5 sm:mx-1">NEW LINE</span>
<span class="text-exo-medium-gray mx-1 sm:mx-2">|</span>
<span class="text-exo-light-gray">DRAG & DROP OR PASTE FILES</span>

View File

@@ -1,12 +1,14 @@
<script lang="ts">
import {
messages,
currentResponse,
import {
messages,
currentResponse,
isLoading,
deleteMessage,
editAndRegenerate,
regenerateLastResponse
regenerateLastResponse,
setEditingImage
} from '$lib/stores/app.svelte';
import type { Message } from '$lib/stores/app.svelte';
import type { MessageAttachment } from '$lib/stores/app.svelte';
import MarkdownContent from './MarkdownContent.svelte';
@@ -365,10 +367,76 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
</div>
{/if}
<!-- Generated Images -->
{#if message.attachments?.some(a => a.type === 'generated-image')}
<div class="mb-3">
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
<div class="relative group/img inline-block">
<img
src={attachment.preview}
alt=""
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
/>
<!-- Button overlay -->
<div class="absolute top-2 right-2 flex gap-1 opacity-0 group-hover/img:opacity-100 transition-opacity">
<!-- Edit button -->
<button
type="button"
class="p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
onclick={() => {
if (attachment.preview) {
setEditingImage(attachment.preview, message);
}
}}
title="Edit image"
>
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" />
</svg>
</button>
<!-- Download button -->
<button
type="button"
class="p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
onclick={() => {
if (attachment.preview) {
const link = document.createElement('a');
link.href = attachment.preview;
link.download = `generated-image-${Date.now()}.png`;
link.click();
}
}}
title="Download image"
>
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
</svg>
</button>
</div>
</div>
{/each}
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{#if message.content === 'Generating image...' || message.content === 'Editing image...' || message.content?.startsWith('Generating...') || message.content?.startsWith('Editing...')}
<div class="flex items-center gap-3 text-exo-yellow">
<div class="relative">
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
</div>
<span class="font-mono tracking-wider uppercase text-sm">{message.content}</span>
</div>
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
{/if}
</div>
</div>

View File

@@ -0,0 +1,471 @@
<script lang="ts">
import {
imageGenerationParams,
setImageGenerationParams,
resetImageGenerationParams,
type ImageGenerationParams,
} from "$lib/stores/app.svelte";
interface Props {
isEditMode?: boolean;
}
let { isEditMode = false }: Props = $props();
let showAdvanced = $state(false);
// Custom dropdown state
let isSizeDropdownOpen = $state(false);
let isQualityDropdownOpen = $state(false);
let sizeButtonRef: HTMLButtonElement | undefined = $state();
let qualityButtonRef: HTMLButtonElement | undefined = $state();
const sizeDropdownPosition = $derived(() => {
if (!sizeButtonRef || !isSizeDropdownOpen) return { top: 0, left: 0, width: 0 };
const rect = sizeButtonRef.getBoundingClientRect();
return { top: rect.top, left: rect.left, width: rect.width };
});
const qualityDropdownPosition = $derived(() => {
if (!qualityButtonRef || !isQualityDropdownOpen) return { top: 0, left: 0, width: 0 };
const rect = qualityButtonRef.getBoundingClientRect();
return { top: rect.top, left: rect.left, width: rect.width };
});
const params = $derived(imageGenerationParams());
const inputFidelityOptions: ImageGenerationParams["inputFidelity"][] = [
"low",
"high",
];
function handleInputFidelityChange(value: ImageGenerationParams["inputFidelity"]) {
setImageGenerationParams({ inputFidelity: value });
}
const sizeOptions: ImageGenerationParams["size"][] = [
"512x512",
"768x768",
"1024x1024",
"1024x768",
"768x1024",
];
const qualityOptions: ImageGenerationParams["quality"][] = [
"low",
"medium",
"high",
];
function selectSize(value: ImageGenerationParams["size"]) {
setImageGenerationParams({ size: value });
isSizeDropdownOpen = false;
}
function selectQuality(value: ImageGenerationParams["quality"]) {
setImageGenerationParams({ quality: value });
isQualityDropdownOpen = false;
}
function handleSeedChange(event: Event) {
const input = event.target as HTMLInputElement;
const value = input.value.trim();
if (value === "") {
setImageGenerationParams({ seed: null });
} else {
const num = parseInt(value, 10);
if (!isNaN(num) && num >= 0) {
setImageGenerationParams({ seed: num });
}
}
}
function handleStepsChange(event: Event) {
const value = parseInt((event.target as HTMLInputElement).value, 10);
setImageGenerationParams({ numInferenceSteps: value });
}
function handleGuidanceChange(event: Event) {
const value = parseFloat((event.target as HTMLInputElement).value);
setImageGenerationParams({ guidance: value });
}
function handleNegativePromptChange(event: Event) {
const value = (event.target as HTMLTextAreaElement).value;
setImageGenerationParams({ negativePrompt: value || null });
}
function clearSteps() {
setImageGenerationParams({ numInferenceSteps: null });
}
function clearGuidance() {
setImageGenerationParams({ guidance: null });
}
function handleReset() {
resetImageGenerationParams();
showAdvanced = false;
}
const hasAdvancedParams = $derived(
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== ""),
);
</script>
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => isSizeDropdownOpen = !isSizeDropdownOpen}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen ? 'border-exo-yellow/70' : ''}"
>
{params.size}
</button>
<div class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen ? 'rotate-180' : ''}">
<svg class="w-3 h-3 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => isSizeDropdownOpen = false}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition().top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {
params.size === size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'
}"
>
{#if params.size === size}
<svg class="w-3 h-3 flex-shrink-0" fill="currentColor" viewBox="0 0 20 20">
<path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd" />
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
<!-- Quality -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>QUALITY:</span
>
<div class="relative">
<button
bind:this={qualityButtonRef}
type="button"
onclick={() => isQualityDropdownOpen = !isQualityDropdownOpen}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isQualityDropdownOpen ? 'border-exo-yellow/70' : ''}"
>
{params.quality.toUpperCase()}
</button>
<div class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isQualityDropdownOpen ? 'rotate-180' : ''}">
<svg class="w-3 h-3 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
</svg>
</div>
</div>
{#if isQualityDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => isQualityDropdownOpen = false}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {qualityDropdownPosition().top}px + 4px); left: {qualityDropdownPosition().left}px;"
>
<div class="py-1">
{#each qualityOptions as quality}
<button
type="button"
onclick={() => selectQuality(quality)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {
params.quality === quality
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'
}"
>
{#if params.quality === quality}
<svg class="w-3 h-3 flex-shrink-0" fill="currentColor" viewBox="0 0 20 20">
<path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd" />
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{quality.toUpperCase()}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
<!-- Input Fidelity (edit mode only) -->
{#if isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>FIDELITY:</span
>
<div class="flex rounded overflow-hidden border border-exo-yellow/30">
{#each inputFidelityOptions as fidelity}
<button
type="button"
onclick={() => handleInputFidelityChange(fidelity)}
class="px-2 py-1 text-xs font-mono uppercase transition-all duration-200 cursor-pointer {
params.inputFidelity === fidelity
? 'bg-exo-yellow text-exo-black'
: 'bg-exo-medium-gray/50 text-exo-light-gray hover:text-exo-yellow'
}"
title={fidelity === 'low' ? 'More creative variation' : 'Closer to original'}
>
{fidelity}
</button>
{/each}
</div>
</div>
{/if}
<!-- Spacer -->
<div class="flex-1"></div>
<!-- Advanced toggle -->
<button
type="button"
onclick={() => (showAdvanced = !showAdvanced)}
class="flex items-center gap-1 text-xs font-mono tracking-wider uppercase transition-colors duration-200 {showAdvanced ||
hasAdvancedParams
? 'text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
<span>ADVANCED</span>
<svg
class="w-3 h-3 transition-transform duration-200 {showAdvanced
? 'rotate-180'
: ''}"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
{#if hasAdvancedParams && !showAdvanced}
<span class="w-1.5 h-1.5 rounded-full bg-exo-yellow"></span>
{/if}
</button>
</div>
<!-- Advanced params section -->
{#if showAdvanced}
<div class="mt-3 pt-3 border-t border-exo-medium-gray/20 space-y-3">
<!-- Row 1: Seed and Steps -->
<div class="flex items-center gap-4 flex-wrap">
<!-- Seed -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SEED:</span
>
<input
type="number"
min="0"
value={params.seed ?? ""}
oninput={handleSeedChange}
placeholder="Random"
class="w-24 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow placeholder:text-exo-light-gray/50 transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
/>
</div>
<!-- Steps Slider -->
<div class="flex items-center gap-1.5 flex-1 min-w-[200px]">
<span
class="text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap"
>STEPS:</span
>
<div class="flex items-center gap-2 flex-1">
<input
type="range"
min="1"
max="100"
value={params.numInferenceSteps ?? 50}
oninput={handleStepsChange}
class="flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow"
/>
<span class="text-xs font-mono text-exo-yellow w-8 text-right">
{params.numInferenceSteps ?? "--"}
</span>
{#if params.numInferenceSteps !== null}
<button
type="button"
onclick={clearSteps}
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
title="Clear"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
{/if}
</div>
</div>
</div>
<!-- Row 2: Guidance -->
<div class="flex items-center gap-1.5">
<span
class="text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap"
>GUIDANCE:</span
>
<div class="flex items-center gap-2 flex-1 max-w-xs">
<input
type="range"
min="1"
max="20"
step="0.5"
value={params.guidance ?? 7.5}
oninput={handleGuidanceChange}
class="flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow"
/>
<span class="text-xs font-mono text-exo-yellow w-8 text-right">
{params.guidance !== null ? params.guidance.toFixed(1) : "--"}
</span>
{#if params.guidance !== null}
<button
type="button"
onclick={clearGuidance}
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
title="Clear"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
{/if}
</div>
</div>
<!-- Row 3: Negative Prompt -->
<div class="flex flex-col gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>NEGATIVE PROMPT:</span
>
<textarea
value={params.negativePrompt ?? ""}
oninput={handleNegativePromptChange}
placeholder="Things to avoid in the image..."
rows={2}
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1.5 text-xs font-mono text-exo-yellow placeholder:text-exo-light-gray/50 resize-none transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
></textarea>
</div>
<!-- Reset Button -->
<div class="flex justify-end pt-1">
<button
type="button"
onclick={handleReset}
class="text-xs font-mono tracking-wider uppercase text-exo-light-gray hover:text-exo-yellow transition-colors duration-200"
>
RESET TO DEFAULTS
</button>
</div>
</div>
{/if}
</div>
<style>
/* Custom range slider styling */
input[type="range"]::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
width: 12px;
height: 12px;
border-radius: 50%;
background: #ffd700;
cursor: pointer;
border: none;
}
input[type="range"]::-moz-range-thumb {
width: 12px;
height: 12px;
border-radius: 50%;
background: #ffd700;
cursor: pointer;
border: none;
}
/* Hide number input spinners */
input[type="number"]::-webkit-inner-spin-button,
input[type="number"]::-webkit-outer-spin-button {
-webkit-appearance: none;
margin: 0;
}
input[type="number"] {
-moz-appearance: textfield;
}
</style>

View File

@@ -5,3 +5,4 @@ export { default as ChatAttachments } from "./ChatAttachments.svelte";
export { default as ChatSidebar } from "./ChatSidebar.svelte";
export { default as ModelCard } from "./ModelCard.svelte";
export { default as MarkdownContent } from "./MarkdownContent.svelte";
export { default as ImageParamsPanel } from "./ImageParamsPanel.svelte";

View File

File diff suppressed because it is too large Load Diff

View File

@@ -47,7 +47,30 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
const modelTasks = $derived(() => {
const tasks: Record<string, string[]> = {};
for (const model of models) {
if (model.tasks && model.tasks.length > 0) {
// Map by short ID
tasks[model.id] = model.tasks;
// Also map by hugging_face_id from the API response
if (model.hugging_face_id) {
tasks[model.hugging_face_id] = model.tasks;
}
}
}
return tasks;
});
// Helper to check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
if (!model?.tasks) return false;
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
}
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -1270,6 +1293,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
placeholder="Ask anything"
showHelperText={false}
showModelSelector={true}
modelTasks={modelTasks()}
/>
</div>
</div>
@@ -1491,8 +1515,18 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{@const foundModel = models.find(m => m.id === selectedModelId)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="flex items-center gap-2 text-exo-light-gray truncate">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{foundModel.name || foundModel.id}</span>
</span>
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
</span>
{:else}
@@ -1537,6 +1571,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(model.id)}
<button
type="button"
onclick={() => {
@@ -1556,7 +1591,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
: 'text-white/30 cursor-default'
}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
@@ -1753,7 +1797,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
<div class="max-w-7xl mx-auto">
<ChatForm placeholder="Ask anything" showModelSelector={true} />
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
</div>
</div>
</div>

View File

@@ -1,6 +1,6 @@
# EXO API Technical Reference
This document describes the REST API exposed by the **EXO ** service, as implemented in:
This document describes the REST API exposed by the **EXO** service, as implemented in:
`src/exo/master/api.py`
@@ -183,7 +183,70 @@ Same schema as `/v1/chat/completions`.
**Response:**
Chat completion plus benchmarking metrics.
## 5. Complete Endpoint Summary
## 5. Image Generation & Editing
### Image Generation
**POST** `/v1/images/generations`
Executes an image generation request using an OpenAI-compatible schema with additional advanced_params.
**Request body (example):**
```json
{
"prompt": "a robot playing chess",
"model": "flux-dev",
"stream": false,
}
```
**Advanced Parameters (`advanced_params`):**
| Parameter | Type | Constraints | Description |
|-----------|------|-------------|-------------|
| `seed` | int | >= 0 | Random seed for reproducible generation |
| `num_inference_steps` | int | 1-100 | Number of denoising steps |
| `guidance` | float | 1.0-20.0 | Classifier-free guidance scale |
| `negative_prompt` | string | - | Text describing what to avoid in the image |
**Response:**
OpenAI-compatible image generation response.
### Benchmarked Image Generation
**POST** `/bench/images/generations`
Same as `/v1/images/generations`, but also returns generation statistics.
**Request body:**
Same schema as `/v1/images/generations`.
**Response:**
Image generation plus benchmarking metrics.
### Image Editing
**POST** `/v1/images/edits`
Executes an image editing request using an OpenAI-compatible schema with additional advanced_params (same as `/v1/images/generations`).
**Response:**
Same format as `/v1/images/generations`.
### Benchmarked Image Editing
**POST** `/bench/images/edits`
Same as `/v1/images/edits`, but also returns generation statistics.
**Request:**
Same schema as `/v1/images/edits`.
**Response:**
Same format as `/bench/images/generations`, including `generation_stats`.
## 6. Complete Endpoint Summary
```
GET /node_id
@@ -203,10 +266,16 @@ GET /v1/models
POST /v1/chat/completions
POST /bench/chat/completions
POST /v1/images/generations
POST /bench/images/generations
POST /v1/images/edits
POST /bench/images/edits
```
## 6. Notes
## 7. Notes
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
* The `/v1/chat/completions` endpoint is compatible with the OpenAI Chat API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
* The `/v1/images/generations` and `/v1/images/edits` endpoints are compatible with the OpenAI Images API format.
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.

View File

@@ -24,6 +24,9 @@ dependencies = [
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux>=0.14.2",
"python-multipart>=0.0.21",
]
[project.scripts]

View File

@@ -1,12 +1,14 @@
import base64
import json
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
from typing import Literal, cast
import anyio
from anyio import BrokenResourceError, create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException, Request
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
@@ -17,6 +19,7 @@ from loguru import logger
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
@@ -24,6 +27,8 @@ from exo.shared.models.model_meta import get_model_card
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
BenchImageGenerationResponse,
BenchImageGenerationTaskParams,
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
@@ -34,6 +39,11 @@ from exo.shared.types.api import (
ErrorResponse,
FinishReason,
GenerationStats,
ImageData,
ImageEditsInternalParams,
ImageGenerationResponse,
ImageGenerationStats,
ImageGenerationTaskParams,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -41,14 +51,17 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
@@ -88,10 +101,13 @@ def chunk_to_response(
async def resolve_model_card(model_id: str) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
else:
return await get_model_card(model_id)
return MODEL_CARDS[model_id]
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await get_model_card(model_id)
class API:
@@ -136,6 +152,7 @@ class API:
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -144,6 +161,7 @@ class API:
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._chat_completion_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
def unpause(self, result_clock: int):
@@ -191,6 +209,12 @@ class API:
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/v1/images/generations", response_model=None)(
self.image_generations
)
self.app.post("/bench/images/generations")(self.bench_image_generations)
self.app.post("/v1/images/edits", response_model=None)(self.image_edits)
self.app.post("/bench/images/edits")(self.bench_image_edits)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -598,6 +622,379 @@ class API:
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def _validate_image_model(self, model: str) -> ModelId:
"""Validate model exists and return resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(model)
resolved_model = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved_model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {resolved_model}"
)
return resolved_model
async def image_generations(
self, payload: ImageGenerationTaskParams
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image generation requests.
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
command = ImageGeneration(
request_params=payload,
)
await self._send(command)
# Check if streaming is requested
if payload.stream and payload.partial_images and payload.partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
),
media_type="text/event-stream",
)
# Non-streaming: collect all image chunks
return await self._collect_image_generation(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
)
async def _generate_image_stream(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> AsyncGenerator[str, None]:
"""Generate SSE stream of partial and final images."""
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
image_total_chunks: dict[tuple[int, bool], int] = {}
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
with recv as chunks:
async for chunk in chunks:
key = (chunk.image_index, chunk.is_partial)
if key not in image_chunks:
image_chunks[key] = {}
image_total_chunks[key] = chunk.total_chunks
image_metadata[key] = (
chunk.partial_index,
chunk.total_partials,
)
image_chunks[key][chunk.chunk_index] = chunk.data
# Check if this image is complete
if len(image_chunks[key]) == image_total_chunks[key]:
full_data = "".join(
image_chunks[key][i] for i in range(len(image_chunks[key]))
)
partial_idx, total_partials = image_metadata[key]
if chunk.is_partial:
# Yield partial image event
event_data = {
"type": "partial",
"partial_index": partial_idx,
"total_partials": total_partials,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
else:
# Final image
event_data = {
"type": "final",
"image_index": chunk.image_index,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
images_complete += 1
if images_complete >= num_images:
yield "data: [DONE]\n\n"
break
# Clean up completed image chunks
del image_chunks[key]
del image_total_chunks[key]
del image_metadata[key]
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def _collect_image_chunks(
self,
command_id: CommandId,
num_images: int,
response_format: str,
capture_stats: bool = False,
) -> tuple[list[ImageData], ImageGenerationStats | None]:
"""Collect image chunks and optionally capture stats."""
# Track chunks per image: {image_index: {chunk_index: data}}
# Only track non-partial (final) images
image_chunks: dict[int, dict[int, str]] = {}
image_total_chunks: dict[int, int] = {}
images_complete = 0
stats: ImageGenerationStats | None = None
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
while images_complete < num_images:
with recv as chunks:
async for chunk in chunks:
if chunk.is_partial:
continue
if chunk.image_index not in image_chunks:
image_chunks[chunk.image_index] = {}
image_total_chunks[chunk.image_index] = chunk.total_chunks
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
if capture_stats and chunk.stats is not None:
stats = chunk.stats
if (
len(image_chunks[chunk.image_index])
== image_total_chunks[chunk.image_index]
):
images_complete += 1
if images_complete >= num_images:
break
images: list[ImageData] = []
for image_idx in range(num_images):
chunks_dict = image_chunks[image_idx]
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
images.append(
ImageData(
b64_json=full_data if response_format == "b64_json" else None,
url=None,
)
)
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def _collect_image_generation(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> ImageGenerationResponse:
"""Collect all image chunks (non-streaming) and return a single response."""
images, _ = await self._collect_image_chunks(
command_id, num_images, response_format, capture_stats=False
)
return ImageGenerationResponse(data=images)
async def _collect_image_generation_with_stats(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> BenchImageGenerationResponse:
images, stats = await self._collect_image_chunks(
command_id, num_images, response_format, capture_stats=True
)
return BenchImageGenerationResponse(data=images, generation_stats=stats)
async def bench_image_generations(
self, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(payload.model)
payload.stream = False
payload.partial_images = 0
command = ImageGeneration(
request_params=payload,
)
await self._send(command)
return await self._collect_image_generation_with_stats(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
)
async def _send_image_edits_command(
self,
image: UploadFile,
prompt: str,
model: str,
n: int,
size: str,
response_format: Literal["url", "b64_json"],
input_fidelity: Literal["low", "high"],
stream: bool,
partial_images: int,
bench: bool,
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
image_strength = 0.7 if input_fidelity == "high" else 0.3
data_chunks = [
image_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
command = ImageEdits(
request_params=ImageEditsInternalParams(
image_data="",
total_input_chunks=total_chunks,
prompt=prompt,
model=resolved_model,
n=n,
size=size,
response_format=response_format,
image_strength=image_strength,
stream=stream,
partial_images=partial_images,
bench=bench,
),
)
logger.info(
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(data_chunks):
await self._send(
SendInputChunk(
chunk=InputImageChunk(
idx=chunk_index,
model=resolved_model,
command_id=command.command_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
)
)
)
await self._send(command)
return command
async def image_edits(
self,
image: UploadFile = File(...),
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
stream: str = Form("false"),
partial_images: str = Form("0"),
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image editing requests (img2img)."""
# Parse string form values to proper types
stream_bool = stream.lower() in ("true", "1", "yes")
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
n=n,
size=size,
response_format=response_format,
input_fidelity=input_fidelity,
stream=stream_bool,
partial_images=partial_images_int,
bench=False,
)
if stream_bool and partial_images_int > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=n,
response_format=response_format,
),
media_type="text/event-stream",
)
return await self._collect_image_generation(
command_id=command.command_id,
num_images=n,
response_format=response_format,
)
async def bench_image_edits(
self,
image: UploadFile = File(...),
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
) -> BenchImageGenerationResponse:
"""Handle benchmark image editing requests with generation stats."""
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
n=n,
size=size,
response_format=response_format,
input_fidelity=input_fidelity,
stream=False,
partial_images=0,
bench=True,
)
return await self._collect_image_generation_with_stats(
command_id=command.command_id,
num_images=n,
response_format=response_format,
)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -619,6 +1016,7 @@ class API:
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks],
)
for card in MODEL_CARDS.values()
]
@@ -657,13 +1055,26 @@ class API:
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(
event.command_id, None
)
elif event.command_id in self._image_generation_queues:
assert isinstance(event.chunk, ImageChunk)
queue = self._image_generation_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -16,8 +16,11 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskFinished,
TestCommand,
)
@@ -26,6 +29,7 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
NodeGatheredInfo,
NodeTimedOut,
@@ -36,6 +40,12 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
from exo.shared.types.tasks import (
ImageGeneration as ImageGenerationTask,
)
from exo.shared.types.tasks import (
TaskId,
TaskStatus,
@@ -100,13 +110,14 @@ class Master:
async for forwarder_command in commands:
try:
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.command
instance_task_counts: dict[InstanceId, int] = {}
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
@@ -147,6 +158,90 @@ class Master:
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageGeneration():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageEdits():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
@@ -176,6 +271,13 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
generated_events.append(
InputChunkReceived(
command_id=chunk.command_id,
chunk=chunk,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -7,7 +7,7 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
@@ -115,6 +115,7 @@ async def test_master():
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
@@ -172,6 +173,7 @@ async def test_master():
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
device_rank=0,
world_size=1,

View File

@@ -10,7 +10,7 @@ from exo.master.tests.conftest import (
create_rdma_connection,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
@@ -50,6 +50,7 @@ def model_card() -> ModelCard:
n_layers=10,
hidden_size=30,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
@@ -169,6 +170,7 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
n_layers=10,
hidden_size=1000,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
)
placements = place_instance(cic, topology, {}, node_memory, node_network)
@@ -195,6 +197,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
n_layers=10,
hidden_size=1000,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
)
placements = place_instance(cic, topology, {}, node_memory, node_network)
@@ -221,6 +224,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
n_layers=10,
hidden_size=1000,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
)

View File

@@ -12,7 +12,7 @@ from exo.master.tests.conftest import (
create_node_memory,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
@@ -238,6 +238,7 @@ def test_get_shard_assignments(
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
cycles = topology.get_cycles()
@@ -517,6 +518,7 @@ def test_get_shard_assignments_insufficient_memory_raises():
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]

View File

@@ -9,6 +9,7 @@ from exo.shared.types.events import (
ChunkGenerated,
Event,
IndexedEvent,
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
NodeDownloadProgress,
@@ -52,8 +53,8 @@ def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged()
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
): # Pass-through events that don't modify state
return state
case InstanceCreated():
return apply_instance_created(event, state)

View File

@@ -45,3 +45,5 @@ LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024

View File

@@ -1,4 +1,7 @@
from pydantic import PositiveInt
from enum import Enum
from typing import Annotated
from pydantic import BeforeValidator, PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
@@ -13,12 +16,38 @@ class ModelId(Id):
return self.split("/")[-1]
class ModelTask(str, Enum):
TextGeneration = "TextGeneration"
TextToImage = "TextToImage"
ImageToImage = "ImageToImage"
def _coerce_model_task(v: object) -> object:
if isinstance(v, str):
return ModelTask(v)
return v
CoercedModelTask = Annotated[ModelTask, BeforeValidator(_coerce_model_task)]
class ComponentInfo(CamelCaseModel):
component_name: str
component_path: str
storage_size: Memory
n_layers: PositiveInt | None
can_shard: bool
safetensors_index_filename: str | None
class ModelCard(CamelCaseModel):
model_id: ModelId
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
tasks: list[CoercedModelTask]
components: list[ComponentInfo] | None = None
MODEL_CARDS: dict[str, ModelCard] = {
@@ -29,6 +58,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"deepseek-v3.1-8bit": ModelCard(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
@@ -36,6 +66,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
@@ -44,6 +75,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"kimi-k2-thinking": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
@@ -51,6 +83,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.1
"llama-3.1-8b": ModelCard(
@@ -59,6 +92,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-8b-8bit": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
@@ -66,6 +100,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-8b-bf16": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
@@ -73,6 +108,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=32,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.1-70b": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
@@ -80,6 +116,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.2
"llama-3.2-1b": ModelCard(
@@ -88,6 +125,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=16,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.2-3b": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
@@ -95,6 +133,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=28,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.2-3b-8bit": ModelCard(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
@@ -102,6 +141,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=28,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.3
"llama-3.3-70b": ModelCard(
@@ -110,6 +150,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.3-70b-8bit": ModelCard(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
@@ -117,6 +158,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"llama-3.3-70b-fp16": ModelCard(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
@@ -124,6 +166,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=80,
hidden_size=8192,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# qwen3
"qwen3-0.6b": ModelCard(
@@ -132,6 +175,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=28,
hidden_size=1024,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"qwen3-0.6b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
@@ -139,6 +183,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=28,
hidden_size=1024,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"qwen3-30b": ModelCard(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
@@ -146,6 +191,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-30b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
@@ -153,6 +199,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
@@ -160,6 +207,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
@@ -167,6 +215,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
@@ -174,6 +223,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
@@ -181,6 +231,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=48,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-235b-a22b-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
@@ -188,6 +239,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=94,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-235b-a22b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
@@ -195,6 +247,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=94,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
@@ -202,6 +255,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=62,
hidden_size=6144,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
@@ -209,6 +263,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=62,
hidden_size=6144,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
@@ -217,6 +272,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=36,
hidden_size=2880,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
@@ -224,6 +280,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=24,
hidden_size=2880,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
@@ -233,6 +290,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=46,
hidden_size=4096,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
"glm-4.5-air-bf16": ModelCard(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
@@ -240,6 +298,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=46,
hidden_size=4096,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
@@ -248,6 +307,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
@@ -255,6 +315,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-8bit-gs32": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
@@ -262,6 +323,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=91,
hidden_size=5120,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# glm 4.7 flash
"glm-4.7-flash-4bit": ModelCard(
@@ -270,6 +332,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-5bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
@@ -277,6 +340,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
@@ -284,6 +348,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"glm-4.7-flash-8bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
@@ -291,6 +356,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=47,
hidden_size=2048,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
@@ -299,6 +365,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=61,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"minimax-m2.1-3bit": ModelCard(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
@@ -306,5 +373,158 @@ MODEL_CARDS: dict[str, ModelCard] = {
n_layers=61,
hidden_size=3072,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
}

View File

@@ -6,7 +6,7 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId, ModelTask
from exo.shared.types.memory import Memory
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
@@ -119,4 +119,7 @@ async def _get_model_card(model_id: str) -> ModelCard:
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
tasks=model_card.tasks
if model_card is not None
else [ModelTask.TextGeneration],
)

View File

@@ -7,7 +7,7 @@ import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@@ -37,6 +37,7 @@ def get_pipeline_shard_metadata(
n_layers=32,
hidden_size=1000,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
device_rank=device_rank,
world_size=world_size,

View File

@@ -1,6 +1,8 @@
import time
from typing import Any, Literal
from collections.abc import Generator
from typing import Annotated, Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
@@ -39,6 +41,7 @@ class ModelListModel(BaseModel):
tags: list[str] = Field(default=[])
storage_size_megabytes: int = Field(default=0)
supports_tensor: bool = Field(default=False)
tasks: list[str] = Field(default=[])
class ModelList(BaseModel):
@@ -137,6 +140,19 @@ class GenerationStats(BaseModel):
peak_memory_usage: Memory
class ImageGenerationStats(BaseModel):
seconds_per_step: float
total_generation_time: float
num_inference_steps: int
num_images: int
image_width: int
image_height: int
peak_memory_usage: Memory
class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
@@ -213,3 +229,103 @@ class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class AdvancedImageParams(BaseModel):
seed: Annotated[int, Field(ge=0)] | None = None
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None
negative_prompt: str | None = None
class ImageGenerationTaskParams(BaseModel):
prompt: str
background: str | None = None
model: str
moderation: str | None = None
n: int | None = 1
output_compression: int | None = None
output_format: Literal["png", "jpeg", "webp"] = "png"
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
stream: bool | None = False
style: str | None = "vivid"
user: str | None = None
advanced_params: AdvancedImageParams | None = None
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
bench: bool = True
class ImageEditsTaskParams(BaseModel):
image: UploadFile
prompt: str
background: str | None = None
input_fidelity: float | None = None
mask: UploadFile | None = None
model: str
n: int | None = 1
output_compression: int | None = None
output_format: Literal["png", "jpeg", "webp"] = "png"
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
stream: bool | None = False
user: str | None = None
advanced_params: AdvancedImageParams | None = None
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
class ImageEditsInternalParams(BaseModel):
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
total_input_chunks: int = 0
prompt: str
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
image_strength: float | None = 0.7
stream: bool = False
partial_images: int | None = 0
advanced_params: AdvancedImageParams | None = None
bench: bool = False
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} chars>"
elif name is not None:
yield name, value
class ImageData(BaseModel):
b64_json: str | None = None
url: str | None = None
revised_prompt: str | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "b64_json" and self.b64_json is not None:
yield name, f"<{len(self.b64_json)} chars>"
elif name is not None:
yield name, value
class ImageGenerationResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
data: list[ImageData]
class BenchImageGenerationResponse(ImageGenerationResponse):
generation_stats: ImageGenerationStats | None = None

View File

@@ -1,10 +1,13 @@
from collections.abc import Generator
from enum import Enum
from typing import Any
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats
from exo.shared.types.api import GenerationStats, ImageGenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .common import CommandId
class ChunkType(str, Enum):
@@ -26,7 +29,35 @@ class TokenChunk(BaseChunk):
class ImageChunk(BaseChunk):
data: bytes
data: str
chunk_index: int
total_chunks: int
image_index: int
is_partial: bool = False
partial_index: int | None = None
total_partials: int | None = None
stats: ImageGenerationStats | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data" and hasattr(value, "__len__"):
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
class InputImageChunk(BaseChunk):
command_id: CommandId
data: str
chunk_index: int
total_chunks: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data" and hasattr(value, "__len__"):
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
GenerationChunk = TokenChunk | ImageChunk

View File

@@ -1,7 +1,12 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -20,6 +25,14 @@ class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
class ImageGeneration(BaseCommand):
request_params: ImageGenerationTaskParams
class ImageEdits(BaseCommand):
request_params: ImageEditsInternalParams
class PlaceInstance(BaseCommand):
model_card: ModelCard
sharding: Sharding
@@ -39,6 +52,12 @@ class TaskFinished(BaseCommand):
finished_command_id: CommandId
class SendInputChunk(BaseCommand):
"""Command to send an input image chunk (converted to event by master)."""
chunk: InputImageChunk
class RequestEventLog(BaseCommand):
since_idx: int
@@ -47,10 +66,13 @@ Command = (
TestCommand
| RequestEventLog
| ChatCompletion
| ImageGeneration
| ImageEdits
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskFinished
| SendInputChunk
)

View File

@@ -3,7 +3,7 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
@@ -96,6 +96,11 @@ class ChunkGenerated(BaseEvent):
chunk: GenerationChunk
class InputChunkReceived(BaseEvent):
command_id: CommandId
chunk: InputImageChunk
class TopologyEdgeCreated(BaseEvent):
conn: Connection
@@ -119,6 +124,7 @@ Event = (
| NodeGatheredInfo
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| TopologyEdgeCreated
| TopologyEdgeDeleted
)

View File

@@ -2,7 +2,11 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.common import CommandId, Id
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
@@ -56,6 +60,22 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class ImageEdits(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageEditsInternalParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
@@ -67,5 +87,7 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| ImageGeneration
| ImageEdits
| Shutdown
)

View File

@@ -1,4 +1,7 @@
from exo.shared.types.api import FinishReason, GenerationStats
from collections.abc import Generator
from typing import Any, Literal
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
from exo.utils.pydantic_ext import TaggedModel
@@ -18,5 +21,32 @@ class GenerationResponse(BaseRunnerResponse):
stats: GenerationStats | None = None
class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
stats: ImageGenerationStats | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
class PartialImageResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
partial_index: int
total_partials: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -15,6 +15,7 @@ import aiofiles
import aiofiles.os as aios
import aiohttp
import certifi
from huggingface_hub._snapshot_download import snapshot_download
from loguru import logger
from pydantic import (
BaseModel,
@@ -445,12 +446,31 @@ def calculate_repo_progress(
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
index_files_dir = snapshot_download(
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
return index_data.weight_map
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
weight_map: dict[str, str] = {}
for index_file in index_files:
relative_dir = index_file.parent.relative_to(index_files_dir)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
if relative_dir != Path("."):
prefixed_weight_map = {
f"{relative_dir}/{key}": str(relative_dir / value)
for key, value in index_data.weight_map.items()
}
weight_map = weight_map | prefixed_weight_map
else:
weight_map = weight_map | index_data.weight_map
return weight_map
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
@@ -555,8 +575,6 @@ async def download_shard(
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_card.model_id), revision, recursive=True
)

View File

@@ -100,26 +100,68 @@ def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list
"*.py",
"tokenizer.model",
"tiktoken.model",
"*/spiece.model",
"*.tiktoken",
"*.txt",
"*.jinja",
]
)
shard_specific_patterns: set[str] = set()
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num <= shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
if shard.model_card.components is not None:
shardable_component = next(
(c for c in shard.model_card.components if c.can_shard), None
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
if weight_map and shardable_component:
for tensor_name, filename in weight_map.items():
# Strip component prefix from tensor name (added by weight map namespacing)
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
if "/" in tensor_name:
_, tensor_name_no_prefix = tensor_name.split("/", 1)
else:
tensor_name_no_prefix = tensor_name
# Determine which component this file belongs to from filename
component_path = Path(filename).parts[0] if "/" in filename else None
if component_path == shardable_component.component_path.rstrip("/"):
layer_num = extract_layer_num(tensor_name_no_prefix)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
if shard.is_first_layer or shard.is_last_layer:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns = set(["*.safetensors"])
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
for component in shard.model_card.components:
if not component.can_shard and component.safetensors_index_filename is None:
component_pattern = f"{component.component_path.rstrip('/')}/*"
shard_specific_patterns.add(component_pattern)
else:
shard_specific_patterns = set(["*.safetensors"])
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
else:
shard_specific_patterns = set(["*.safetensors"])
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

View File

@@ -5,7 +5,7 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
@@ -92,6 +92,7 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
n_layers=1,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
device_rank=0,
world_size=1,

View File

@@ -0,0 +1,12 @@
from exo.worker.engines.image.distributed_model import (
DistributedImageModel,
initialize_image_model,
)
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
__all__ = [
"DistributedImageModel",
"generate_image",
"initialize_image_model",
"warmup_image_generator",
]

View File

@@ -0,0 +1,50 @@
from enum import Enum
from math import ceil
from pydantic import BaseModel
class BlockType(Enum):
JOINT = "joint" # Separate image/text streams
SINGLE = "single" # Concatenated streams
class TransformerBlockConfig(BaseModel):
model_config = {"frozen": True}
block_type: BlockType
count: int
has_separate_text_output: bool # True for joint blocks that output text separately
class ImageModelConfig(BaseModel):
model_family: str
block_configs: tuple[TransformerBlockConfig, ...]
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
num_sync_steps_factor: float # Fraction of steps for sync phase
guidance_scale: float | None = None # None or <= 1.0 disables CFG
@property
def total_blocks(self) -> int:
return sum(bc.count for bc in self.block_configs)
@property
def joint_block_count(self) -> int:
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
)
@property
def single_block_count(self) -> int:
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
)
def get_steps_for_quality(self, quality: str) -> int:
return self.default_steps[quality]
def get_num_sync_steps(self, steps: int) -> int:
return ceil(steps * self.num_sync_steps_factor)

View File

@@ -0,0 +1,166 @@
from collections.abc import Generator
from pathlib import Path
from typing import Any, Literal, Optional
import mlx.core as mx
from mflux.models.common.config.config import Config
from PIL import Image
from exo.shared.types.api import AdvancedImageParams
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,
get_config_for_model,
)
from exo.worker.engines.image.models.base import ModelAdapter
from exo.worker.engines.image.pipeline import DiffusionRunner
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
from exo.worker.runner.bootstrap import logger
class DistributedImageModel:
_config: ImageModelConfig
_adapter: ModelAdapter[Any, Any]
_runner: DiffusionRunner
def __init__(
self,
model_id: str,
local_path: Path,
shard_metadata: PipelineShardMetadata,
group: Optional[mx.distributed.Group] = None,
quantize: int | None = None,
):
config = get_config_for_model(model_id)
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
if group is not None:
adapter.slice_transformer_blocks(
start_layer=shard_metadata.start_layer,
end_layer=shard_metadata.end_layer,
)
runner = DiffusionRunner(
config=config,
adapter=adapter,
group=group,
shard_metadata=shard_metadata,
)
if group is not None:
logger.info("Initialized distributed diffusion runner")
mx.eval(adapter.model.parameters())
# TODO(ciaran): Do we need this?
mx.eval(adapter.model)
mx_barrier(group)
logger.info(f"Transformer sharded for rank {group.rank()}")
else:
logger.info("Single-node initialization")
self._config = config
self._adapter = adapter
self._runner = runner
@classmethod
def from_bound_instance(
cls, bound_instance: BoundInstance
) -> "DistributedImageModel":
model_id = bound_instance.bound_shard.model_card.model_id
model_path = build_model_path(model_id)
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, PipelineShardMetadata):
raise ValueError("Expected PipelineShardMetadata for image generation")
is_distributed = (
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
)
if is_distributed:
logger.info("Starting distributed init for image model")
group = mlx_distributed_init(bound_instance)
else:
group = None
return cls(
model_id=model_id,
local_path=model_path,
shard_metadata=shard_metadata,
group=group,
)
def get_steps_for_quality(self, quality: Literal["low", "medium", "high"]) -> int:
"""Get the number of inference steps for a quality level."""
return self._config.get_steps_for_quality(quality)
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"] = "medium",
seed: int = 2,
image_path: Path | None = None,
partial_images: int = 0,
advanced_params: AdvancedImageParams | None = None,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
if (
advanced_params is not None
and advanced_params.num_inference_steps is not None
):
steps = advanced_params.num_inference_steps
else:
steps = self._config.get_steps_for_quality(quality)
guidance_override: float | None = None
if advanced_params is not None and advanced_params.guidance is not None:
guidance_override = advanced_params.guidance
negative_prompt: str | None = None
if advanced_params is not None and advanced_params.negative_prompt is not None:
negative_prompt = advanced_params.negative_prompt
# For edit mode: compute dimensions from input image
# This also stores image_paths in the adapter for encode_prompt()
if image_path is not None:
computed_dims = self._adapter.set_image_dimensions(image_path)
if computed_dims is not None:
# Override user-provided dimensions with computed ones
width, height = computed_dims
config = Config(
num_inference_steps=steps,
height=height,
width=width,
image_path=image_path,
model_config=self._adapter.model.model_config,
)
num_sync_steps = self._config.get_num_sync_steps(steps)
for result in self._runner.generate_image(
runtime_config=config,
prompt=prompt,
seed=seed,
partial_images=partial_images,
guidance_override=guidance_override,
negative_prompt=negative_prompt,
num_sync_steps=num_sync_steps,
):
if isinstance(result, tuple):
# Partial image: (GeneratedImage, partial_index, total_partials)
image, partial_idx, total_partials = result
yield (image, partial_idx, total_partials)
else:
logger.info("generated image")
yield result
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
return DistributedImageModel.from_bound_instance(bound_instance)

View File

@@ -0,0 +1,170 @@
import base64
import io
import random
import tempfile
import time
from pathlib import Path
from typing import Generator, Literal
import mlx.core as mx
from PIL import Image
from exo.shared.types.api import (
AdvancedImageParams,
ImageEditsInternalParams,
ImageGenerationStats,
ImageGenerationTaskParams,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.runner_response import (
ImageGenerationResponse,
PartialImageResponse,
)
from exo.worker.engines.image.distributed_model import DistributedImageModel
def parse_size(size_str: str | None) -> tuple[int, int]:
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
if not size_str or size_str == "auto":
size_str = "1024x1024"
try:
parts = size_str.split("x")
if len(parts) == 2:
width, height = int(parts[0]), int(parts[1])
return (width, height)
except (ValueError, AttributeError):
pass
# Default fallback
return (1024, 1024)
def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
"""Warmup the image generator with a small image."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a small dummy image for warmup (needed for edit models)
dummy_image = Image.new("RGB", (256, 256), color=(128, 128, 128))
dummy_path = Path(tmpdir) / "warmup.png"
dummy_image.save(dummy_path)
warmup_params = AdvancedImageParams(num_inference_steps=2)
for result in model.generate(
prompt="Warmup",
height=256,
width=256,
quality="low",
image_path=dummy_path,
advanced_params=warmup_params,
):
if not isinstance(result, tuple):
return result
return None
def generate_image(
model: DistributedImageModel,
task: ImageGenerationTaskParams | ImageEditsInternalParams,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
When partial_images > 0 or stream=True, yields PartialImageResponse for
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0)
ImageGenerationResponse for the final complete image
"""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
advanced_params = task.advanced_params
if advanced_params is not None and advanced_params.seed is not None:
seed = advanced_params.seed
else:
seed = random.randint(0, 2**32 - 1)
is_bench = getattr(task, "bench", False)
generation_start_time: float = 0.0
if is_bench:
mx.reset_peak_memory()
generation_start_time = time.perf_counter()
partial_images = task.partial_images or (3 if task.stream else 0)
image_path: Path | None = None
with tempfile.TemporaryDirectory() as tmpdir:
if isinstance(task, ImageEditsInternalParams):
# Decode base64 image data and save to temp file
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
# Iterate over generator results
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
)
else:
image = result
stats: ImageGenerationStats | None = None
if is_bench:
generation_end_time = time.perf_counter()
total_generation_time = generation_end_time - generation_start_time
num_inference_steps = model.get_steps_for_quality(quality)
seconds_per_step = (
total_generation_time / num_inference_steps
if num_inference_steps > 0
else 0.0
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=task.n or 1,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
)

View File

@@ -0,0 +1,84 @@
from pathlib import Path
from typing import Any, Callable
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter
from exo.worker.engines.image.models.flux import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
FluxModelAdapter,
)
from exo.worker.engines.image.models.qwen import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
QwenEditModelAdapter,
QwenModelAdapter,
)
__all__: list[str] = []
# Type alias for adapter factory functions
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter[Any, Any]]
# Registry maps model_family string to adapter factory
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
"flux": FluxModelAdapter,
"qwen-edit": QwenEditModelAdapter,
"qwen": QwenModelAdapter,
}
# Config registry: maps model ID patterns to configs
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
"qwen-image": QWEN_IMAGE_CONFIG,
}
def get_config_for_model(model_id: str) -> ImageModelConfig:
"""Get configuration for a model ID.
Args:
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
Returns:
The model configuration
Raises:
ValueError: If no configuration found for model ID
"""
model_id_lower = model_id.lower()
for pattern, config in _CONFIG_REGISTRY.items():
if pattern in model_id_lower:
return config
raise ValueError(f"No configuration found for model: {model_id}")
def create_adapter_for_model(
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
) -> ModelAdapter[Any, Any]:
"""Create a model adapter for the given configuration.
Args:
config: The model configuration
model_id: The model identifier
local_path: Path to the model weights
quantize: Optional quantization bits
Returns:
A ModelAdapter instance
Raises:
ValueError: If no adapter found for model family
"""
factory = _ADAPTER_REGISTRY.get(config.model_family)
if factory is None:
raise ValueError(f"No adapter found for model family: {config.model_family}")
return factory(config, model_id, local_path, quantize)

View File

@@ -0,0 +1,295 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, TypeVar
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
from mflux.utils.image_util import ImageUtil
from PIL import Image
from exo.worker.engines.image.config import ImageModelConfig
if TYPE_CHECKING:
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
ModelT = TypeVar("ModelT")
TransformerT = TypeVar("TransformerT")
RotaryEmbeddings = mx.array | tuple[mx.array, mx.array]
class PromptData(ABC):
@property
@abstractmethod
def prompt_embeds(self) -> mx.array: ...
@property
@abstractmethod
def pooled_prompt_embeds(self) -> mx.array: ...
@property
@abstractmethod
def negative_prompt_embeds(self) -> mx.array | None: ...
@property
@abstractmethod
def negative_pooled_prompt_embeds(self) -> mx.array | None: ...
@abstractmethod
def get_encoder_hidden_states_mask(
self, positive: bool = True
) -> mx.array | None: ...
@property
@abstractmethod
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
"""Conditioning image grid dimensions for edit mode.
Returns:
Grid dimensions (edit) or None (standard generation).
"""
...
@property
@abstractmethod
def conditioning_latents(self) -> mx.array | None:
"""Conditioning latents for edit mode.
Returns:
Conditioning latents array for image editing, None for standard generation.
"""
...
@abstractmethod
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
"""Get embeddings for CFG with batch_size=2.
Combines positive and negative embeddings into batched tensors for
a single forward pass. Pads shorter sequences to max length. Attention
mask is used to mask padding.
Returns:
None if model doesn't support CFG, otherwise tuple of:
- batched_embeds: [2, max_seq, hidden] (positive then negative)
- batched_mask: [2, max_seq] attention mask
- batched_pooled: [2, hidden] pooled embeddings or None
- conditioning_latents: [2, latent_seq, latent_dim] or None
TODO(ciaran): type this
"""
...
class ModelAdapter(ABC, Generic[ModelT, TransformerT]):
_config: ImageModelConfig
_model: ModelT
_transformer: TransformerT
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> ModelT:
return self._model
@property
def transformer(self) -> TransformerT:
return self._transformer
@property
@abstractmethod
def hidden_dim(self) -> int: ...
@property
@abstractmethod
def needs_cfg(self) -> bool:
"""Whether this model uses classifier-free guidance."""
...
@abstractmethod
def _get_latent_creator(self) -> type: ...
@abstractmethod
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list["JointBlockWrapper[Any]"]:
"""Create wrapped joint transformer blocks with pipefusion support.
Args:
text_seq_len: Number of text tokens (constant for generation)
encoder_hidden_states_mask: Attention mask for text (Qwen only)
Returns:
List of wrapped joint blocks ready for pipefusion
"""
...
@abstractmethod
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list["SingleBlockWrapper[Any]"]:
"""Create wrapped single transformer blocks with pipefusion support.
Args:
text_seq_len: Number of text tokens (constant for generation)
Returns:
List of wrapped single blocks ready for pipefusion
"""
...
@abstractmethod
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
"""Remove transformer blocks outside the assigned range.
This should be called BEFORE mx.eval() to avoid loading unused weights
in distributed mode.
Args:
start_layer: First layer index (inclusive) assigned to this node
end_layer: Last layer index (exclusive) assigned to this node
"""
...
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
"""Default implementation: no dimension computation needed.
Override in edit adapters to compute dimensions from input image.
TODO(ciaran): this is a hack
Returns:
None (use user-specified dimensions)
"""
return None
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial latents. Uses model-specific latent creator."""
model: Any = self.model
return LatentCreator.create_for_txt2img_or_img2img(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
img2img=Img2Img(
vae=model.vae,
latent_creator=self._get_latent_creator(),
sigmas=runtime_config.scheduler.sigmas,
init_time_step=runtime_config.init_time_step,
image_path=runtime_config.image_path,
),
)
def decode_latents(
self,
latents: mx.array,
runtime_config: Config,
seed: int,
prompt: str,
) -> Image.Image:
model: Any = self.model # Allow attribute access on model
latents = self._get_latent_creator().unpack_latents(
latents=latents,
height=runtime_config.height,
width=runtime_config.width,
)
decoded = model.vae.decode(latents)
# TODO(ciaran):
# from mflux.models.common.vae.vae_util import VAEUtil
# VAEUtil.decode(vae=model.vae, latents=latents, tiling_config=self.tiling_config)
generated_image = ImageUtil.to_image(
decoded_latents=decoded,
config=runtime_config,
seed=seed,
prompt=prompt,
quantization=model.bits,
lora_paths=model.lora_paths,
lora_scales=model.lora_scales,
image_path=runtime_config.image_path,
image_strength=runtime_config.image_strength,
generation_time=0,
)
return generated_image.image
@abstractmethod
def encode_prompt(
self, prompt: str, negative_prompt: str | None = None
) -> "PromptData": ...
@abstractmethod
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]: ...
@abstractmethod
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array: ...
@abstractmethod
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> RotaryEmbeddings: ...
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
@abstractmethod
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
"""Apply classifier-free guidance to combine positive/negative predictions.
Only called when needs_cfg is True.
Args:
noise_positive: Noise prediction from positive prompt
noise_negative: Noise prediction from negative prompt
guidance_scale: Guidance strength
Returns:
Guided noise prediction
"""
...
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
transformer: Any = self.transformer
hidden_states = transformer.norm_out(hidden_states, text_embeddings)
return transformer.proj_out(hidden_states)

View File

@@ -0,0 +1,11 @@
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
from exo.worker.engines.image.models.flux.config import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
)
__all__ = [
"FluxModelAdapter",
"FLUX_DEV_CONFIG",
"FLUX_SCHNELL_CONFIG",
]

View File

@@ -0,0 +1,215 @@
from pathlib import Path
from typing import Any
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.variants.txt2img.flux import Flux1
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import (
ModelAdapter,
PromptData,
RotaryEmbeddings,
)
from exo.worker.engines.image.models.flux.wrappers import (
FluxJointBlockWrapper,
FluxSingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class FluxPromptData(PromptData):
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
self._prompt_embeds = prompt_embeds
self._pooled_prompt_embeds = pooled_prompt_embeds
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._pooled_prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array | None:
return None
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
return None
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
return None
@property
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
return None
@property
def conditioning_latents(self) -> mx.array | None:
return None
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
return None
class FluxModelAdapter(ModelAdapter[Flux1, Transformer]):
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = Flux1(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
model_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def hidden_dim(self) -> int:
return self._transformer.x_embedder.weight.shape[0]
@property
def needs_cfg(self) -> bool:
return False
def _get_latent_creator(self) -> type:
return FluxLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper[Any]]:
"""Create wrapped joint blocks for Flux."""
return [
FluxJointBlockWrapper(block, text_seq_len)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper[Any]]:
"""Create wrapped single blocks for Flux."""
return [
FluxSingleBlockWrapper(block, text_seq_len)
for block in self._transformer.single_transformer_blocks
]
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
all_joint = list(self._transformer.transformer_blocks)
all_single = list(self._transformer.single_transformer_blocks)
total_joint_blocks = len(all_joint)
if end_layer <= total_joint_blocks:
# All assigned are joint blocks
joint_start, joint_end = start_layer, end_layer
single_start, single_end = 0, 0
elif start_layer >= total_joint_blocks:
# All assigned are single blocks
joint_start, joint_end = 0, 0
single_start = start_layer - total_joint_blocks
single_end = end_layer - total_joint_blocks
else:
# Spans both joint and single
joint_start, joint_end = start_layer, total_joint_blocks
single_start = 0
single_end = end_layer - total_joint_blocks
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
self._transformer.single_transformer_blocks = all_single[
single_start:single_end
]
def encode_prompt(
self, prompt: str, negative_prompt: str | None = None
) -> FluxPromptData:
del negative_prompt
assert isinstance(self.model.prompt_cache, dict)
assert isinstance(self.model.tokenizers, dict)
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
prompt_cache=self.model.prompt_cache,
t5_tokenizer=self.model.tokenizers["t5"],
clip_tokenizer=self.model.tokenizers["clip"],
t5_text_encoder=self.model.t5_text_encoder,
clip_text_encoder=self.model.clip_text_encoder,
)
return FluxPromptData(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.x_embedder(hidden_states)
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None, # Ignored by Flux
) -> mx.array:
if pooled_prompt_embeds is None:
raise ValueError(
"pooled_prompt_embeds is required for Flux text embeddings"
)
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
return Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> RotaryEmbeddings:
return Transformer.compute_rotary_embeddings(
prompt_embeds,
self._transformer.pos_embed,
runtime_config,
kontext_image_ids,
)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
raise NotImplementedError("Flux does not use classifier-free guidance")

View File

@@ -0,0 +1,34 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
FLUX_SCHNELL_CONFIG = ImageModelConfig(
model_family="flux",
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
default_steps={"low": 1, "medium": 2, "high": 4},
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
)
FLUX_DEV_CONFIG = ImageModelConfig(
model_family="flux",
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
)

View File

@@ -0,0 +1,279 @@
from typing import final
import mlx.core as mx
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
AttentionUtils,
)
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
JointTransformerBlock,
)
from mflux.models.flux.model.flux_transformer.single_transformer_block import (
SingleTransformerBlock,
)
from pydantic import BaseModel, ConfigDict
from exo.worker.engines.image.models.base import RotaryEmbeddings
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
@final
class FluxModulationParams(BaseModel):
model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True)
gate_msa: mx.array
shift_mlp: mx.array
scale_mlp: mx.array
gate_mlp: mx.array
@final
class FluxNormGateState(BaseModel):
model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True)
norm_hidden: mx.array
gate: mx.array
class FluxJointBlockWrapper(JointBlockWrapper[JointTransformerBlock]):
def __init__(self, block: JointTransformerBlock, text_seq_len: int):
super().__init__(block, text_seq_len)
self._num_heads = block.attn.num_heads
self._head_dim = block.attn.head_dimension
# Intermediate state stored between _compute_qkv and _apply_output
self._hidden_mod: FluxModulationParams | None = None
self._context_mod: FluxModulationParams | None = None
def _compute_qkv(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: RotaryEmbeddings,
patch_mode: bool = False,
) -> tuple[mx.array, mx.array, mx.array]:
assert isinstance(rotary_embeddings, mx.array)
attn = self.block.attn
(
norm_hidden,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.block.norm1(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
self._hidden_mod = FluxModulationParams(
gate_msa=gate_msa,
shift_mlp=shift_mlp,
scale_mlp=scale_mlp,
gate_mlp=gate_mlp,
)
(
norm_encoder,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = self.block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
self._context_mod = FluxModulationParams(
gate_msa=c_gate_msa,
shift_mlp=c_shift_mlp,
scale_mlp=c_scale_mlp,
gate_mlp=c_gate_mlp,
)
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
if patch_mode:
text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:,
:,
self._text_seq_len + self._patch_start : self._text_seq_len
+ self._patch_end,
...,
]
rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
else:
rope = rotary_embeddings
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=rope)
return query, key, value
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
batch_size = query.shape[0]
return AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
) -> tuple[mx.array, mx.array]:
attn = self.block.attn
context_attn_output = attn_out[:, : self._text_seq_len, :]
hidden_attn_output = attn_out[:, self._text_seq_len :, :]
hidden_attn_output = attn.to_out[0](hidden_attn_output)
context_attn_output = attn.to_add_out(context_attn_output)
assert self._hidden_mod is not None
assert self._context_mod is not None
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=hidden_states,
attn_output=hidden_attn_output,
gate_mlp=self._hidden_mod.gate_mlp,
gate_msa=self._hidden_mod.gate_msa,
scale_mlp=self._hidden_mod.scale_mlp,
shift_mlp=self._hidden_mod.shift_mlp,
norm_layer=self.block.norm2,
ff_layer=self.block.ff,
)
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=encoder_hidden_states,
attn_output=context_attn_output,
gate_mlp=self._context_mod.gate_mlp,
gate_msa=self._context_mod.gate_msa,
scale_mlp=self._context_mod.scale_mlp,
shift_mlp=self._context_mod.shift_mlp,
norm_layer=self.block.norm2_context,
ff_layer=self.block.ff_context,
)
return encoder_hidden_states, hidden_states
class FluxSingleBlockWrapper(SingleBlockWrapper[SingleTransformerBlock]):
"""Flux-specific single block wrapper with pipefusion support."""
def __init__(self, block: SingleTransformerBlock, text_seq_len: int):
super().__init__(block, text_seq_len)
self._num_heads = block.attn.num_heads
self._head_dim = block.attn.head_dimension
# Intermediate state stored between _compute_qkv and _apply_output
self._norm_state: FluxNormGateState | None = None
def _compute_qkv(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: RotaryEmbeddings,
patch_mode: bool = False,
) -> tuple[mx.array, mx.array, mx.array]:
assert isinstance(rotary_embeddings, mx.array)
attn = self.block.attn
norm_hidden, gate = self.block.norm(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
self._norm_state = FluxNormGateState(norm_hidden=norm_hidden, gate=gate)
query, key, value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
if patch_mode:
text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:,
:,
self._text_seq_len + self._patch_start : self._text_seq_len
+ self._patch_end,
...,
]
rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
else:
rope = rotary_embeddings
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=rope)
return query, key, value
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
batch_size = query.shape[0]
return AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
residual = hidden_states
assert self._norm_state is not None
output = self.block._apply_feed_forward_and_projection( # pyright: ignore[reportPrivateUsage]
norm_hidden_states=self._norm_state.norm_hidden,
attn_output=attn_out,
gate=self._norm_state.gate,
)
return residual + output

View File

@@ -0,0 +1,13 @@
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
from exo.worker.engines.image.models.qwen.config import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
)
from exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter
__all__ = [
"QwenModelAdapter",
"QwenEditModelAdapter",
"QWEN_IMAGE_CONFIG",
"QWEN_IMAGE_EDIT_CONFIG",
]

View File

@@ -0,0 +1,292 @@
from pathlib import Path
from typing import Any
import mlx.core as mx
from mflux.models.common.config import ModelConfig
from mflux.models.common.config.config import Config
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
QwenPromptEncoder,
)
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import (
ModelAdapter,
PromptData,
RotaryEmbeddings,
)
from exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class QwenPromptData(PromptData):
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
):
self._prompt_embeds = prompt_embeds
self._prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self._negative_prompt_mask = negative_prompt_mask
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array:
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
return self._negative_prompt_embeds
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:
if positive:
return self._prompt_mask
else:
return self._negative_prompt_mask
@property
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
return None
@property
def conditioning_latents(self) -> mx.array | None:
return None
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
"""Batch positive and negative embeddings for CFG with batch_size=2.
Pads shorter sequence to max length using zeros for embeddings
and zeros (masked) for attention mask.
Returns:
Tuple of (batched_embeds, batched_mask, None, conditioning_latents)
- batched_embeds: [2, max_seq, hidden]
- batched_mask: [2, max_seq]
- None for pooled (Qwen doesn't use it)
- conditioning_latents: [2, latent_seq, latent_dim] or None
"""
pos_embeds = self._prompt_embeds
neg_embeds = self._negative_prompt_embeds
pos_mask = self._prompt_mask
neg_mask = self._negative_prompt_mask
pos_seq_len = pos_embeds.shape[1]
neg_seq_len = neg_embeds.shape[1]
max_seq_len = max(pos_seq_len, neg_seq_len)
hidden_dim = pos_embeds.shape[2]
if pos_seq_len < max_seq_len:
pad_len = max_seq_len - pos_seq_len
pos_embeds = mx.concatenate(
[
pos_embeds,
mx.zeros((1, pad_len, hidden_dim), dtype=pos_embeds.dtype),
],
axis=1,
)
pos_mask = mx.concatenate(
[pos_mask, mx.zeros((1, pad_len), dtype=pos_mask.dtype)],
axis=1,
)
elif neg_seq_len < max_seq_len:
pad_len = max_seq_len - neg_seq_len
neg_embeds = mx.concatenate(
[
neg_embeds,
mx.zeros((1, pad_len, hidden_dim), dtype=neg_embeds.dtype),
],
axis=1,
)
neg_mask = mx.concatenate(
[neg_mask, mx.zeros((1, pad_len), dtype=neg_mask.dtype)],
axis=1,
)
batched_embeds = mx.concatenate([pos_embeds, neg_embeds], axis=0)
batched_mask = mx.concatenate([pos_mask, neg_mask], axis=0)
# TODO(ciaran): currently None but maybe we will deduplicate with edit
# adapter
cond_latents = self.conditioning_latents
if cond_latents is not None:
cond_latents = mx.concatenate([cond_latents, cond_latents], axis=0)
return batched_embeds, batched_mask, None, cond_latents
class QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):
"""Adapter for Qwen-Image model.
Key differences from Flux:
- Single text encoder (vs dual T5+CLIP)
- 60 joint-style blocks, no single blocks
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
- Norm-preserving CFG with negative prompts
- Uses attention mask for variable-length text
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImage(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
model_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper[Any]]:
"""Create wrapped joint blocks for Qwen."""
return [
QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper[Any]]:
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
self._transformer.transformer_blocks = self._transformer.transformer_blocks[
start_layer:end_layer
]
def encode_prompt(
self, prompt: str, negative_prompt: str | None = None
) -> QwenPromptData:
assert isinstance(self.model.prompt_cache, dict)
assert isinstance(self.model.tokenizers, dict)
if negative_prompt is None or negative_prompt == "":
negative_prompt = " "
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
QwenPromptEncoder.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_cache=self.model.prompt_cache,
qwen_tokenizer=self.model.tokenizers["qwen"],
qwen_text_encoder=self.model.text_encoder,
)
)
return QwenPromptData(
prompt_embeds=prompt_embeds,
prompt_mask=prompt_mask,
negative_prompt_embeds=neg_embeds,
negative_prompt_mask=neg_mask,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.img_in(hidden_states)
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
# (which for Qwen is the same as prompt_embeds)
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> RotaryEmbeddings:
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # pyright: ignore[reportPrivateUsage]
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
return self._model.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)

View File

@@ -0,0 +1,29 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
QWEN_IMAGE_CONFIG = ImageModelConfig(
model_family="qwen",
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
)
QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
model_family="qwen-edit",
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125,
guidance_scale=3.5,
)

View File

@@ -0,0 +1,434 @@
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.variants.edit.qwen_edit_util import QwenEditUtil
from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import (
ModelAdapter,
PromptData,
RotaryEmbeddings,
)
from exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
@dataclass(frozen=True)
class EditImageDimensions:
vl_width: int
vl_height: int
vae_width: int
vae_height: int
image_paths: list[str]
class QwenEditPromptData(PromptData):
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
conditioning_latents: mx.array,
qwen_image_ids: mx.array,
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],
):
self._prompt_embeds = prompt_embeds
self._prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self._negative_prompt_mask = negative_prompt_mask
self._conditioning_latents = conditioning_latents
self._qwen_image_ids = qwen_image_ids
self._cond_image_grid = cond_image_grid
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array:
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
return self._negative_prompt_embeds
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:
if positive:
return self._prompt_mask
else:
return self._negative_prompt_mask
@property
def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:
return self._cond_image_grid
@property
def conditioning_latents(self) -> mx.array:
return self._conditioning_latents
@property
def qwen_image_ids(self) -> mx.array:
return self._qwen_image_ids
@property
def is_edit_mode(self) -> bool:
return True
def get_batched_cfg_data(
self,
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
"""Batch positive and negative embeddings for CFG with batch_size=2.
Pads shorter sequence to max length using zeros for embeddings
and zeros (masked) for attention mask. Duplicates conditioning
latents for both positive and negative passes.
Returns:
Tuple of (batched_embeds, batched_mask, None, batched_cond_latents)
- batched_embeds: [2, max_seq, hidden]
- batched_mask: [2, max_seq]
- None for pooled (Qwen doesn't use it)
- batched_cond_latents: [2, latent_seq, latent_dim]
TODO(ciaran): type this
"""
pos_embeds = self._prompt_embeds
neg_embeds = self._negative_prompt_embeds
pos_mask = self._prompt_mask
neg_mask = self._negative_prompt_mask
pos_seq_len = pos_embeds.shape[1]
neg_seq_len = neg_embeds.shape[1]
max_seq_len = max(pos_seq_len, neg_seq_len)
hidden_dim = pos_embeds.shape[2]
if pos_seq_len < max_seq_len:
pad_len = max_seq_len - pos_seq_len
pos_embeds = mx.concatenate(
[
pos_embeds,
mx.zeros((1, pad_len, hidden_dim), dtype=pos_embeds.dtype),
],
axis=1,
)
pos_mask = mx.concatenate(
[pos_mask, mx.zeros((1, pad_len), dtype=pos_mask.dtype)],
axis=1,
)
if neg_seq_len < max_seq_len:
pad_len = max_seq_len - neg_seq_len
neg_embeds = mx.concatenate(
[
neg_embeds,
mx.zeros((1, pad_len, hidden_dim), dtype=neg_embeds.dtype),
],
axis=1,
)
neg_mask = mx.concatenate(
[neg_mask, mx.zeros((1, pad_len), dtype=neg_mask.dtype)],
axis=1,
)
batched_embeds = mx.concatenate([pos_embeds, neg_embeds], axis=0)
batched_mask = mx.concatenate([pos_mask, neg_mask], axis=0)
batched_cond_latents = mx.concatenate(
[self._conditioning_latents, self._conditioning_latents], axis=0
)
return batched_embeds, batched_mask, None, batched_cond_latents
class QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):
"""Adapter for Qwen-Image-Edit model.
Key differences from standard QwenModelAdapter:
- Uses QwenImageEdit model with vision-language components
- Encodes prompts WITH input images via VL tokenizer/encoder
- Creates conditioning latents from input images
- Supports image editing with concatenated latents during diffusion
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImageEdit(
quantize=quantize,
model_path=str(local_path),
)
self._transformer = self._model.transformer
self._edit_dimensions: EditImageDimensions | None = None
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> QwenImageEdit:
return self._model
@property
def transformer(self) -> QwenTransformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def _get_latent_creator(self) -> type[QwenLatentCreator]:
return QwenLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper[Any]]:
"""Create wrapped joint blocks for Qwen Edit."""
return [
QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper[Any]]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
self._transformer.transformer_blocks = self._transformer.transformer_blocks[
start_layer:end_layer
]
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
"""Compute and store dimensions from input image.
Also stores image_paths for use in encode_prompt().
Returns:
(output_width, output_height) for runtime config
"""
vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(
image_path
)
self._edit_dimensions = EditImageDimensions(
vl_width=vl_w,
vl_height=vl_h,
vae_width=vae_w,
vae_height=vae_h,
image_paths=[str(image_path)],
)
return out_w, out_h
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial noise latents (pure noise for edit mode)."""
return QwenLatentCreator.create_noise(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
)
def encode_prompt(
self, prompt: str, negative_prompt: str | None = None
) -> QwenEditPromptData:
dims = self._edit_dimensions
if dims is None:
raise RuntimeError(
"set_image_dimensions() must be called before encode_prompt() "
"for QwenEditModelAdapter"
)
if negative_prompt is None or negative_prompt == "":
negative_prompt = " "
# TODO(ciaran): config is untyped and unused, unsure if Config or RuntimeConfig is intended
(
prompt_embeds,
prompt_mask,
negative_prompt_embeds,
negative_prompt_mask,
) = self._model._encode_prompts_with_images( # pyright: ignore[reportPrivateUsage]
prompt,
negative_prompt,
dims.image_paths,
self._config,
dims.vl_width,
dims.vl_height,
)
(
conditioning_latents,
qwen_image_ids,
cond_h_patches,
cond_w_patches,
num_images,
) = QwenEditUtil.create_image_conditioning_latents(
vae=self._model.vae,
height=dims.vae_height,
width=dims.vae_width,
image_paths=dims.image_paths,
vl_width=dims.vl_width,
vl_height=dims.vl_height,
)
if num_images > 1:
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [
(1, cond_h_patches, cond_w_patches) for _ in range(num_images)
]
else:
cond_image_grid = (1, cond_h_patches, cond_w_patches)
return QwenEditPromptData(
prompt_embeds=prompt_embeds,
prompt_mask=prompt_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_mask=negative_prompt_mask,
conditioning_latents=conditioning_latents,
qwen_image_ids=qwen_image_ids,
cond_image_grid=cond_image_grid,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.img_in(hidden_states)
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> RotaryEmbeddings:
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # pyright: ignore[reportPrivateUsage]
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
return QwenImage.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)
def _compute_dimensions_from_image(
self, image_path: Path
) -> tuple[int, int, int, int, int, int]:
from mflux.utils.image_util import ImageUtil
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
image_size = pil_image.size
# Vision-language dimensions (384x384 target area)
condition_image_size = 384 * 384
condition_ratio = image_size[0] / image_size[1]
vl_width = math.sqrt(condition_image_size * condition_ratio)
vl_height = vl_width / condition_ratio
vl_width = round(vl_width / 32) * 32
vl_height = round(vl_height / 32) * 32
# VAE dimensions (1024x1024 target area)
vae_image_size = 1024 * 1024
vae_ratio = image_size[0] / image_size[1]
vae_width = math.sqrt(vae_image_size * vae_ratio)
vae_height = vae_width / vae_ratio
vae_width = round(vae_width / 32) * 32
vae_height = round(vae_height / 32) * 32
# Output dimensions from input image aspect ratio
target_area = 1024 * 1024
ratio = image_size[0] / image_size[1]
output_width = math.sqrt(target_area * ratio)
output_height = output_width / ratio
output_width = round(output_width / 32) * 32
output_height = round(output_height / 32) * 32
# Ensure multiple of 16 for VAE
vae_scale_factor = 8
multiple_of = vae_scale_factor * 2
output_width = output_width // multiple_of * multiple_of
output_height = output_height // multiple_of * multiple_of
return (
int(vl_width),
int(vl_height),
int(vae_width),
int(vae_height),
int(output_width),
int(output_height),
)

View File

@@ -0,0 +1,204 @@
from typing import final
import mlx.core as mx
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
QwenTransformerBlock,
)
from pydantic import BaseModel, ConfigDict
from exo.worker.engines.image.models.base import RotaryEmbeddings
from exo.worker.engines.image.pipeline.block_wrapper import JointBlockWrapper
@final
class QwenStreamModulation(BaseModel):
model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True)
mod1: mx.array
mod2: mx.array
gate1: mx.array
class QwenJointBlockWrapper(JointBlockWrapper[QwenTransformerBlock]):
def __init__(
self,
block: QwenTransformerBlock,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
):
super().__init__(block, text_seq_len)
self._encoder_hidden_states_mask = encoder_hidden_states_mask
self._num_heads = block.attn.num_heads
self._head_dim = block.attn.head_dim
# Intermediate state stored between _compute_qkv and _apply_output
self._img_mod: QwenStreamModulation | None = None
self._txt_mod: QwenStreamModulation | None = None
def set_encoder_mask(self, mask: mx.array | None) -> None:
"""Set the encoder hidden states mask for attention."""
self._encoder_hidden_states_mask = mask
def _compute_qkv(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: RotaryEmbeddings,
patch_mode: bool = False,
) -> tuple[mx.array, mx.array, mx.array]:
assert isinstance(rotary_embeddings, tuple)
batch_size = hidden_states.shape[0]
img_seq_len = hidden_states.shape[1]
attn = self.block.attn
img_mod_params = self.block.img_mod_linear(
self.block.img_mod_silu(text_embeddings)
)
txt_mod_params = self.block.txt_mod_linear(
self.block.txt_mod_silu(text_embeddings)
)
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
img_normed = self.block.img_norm1(hidden_states)
img_modulated, img_gate1 = QwenTransformerBlock._modulate( # pyright: ignore[reportPrivateUsage]
img_normed, img_mod1
)
self._img_mod = QwenStreamModulation(
mod1=img_mod1, mod2=img_mod2, gate1=img_gate1
)
txt_normed = self.block.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate( # pyright: ignore[reportPrivateUsage]
txt_normed, txt_mod1
)
self._txt_mod = QwenStreamModulation(
mod1=txt_mod1, mod2=txt_mod2, gate1=txt_gate1
)
img_query = attn.to_q(img_modulated)
img_key = attn.to_k(img_modulated)
img_value = attn.to_v(img_modulated)
txt_query = attn.add_q_proj(txt_modulated)
txt_key = attn.add_k_proj(txt_modulated)
txt_value = attn.add_v_proj(txt_modulated)
img_query = mx.reshape(
img_query, (batch_size, img_seq_len, self._num_heads, self._head_dim)
)
img_key = mx.reshape(
img_key, (batch_size, img_seq_len, self._num_heads, self._head_dim)
)
img_value = mx.reshape(
img_value, (batch_size, img_seq_len, self._num_heads, self._head_dim)
)
txt_query = mx.reshape(
txt_query,
(batch_size, self._text_seq_len, self._num_heads, self._head_dim),
)
txt_key = mx.reshape(
txt_key, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
)
txt_value = mx.reshape(
txt_value, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
)
img_query = attn.norm_q(img_query)
img_key = attn.norm_k(img_key)
txt_query = attn.norm_added_q(txt_query)
txt_key = attn.norm_added_k(txt_key)
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
if patch_mode:
# Slice image RoPE for patch, keep full text RoPE
img_cos = img_cos[self._patch_start : self._patch_end]
img_sin = img_sin[self._patch_start : self._patch_end]
img_query = QwenAttention._apply_rope_qwen(img_query, img_cos, img_sin) # pyright: ignore[reportPrivateUsage]
img_key = QwenAttention._apply_rope_qwen(img_key, img_cos, img_sin) # pyright: ignore[reportPrivateUsage]
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin) # pyright: ignore[reportPrivateUsage]
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin) # pyright: ignore[reportPrivateUsage]
img_query = mx.transpose(img_query, (0, 2, 1, 3))
img_key = mx.transpose(img_key, (0, 2, 1, 3))
img_value = mx.transpose(img_value, (0, 2, 1, 3))
txt_query = mx.transpose(txt_query, (0, 2, 1, 3))
txt_key = mx.transpose(txt_key, (0, 2, 1, 3))
txt_value = mx.transpose(txt_value, (0, 2, 1, 3))
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
return query, key, value
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
attn = self.block.attn
mask = QwenAttention._convert_mask_for_qwen( # pyright: ignore[reportPrivateUsage]
mask=self._encoder_hidden_states_mask,
joint_seq_len=key.shape[2],
txt_seq_len=self._text_seq_len,
)
query_bshd = mx.transpose(query, (0, 2, 1, 3))
key_bshd = mx.transpose(key, (0, 2, 1, 3))
value_bshd = mx.transpose(value, (0, 2, 1, 3))
return attn._compute_attention_qwen( # pyright: ignore[reportPrivateUsage]
query=query_bshd,
key=key_bshd,
value=value_bshd,
mask=mask,
block_idx=None,
)
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
) -> tuple[mx.array, mx.array]:
attn = self.block.attn
assert self._img_mod is not None
assert self._txt_mod is not None
txt_attn_output = attn_out[:, : self._text_seq_len, :]
img_attn_output = attn_out[:, self._text_seq_len :, :]
img_attn_output = attn.attn_to_out[0](img_attn_output)
txt_attn_output = attn.to_add_out(txt_attn_output)
hidden_states = hidden_states + self._img_mod.gate1 * img_attn_output
encoder_hidden_states = (
encoder_hidden_states + self._txt_mod.gate1 * txt_attn_output
)
img_normed2 = self.block.img_norm2(hidden_states)
img_modulated2, img_gate2 = QwenTransformerBlock._modulate( # pyright: ignore[reportPrivateUsage]
img_normed2, self._img_mod.mod2
)
img_mlp_output = self.block.img_ff(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
txt_normed2 = self.block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate( # pyright: ignore[reportPrivateUsage]
txt_normed2, self._txt_mod.mod2
)
txt_mlp_output = self.block.txt_ff(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
return encoder_hidden_states, hidden_states

View File

@@ -0,0 +1,15 @@
from exo.worker.engines.image.pipeline.block_wrapper import (
BlockWrapperMode,
JointBlockWrapper,
SingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
__all__ = [
"BlockWrapperMode",
"DiffusionRunner",
"ImagePatchKVCache",
"JointBlockWrapper",
"SingleBlockWrapper",
]

View File

@@ -0,0 +1,303 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Generic, Self, TypeVar
import mlx.core as mx
from exo.worker.engines.image.models.base import RotaryEmbeddings
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
BlockT = TypeVar("BlockT")
class BlockWrapperMode(Enum):
CACHING = "caching" # Sync mode: compute full attention, populate cache
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
class BlockWrapperMixin:
"""Common cache management logic for block wrappers.
Including:
- KV cache creation and management
- Mode
- Patch range setting
"""
_text_seq_len: int
_kv_cache: ImagePatchKVCache | None
_mode: BlockWrapperMode
_patch_start: int
_patch_end: int
def _init_cache_state(self, text_seq_len: int) -> None:
self._text_seq_len = text_seq_len
self._kv_cache = None
self._mode = BlockWrapperMode.CACHING
self._patch_start = 0
self._patch_end = 0
def set_patch(
self,
mode: BlockWrapperMode,
patch_start: int = 0,
patch_end: int = 0,
) -> Self:
"""Set mode and patch range.
Args:
mode: CACHING (full attention) or PATCHED (use cached KV)
patch_start: Start token index within image (for PATCHED mode)
patch_end: End token index within image (for PATCHED mode)
Returns:
Self for method chaining
"""
self._mode = mode
self._patch_start = patch_start
self._patch_end = patch_end
return self
def set_text_seq_len(self, text_seq_len: int) -> None:
self._text_seq_len = text_seq_len
def _get_active_cache(self) -> ImagePatchKVCache | None:
return self._kv_cache
def _ensure_cache(self, img_key: mx.array) -> None:
if self._kv_cache is None:
batch, num_heads, img_seq_len, head_dim = img_key.shape
self._kv_cache = ImagePatchKVCache(
batch_size=batch,
num_heads=num_heads,
image_seq_len=img_seq_len,
head_dim=head_dim,
)
def _cache_full_image_kv(self, img_key: mx.array, img_value: mx.array) -> None:
self._ensure_cache(img_key)
cache = self._get_active_cache()
assert cache is not None
cache.update_image_patch(0, img_key.shape[2], img_key, img_value)
def _cache_patch_kv(self, img_key: mx.array, img_value: mx.array) -> None:
cache = self._get_active_cache()
assert cache is not None
cache.update_image_patch(self._patch_start, self._patch_end, img_key, img_value)
def _get_full_kv(
self, text_key: mx.array, text_value: mx.array
) -> tuple[mx.array, mx.array]:
cache = self._get_active_cache()
assert cache is not None
return cache.get_full_kv(text_key, text_value)
def reset_cache(self) -> None:
self._kv_cache = None
class JointBlockWrapper(BlockWrapperMixin, ABC, Generic[BlockT]):
"""Base class for joint transformer block wrappers with pipefusion support.
The wrapper:
- Owns its KV cache (created lazily on first CACHING forward)
- Controls the forward pass flow (CACHING vs PATCHED mode)
- Handles patch slicing and cache operations
"""
block: BlockT
def __init__(self, block: BlockT, text_seq_len: int):
self.block = block
self._init_cache_state(text_seq_len)
def set_encoder_mask(self, mask: mx.array | None) -> None: # noqa: B027
"""Set the encoder hidden states mask for attention.
Override in subclasses that use attention masks
Default is a no-op for models that don't use masks
"""
del mask # Unused in base class
def __call__(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: RotaryEmbeddings,
) -> tuple[mx.array, mx.array]:
if self._mode == BlockWrapperMode.CACHING:
return self._forward_caching(
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
)
return self._forward_patched(
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
)
def _forward_caching(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: RotaryEmbeddings,
) -> tuple[mx.array, mx.array]:
"""CACHING mode: Full attention, store image K/V in cache."""
query, key, value = self._compute_qkv(
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
)
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_full_image_kv(img_key, img_value)
attn_out = self._compute_attention(query, key, value)
return self._apply_output(
attn_out, hidden_states, encoder_hidden_states, text_embeddings
)
def _forward_patched(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: RotaryEmbeddings,
) -> tuple[mx.array, mx.array]:
# hidden_states is already the patch (provided by runner)
patch_hidden = hidden_states
query, key, value = self._compute_qkv(
patch_hidden,
encoder_hidden_states,
text_embeddings,
rotary_embeddings,
patch_mode=True,
)
text_key = key[:, :, : self._text_seq_len, :]
text_value = value[:, :, : self._text_seq_len, :]
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_patch_kv(img_key, img_value)
full_key, full_value = self._get_full_kv(text_key, text_value)
attn_out = self._compute_attention(query, full_key, full_value)
return self._apply_output(
attn_out, patch_hidden, encoder_hidden_states, text_embeddings
)
@abstractmethod
def _compute_qkv(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: RotaryEmbeddings,
patch_mode: bool = False,
) -> tuple[mx.array, mx.array, mx.array]: ...
@abstractmethod
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array: ...
@abstractmethod
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
) -> tuple[mx.array, mx.array]: ...
class SingleBlockWrapper(BlockWrapperMixin, ABC, Generic[BlockT]):
"""Base class for single-stream transformer block wrappers.
Similar to JointBlockWrapper but for blocks that operate on a single
concatenated [text, image] stream rather than separate streams.
"""
block: BlockT
def __init__(self, block: BlockT, text_seq_len: int):
self.block = block
self._init_cache_state(text_seq_len)
def __call__(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: RotaryEmbeddings,
) -> mx.array:
if self._mode == BlockWrapperMode.CACHING:
return self._forward_caching(
hidden_states, text_embeddings, rotary_embeddings
)
return self._forward_patched(hidden_states, text_embeddings, rotary_embeddings)
def _forward_caching(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: RotaryEmbeddings,
) -> mx.array:
"""CACHING mode: Full attention, store image K/V in cache."""
query, key, value = self._compute_qkv(
hidden_states, text_embeddings, rotary_embeddings
)
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_full_image_kv(img_key, img_value)
attn_out = self._compute_attention(query, key, value)
return self._apply_output(attn_out, hidden_states, text_embeddings)
def _forward_patched(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: RotaryEmbeddings,
) -> mx.array:
"""PATCHED mode: Compute patch Q/K/V, use cached image K/V for attention."""
query, key, value = self._compute_qkv(
hidden_states, text_embeddings, rotary_embeddings, patch_mode=True
)
text_key = key[:, :, : self._text_seq_len, :]
text_value = value[:, :, : self._text_seq_len, :]
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_patch_kv(img_key, img_value)
full_key, full_value = self._get_full_kv(text_key, text_value)
attn_out = self._compute_attention(query, full_key, full_value)
return self._apply_output(attn_out, hidden_states, text_embeddings)
@abstractmethod
def _compute_qkv(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: RotaryEmbeddings,
patch_mode: bool = False,
) -> tuple[mx.array, mx.array, mx.array]: ...
@abstractmethod
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array: ...
@abstractmethod
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array: ...

View File

@@ -0,0 +1,72 @@
import mlx.core as mx
class ImagePatchKVCache:
"""KV cache that stores only IMAGE K/V with patch-level updates.
Only caches image K/V since:
- Text K/V is always computed fresh (same for all patches)
- Only image portion needs stale/fresh cache management across patches
"""
def __init__(
self,
batch_size: int,
num_heads: int,
image_seq_len: int,
head_dim: int,
dtype: mx.Dtype = mx.float32,
):
self.batch_size = batch_size
self.num_heads = num_heads
self.image_seq_len = image_seq_len
self.head_dim = head_dim
self._dtype = dtype
self.key_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
self.value_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
def update_image_patch(
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
) -> None:
"""Update cache with fresh K/V for an image patch slice.
Args:
patch_start: Start token index within image portion (0-indexed)
patch_end: End token index within image portion
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
"""
self.key_cache[:, :, patch_start:patch_end, :] = key
self.value_cache[:, :, patch_start:patch_end, :] = value
def get_full_kv(
self, text_key: mx.array, text_value: mx.array
) -> tuple[mx.array, mx.array]:
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
Args:
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
Returns:
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
"""
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
return full_key, full_value
def reset(self) -> None:
"""Reset cache to zeros."""
self.key_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)
self.value_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)

View File

@@ -0,0 +1,968 @@
from math import ceil
from typing import Any, Optional
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.utils.exceptions import StopImageGenerationException
from tqdm import tqdm
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import (
ModelAdapter,
PromptData,
RotaryEmbeddings,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
BlockWrapperMode,
JointBlockWrapper,
SingleBlockWrapper,
)
def calculate_patch_heights(latent_height: int, num_patches: int):
patch_height = ceil(latent_height / num_patches)
actual_num_patches = ceil(latent_height / patch_height)
patch_heights = [patch_height] * (actual_num_patches - 1)
last_height = latent_height - patch_height * (actual_num_patches - 1)
patch_heights.append(last_height)
return patch_heights, actual_num_patches
def calculate_token_indices(patch_heights: list[int], latent_width: int):
tokens_per_row = latent_width
token_ranges = []
cumulative_height = 0
for h in patch_heights:
start_token = tokens_per_row * cumulative_height
end_token = tokens_per_row * (cumulative_height + h)
token_ranges.append((start_token, end_token))
cumulative_height += h
return token_ranges
class DiffusionRunner:
"""Orchestrates the diffusion loop for image generation.
In distributed mode, it implements PipeFusion with:
- Sync pipeline for initial timesteps (full image, all devices in lockstep)
- Async pipeline for later timesteps (patches processed independently)
"""
def __init__(
self,
config: ImageModelConfig,
adapter: ModelAdapter[Any, Any],
group: Optional[mx.distributed.Group],
shard_metadata: PipelineShardMetadata,
num_patches: Optional[int] = None,
):
self.config = config
self.adapter = adapter
self.group = group
if group is None:
self.rank = 0
self.world_size = 1
self.next_rank = 0
self.prev_rank = 0
self.start_layer = 0
self.end_layer = config.total_blocks
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.next_rank = (self.rank + 1) % self.world_size
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.num_patches = num_patches if num_patches else max(1, self.world_size)
self.total_joint = config.joint_block_count
self.total_single = config.single_block_count
self.total_layers = config.total_blocks
self._guidance_override: float | None = None
self._compute_assigned_blocks()
def _compute_assigned_blocks(self) -> None:
"""Determine which joint/single blocks this stage owns."""
start = self.start_layer
end = self.end_layer
if end <= self.total_joint:
self.joint_start = start
self.joint_end = end
self.single_start = 0
self.single_end = 0
elif start >= self.total_joint:
self.joint_start = 0
self.joint_end = 0
self.single_start = start - self.total_joint
self.single_end = end - self.total_joint
else:
self.joint_start = start
self.joint_end = self.total_joint
self.single_start = 0
self.single_end = end - self.total_joint
self.has_joint_blocks = self.joint_end > self.joint_start
self.has_single_blocks = self.single_end > self.single_start
self.owns_concat_stage = self.has_joint_blocks and (
self.has_single_blocks or self.end_layer == self.total_joint
)
# Wrappers created lazily on first forward (need text_seq_len)
self.joint_block_wrappers: list[JointBlockWrapper[Any]] | None = None
self.single_block_wrappers: list[SingleBlockWrapper[Any]] | None = None
self._wrappers_initialized = False
self._current_text_seq_len: int | None = None
@property
def is_first_stage(self) -> bool:
return self.rank == 0
@property
def is_last_stage(self) -> bool:
return self.rank == self.world_size - 1
@property
def is_distributed(self) -> bool:
return self.group is not None
def _get_effective_guidance_scale(self) -> float | None:
if self._guidance_override is not None:
return self._guidance_override
return self.config.guidance_scale
def _ensure_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> None:
"""Lazily create block wrappers on first forward pass.
Wrappers need text_seq_len which is only known after prompt encoding.
Re-initializes if text_seq_len changes (e.g., warmup vs real generation).
"""
if self._wrappers_initialized and self._current_text_seq_len == text_seq_len:
return
self.joint_block_wrappers = self.adapter.get_joint_block_wrappers(
text_seq_len=text_seq_len,
encoder_hidden_states_mask=encoder_hidden_states_mask,
)
self.single_block_wrappers = self.adapter.get_single_block_wrappers(
text_seq_len=text_seq_len,
)
self._wrappers_initialized = True
self._current_text_seq_len = text_seq_len
def _reset_all_caches(self) -> None:
"""Reset KV caches on all wrappers for a new generation."""
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.reset_cache()
if self.single_block_wrappers:
for wrapper in self.single_block_wrappers:
wrapper.reset_cache()
def _set_text_seq_len(self, text_seq_len: int) -> None:
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_text_seq_len(text_seq_len)
if self.single_block_wrappers:
for wrapper in self.single_block_wrappers:
wrapper.set_text_seq_len(text_seq_len)
def _calculate_capture_steps(
self,
partial_images: int,
init_time_step: int,
num_inference_steps: int,
) -> set[int]:
"""Calculate which timesteps should produce partial images.
Places the first partial after step 1 for fast initial feedback,
then evenly spaces remaining partials with equal gaps between them
and from the last partial to the final image.
Args:
partial_images: Number of partial images to capture
init_time_step: Starting timestep (for img2img this may not be 0)
num_inference_steps: Total inference steps
Returns:
Set of timestep indices to capture
"""
if partial_images <= 0:
return set()
total_steps = num_inference_steps - init_time_step
if total_steps <= 1:
return set()
if partial_images >= total_steps - 1:
return set(range(init_time_step, num_inference_steps - 1))
capture_steps: set[int] = set()
first_capture = init_time_step + 1
capture_steps.add(first_capture)
if partial_images == 1:
return capture_steps
final_step = num_inference_steps - 1
remaining_range = final_step - first_capture
for i in range(1, partial_images):
step_idx = first_capture + int(i * remaining_range / partial_images)
capture_steps.add(step_idx)
return capture_steps
def generate_image(
self,
runtime_config: Config,
prompt: str,
seed: int,
partial_images: int = 0,
guidance_override: float | None = None,
negative_prompt: str | None = None,
num_sync_steps: int = 1,
):
"""Primary entry point for image generation.
Orchestrates the full generation flow:
1. Create runtime config
2. Create initial latents
3. Encode prompt
4. Run diffusion loop (yielding partials if requested)
5. Decode to image
Args:
settings: Generation config (steps, height, width)
prompt: Text prompt
seed: Random seed
partial_images: Number of intermediate images to yield (0 for none)
guidance_override: Optional override for guidance scale (CFG)
Yields:
Partial images as (GeneratedImage, partial_index, total_partials) tuples
Final GeneratedImage
"""
self._guidance_override = guidance_override
latents = self.adapter.create_latents(seed, runtime_config)
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
capture_steps = self._calculate_capture_steps(
partial_images=partial_images,
init_time_step=runtime_config.init_time_step,
num_inference_steps=runtime_config.num_inference_steps,
)
diffusion_gen = self._run_diffusion_loop(
latents=latents,
prompt_data=prompt_data,
runtime_config=runtime_config,
seed=seed,
prompt=prompt,
capture_steps=capture_steps,
num_sync_steps=num_sync_steps,
)
partial_index = 0
total_partials = len(capture_steps)
if capture_steps:
try:
while True:
partial_latents, _step = next(diffusion_gen)
if self.is_last_stage:
partial_image = self.adapter.decode_latents(
partial_latents, runtime_config, seed, prompt
)
yield (partial_image, partial_index, total_partials)
partial_index += 1
except StopIteration as e:
latents = e.value
else:
try:
while True:
next(diffusion_gen)
except StopIteration as e:
latents = e.value
if self.is_last_stage:
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt)
def _run_diffusion_loop(
self,
latents: mx.array,
prompt_data: PromptData,
runtime_config: Config,
seed: int,
prompt: str,
num_sync_steps: int,
capture_steps: set[int] | None = None,
):
if capture_steps is None:
capture_steps = set()
self._reset_all_caches()
time_steps = tqdm(range(runtime_config.num_inference_steps))
ctx = self.adapter.model.callbacks.start(
seed=seed, prompt=prompt, config=runtime_config
)
ctx.before_loop(
latents=latents,
)
for t in time_steps:
try:
latents = self._diffusion_step(
t=t,
config=runtime_config,
latents=latents,
prompt_data=prompt_data,
num_sync_steps=num_sync_steps,
)
ctx.in_loop(
t=t,
latents=latents,
)
mx.eval(latents)
if t in capture_steps and self.is_last_stage:
yield (latents, t)
except KeyboardInterrupt: # noqa: PERF203
ctx.interruption(t=t, latents=latents)
raise StopImageGenerationException(
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
) from None
ctx.after_loop(latents=latents)
return latents
def _forward_pass(
self,
latents: mx.array,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
t: int,
config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
conditioning_latents: mx.array | None = None,
) -> mx.array:
"""Run a single forward pass through the transformer.
Args:
latents: Input latents (already scaled by caller)
prompt_embeds: Text embeddings
pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)
t: Current timestep
config: Runtime configuration
encoder_hidden_states_mask: Attention mask for text (Qwen)
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
conditioning_latents: Conditioning latents for edit mode
Returns:
Noise prediction tensor
"""
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_hidden_states_mask)
if self.joint_block_wrappers and encoder_hidden_states_mask is not None:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_hidden_states_mask)
scaled_latents = config.scheduler.scale_model_input(latents, t)
# For edit mode: concatenate with conditioning latents
original_latent_tokens = scaled_latents.shape[1]
if conditioning_latents is not None:
scaled_latents = mx.concatenate(
[scaled_latents, conditioning_latents], axis=1
)
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
scaled_latents, prompt_embeds
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds, hidden_states=hidden_states
)
rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
encoder_hidden_states_mask=encoder_hidden_states_mask,
cond_image_grid=cond_image_grid,
)
assert self.joint_block_wrappers is not None
for wrapper in self.joint_block_wrappers:
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
)
if self.joint_block_wrappers:
hidden_states = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
)
# Extract image portion
hidden_states = hidden_states[:, text_seq_len:, ...]
# For edit mode: extract only the generated portion (exclude conditioning latents)
if conditioning_latents is not None:
hidden_states = hidden_states[:, :original_latent_tokens, ...]
return self.adapter.final_projection(hidden_states, text_embeddings)
def _diffusion_step(
self,
t: int,
config: Config,
latents: mx.array,
prompt_data: PromptData,
num_sync_steps: int,
) -> mx.array:
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif t < config.init_time_step + num_sync_steps:
return self._sync_pipeline_step(
t,
config,
latents,
prompt_data,
)
else:
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
is_first_async_step=t == config.init_time_step + num_sync_steps,
)
def _single_node_step(
self,
t: int,
config: Config,
latents: mx.array,
prompt_data: PromptData,
) -> mx.array:
cond_image_grid = prompt_data.cond_image_grid
needs_cfg = self.adapter.needs_cfg
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
batched_pooled if batched_pooled is not None else prompt_embeds
)
step_latents = mx.concatenate([latents, latents], axis=0)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = latents
noise = self._forward_pass(
step_latents,
prompt_embeds,
pooled_embeds,
t=t,
config=config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
conditioning_latents=cond_latents,
)
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale=guidance_scale
)
return config.scheduler.step(noise=noise, timestep=t, latents=latents)
def _create_patches(
self,
latents: mx.array,
config: Config,
) -> tuple[list[mx.array], list[tuple[int, int]]]:
latent_height = config.height // 16
latent_width = config.width // 16
patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)
token_indices = calculate_token_indices(patch_heights, latent_width)
patch_latents = [latents[:, start:end, :] for start, end in token_indices]
return patch_latents, token_indices
def _run_sync_pass(
self,
t: int,
config: Config,
scaled_hidden_states: mx.array,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
encoder_hidden_states_mask: mx.array | None,
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] | None,
kontext_image_ids: mx.array | None,
num_img_tokens: int,
original_latent_tokens: int,
conditioning_latents: mx.array | None,
) -> mx.array | None:
hidden_states = scaled_hidden_states
batch_size = hidden_states.shape[0]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
dtype = scaled_hidden_states.dtype
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_hidden_states_mask)
encoder_hidden_states: mx.array | None = None
if self.is_first_stage:
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
hidden_states, prompt_embeds
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
encoder_hidden_states_mask=encoder_hidden_states_mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
if self.has_joint_blocks:
if not self.is_first_stage:
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if self.owns_concat_stage:
assert encoder_hidden_states is not None
concatenated = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if not self.is_last_stage:
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
hidden_states = hidden_states[:, text_seq_len:, ...]
if conditioning_latents is not None:
hidden_states = hidden_states[:, :original_latent_tokens, ...]
if self.is_last_stage:
return self.adapter.final_projection(hidden_states, text_embeddings)
return None
def _sync_pipeline_step(
self,
t: int,
config: Config,
hidden_states: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t)
original_latent_tokens = scaled_hidden_states.shape[1]
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
batched_pooled if batched_pooled is not None else prompt_embeds
)
step_latents = mx.concatenate(
[scaled_hidden_states, scaled_hidden_states], axis=0
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = scaled_hidden_states
if cond_latents is not None:
num_img_tokens = original_latent_tokens + cond_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
if self.is_first_stage and cond_latents is not None:
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
noise = self._run_sync_pass(
t,
config,
step_latents,
prompt_embeds,
pooled_embeds,
encoder_mask,
cond_image_grid,
kontext_image_ids,
num_img_tokens,
original_latent_tokens,
cond_latents,
)
if self.is_last_stage:
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
hidden_states = config.scheduler.step(
noise=noise, timestep=t, latents=prev_latents
)
if not self.is_first_stage:
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
hidden_states = prev_latents
return hidden_states
def _async_pipeline_step(
self,
t: int,
config: Config,
latents: mx.array,
prompt_data: PromptData,
is_first_async_step: bool,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
pooled_embeds = (
batched_pooled if batched_pooled is not None else prompt_embeds
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_mask)
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
prev_patch_latents = [p for p in patch_latents]
encoder_hidden_states: mx.array | None = None
for patch_idx in range(len(patch_latents)):
patch = patch_latents[patch_idx]
if (
self.is_first_stage
and not self.is_last_stage
and not is_first_async_step
):
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
noise, encoder_hidden_states = self._run_single_patch_pass(
patch=step_patch,
patch_idx=patch_idx,
token_indices=token_indices[patch_idx],
prompt_embeds=prompt_embeds,
text_embeddings=text_embeddings,
image_rotary_embeddings=image_rotary_embeddings,
encoder_hidden_states=encoder_hidden_states,
)
if self.is_last_stage:
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
patch_latents[patch_idx] = config.scheduler.step(
noise=noise,
timestep=t,
latents=prev_patch_latents[patch_idx],
)
if not self.is_first_stage and t != config.num_inference_steps - 1:
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
return mx.concatenate(patch_latents, axis=1)
def _run_single_patch_pass(
self,
patch: mx.array,
patch_idx: int,
token_indices: tuple[int, int],
prompt_embeds: mx.array,
text_embeddings: mx.array,
image_rotary_embeddings: RotaryEmbeddings,
encoder_hidden_states: mx.array | None,
) -> tuple[mx.array | None, mx.array | None]:
"""Process a single patch through the forward pipeline.
Handles stage-to-stage communication (stage i -> stage i+1).
Ring communication (last stage -> first stage) is handled by the caller.
Args:
patch: The patch latents to process
patch_idx: Index of this patch (0-indexed)
token_indices: (start_token, end_token) for this patch
prompt_embeds: Text embeddings (for compute_embeddings on first stage)
text_embeddings: Precomputed text embeddings
image_rotary_embeddings: Precomputed rotary embeddings
encoder_hidden_states: Encoder hidden states (passed between patches)
Returns:
(noise_prediction, encoder_hidden_states) - noise is None for non-last stages
"""
start_token, end_token = token_indices
batch_size = patch.shape[0]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
if self.has_joint_blocks:
if not self.is_first_stage:
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
if patch_idx == 0:
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
patch, prompt_embeds
)
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if self.owns_concat_stage:
assert encoder_hidden_states is not None
patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if not self.is_last_stage:
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None
if self.is_last_stage:
patch_img_only = patch[:, text_seq_len:, :]
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
return noise, encoder_hidden_states

View File

@@ -13,11 +13,17 @@ from mlx.nn.layers.distributed import (
shard_linear,
sum_gradients,
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
@@ -335,7 +341,33 @@ def tensor_auto_parallel(
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, MiniMaxModel):
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -343,6 +375,15 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -377,6 +418,34 @@ class TensorParallelShardingStrategy(ABC):
) -> nn.Module: ...
class LlamaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(LlamaModel, model)
for layer in model.layers:
# Force load weights before sharding to avoid FAST_SYNCH deadlock
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.n_heads //= self.N
if layer.self_attn.n_kv_heads is not None:
layer.self_attn.n_kv_heads //= self.N
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
@@ -403,6 +472,105 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
else:
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
layer.self_attn.q_b_proj
)
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
layer.self_attn.kv_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
# Shard the MLP
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
else:
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
return model
class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.block_sparse_moe.switch_mlp.down_proj
)
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
@@ -455,3 +623,58 @@ class ShardedQwenMoE(CustomMlxLayer):
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
layer.self_attn.num_key_value_groups = (
layer.self_attn.num_attention_heads
// layer.self_attn.num_key_value_heads
)
layer.self_attn.sinks = layer.self_attn.sinks[
layer.self_attn.num_attention_heads
* self.group.rank() : layer.self_attn.num_attention_heads
* (self.group.rank() + 1)
]
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
class ShardedGptOssMoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

View File

@@ -1,7 +1,7 @@
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper

View File

@@ -9,13 +9,15 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
EventId,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
NodeDownloadProgress,
NodeGatheredInfo,
TaskCreated,
@@ -28,6 +30,7 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
DownloadModel,
ImageEdits,
Shutdown,
Task,
TaskStatus,
@@ -93,6 +96,10 @@ class Worker:
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
async def run(self):
logger.info("Starting Worker")
@@ -157,6 +164,17 @@ class Worker:
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
)
async def plan_step(self):
while True:
await anyio.sleep(0.1)
@@ -169,6 +187,8 @@ class Worker:
self.state.instances,
self.state.runners,
self.state.tasks,
self.input_chunk_buffer,
self.input_chunk_counts,
)
if task is None:
continue
@@ -232,6 +252,46 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
chunks = self.input_chunk_buffer.get(cmd_id, {})
assembled = "".join(chunks[i] for i in range(len(chunks)))
logger.info(
f"Assembled input image from {len(chunks)} chunks, "
f"total size: {len(assembled)} bytes"
)
# Create modified task with assembled image data
modified_task = ImageEdits(
task_id=task.task_id,
command_id=task.command_id,
instance_id=task.instance_id,
task_status=task.task_status,
task_params=ImageEditsInternalParams(
image_data=assembled,
total_input_chunks=task.task_params.total_input_chunks,
prompt=task.task_params.prompt,
model=task.task_params.model,
n=task.task_params.n,
quality=task.task_params.quality,
output_format=task.task_params.output_format,
response_format=task.task_params.response_format,
size=task.task_params.size,
image_strength=task.task_params.image_strength,
bench=task.task_params.bench,
stream=task.task_params.stream,
partial_images=task.task_params.partial_images,
advanced_params=task.task_params.advanced_params,
),
)
# Cleanup buffers
if cmd_id in self.input_chunk_buffer:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)

View File

@@ -3,12 +3,14 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadModel,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -49,6 +51,8 @@ def plan(
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_counts: Mapping[CommandId, int] | None = None,
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
@@ -58,7 +62,7 @@ def plan(
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
)
@@ -262,14 +266,24 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
if not isinstance(task, ChatCompletion):
# TODO(ciaran): do this better!
if not isinstance(task, (ChatCompletion, ImageGeneration, ImageEdits)):
continue
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
continue
# For ImageEdits tasks, verify all input chunks have been received
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue

View File

@@ -1,3 +1,4 @@
import base64
import time
from collections.abc import Generator
from functools import cache
@@ -12,8 +13,11 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
load_harmony_encoding,
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.models.model_cards import ModelId, ModelTask
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -24,6 +28,8 @@ from exo.shared.types.events import (
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -33,6 +39,8 @@ from exo.shared.types.tasks import (
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ImageGenerationResponse,
PartialImageResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -48,7 +56,15 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.image import (
DistributedImageModel,
generate_image,
initialize_image_model,
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -79,7 +95,7 @@ def main(
setup_start_time = time.time()
model = None
model: Model | DistributedImageModel | None = None
tokenizer = None
group = None
@@ -133,15 +149,25 @@ def main(
)
time.sleep(0.5)
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -151,15 +177,30 @@ def main(
)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert not isinstance(model, DistributedImageModel)
assert tokenizer
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
logger.info("warmup completed (non-primary node)")
current_status = RunnerReady()
logger.info("runner ready")
case ChatCompletion(task_params=task_params, command_id=command_id) if (
@@ -173,7 +214,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert model
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
assert task_params.messages[0].content is not None
@@ -240,6 +281,90 @@ def main(
)
raise
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, DistributedImageModel)
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(model=model, task=task_params):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, DistributedImageModel)
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
image_index = 0
for response in generate_image(model=model, task=task_params):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
@@ -329,6 +454,72 @@ def parse_thinking_models(
yield response
def _send_image_chunk(
encoded_data: str,
command_id: CommandId,
model_id: ModelId,
event_sender: MpSender[Event],
image_index: int,
is_partial: bool,
partial_index: int | None = None,
total_partials: int | None = None,
stats: ImageGenerationStats | None = None,
) -> None:
"""Send base64-encoded image data as chunks via events."""
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
for chunk_index, chunk_data in enumerate(data_chunks):
# Only include stats on the last chunk of the final image
chunk_stats = (
stats if chunk_index == total_chunks - 1 and not is_partial else None
)
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
is_partial=is_partial,
partial_index=partial_index,
total_partials=total_partials,
stats=chunk_stats,
),
)
)
def _process_image_response(
response: ImageGenerationResponse | PartialImageResponse,
command_id: CommandId,
shard_metadata: ShardMetadata,
event_sender: MpSender[Event],
image_index: int,
) -> None:
"""Process a single image response and send chunks."""
encoded_data = base64.b64encode(response.image_data).decode("utf-8")
is_partial = isinstance(response, PartialImageResponse)
# Extract stats from final ImageGenerationResponse if available
stats = response.stats if isinstance(response, ImageGenerationResponse) else None
_send_image_chunk(
encoded_data=encoded_data,
command_id=command_id,
model_id=shard_metadata.model_card.model_id,
event_sender=event_sender,
image_index=response.partial_index if is_partial else image_index,
is_partial=is_partial,
partial_index=response.partial_index if is_partial else None,
total_partials=response.total_partials if is_partial else None,
stats=stats,
)
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import BaseTask, TaskId
@@ -38,6 +38,7 @@ def get_pipeline_shard_metadata(
n_layers=32,
hidden_size=2048,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
device_rank=device_rank,
world_size=world_size,

View File

@@ -11,7 +11,7 @@ import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
@@ -87,6 +87,7 @@ def run_gpt_oss_pipeline_device(
n_layers=24,
hidden_size=2880,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
device_rank=rank,
world_size=world_size,
@@ -156,6 +157,7 @@ def run_gpt_oss_tensor_parallel_device(
n_layers=24,
hidden_size=2880,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
device_rank=rank,
world_size=world_size,

2392
uv.lock generated
View File

File diff suppressed because it is too large Load Diff