Compare commits

..

222 Commits

Author SHA1 Message Date
ciaranbor
9b6a6059b9 Start image editing time steps at 0 2026-01-09 19:42:13 +00:00
ciaranbor
fff08568ba Ignore image_strength 2026-01-09 19:23:14 +00:00
ciaranbor
edde03d74c Handle conditioning latents in sync pipeline 2026-01-09 17:44:52 +00:00
ciaranbor
d67eb9f2af Use dummy image for editing warmup 2026-01-09 17:27:41 +00:00
ciaranbor
bcf1fbd561 Support streaming for image editing 2026-01-09 17:08:15 +00:00
ciaranbor
9067033f20 Support image editing in runner 2026-01-09 16:15:51 +00:00
ciaranbor
17ef7d8838 Add editing features to adapter 2026-01-09 16:15:51 +00:00
ciaranbor
23962ab4e2 Default partial images to 3 if streaming 2026-01-09 16:15:51 +00:00
ciaranbor
59c5de8256 Add Qwen-Image model adapter 2026-01-09 16:15:51 +00:00
ciaranbor
a43a14da1f Add Qwen-Image-Edit model config 2026-01-09 16:15:51 +00:00
ciaranbor
bb46c12878 Use image generation in streaming mode in UI 2026-01-09 16:15:51 +00:00
ciaranbor
79b5316efe Handle partial image streaming 2026-01-09 16:15:51 +00:00
ciaranbor
40d49c2720 Add streaming params to ImageGenerationTaskParams 2026-01-09 16:15:51 +00:00
ciaranbor
f87d06b1f1 Add Qwen-Image-Edit-2509 2026-01-09 16:15:51 +00:00
ciaranbor
617c6ffdcb Handle image editing time steps 2026-01-09 16:15:51 +00:00
ciaranbor
acd246f49e Fix time steps 2026-01-09 16:15:51 +00:00
ciaranbor
d5536c1a2b Remove duplicate RunnerReady event 2026-01-09 16:15:51 +00:00
ciaranbor
2aba3fc9a9 Fix image_strength meaning 2026-01-09 16:15:51 +00:00
ciaranbor
c7cb22d546 Truncate image data logs 2026-01-09 16:15:51 +00:00
ciaranbor
5a429f4ab6 Chunk image input 2026-01-09 16:15:51 +00:00
ciaranbor
42c427a5bb Avoid logging image data 2026-01-09 16:15:51 +00:00
ciaranbor
4b8976be51 Support image editing 2026-01-09 16:15:51 +00:00
Sami Khan
85bf4aba1c small UI change 2026-01-09 16:15:51 +00:00
Sami Khan
de12873b1a image gen in dashboard 2026-01-09 16:15:51 +00:00
ciaranbor
a86bb97d65 Better llm model type check 2026-01-07 10:17:42 +00:00
ciaranbor
a5c6db7145 Prune blocks before model load 2026-01-06 16:47:58 +00:00
ciaranbor
e74345bb09 Own TODOs 2026-01-06 13:30:17 +00:00
ciaranbor
58d1f159b7 Remove double RunnerReady event 2026-01-06 13:25:01 +00:00
ciaranbor
8183225714 Fix hidden_size for image models 2026-01-06 12:33:28 +00:00
ciaranbor
46f957ee5b Fix uv.lock 2026-01-06 12:25:23 +00:00
ciaranbor
a2f52e04e3 Fix image model cards 2026-01-06 11:00:01 +00:00
ciaranbor
859960608b Skip decode on non-final ranks 2026-01-06 10:51:21 +00:00
ciaranbor
80ad016004 Final rank produces image 2026-01-06 10:51:21 +00:00
ciaranbor
d8938e6e72 Increase number of sync steps 2026-01-06 10:51:21 +00:00
ciaranbor
bd6a6cc6d3 Change Qwen-Image steps 2026-01-06 10:51:21 +00:00
ciaranbor
a88d588de4 Fix Qwen-Image latent shapes 2026-01-06 10:51:21 +00:00
ciaranbor
08e8a30fb7 Fix joint block patch recv shape for non-zero ranks 2026-01-06 10:51:21 +00:00
ciaranbor
d926df8f95 Fix comms issue for models without single blocks 2026-01-06 10:51:21 +00:00
ciaranbor
4ff550106d Support Qwen in DiffusionRunner pipefusion 2026-01-06 10:51:21 +00:00
ciaranbor
1d168dfe61 Implement Qwen pipefusion 2026-01-06 10:51:21 +00:00
ciaranbor
f813b9f5e1 Add guidance_scale parameter to image model config 2026-01-06 10:51:21 +00:00
ciaranbor
0f96083d48 Move orchestration to DiffusionRunner 2026-01-06 10:51:21 +00:00
ciaranbor
1a1b394f6d Add initial QwenModelAdapter 2026-01-06 10:51:21 +00:00
ciaranbor
6ab5a9d3d4 Tweak embeddings interface 2026-01-06 10:51:21 +00:00
ciaranbor
90bf4608df Add Qwen ImageModelConfig 2026-01-06 10:51:21 +00:00
ciaranbor
f2a0fdf25c Use 10% sync steps 2026-01-06 10:51:21 +00:00
ciaranbor
f574b3f57e Update FluxModelAdaper for new interface 2026-01-06 10:51:21 +00:00
ciaranbor
5cca9d8493 Register QwenModelAdapter 2026-01-06 10:51:21 +00:00
ciaranbor
3cd421079b Support multiple forward passes in runner 2026-01-06 10:51:21 +00:00
ciaranbor
d9eb4637ee Extend block wrapper parameters 2026-01-06 10:51:21 +00:00
ciaranbor
19f52e80fd Relax adaptor typing 2026-01-06 10:51:21 +00:00
ciaranbor
3f4162b732 Add Qwen-Image model card 2026-01-06 10:51:21 +00:00
ciaranbor
cad86ee76e Clean up dead code 2026-01-06 10:51:21 +00:00
ciaranbor
d7be6a09b0 Add BaseModelAdaptor 2026-01-06 10:51:21 +00:00
ciaranbor
79603e73ed Refactor filestructure 2026-01-06 10:51:21 +00:00
ciaranbor
78901cfe23 Treat unified blocks as single blocks (equivalent) 2026-01-06 10:51:21 +00:00
ciaranbor
c0ac199ab8 Refactor to handle entire denoising process in Diffusion runner 2026-01-06 10:51:21 +00:00
ciaranbor
b70d6abfa2 Move transformer to adapter 2026-01-06 10:51:21 +00:00
ciaranbor
16bfab9bab Move some more logic to adaptor 2026-01-06 10:51:21 +00:00
ciaranbor
28ee6f6370 Add generic block wrapper 2026-01-06 10:51:21 +00:00
ciaranbor
6b299bab8f Access transformer blocks from adaptor 2026-01-06 10:51:21 +00:00
ciaranbor
a3754a60b6 Better typing 2026-01-06 10:51:21 +00:00
ciaranbor
06039f93f5 Create wrappers at init time 2026-01-06 10:51:21 +00:00
ciaranbor
fcfecc9cd8 Combine model factory and adaptor 2026-01-06 10:51:21 +00:00
ciaranbor
ba798ae4f9 Implement model factory 2026-01-06 10:51:21 +00:00
ciaranbor
9a0e1e93a9 Add adaptor registry 2026-01-06 10:51:21 +00:00
ciaranbor
196f504c82 Remove mflux/generator/generate.py 2026-01-06 10:51:21 +00:00
ciaranbor
e3d89b8d63 Switch to using DistributedImageModel 2026-01-06 10:51:21 +00:00
ciaranbor
cb8079525c Add DistributedImageModel 2026-01-06 10:51:21 +00:00
ciaranbor
cb03c62c4a Use new generic wrappers, etc in denoising 2026-01-06 10:51:21 +00:00
ciaranbor
0653668048 Add generic transformer block wrappers 2026-01-06 10:51:21 +00:00
ciaranbor
0054bc4c14 Add FluxAdaptor 2026-01-06 10:51:21 +00:00
ciaranbor
b7b682b7bb Add ModelAdaptor, derivations implement model specific logic 2026-01-06 10:51:21 +00:00
ciaranbor
f7a651c1c1 Introduce image model config concept 2026-01-06 10:51:21 +00:00
ciaranbor
98e8d74cea Consolidate kv cache patching 2026-01-06 10:51:21 +00:00
ciaranbor
27567f8a4e Support different configuration comms 2026-01-06 10:51:21 +00:00
ciaranbor
28227bb45a Add ImageGenerator protocol 2026-01-06 10:51:21 +00:00
ciaranbor
7683d4a21f Force final patch receive order 2026-01-06 10:51:21 +00:00
ciaranbor
0a3cb77a29 Remove logs 2026-01-06 10:51:21 +00:00
ciaranbor
3f5810c1fe Update patch list 2026-01-06 10:51:21 +00:00
ciaranbor
fc62ae1b9b Slight refactor 2026-01-06 10:51:21 +00:00
ciaranbor
ec5bad4254 Don't need array for prev patches 2026-01-06 10:51:21 +00:00
ciaranbor
f9f54be32b Fix send/recv order 2026-01-06 10:51:21 +00:00
ciaranbor
36daf9183f Fix async single transformer block 2026-01-06 10:51:21 +00:00
ciaranbor
5d38ffc77e Use relative rank variables 2026-01-06 10:51:21 +00:00
ciaranbor
1b4851765a Fix writing patches 2026-01-06 10:51:21 +00:00
ciaranbor
8787eaf3df Collect final image 2026-01-06 10:51:21 +00:00
ciaranbor
e1e3aa7a5e Fix recv_template shape 2026-01-06 10:51:21 +00:00
ciaranbor
0fe5239273 Add logs 2026-01-06 10:51:21 +00:00
ciaranbor
7eddf7404b Optimise async pipeline 2026-01-06 10:51:21 +00:00
ciaranbor
5f3bc30f17 Add next_rank and prev_rank members 2026-01-06 10:51:21 +00:00
ciaranbor
90a7e6601d Add _create_patches method 2026-01-06 10:51:21 +00:00
ciaranbor
ce2691c8d3 Fix shapes 2026-01-06 10:51:21 +00:00
ciaranbor
076d2901e8 Reorder comms 2026-01-06 10:51:20 +00:00
ciaranbor
7a733b584c Remove all_gather from sync pipeline, send from final rank to first rank 2026-01-06 10:51:20 +00:00
ciaranbor
94fee6f2d2 Simplify kv_cache initialization 2026-01-06 10:51:20 +00:00
ciaranbor
ef4fe09424 Fix kv cache 2026-01-06 10:51:20 +00:00
ciaranbor
2919bcf21d Clean up kv caches 2026-01-06 10:51:20 +00:00
ciaranbor
dd84cc9ca2 Fix return 2026-01-06 10:51:20 +00:00
ciaranbor
5a74d76d41 Fix hidden_states shapes 2026-01-06 10:51:20 +00:00
ciaranbor
e115814c74 Only perform projection and scheduler step on last rank 2026-01-06 10:51:20 +00:00
ciaranbor
d85432d4f0 Only compute embeddings on rank 0 2026-01-06 10:51:20 +00:00
ciaranbor
da823a2b02 Remove eval 2026-01-06 10:51:20 +00:00
ciaranbor
8576f4252b Remove eval 2026-01-06 10:51:20 +00:00
ciaranbor
7ca0bc5b55 Only send encoder_hidden_states with the first patch (once per timestep) 2026-01-06 10:51:20 +00:00
ciaranbor
db24f052d7 Remove redundant text kv cache computation 2026-01-06 10:51:20 +00:00
ciaranbor
7b8382be10 Concatenate before all gather 2026-01-06 10:51:20 +00:00
ciaranbor
d3685b0eb5 Increase number of sync steps 2026-01-06 10:51:20 +00:00
ciaranbor
93f4bdc5f9 Reinitialise kv_caches between generations 2026-01-06 10:51:20 +00:00
ciaranbor
8eea0327b8 Eliminate double kv cache computation 2026-01-06 10:51:20 +00:00
ciaranbor
085358e5e0 Add kv cache caching wrappers for sync pipeline transformer blocks 2026-01-06 10:51:20 +00:00
ciaranbor
546efe4dd2 Persist kv caches 2026-01-06 10:51:20 +00:00
ciaranbor
4ddfb6e254 Implement naive async pipeline implementation 2026-01-06 10:51:20 +00:00
ciaranbor
12f20fd94e Use wrapper classes for patched transformer logic 2026-01-06 10:51:20 +00:00
ciaranbor
f7ba70d5ae Add patch-aware joint and single attention wrappers 2026-01-06 10:51:20 +00:00
ciaranbor
4ecad10a66 Fix group.size() 2026-01-06 10:51:20 +00:00
ciaranbor
552ae776fe Add classes to manage kv caches with patch support 2026-01-06 10:51:20 +00:00
ciaranbor
6e0a6e8956 Use heuristic for number of sync steps 2026-01-06 10:51:20 +00:00
ciaranbor
e8b0a2124c Generalise number of denoising steps 2026-01-06 10:51:20 +00:00
ciaranbor
129df1ec89 Add flux1-dev 2026-01-06 10:51:20 +00:00
ciaranbor
a87fe26973 Move scheduler step to inner pipeline 2026-01-06 10:51:20 +00:00
ciaranbor
a9ea223dc7 Add barrier before all_gather 2026-01-06 10:51:20 +00:00
ciaranbor
0af3349f2f Fix transformer blocks pruning 2026-01-06 10:51:20 +00:00
ciaranbor
20e3319a3e Fix image generation api 2026-01-06 10:51:20 +00:00
ciaranbor
4c88fac266 Create queue in try block 2026-01-06 10:51:20 +00:00
ciaranbor
e1d916f743 Conform to rebase 2026-01-06 10:51:20 +00:00
ciaranbor
09c9b2e29f Refactor denoising 2026-01-06 10:51:20 +00:00
ciaranbor
b6359a7199 Move more logic to DistributedFlux 2026-01-06 10:51:20 +00:00
ciaranbor
b5a043f676 Move surrounding logic back to _sync_pipeline 2026-01-06 10:51:20 +00:00
ciaranbor
55e690fd49 Add patching aware member variables 2026-01-06 10:51:20 +00:00
ciaranbor
9e4ffb11ec Implement sync/async switching logic 2026-01-06 10:51:20 +00:00
ciaranbor
d665a8d05a Move current transformer implementation to _sync_pipeline method 2026-01-06 10:51:20 +00:00
ciaranbor
cac77816be Remove some logs 2026-01-06 10:51:20 +00:00
ciaranbor
25b9c3369e Remove old Flux1 implementation 2026-01-06 10:51:20 +00:00
ciaranbor
c19c5b4080 Prune unused transformer blocks 2026-01-06 10:51:20 +00:00
ciaranbor
9592f8b6b0 Add mx.eval 2026-01-06 10:51:20 +00:00
ciaranbor
7d7c16ebc1 Test evals 2026-01-06 10:51:20 +00:00
ciaranbor
450d0ba923 Test only barriers 2026-01-06 10:51:20 +00:00
ciaranbor
ea64062362 All perform final projection 2026-01-06 10:51:20 +00:00
ciaranbor
206b12e912 Another barrier 2026-01-06 10:51:20 +00:00
ciaranbor
eecc1da596 More debug 2026-01-06 10:51:20 +00:00
ciaranbor
44e68e4498 Add barriers 2026-01-06 10:51:20 +00:00
ciaranbor
f1548452fa Add log 2026-01-06 10:51:20 +00:00
ciaranbor
97769c82a9 Restore distributed logging 2026-01-06 10:51:20 +00:00
ciaranbor
26e5b03285 Use bootstrap logger 2026-01-06 10:51:20 +00:00
ciaranbor
8f93a1ff78 Remove logs 2026-01-06 10:51:20 +00:00
ciaranbor
e07dcc43b9 fix single block receive shape 2026-01-06 10:51:20 +00:00
ciaranbor
f91d0797fb Add debug logs 2026-01-06 10:51:20 +00:00
ciaranbor
aaeebaf79e Move communication logic to DistributedTransformer wrapper 2026-01-06 10:51:20 +00:00
ciaranbor
c3075a003e Move inference logic to DistribuedFlux1 2026-01-06 10:51:20 +00:00
ciaranbor
be796e55ac Add DistributedFlux1 class 2026-01-06 10:51:20 +00:00
ciaranbor
6e0c611f37 Rename pipeline to pipefusion 2026-01-06 10:51:20 +00:00
ciaranbor
88996eddcb Further refactor 2026-01-06 10:51:20 +00:00
ciaranbor
fb4fae51fa Refactor warmup 2026-01-06 10:51:20 +00:00
ciaranbor
dbefc209f5 Manually handle flux1 inference 2026-01-06 10:51:20 +00:00
ciaranbor
e6dd95524c Refactor flux1 image generation 2026-01-06 10:51:20 +00:00
ciaranbor
c2a9e5e53b Use quality parameter to set number of inference steps 2026-01-06 10:51:20 +00:00
ciaranbor
21587898bc Chunk image data transfer 2026-01-06 10:51:20 +00:00
ciaranbor
b6f23d0b01 Define EXO_MAX_CHUNK_SIZE 2026-01-06 10:51:20 +00:00
ciaranbor
f00ba03f4b Add indexing info to ImageChunk 2026-01-06 10:50:56 +00:00
ciaranbor
73e3713296 Remove sharding logs 2026-01-06 10:50:56 +00:00
ciaranbor
ecca6b4d20 Temp: reduce flux1.schnell storage size 2026-01-06 10:50:56 +00:00
ciaranbor
8bac08a236 Fix mflux transformer all_gather 2026-01-06 10:50:34 +00:00
ciaranbor
e7cca752fd Add all_gather -> broadcast todo 2026-01-06 10:50:34 +00:00
ciaranbor
540fe8b278 Fix world size 2026-01-06 10:50:34 +00:00
ciaranbor
2972f4620c Fix transition block? 2026-01-06 10:50:34 +00:00
ciaranbor
0ed81d8afa Implement image generation warmup 2026-01-06 10:50:34 +00:00
ciaranbor
66a24d59b9 Add logs 2026-01-06 10:50:11 +00:00
ciaranbor
5dcc359dba Add spiece.model to default patterns 2026-01-06 10:50:11 +00:00
ciaranbor
c2a4d61865 Just download all files for now 2026-01-06 10:49:43 +00:00
ciaranbor
ba12ee4897 Fix get_allow_patterns to include non-indexed safetensors files 2026-01-06 10:49:43 +00:00
ciaranbor
bcd69a3b01 Use half-open layer indexing in get_allow_patterns 2026-01-06 10:49:43 +00:00
ciaranbor
f5eb5d0338 Enable distributed mflux 2026-01-06 10:49:43 +00:00
ciaranbor
058aff5145 Implement mflux transformer sharding and communication pattern 2026-01-06 10:49:43 +00:00
ciaranbor
5cb0bc6a63 Update get_allow_patterns to handle sharding components 2026-01-06 10:49:43 +00:00
ciaranbor
c3aab450c6 Namespace both keys and values for component weight maps 2026-01-06 10:49:43 +00:00
ciaranbor
cf27673e20 Add components to Flux.1-schnell MODEL_CARD 2026-01-06 10:49:43 +00:00
ciaranbor
96c165e297 Add component concept for ModelMetadata 2026-01-06 10:48:42 +00:00
ciaranbor
2a589177cd Fix multiple components weight map key conflicts 2026-01-06 10:48:26 +00:00
ciaranbor
f782b619b6 get_weight_map: handle repos with multiple safetensors.index.json files 2026-01-06 10:48:26 +00:00
ciaranbor
dc661e4b5e Add initial image edits spec 2026-01-06 10:48:26 +00:00
ciaranbor
8b7d8ef394 Add image edits endpoint 2026-01-06 10:47:44 +00:00
ciaranbor
7dd2b328c8 Add ImageToImage task 2026-01-06 10:45:26 +00:00
ciaranbor
73a165702d Allow ModelCards to have multiple tasks 2026-01-06 10:44:53 +00:00
ciaranbor
0c76978b35 Fix text generation 2026-01-06 10:41:38 +00:00
ciaranbor
25188c845e Rename mlx_generate_image to mflux_generate 2026-01-06 10:41:38 +00:00
ciaranbor
df94169aba Initialize mlx or mflux engine based on model task 2026-01-06 10:39:27 +00:00
ciaranbor
a2d4c0de2a Restore warmup for text generation 2026-01-06 10:17:21 +00:00
ciaranbor
2edbc7e026 Add initialize_mflux function 2026-01-06 10:17:21 +00:00
ciaranbor
8f6e360d21 Move image generation to mflux engine 2026-01-06 10:17:21 +00:00
ciaranbor
085b966a5f Just use str for image generation size 2026-01-06 10:17:21 +00:00
ciaranbor
c64a55bfed Use MFlux for image generation 2026-01-06 10:17:21 +00:00
ciaranbor
fee716faab Add get_model_card function 2026-01-06 10:17:21 +00:00
ciaranbor
b88c89ee9c Add ModelTask enum 2026-01-06 10:17:21 +00:00
ciaranbor
9ef7b913e2 ADd flux1-schnell model 2026-01-06 10:08:11 +00:00
ciaranbor
0daa4b36db Add task field to ModelCard 2026-01-06 10:08:11 +00:00
ciaranbor
3c2da43792 Update mflux version 2026-01-06 10:04:51 +00:00
ciaranbor
8c4c53b50a Enable recursive repo downloads 2026-01-06 10:03:18 +00:00
ciaranbor
b2beb4c9cd Add dummy generate_image implementation 2026-01-06 10:03:18 +00:00
ciaranbor
098a11b262 Use base64 encoded str for image data 2026-01-06 10:03:18 +00:00
ciaranbor
bedb9045a0 Handle ImageGeneration tasks in _pending_tasks 2026-01-06 10:03:18 +00:00
ciaranbor
8e23841b4e Add mflux dependency 2026-01-06 10:03:18 +00:00
ciaranbor
4420eac10d Handle ImageGeneration task in runner task processing 2026-01-06 10:02:04 +00:00
ciaranbor
d0772e9e0f Handle ImageGeneration command in master command processing 2026-01-06 10:00:32 +00:00
ciaranbor
8d861168f1 Add image generation to API 2026-01-06 10:00:07 +00:00
ciaranbor
242648dff4 Add ImageGenerationResponse 2026-01-06 09:59:13 +00:00
ciaranbor
9b06b754cb Add ImageGeneration task 2026-01-06 09:59:13 +00:00
ciaranbor
1603984f45 Add ImageGeneration command 2026-01-06 09:58:46 +00:00
ciaranbor
f9418843f8 Add image generation params and response types 2026-01-05 21:51:10 +00:00
ciaranbor
877e7196c3 Add pillow dependency 2026-01-05 21:51:10 +00:00
ciaranbor
db7c4670b9 Fix mlx stream_generate import 2026-01-05 21:18:23 +00:00
madanlalit
4f6fcd9e93 feat(macos-app): add custom namespace UI for cluster isolation
Add Advanced Options section with custom namespace field that allows
users to override EXO_LIBP2P_NAMESPACE environment variable. This
enables splitting machines that can see each other into separate
clusters.

- Added customNamespace property with UserDefaults persistence
- Added Advanced Options collapsible section with text field
- Added Save & Restart button that auto-restarts exo process
- Namespace replaces buildTag when custom value is set
- Falls back to buildTag (version) when namespace is empty
2026-01-05 15:25:00 +01:00
Evan Quiney
839b67f318 [feat] Add an option to disable the worker (#1091)
## Motivation

Workerless machines can be used for networking without running any gpu
jobs - add a cli flag that adds this basic functionality.

## Changes

Adds the --no-worker cli flag

## Test Plan

### Manual Testing

Exo starts as expected

### Automated Testing

None
2026-01-05 12:05:03 +00:00
Drifter4242
47b8e0ce12 feat: remember last launch settings (model, sharding, instance type) (#1028)
## Motivation

Saves the last launch settings, so that the next time you run exo it
will default to the same launch settings.
This is just a small quality of life improvement.

## Changes

When you launch it saves the settings to the web browser local storage.
When it fills out the model list, it reads the settings and sets the
default.

I reviewed, tested and edited the code, but some of the code was written
by Claude Opus. I hope that's ok.

## Why It Works

See above

## Test Plan

### Manual Testing

I have two Macbook Studio M3 Ultras, each with 512Gb ram, connected with
Thunderbolt 5. I ran Kimi K2 Thinking with MLX Ring and Tensor Split.
I ran exo multiple times to confirm that the default works.

### Automated Testing

No changes to automated testing.
2026-01-05 11:27:14 +00:00
Evan Quiney
17f9b583a4 Task Deduplication (#1062) 2026-01-03 20:01:49 +00:00
RickyChen / 陳昭儒
844bcc7ce6 fix: prevent form submission during IME composition (#1069)
## Problem
When typing in Chinese (or other IME-based languages like
Japanese/Korean), pressing Enter to select a character from the IME
candidate list would incorrectly submit the message instead of
confirming the character selection.

## Solution
Added IME composition state detection in the `handleKeydown` function in
`ChatForm.svelte`:
- Check `event.isComposing` to detect active IME composition
- Fallback to `event.keyCode === 229` for broader browser compatibility
- Return early when IME is active, allowing normal character selection

## Changes
- Modified `dashboard/src/lib/components/ChatForm.svelte` 
- Added IME composition check before Enter key handling

Co-authored-by: Ricky Chen <rickychen@Rickys-MacBook-Pro.local>
2025-12-31 17:11:04 +00:00
Evan Quiney
c1be5184b2 Fix tests broken by 283c (#1063)
Some tests were broken by #1058 and #1046 - this fixes them.
2025-12-31 01:53:55 +00:00
Alex Cheema
1ec550dff1 Emit download progress on start, and change downloads to be keyed by model_id (#1044)
## Motivation

We added a download page to the dashboard which shows the currently
download status of each model on each node. Users have reported this to
be extremely useful.

However, we don't currently fetch the download progress on start, so it
doesn't show any model's download status.

## Changes

Fetch and emit model download status on start of worker, and
periodically every 5 mins.
Also to support this, I changed download_status to be keyed by model_id
instead of shard, since we want download_status of each model, not each
shard.

## Why It Works

The dashboard already implements the correct functionality, we just
weren't populating the download status in the state. Now it gets
populated and shows correctly.

## Test Plan

### Manual Testing
On a cluster of 2 x 512GB M3 Ultra Mac Studio, I launched an instance
onto one node that hadn't been downloaded. I checked the download page
and it showed the in progress download. I downloaded it to completion,
restarted exo on both nodes, and then opened the download page and it
showed the model as 100% downloaded and other models as 0% that hadn't
been downloaded.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2025-12-31 01:18:10 +00:00
Alex Cheema
283c0e39e4 Placement filters for tensor parallel supports_tensor, tensor dimension and pipeline parallel deepseek v3.1 (#1058)
## Motivation

Certain placements are not valid. Added filters to exclude these placements. There were invalid placement previews being shown in the dashboard which would then fail when the user actually tries to launch an instance with that placement.


## Changes

Three filters added:

1. Certain models do not support tensor parallel at all. Checks `supports_tensor` on the model_meta.
2. For models that do support tensor parallelism, certain tensor parallel sizes are not valid. This check is actually not correct right now but it works fine for now. The actual correct check is more involved.
3. For unknown reasons, deepseek v3.1 (8-bit) does not work with tensor parallelism.

## Why It Works

`place_instance` now raises an `Exception` for invalid placements.

## Test Plan

### Manual Testing
Since `/instance/previews` enumerates all possible placements and runs `place_instance`, I checked the dashboard to see if invalid placements are still shown.
2025-12-31 00:33:40 +00:00
Alex Cheema
35be4c55c3 prioritise mlx jaccl coordinator ip (en0 -> en1 -> non-TB5 -> other) 2025-12-31 00:10:19 +00:00
Alex Cheema
31d4cd8409 set KV_CACHE_BITS to None to disable quantized kv cache 2025-12-31 00:03:30 +00:00
Alex Cheema
8a6da58404 remove mx.set_cache_limit 2025-12-30 23:58:15 +00:00
59 changed files with 8734 additions and 1798 deletions

View File

@@ -20,6 +20,8 @@ struct ContentView: View {
@State private var showDebugInfo = false
@State private var bugReportInFlight = false
@State private var bugReportMessage: String?
@State private var showAdvancedOptions = false
@State private var pendingNamespace: String = ""
var body: some View {
VStack(alignment: .leading, spacing: 12) {
@@ -197,6 +199,8 @@ struct ContentView: View {
updater.checkForUpdates()
}
.padding(.bottom, 8)
advancedOptionsSection
.padding(.bottom, 8)
debugSection
.padding(.bottom, 8)
controlButton(title: "Quit", tint: .secondary) {
@@ -327,6 +331,47 @@ struct ContentView: View {
}
}
private var advancedOptionsSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Advanced Options")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showAdvancedOptions)
}
.animation(nil, value: showAdvancedOptions)
if showAdvancedOptions {
VStack(alignment: .leading, spacing: 8) {
VStack(alignment: .leading, spacing: 4) {
Text("Cluster Namespace")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
TextField("optional", text: $pendingNamespace)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingNamespace = controller.customNamespace
}
Button("Save & Restart") {
controller.customNamespace = pendingNamespace
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingNamespace == controller.customNamespace)
}
}
}
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showAdvancedOptions)
}
private var debugSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {

View File

@@ -2,6 +2,8 @@ import AppKit
import Combine
import Foundation
private let customNamespaceKey = "EXOCustomNamespace"
@MainActor
final class ExoProcessController: ObservableObject {
enum Status: Equatable {
@@ -27,6 +29,13 @@ final class ExoProcessController: ObservableObject {
@Published private(set) var status: Status = .stopped
@Published private(set) var lastError: String?
@Published private(set) var launchCountdownSeconds: Int?
@Published var customNamespace: String = {
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
}() {
didSet {
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
}
}
private var process: Process?
private var runtimeDirectoryURL: URL?
@@ -180,7 +189,7 @@ final class ExoProcessController: ObservableObject {
private func makeEnvironment(for runtimeURL: URL) -> [String: String] {
var environment = ProcessInfo.processInfo.environment
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
environment["EXO_LIBP2P_NAMESPACE"] = buildTag()
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
var paths: [String] = []
if let existing = environment["PATH"], !existing.isEmpty {
@@ -217,6 +226,12 @@ final class ExoProcessController: ObservableObject {
}
return "dev"
}
private func computeNamespace() -> String {
let base = buildTag()
let custom = customNamespace.trimmingCharacters(in: .whitespaces)
return custom.isEmpty ? base : custom
}
}
struct RuntimeError: LocalizedError {

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import { isLoading, sendMessage, generateImage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import ChatAttachments from './ChatAttachments.svelte';
import type { ChatUploadedFile } from '$lib/types/files';
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
@@ -10,6 +10,7 @@
showHelperText?: boolean;
autofocus?: boolean;
showModelSelector?: boolean;
modelTasks?: Record<string, string[]>;
}
let {
@@ -17,7 +18,8 @@
placeholder = 'Ask anything',
showHelperText = false,
autofocus = true,
showModelSelector = false
showModelSelector = false,
modelTasks = {}
}: Props = $props();
let message = $state('');
@@ -48,13 +50,29 @@
// Accept all supported file types
const acceptString = getAcceptString(['image', 'text', 'pdf']);
// Check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const tasks = modelTasks[modelId] || [];
return tasks.includes('TextToImage') || tasks.includes('ImageToImage');
}
// Check if the currently selected model supports image generation
const isImageModel = $derived(() => {
if (!currentModel) return false;
return modelSupportsImageGeneration(currentModel);
});
// Extract available models from running instances
const availableModels = $derived(() => {
const models: Array<{id: string, label: string}> = [];
const models: Array<{id: string, label: string, isImageModel: boolean}> = [];
for (const [, instance] of Object.entries(instanceData)) {
const modelId = getInstanceModelId(instance);
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
models.push({
id: modelId,
label: modelId.split('/').pop() || modelId,
isImageModel: modelSupportsImageGeneration(modelId)
});
}
}
return models;
@@ -139,6 +157,11 @@
}
function handleKeydown(event: KeyboardEvent) {
// Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)
if (event.isComposing || event.keyCode === 229) {
return;
}
if (event.key === 'Enter' && !event.shiftKey) {
event.preventDefault();
handleSubmit();
@@ -155,7 +178,12 @@
uploadedFiles = [];
resetTextareaHeight();
sendMessage(content, files);
// Use image generation for image models
if (isImageModel() && content) {
generateImage(content);
} else {
sendMessage(content, files);
}
// Refocus the textarea after sending
setTimeout(() => textareaRef?.focus(), 10);
@@ -292,7 +320,14 @@
{:else}
<span class="w-3"></span>
{/if}
<span class="truncate">{model.label}</span>
{#if model.isImageModel}
<svg class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate flex-1">{model.label}</span>
</button>
{/each}
</div>
@@ -352,7 +387,7 @@
onkeydown={handleKeydown}
oninput={handleInput}
onpaste={handlePaste}
{placeholder}
placeholder={isImageModel() ? 'Describe the image you want to generate...' : placeholder}
disabled={loading}
rows={1}
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
@@ -366,14 +401,23 @@
{!canSend || loading
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label="Send message"
aria-label={isImageModel() ? "Generate image" : "Send message"}
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
<span class="hidden sm:inline">PROCESSING</span>
<span class="hidden sm:inline">{isImageModel() ? 'GENERATING' : 'PROCESSING'}</span>
<span class="sm:hidden">...</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}

View File

@@ -365,10 +365,58 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
</div>
{/if}
<!-- Generated Images -->
{#if message.attachments?.some(a => a.type === 'generated-image')}
<div class="mb-3">
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
<div class="relative group/img inline-block">
<img
src={attachment.preview}
alt=""
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
/>
<!-- Download button overlay -->
<button
type="button"
class="absolute top-2 right-2 p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow opacity-0 group-hover/img:opacity-100 transition-opacity hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
onclick={() => {
if (attachment.preview) {
const link = document.createElement('a');
link.href = attachment.preview;
link.download = `generated-image-${Date.now()}.png`;
link.click();
}
}}
title="Download image"
>
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
</svg>
</button>
</div>
{/each}
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{#if message.content === 'Generating image...'}
<div class="flex items-center gap-3 text-exo-yellow">
<div class="relative">
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
</div>
<span class="font-mono tracking-wider uppercase text-sm">Generating image...</span>
</div>
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
{/if}
</div>
</div>

View File

File diff suppressed because it is too large Load Diff

View File

@@ -47,10 +47,86 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
const modelTasks = $derived(() => {
const tasks: Record<string, string[]> = {};
for (const model of models) {
if (model.tasks && model.tasks.length > 0) {
// Map by short ID
tasks[model.id] = model.tasks;
// Also map by hugging_face_id from the API response
if (model.hugging_face_id) {
tasks[model.hugging_face_id] = model.tasks;
}
}
}
return tasks;
});
// Helper to check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
if (!model?.tasks) return false;
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
}
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
// Launch defaults persistence
const LAUNCH_DEFAULTS_KEY = 'exo-launch-defaults';
interface LaunchDefaults {
modelId: string | null;
sharding: 'Pipeline' | 'Tensor';
instanceType: InstanceMeta;
minNodes: number;
}
function saveLaunchDefaults(): void {
const defaults: LaunchDefaults = {
modelId: selectedPreviewModelId(),
sharding: selectedSharding,
instanceType: selectedInstanceType,
minNodes: selectedMinNodes,
};
try {
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
} catch (e) {
console.warn('Failed to save launch defaults:', e);
}
}
function loadLaunchDefaults(): LaunchDefaults | null {
try {
const stored = localStorage.getItem(LAUNCH_DEFAULTS_KEY);
if (!stored) return null;
return JSON.parse(stored) as LaunchDefaults;
} catch (e) {
console.warn('Failed to load launch defaults:', e);
return null;
}
}
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
const defaults = loadLaunchDefaults();
if (!defaults) return;
// Apply sharding and instance type unconditionally
selectedSharding = defaults.sharding;
selectedInstanceType = defaults.instanceType;
// Apply minNodes if valid (between 1 and maxNodes)
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
selectPreviewModel(defaults.modelId);
}
}
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
let selectedMinNodes = $state<number>(1);
let minNodesInitialized = $state(false);
@@ -298,6 +374,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const data = await response.json();
// API returns { data: [{ id, name }] } format
models = data.data || [];
// Restore last launch defaults if available
const currentNodeCount = topologyData() ? Object.keys(topologyData()!.nodes).length : 1;
applyLaunchDefaults(models, currentNodeCount);
}
} catch (error) {
console.error('Failed to fetch models:', error);
@@ -988,6 +1067,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function handleSliderMouseUp() {
isDraggingSlider = false;
saveLaunchDefaults();
}
// Handle touch events for mobile
@@ -1007,6 +1087,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function handleSliderTouchEnd() {
isDraggingSlider = false;
saveLaunchDefaults();
}
const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0);
@@ -1192,6 +1273,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
placeholder="Ask anything"
showHelperText={false}
showModelSelector={true}
modelTasks={modelTasks()}
/>
</div>
</div>
@@ -1413,8 +1495,18 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{@const foundModel = models.find(m => m.id === selectedModelId)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="flex items-center gap-2 text-exo-light-gray truncate">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{foundModel.name || foundModel.id}</span>
</span>
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
</span>
{:else}
@@ -1459,11 +1551,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(model.id)}
<button
type="button"
onclick={() => {
if (modelCanFit) {
selectPreviewModel(model.id);
saveLaunchDefaults();
isModelDropdownOpen = false;
modelDropdownSearch = '';
}
@@ -1477,7 +1571,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
: 'text-white/30 cursor-default'
}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
@@ -1497,7 +1600,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="text-xs text-white/70 font-mono mb-2">Sharding:</div>
<div class="flex gap-2">
<button
onclick={() => selectedSharding = 'Pipeline'}
onclick={() => { selectedSharding = 'Pipeline'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Pipeline' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Pipeline' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1508,7 +1611,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
Pipeline
</button>
<button
onclick={() => selectedSharding = 'Tensor'}
onclick={() => { selectedSharding = 'Tensor'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Tensor' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Tensor' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1526,7 +1629,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="text-xs text-white/70 font-mono mb-2">Instance Type:</div>
<div class="flex gap-2">
<button
onclick={() => selectedInstanceType = 'MlxRing'}
onclick={() => { selectedInstanceType = 'MlxRing'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxRing' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxRing' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1537,7 +1640,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
MLX Ring
</button>
<button
onclick={() => selectedInstanceType = 'MlxIbv'}
onclick={() => { selectedInstanceType = 'MlxIbv'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxIbv' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxIbv' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1674,7 +1777,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
<div class="max-w-7xl mx-auto">
<ChatForm placeholder="Ask anything" showModelSelector={true} />
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
</div>
</div>
</div>

View File

@@ -35,6 +35,8 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux>=0.12.1",
]
[project.scripts]

View File

@@ -28,7 +28,7 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
worker: Worker
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
master: Master | None
@@ -62,15 +62,19 @@ class Node:
else:
api = None
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
)
if not args.no_worker:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
)
else:
worker = None
# We start every node with a master
master = Master(
node_id,
@@ -100,8 +104,9 @@ class Node:
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.worker.run)
tg.start_soon(self.election.run)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
tg.start_soon(self.master.run)
if self.api:
@@ -209,6 +214,7 @@ class Args(CamelCaseModel):
spawn_api: bool = False
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
@classmethod
def parse(cls) -> Self:
@@ -246,6 +252,10 @@ class Args(CamelCaseModel):
dest="api_port",
default=52415,
)
parser.add_argument(
"--no-worker",
action="store_true",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -1,11 +1,13 @@
import base64
import json
import time
from collections.abc import AsyncGenerator
from typing import cast
from typing import Literal, cast
import anyio
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
@@ -22,9 +24,10 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
ChatCompletionChoice,
@@ -34,6 +37,10 @@ from exo.shared.types.api import (
CreateInstanceResponse,
DeleteInstanceResponse,
FinishReason,
ImageData,
ImageEditsInternalParams,
ImageGenerationResponse,
ImageGenerationTaskParams,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -41,14 +48,17 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
@@ -84,12 +94,23 @@ def chunk_to_response(
)
async def resolve_model_meta(model_id: str) -> ModelMetadata:
def get_model_card(model_id: str) -> ModelCard | None:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for _, model_card in MODEL_CARDS.items():
if model_id == model_card.model_id:
return model_card
async def resolve_model_meta(model_id: str) -> ModelMetadata:
model_card = get_model_card(model_id)
if model_card is not None:
return model_card.metadata
else:
return await get_model_meta(model_id)
return await get_model_meta(model_id)
class API:
@@ -133,6 +154,7 @@ class API:
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -141,6 +163,7 @@ class API:
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._chat_completion_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
def unpause(self, result_clock: int):
@@ -172,6 +195,10 @@ class API:
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
self.app.post("/v1/images/generations", response_model=None)(
self.image_generations
)
self.app.post("/v1/images/edits", response_model=None)(self.image_edits)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -525,6 +552,325 @@ class API:
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
async def image_generations(
self, payload: ImageGenerationTaskParams
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image generation requests.
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
)
command = ImageGeneration(
request_params=payload,
)
await self._send(command)
# Check if streaming is requested
if payload.stream and payload.partial_images and payload.partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
),
media_type="text/event-stream",
)
# Non-streaming: collect all image chunks
return await self._collect_image_generation(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
)
async def _generate_image_stream(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> AsyncGenerator[str, None]:
"""Generate SSE stream of partial and final images."""
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
image_total_chunks: dict[tuple[int, bool], int] = {}
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
with recv as chunks:
async for chunk in chunks:
key = (chunk.image_index, chunk.is_partial)
if key not in image_chunks:
image_chunks[key] = {}
image_total_chunks[key] = chunk.total_chunks
image_metadata[key] = (
chunk.partial_index,
chunk.total_partials,
)
image_chunks[key][chunk.chunk_index] = chunk.data
# Check if this image is complete
if len(image_chunks[key]) == image_total_chunks[key]:
full_data = "".join(
image_chunks[key][i] for i in range(len(image_chunks[key]))
)
partial_idx, total_partials = image_metadata[key]
if chunk.is_partial:
# Yield partial image event
event_data = {
"type": "partial",
"partial_index": partial_idx,
"total_partials": total_partials,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
else:
# Final image
event_data = {
"type": "final",
"image_index": chunk.image_index,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
images_complete += 1
if images_complete >= num_images:
yield "data: [DONE]\n\n"
break
# Clean up completed image chunks
del image_chunks[key]
del image_total_chunks[key]
del image_metadata[key]
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def _collect_image_generation(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> ImageGenerationResponse:
"""Collect all image chunks (non-streaming) and return a single response."""
# Track chunks per image: {image_index: {chunk_index: data}}
# Only track non-partial (final) images
image_chunks: dict[int, dict[int, str]] = {}
image_total_chunks: dict[int, int] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
while images_complete < num_images:
with recv as chunks:
async for chunk in chunks:
# Skip partial images in non-streaming mode
if chunk.is_partial:
continue
if chunk.image_index not in image_chunks:
image_chunks[chunk.image_index] = {}
image_total_chunks[chunk.image_index] = chunk.total_chunks
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
# Check if this image is complete
if (
len(image_chunks[chunk.image_index])
== image_total_chunks[chunk.image_index]
):
images_complete += 1
if images_complete >= num_images:
break
# Reassemble images in order
images: list[ImageData] = []
for image_idx in range(num_images):
chunks_dict = image_chunks[image_idx]
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
images.append(
ImageData(
b64_json=full_data if response_format == "b64_json" else None,
url=None, # URL format not implemented yet
)
)
return ImageGenerationResponse(data=images)
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def image_edits(
self,
image: UploadFile = File(...),
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
stream: bool = Form(False),
partial_images: int = Form(0),
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image editing requests (img2img)."""
model_meta = await resolve_model_meta(model)
resolved_model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved_model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {resolved_model}"
)
# Read and base64 encode the uploaded image
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
# Map input_fidelity to image_strength
image_strength = 0.7 if input_fidelity == "high" else 0.3
# Split image into chunks to stay under gossipsub message size limit
data_chunks = [
image_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
# Create command first to get command_id
command = ImageEdits(
request_params=ImageEditsInternalParams(
image_data="", # Empty - will be assembled at worker from chunks
total_input_chunks=total_chunks,
prompt=prompt,
model=resolved_model,
n=n,
size=size,
response_format=response_format,
image_strength=image_strength,
stream=stream,
partial_images=partial_images,
),
)
# Send input chunks BEFORE the command
logger.info(
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(data_chunks):
await self._send(
SendInputChunk(
chunk=InputImageChunk(
idx=chunk_index,
model=resolved_model,
command_id=command.command_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
)
)
)
# Now send the main command
await self._send(command)
num_images = n
# Check if streaming is requested
if stream and partial_images and partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=num_images,
response_format=response_format,
),
media_type="text/event-stream",
)
# Track chunks per image: {image_index: {chunk_index: data}}
image_chunks: dict[int, dict[int, str]] = {}
image_total_chunks: dict[int, int] = {}
images_complete = 0
try:
self._image_generation_queues[command.command_id], recv = channel[
ImageChunk
]()
while images_complete < num_images:
with recv as chunks:
async for chunk in chunks:
if chunk.image_index not in image_chunks:
image_chunks[chunk.image_index] = {}
image_total_chunks[chunk.image_index] = chunk.total_chunks
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
if (
len(image_chunks[chunk.image_index])
== image_total_chunks[chunk.image_index]
):
images_complete += 1
if images_complete >= num_images:
break
images: list[ImageData] = []
for image_idx in range(num_images):
chunks_dict = image_chunks[image_idx]
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
images.append(
ImageData(
b64_json=full_data if response_format == "b64_json" else None,
url=None, # URL format not implemented yet
)
)
return ImageGenerationResponse(data=images)
except anyio.get_cancelled_exc_class():
raise
finally:
# Send TaskFinished command
await self._send(TaskFinished(finished_command_id=command.command_id))
del self._image_generation_queues[command.command_id]
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -547,6 +893,7 @@ class API:
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
tasks=[task.value for task in card.tasks],
)
for card in MODEL_CARDS.values()
]
@@ -584,14 +931,17 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if (
isinstance(event, ChunkGenerated)
and event.command_id in self._chat_completion_queues
):
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
if isinstance(event, ChunkGenerated):
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
elif event.command_id in self._image_generation_queues:
assert isinstance(event.chunk, ImageChunk)
await self._image_generation_queues[event.command_id].send(
event.chunk
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -2,6 +2,7 @@ from datetime import datetime, timedelta, timezone
import anyio
from anyio.abc import TaskGroup
from fastapi.routing import request_response
from loguru import logger
from exo.master.placement import (
@@ -11,13 +12,17 @@ from exo.master.placement import (
place_instance,
)
from exo.shared.apply import apply
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.commands import (
ChatCompletion,
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskFinished,
TestCommand,
)
@@ -26,6 +31,7 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
NodeTimedOut,
TaskCreated,
@@ -35,6 +41,12 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
from exo.shared.types.tasks import (
ImageGeneration as ImageGenerationTask,
)
from exo.shared.types.tasks import (
TaskId,
TaskStatus,
@@ -99,6 +111,7 @@ class Master:
async for forwarder_command in commands:
try:
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.command
match command:
@@ -146,6 +159,92 @@ class Master:
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageGeneration():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageEdits():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
@@ -173,6 +272,13 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
generated_events.append(
InputChunkReceived(
command_id=chunk.command_id,
chunk=chunk,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
)
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
@@ -29,6 +30,7 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.shards import Sharding
def random_ephemeral_port() -> int:
@@ -65,6 +67,28 @@ def place_instance(
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_meta.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_meta.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
)
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [

View File

@@ -385,13 +385,14 @@ def get_mlx_jaccl_coordinators(
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
return ip
logger.warning(

View File

@@ -50,7 +50,7 @@ def model_meta() -> ModelMetadata:
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=10,
hidden_size=30,
supports_tensor=True,
)

View File

@@ -9,6 +9,7 @@ from exo.shared.types.events import (
ChunkGenerated,
Event,
IndexedEvent,
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
NodeCreated,
@@ -40,8 +41,8 @@ def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged()
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
): # Pass-through events that don't modify state
return state
case InstanceCreated():
return apply_instance_created(event, state)

View File

@@ -44,3 +44,5 @@ LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024

View File

@@ -1,5 +1,5 @@
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.models import ComponentInfo, ModelId, ModelMetadata, ModelTask
from exo.utils.pydantic_ext import CamelCaseModel
@@ -8,6 +8,7 @@ class ModelCard(CamelCaseModel):
model_id: ModelId
name: str
description: str
tasks: list[ModelTask]
tags: list[str]
metadata: ModelMetadata
@@ -45,6 +46,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
@@ -60,6 +62,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
@@ -133,6 +136,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
@@ -148,6 +152,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
@@ -164,6 +169,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
@@ -179,6 +185,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
@@ -194,6 +201,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
@@ -209,6 +217,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
@@ -225,6 +234,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
@@ -240,6 +250,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
@@ -255,6 +266,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
@@ -271,6 +283,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
@@ -286,6 +299,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
@@ -301,6 +315,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
@@ -317,6 +332,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
@@ -332,6 +348,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
@@ -347,6 +364,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
@@ -362,6 +380,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
@@ -377,6 +396,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
@@ -392,6 +412,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
@@ -407,6 +428,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
@@ -422,6 +444,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
@@ -437,6 +460,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
@@ -452,6 +476,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
@@ -467,6 +492,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
@@ -482,6 +508,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
@@ -498,6 +525,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
@@ -513,6 +541,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
@@ -529,6 +558,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
@@ -544,6 +574,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
@@ -569,4 +600,188 @@ MODEL_CARDS: dict[str, ModelCard] = {
# supports_tensor=True,
# ),
# ),
"flux1-schnell": ModelCard(
short_id="flux1-schnell",
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
name="FLUX.1 [schnell]",
description="""FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
pretty_name="FLUX.1 [schnell]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120), # + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"flux1-dev": ModelCard(
short_id="flux1-dev",
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
name="FLUX.1 [dev]",
description="""FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
pretty_name="FLUX.1 [dev]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"qwen-image": ModelCard(
short_id="qwen-image",
model_id=ModelId("Qwen/Qwen-Image"),
name="Qwen Image",
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("Qwen/Qwen-Image"),
pretty_name="Qwen Image",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"qwen-image-edit-2509": ModelCard(
short_id="qwen-image-edit-2509",
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
name="Qwen Image Edit 2509",
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
tasks=[ModelTask.ImageToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
pretty_name="Qwen Image Edit 2509",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
}

View File

@@ -1,6 +1,8 @@
import time
from collections.abc import Generator
from typing import Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
@@ -27,6 +29,7 @@ class ModelListModel(BaseModel):
tags: list[str] = Field(default=[])
storage_size_megabytes: int = Field(default=0)
supports_tensor: bool = Field(default=False)
tasks: list[str] = Field(default=[])
class ModelList(BaseModel):
@@ -181,3 +184,75 @@ class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class ImageGenerationTaskParams(BaseModel):
prompt: str
# background: str | None = None
model: str
# moderation: str | None = None
n: int | None = 1
# output_compression: int | None = None
output_format: Literal["png", "jpeg", "webp"] = "png"
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
stream: bool | None = False
# style: str | None = "vivid"
# user: str | None = None
class ImageEditsTaskParams(BaseModel):
image: UploadFile
prompt: str
input_fidelity: float = 0.7
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
# user: str | None = None
class ImageEditsInternalParams(BaseModel):
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
total_input_chunks: int = 0
prompt: str
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
image_strength: float = 0.7
stream: bool = False
partial_images: int | None = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} chars>"
elif name is not None:
yield name, value
class ImageData(BaseModel):
b64_json: str | None = None
url: str | None = None
revised_prompt: str | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "b64_json" and value is not None:
yield name, f"<{len(value)} chars>"
elif name is not None:
yield name, value
class ImageGenerationResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
data: list[ImageData]

View File

@@ -1,8 +1,11 @@
from collections.abc import Generator
from enum import Enum
from typing import Any
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .common import CommandId
from .models import ModelId
@@ -23,7 +26,34 @@ class TokenChunk(BaseChunk):
class ImageChunk(BaseChunk):
data: bytes
data: str
chunk_index: int
total_chunks: int
image_index: int
is_partial: bool = False
partial_index: int | None = None
total_partials: int | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data":
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
class InputImageChunk(BaseChunk):
command_id: CommandId
data: str
chunk_index: int
total_chunks: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data":
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
GenerationChunk = TokenChunk | ImageChunk

View File

@@ -1,6 +1,11 @@
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -20,6 +25,14 @@ class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
class ImageGeneration(BaseCommand):
request_params: ImageGenerationTaskParams
class ImageEdits(BaseCommand):
request_params: ImageEditsInternalParams
class PlaceInstance(BaseCommand):
model_meta: ModelMetadata
sharding: Sharding
@@ -39,6 +52,12 @@ class TaskFinished(BaseCommand):
finished_command_id: CommandId
class SendInputChunk(BaseCommand):
"""Command to send an input image chunk (converted to event by master)."""
chunk: InputImageChunk
class RequestEventLog(BaseCommand):
since_idx: int
@@ -47,10 +66,13 @@ Command = (
TestCommand
| RequestEventLog
| ChatCompletion
| ImageGeneration
| ImageEdits
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskFinished
| SendInputChunk
)

View File

@@ -3,7 +3,7 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
@@ -106,6 +106,11 @@ class ChunkGenerated(BaseEvent):
chunk: GenerationChunk
class InputChunkReceived(BaseEvent):
command_id: CommandId
chunk: InputImageChunk
class TopologyEdgeCreated(BaseEvent):
edge: Connection
@@ -131,6 +136,7 @@ Event = (
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| TopologyEdgeCreated
| TopologyEdgeDeleted
)

View File

@@ -1,3 +1,5 @@
from enum import Enum
from pydantic import PositiveInt
from exo.shared.types.common import Id
@@ -9,6 +11,21 @@ class ModelId(Id):
pass
class ModelTask(str, Enum):
TextGeneration = "TextGeneration"
TextToImage = "TextToImage"
ImageToImage = "ImageToImage"
class ComponentInfo(CamelCaseModel):
component_name: str
component_path: str
storage_size: Memory
n_layers: PositiveInt | None
can_shard: bool
safetensors_index_filename: str | None
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
@@ -16,3 +33,4 @@ class ModelMetadata(CamelCaseModel):
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
components: list[ComponentInfo] | None = None

View File

@@ -2,7 +2,11 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.common import CommandId, Id
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
@@ -56,6 +60,22 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class ImageEdits(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageEditsInternalParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
@@ -67,5 +87,7 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| ImageGeneration
| ImageEdits
| Shutdown
)

View File

@@ -1,3 +1,6 @@
from collections.abc import Generator
from typing import Any, Literal
from exo.shared.types.api import FinishReason
from exo.utils.pydantic_ext import TaggedModel
@@ -17,5 +20,31 @@ class GenerationResponse(BaseRunnerResponse):
finish_reason: FinishReason | None = None
class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
class PartialImageResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
partial_index: int
total_partials: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -53,6 +53,10 @@ class RunnerRunning(BaseRunnerStatus):
pass
class RunnerShuttingDown(BaseRunnerStatus):
pass
class RunnerShutdown(BaseRunnerStatus):
pass
@@ -70,6 +74,7 @@ RunnerStatus = (
| RunnerWarmingUp
| RunnerReady
| RunnerRunning
| RunnerShuttingDown
| RunnerShutdown
| RunnerFailed
)

View File

@@ -9,6 +9,7 @@ from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal
from urllib.parse import urljoin
from huggingface_hub._snapshot_download import snapshot_download
import aiofiles
import aiofiles.os as aios
@@ -441,15 +442,39 @@ def calculate_repo_progress(
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
index_files_dir = snapshot_download(
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
return index_data.weight_map
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
weight_map: dict[str, str] = {}
for index_file in index_files:
relative_dir = index_file.parent.relative_to(index_files_dir)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
if relative_dir != Path("."):
prefixed_weight_map = {
f"{relative_dir}/{key}": str(relative_dir / value)
for key, value in index_data.weight_map.items()
}
weight_map = weight_map | prefixed_weight_map
else:
weight_map = weight_map | index_data.weight_map
return weight_map
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# TODO: 'Smart' downloads are disabled because:
# (i) We don't handle all kinds of files;
# (ii) We don't have sticky sessions.
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)
@@ -546,8 +571,6 @@ async def download_shard(
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_meta.model_id), revision, recursive=True
)

View File

@@ -100,26 +100,68 @@ def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list
"*.py",
"tokenizer.model",
"tiktoken.model",
"*/spiece.model",
"*.tiktoken",
"*.txt",
"*.jinja",
]
)
shard_specific_patterns: set[str] = set()
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num <= shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
if shard.model_meta.components is not None:
shardable_component = next(
(c for c in shard.model_meta.components if c.can_shard), None
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
if weight_map and shardable_component:
for tensor_name, filename in weight_map.items():
# Strip component prefix from tensor name (added by weight map namespacing)
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
if "/" in tensor_name:
_, tensor_name_no_prefix = tensor_name.split("/", 1)
else:
tensor_name_no_prefix = tensor_name
# Determine which component this file belongs to from filename
component_path = Path(filename).parts[0] if "/" in filename else None
if component_path == shardable_component.component_path.rstrip("/"):
layer_num = extract_layer_num(tensor_name_no_prefix)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
if shard.is_first_layer or shard.is_last_layer:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns = set(["*.safetensors"])
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
for component in shard.model_meta.components:
if not component.can_shard and component.safetensors_index_filename is None:
component_pattern = f"{component.component_path.rstrip('/')}/*"
shard_specific_patterns.add(component_pattern)
else:
shard_specific_patterns = set(["*.safetensors"])
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
else:
shard_specific_patterns = set(["*.safetensors"])
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

View File

@@ -0,0 +1,10 @@
from exo.worker.engines.image.base import ImageGenerator
from exo.worker.engines.image.distributed_model import initialize_image_model
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
__all__ = [
"ImageGenerator",
"generate_image",
"initialize_image_model",
"warmup_image_generator",
]

View File

@@ -0,0 +1,50 @@
from collections.abc import Generator
from pathlib import Path
from typing import Literal, Protocol, runtime_checkable
from PIL import Image
@runtime_checkable
class ImageGenerator(Protocol):
@property
def rank(self) -> int: ...
@property
def is_first_stage(self) -> bool: ...
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"],
seed: int,
image_path: Path | None = None,
partial_images: int = 0,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
"""Generate an image from a text prompt, or edit an existing image.
For distributed inference, only the last stage returns images.
Other stages yield nothing after participating in the pipeline.
When partial_images > 0, yields intermediate images during diffusion
as tuples of (image, partial_index, total_partials), then yields
the final image.
When partial_images = 0 (default), only yields the final image.
Args:
prompt: Text description of the image to generate
height: Image height in pixels
width: Image width in pixels
quality: Generation quality level
seed: Random seed for reproducibility
image_path: Optional path to input image for image editing
partial_images: Number of intermediate images to yield (0 for none)
Yields:
Intermediate images as (Image, partial_index, total_partials) tuples
Final PIL Image (last stage) or nothing (other stages)
"""
...

View File

@@ -0,0 +1,74 @@
from enum import Enum
from math import ceil
from pydantic import BaseModel
class BlockType(Enum):
JOINT = "joint" # Separate image/text streams
SINGLE = "single" # Concatenated streams
class TransformerBlockConfig(BaseModel):
model_config = {"frozen": True}
block_type: BlockType
count: int
has_separate_text_output: bool # True for joint blocks that output text separately
class ImageModelConfig(BaseModel):
model_config = {"frozen": True}
# Model identification
model_family: str # "flux", "fibo", "qwen"
model_variant: str # "schnell", "dev", etc.
# Architecture parameters
hidden_dim: int
num_heads: int
head_dim: int
# Block configuration - ordered sequence of block types
block_configs: tuple[TransformerBlockConfig, ...]
# Tokenization parameters
patch_size: int # 2 for Flux/Qwen
vae_scale_factor: int # 8 for Flux, 16 for others
# Inference parameters
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
num_sync_steps_factor: float # Fraction of steps for sync phase
# Feature flags
uses_attention_mask: bool # True for Fibo
# CFG (Classifier-Free Guidance) parameters
guidance_scale: float | None = None # None or <= 1.0 disables CFG
@property
def total_blocks(self) -> int:
"""Total number of transformer blocks."""
return sum(bc.count for bc in self.block_configs)
@property
def joint_block_count(self) -> int:
"""Number of joint transformer blocks."""
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
)
@property
def single_block_count(self) -> int:
"""Number of single transformer blocks."""
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
)
def get_steps_for_quality(self, quality: str) -> int:
"""Get inference steps for a quality level."""
return self.default_steps[quality]
def get_num_sync_steps(self, quality: str) -> int:
"""Get number of synchronous steps based on quality."""
return ceil(self.default_steps[quality] * self.num_sync_steps_factor)

View File

@@ -0,0 +1,228 @@
from collections.abc import Generator
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional
import mlx.core as mx
from mflux.config.config import Config
from PIL import Image
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,
get_config_for_model,
)
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline import DiffusionRunner
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
from exo.worker.runner.bootstrap import logger
class DistributedImageModel:
__slots__ = (
"_config",
"_adapter",
"_group",
"_shard_metadata",
"_runner",
)
_config: ImageModelConfig
_adapter: BaseModelAdapter
_group: Optional[mx.distributed.Group]
_shard_metadata: PipelineShardMetadata
_runner: DiffusionRunner
def __init__(
self,
model_id: str,
local_path: Path,
shard_metadata: PipelineShardMetadata,
group: Optional[mx.distributed.Group] = None,
quantize: int | None = None,
):
# Get model config and create adapter (adapter owns the model)
config = get_config_for_model(model_id)
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
if group is not None:
adapter.slice_transformer_blocks(
start_layer=shard_metadata.start_layer,
end_layer=shard_metadata.end_layer,
total_joint_blocks=config.joint_block_count,
total_single_blocks=config.single_block_count,
)
# Create diffusion runner (handles both single-node and distributed modes)
num_sync_steps = config.get_num_sync_steps("medium") if group else 0
runner = DiffusionRunner(
config=config,
adapter=adapter,
group=group,
shard_metadata=shard_metadata,
num_sync_steps=num_sync_steps,
)
if group is not None:
logger.info("Initialized distributed diffusion runner")
mx.eval(adapter.model.parameters())
# TODO(ciaran): Do we need this?
mx.eval(adapter.model)
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
logger.info(f"Transformer sharded for rank {group.rank()}")
else:
logger.info("Single-node initialization")
object.__setattr__(self, "_config", config)
object.__setattr__(self, "_adapter", adapter)
object.__setattr__(self, "_group", group)
object.__setattr__(self, "_shard_metadata", shard_metadata)
object.__setattr__(self, "_runner", runner)
@classmethod
def from_bound_instance(
cls, bound_instance: BoundInstance
) -> "DistributedImageModel":
model_id = bound_instance.bound_shard.model_meta.model_id
model_path = build_model_path(model_id)
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, PipelineShardMetadata):
raise ValueError("Expected PipelineShardMetadata for image generation")
is_distributed = (
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
)
if is_distributed:
logger.info("Starting distributed init for image model")
group = mlx_distributed_init(bound_instance)
else:
group = None
return cls(
model_id=model_id,
local_path=model_path,
shard_metadata=shard_metadata,
group=group,
)
@property
def model(self) -> Any:
"""Return the underlying mflux model via the adapter."""
return self._adapter.model
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def adapter(self) -> BaseModelAdapter:
return self._adapter
@property
def group(self) -> Optional[mx.distributed.Group]:
return self._group
@property
def shard_metadata(self) -> PipelineShardMetadata:
return self._shard_metadata
@property
def rank(self) -> int:
return self._shard_metadata.device_rank
@property
def world_size(self) -> int:
return self._shard_metadata.world_size
@property
def is_first_stage(self) -> bool:
return self._shard_metadata.device_rank == 0
@property
def is_last_stage(self) -> bool:
return self._shard_metadata.device_rank == self._shard_metadata.world_size - 1
@property
def is_distributed(self) -> bool:
return self._shard_metadata.world_size > 1
@property
def runner(self) -> DiffusionRunner:
return self._runner
# Delegate attribute access to the underlying model via the adapter.
# Guarded with TYPE_CHECKING to prevent type checker complaints
# while still providing full delegation at runtime.
if not TYPE_CHECKING:
def __getattr__(self, name: str) -> Any:
return getattr(self._adapter.model, name)
def __setattr__(self, name: str, value: Any) -> None:
if name in (
"_config",
"_adapter",
"_group",
"_shard_metadata",
"_runner",
):
object.__setattr__(self, name, value)
else:
setattr(self._adapter.model, name, value)
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"] = "medium",
seed: int = 2,
image_path: Path | None = None,
partial_images: int = 0,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
# Determine number of inference steps based on quality
steps = self._config.get_steps_for_quality(quality)
# For edit mode: compute dimensions from input image
# This also stores image_paths in the adapter for encode_prompt()
if image_path is not None:
computed_dims = self._adapter.set_image_dimensions(image_path)
if computed_dims is not None:
# Override user-provided dimensions with computed ones
width, height = computed_dims
config = Config(
num_inference_steps=steps,
height=height,
width=width,
image_path=image_path,
)
# Generate images via the runner
for result in self._runner.generate_image(
settings=config,
prompt=prompt,
seed=seed,
partial_images=partial_images,
):
if isinstance(result, tuple):
# Partial image: (GeneratedImage, partial_index, total_partials)
generated_image, partial_idx, total_partials = result
yield (generated_image.image, partial_idx, total_partials)
else:
# Final image: GeneratedImage
logger.info("generated image")
yield result.image
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
"""Initialize DistributedImageModel from a BoundInstance."""
return DistributedImageModel.from_bound_instance(bound_instance)

View File

@@ -0,0 +1,120 @@
import base64
import io
import tempfile
from pathlib import Path
from typing import Generator, Literal
from PIL import Image
from exo.shared.types.api import ImageEditsInternalParams, ImageGenerationTaskParams
from exo.shared.types.worker.runner_response import (
ImageGenerationResponse,
PartialImageResponse,
)
from exo.worker.engines.image.base import ImageGenerator
def parse_size(size_str: str | None) -> tuple[int, int]:
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
if not size_str or size_str == "auto":
size_str = "1024x1024"
try:
parts = size_str.split("x")
if len(parts) == 2:
width, height = int(parts[0]), int(parts[1])
return (width, height)
except (ValueError, AttributeError):
pass
# Default fallback
return (1024, 1024)
def warmup_image_generator(model: ImageGenerator) -> Image.Image | None:
"""Warmup the image generator with a small image."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a small dummy image for warmup (needed for edit models)
dummy_image = Image.new("RGB", (256, 256), color=(128, 128, 128))
dummy_path = Path(tmpdir) / "warmup.png"
dummy_image.save(dummy_path)
for result in model.generate(
prompt="Warmup",
height=256,
width=256,
quality="low",
seed=2,
image_path=dummy_path,
):
if not isinstance(result, tuple):
return result
return None
def generate_image(
model: ImageGenerator,
task: ImageGenerationTaskParams | ImageEditsInternalParams,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
When partial_images > 0 or stream=True, yields PartialImageResponse for
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0)
ImageGenerationResponse for the final complete image
"""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
seed = 2 # TODO(ciaran): Randomise when not testing anymore
# Handle streaming params for both generation and edit tasks
partial_images = task.partial_images or (3 if task.stream else 0)
image_path: Path | None = None
with tempfile.TemporaryDirectory() as tmpdir:
if isinstance(task, ImageEditsInternalParams):
# Decode base64 image data and save to temp file
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
# Iterate over generator results
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
image_path=image_path,
partial_images=partial_images,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
)
else:
# Final image
image = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
)

View File

@@ -0,0 +1,84 @@
from pathlib import Path
from typing import Callable
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.flux import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
FluxModelAdapter,
)
from exo.worker.engines.image.models.qwen import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
QwenEditModelAdapter,
QwenModelAdapter,
)
from exo.worker.engines.image.pipeline.adapter import ModelAdapter
__all__: list[str] = []
# Type alias for adapter factory functions
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter]
# Registry maps model_family string to adapter factory
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
"flux": FluxModelAdapter,
"qwen-edit": QwenEditModelAdapter,
"qwen": QwenModelAdapter,
}
# Config registry: maps model ID patterns to configs
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
"qwen-image": QWEN_IMAGE_CONFIG,
}
def get_config_for_model(model_id: str) -> ImageModelConfig:
"""Get configuration for a model ID.
Args:
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
Returns:
The model configuration
Raises:
ValueError: If no configuration found for model ID
"""
model_id_lower = model_id.lower()
for pattern, config in _CONFIG_REGISTRY.items():
if pattern in model_id_lower:
return config
raise ValueError(f"No configuration found for model: {model_id}")
def create_adapter_for_model(
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
) -> ModelAdapter:
"""Create a model adapter for the given configuration.
Args:
config: The model configuration
model_id: The model identifier
local_path: Path to the model weights
quantize: Optional quantization bits
Returns:
A ModelAdapter instance
Raises:
ValueError: If no adapter found for model family
"""
factory = _ADAPTER_REGISTRY.get(config.model_family)
if factory is None:
raise ValueError(f"No adapter found for model family: {config.model_family}")
return factory(config, model_id, local_path, quantize)

View File

@@ -0,0 +1,103 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import mlx.core as mx
from mflux.config.runtime_config import RuntimeConfig
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
from mflux.utils.array_util import ArrayUtil
from mflux.utils.image_util import ImageUtil
class BaseModelAdapter(ABC):
"""Base class for model adapters with shared utilities.
Provides common implementations for latent creation and decoding.
Subclasses implement model-specific prompt encoding and noise computation.
"""
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
"""Create initial latents. Uses model-specific latent creator."""
return LatentCreator.create_for_txt2img_or_img2img(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
img2img=Img2Img(
vae=self.model.vae,
latent_creator=self._get_latent_creator(),
sigmas=runtime_config.scheduler.sigmas,
init_time_step=runtime_config.init_time_step,
image_path=runtime_config.image_path,
),
)
def decode_latents(
self,
latents: mx.array,
runtime_config: RuntimeConfig,
seed: int,
prompt: str,
) -> Any:
"""Decode latents to image. Shared implementation."""
latents = ArrayUtil.unpack_latents(
latents=latents,
height=runtime_config.height,
width=runtime_config.width,
)
decoded = self.model.vae.decode(latents)
return ImageUtil.to_image(
decoded_latents=decoded,
config=runtime_config,
seed=seed,
prompt=prompt,
quantization=self.model.bits,
lora_paths=self.model.lora_paths,
lora_scales=self.model.lora_scales,
image_path=runtime_config.image_path,
image_strength=runtime_config.image_strength,
generation_time=0,
)
# Abstract methods - subclasses must implement
@property
@abstractmethod
def model(self) -> Any:
"""Return the underlying mflux model."""
...
@abstractmethod
def _get_latent_creator(self) -> type:
"""Return the latent creator class for this model."""
...
@abstractmethod
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
):
"""Remove transformer blocks outside the assigned range.
This should be called BEFORE mx.eval() to avoid loading unused weights
in distributed mode.
Args:
start_layer: First layer index (inclusive) assigned to this node
end_layer: Last layer index (exclusive) assigned to this node
total_joint_blocks: Total number of joint blocks in the model
total_single_blocks: Total number of single blocks in the model
"""
...
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
"""Default implementation: no dimension computation needed.
Override in edit adapters to compute dimensions from input image.
Returns:
None (use user-specified dimensions)
"""
return None

View File

@@ -0,0 +1,11 @@
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
from exo.worker.engines.image.models.flux.config import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
)
__all__ = [
"FluxModelAdapter",
"FLUX_DEV_CONFIG",
"FLUX_SCHNELL_CONFIG",
]

View File

@@ -0,0 +1,680 @@
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
from mflux.config.model_config import ModelConfig
from mflux.config.runtime_config import RuntimeConfig
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
AttentionUtils,
)
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
JointTransformerBlock,
)
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.variants.txt2img.flux import Flux1
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class FluxPromptData:
"""Container for Flux prompt encoding results."""
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
self._prompt_embeds = prompt_embeds
self._pooled_prompt_embeds = pooled_prompt_embeds
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._pooled_prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Flux has no extra forward kwargs."""
return {}
@property
def conditioning_latents(self) -> mx.array | None:
"""Flux does not use conditioning latents."""
return None
class FluxModelAdapter(BaseModelAdapter):
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = Flux1(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
local_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> Flux1:
return self._model
@property
def transformer(self) -> Transformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.x_embedder.weight.shape[0]
def _get_latent_creator(self) -> type:
return FluxLatentCreator
def encode_prompt(self, prompt: str) -> FluxPromptData:
"""Encode prompt into FluxPromptData."""
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
prompt_cache=self._model.prompt_cache,
t5_tokenizer=self._model.t5_tokenizer,
clip_tokenizer=self._model.clip_tokenizer,
t5_text_encoder=self._model.t5_text_encoder,
clip_text_encoder=self._model.clip_text_encoder,
)
return FluxPromptData(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)
@property
def needs_cfg(self) -> bool:
return False
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
raise NotImplementedError("Flux does not use classifier-free guidance")
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.x_embedder(hidden_states)
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: RuntimeConfig,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None, # Ignored by Flux
) -> mx.array:
if pooled_prompt_embeds is None:
raise ValueError(
"pooled_prompt_embeds is required for Flux text embeddings"
)
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
return Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: RuntimeConfig,
**kwargs: Any,
) -> mx.array:
kontext_image_ids = kwargs.get("kontext_image_ids")
return Transformer.compute_rotary_embeddings(
prompt_embeds,
self._transformer.pos_embed,
runtime_config,
kontext_image_ids,
)
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # mx.array for Flux
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
if mode == BlockWrapperMode.CACHING:
return self._apply_joint_block_caching(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
)
else:
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_joint_block_patched(
block=block,
patch_hidden=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
if mode == BlockWrapperMode.CACHING:
return self._apply_single_block_caching(
block=block,
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
)
else:
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_single_block_patched(
block=block,
patch_hidden=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)
def get_joint_blocks(self) -> list[JointBlockInterface]:
return cast(
list[JointBlockInterface], list(self._transformer.transformer_blocks)
)
def get_single_blocks(self) -> list[SingleBlockInterface]:
return cast(
list[SingleBlockInterface],
list(self._transformer.single_transformer_blocks),
)
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
) -> None:
if end_layer <= total_joint_blocks:
# All assigned are joint blocks
joint_start, joint_end = start_layer, end_layer
single_start, single_end = 0, 0
elif start_layer >= total_joint_blocks:
# All assigned are single blocks
joint_start, joint_end = 0, 0
single_start = start_layer - total_joint_blocks
single_end = end_layer - total_joint_blocks
else:
# Spans both joint and single
joint_start, joint_end = start_layer, total_joint_blocks
single_start = 0
single_end = end_layer - total_joint_blocks
all_joint = list(self._transformer.transformer_blocks)
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
all_single = list(self._transformer.single_transformer_blocks)
self._transformer.single_transformer_blocks = all_single[
single_start:single_end
]
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def _apply_joint_block_caching(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
) -> tuple[mx.array, mx.array]:
num_img_tokens = hidden_states.shape[1]
batch_size = hidden_states.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# 1. Compute norms
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
)
# 2. Compute Q, K, V for full image
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Compute Q, K, V for text
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 4. Concatenate Q, K, V: [text, image]
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
# 5. Apply RoPE
query, key = AttentionUtils.apply_rope(
xq=query, xk=key, freqs_cis=rotary_embeddings
)
# 6. Store IMAGE K/V in cache for async pipeline
if kv_cache is not None:
kv_cache.update_image_patch(
patch_start=0,
patch_end=num_img_tokens,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 7. Compute full attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 8. Extract and project outputs
context_attn_output = attn_output[:, :text_seq_len, :]
attn_output = attn_output[:, text_seq_len:, :]
attn_output = attn.to_out[0](attn_output)
context_attn_output = attn.to_add_out(context_attn_output)
# 9. Apply norm and feed forward
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=hidden_states,
attn_output=attn_output,
gate_mlp=gate_mlp,
gate_msa=gate_msa,
scale_mlp=scale_mlp,
shift_mlp=shift_mlp,
norm_layer=block.norm2,
ff_layer=block.ff,
)
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=encoder_hidden_states,
attn_output=context_attn_output,
gate_mlp=c_gate_mlp,
gate_msa=c_gate_msa,
scale_mlp=c_scale_mlp,
shift_mlp=c_shift_mlp,
norm_layer=block.norm2_context,
ff_layer=block.ff_context,
)
return encoder_hidden_states, hidden_states
def _apply_joint_block_patched(
self,
block: JointBlockInterface,
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
) -> tuple[mx.array, mx.array]:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# 1. Compute norms
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
hidden_states=patch_hidden,
text_embeddings=text_embeddings,
)
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
)
# 2. Compute Q, K, V for image patch
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Compute Q, K, V for text
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 4. Concatenate Q, K, V for patch: [text, patch]
query = mx.concatenate([txt_query, img_query], axis=2)
patch_key = mx.concatenate([txt_key, img_key], axis=2)
patch_value = mx.concatenate([txt_value, img_value], axis=2)
# 5. Extract RoPE for [text + current_patch]
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
]
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
# 6. Apply RoPE
query, patch_key = AttentionUtils.apply_rope(
xq=query, xk=patch_key, freqs_cis=patch_rope
)
# 7. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=patch_key[:, :, text_seq_len:, :],
value=patch_value[:, :, text_seq_len:, :],
)
# 8. Get full K, V from cache
full_key, full_value = kv_cache.get_full_kv(
text_key=patch_key[:, :, :text_seq_len, :],
text_value=patch_value[:, :, :text_seq_len, :],
)
# 9. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=full_key,
value=full_value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 10. Extract and project outputs
context_attn_output = attn_output[:, :text_seq_len, :]
hidden_attn_output = attn_output[:, text_seq_len:, :]
hidden_attn_output = attn.to_out[0](hidden_attn_output)
context_attn_output = attn.to_add_out(context_attn_output)
# 11. Apply norm and feed forward
patch_hidden = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=patch_hidden,
attn_output=hidden_attn_output,
gate_mlp=gate_mlp,
gate_msa=gate_msa,
scale_mlp=scale_mlp,
shift_mlp=shift_mlp,
norm_layer=block.norm2,
ff_layer=block.ff,
)
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=encoder_hidden_states,
attn_output=context_attn_output,
gate_mlp=c_gate_mlp,
gate_msa=c_gate_msa,
scale_mlp=c_scale_mlp,
shift_mlp=c_shift_mlp,
norm_layer=block.norm2_context,
ff_layer=block.ff_context,
)
return encoder_hidden_states, patch_hidden
def _apply_single_block_caching(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
) -> mx.array:
total_seq_len = hidden_states.shape[1]
num_img_tokens = total_seq_len - text_seq_len
batch_size = hidden_states.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# Residual connection
residual = hidden_states
# 1. Compute norm
norm_hidden, gate = block.norm(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V
query, key, value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Apply RoPE
query, key = AttentionUtils.apply_rope(
xq=query, xk=key, freqs_cis=rotary_embeddings
)
# 4. Store IMAGE K/V in cache
if kv_cache is not None:
kv_cache.update_image_patch(
patch_start=0,
patch_end=num_img_tokens,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 5. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 6. Apply feed forward and projection
hidden_states = block._apply_feed_forward_and_projection(
norm_hidden_states=norm_hidden,
attn_output=attn_output,
gate=gate,
)
return residual + hidden_states
def _apply_single_block_patched(
self,
block: SingleBlockInterface,
patch_hidden: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
) -> mx.array:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# Residual connection
residual = patch_hidden
# 1. Compute norm
norm_hidden, gate = block.norm(
hidden_states=patch_hidden,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V
query, key, value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Extract RoPE for [text + current_patch]
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
]
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
# 4. Apply RoPE
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=patch_rope)
# 5. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 6. Get full K, V from cache
full_key, full_value = kv_cache.get_full_kv(
text_key=key[:, :, :text_seq_len, :],
text_value=value[:, :, :text_seq_len, :],
)
# 7. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=full_key,
value=full_value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 8. Apply feed forward and projection
hidden_states = block._apply_feed_forward_and_projection(
norm_hidden_states=norm_hidden,
attn_output=attn_output,
gate=gate,
)
return residual + hidden_states

View File

@@ -0,0 +1,48 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
FLUX_SCHNELL_CONFIG = ImageModelConfig(
model_family="flux",
model_variant="schnell",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
patch_size=2,
vae_scale_factor=8,
default_steps={"low": 1, "medium": 2, "high": 4},
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
uses_attention_mask=False,
)
FLUX_DEV_CONFIG = ImageModelConfig(
model_family="flux",
model_variant="dev",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
patch_size=2,
vae_scale_factor=8,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
uses_attention_mask=False,
)

View File

@@ -0,0 +1,13 @@
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
from exo.worker.engines.image.models.qwen.config import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
)
from exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter
__all__ = [
"QwenModelAdapter",
"QwenEditModelAdapter",
"QWEN_IMAGE_CONFIG",
"QWEN_IMAGE_EDIT_CONFIG",
]

View File

@@ -0,0 +1,519 @@
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
from mflux.config.model_config import ModelConfig
from mflux.config.runtime_config import RuntimeConfig
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
QwenPromptEncoder,
)
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
QwenTransformerBlock,
)
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class QwenPromptData:
"""Container for Qwen prompt encoding results.
Implements PromptData protocol with additional Qwen-specific attributes.
"""
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
):
self._prompt_embeds = prompt_embeds
self.prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self.negative_prompt_mask = negative_prompt_mask
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
return self._prompt_embeds # Use prompt_embeds as placeholder
@property
def negative_prompt_embeds(self) -> mx.array:
"""Negative prompt embeddings for CFG."""
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
"""Placeholder - Qwen doesn't use pooled embeds."""
return self._negative_prompt_embeds
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Return encoder_hidden_states_mask for the appropriate prompt."""
if positive:
return {"encoder_hidden_states_mask": self.prompt_mask}
else:
return {"encoder_hidden_states_mask": self.negative_prompt_mask}
@property
def conditioning_latents(self) -> mx.array | None:
"""Standard Qwen does not use conditioning latents."""
return None
class QwenModelAdapter(BaseModelAdapter):
"""Adapter for Qwen-Image model.
Key differences from Flux:
- Single text encoder (vs dual T5+CLIP)
- 60 joint-style blocks, no single blocks
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
- Norm-preserving CFG with negative prompts
- Uses attention mask for variable-length text
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImage(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
local_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> QwenImage:
return self._model
@property
def transformer(self) -> QwenTransformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def encode_prompt(self, prompt: str) -> QwenPromptData:
"""Encode prompt into QwenPromptData.
Qwen uses classifier-free guidance with explicit negative prompts.
Returns a QwenPromptData container with all 4 tensors.
"""
# TODO(ciaran): empty string as default negative prompt
negative_prompt = ""
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
QwenPromptEncoder.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_cache=self._model.prompt_cache,
qwen_tokenizer=self._model.qwen_tokenizer,
qwen_text_encoder=self._model.text_encoder,
)
)
return QwenPromptData(
prompt_embeds=prompt_embeds,
prompt_mask=prompt_mask,
negative_prompt_embeds=neg_embeds,
negative_prompt_mask=neg_mask,
)
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
return self._model.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute image and text embeddings."""
# Image embedding
embedded_hidden = self._transformer.img_in(hidden_states)
# Text embedding: first normalize, then project
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: RuntimeConfig,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings.
For Qwen, the time_text_embed only uses hidden_states for:
- batch_size (shape[0])
- dtype
This allows us to pass any tensor (latents, prompt_embeds) as a fallback
when embedded hidden_states are not yet available.
"""
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
# (which for Qwen is the same as prompt_embeds)
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: RuntimeConfig,
**kwargs: Any,
) -> Any:
"""Compute 3D rotary embeddings for Qwen.
Qwen uses video-aware 3D RoPE with separate embeddings for image and text.
Returns:
tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]:
((img_cos, img_sin), (txt_cos, txt_sin))
"""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
cond_image_grid = kwargs.get("cond_image_grid")
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]] for Qwen
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply Qwen joint block.
For caching mode, we run the full block and optionally populate the KV cache.
For patched mode, we use the cached KV values (not yet implemented).
"""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
block_idx = kwargs.get("block_idx")
if mode == BlockWrapperMode.CACHING:
return self._apply_joint_block_caching(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
else:
# mode == BlockWrapperMode.PATCHED
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_joint_block_patched(
block=block,
patch_hidden=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Qwen has no single blocks."""
raise NotImplementedError("Qwen does not have single blocks")
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final normalization and projection."""
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)
def get_joint_blocks(self) -> list[JointBlockInterface]:
"""Return all 60 transformer blocks."""
return cast(
list[JointBlockInterface], list(self._transformer.transformer_blocks)
)
def get_single_blocks(self) -> list[SingleBlockInterface]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
) -> None:
all_blocks = list(self._transformer.transformer_blocks)
assigned_blocks = all_blocks[start_layer:end_layer]
self._transformer.transformer_blocks = assigned_blocks
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams.
For Qwen, this is called before final projection.
The streams remain separate through all blocks.
"""
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def _apply_joint_block_caching(
self,
block: Any, # QwenTransformerBlock
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
"""Apply joint block in caching mode (full attention, optionally populate cache).
Delegates to the QwenTransformerBlock's forward pass.
"""
# Call the block directly - it handles all the modulation and attention internally
return block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
text_embeddings=text_embeddings,
image_rotary_emb=rotary_embeddings,
block_idx=block_idx,
)
def _apply_joint_block_patched(
self,
block: Any, # QwenTransformerBlock
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dim
# 1. Compute modulation parameters
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
# 2. Apply normalization and modulation
img_normed = block.img_norm1(patch_hidden)
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
txt_normed = block.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
# 3. Compute Q, K, V for image patch
img_query = attn.to_q(img_modulated)
img_key = attn.to_k(img_modulated)
img_value = attn.to_v(img_modulated)
# 4. Compute Q, K, V for text
txt_query = attn.add_q_proj(txt_modulated)
txt_key = attn.add_k_proj(txt_modulated)
txt_value = attn.add_v_proj(txt_modulated)
# 5. Reshape to [B, S, H, D]
patch_len = patch_hidden.shape[1]
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
txt_query = mx.reshape(
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
)
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
txt_value = mx.reshape(
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
)
# 6. Apply RMSNorm to Q, K
img_query = attn.norm_q(img_query)
img_key = attn.norm_k(img_key)
txt_query = attn.norm_added_q(txt_query)
txt_key = attn.norm_added_k(txt_key)
# 7. Extract RoPE for patch: slice image RoPE, keep full text RoPE
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
patch_img_cos = img_cos[patch_start:patch_end]
patch_img_sin = img_sin[patch_start:patch_end]
# 8. Apply RoPE to Q, K
img_query = QwenAttention._apply_rope_qwen(
img_query, patch_img_cos, patch_img_sin
)
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
# 9. Transpose to [B, H, S, D] for cache operations
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
# 10. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=img_key_bhsd,
value=img_value_bhsd,
)
# 11. Get full K, V from cache (text + full image)
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
full_key, full_value = kv_cache.get_full_kv(
text_key=txt_key_bhsd,
text_value=txt_value_bhsd,
)
# 12. Build query: [text, patch]
joint_query = mx.concatenate([txt_query, img_query], axis=1)
# 13. Build attention mask for [text + patch] query attending to [text + full_image] KV
mask = QwenAttention._convert_mask_for_qwen(
mask=encoder_hidden_states_mask,
joint_seq_len=full_key.shape[2], # text + full_image
txt_seq_len=text_seq_len,
)
# 14. Compute attention
hidden_states = attn._compute_attention_qwen(
query=joint_query,
key=mx.transpose(full_key, (0, 2, 1, 3)), # Back to [B, S, H, D]
value=mx.transpose(full_value, (0, 2, 1, 3)),
mask=mask,
block_idx=block_idx,
)
# 15. Extract text and image attention outputs
txt_attn_output = hidden_states[:, :text_seq_len, :]
img_attn_output = hidden_states[:, text_seq_len:, :]
# 16. Project outputs
img_attn_output = attn.attn_to_out[0](img_attn_output)
txt_attn_output = attn.to_add_out(txt_attn_output)
# 17. Apply residual + gate for attention
patch_hidden = patch_hidden + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# 18. Apply feed-forward for image
img_normed2 = block.img_norm2(patch_hidden)
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
img_normed2, img_mod2
)
img_mlp_output = block.img_ff(img_modulated2)
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
# 19. Apply feed-forward for text
txt_normed2 = block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
txt_normed2, txt_mod2
)
txt_mlp_output = block.txt_ff(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
return encoder_hidden_states, patch_hidden

View File

@@ -0,0 +1,49 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
# Qwen-Image has 60 joint-style blocks (no single blocks)
# Architecture: 24 heads * 128 dim = 3072 hidden dim
# VAE uses scale factor of 16 (vs Flux's 8)
QWEN_IMAGE_CONFIG = ImageModelConfig(
model_family="qwen",
model_variant="image",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
# Qwen has no single blocks - all blocks process image and text separately
),
patch_size=2,
vae_scale_factor=16,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
uses_attention_mask=True, # Qwen uses encoder_hidden_states_mask
guidance_scale=None, # Set to None or < 1.0 to disable CFG
)
# Qwen-Image-Edit uses the same architecture but different processing pipeline
# Uses vision-language encoding and conditioning latents
QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
model_family="qwen-edit",
model_variant="image-edit",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
),
patch_size=2,
vae_scale_factor=16,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125,
uses_attention_mask=True,
guidance_scale=None,
)

View File

@@ -0,0 +1,671 @@
import math
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
from mflux.config.runtime_config import RuntimeConfig
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
QwenTransformerBlock,
)
from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
from mflux.models.qwen.variants.edit.utils.qwen_edit_util import QwenEditUtil
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class QwenEditPromptData:
"""Container for Qwen edit prompt encoding results.
Includes vision-language encoded embeddings and edit-specific conditioning.
"""
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
conditioning_latents: mx.array,
qwen_image_ids: mx.array,
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],
):
self._prompt_embeds = prompt_embeds
self.prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self.negative_prompt_mask = negative_prompt_mask
self._conditioning_latents = conditioning_latents
self._qwen_image_ids = qwen_image_ids
self._cond_image_grid = cond_image_grid
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from vision-language encoder."""
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
return self._prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array:
"""Negative prompt embeddings for CFG."""
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
"""Placeholder - Qwen doesn't use pooled embeds."""
return self._negative_prompt_embeds
@property
def conditioning_latents(self) -> mx.array:
"""Static image conditioning latents to concatenate with generated latents."""
return self._conditioning_latents
@property
def qwen_image_ids(self) -> mx.array:
"""Spatial position IDs for conditioning images."""
return self._qwen_image_ids
@property
def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:
"""Conditioning image grid dimensions."""
return self._cond_image_grid
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Return encoder_hidden_states_mask and edit-specific params."""
if positive:
return {
"encoder_hidden_states_mask": self.prompt_mask,
"qwen_image_ids": self._qwen_image_ids,
"cond_image_grid": self._cond_image_grid,
}
else:
return {
"encoder_hidden_states_mask": self.negative_prompt_mask,
"qwen_image_ids": self._qwen_image_ids,
"cond_image_grid": self._cond_image_grid,
}
@property
def is_edit_mode(self) -> bool:
"""Indicates this is edit mode with conditioning latents."""
return True
class QwenEditModelAdapter(BaseModelAdapter):
"""Adapter for Qwen-Image-Edit model.
Key differences from standard QwenModelAdapter:
- Uses QwenImageEdit model with vision-language components
- Encodes prompts WITH input images via VL tokenizer/encoder
- Creates conditioning latents from input images
- Supports image editing with concatenated latents during diffusion
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImageEdit(
quantize=quantize,
local_path=str(local_path),
)
self._transformer = self._model.transformer
# Store dimensions and image paths (set via set_image_dimensions)
self._vl_width: int | None = None
self._vl_height: int | None = None
self._vae_width: int | None = None
self._vae_height: int | None = None
self._image_paths: list[str] | None = None
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> QwenImageEdit:
return self._model
@property
def transformer(self) -> QwenTransformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def _compute_dimensions_from_image(
self, image_path: Path
) -> tuple[int, int, int, int, int, int]:
"""Compute VL and VAE dimensions from input image.
Returns:
(vl_width, vl_height, vae_width, vae_height, output_width, output_height)
"""
from mflux.utils.image_util import ImageUtil
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
image_size = pil_image.size
# Vision-language dimensions (384x384 target area)
condition_image_size = 384 * 384
condition_ratio = image_size[0] / image_size[1]
vl_width = math.sqrt(condition_image_size * condition_ratio)
vl_height = vl_width / condition_ratio
vl_width = round(vl_width / 32) * 32
vl_height = round(vl_height / 32) * 32
# VAE dimensions (1024x1024 target area)
vae_image_size = 1024 * 1024
vae_ratio = image_size[0] / image_size[1]
vae_width = math.sqrt(vae_image_size * vae_ratio)
vae_height = vae_width / vae_ratio
vae_width = round(vae_width / 32) * 32
vae_height = round(vae_height / 32) * 32
# Output dimensions from input image aspect ratio
target_area = 1024 * 1024
ratio = image_size[0] / image_size[1]
output_width = math.sqrt(target_area * ratio)
output_height = output_width / ratio
output_width = round(output_width / 32) * 32
output_height = round(output_height / 32) * 32
# Ensure multiple of 16 for VAE
vae_scale_factor = 8
multiple_of = vae_scale_factor * 2
output_width = output_width // multiple_of * multiple_of
output_height = output_height // multiple_of * multiple_of
return (
int(vl_width),
int(vl_height),
int(vae_width),
int(vae_height),
int(output_width),
int(output_height),
)
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
"""Create initial noise latents (pure noise for edit mode)."""
return QwenLatentCreator.create_noise(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
)
def encode_prompt(self, prompt: str) -> QwenEditPromptData:
"""Encode prompt with input images using vision-language encoder.
Uses stored image_paths from set_image_dimensions() for VL encoding.
Args:
prompt: Text prompt for editing
Returns:
QwenEditPromptData with VL embeddings and conditioning latents
"""
# Ensure image_paths and dimensions were set via set_image_dimensions()
if self._image_paths is None:
raise RuntimeError(
"set_image_dimensions() must be called before encode_prompt() "
"for QwenEditModelAdapter"
)
negative_prompt = ""
image_paths = self._image_paths
# Use stored dimensions (computed from input image)
vl_width = self._vl_width
vl_height = self._vl_height
vae_width = self._vae_width
vae_height = self._vae_height
# Encode prompts with images via vision-language components
tokenizer = self._model.qwen_vl_tokenizer
pos_input_ids, pos_attention_mask, pos_pixel_values, pos_image_grid_thw = (
tokenizer.tokenize_with_image(
prompt, image_paths, vl_width=vl_width, vl_height=vl_height
)
)
pos_hidden_states = self._model.qwen_vl_encoder(
input_ids=pos_input_ids,
attention_mask=pos_attention_mask,
pixel_values=pos_pixel_values,
image_grid_thw=pos_image_grid_thw,
)
mx.eval(pos_hidden_states[0])
mx.eval(pos_hidden_states[1])
# Encode negative prompt with images
neg_input_ids, neg_attention_mask, neg_pixel_values, neg_image_grid_thw = (
tokenizer.tokenize_with_image(
negative_prompt, image_paths, vl_width=vl_width, vl_height=vl_height
)
)
neg_hidden_states = self._model.qwen_vl_encoder(
input_ids=neg_input_ids,
attention_mask=neg_attention_mask,
pixel_values=neg_pixel_values,
image_grid_thw=neg_image_grid_thw,
)
mx.eval(neg_hidden_states[0])
mx.eval(neg_hidden_states[1])
# Create conditioning latents from input images
# Ensure dimensions are set (should have been set via set_image_dimensions)
assert vl_width is not None and vl_height is not None
assert vae_width is not None and vae_height is not None
(
conditioning_latents,
qwen_image_ids,
cond_h_patches,
cond_w_patches,
num_images,
) = QwenEditUtil.create_image_conditioning_latents(
vae=self._model.vae,
height=vae_height,
width=vae_width,
image_paths=image_paths,
vl_width=vl_width,
vl_height=vl_height,
)
# Build cond_image_grid
if num_images > 1:
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [
(1, cond_h_patches, cond_w_patches) for _ in range(num_images)
]
else:
cond_image_grid = (1, cond_h_patches, cond_w_patches)
return QwenEditPromptData(
prompt_embeds=pos_hidden_states[0].astype(mx.float16),
prompt_mask=pos_hidden_states[1].astype(mx.float16),
negative_prompt_embeds=neg_hidden_states[0].astype(mx.float16),
negative_prompt_mask=neg_hidden_states[1].astype(mx.float16),
conditioning_latents=conditioning_latents,
qwen_image_ids=qwen_image_ids,
cond_image_grid=cond_image_grid,
)
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
"""Compute and store dimensions from input image.
Also stores image_paths for use in encode_prompt().
Returns:
(output_width, output_height) for runtime config
"""
vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(
image_path
)
self._vl_width = vl_w
self._vl_height = vl_h
self._vae_width = vae_w
self._vae_height = vae_h
self._image_paths = [str(image_path)]
return out_w, out_h
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
return QwenImage.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute image and text embeddings."""
embedded_hidden = self._transformer.img_in(hidden_states)
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: RuntimeConfig,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings."""
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: RuntimeConfig,
**kwargs: Any,
) -> Any:
"""Compute 3D rotary embeddings for Qwen edit."""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
cond_image_grid = kwargs.get("cond_image_grid")
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply Qwen joint block."""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
block_idx = kwargs.get("block_idx")
if mode == BlockWrapperMode.CACHING:
return self._apply_joint_block_caching(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
else:
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_joint_block_patched(
block=block,
patch_hidden=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Qwen has no single blocks."""
raise NotImplementedError("Qwen does not have single blocks")
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final normalization and projection."""
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)
def get_joint_blocks(self) -> list[JointBlockInterface]:
"""Return all 60 transformer blocks."""
return cast(
list[JointBlockInterface], list(self._transformer.transformer_blocks)
)
def get_single_blocks(self) -> list[SingleBlockInterface]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
) -> None:
all_blocks = list(self._transformer.transformer_blocks)
assigned_blocks = all_blocks[start_layer:end_layer]
self._transformer.transformer_blocks = assigned_blocks
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams."""
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def _apply_joint_block_caching(
self,
block: Any,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
"""Apply joint block in caching mode."""
return block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
text_embeddings=text_embeddings,
image_rotary_emb=rotary_embeddings,
block_idx=block_idx,
)
def _apply_joint_block_patched(
self,
block: Any,
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dim
# Modulation parameters
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
# Normalization and modulation
img_normed = block.img_norm1(patch_hidden)
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
txt_normed = block.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
# Q, K, V for image patch
img_query = attn.to_q(img_modulated)
img_key = attn.to_k(img_modulated)
img_value = attn.to_v(img_modulated)
# Q, K, V for text
txt_query = attn.add_q_proj(txt_modulated)
txt_key = attn.add_k_proj(txt_modulated)
txt_value = attn.add_v_proj(txt_modulated)
# Reshape to [B, S, H, D]
patch_len = patch_hidden.shape[1]
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
txt_query = mx.reshape(
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
)
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
txt_value = mx.reshape(
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
)
# RMSNorm to Q, K
img_query = attn.norm_q(img_query)
img_key = attn.norm_k(img_key)
txt_query = attn.norm_added_q(txt_query)
txt_key = attn.norm_added_k(txt_key)
# Extract RoPE for patch
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
patch_img_cos = img_cos[patch_start:patch_end]
patch_img_sin = img_sin[patch_start:patch_end]
# Apply RoPE
img_query = QwenAttention._apply_rope_qwen(
img_query, patch_img_cos, patch_img_sin
)
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
# Transpose to [B, H, S, D]
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
# Update cache
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=img_key_bhsd,
value=img_value_bhsd,
)
# Get full K, V from cache
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
full_key, full_value = kv_cache.get_full_kv(
text_key=txt_key_bhsd,
text_value=txt_value_bhsd,
)
# Build query
joint_query = mx.concatenate([txt_query, img_query], axis=1)
# Build attention mask
mask = QwenAttention._convert_mask_for_qwen(
mask=encoder_hidden_states_mask,
joint_seq_len=full_key.shape[2],
txt_seq_len=text_seq_len,
)
# Compute attention
hidden_states = attn._compute_attention_qwen(
query=joint_query,
key=mx.transpose(full_key, (0, 2, 1, 3)),
value=mx.transpose(full_value, (0, 2, 1, 3)),
mask=mask,
block_idx=block_idx,
)
# Extract outputs
txt_attn_output = hidden_states[:, :text_seq_len, :]
img_attn_output = hidden_states[:, text_seq_len:, :]
# Project
img_attn_output = attn.attn_to_out[0](img_attn_output)
txt_attn_output = attn.to_add_out(txt_attn_output)
# Residual + gate
patch_hidden = patch_hidden + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Feed-forward for image
img_normed2 = block.img_norm2(patch_hidden)
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
img_normed2, img_mod2
)
img_mlp_output = block.img_ff(img_modulated2)
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
# Feed-forward for text
txt_normed2 = block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
txt_normed2, txt_mod2
)
txt_mlp_output = block.txt_ff(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
return encoder_hidden_states, patch_hidden

View File

@@ -0,0 +1,23 @@
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
ModelAdapter,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
__all__ = [
"BlockWrapperMode",
"DiffusionRunner",
"ImagePatchKVCache",
"JointBlockInterface",
"JointBlockWrapper",
"ModelAdapter",
"SingleBlockInterface",
"SingleBlockWrapper",
]

View File

@@ -0,0 +1,402 @@
from enum import Enum
from pathlib import Path
from typing import Any, Protocol
import mlx.core as mx
from mflux.config.runtime_config import RuntimeConfig
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class AttentionInterface(Protocol):
num_heads: int
head_dimension: int
to_q: Any
to_k: Any
to_v: Any
norm_q: Any
norm_k: Any
to_out: list[Any]
class JointAttentionInterface(AttentionInterface, Protocol):
add_q_proj: Any
add_k_proj: Any
add_v_proj: Any
norm_added_q: Any
norm_added_k: Any
to_add_out: Any
class JointBlockInterface(Protocol):
attn: JointAttentionInterface
norm1: Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
norm1_context: (
Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
)
norm2: Any
norm2_context: Any
ff: Any
ff_context: Any
class SingleBlockInterface(Protocol):
attn: AttentionInterface
norm: Any # Callable module: (hidden_states, text_embeddings) -> tuple[2 arrays]
def _apply_feed_forward_and_projection(
self, norm_hidden_states: mx.array, attn_output: mx.array, gate: mx.array
) -> mx.array:
"""Apply feed forward network and projection."""
...
class BlockWrapperMode(Enum):
CACHING = "caching" # Sync mode: compute full attention, populate cache
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
class PromptData(Protocol):
"""Protocol for encoded prompt data.
All adapters must return prompt data that conforms to this protocol.
Model-specific prompt data classes can add additional attributes
(e.g., attention masks for Qwen).
"""
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
...
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Pooled text embeddings (for Flux) or placeholder (for Qwen)."""
...
@property
def negative_prompt_embeds(self) -> mx.array | None:
"""Negative prompt embeddings for CFG (None if not using CFG)."""
...
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Negative pooled embeddings for CFG (None if not using CFG)."""
...
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Return model-specific kwargs for forward pass.
Args:
positive: If True, return kwargs for positive prompt pass.
If False, return kwargs for negative prompt pass.
Returns:
Dict of extra kwargs (e.g., {"encoder_hidden_states_mask": ...} for Qwen)
"""
...
@property
def conditioning_latents(self) -> mx.array | None:
"""Conditioning latents for edit mode.
Returns:
Conditioning latents array for image editing, None for standard generation.
"""
...
class ModelAdapter(Protocol):
@property
def config(self) -> ImageModelConfig:
"""Return the model configuration."""
...
@property
def model(self) -> Any:
"""Return the underlying mflux model instance (e.g., Flux1, Fibo, Qwen)."""
...
@property
def transformer(self) -> Any:
"""Return the transformer component of the model."""
...
@property
def hidden_dim(self) -> int:
"""Return the hidden dimension of the transformer."""
...
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute x_embedder and context_embedder outputs.
Args:
hidden_states: Input latent states
prompt_embeds: Text embeddings from encoder
Returns:
Tuple of (embedded_hidden_states, embedded_encoder_states)
"""
...
def compute_text_embeddings(
self,
t: int,
runtime_config: RuntimeConfig,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings for conditioning.
Args:
t: Current timestep
runtime_config: Runtime configuration
pooled_prompt_embeds: Pooled text embeddings (used by Flux)
hidden_states: Image hidden states
Returns:
Text embeddings tensor
"""
...
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: RuntimeConfig,
**kwargs: Any,
) -> Any:
"""Compute rotary position embeddings.
Args:
prompt_embeds: Text embeddings
runtime_config: Runtime configuration
**kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask for Qwen)
Returns:
Flux: mx.array
Qwen: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]
"""
...
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # Format varies: mx.array (Flux) or nested tuple (Qwen)
kv_cache: ImagePatchKVCache | None,
mode: "BlockWrapperMode",
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply a joint transformer block.
Args:
block: The joint transformer block
hidden_states: Image hidden states
encoder_hidden_states: Text hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings (format varies by model)
kv_cache: KV cache (None if not using cache)
mode: CACHING or PATCHED mode
text_seq_len: Text sequence length
patch_start: Start index for patched mode
patch_end: End index for patched mode
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
block_idx for Qwen)
Returns:
Tuple of (encoder_hidden_states, hidden_states)
"""
...
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: "BlockWrapperMode",
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Apply a single transformer block.
Args:
block: The single transformer block
hidden_states: Concatenated [text + image] hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
kv_cache: KV cache (None if not using cache)
mode: CACHING or PATCHED mode
text_seq_len: Text sequence length
patch_start: Start index for patched mode
patch_end: End index for patched mode
Returns:
Output hidden states
"""
...
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final norm and projection.
Args:
hidden_states: Hidden states (image only, text already removed)
text_embeddings: Conditioning embeddings
Returns:
Projected output
"""
...
def get_joint_blocks(self) -> list[JointBlockInterface]:
"""Get the list of joint transformer blocks from the model."""
...
def get_single_blocks(self) -> list[SingleBlockInterface]:
"""Get the list of single transformer blocks from the model."""
...
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
):
"""Remove transformer blocks outside the assigned range.
This should be called BEFORE mx.eval() to avoid loading unused weights
in distributed mode.
Args:
start_layer: First layer index (inclusive) assigned to this node
end_layer: Last layer index (exclusive) assigned to this node
total_joint_blocks: Total number of joint blocks in the model
total_single_blocks: Total number of single blocks in the model
"""
...
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams for transition to single blocks.
This is called at the transition point from joint blocks (which process
image and text separately) to single blocks (which process them
together). Override to customize the merge strategy.
Args:
hidden_states: Image hidden states
encoder_hidden_states: Text hidden states
Returns:
Merged hidden states (default: concatenate [text, image])
"""
...
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
"""Create initial noise latents for generation.
Args:
seed: Random seed
runtime_config: Runtime configuration with dimensions
Returns:
Initial latent tensor
"""
...
def encode_prompt(self, prompt: str) -> PromptData:
"""Encode prompt into model-specific prompt data.
Args:
prompt: Text prompt
Returns:
PromptData containing embeddings (and model-specific extras)
"""
...
@property
def needs_cfg(self) -> bool:
"""Whether this model uses classifier-free guidance.
Returns:
True if model requires two forward passes with guidance (e.g., Qwen)
False if model uses a single forward pass (e.g., Flux)
"""
...
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
"""Apply classifier-free guidance to combine positive/negative predictions.
Only called when needs_cfg is True.
Args:
noise_positive: Noise prediction from positive prompt
noise_negative: Noise prediction from negative prompt
guidance_scale: Guidance strength
Returns:
Guided noise prediction
"""
...
def decode_latents(
self,
latents: mx.array,
runtime_config: RuntimeConfig,
seed: int,
prompt: str,
) -> Any:
"""Decode latents to final image.
Args:
latents: Final denoised latents
runtime_config: Runtime configuration
seed: Random seed (for metadata)
prompt: Text prompt (for metadata)
Returns:
GeneratedImage result
"""
...
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
"""Compute and store dimensions from input image for edit mode.
For edit adapters: computes dimensions from input image aspect ratio,
stores image paths internally for encode_prompt(), returns (width, height).
For standard adapters: returns None (use user-specified dimensions).
Args:
image_path: Path to the input image
Returns:
Tuple of (width, height) if dimensions were computed, None otherwise.
"""
...

View File

@@ -0,0 +1,146 @@
from typing import Any
import mlx.core as mx
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
ModelAdapter,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class JointBlockWrapper:
"""Unified wrapper for joint transformer blocks.
Handles both CACHING (sync) and PATCHED (async) modes by delegating
to the model adapter for model-specific attention computation.
The wrapper is created once at initialization and reused across calls.
Mode and KV cache are passed at call time to support switching between
sync and async pipelines.
"""
def __init__(
self,
block: JointBlockInterface,
adapter: ModelAdapter,
):
"""Initialize the joint block wrapper.
Args:
block: The joint transformer block to wrap
adapter: Model adapter for model-specific operations
"""
self.block = block
self.adapter = adapter
def __call__(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
text_seq_len: int,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply the joint block.
Args:
hidden_states: Image hidden states (full or patch depending on mode)
encoder_hidden_states: Text hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
text_seq_len: Text sequence length
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
mode: CACHING (populate cache) or PATCHED (use cached K/V)
patch_start: Start index for patched mode (required if mode=PATCHED)
patch_end: End index for patched mode (required if mode=PATCHED)
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
block_idx for Qwen)
Returns:
Tuple of (encoder_hidden_states, hidden_states)
"""
return self.adapter.apply_joint_block(
block=self.block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
mode=mode,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
**kwargs,
)
class SingleBlockWrapper:
"""Unified wrapper for single transformer blocks.
Handles both CACHING (sync) and PATCHED (async) modes by delegating
to the model adapter for model-specific attention computation.
The wrapper is created once at initialization and reused across calls.
Mode and KV cache are passed at call time to support switching between
sync and async pipelines.
"""
def __init__(
self,
block: SingleBlockInterface,
adapter: ModelAdapter,
):
"""Initialize the single block wrapper.
Args:
block: The single transformer block to wrap
adapter: Model adapter for model-specific operations
"""
self.block = block
self.adapter = adapter
def __call__(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
text_seq_len: int,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Apply the single block.
Args:
hidden_states: [text + image] hidden states (full or patch depending on mode)
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
text_seq_len: Text sequence length
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
mode: CACHING (populate cache) or PATCHED (use cached K/V)
patch_start: Start index for patched mode (required if mode=PATCHED)
patch_end: End index for patched mode (required if mode=PATCHED)
Returns:
Output hidden states
"""
return self.adapter.apply_single_block(
block=self.block,
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
mode=mode,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)

View File

@@ -0,0 +1,72 @@
import mlx.core as mx
class ImagePatchKVCache:
"""KV cache that stores only IMAGE K/V with patch-level updates.
Only caches image K/V since:
- Text K/V is always computed fresh (same for all patches)
- Only image portion needs stale/fresh cache management across patches
"""
def __init__(
self,
batch_size: int,
num_heads: int,
image_seq_len: int,
head_dim: int,
dtype: mx.Dtype = mx.float32,
):
self.batch_size = batch_size
self.num_heads = num_heads
self.image_seq_len = image_seq_len
self.head_dim = head_dim
self._dtype = dtype
self.key_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
self.value_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
def update_image_patch(
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
) -> None:
"""Update cache with fresh K/V for an image patch slice.
Args:
patch_start: Start token index within image portion (0-indexed)
patch_end: End token index within image portion
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
"""
self.key_cache[:, :, patch_start:patch_end, :] = key
self.value_cache[:, :, patch_start:patch_end, :] = value
def get_full_kv(
self, text_key: mx.array, text_value: mx.array
) -> tuple[mx.array, mx.array]:
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
Args:
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
Returns:
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
"""
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
return full_key, full_value
def reset(self) -> None:
"""Reset cache to zeros."""
self.key_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)
self.value_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)

View File

@@ -0,0 +1,975 @@
from math import ceil
from typing import Any, Optional
import mlx.core as mx
from mflux.callbacks.callbacks import Callbacks
from mflux.config.config import Config
from mflux.config.runtime_config import RuntimeConfig
from mflux.utils.exceptions import StopImageGenerationException
from tqdm import tqdm
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
ModelAdapter,
PromptData,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
def calculate_patch_heights(latent_height: int, num_patches: int):
patch_height = ceil(latent_height / num_patches)
actual_num_patches = ceil(latent_height / patch_height)
patch_heights = [patch_height] * (actual_num_patches - 1)
last_height = latent_height - patch_height * (actual_num_patches - 1)
patch_heights.append(last_height)
return patch_heights, actual_num_patches
def calculate_token_indices(patch_heights: list[int], latent_width: int):
tokens_per_row = latent_width
token_ranges = []
cumulative_height = 0
for h in patch_heights:
start_token = tokens_per_row * cumulative_height
end_token = tokens_per_row * (cumulative_height + h)
token_ranges.append((start_token, end_token))
cumulative_height += h
return token_ranges
class DiffusionRunner:
"""Orchestrates the diffusion loop for image generation.
This class owns the entire diffusion process, handling both single-node
and distributed (PipeFusion) modes.
In distributed mode, it implements PipeFusion with:
- Sync pipeline for initial timesteps (full image, all devices in lockstep)
- Async pipeline for later timesteps (patches processed independently)
"""
def __init__(
self,
config: ImageModelConfig,
adapter: ModelAdapter,
group: Optional[mx.distributed.Group],
shard_metadata: PipelineShardMetadata,
num_sync_steps: int = 1,
num_patches: Optional[int] = None,
):
"""Initialize the diffusion runner.
Args:
config: Model configuration (architecture, block counts, etc.)
adapter: Model adapter for model-specific operations
group: MLX distributed group (None for single-node mode)
shard_metadata: Pipeline shard metadata with layer assignments
num_sync_steps: Number of synchronous timesteps before async mode
num_patches: Number of patches for async mode (defaults to world_size)
"""
self.config = config
self.adapter = adapter
self.group = group
# Handle single-node vs distributed mode
if group is None:
self.rank = 0
self.world_size = 1
self.next_rank = 0
self.prev_rank = 0
self.start_layer = 0
self.end_layer = config.total_blocks
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.next_rank = (self.rank + 1) % self.world_size
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.num_sync_steps = num_sync_steps
self.num_patches = num_patches if num_patches else max(1, self.world_size)
# Persistent KV caches (initialized on first async timestep, reused across timesteps)
self.joint_kv_caches: list[ImagePatchKVCache] | None = None
self.single_kv_caches: list[ImagePatchKVCache] | None = None
# Get block counts from config (model-agnostic)
self.total_joint = config.joint_block_count
self.total_single = config.single_block_count
self.total_layers = config.total_blocks
self._compute_assigned_blocks()
def _compute_assigned_blocks(self) -> None:
"""Determine which joint/single blocks this stage owns."""
start = self.start_layer
end = self.end_layer
if end <= self.total_joint:
# All assigned blocks are joint blocks
self.joint_start = start
self.joint_end = end
self.single_start = 0
self.single_end = 0
elif start >= self.total_joint:
# All assigned blocks are single blocks
self.joint_start = 0
self.joint_end = 0
self.single_start = start - self.total_joint
self.single_end = end - self.total_joint
else:
# Stage spans joint→single transition
self.joint_start = start
self.joint_end = self.total_joint
self.single_start = 0
self.single_end = end - self.total_joint
self.has_joint_blocks = self.joint_end > self.joint_start
self.has_single_blocks = self.single_end > self.single_start
self.owns_concat_stage = self.has_joint_blocks and (
self.has_single_blocks or self.end_layer == self.total_joint
)
joint_blocks = self.adapter.get_joint_blocks()
single_blocks = self.adapter.get_single_blocks()
# Wrap blocks at initialization (reused across all calls)
self.joint_block_wrappers = [
JointBlockWrapper(block=block, adapter=self.adapter)
for block in joint_blocks
]
self.single_block_wrappers = [
SingleBlockWrapper(block=block, adapter=self.adapter)
for block in single_blocks
]
@property
def is_first_stage(self) -> bool:
return self.rank == 0
@property
def is_last_stage(self) -> bool:
return self.rank == self.world_size - 1
@property
def is_distributed(self) -> bool:
return self.group is not None
def _calculate_capture_steps(
self,
partial_images: int,
init_time_step: int,
num_inference_steps: int,
) -> set[int]:
"""Calculate which timesteps should produce partial images.
Evenly spaces `partial_images` captures across the diffusion loop.
Does NOT include the final timestep (that's the complete image).
Args:
partial_images: Number of partial images to capture
init_time_step: Starting timestep (for img2img this may not be 0)
num_inference_steps: Total inference steps
Returns:
Set of timestep indices to capture
"""
if partial_images <= 0:
return set()
total_steps = num_inference_steps - init_time_step
if total_steps <= 1:
return set()
if partial_images >= total_steps - 1:
# Capture every step except final
return set(range(init_time_step, num_inference_steps - 1))
# Evenly space partial captures
step_interval = total_steps / (partial_images + 1)
capture_steps: set[int] = set()
for i in range(1, partial_images + 1):
step_idx = int(init_time_step + i * step_interval)
# Ensure we don't capture the final step
if step_idx < num_inference_steps - 1:
capture_steps.add(step_idx)
return capture_steps
def generate_image(
self,
settings: Config,
prompt: str,
seed: int,
partial_images: int = 0,
):
"""Primary entry point for image generation.
Orchestrates the full generation flow:
1. Create runtime config
2. Create initial latents
3. Encode prompt
4. Run diffusion loop (yielding partials if requested)
5. Decode to image
When partial_images > 0, yields (GeneratedImage, partial_index, total_partials)
tuples for intermediate images, then yields the final GeneratedImage.
Args:
settings: Generation config (steps, height, width)
prompt: Text prompt
seed: Random seed
partial_images: Number of intermediate images to yield (0 for none)
Yields:
Partial images as (GeneratedImage, partial_index, total_partials) tuples
Final GeneratedImage
"""
runtime_config = RuntimeConfig(settings, self.adapter.model.model_config)
latents = self.adapter.create_latents(seed, runtime_config)
prompt_data = self.adapter.encode_prompt(prompt)
# Calculate which steps to capture
capture_steps = self._calculate_capture_steps(
partial_images=partial_images,
init_time_step=runtime_config.init_time_step,
num_inference_steps=runtime_config.num_inference_steps,
)
# Run diffusion loop - may yield partial latents
diffusion_gen = self._run_diffusion_loop(
latents=latents,
prompt_data=prompt_data,
runtime_config=runtime_config,
seed=seed,
prompt=prompt,
capture_steps=capture_steps,
)
# Process partial yields and get final latents
partial_index = 0
total_partials = len(capture_steps)
if capture_steps:
# Generator mode - iterate to get partials and final latents
try:
while True:
partial_latents, _step = next(diffusion_gen)
if self.is_last_stage:
partial_image = self.adapter.decode_latents(
partial_latents, runtime_config, seed, prompt
)
yield (partial_image, partial_index, total_partials)
partial_index += 1
except StopIteration as e:
latents = e.value
else:
# No partials - just consume generator to get final latents
try:
while True:
next(diffusion_gen)
except StopIteration as e:
latents = e.value
# Yield final image (only on last stage)
if self.is_last_stage:
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt)
def _run_diffusion_loop(
self,
latents: mx.array,
prompt_data: PromptData,
runtime_config: RuntimeConfig,
seed: int,
prompt: str,
capture_steps: set[int] | None = None,
):
"""Execute the diffusion loop, optionally yielding at capture steps.
When capture_steps is provided and non-empty, this becomes a generator
that yields (latents, step_index) tuples at the specified timesteps.
Only the last stage yields (others have incomplete latents).
Args:
latents: Initial noise latents
prompt_data: Encoded prompt data
runtime_config: RuntimeConfig with scheduler, steps, dimensions
seed: Random seed (for callbacks)
prompt: Text prompt (for callbacks)
capture_steps: Set of timestep indices to capture (None = no captures)
Yields:
(latents, step_index) tuples at capture steps (last stage only)
Returns:
Final denoised latents ready for VAE decoding
"""
if capture_steps is None:
capture_steps = set()
time_steps = tqdm(range(runtime_config.num_inference_steps))
# Call subscribers for beginning of loop
Callbacks.before_loop(
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
)
for t in time_steps:
try:
latents = self._diffusion_step(
t=t,
config=runtime_config,
latents=latents,
prompt_data=prompt_data,
)
# Call subscribers in-loop
Callbacks.in_loop(
t=t,
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
time_steps=time_steps,
)
mx.eval(latents)
# Yield partial latents at capture steps (only on last stage)
if t in capture_steps and self.is_last_stage:
yield (latents, t)
except KeyboardInterrupt: # noqa: PERF203
Callbacks.interruption(
t=t,
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
time_steps=time_steps,
)
raise StopImageGenerationException(
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
) from None
# Call subscribers after loop
Callbacks.after_loop(
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
)
return latents
def _forward_pass(
self,
latents: mx.array,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
kwargs: dict[str, Any],
) -> mx.array:
"""Run a single forward pass through the transformer.
This is the internal method called by adapters via compute_step_noise.
Returns noise prediction without applying scheduler step.
For edit mode, concatenates conditioning latents with generated latents
before the transformer, and extracts only the generated portion after.
Args:
latents: Input latents (already scaled by caller)
prompt_embeds: Text embeddings
pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)
kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask, t)
Returns:
Noise prediction tensor
"""
t = kwargs.get("t", 0)
config = kwargs.get("config")
if config is None:
raise ValueError("config must be provided in kwargs")
scaled_latents = config.scheduler.scale_model_input(latents, t)
# For edit mode: concatenate with conditioning latents
conditioning_latents = kwargs.get("conditioning_latents")
original_latent_tokens = scaled_latents.shape[1]
if conditioning_latents is not None:
scaled_latents = mx.concatenate(
[scaled_latents, conditioning_latents], axis=1
)
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
scaled_latents, prompt_embeds
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds, hidden_states=hidden_states
)
rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds, config, **kwargs
)
text_seq_len = prompt_embeds.shape[1]
# Run through all joint blocks
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=None,
mode=BlockWrapperMode.CACHING,
block_idx=block_idx,
**kwargs,
)
# Merge streams
if self.joint_block_wrappers:
hidden_states = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
# Run through single blocks
for wrapper in self.single_block_wrappers:
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=None,
mode=BlockWrapperMode.CACHING,
)
# Extract image portion and project
hidden_states = hidden_states[:, text_seq_len:, ...]
# For edit mode: extract only the generated portion (exclude conditioning latents)
if conditioning_latents is not None:
hidden_states = hidden_states[:, :original_latent_tokens, ...]
return self.adapter.final_projection(hidden_states, text_embeddings)
def _diffusion_step(
self,
t: int,
config: RuntimeConfig,
latents: mx.array,
prompt_data: PromptData,
) -> mx.array:
"""Execute a single diffusion step.
Routes to single-node, sync pipeline, or async pipeline based on
configuration and current timestep.
"""
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif t < config.init_time_step + self.num_sync_steps:
return self._sync_pipeline(
t,
config,
latents,
prompt_data,
)
else:
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
)
def _single_node_step(
self,
t: int,
config: RuntimeConfig,
latents: mx.array,
prompt_data: PromptData,
) -> mx.array:
"""Execute a single diffusion step on a single node (no distribution)."""
base_kwargs = {"t": t, "config": config}
# For edit mode: include conditioning latents
if prompt_data.conditioning_latents is not None:
base_kwargs["conditioning_latents"] = prompt_data.conditioning_latents
if self.adapter.needs_cfg:
# Two forward passes + guidance for CFG models (e.g., Qwen)
pos_kwargs = {
**base_kwargs,
**prompt_data.get_extra_forward_kwargs(positive=True),
}
noise_pos = self._forward_pass(
latents,
prompt_data.prompt_embeds,
prompt_data.pooled_prompt_embeds,
pos_kwargs,
)
neg_kwargs = {
**base_kwargs,
**prompt_data.get_extra_forward_kwargs(positive=False),
}
noise_neg = self._forward_pass(
latents,
prompt_data.negative_prompt_embeds,
prompt_data.negative_pooled_prompt_embeds,
neg_kwargs,
)
assert self.config.guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale=self.config.guidance_scale
)
else:
# Single forward pass for non-CFG models (e.g., Flux)
kwargs = {**base_kwargs, **prompt_data.get_extra_forward_kwargs()}
noise = self._forward_pass(
latents,
prompt_data.prompt_embeds,
prompt_data.pooled_prompt_embeds,
kwargs,
)
return config.scheduler.step(model_output=noise, timestep=t, sample=latents)
def _initialize_kv_caches(
self,
batch_size: int,
num_img_tokens: int,
dtype: mx.Dtype,
) -> None:
"""Initialize KV caches for both sync and async pipelines.
Note: Caches only store IMAGE K/V, not text K/V. Text K/V is always
computed fresh and doesn't need caching (it's the same for all patches).
"""
self.joint_kv_caches = [
ImagePatchKVCache(
batch_size=batch_size,
num_heads=self.config.num_heads,
image_seq_len=num_img_tokens,
head_dim=self.config.head_dim,
dtype=dtype,
)
for _ in range(len(self.joint_block_wrappers))
]
self.single_kv_caches = [
ImagePatchKVCache(
batch_size=batch_size,
num_heads=self.config.num_heads,
image_seq_len=num_img_tokens,
head_dim=self.config.head_dim,
dtype=dtype,
)
for _ in range(len(self.single_block_wrappers))
]
def _create_patches(
self,
latents: mx.array,
config: RuntimeConfig,
) -> tuple[list[mx.array], list[tuple[int, int]]]:
"""Split latents into patches for async pipeline."""
# Use 16 to match FluxLatentCreator.create_noise formula
latent_height = config.height // 16
latent_width = config.width // 16
patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)
token_indices = calculate_token_indices(patch_heights, latent_width)
# Split latents into patches
patch_latents = [latents[:, start:end, :] for start, end in token_indices]
return patch_latents, token_indices
def _sync_pipeline(
self,
t: int,
config: RuntimeConfig,
hidden_states: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
prompt_embeds = prompt_data.prompt_embeds
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
extra_kwargs = prompt_data.get_extra_forward_kwargs()
hidden_states = config.scheduler.scale_model_input(hidden_states, t)
# For edit mode: handle conditioning latents
# All stages need to know the total token count for correct recv templates
conditioning_latents = prompt_data.conditioning_latents
original_latent_tokens = hidden_states.shape[1]
if conditioning_latents is not None:
num_img_tokens = original_latent_tokens + conditioning_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
# First stage: concatenate conditioning latents before embedding
if self.is_first_stage and conditioning_latents is not None:
hidden_states = mx.concatenate(
[hidden_states, conditioning_latents], axis=1
)
# === PHASE 1: Embeddings ===
if self.is_first_stage:
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
hidden_states, prompt_embeds
)
# All stages need these for their blocks
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
kontext_image_ids=kontext_image_ids,
**extra_kwargs,
)
# === Initialize KV caches to populate during sync for async warmstart ===
batch_size = prev_latents.shape[0]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
if t == config.init_time_step:
self._initialize_kv_caches(
batch_size=batch_size,
num_img_tokens=num_img_tokens,
dtype=prev_latents.dtype,
)
# === PHASE 2: Joint Blocks with Communication and Caching ===
if self.has_joint_blocks:
# Receive from previous stage (if not first stage)
if not self.is_first_stage:
recv_template = mx.zeros(
(batch_size, num_img_tokens, hidden_dim), dtype=prev_latents.dtype
)
hidden_states = mx.distributed.recv_like(
recv_template, self.prev_rank, group=self.group
)
enc_template = mx.zeros(
(batch_size, text_seq_len, hidden_dim), dtype=prev_latents.dtype
)
encoder_hidden_states = mx.distributed.recv_like(
enc_template, self.prev_rank, group=self.group
)
mx.eval(hidden_states, encoder_hidden_states)
# Run assigned joint blocks with caching mode
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.joint_kv_caches[block_idx],
mode=BlockWrapperMode.CACHING,
**extra_kwargs,
)
# === PHASE 3: Joint→Single Transition ===
if self.owns_concat_stage:
# Merge encoder and hidden states using adapter hook
concatenated = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
if self.has_single_blocks or self.is_last_stage:
# Keep locally: either for single blocks or final projection
hidden_states = concatenated
else:
# Send concatenated state to next stage (which has single blocks)
mx.eval(
mx.distributed.send(concatenated, self.next_rank, group=self.group)
)
elif self.has_joint_blocks and not self.is_last_stage:
# Send joint block outputs to next stage (which has more joint blocks)
mx.eval(
mx.distributed.send(hidden_states, self.next_rank, group=self.group),
mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
),
)
# === PHASE 4: Single Blocks with Communication and Caching ===
if self.has_single_blocks:
# Receive from previous stage if we didn't do concatenation
if not self.owns_concat_stage and not self.is_first_stage:
recv_template = mx.zeros(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype=prev_latents.dtype,
)
hidden_states = mx.distributed.recv_like(
recv_template, self.prev_rank, group=self.group
)
mx.eval(hidden_states)
# Run assigned single blocks with caching mode
for block_idx, wrapper in enumerate(self.single_block_wrappers):
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.single_kv_caches[block_idx],
mode=BlockWrapperMode.CACHING,
)
# Send to next stage if not last
if not self.is_last_stage:
mx.eval(
mx.distributed.send(hidden_states, self.next_rank, group=self.group)
)
# === PHASE 5: Last Stage - Final Projection + Scheduler ===
# Extract image portion (remove text embeddings prefix)
hidden_states = hidden_states[:, text_seq_len:, ...]
# For edit mode: extract only the generated portion (exclude conditioning latents)
if conditioning_latents is not None:
hidden_states = hidden_states[:, :original_latent_tokens, ...]
if self.is_last_stage:
hidden_states = self.adapter.final_projection(
hidden_states, text_embeddings
)
hidden_states = config.scheduler.step(
model_output=hidden_states,
timestep=t,
sample=prev_latents,
)
if not self.is_first_stage:
mx.eval(mx.distributed.send(hidden_states, 0, group=self.group))
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
# For shape correctness
hidden_states = prev_latents
return hidden_states
def _async_pipeline_step(
self,
t: int,
config: RuntimeConfig,
latents: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
patch_latents = self._async_pipeline(
t,
config,
patch_latents,
token_indices,
prompt_data,
kontext_image_ids,
)
return mx.concatenate(patch_latents, axis=1)
def _async_pipeline(
self,
t: int,
config: RuntimeConfig,
patch_latents: list[mx.array],
token_indices: list[tuple[int, int]],
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> list[mx.array]:
"""Execute async pipeline for all patches."""
assert self.joint_kv_caches is not None
assert self.single_kv_caches is not None
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
prompt_embeds = prompt_data.prompt_embeds
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
extra_kwargs = prompt_data.get_extra_forward_kwargs()
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
kontext_image_ids=kontext_image_ids,
**extra_kwargs,
)
batch_size = patch_latents[0].shape[0]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
for patch_idx, patch in enumerate(patch_latents):
patch_prev = patch
start_token, end_token = token_indices[patch_idx]
if self.has_joint_blocks:
if (
not self.is_first_stage
or t != config.init_time_step + self.num_sync_steps
):
if self.is_first_stage:
# First stage receives latent-space from last stage (scheduler output)
recv_template = patch
else:
# Other stages receive hidden-space from previous stage
patch_len = patch.shape[1]
recv_template = mx.zeros(
(batch_size, patch_len, hidden_dim),
dtype=patch.dtype,
)
patch = mx.distributed.recv_like(
recv_template, src=self.prev_rank, group=self.group
)
mx.eval(patch)
patch_latents[patch_idx] = patch
if not self.is_first_stage and patch_idx == 0:
enc_template = mx.zeros(
(batch_size, text_seq_len, hidden_dim),
dtype=patch_latents[0].dtype,
)
encoder_hidden_states = mx.distributed.recv_like(
enc_template, src=self.prev_rank, group=self.group
)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
patch, prompt_embeds
)
# Run assigned joint blocks with patched mode
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.joint_kv_caches[block_idx],
mode=BlockWrapperMode.PATCHED,
patch_start=start_token,
patch_end=end_token,
**extra_kwargs,
)
if self.owns_concat_stage:
patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)
if self.has_single_blocks or self.is_last_stage:
# Keep locally: either for single blocks or final projection
patch = patch_concat
else:
mx.eval(
mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
)
elif self.has_joint_blocks and not self.is_last_stage:
mx.eval(mx.distributed.send(patch, self.next_rank, group=self.group))
if patch_idx == 0:
mx.eval(
mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
recv_template = mx.zeros(
[
batch_size,
text_seq_len + patch_latents[patch_idx].shape[1],
hidden_dim,
],
dtype=patch_latents[0].dtype,
)
patch = mx.distributed.recv_like(
recv_template, src=self.prev_rank, group=self.group
)
mx.eval(patch)
patch_latents[patch_idx] = patch
# Run assigned single blocks with patched mode
for block_idx, wrapper in enumerate(self.single_block_wrappers):
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.single_kv_caches[block_idx],
mode=BlockWrapperMode.PATCHED,
patch_start=start_token,
patch_end=end_token,
)
if not self.is_last_stage:
mx.eval(
mx.distributed.send(patch, self.next_rank, group=self.group)
)
if self.is_last_stage:
patch_img_only = patch[:, text_seq_len:, :]
patch_img_only = self.adapter.final_projection(
patch_img_only, text_embeddings
)
patch = config.scheduler.step(
model_output=patch_img_only,
timestep=t,
sample=patch_prev,
)
if not self.is_first_stage and t != config.num_inference_steps - 1:
mx.eval(
mx.distributed.send(patch, self.next_rank, group=self.group)
)
patch_latents[patch_idx] = patch
return patch_latents

View File

@@ -103,6 +103,7 @@ class PipelineLastLayer(CustomMlxLayer):
# This change happened upstream - check out mlx github somewhere??
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
# TODO(ciaran): This is overkill
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
return output

View File

@@ -9,7 +9,7 @@ MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"
CACHE_GROUP_SIZE: int = 64
KV_CACHE_BITS: int | None = 8
KV_CACHE_BITS: int | None = None
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True

View File

@@ -1,7 +1,7 @@
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.tokenizer_utils import TokenizerWrapper

View File

@@ -343,10 +343,6 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
if hasattr(model, "make_cache"):
logger.info(f"Using make_cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")
@@ -399,11 +395,5 @@ def set_wired_limit_for_model(model_size: Memory):
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
)
kv_bytes = int(0.02 * model_bytes)
target_cache = int(1.10 * (model_bytes + kv_bytes))
target_cache = min(target_cache, max_rec_size)
mx.set_cache_limit(target_cache)
mx.set_wired_limit(max_rec_size)
logger.info(
f"Wired limit set to {max_rec_size}. Cache limit set to {target_cache}."
)
logger.info(f"Wired limit set to {max_rec_size}.")

View File

@@ -8,13 +8,15 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
EventId,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
@@ -23,12 +25,14 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
DownloadModel,
ImageEdits,
Shutdown,
Task,
TaskStatus,
@@ -83,7 +87,7 @@ class Worker:
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup | None = None
@@ -94,6 +98,10 @@ class Worker:
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
async def run(self):
logger.info("Starting Worker")
@@ -128,6 +136,7 @@ class Worker:
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -171,6 +180,17 @@ class Worker:
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
)
async def plan_step(self):
while True:
await anyio.sleep(0.1)
@@ -183,6 +203,8 @@ class Worker:
self.state.instances,
self.state.runners,
self.state.tasks,
self.input_chunk_buffer,
self.input_chunk_counts,
)
if task is None:
continue
@@ -200,11 +222,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard not in self.download_status:
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -217,7 +239,7 @@ class Worker:
progress = DownloadCompleted(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -244,6 +266,42 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
chunks = self.input_chunk_buffer.get(cmd_id, {})
assembled = "".join(chunks[i] for i in range(len(chunks)))
logger.info(
f"Assembled input image from {len(chunks)} chunks, "
f"total size: {len(assembled)} bytes"
)
# Create modified task with assembled image data
modified_task = ImageEdits(
task_id=task.task_id,
command_id=task.command_id,
instance_id=task.instance_id,
task_status=task.task_status,
task_params=ImageEditsInternalParams(
image_data=assembled,
total_input_chunks=task.task_params.total_input_chunks,
prompt=task.task_params.prompt,
model=task.task_params.model,
n=task.task_params.n,
quality=task.task_params.quality,
output_format=task.task_params.output_format,
response_format=task.task_params.response_format,
size=task.task_params.size,
image_strength=task.task_params.image_strength,
),
)
# Cleanup buffers
if cmd_id in self.input_chunk_buffer:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
@@ -349,7 +407,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata] = status
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -363,7 +421,7 @@ class Worker:
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
self.download_status[shard] = status
self.download_status[shard.model_meta.model_id] = status
# Footgun!
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
@@ -384,7 +442,7 @@ class Worker:
progress
),
)
self.download_status[shard] = status
self.download_status[shard.model_meta.model_id] = status
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
@@ -444,3 +502,40 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id, shard_metadata=progress.shard
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -2,12 +2,15 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadModel,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -34,7 +37,6 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -43,12 +45,14 @@ def plan(
# Runners is expected to be FRESH and so should not come from state
runners: Mapping[RunnerId, RunnerSupervisor],
# DL_status is expected to be FRESH and so should not come from state
download_status: Mapping[ShardMetadata, DownloadProgress],
download_status: Mapping[ModelId, DownloadProgress],
# gdls is not expected to be fresh
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_counts: Mapping[CommandId, int] | None = None,
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
@@ -58,7 +62,7 @@ def plan(
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
)
@@ -111,13 +115,14 @@ def _create_runner(
def _model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ShardMetadata, DownloadProgress],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
not isinstance(
download_status.get(runner.bound_instance.bound_shard, None),
(DownloadOngoing, DownloadCompleted),
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
@@ -261,18 +266,38 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
if not isinstance(task, ChatCompletion):
# TODO(ciaran): do this better!
if (
not isinstance(task, ChatCompletion)
and not isinstance(task, ImageGeneration)
and not isinstance(task, ImageEdits)
):
continue
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
continue
# For ImageEdits tasks, verify all input chunks have been received
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(

View File

@@ -1,7 +1,10 @@
import base64
import time
from exo.master.api import get_model_card
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -9,9 +12,12 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelTask
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -21,6 +27,8 @@ from exo.shared.types.tasks import (
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ImageGenerationResponse,
PartialImageResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -32,10 +40,18 @@ from exo.shared.types.worker.runners import (
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
)
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.engines.image import (
ImageGenerator,
generate_image,
initialize_image_model,
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -69,6 +85,10 @@ def main(
sampler = None
group = None
model_card = get_model_card(shard_metadata.model_meta.model_id)
assert model_card
model_tasks = model_card.tasks
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
event_sender.send(
@@ -111,16 +131,26 @@ def main(
)
)
model, tokenizer, sampler = load_mlx_items(
bound_instance, group
)
# TODO(ciaran): switch
if ModelTask.TextGeneration in model_tasks:
model, tokenizer, sampler = load_mlx_items(
bound_instance, group
)
elif (
ModelTask.TextToImage in model_tasks
or ModelTask.ImageToImage in model_tasks
):
model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {model_card.tasks}"
)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
assert sampler
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -130,22 +160,40 @@ def main(
)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
sampler=sampler,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
if ModelTask.TextGeneration in model_tasks:
assert model and isinstance(model, Model)
assert tokenizer
assert sampler
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
sampler=sampler,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
elif (
ModelTask.TextToImage in model_tasks
or ModelTask.ImageToImage in model_tasks
):
assert isinstance(model, ImageGenerator)
image = warmup_image_generator(model=model)
if image is not None:
logger.info(
f"warmed up by generating {image.size} image"
)
else:
logger.info("warmup completed (non-primary node)")
current_status = RunnerReady()
logger.info("runner ready")
case ChatCompletion(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert model
assert model and isinstance(model, Model)
assert tokenizer
assert sampler
logger.info(f"received chat request: {str(task)[:500]}")
@@ -186,14 +234,171 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
logger.info("runner shutting down")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, ImageGenerator)
logger.info(
f"received image generation request: {str(task)[:500]}"
)
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
break
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(
model=model,
task=task_params,
):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
encoded_data = base64.b64encode(
response.image_data
).decode("utf-8")
# Split into chunks to stay under gossipsub 1MB limit
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
)
]
total_chunks = len(data_chunks)
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}: {len(encoded_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(
data_chunks
):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=shard_metadata.model_meta.model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=response.partial_index,
is_partial=True,
partial_index=response.partial_index,
total_partials=response.total_partials,
),
)
)
case ImageGenerationResponse():
encoded_data = base64.b64encode(
response.image_data
).decode("utf-8")
# Split into chunks to stay under gossipsub 1MB limit
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
)
]
total_chunks = len(data_chunks)
logger.info(
f"sending final ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(
data_chunks
):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=shard_metadata.model_meta.model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
is_partial=False,
),
)
)
image_index += 1
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, ImageGenerator)
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
image_index = 0
for response in generate_image(
model=model,
task=task_params,
):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case ImageGenerationResponse():
encoded_data = base64.b64encode(
response.image_data
).decode("utf-8")
# Split into chunks to stay under gossipsub 1MB limit
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
)
]
total_chunks = len(data_chunks)
logger.info(
f"sending ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(
data_chunks
):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=shard_metadata.model_meta.model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
is_partial=False,
),
)
)
image_index += 1
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
@@ -208,9 +413,8 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
)
if isinstance(current_status, RunnerShutdown):
break
except ClosedResourceError:
logger.warning("runner communication closed unexpectedly")
except Exception as e:

View File

@@ -14,13 +14,23 @@ from anyio import (
from anyio.abc import TaskGroup
from loguru import logger
from exo.shared.types.events import Event, RunnerStatusUpdated, TaskAcknowledged
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerConnecting,
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerRunning,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
@@ -39,10 +49,10 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
# err_path: str
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -77,7 +87,6 @@ class RunnerSupervisor:
_ev_recv=ev_recv,
_task_sender=task_sender,
_event_sender=event_sender,
# err_path=err_path,
)
return self
@@ -118,6 +127,10 @@ class RunnerSupervisor:
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.completed:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -138,6 +151,22 @@ class RunnerSupervisor:
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
continue
if (
isinstance(event, TaskStatusUpdated)
and event.task_status == TaskStatus.Complete
):
# If a task has just been completed, we should be working on it.
assert isinstance(
self.status,
(
RunnerRunning,
RunnerWarmingUp,
RunnerLoading,
RunnerConnecting,
RunnerShuttingDown,
),
)
self.completed.add(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)

View File

@@ -9,9 +9,11 @@ MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
NODE_C: Final[NodeId] = NodeId("cccccccc-cccc-4ccc-8ccc-cccccccccccc")
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
RUNNER_3_ID: Final[RunnerId] = RunnerId("Runner3")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, field
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask
from exo.shared.types.tasks import BaseTask, TaskId
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
@@ -21,6 +19,7 @@ from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
class OtherTask(BaseTask):

View File

@@ -1,5 +1,6 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance
@@ -7,7 +8,6 @@ from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerIdle,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
@@ -46,7 +46,7 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ShardMetadata, DownloadProgress] = {}
download_status: dict[ModelId, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
@@ -94,7 +94,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
}
# Global view has completed downloads for both nodes
@@ -140,7 +140,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Local status claims the shard is downloaded already
local_download_status = {
shard: DownloadCompleted(shard_metadata=shard, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
@@ -192,7 +192,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
}
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],

View File

@@ -12,8 +12,10 @@ from exo.worker.tests.constants import (
MODEL_A_ID,
NODE_A,
NODE_B,
NODE_C,
RUNNER_1_ID,
RUNNER_2_ID,
RUNNER_3_ID,
)
from exo.worker.tests.unittests.conftest import (
FakeRunnerSupervisor,
@@ -24,37 +26,39 @@ from exo.worker.tests.unittests.conftest import (
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
"""
For non-final device_rank shards, StartWarmup should be emitted when all
For non-zero device_rank shards, StartWarmup should be emitted when all
shards in the instance are Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=3)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=3)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=2, world_size=3)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID, NODE_C: RUNNER_3_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1, RUNNER_3_ID: shard2},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
RUNNER_3_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_A,
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
@@ -150,9 +154,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
"""
Rank-zero shard should not start warmup until all non-zero ranks are
already WarmingUp.
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
For accepting ranks (device_rank != 0), StartWarmup should be
emitted when all shards in the instance are Loaded/WarmingUp.
In a 2-node setup, rank 0 is the accepting rank.
In a 2-node setup, rank 1 is the accepting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -163,7 +167,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 0 is the accepting rank
# Rank 1 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
@@ -188,6 +192,23 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
tasks={},
)
assert result is None
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
@@ -280,9 +301,8 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
"""
Connecting rank (device_rank == world_size - 1) should not start warmup
Connecting rank (device_rank == 0) should not start warmup
until all other ranks are already WarmingUp.
In a 2-node setup, rank 1 is the connecting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -295,13 +315,13 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
@@ -309,7 +329,7 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
}
result = plan_mod.plan(
node_id=NODE_B,
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},

View File

@@ -34,6 +34,7 @@ from exo.shared.types.worker.runners import (
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
@@ -199,6 +200,9 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
),
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),

1293
uv.lock generated
View File

File diff suppressed because it is too large Load Diff