Compare commits

...

215 Commits

Author SHA1 Message Date
ciaranbor
7925cc826a Fix flux tokenizer 2026-01-15 18:39:56 +00:00
ciaranbor
b9c9f53a73 Reduce image generation and image edits code duplication 2026-01-15 18:28:24 +00:00
ciaranbor
7eea39de9d Update mflux to 0.14.2 2026-01-15 17:43:33 +00:00
ciaranbor
a736c3824c Add python-multipart dependency 2026-01-15 16:31:47 +00:00
ciaranbor
58970a7ba2 Register tasks for newly added models 2026-01-15 16:31:47 +00:00
ciaranbor
fe9bbcd3d0 Linting 2026-01-15 16:31:47 +00:00
ciaranbor
b164e62fd7 Start image editing time steps at 0 2026-01-15 16:31:47 +00:00
ciaranbor
6078e9a8a1 Ignore image_strength 2026-01-15 16:31:47 +00:00
ciaranbor
ff52182aff Handle conditioning latents in sync pipeline 2026-01-15 16:31:47 +00:00
ciaranbor
de4efff6ac Use dummy image for editing warmup 2026-01-15 16:31:47 +00:00
ciaranbor
012b87abbf Support streaming for image editing 2026-01-15 16:31:47 +00:00
ciaranbor
68a64b5671 Support image editing in runner 2026-01-15 16:31:47 +00:00
ciaranbor
da8579c3ab Add editing features to adapter 2026-01-15 16:31:47 +00:00
ciaranbor
db222fb69e Default partial images to 3 if streaming 2026-01-15 16:31:47 +00:00
ciaranbor
a4329997c8 Add Qwen-Image model adapter 2026-01-15 16:31:47 +00:00
ciaranbor
3d9f8a5161 Add Qwen-Image-Edit model config 2026-01-15 16:31:47 +00:00
ciaranbor
5095d52a3d Use image generation in streaming mode in UI 2026-01-15 16:31:47 +00:00
ciaranbor
59b92b7c56 Handle partial image streaming 2026-01-15 16:31:47 +00:00
ciaranbor
1be6caacb4 Add streaming params to ImageGenerationTaskParams 2026-01-15 16:31:47 +00:00
ciaranbor
b2cf0390da Add Qwen-Image-Edit-2509 2026-01-15 16:31:47 +00:00
ciaranbor
148ee6c347 Handle image editing time steps 2026-01-15 16:31:47 +00:00
ciaranbor
e40976290d Fix time steps 2026-01-15 16:31:47 +00:00
ciaranbor
69178664b3 Fix image_strength meaning 2026-01-15 16:31:47 +00:00
ciaranbor
b8ef6345ef Truncate image data logs 2026-01-15 16:31:47 +00:00
ciaranbor
fc534e3a0d Chunk image input 2026-01-15 16:31:47 +00:00
ciaranbor
c8cb0b04b4 Avoid logging image data 2026-01-15 16:31:47 +00:00
ciaranbor
77334d60c7 Support image editing 2026-01-15 16:31:47 +00:00
Sami Khan
71d684861c small UI change 2026-01-15 16:31:47 +00:00
Sami Khan
95e277acf7 image gen in dashboard 2026-01-15 16:31:46 +00:00
ciaranbor
63284fd5fe Better llm model type check 2026-01-15 16:31:46 +00:00
ciaranbor
982567db05 Prune blocks before model load 2026-01-15 16:31:46 +00:00
ciaranbor
d72a41e9bc Own TODOs 2026-01-15 16:31:46 +00:00
ciaranbor
8921b786ab Remove double RunnerReady event 2026-01-15 16:31:46 +00:00
ciaranbor
52113075a4 Fix hidden_size for image models 2026-01-15 16:31:46 +00:00
ciaranbor
9d8add1977 Fix image model cards 2026-01-15 16:31:46 +00:00
ciaranbor
daf7fe495b Skip decode on non-final ranks 2026-01-15 16:31:46 +00:00
ciaranbor
b494cc3c11 Final rank produces image 2026-01-15 16:31:46 +00:00
ciaranbor
e9e6f93945 Increase number of sync steps 2026-01-15 16:31:46 +00:00
ciaranbor
3410874ee9 Change Qwen-Image steps 2026-01-15 16:31:46 +00:00
ciaranbor
a6ba92bf6b Fix Qwen-Image latent shapes 2026-01-15 16:31:46 +00:00
ciaranbor
280364d872 Fix joint block patch recv shape for non-zero ranks 2026-01-15 16:31:46 +00:00
ciaranbor
832f687d85 Fix comms issue for models without single blocks 2026-01-15 16:31:46 +00:00
ciaranbor
769162b509 Support Qwen in DiffusionRunner pipefusion 2026-01-15 16:31:46 +00:00
ciaranbor
f80a9789a5 Implement Qwen pipefusion 2026-01-15 16:31:46 +00:00
ciaranbor
c159f2f7b9 Add guidance_scale parameter to image model config 2026-01-15 16:31:46 +00:00
ciaranbor
f4270a6056 Move orchestration to DiffusionRunner 2026-01-15 16:31:46 +00:00
ciaranbor
bd48be8b0e Add initial QwenModelAdapter 2026-01-15 16:31:46 +00:00
ciaranbor
2b556ac7fb Tweak embeddings interface 2026-01-15 16:31:46 +00:00
ciaranbor
3bccde49d0 Add Qwen ImageModelConfig 2026-01-15 16:31:46 +00:00
ciaranbor
1fa952adfc Use 10% sync steps 2026-01-15 16:31:46 +00:00
ciaranbor
f4909aa7c6 Update FluxModelAdaper for new interface 2026-01-15 16:31:46 +00:00
ciaranbor
89a2bd4d18 Register QwenModelAdapter 2026-01-15 16:31:46 +00:00
ciaranbor
623f623297 Support multiple forward passes in runner 2026-01-15 16:31:46 +00:00
ciaranbor
4022d0585b Extend block wrapper parameters 2026-01-15 16:31:46 +00:00
ciaranbor
e2155579f4 Relax adaptor typing 2026-01-15 16:31:46 +00:00
ciaranbor
5a1a124e65 Add Qwen-Image model card 2026-01-15 16:31:46 +00:00
ciaranbor
3121827263 Clean up dead code 2026-01-15 16:31:46 +00:00
ciaranbor
480b72b1b1 Add BaseModelAdaptor 2026-01-15 16:31:46 +00:00
ciaranbor
28986bb678 Refactor filestructure 2026-01-15 16:31:46 +00:00
ciaranbor
c53fc6a16f Treat unified blocks as single blocks (equivalent) 2026-01-15 16:31:46 +00:00
ciaranbor
2e86b0f5a9 Refactor to handle entire denoising process in Diffusion runner 2026-01-15 16:31:46 +00:00
ciaranbor
98f0a29085 Move transformer to adapter 2026-01-15 16:31:46 +00:00
ciaranbor
1c0f2daf3c Move some more logic to adaptor 2026-01-15 16:31:46 +00:00
ciaranbor
814a836db1 Add generic block wrapper 2026-01-15 16:31:46 +00:00
ciaranbor
ef03ef049c Access transformer blocks from adaptor 2026-01-15 16:31:46 +00:00
ciaranbor
ba8567418d Better typing 2026-01-15 16:31:46 +00:00
ciaranbor
801ecf4483 Create wrappers at init time 2026-01-15 16:31:46 +00:00
ciaranbor
87e25961f5 Combine model factory and adaptor 2026-01-15 16:31:46 +00:00
ciaranbor
1d1014eaef Implement model factory 2026-01-15 16:31:46 +00:00
ciaranbor
8bd077de52 Add adaptor registry 2026-01-15 16:31:46 +00:00
ciaranbor
8377da5e22 Remove mflux/generator/generate.py 2026-01-15 16:31:46 +00:00
ciaranbor
4aa3b75000 Switch to using DistributedImageModel 2026-01-15 16:31:46 +00:00
ciaranbor
3d24aab421 Add DistributedImageModel 2026-01-15 16:31:46 +00:00
ciaranbor
04bb688005 Use new generic wrappers, etc in denoising 2026-01-15 16:31:46 +00:00
ciaranbor
1c70cea40c Add generic transformer block wrappers 2026-01-15 16:31:46 +00:00
ciaranbor
5928f369c5 Add FluxAdaptor 2026-01-15 16:31:46 +00:00
ciaranbor
ce1b66e5e6 Add ModelAdaptor, derivations implement model specific logic 2026-01-15 16:31:46 +00:00
ciaranbor
5d503a1ffb Introduce image model config concept 2026-01-15 16:31:46 +00:00
ciaranbor
f94a5ec8df Consolidate kv cache patching 2026-01-15 16:31:46 +00:00
ciaranbor
d45f9d98c0 Support different configuration comms 2026-01-15 16:31:46 +00:00
ciaranbor
7d4faf04fb Add ImageGenerator protocol 2026-01-15 16:31:46 +00:00
ciaranbor
2d4ba878cb Force final patch receive order 2026-01-15 16:31:46 +00:00
ciaranbor
bca5a9ffe3 Remove logs 2026-01-15 16:31:46 +00:00
ciaranbor
6bbd134880 Update patch list 2026-01-15 16:31:46 +00:00
ciaranbor
1414da68ec Slight refactor 2026-01-15 16:31:46 +00:00
ciaranbor
c88156f5ab Don't need array for prev patches 2026-01-15 16:31:46 +00:00
ciaranbor
770982c830 Fix send/recv order 2026-01-15 16:31:46 +00:00
ciaranbor
8d99ed8133 Fix async single transformer block 2026-01-15 16:31:46 +00:00
ciaranbor
c3d8fbc5ed Use relative rank variables 2026-01-15 16:31:46 +00:00
ciaranbor
a72830c301 Fix writing patches 2026-01-15 16:31:46 +00:00
ciaranbor
1fc355a2b1 Collect final image 2026-01-15 16:31:46 +00:00
ciaranbor
19dc8380c6 Fix recv_template shape 2026-01-15 16:31:46 +00:00
ciaranbor
11148923ca Add logs 2026-01-15 16:31:46 +00:00
ciaranbor
329b7d5f36 Optimise async pipeline 2026-01-15 16:31:46 +00:00
ciaranbor
1752aaa44a Add next_rank and prev_rank members 2026-01-15 16:31:46 +00:00
ciaranbor
a3fa833ae4 Add _create_patches method 2026-01-15 16:31:46 +00:00
ciaranbor
4661013cbb Fix shapes 2026-01-15 16:31:46 +00:00
ciaranbor
54e80a314d Reorder comms 2026-01-15 16:31:46 +00:00
ciaranbor
7990d8b1ef Remove all_gather from sync pipeline, send from final rank to first rank 2026-01-15 16:31:46 +00:00
ciaranbor
289bbe3253 Simplify kv_cache initialization 2026-01-15 16:31:46 +00:00
ciaranbor
ea06742295 Fix kv cache 2026-01-15 16:31:46 +00:00
ciaranbor
5bf986db6b Clean up kv caches 2026-01-15 16:31:46 +00:00
ciaranbor
0a9c1f7212 Fix return 2026-01-15 16:31:46 +00:00
ciaranbor
4b84aa5f70 Fix hidden_states shapes 2026-01-15 16:31:46 +00:00
ciaranbor
148f6550ed Only perform projection and scheduler step on last rank 2026-01-15 16:31:46 +00:00
ciaranbor
ecf2f40b4c Only compute embeddings on rank 0 2026-01-15 16:31:46 +00:00
ciaranbor
eea030b8c2 Remove eval 2026-01-15 16:31:46 +00:00
ciaranbor
73c92dfe60 Remove eval 2026-01-15 16:31:46 +00:00
ciaranbor
a6f7c4b822 Only send encoder_hidden_states with the first patch (once per timestep) 2026-01-15 16:31:46 +00:00
ciaranbor
4b81f8a672 Remove redundant text kv cache computation 2026-01-15 16:31:46 +00:00
ciaranbor
6e57d817d1 Concatenate before all gather 2026-01-15 16:31:46 +00:00
ciaranbor
4905107ea2 Increase number of sync steps 2026-01-15 16:31:46 +00:00
ciaranbor
e56c970e74 Reinitialise kv_caches between generations 2026-01-15 16:31:46 +00:00
ciaranbor
5d3bc83a63 Eliminate double kv cache computation 2026-01-15 16:31:46 +00:00
ciaranbor
88356eb0a0 Add kv cache caching wrappers for sync pipeline transformer blocks 2026-01-15 16:31:46 +00:00
ciaranbor
cab296ada7 Persist kv caches 2026-01-15 16:31:46 +00:00
ciaranbor
bfc6650a13 Implement naive async pipeline implementation 2026-01-15 16:31:46 +00:00
ciaranbor
66c091ae88 Use wrapper classes for patched transformer logic 2026-01-15 16:31:45 +00:00
ciaranbor
1ca1a3e490 Add patch-aware joint and single attention wrappers 2026-01-15 16:31:45 +00:00
ciaranbor
b778213792 Fix group.size() 2026-01-15 16:31:45 +00:00
ciaranbor
14a3a5d41c Add classes to manage kv caches with patch support 2026-01-15 16:31:45 +00:00
ciaranbor
bef9589510 Use heuristic for number of sync steps 2026-01-15 16:31:45 +00:00
ciaranbor
d39fbf796d Generalise number of denoising steps 2026-01-15 16:31:45 +00:00
ciaranbor
19ef6ea748 Add flux1-dev 2026-01-15 16:31:45 +00:00
ciaranbor
431ddf947e Move scheduler step to inner pipeline 2026-01-15 16:31:45 +00:00
ciaranbor
b5485bf6ef Add barrier before all_gather 2026-01-15 16:31:45 +00:00
ciaranbor
8325d5b865 Fix transformer blocks pruning 2026-01-15 16:31:45 +00:00
ciaranbor
44de96c15c Fix image generation api 2026-01-15 16:31:45 +00:00
ciaranbor
00c88a1102 Create queue in try block 2026-01-15 16:31:45 +00:00
ciaranbor
594487caed Conform to rebase 2026-01-15 16:31:45 +00:00
ciaranbor
7d9df93b7a Refactor denoising 2026-01-15 16:31:45 +00:00
ciaranbor
692907d2de Move more logic to DistributedFlux 2026-01-15 16:31:45 +00:00
ciaranbor
4018f698a1 Move surrounding logic back to _sync_pipeline 2026-01-15 16:31:45 +00:00
ciaranbor
330c7bb9cf Add patching aware member variables 2026-01-15 16:31:45 +00:00
ciaranbor
c8d54af8b6 Implement sync/async switching logic 2026-01-15 16:31:45 +00:00
ciaranbor
ba798e6bd3 Move current transformer implementation to _sync_pipeline method 2026-01-15 16:31:45 +00:00
ciaranbor
bf25de116a Remove some logs 2026-01-15 16:31:45 +00:00
ciaranbor
64e0dd06a8 Remove old Flux1 implementation 2026-01-15 16:31:45 +00:00
ciaranbor
5dafb7aceb Prune unused transformer blocks 2026-01-15 16:31:45 +00:00
ciaranbor
bbe0b58642 Add mx.eval 2026-01-15 16:31:45 +00:00
ciaranbor
b3233e35f0 Test evals 2026-01-15 16:31:45 +00:00
ciaranbor
887441e666 Test only barriers 2026-01-15 16:31:45 +00:00
ciaranbor
e3231ae22b All perform final projection 2026-01-15 16:31:45 +00:00
ciaranbor
b2918f5e42 Another barrier 2026-01-15 16:31:45 +00:00
ciaranbor
4d9b893d7a More debug 2026-01-15 16:31:45 +00:00
ciaranbor
2494a05790 Add barriers 2026-01-15 16:31:45 +00:00
ciaranbor
9802f27545 Add log 2026-01-15 16:31:45 +00:00
ciaranbor
926b197ea5 Restore distributed logging 2026-01-15 16:31:45 +00:00
ciaranbor
580d1738fc Use bootstrap logger 2026-01-15 16:31:45 +00:00
ciaranbor
a94aacb72b Remove logs 2026-01-15 16:31:45 +00:00
ciaranbor
e6758829c7 fix single block receive shape 2026-01-15 16:31:45 +00:00
ciaranbor
c892352860 Add debug logs 2026-01-15 16:31:45 +00:00
ciaranbor
23048f0fbb Move communication logic to DistributedTransformer wrapper 2026-01-15 16:31:45 +00:00
ciaranbor
3a45e55dcf Move inference logic to DistribuedFlux1 2026-01-15 16:31:45 +00:00
ciaranbor
50ba4a38f1 Add DistributedFlux1 class 2026-01-15 16:31:45 +00:00
ciaranbor
d4f49b9a38 Rename pipeline to pipefusion 2026-01-15 16:31:45 +00:00
ciaranbor
57135bda07 Further refactor 2026-01-15 16:31:45 +00:00
ciaranbor
ab492c76e9 Refactor warmup 2026-01-15 16:31:45 +00:00
ciaranbor
e55b3d496f Manually handle flux1 inference 2026-01-15 16:31:45 +00:00
ciaranbor
c20ad0d5fe Refactor flux1 image generation 2026-01-15 16:31:45 +00:00
ciaranbor
b02fb39747 Use quality parameter to set number of inference steps 2026-01-15 16:31:45 +00:00
ciaranbor
d257abed82 Chunk image data transfer 2026-01-15 16:31:45 +00:00
ciaranbor
e84a14b650 Define EXO_MAX_CHUNK_SIZE 2026-01-15 16:31:45 +00:00
ciaranbor
04128b65a7 Add indexing info to ImageChunk 2026-01-15 16:31:45 +00:00
ciaranbor
3d6e675af8 Remove sharding logs 2026-01-15 16:31:45 +00:00
ciaranbor
7b3320cd0e Temp: reduce flux1.schnell storage size 2026-01-15 16:31:45 +00:00
ciaranbor
1b7208bc04 Fix mflux transformer all_gather 2026-01-15 16:31:45 +00:00
ciaranbor
eef91921f2 Add all_gather -> broadcast todo 2026-01-15 16:31:45 +00:00
ciaranbor
3e8ab46d69 Fix world size 2026-01-15 16:31:45 +00:00
ciaranbor
02f811dd7e Fix transition block? 2026-01-15 16:31:45 +00:00
ciaranbor
1b7eb4abb2 Implement image generation warmup 2026-01-15 16:31:45 +00:00
ciaranbor
c8f27976c9 Add logs 2026-01-15 16:31:45 +00:00
ciaranbor
56e6ae4984 Add spiece.model to default patterns 2026-01-15 16:31:45 +00:00
ciaranbor
bccb2977ec Just download all files for now 2026-01-15 16:31:45 +00:00
ciaranbor
3d38e1977e Fix get_allow_patterns to include non-indexed safetensors files 2026-01-15 16:31:45 +00:00
ciaranbor
beb6371caf Use half-open layer indexing in get_allow_patterns 2026-01-15 16:31:45 +00:00
ciaranbor
6f66b387a8 Enable distributed mflux 2026-01-15 16:31:45 +00:00
ciaranbor
b0b789d971 Implement mflux transformer sharding and communication pattern 2026-01-15 16:31:45 +00:00
ciaranbor
6921df88a1 Update get_allow_patterns to handle sharding components 2026-01-15 16:31:45 +00:00
ciaranbor
b4cd0517c9 Namespace both keys and values for component weight maps 2026-01-15 16:31:45 +00:00
ciaranbor
ece3f207ad Add components to Flux.1-schnell MODEL_CARD 2026-01-15 16:31:45 +00:00
ciaranbor
29575a1fea Add component concept for ModelMetadata 2026-01-15 16:31:45 +00:00
ciaranbor
8d0cdb2b52 Fix multiple components weight map key conflicts 2026-01-15 16:31:45 +00:00
ciaranbor
eeac072a6b get_weight_map: handle repos with multiple safetensors.index.json files 2026-01-15 16:31:45 +00:00
ciaranbor
7ed6b75b41 Add initial image edits spec 2026-01-15 16:31:45 +00:00
ciaranbor
6e00899385 Add image edits endpoint 2026-01-15 16:31:45 +00:00
ciaranbor
ea3bab243a Add ImageToImage task 2026-01-15 16:31:45 +00:00
ciaranbor
497a2c065d Allow ModelCards to have multiple tasks 2026-01-15 16:31:45 +00:00
ciaranbor
4dd1a7c1b6 Fix text generation 2026-01-15 16:31:45 +00:00
ciaranbor
670a0f0c4a Rename mlx_generate_image to mflux_generate 2026-01-15 16:31:45 +00:00
ciaranbor
01a0d6d141 Initialize mlx or mflux engine based on model task 2026-01-15 16:31:45 +00:00
ciaranbor
761d2d82a7 Restore warmup for text generation 2026-01-15 16:31:45 +00:00
ciaranbor
248bea1839 Add initialize_mflux function 2026-01-15 16:31:45 +00:00
ciaranbor
f41d0129e5 Move image generation to mflux engine 2026-01-15 16:31:45 +00:00
ciaranbor
65f9d666b5 Just use str for image generation size 2026-01-15 16:31:45 +00:00
ciaranbor
7b13b361d0 Use MFlux for image generation 2026-01-15 16:31:45 +00:00
ciaranbor
dad82a605c Add get_model_card function 2026-01-15 16:31:45 +00:00
ciaranbor
6573b47abf Add ModelTask enum 2026-01-15 16:31:45 +00:00
ciaranbor
4596a7ac24 ADd flux1-schnell model 2026-01-15 16:31:45 +00:00
ciaranbor
e580b45eb2 Add task field to ModelCard 2026-01-15 16:31:04 +00:00
ciaranbor
def080c7e3 Update mflux version 2026-01-15 16:31:04 +00:00
ciaranbor
806239f14b Enable recursive repo downloads 2026-01-15 16:31:04 +00:00
ciaranbor
154f3561e7 Add dummy generate_image implementation 2026-01-15 16:31:04 +00:00
ciaranbor
07f7601948 Use base64 encoded str for image data 2026-01-15 16:31:03 +00:00
ciaranbor
229bd05473 Handle ImageGeneration tasks in _pending_tasks 2026-01-15 16:31:03 +00:00
ciaranbor
ac0c187aed Add mflux dependency 2026-01-15 16:31:03 +00:00
ciaranbor
083de373db Handle ImageGeneration task in runner task processing 2026-01-15 16:31:03 +00:00
ciaranbor
ae95172e41 Handle ImageGeneration command in master command processing 2026-01-15 16:31:03 +00:00
ciaranbor
a688001446 Add image generation to API 2026-01-15 16:31:03 +00:00
ciaranbor
796f291d85 Add ImageGenerationResponse 2026-01-15 16:31:03 +00:00
ciaranbor
73a09cf98c Add ImageGeneration task 2026-01-15 16:31:03 +00:00
ciaranbor
6c6dfd9ec7 Add ImageGeneration command 2026-01-15 16:31:03 +00:00
ciaranbor
41dbcf0b37 Add image generation params and response types 2026-01-15 16:31:03 +00:00
ciaranbor
b291950c1a Add pillow dependency 2026-01-15 16:31:03 +00:00
ciaranbor
f48b3dd870 Fix mlx stream_generate import 2026-01-15 16:31:03 +00:00
44 changed files with 8849 additions and 2345 deletions

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import { isLoading, sendMessage, generateImage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import ChatAttachments from './ChatAttachments.svelte';
import type { ChatUploadedFile } from '$lib/types/files';
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
@@ -10,6 +10,7 @@
showHelperText?: boolean;
autofocus?: boolean;
showModelSelector?: boolean;
modelTasks?: Record<string, string[]>;
}
let {
@@ -17,7 +18,8 @@
placeholder = 'Ask anything',
showHelperText = false,
autofocus = true,
showModelSelector = false
showModelSelector = false,
modelTasks = {}
}: Props = $props();
let message = $state('');
@@ -48,13 +50,29 @@
// Accept all supported file types
const acceptString = getAcceptString(['image', 'text', 'pdf']);
// Check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const tasks = modelTasks[modelId] || [];
return tasks.includes('TextToImage') || tasks.includes('ImageToImage');
}
// Check if the currently selected model supports image generation
const isImageModel = $derived(() => {
if (!currentModel) return false;
return modelSupportsImageGeneration(currentModel);
});
// Extract available models from running instances
const availableModels = $derived(() => {
const models: Array<{id: string, label: string}> = [];
const models: Array<{id: string, label: string, isImageModel: boolean}> = [];
for (const [, instance] of Object.entries(instanceData)) {
const modelId = getInstanceModelId(instance);
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
models.push({
id: modelId,
label: modelId.split('/').pop() || modelId,
isImageModel: modelSupportsImageGeneration(modelId)
});
}
}
return models;
@@ -160,7 +178,12 @@
uploadedFiles = [];
resetTextareaHeight();
sendMessage(content, files);
// Use image generation for image models
if (isImageModel() && content) {
generateImage(content);
} else {
sendMessage(content, files);
}
// Refocus the textarea after sending
setTimeout(() => textareaRef?.focus(), 10);
@@ -297,7 +320,14 @@
{:else}
<span class="w-3"></span>
{/if}
<span class="truncate">{model.label}</span>
{#if model.isImageModel}
<svg class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate flex-1">{model.label}</span>
</button>
{/each}
</div>
@@ -357,7 +387,7 @@
onkeydown={handleKeydown}
oninput={handleInput}
onpaste={handlePaste}
{placeholder}
placeholder={isImageModel() ? 'Describe the image you want to generate...' : placeholder}
disabled={loading}
rows={1}
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
@@ -371,14 +401,23 @@
{!canSend || loading
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label="Send message"
aria-label={isImageModel() ? "Generate image" : "Send message"}
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
<span class="hidden sm:inline">PROCESSING</span>
<span class="hidden sm:inline">{isImageModel() ? 'GENERATING' : 'PROCESSING'}</span>
<span class="sm:hidden">...</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}

View File

@@ -365,10 +365,58 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
</div>
{/if}
<!-- Generated Images -->
{#if message.attachments?.some(a => a.type === 'generated-image')}
<div class="mb-3">
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
<div class="relative group/img inline-block">
<img
src={attachment.preview}
alt=""
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
/>
<!-- Download button overlay -->
<button
type="button"
class="absolute top-2 right-2 p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow opacity-0 group-hover/img:opacity-100 transition-opacity hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
onclick={() => {
if (attachment.preview) {
const link = document.createElement('a');
link.href = attachment.preview;
link.download = `generated-image-${Date.now()}.png`;
link.click();
}
}}
title="Download image"
>
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
</svg>
</button>
</div>
{/each}
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{#if message.content === 'Generating image...'}
<div class="flex items-center gap-3 text-exo-yellow">
<div class="relative">
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
</div>
<span class="font-mono tracking-wider uppercase text-sm">Generating image...</span>
</div>
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
{/if}
</div>
</div>

View File

File diff suppressed because it is too large Load Diff

View File

@@ -47,7 +47,30 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
const modelTasks = $derived(() => {
const tasks: Record<string, string[]> = {};
for (const model of models) {
if (model.tasks && model.tasks.length > 0) {
// Map by short ID
tasks[model.id] = model.tasks;
// Also map by hugging_face_id from the API response
if (model.hugging_face_id) {
tasks[model.hugging_face_id] = model.tasks;
}
}
}
return tasks;
});
// Helper to check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
if (!model?.tasks) return false;
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
}
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -1250,6 +1273,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
placeholder="Ask anything"
showHelperText={false}
showModelSelector={true}
modelTasks={modelTasks()}
/>
</div>
</div>
@@ -1471,8 +1495,18 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{@const foundModel = models.find(m => m.id === selectedModelId)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="flex items-center gap-2 text-exo-light-gray truncate">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{foundModel.name || foundModel.id}</span>
</span>
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
</span>
{:else}
@@ -1517,6 +1551,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(model.id)}
<button
type="button"
onclick={() => {
@@ -1536,7 +1571,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
: 'text-white/30 cursor-default'
}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
@@ -1733,7 +1777,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
<div class="max-w-7xl mx-auto">
<ChatForm placeholder="Ask anything" showModelSelector={true} />
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
</div>
</div>
</div>

View File

@@ -23,6 +23,9 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux>=0.14.2",
"python-multipart>=0.0.21",
]
[project.scripts]

View File

@@ -1,11 +1,13 @@
import base64
import json
import time
from collections.abc import AsyncGenerator
from typing import cast
from typing import Literal, cast
import anyio
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
@@ -22,9 +24,10 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
BenchChatCompletionResponse,
@@ -37,6 +40,10 @@ from exo.shared.types.api import (
DeleteInstanceResponse,
FinishReason,
GenerationStats,
ImageData,
ImageEditsInternalParams,
ImageGenerationResponse,
ImageGenerationTaskParams,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -44,14 +51,17 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
@@ -87,12 +97,23 @@ def chunk_to_response(
)
async def resolve_model_meta(model_id: str) -> ModelMetadata:
def get_model_card(model_id: str) -> ModelCard | None:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for _, model_card in MODEL_CARDS.items():
if model_id == model_card.model_id:
return model_card
async def resolve_model_meta(model_id: str) -> ModelMetadata:
model_card = get_model_card(model_id)
if model_card is not None:
return model_card.metadata
else:
return await get_model_meta(model_id)
return await get_model_meta(model_id)
class API:
@@ -136,6 +157,7 @@ class API:
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -144,6 +166,7 @@ class API:
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._chat_completion_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
def unpause(self, result_clock: int):
@@ -176,6 +199,10 @@ class API:
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/v1/images/generations", response_model=None)(
self.image_generations
)
self.app.post("/v1/images/edits", response_model=None)(self.image_edits)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -595,6 +622,282 @@ class API:
)
return response
async def _validate_image_model(self, model: str) -> ModelId:
"""Validate model exists and return resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_meta = await resolve_model_meta(model)
resolved_model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved_model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {resolved_model}"
)
return resolved_model
async def image_generations(
self, payload: ImageGenerationTaskParams
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image generation requests.
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
command = ImageGeneration(
request_params=payload,
)
await self._send(command)
# Check if streaming is requested
if payload.stream and payload.partial_images and payload.partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
),
media_type="text/event-stream",
)
# Non-streaming: collect all image chunks
return await self._collect_image_generation(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
)
async def _generate_image_stream(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> AsyncGenerator[str, None]:
"""Generate SSE stream of partial and final images."""
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
image_total_chunks: dict[tuple[int, bool], int] = {}
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
with recv as chunks:
async for chunk in chunks:
key = (chunk.image_index, chunk.is_partial)
if key not in image_chunks:
image_chunks[key] = {}
image_total_chunks[key] = chunk.total_chunks
image_metadata[key] = (
chunk.partial_index,
chunk.total_partials,
)
image_chunks[key][chunk.chunk_index] = chunk.data
# Check if this image is complete
if len(image_chunks[key]) == image_total_chunks[key]:
full_data = "".join(
image_chunks[key][i] for i in range(len(image_chunks[key]))
)
partial_idx, total_partials = image_metadata[key]
if chunk.is_partial:
# Yield partial image event
event_data = {
"type": "partial",
"partial_index": partial_idx,
"total_partials": total_partials,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
else:
# Final image
event_data = {
"type": "final",
"image_index": chunk.image_index,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
images_complete += 1
if images_complete >= num_images:
yield "data: [DONE]\n\n"
break
# Clean up completed image chunks
del image_chunks[key]
del image_total_chunks[key]
del image_metadata[key]
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def _collect_image_generation(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> ImageGenerationResponse:
"""Collect all image chunks (non-streaming) and return a single response."""
# Track chunks per image: {image_index: {chunk_index: data}}
# Only track non-partial (final) images
image_chunks: dict[int, dict[int, str]] = {}
image_total_chunks: dict[int, int] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
while images_complete < num_images:
with recv as chunks:
async for chunk in chunks:
# Skip partial images in non-streaming mode
if chunk.is_partial:
continue
if chunk.image_index not in image_chunks:
image_chunks[chunk.image_index] = {}
image_total_chunks[chunk.image_index] = chunk.total_chunks
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
# Check if this image is complete
if (
len(image_chunks[chunk.image_index])
== image_total_chunks[chunk.image_index]
):
images_complete += 1
if images_complete >= num_images:
break
# Reassemble images in order
images: list[ImageData] = []
for image_idx in range(num_images):
chunks_dict = image_chunks[image_idx]
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
images.append(
ImageData(
b64_json=full_data if response_format == "b64_json" else None,
url=None, # URL format not implemented yet
)
)
return ImageGenerationResponse(data=images)
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def image_edits(
self,
image: UploadFile = File(...),
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
stream: bool = Form(False),
partial_images: int = Form(0),
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image editing requests (img2img)."""
resolved_model = await self._validate_image_model(model)
# Read and base64 encode the uploaded image
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
# Map input_fidelity to image_strength
image_strength = 0.7 if input_fidelity == "high" else 0.3
# Split image into chunks to stay under gossipsub message size limit
data_chunks = [
image_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
# Create command first to get command_id
command = ImageEdits(
request_params=ImageEditsInternalParams(
image_data="", # Empty - will be assembled at worker from chunks
total_input_chunks=total_chunks,
prompt=prompt,
model=resolved_model,
n=n,
size=size,
response_format=response_format,
image_strength=image_strength,
stream=stream,
partial_images=partial_images,
),
)
# Send input chunks BEFORE the command
logger.info(
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(data_chunks):
await self._send(
SendInputChunk(
chunk=InputImageChunk(
idx=chunk_index,
model=resolved_model,
command_id=command.command_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
)
)
)
# Now send the main command
await self._send(command)
num_images = n
# Check if streaming is requested
if stream and partial_images and partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=num_images,
response_format=response_format,
),
media_type="text/event-stream",
)
# Non-streaming: collect all image chunks
return await self._collect_image_generation(
command_id=command.command_id,
num_images=num_images,
response_format=response_format,
)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -617,6 +920,7 @@ class API:
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
tasks=[task.value for task in card.tasks],
)
for card in MODEL_CARDS.values()
]
@@ -654,14 +958,17 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if (
isinstance(event, ChunkGenerated)
and event.command_id in self._chat_completion_queues
):
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
if isinstance(event, ChunkGenerated):
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
elif event.command_id in self._image_generation_queues:
assert isinstance(event.chunk, ImageChunk)
await self._image_generation_queues[event.command_id].send(
event.chunk
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -16,8 +16,11 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskFinished,
TestCommand,
)
@@ -26,6 +29,7 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
NodeTimedOut,
TaskCreated,
@@ -35,6 +39,12 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
from exo.shared.types.tasks import (
ImageGeneration as ImageGenerationTask,
)
from exo.shared.types.tasks import (
TaskId,
TaskStatus,
@@ -99,13 +109,14 @@ class Master:
async for forwarder_command in commands:
try:
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.command
instance_task_counts: dict[InstanceId, int] = {}
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
@@ -146,6 +157,90 @@ class Master:
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageGeneration():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageEdits():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
@@ -173,6 +268,13 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
generated_events.append(
InputChunkReceived(
command_id=chunk.command_id,
chunk=chunk,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -9,6 +9,7 @@ from exo.shared.types.events import (
ChunkGenerated,
Event,
IndexedEvent,
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
NodeCreated,
@@ -40,8 +41,8 @@ def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged()
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
): # Pass-through events that don't modify state
return state
case InstanceCreated():
return apply_instance_created(event, state)

View File

@@ -44,3 +44,5 @@ LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024

View File

@@ -1,5 +1,5 @@
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.models import ComponentInfo, ModelId, ModelMetadata, ModelTask
from exo.utils.pydantic_ext import CamelCaseModel
@@ -8,6 +8,7 @@ class ModelCard(CamelCaseModel):
model_id: ModelId
name: str
description: str
tasks: list[ModelTask]
tags: list[str]
metadata: ModelMetadata
@@ -19,6 +20,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
@@ -34,6 +36,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
@@ -50,6 +53,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
@@ -65,6 +69,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
@@ -81,6 +86,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
@@ -96,6 +102,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
@@ -111,6 +118,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
@@ -126,6 +134,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
@@ -142,6 +151,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
@@ -157,6 +167,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
@@ -172,6 +183,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
@@ -188,6 +200,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
@@ -203,6 +216,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
@@ -218,6 +232,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
@@ -234,6 +249,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
@@ -249,6 +265,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
@@ -264,6 +281,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
@@ -279,6 +297,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
@@ -294,6 +313,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
@@ -309,6 +329,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
@@ -324,6 +345,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
@@ -339,6 +361,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
@@ -354,6 +377,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
@@ -369,6 +393,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
@@ -384,6 +409,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
@@ -399,6 +425,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
@@ -415,6 +442,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
@@ -430,6 +458,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
@@ -447,6 +476,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
@@ -462,6 +492,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
@@ -478,6 +509,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
@@ -493,6 +525,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
@@ -508,6 +541,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
@@ -524,6 +558,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
@@ -539,6 +574,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
@@ -549,4 +585,188 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
"flux1-schnell": ModelCard(
short_id="flux1-schnell",
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
name="FLUX.1 [schnell]",
description="""FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
pretty_name="FLUX.1 [schnell]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120), # + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"flux1-dev": ModelCard(
short_id="flux1-dev",
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
name="FLUX.1 [dev]",
description="""FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
pretty_name="FLUX.1 [dev]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"qwen-image": ModelCard(
short_id="qwen-image",
model_id=ModelId("Qwen/Qwen-Image"),
name="Qwen Image",
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("Qwen/Qwen-Image"),
pretty_name="Qwen Image",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"qwen-image-edit-2509": ModelCard(
short_id="qwen-image-edit-2509",
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
name="Qwen Image Edit 2509",
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
tasks=[ModelTask.ImageToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
pretty_name="Qwen Image Edit 2509",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
}

View File

@@ -1,6 +1,8 @@
import time
from collections.abc import Generator
from typing import Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
@@ -28,6 +30,7 @@ class ModelListModel(BaseModel):
tags: list[str] = Field(default=[])
storage_size_megabytes: int = Field(default=0)
supports_tensor: bool = Field(default=False)
tasks: list[str] = Field(default=[])
class ModelList(BaseModel):
@@ -202,3 +205,75 @@ class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class ImageGenerationTaskParams(BaseModel):
prompt: str
# background: str | None = None
model: str
# moderation: str | None = None
n: int | None = 1
# output_compression: int | None = None
output_format: Literal["png", "jpeg", "webp"] = "png"
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
stream: bool | None = False
# style: str | None = "vivid"
# user: str | None = None
class ImageEditsTaskParams(BaseModel):
image: UploadFile
prompt: str
input_fidelity: float = 0.7
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
# user: str | None = None
class ImageEditsInternalParams(BaseModel):
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
total_input_chunks: int = 0
prompt: str
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
image_strength: float = 0.7
stream: bool = False
partial_images: int | None = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} chars>"
elif name is not None:
yield name, value
class ImageData(BaseModel):
b64_json: str | None = None
url: str | None = None
revised_prompt: str | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "b64_json" and self.b64_json is not None:
yield name, f"<{len(self.b64_json)} chars>"
elif name is not None:
yield name, value
class ImageGenerationResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
data: list[ImageData]

View File

@@ -1,9 +1,12 @@
from collections.abc import Generator
from enum import Enum
from typing import Any
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .common import CommandId
from .models import ModelId
@@ -25,7 +28,34 @@ class TokenChunk(BaseChunk):
class ImageChunk(BaseChunk):
data: bytes
data: str
chunk_index: int
total_chunks: int
image_index: int
is_partial: bool = False
partial_index: int | None = None
total_partials: int | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data" and hasattr(value, "__len__"):
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
class InputImageChunk(BaseChunk):
command_id: CommandId
data: str
chunk_index: int
total_chunks: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data" and hasattr(value, "__len__"):
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
GenerationChunk = TokenChunk | ImageChunk

View File

@@ -1,6 +1,11 @@
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -20,6 +25,14 @@ class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
class ImageGeneration(BaseCommand):
request_params: ImageGenerationTaskParams
class ImageEdits(BaseCommand):
request_params: ImageEditsInternalParams
class PlaceInstance(BaseCommand):
model_meta: ModelMetadata
sharding: Sharding
@@ -39,6 +52,12 @@ class TaskFinished(BaseCommand):
finished_command_id: CommandId
class SendInputChunk(BaseCommand):
"""Command to send an input image chunk (converted to event by master)."""
chunk: InputImageChunk
class RequestEventLog(BaseCommand):
since_idx: int
@@ -47,10 +66,13 @@ Command = (
TestCommand
| RequestEventLog
| ChatCompletion
| ImageGeneration
| ImageEdits
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskFinished
| SendInputChunk
)

View File

@@ -3,7 +3,7 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
@@ -106,6 +106,11 @@ class ChunkGenerated(BaseEvent):
chunk: GenerationChunk
class InputChunkReceived(BaseEvent):
command_id: CommandId
chunk: InputImageChunk
class TopologyEdgeCreated(BaseEvent):
edge: Connection
@@ -131,6 +136,7 @@ Event = (
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| TopologyEdgeCreated
| TopologyEdgeDeleted
)

View File

@@ -1,3 +1,5 @@
from enum import Enum
from pydantic import PositiveInt
from exo.shared.types.common import Id
@@ -9,6 +11,21 @@ class ModelId(Id):
pass
class ModelTask(str, Enum):
TextGeneration = "TextGeneration"
TextToImage = "TextToImage"
ImageToImage = "ImageToImage"
class ComponentInfo(CamelCaseModel):
component_name: str
component_path: str
storage_size: Memory
n_layers: PositiveInt | None
can_shard: bool
safetensors_index_filename: str | None
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
@@ -16,3 +33,4 @@ class ModelMetadata(CamelCaseModel):
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
components: list[ComponentInfo] | None = None

View File

@@ -2,7 +2,11 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.common import CommandId, Id
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
@@ -56,6 +60,22 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class ImageEdits(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageEditsInternalParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
@@ -67,5 +87,7 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| ImageGeneration
| ImageEdits
| Shutdown
)

View File

@@ -1,3 +1,6 @@
from collections.abc import Generator
from typing import Any, Literal
from exo.shared.types.api import FinishReason, GenerationStats
from exo.utils.pydantic_ext import TaggedModel
@@ -18,5 +21,31 @@ class GenerationResponse(BaseRunnerResponse):
stats: GenerationStats | None = None
class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
class PartialImageResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
partial_index: int
total_partials: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -9,6 +9,7 @@ from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal
from urllib.parse import urljoin
from huggingface_hub._snapshot_download import snapshot_download
import aiofiles
import aiofiles.os as aios
@@ -441,12 +442,31 @@ def calculate_repo_progress(
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
index_files_dir = snapshot_download(
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
return index_data.weight_map
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
weight_map: dict[str, str] = {}
for index_file in index_files:
relative_dir = index_file.parent.relative_to(index_files_dir)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
if relative_dir != Path("."):
prefixed_weight_map = {
f"{relative_dir}/{key}": str(relative_dir / value)
for key, value in index_data.weight_map.items()
}
weight_map = weight_map | prefixed_weight_map
else:
weight_map = weight_map | index_data.weight_map
return weight_map
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
@@ -551,8 +571,6 @@ async def download_shard(
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_meta.model_id), revision, recursive=True
)

View File

@@ -100,26 +100,68 @@ def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list
"*.py",
"tokenizer.model",
"tiktoken.model",
"*/spiece.model",
"*.tiktoken",
"*.txt",
"*.jinja",
]
)
shard_specific_patterns: set[str] = set()
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num <= shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
if shard.model_meta.components is not None:
shardable_component = next(
(c for c in shard.model_meta.components if c.can_shard), None
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
if weight_map and shardable_component:
for tensor_name, filename in weight_map.items():
# Strip component prefix from tensor name (added by weight map namespacing)
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
if "/" in tensor_name:
_, tensor_name_no_prefix = tensor_name.split("/", 1)
else:
tensor_name_no_prefix = tensor_name
# Determine which component this file belongs to from filename
component_path = Path(filename).parts[0] if "/" in filename else None
if component_path == shardable_component.component_path.rstrip("/"):
layer_num = extract_layer_num(tensor_name_no_prefix)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
if shard.is_first_layer or shard.is_last_layer:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns = set(["*.safetensors"])
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
for component in shard.model_meta.components:
if not component.can_shard and component.safetensors_index_filename is None:
component_pattern = f"{component.component_path.rstrip('/')}/*"
shard_specific_patterns.add(component_pattern)
else:
shard_specific_patterns = set(["*.safetensors"])
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
else:
shard_specific_patterns = set(["*.safetensors"])
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

View File

@@ -0,0 +1,10 @@
from exo.worker.engines.image.base import ImageGenerator
from exo.worker.engines.image.distributed_model import initialize_image_model
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
__all__ = [
"ImageGenerator",
"generate_image",
"initialize_image_model",
"warmup_image_generator",
]

View File

@@ -0,0 +1,50 @@
from collections.abc import Generator
from pathlib import Path
from typing import Literal, Protocol, runtime_checkable
from PIL import Image
@runtime_checkable
class ImageGenerator(Protocol):
@property
def rank(self) -> int: ...
@property
def is_first_stage(self) -> bool: ...
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"],
seed: int,
image_path: Path | None = None,
partial_images: int = 0,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
"""Generate an image from a text prompt, or edit an existing image.
For distributed inference, only the last stage returns images.
Other stages yield nothing after participating in the pipeline.
When partial_images > 0, yields intermediate images during diffusion
as tuples of (image, partial_index, total_partials), then yields
the final image.
When partial_images = 0 (default), only yields the final image.
Args:
prompt: Text description of the image to generate
height: Image height in pixels
width: Image width in pixels
quality: Generation quality level
seed: Random seed for reproducibility
image_path: Optional path to input image for image editing
partial_images: Number of intermediate images to yield (0 for none)
Yields:
Intermediate images as (Image, partial_index, total_partials) tuples
Final PIL Image (last stage) or nothing (other stages)
"""
...

View File

@@ -0,0 +1,74 @@
from enum import Enum
from math import ceil
from pydantic import BaseModel
class BlockType(Enum):
JOINT = "joint" # Separate image/text streams
SINGLE = "single" # Concatenated streams
class TransformerBlockConfig(BaseModel):
model_config = {"frozen": True}
block_type: BlockType
count: int
has_separate_text_output: bool # True for joint blocks that output text separately
class ImageModelConfig(BaseModel):
model_config = {"frozen": True}
# Model identification
model_family: str # "flux", "fibo", "qwen"
model_variant: str # "schnell", "dev", etc.
# Architecture parameters
hidden_dim: int
num_heads: int
head_dim: int
# Block configuration - ordered sequence of block types
block_configs: tuple[TransformerBlockConfig, ...]
# Tokenization parameters
patch_size: int # 2 for Flux/Qwen
vae_scale_factor: int # 8 for Flux, 16 for others
# Inference parameters
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
num_sync_steps_factor: float # Fraction of steps for sync phase
# Feature flags
uses_attention_mask: bool # True for Fibo
# CFG (Classifier-Free Guidance) parameters
guidance_scale: float | None = None # None or <= 1.0 disables CFG
@property
def total_blocks(self) -> int:
"""Total number of transformer blocks."""
return sum(bc.count for bc in self.block_configs)
@property
def joint_block_count(self) -> int:
"""Number of joint transformer blocks."""
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
)
@property
def single_block_count(self) -> int:
"""Number of single transformer blocks."""
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
)
def get_steps_for_quality(self, quality: str) -> int:
"""Get inference steps for a quality level."""
return self.default_steps[quality]
def get_num_sync_steps(self, quality: str) -> int:
"""Get number of synchronous steps based on quality."""
return ceil(self.default_steps[quality] * self.num_sync_steps_factor)

View File

@@ -0,0 +1,229 @@
from collections.abc import Generator
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional
import mlx.core as mx
from mflux.models.common.config.config import Config
from PIL import Image
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,
get_config_for_model,
)
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline import DiffusionRunner
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
from exo.worker.runner.bootstrap import logger
class DistributedImageModel:
__slots__ = (
"_config",
"_adapter",
"_group",
"_shard_metadata",
"_runner",
)
_config: ImageModelConfig
_adapter: BaseModelAdapter
_group: Optional[mx.distributed.Group]
_shard_metadata: PipelineShardMetadata
_runner: DiffusionRunner
def __init__(
self,
model_id: str,
local_path: Path,
shard_metadata: PipelineShardMetadata,
group: Optional[mx.distributed.Group] = None,
quantize: int | None = None,
):
# Get model config and create adapter (adapter owns the model)
config = get_config_for_model(model_id)
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
if group is not None:
adapter.slice_transformer_blocks(
start_layer=shard_metadata.start_layer,
end_layer=shard_metadata.end_layer,
total_joint_blocks=config.joint_block_count,
total_single_blocks=config.single_block_count,
)
# Create diffusion runner (handles both single-node and distributed modes)
num_sync_steps = config.get_num_sync_steps("medium") if group else 0
runner = DiffusionRunner(
config=config,
adapter=adapter,
group=group,
shard_metadata=shard_metadata,
num_sync_steps=num_sync_steps,
)
if group is not None:
logger.info("Initialized distributed diffusion runner")
mx.eval(adapter.model.parameters())
# TODO(ciaran): Do we need this?
mx.eval(adapter.model)
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
logger.info(f"Transformer sharded for rank {group.rank()}")
else:
logger.info("Single-node initialization")
object.__setattr__(self, "_config", config)
object.__setattr__(self, "_adapter", adapter)
object.__setattr__(self, "_group", group)
object.__setattr__(self, "_shard_metadata", shard_metadata)
object.__setattr__(self, "_runner", runner)
@classmethod
def from_bound_instance(
cls, bound_instance: BoundInstance
) -> "DistributedImageModel":
model_id = bound_instance.bound_shard.model_meta.model_id
model_path = build_model_path(model_id)
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, PipelineShardMetadata):
raise ValueError("Expected PipelineShardMetadata for image generation")
is_distributed = (
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
)
if is_distributed:
logger.info("Starting distributed init for image model")
group = mlx_distributed_init(bound_instance)
else:
group = None
return cls(
model_id=model_id,
local_path=model_path,
shard_metadata=shard_metadata,
group=group,
)
@property
def model(self) -> Any:
"""Return the underlying mflux model via the adapter."""
return self._adapter.model
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def adapter(self) -> BaseModelAdapter:
return self._adapter
@property
def group(self) -> Optional[mx.distributed.Group]:
return self._group
@property
def shard_metadata(self) -> PipelineShardMetadata:
return self._shard_metadata
@property
def rank(self) -> int:
return self._shard_metadata.device_rank
@property
def world_size(self) -> int:
return self._shard_metadata.world_size
@property
def is_first_stage(self) -> bool:
return self._shard_metadata.device_rank == 0
@property
def is_last_stage(self) -> bool:
return self._shard_metadata.device_rank == self._shard_metadata.world_size - 1
@property
def is_distributed(self) -> bool:
return self._shard_metadata.world_size > 1
@property
def runner(self) -> DiffusionRunner:
return self._runner
# Delegate attribute access to the underlying model via the adapter.
# Guarded with TYPE_CHECKING to prevent type checker complaints
# while still providing full delegation at runtime.
if not TYPE_CHECKING:
def __getattr__(self, name: str) -> Any:
return getattr(self._adapter.model, name)
def __setattr__(self, name: str, value: Any) -> None:
if name in (
"_config",
"_adapter",
"_group",
"_shard_metadata",
"_runner",
):
object.__setattr__(self, name, value)
else:
setattr(self._adapter.model, name, value)
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"] = "medium",
seed: int = 2,
image_path: Path | None = None,
partial_images: int = 0,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
# Determine number of inference steps based on quality
steps = self._config.get_steps_for_quality(quality)
# For edit mode: compute dimensions from input image
# This also stores image_paths in the adapter for encode_prompt()
if image_path is not None:
computed_dims = self._adapter.set_image_dimensions(image_path)
if computed_dims is not None:
# Override user-provided dimensions with computed ones
width, height = computed_dims
config = Config(
num_inference_steps=steps,
height=height,
width=width,
image_path=image_path,
model_config=self._adapter.model.model_config,
)
# Generate images via the runner
for result in self._runner.generate_image(
runtime_config=config,
prompt=prompt,
seed=seed,
partial_images=partial_images,
):
if isinstance(result, tuple):
# Partial image: (GeneratedImage, partial_index, total_partials)
generated_image, partial_idx, total_partials = result
yield (generated_image.image, partial_idx, total_partials)
else:
# Final image: GeneratedImage
logger.info("generated image")
yield result.image
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
"""Initialize DistributedImageModel from a BoundInstance."""
return DistributedImageModel.from_bound_instance(bound_instance)

View File

@@ -0,0 +1,120 @@
import base64
import io
import tempfile
from pathlib import Path
from typing import Generator, Literal
from PIL import Image
from exo.shared.types.api import ImageEditsInternalParams, ImageGenerationTaskParams
from exo.shared.types.worker.runner_response import (
ImageGenerationResponse,
PartialImageResponse,
)
from exo.worker.engines.image.base import ImageGenerator
def parse_size(size_str: str | None) -> tuple[int, int]:
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
if not size_str or size_str == "auto":
size_str = "1024x1024"
try:
parts = size_str.split("x")
if len(parts) == 2:
width, height = int(parts[0]), int(parts[1])
return (width, height)
except (ValueError, AttributeError):
pass
# Default fallback
return (1024, 1024)
def warmup_image_generator(model: ImageGenerator) -> Image.Image | None:
"""Warmup the image generator with a small image."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a small dummy image for warmup (needed for edit models)
dummy_image = Image.new("RGB", (256, 256), color=(128, 128, 128))
dummy_path = Path(tmpdir) / "warmup.png"
dummy_image.save(dummy_path)
for result in model.generate(
prompt="Warmup",
height=256,
width=256,
quality="low",
seed=2,
image_path=dummy_path,
):
if not isinstance(result, tuple):
return result
return None
def generate_image(
model: ImageGenerator,
task: ImageGenerationTaskParams | ImageEditsInternalParams,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
When partial_images > 0 or stream=True, yields PartialImageResponse for
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0)
ImageGenerationResponse for the final complete image
"""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
seed = 2 # TODO(ciaran): Randomise when not testing anymore
# Handle streaming params for both generation and edit tasks
partial_images = task.partial_images or (3 if task.stream else 0)
image_path: Path | None = None
with tempfile.TemporaryDirectory() as tmpdir:
if isinstance(task, ImageEditsInternalParams):
# Decode base64 image data and save to temp file
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
# Iterate over generator results
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
image_path=image_path,
partial_images=partial_images,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
)
else:
# Final image
image = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
)

View File

@@ -0,0 +1,84 @@
from pathlib import Path
from typing import Callable
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.flux import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
FluxModelAdapter,
)
from exo.worker.engines.image.models.qwen import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
QwenEditModelAdapter,
QwenModelAdapter,
)
from exo.worker.engines.image.pipeline.adapter import ModelAdapter
__all__: list[str] = []
# Type alias for adapter factory functions
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter]
# Registry maps model_family string to adapter factory
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
"flux": FluxModelAdapter,
"qwen-edit": QwenEditModelAdapter,
"qwen": QwenModelAdapter,
}
# Config registry: maps model ID patterns to configs
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
"qwen-image": QWEN_IMAGE_CONFIG,
}
def get_config_for_model(model_id: str) -> ImageModelConfig:
"""Get configuration for a model ID.
Args:
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
Returns:
The model configuration
Raises:
ValueError: If no configuration found for model ID
"""
model_id_lower = model_id.lower()
for pattern, config in _CONFIG_REGISTRY.items():
if pattern in model_id_lower:
return config
raise ValueError(f"No configuration found for model: {model_id}")
def create_adapter_for_model(
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
) -> ModelAdapter:
"""Create a model adapter for the given configuration.
Args:
config: The model configuration
model_id: The model identifier
local_path: Path to the model weights
quantize: Optional quantization bits
Returns:
A ModelAdapter instance
Raises:
ValueError: If no adapter found for model family
"""
factory = _ADAPTER_REGISTRY.get(config.model_family)
if factory is None:
raise ValueError(f"No adapter found for model family: {config.model_family}")
return factory(config, model_id, local_path, quantize)

View File

@@ -0,0 +1,105 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
from mflux.utils.image_util import ImageUtil
class BaseModelAdapter(ABC):
"""Base class for model adapters with shared utilities.
Provides common implementations for latent creation and decoding.
Subclasses implement model-specific prompt encoding and noise computation.
"""
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial latents. Uses model-specific latent creator."""
return LatentCreator.create_for_txt2img_or_img2img(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
img2img=Img2Img(
vae=self.model.vae,
latent_creator=self._get_latent_creator(),
sigmas=runtime_config.scheduler.sigmas,
init_time_step=runtime_config.init_time_step,
image_path=runtime_config.image_path,
),
)
def decode_latents(
self,
latents: mx.array,
runtime_config: Config,
seed: int,
prompt: str,
) -> Any:
"""Decode latents to image. Shared implementation."""
latents = self._get_latent_creator().unpack_latents(
latents=latents,
height=runtime_config.height,
width=runtime_config.width,
)
decoded = self.model.vae.decode(latents)
# TODO(ciaran):
# from mflux.models.common.vae.vae_util import VAEUtil
# VAEUtil.decode(vae=self.model.vae, latents=latents, tiling_config=self.tiling_config)
return ImageUtil.to_image(
decoded_latents=decoded,
config=runtime_config,
seed=seed,
prompt=prompt,
quantization=self.model.bits,
lora_paths=self.model.lora_paths,
lora_scales=self.model.lora_scales,
image_path=runtime_config.image_path,
image_strength=runtime_config.image_strength,
generation_time=0,
)
# Abstract methods - subclasses must implement
@property
@abstractmethod
def model(self) -> Any:
"""Return the underlying mflux model."""
...
@abstractmethod
def _get_latent_creator(self) -> type:
"""Return the latent creator class for this model."""
...
@abstractmethod
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
):
"""Remove transformer blocks outside the assigned range.
This should be called BEFORE mx.eval() to avoid loading unused weights
in distributed mode.
Args:
start_layer: First layer index (inclusive) assigned to this node
end_layer: Last layer index (exclusive) assigned to this node
total_joint_blocks: Total number of joint blocks in the model
total_single_blocks: Total number of single blocks in the model
"""
...
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
"""Default implementation: no dimension computation needed.
Override in edit adapters to compute dimensions from input image.
Returns:
None (use user-specified dimensions)
"""
return None

View File

@@ -0,0 +1,11 @@
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
from exo.worker.engines.image.models.flux.config import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
)
__all__ = [
"FluxModelAdapter",
"FLUX_DEV_CONFIG",
"FLUX_SCHNELL_CONFIG",
]

View File

@@ -0,0 +1,684 @@
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
AttentionUtils,
)
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
JointTransformerBlock,
)
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.variants.txt2img.flux import Flux1
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class FluxPromptData:
"""Container for Flux prompt encoding results."""
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
self._prompt_embeds = prompt_embeds
self._pooled_prompt_embeds = pooled_prompt_embeds
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._pooled_prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Flux has no extra forward kwargs."""
return {}
@property
def conditioning_latents(self) -> mx.array | None:
"""Flux does not use conditioning latents."""
return None
class FluxModelAdapter(BaseModelAdapter):
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = Flux1(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
model_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> Flux1:
return self._model
@property
def transformer(self) -> Transformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.x_embedder.weight.shape[0]
def _get_latent_creator(self) -> type:
return FluxLatentCreator
def encode_prompt(self, prompt: str) -> FluxPromptData:
"""Encode prompt into FluxPromptData."""
assert isinstance(self.model.prompt_cache, dict)
assert isinstance(self.model.tokenizers, dict)
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
prompt_cache=self.model.prompt_cache,
t5_tokenizer=self.model.tokenizers["t5"],
clip_tokenizer=self.model.tokenizers["clip"],
t5_text_encoder=self.model.t5_text_encoder,
clip_text_encoder=self.model.clip_text_encoder,
)
return FluxPromptData(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)
@property
def needs_cfg(self) -> bool:
return False
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
raise NotImplementedError("Flux does not use classifier-free guidance")
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.x_embedder(hidden_states)
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None, # Ignored by Flux
) -> mx.array:
if pooled_prompt_embeds is None:
raise ValueError(
"pooled_prompt_embeds is required for Flux text embeddings"
)
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
return Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
**kwargs: Any,
) -> mx.array:
kontext_image_ids = kwargs.get("kontext_image_ids")
return Transformer.compute_rotary_embeddings(
prompt_embeds,
self._transformer.pos_embed,
runtime_config,
kontext_image_ids,
)
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # mx.array for Flux
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
if mode == BlockWrapperMode.CACHING:
return self._apply_joint_block_caching(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
)
else:
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_joint_block_patched(
block=block,
patch_hidden=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
if mode == BlockWrapperMode.CACHING:
return self._apply_single_block_caching(
block=block,
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
)
else:
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_single_block_patched(
block=block,
patch_hidden=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)
def get_joint_blocks(self) -> list[JointBlockInterface]:
return cast(
list[JointBlockInterface], list(self._transformer.transformer_blocks)
)
def get_single_blocks(self) -> list[SingleBlockInterface]:
return cast(
list[SingleBlockInterface],
list(self._transformer.single_transformer_blocks),
)
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
) -> None:
if end_layer <= total_joint_blocks:
# All assigned are joint blocks
joint_start, joint_end = start_layer, end_layer
single_start, single_end = 0, 0
elif start_layer >= total_joint_blocks:
# All assigned are single blocks
joint_start, joint_end = 0, 0
single_start = start_layer - total_joint_blocks
single_end = end_layer - total_joint_blocks
else:
# Spans both joint and single
joint_start, joint_end = start_layer, total_joint_blocks
single_start = 0
single_end = end_layer - total_joint_blocks
all_joint = list(self._transformer.transformer_blocks)
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
all_single = list(self._transformer.single_transformer_blocks)
self._transformer.single_transformer_blocks = all_single[
single_start:single_end
]
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def _apply_joint_block_caching(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
) -> tuple[mx.array, mx.array]:
num_img_tokens = hidden_states.shape[1]
batch_size = hidden_states.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# 1. Compute norms
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
)
# 2. Compute Q, K, V for full image
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Compute Q, K, V for text
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 4. Concatenate Q, K, V: [text, image]
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
# 5. Apply RoPE
query, key = AttentionUtils.apply_rope(
xq=query, xk=key, freqs_cis=rotary_embeddings
)
# 6. Store IMAGE K/V in cache for async pipeline
if kv_cache is not None:
kv_cache.update_image_patch(
patch_start=0,
patch_end=num_img_tokens,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 7. Compute full attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 8. Extract and project outputs
context_attn_output = attn_output[:, :text_seq_len, :]
attn_output = attn_output[:, text_seq_len:, :]
attn_output = attn.to_out[0](attn_output)
context_attn_output = attn.to_add_out(context_attn_output)
# 9. Apply norm and feed forward
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=hidden_states,
attn_output=attn_output,
gate_mlp=gate_mlp,
gate_msa=gate_msa,
scale_mlp=scale_mlp,
shift_mlp=shift_mlp,
norm_layer=block.norm2,
ff_layer=block.ff,
)
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=encoder_hidden_states,
attn_output=context_attn_output,
gate_mlp=c_gate_mlp,
gate_msa=c_gate_msa,
scale_mlp=c_scale_mlp,
shift_mlp=c_shift_mlp,
norm_layer=block.norm2_context,
ff_layer=block.ff_context,
)
return encoder_hidden_states, hidden_states
def _apply_joint_block_patched(
self,
block: JointBlockInterface,
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
) -> tuple[mx.array, mx.array]:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# 1. Compute norms
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
hidden_states=patch_hidden,
text_embeddings=text_embeddings,
)
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
)
# 2. Compute Q, K, V for image patch
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Compute Q, K, V for text
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 4. Concatenate Q, K, V for patch: [text, patch]
query = mx.concatenate([txt_query, img_query], axis=2)
patch_key = mx.concatenate([txt_key, img_key], axis=2)
patch_value = mx.concatenate([txt_value, img_value], axis=2)
# 5. Extract RoPE for [text + current_patch]
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
]
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
# 6. Apply RoPE
query, patch_key = AttentionUtils.apply_rope(
xq=query, xk=patch_key, freqs_cis=patch_rope
)
# 7. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=patch_key[:, :, text_seq_len:, :],
value=patch_value[:, :, text_seq_len:, :],
)
# 8. Get full K, V from cache
full_key, full_value = kv_cache.get_full_kv(
text_key=patch_key[:, :, :text_seq_len, :],
text_value=patch_value[:, :, :text_seq_len, :],
)
# 9. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=full_key,
value=full_value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 10. Extract and project outputs
context_attn_output = attn_output[:, :text_seq_len, :]
hidden_attn_output = attn_output[:, text_seq_len:, :]
hidden_attn_output = attn.to_out[0](hidden_attn_output)
context_attn_output = attn.to_add_out(context_attn_output)
# 11. Apply norm and feed forward
patch_hidden = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=patch_hidden,
attn_output=hidden_attn_output,
gate_mlp=gate_mlp,
gate_msa=gate_msa,
scale_mlp=scale_mlp,
shift_mlp=shift_mlp,
norm_layer=block.norm2,
ff_layer=block.ff,
)
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=encoder_hidden_states,
attn_output=context_attn_output,
gate_mlp=c_gate_mlp,
gate_msa=c_gate_msa,
scale_mlp=c_scale_mlp,
shift_mlp=c_shift_mlp,
norm_layer=block.norm2_context,
ff_layer=block.ff_context,
)
return encoder_hidden_states, patch_hidden
def _apply_single_block_caching(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
) -> mx.array:
total_seq_len = hidden_states.shape[1]
num_img_tokens = total_seq_len - text_seq_len
batch_size = hidden_states.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# Residual connection
residual = hidden_states
# 1. Compute norm
norm_hidden, gate = block.norm(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V
query, key, value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Apply RoPE
query, key = AttentionUtils.apply_rope(
xq=query, xk=key, freqs_cis=rotary_embeddings
)
# 4. Store IMAGE K/V in cache
if kv_cache is not None:
kv_cache.update_image_patch(
patch_start=0,
patch_end=num_img_tokens,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 5. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 6. Apply feed forward and projection
hidden_states = block._apply_feed_forward_and_projection(
norm_hidden_states=norm_hidden,
attn_output=attn_output,
gate=gate,
)
return residual + hidden_states
def _apply_single_block_patched(
self,
block: SingleBlockInterface,
patch_hidden: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
) -> mx.array:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# Residual connection
residual = patch_hidden
# 1. Compute norm
norm_hidden, gate = block.norm(
hidden_states=patch_hidden,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V
query, key, value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Extract RoPE for [text + current_patch]
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
]
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
# 4. Apply RoPE
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=patch_rope)
# 5. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 6. Get full K, V from cache
full_key, full_value = kv_cache.get_full_kv(
text_key=key[:, :, :text_seq_len, :],
text_value=value[:, :, :text_seq_len, :],
)
# 7. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=full_key,
value=full_value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 8. Apply feed forward and projection
hidden_states = block._apply_feed_forward_and_projection(
norm_hidden_states=norm_hidden,
attn_output=attn_output,
gate=gate,
)
return residual + hidden_states

View File

@@ -0,0 +1,48 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
FLUX_SCHNELL_CONFIG = ImageModelConfig(
model_family="flux",
model_variant="schnell",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
patch_size=2,
vae_scale_factor=8,
default_steps={"low": 1, "medium": 2, "high": 4},
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
uses_attention_mask=False,
)
FLUX_DEV_CONFIG = ImageModelConfig(
model_family="flux",
model_variant="dev",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
patch_size=2,
vae_scale_factor=8,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
uses_attention_mask=False,
)

View File

@@ -0,0 +1,13 @@
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
from exo.worker.engines.image.models.qwen.config import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
)
from exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter
__all__ = [
"QwenModelAdapter",
"QwenEditModelAdapter",
"QWEN_IMAGE_CONFIG",
"QWEN_IMAGE_EDIT_CONFIG",
]

View File

@@ -0,0 +1,523 @@
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
from mflux.models.common.config import ModelConfig
from mflux.models.common.config.config import Config
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
QwenPromptEncoder,
)
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
QwenTransformerBlock,
)
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class QwenPromptData:
"""Container for Qwen prompt encoding results.
Implements PromptData protocol with additional Qwen-specific attributes.
"""
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
):
self._prompt_embeds = prompt_embeds
self.prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self.negative_prompt_mask = negative_prompt_mask
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
return self._prompt_embeds # Use prompt_embeds as placeholder
@property
def negative_prompt_embeds(self) -> mx.array:
"""Negative prompt embeddings for CFG."""
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
"""Placeholder - Qwen doesn't use pooled embeds."""
return self._negative_prompt_embeds
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Return encoder_hidden_states_mask for the appropriate prompt."""
if positive:
return {"encoder_hidden_states_mask": self.prompt_mask}
else:
return {"encoder_hidden_states_mask": self.negative_prompt_mask}
@property
def conditioning_latents(self) -> mx.array | None:
"""Standard Qwen does not use conditioning latents."""
return None
class QwenModelAdapter(BaseModelAdapter):
"""Adapter for Qwen-Image model.
Key differences from Flux:
- Single text encoder (vs dual T5+CLIP)
- 60 joint-style blocks, no single blocks
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
- Norm-preserving CFG with negative prompts
- Uses attention mask for variable-length text
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImage(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
model_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> QwenImage:
return self._model
@property
def transformer(self) -> QwenTransformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def encode_prompt(self, prompt: str) -> QwenPromptData:
"""Encode prompt into QwenPromptData.
Qwen uses classifier-free guidance with explicit negative prompts.
Returns a QwenPromptData container with all 4 tensors.
"""
assert isinstance(self.model.prompt_cache, dict)
assert isinstance(self.model.tokenizers, dict)
# TODO(ciaran): empty string as default negative prompt
negative_prompt = ""
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
QwenPromptEncoder.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_cache=self.model.prompt_cache,
qwen_tokenizer=self.model.tokenizers["qwen"],
qwen_text_encoder=self.model.text_encoder,
)
)
return QwenPromptData(
prompt_embeds=prompt_embeds,
prompt_mask=prompt_mask,
negative_prompt_embeds=neg_embeds,
negative_prompt_mask=neg_mask,
)
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
return self._model.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute image and text embeddings."""
# Image embedding
embedded_hidden = self._transformer.img_in(hidden_states)
# Text embedding: first normalize, then project
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings.
For Qwen, the time_text_embed only uses hidden_states for:
- batch_size (shape[0])
- dtype
This allows us to pass any tensor (latents, prompt_embeds) as a fallback
when embedded hidden_states are not yet available.
"""
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
# (which for Qwen is the same as prompt_embeds)
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
**kwargs: Any,
) -> Any:
"""Compute 3D rotary embeddings for Qwen.
Qwen uses video-aware 3D RoPE with separate embeddings for image and text.
Returns:
tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]:
((img_cos, img_sin), (txt_cos, txt_sin))
"""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
cond_image_grid = kwargs.get("cond_image_grid")
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]] for Qwen
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply Qwen joint block.
For caching mode, we run the full block and optionally populate the KV cache.
For patched mode, we use the cached KV values (not yet implemented).
"""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
block_idx = kwargs.get("block_idx")
if mode == BlockWrapperMode.CACHING:
return self._apply_joint_block_caching(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
else:
# mode == BlockWrapperMode.PATCHED
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_joint_block_patched(
block=block,
patch_hidden=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Qwen has no single blocks."""
raise NotImplementedError("Qwen does not have single blocks")
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final normalization and projection."""
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)
def get_joint_blocks(self) -> list[JointBlockInterface]:
"""Return all 60 transformer blocks."""
return cast(
list[JointBlockInterface], list(self._transformer.transformer_blocks)
)
def get_single_blocks(self) -> list[SingleBlockInterface]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
) -> None:
all_blocks = list(self._transformer.transformer_blocks)
assigned_blocks = all_blocks[start_layer:end_layer]
self._transformer.transformer_blocks = assigned_blocks
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams.
For Qwen, this is called before final projection.
The streams remain separate through all blocks.
"""
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def _apply_joint_block_caching(
self,
block: Any, # QwenTransformerBlock
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
"""Apply joint block in caching mode (full attention, optionally populate cache).
Delegates to the QwenTransformerBlock's forward pass.
"""
# Call the block directly - it handles all the modulation and attention internally
return block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
text_embeddings=text_embeddings,
image_rotary_emb=rotary_embeddings,
block_idx=block_idx,
)
def _apply_joint_block_patched(
self,
block: Any, # QwenTransformerBlock
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dim
# 1. Compute modulation parameters
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
# 2. Apply normalization and modulation
img_normed = block.img_norm1(patch_hidden)
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
txt_normed = block.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
# 3. Compute Q, K, V for image patch
img_query = attn.to_q(img_modulated)
img_key = attn.to_k(img_modulated)
img_value = attn.to_v(img_modulated)
# 4. Compute Q, K, V for text
txt_query = attn.add_q_proj(txt_modulated)
txt_key = attn.add_k_proj(txt_modulated)
txt_value = attn.add_v_proj(txt_modulated)
# 5. Reshape to [B, S, H, D]
patch_len = patch_hidden.shape[1]
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
txt_query = mx.reshape(
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
)
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
txt_value = mx.reshape(
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
)
# 6. Apply RMSNorm to Q, K
img_query = attn.norm_q(img_query)
img_key = attn.norm_k(img_key)
txt_query = attn.norm_added_q(txt_query)
txt_key = attn.norm_added_k(txt_key)
# 7. Extract RoPE for patch: slice image RoPE, keep full text RoPE
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
patch_img_cos = img_cos[patch_start:patch_end]
patch_img_sin = img_sin[patch_start:patch_end]
# 8. Apply RoPE to Q, K
img_query = QwenAttention._apply_rope_qwen(
img_query, patch_img_cos, patch_img_sin
)
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
# 9. Transpose to [B, H, S, D] for cache operations
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
# 10. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=img_key_bhsd,
value=img_value_bhsd,
)
# 11. Get full K, V from cache (text + full image)
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
full_key, full_value = kv_cache.get_full_kv(
text_key=txt_key_bhsd,
text_value=txt_value_bhsd,
)
# 12. Build query: [text, patch]
joint_query = mx.concatenate([txt_query, img_query], axis=1)
# 13. Build attention mask for [text + patch] query attending to [text + full_image] KV
mask = QwenAttention._convert_mask_for_qwen(
mask=encoder_hidden_states_mask,
joint_seq_len=full_key.shape[2], # text + full_image
txt_seq_len=text_seq_len,
)
# 14. Compute attention
hidden_states = attn._compute_attention_qwen(
query=joint_query,
key=mx.transpose(full_key, (0, 2, 1, 3)), # Back to [B, S, H, D]
value=mx.transpose(full_value, (0, 2, 1, 3)),
mask=mask,
block_idx=block_idx,
)
# 15. Extract text and image attention outputs
txt_attn_output = hidden_states[:, :text_seq_len, :]
img_attn_output = hidden_states[:, text_seq_len:, :]
# 16. Project outputs
img_attn_output = attn.attn_to_out[0](img_attn_output)
txt_attn_output = attn.to_add_out(txt_attn_output)
# 17. Apply residual + gate for attention
patch_hidden = patch_hidden + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# 18. Apply feed-forward for image
img_normed2 = block.img_norm2(patch_hidden)
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
img_normed2, img_mod2
)
img_mlp_output = block.img_ff(img_modulated2)
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
# 19. Apply feed-forward for text
txt_normed2 = block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
txt_normed2, txt_mod2
)
txt_mlp_output = block.txt_ff(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
return encoder_hidden_states, patch_hidden

View File

@@ -0,0 +1,49 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
# Qwen-Image has 60 joint-style blocks (no single blocks)
# Architecture: 24 heads * 128 dim = 3072 hidden dim
# VAE uses scale factor of 16 (vs Flux's 8)
QWEN_IMAGE_CONFIG = ImageModelConfig(
model_family="qwen",
model_variant="image",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
# Qwen has no single blocks - all blocks process image and text separately
),
patch_size=2,
vae_scale_factor=16,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
uses_attention_mask=True, # Qwen uses encoder_hidden_states_mask
guidance_scale=None, # Set to None or < 1.0 to disable CFG
)
# Qwen-Image-Edit uses the same architecture but different processing pipeline
# Uses vision-language encoding and conditioning latents
QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
model_family="qwen-edit",
model_variant="image-edit",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
),
patch_size=2,
vae_scale_factor=16,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125,
uses_attention_mask=True,
guidance_scale=None,
)

View File

@@ -0,0 +1,648 @@
import math
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
QwenTransformerBlock,
)
from mflux.models.qwen.variants.edit.qwen_edit_util import QwenEditUtil
from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class QwenEditPromptData:
"""Container for Qwen edit prompt encoding results.
Includes vision-language encoded embeddings and edit-specific conditioning.
"""
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
conditioning_latents: mx.array,
qwen_image_ids: mx.array,
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],
):
self._prompt_embeds = prompt_embeds
self.prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self.negative_prompt_mask = negative_prompt_mask
self._conditioning_latents = conditioning_latents
self._qwen_image_ids = qwen_image_ids
self._cond_image_grid = cond_image_grid
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from vision-language encoder."""
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
return self._prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array:
"""Negative prompt embeddings for CFG."""
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
"""Placeholder - Qwen doesn't use pooled embeds."""
return self._negative_prompt_embeds
@property
def conditioning_latents(self) -> mx.array:
"""Static image conditioning latents to concatenate with generated latents."""
return self._conditioning_latents
@property
def qwen_image_ids(self) -> mx.array:
"""Spatial position IDs for conditioning images."""
return self._qwen_image_ids
@property
def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:
"""Conditioning image grid dimensions."""
return self._cond_image_grid
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Return encoder_hidden_states_mask and edit-specific params."""
if positive:
return {
"encoder_hidden_states_mask": self.prompt_mask,
"qwen_image_ids": self._qwen_image_ids,
"cond_image_grid": self._cond_image_grid,
}
else:
return {
"encoder_hidden_states_mask": self.negative_prompt_mask,
"qwen_image_ids": self._qwen_image_ids,
"cond_image_grid": self._cond_image_grid,
}
@property
def is_edit_mode(self) -> bool:
"""Indicates this is edit mode with conditioning latents."""
return True
class QwenEditModelAdapter(BaseModelAdapter):
"""Adapter for Qwen-Image-Edit model.
Key differences from standard QwenModelAdapter:
- Uses QwenImageEdit model with vision-language components
- Encodes prompts WITH input images via VL tokenizer/encoder
- Creates conditioning latents from input images
- Supports image editing with concatenated latents during diffusion
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImageEdit(
quantize=quantize,
model_path=str(local_path),
)
self._transformer = self._model.transformer
# Store dimensions and image paths (set via set_image_dimensions)
self._vl_width: int | None = None
self._vl_height: int | None = None
self._vae_width: int | None = None
self._vae_height: int | None = None
self._image_paths: list[str] | None = None
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> QwenImageEdit:
return self._model
@property
def transformer(self) -> QwenTransformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def _compute_dimensions_from_image(
self, image_path: Path
) -> tuple[int, int, int, int, int, int]:
"""Compute VL and VAE dimensions from input image.
Returns:
(vl_width, vl_height, vae_width, vae_height, output_width, output_height)
"""
from mflux.utils.image_util import ImageUtil
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
image_size = pil_image.size
# Vision-language dimensions (384x384 target area)
condition_image_size = 384 * 384
condition_ratio = image_size[0] / image_size[1]
vl_width = math.sqrt(condition_image_size * condition_ratio)
vl_height = vl_width / condition_ratio
vl_width = round(vl_width / 32) * 32
vl_height = round(vl_height / 32) * 32
# VAE dimensions (1024x1024 target area)
vae_image_size = 1024 * 1024
vae_ratio = image_size[0] / image_size[1]
vae_width = math.sqrt(vae_image_size * vae_ratio)
vae_height = vae_width / vae_ratio
vae_width = round(vae_width / 32) * 32
vae_height = round(vae_height / 32) * 32
# Output dimensions from input image aspect ratio
target_area = 1024 * 1024
ratio = image_size[0] / image_size[1]
output_width = math.sqrt(target_area * ratio)
output_height = output_width / ratio
output_width = round(output_width / 32) * 32
output_height = round(output_height / 32) * 32
# Ensure multiple of 16 for VAE
vae_scale_factor = 8
multiple_of = vae_scale_factor * 2
output_width = output_width // multiple_of * multiple_of
output_height = output_height // multiple_of * multiple_of
return (
int(vl_width),
int(vl_height),
int(vae_width),
int(vae_height),
int(output_width),
int(output_height),
)
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial noise latents (pure noise for edit mode)."""
return QwenLatentCreator.create_noise(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
)
def encode_prompt(self, prompt: str) -> QwenEditPromptData:
"""Encode prompt with input images using vision-language encoder.
Uses stored image_paths from set_image_dimensions() for VL encoding.
Args:
prompt: Text prompt for editing
Returns:
QwenEditPromptData with VL embeddings and conditioning latents
"""
# Ensure image_paths and dimensions were set via set_image_dimensions()
if (
self._image_paths is None
or self._vl_height is None
or self._vl_width is None
or self._vae_height is None
or self._vae_width is None
):
raise RuntimeError(
"set_image_dimensions() must be called before encode_prompt() "
"for QwenEditModelAdapter"
)
negative_prompt = ""
image_paths = self._image_paths
# TODO(ciaran): config is untyped and unused, unsure if Config or RuntimeConfig is intended
(
prompt_embeds,
prompt_mask,
negative_prompt_embeds,
negative_prompt_mask,
) = self._model._encode_prompts_with_images(
prompt,
negative_prompt,
image_paths,
self._config,
self._vl_width,
self._vl_height,
)
(
conditioning_latents,
qwen_image_ids,
cond_h_patches,
cond_w_patches,
num_images,
) = QwenEditUtil.create_image_conditioning_latents(
vae=self._model.vae,
height=self._vae_height,
width=self._vae_width,
image_paths=image_paths,
vl_width=self._vl_width,
vl_height=self._vl_height,
)
# Build cond_image_grid
if num_images > 1:
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [
(1, cond_h_patches, cond_w_patches) for _ in range(num_images)
]
else:
cond_image_grid = (1, cond_h_patches, cond_w_patches)
return QwenEditPromptData(
prompt_embeds=prompt_embeds,
prompt_mask=prompt_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_mask=negative_prompt_mask,
conditioning_latents=conditioning_latents,
qwen_image_ids=qwen_image_ids,
cond_image_grid=cond_image_grid,
)
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
"""Compute and store dimensions from input image.
Also stores image_paths for use in encode_prompt().
Returns:
(output_width, output_height) for runtime config
"""
vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(
image_path
)
self._vl_width = vl_w
self._vl_height = vl_h
self._vae_width = vae_w
self._vae_height = vae_h
self._image_paths = [str(image_path)]
return out_w, out_h
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
return QwenImage.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute image and text embeddings."""
embedded_hidden = self._transformer.img_in(hidden_states)
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings."""
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
**kwargs: Any,
) -> Any:
"""Compute 3D rotary embeddings for Qwen edit."""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
cond_image_grid = kwargs.get("cond_image_grid")
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply Qwen joint block."""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
block_idx = kwargs.get("block_idx")
if mode == BlockWrapperMode.CACHING:
return self._apply_joint_block_caching(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
else:
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_joint_block_patched(
block=block,
patch_hidden=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Qwen has no single blocks."""
raise NotImplementedError("Qwen does not have single blocks")
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final normalization and projection."""
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)
def get_joint_blocks(self) -> list[JointBlockInterface]:
"""Return all 60 transformer blocks."""
return cast(
list[JointBlockInterface], list(self._transformer.transformer_blocks)
)
def get_single_blocks(self) -> list[SingleBlockInterface]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
) -> None:
all_blocks = list(self._transformer.transformer_blocks)
assigned_blocks = all_blocks[start_layer:end_layer]
self._transformer.transformer_blocks = assigned_blocks
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams."""
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def _apply_joint_block_caching(
self,
block: Any,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
"""Apply joint block in caching mode."""
return block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
text_embeddings=text_embeddings,
image_rotary_emb=rotary_embeddings,
block_idx=block_idx,
)
def _apply_joint_block_patched(
self,
block: Any,
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dim
# Modulation parameters
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
# Normalization and modulation
img_normed = block.img_norm1(patch_hidden)
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
txt_normed = block.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
# Q, K, V for image patch
img_query = attn.to_q(img_modulated)
img_key = attn.to_k(img_modulated)
img_value = attn.to_v(img_modulated)
# Q, K, V for text
txt_query = attn.add_q_proj(txt_modulated)
txt_key = attn.add_k_proj(txt_modulated)
txt_value = attn.add_v_proj(txt_modulated)
# Reshape to [B, S, H, D]
patch_len = patch_hidden.shape[1]
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
txt_query = mx.reshape(
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
)
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
txt_value = mx.reshape(
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
)
# RMSNorm to Q, K
img_query = attn.norm_q(img_query)
img_key = attn.norm_k(img_key)
txt_query = attn.norm_added_q(txt_query)
txt_key = attn.norm_added_k(txt_key)
# Extract RoPE for patch
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
patch_img_cos = img_cos[patch_start:patch_end]
patch_img_sin = img_sin[patch_start:patch_end]
# Apply RoPE
img_query = QwenAttention._apply_rope_qwen(
img_query, patch_img_cos, patch_img_sin
)
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
# Transpose to [B, H, S, D]
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
# Update cache
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=img_key_bhsd,
value=img_value_bhsd,
)
# Get full K, V from cache
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
full_key, full_value = kv_cache.get_full_kv(
text_key=txt_key_bhsd,
text_value=txt_value_bhsd,
)
# Build query
joint_query = mx.concatenate([txt_query, img_query], axis=1)
# Build attention mask
mask = QwenAttention._convert_mask_for_qwen(
mask=encoder_hidden_states_mask,
joint_seq_len=full_key.shape[2],
txt_seq_len=text_seq_len,
)
# Compute attention
hidden_states = attn._compute_attention_qwen(
query=joint_query,
key=mx.transpose(full_key, (0, 2, 1, 3)),
value=mx.transpose(full_value, (0, 2, 1, 3)),
mask=mask,
block_idx=block_idx,
)
# Extract outputs
txt_attn_output = hidden_states[:, :text_seq_len, :]
img_attn_output = hidden_states[:, text_seq_len:, :]
# Project
img_attn_output = attn.attn_to_out[0](img_attn_output)
txt_attn_output = attn.to_add_out(txt_attn_output)
# Residual + gate
patch_hidden = patch_hidden + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Feed-forward for image
img_normed2 = block.img_norm2(patch_hidden)
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
img_normed2, img_mod2
)
img_mlp_output = block.img_ff(img_modulated2)
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
# Feed-forward for text
txt_normed2 = block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
txt_normed2, txt_mod2
)
txt_mlp_output = block.txt_ff(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
return encoder_hidden_states, patch_hidden

View File

@@ -0,0 +1,23 @@
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
ModelAdapter,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
__all__ = [
"BlockWrapperMode",
"DiffusionRunner",
"ImagePatchKVCache",
"JointBlockInterface",
"JointBlockWrapper",
"ModelAdapter",
"SingleBlockInterface",
"SingleBlockWrapper",
]

View File

@@ -0,0 +1,402 @@
from enum import Enum
from pathlib import Path
from typing import Any, Protocol
import mlx.core as mx
from mflux.models.common.config.config import Config
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class AttentionInterface(Protocol):
num_heads: int
head_dimension: int
to_q: Any
to_k: Any
to_v: Any
norm_q: Any
norm_k: Any
to_out: list[Any]
class JointAttentionInterface(AttentionInterface, Protocol):
add_q_proj: Any
add_k_proj: Any
add_v_proj: Any
norm_added_q: Any
norm_added_k: Any
to_add_out: Any
class JointBlockInterface(Protocol):
attn: JointAttentionInterface
norm1: Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
norm1_context: (
Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
)
norm2: Any
norm2_context: Any
ff: Any
ff_context: Any
class SingleBlockInterface(Protocol):
attn: AttentionInterface
norm: Any # Callable module: (hidden_states, text_embeddings) -> tuple[2 arrays]
def _apply_feed_forward_and_projection(
self, norm_hidden_states: mx.array, attn_output: mx.array, gate: mx.array
) -> mx.array:
"""Apply feed forward network and projection."""
...
class BlockWrapperMode(Enum):
CACHING = "caching" # Sync mode: compute full attention, populate cache
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
class PromptData(Protocol):
"""Protocol for encoded prompt data.
All adapters must return prompt data that conforms to this protocol.
Model-specific prompt data classes can add additional attributes
(e.g., attention masks for Qwen).
"""
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
...
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Pooled text embeddings (for Flux) or placeholder (for Qwen)."""
...
@property
def negative_prompt_embeds(self) -> mx.array | None:
"""Negative prompt embeddings for CFG (None if not using CFG)."""
...
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Negative pooled embeddings for CFG (None if not using CFG)."""
...
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Return model-specific kwargs for forward pass.
Args:
positive: If True, return kwargs for positive prompt pass.
If False, return kwargs for negative prompt pass.
Returns:
Dict of extra kwargs (e.g., {"encoder_hidden_states_mask": ...} for Qwen)
"""
...
@property
def conditioning_latents(self) -> mx.array | None:
"""Conditioning latents for edit mode.
Returns:
Conditioning latents array for image editing, None for standard generation.
"""
...
class ModelAdapter(Protocol):
@property
def config(self) -> ImageModelConfig:
"""Return the model configuration."""
...
@property
def model(self) -> Any:
"""Return the underlying mflux model instance (e.g., Flux1, Fibo, Qwen)."""
...
@property
def transformer(self) -> Any:
"""Return the transformer component of the model."""
...
@property
def hidden_dim(self) -> int:
"""Return the hidden dimension of the transformer."""
...
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute x_embedder and context_embedder outputs.
Args:
hidden_states: Input latent states
prompt_embeds: Text embeddings from encoder
Returns:
Tuple of (embedded_hidden_states, embedded_encoder_states)
"""
...
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings for conditioning.
Args:
t: Current timestep
runtime_config: Runtime configuration
pooled_prompt_embeds: Pooled text embeddings (used by Flux)
hidden_states: Image hidden states
Returns:
Text embeddings tensor
"""
...
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
**kwargs: Any,
) -> Any:
"""Compute rotary position embeddings.
Args:
prompt_embeds: Text embeddings
runtime_config: Runtime configuration
**kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask for Qwen)
Returns:
Flux: mx.array
Qwen: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]
"""
...
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # Format varies: mx.array (Flux) or nested tuple (Qwen)
kv_cache: ImagePatchKVCache | None,
mode: "BlockWrapperMode",
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply a joint transformer block.
Args:
block: The joint transformer block
hidden_states: Image hidden states
encoder_hidden_states: Text hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings (format varies by model)
kv_cache: KV cache (None if not using cache)
mode: CACHING or PATCHED mode
text_seq_len: Text sequence length
patch_start: Start index for patched mode
patch_end: End index for patched mode
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
block_idx for Qwen)
Returns:
Tuple of (encoder_hidden_states, hidden_states)
"""
...
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: "BlockWrapperMode",
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Apply a single transformer block.
Args:
block: The single transformer block
hidden_states: Concatenated [text + image] hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
kv_cache: KV cache (None if not using cache)
mode: CACHING or PATCHED mode
text_seq_len: Text sequence length
patch_start: Start index for patched mode
patch_end: End index for patched mode
Returns:
Output hidden states
"""
...
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final norm and projection.
Args:
hidden_states: Hidden states (image only, text already removed)
text_embeddings: Conditioning embeddings
Returns:
Projected output
"""
...
def get_joint_blocks(self) -> list[JointBlockInterface]:
"""Get the list of joint transformer blocks from the model."""
...
def get_single_blocks(self) -> list[SingleBlockInterface]:
"""Get the list of single transformer blocks from the model."""
...
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
):
"""Remove transformer blocks outside the assigned range.
This should be called BEFORE mx.eval() to avoid loading unused weights
in distributed mode.
Args:
start_layer: First layer index (inclusive) assigned to this node
end_layer: Last layer index (exclusive) assigned to this node
total_joint_blocks: Total number of joint blocks in the model
total_single_blocks: Total number of single blocks in the model
"""
...
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams for transition to single blocks.
This is called at the transition point from joint blocks (which process
image and text separately) to single blocks (which process them
together). Override to customize the merge strategy.
Args:
hidden_states: Image hidden states
encoder_hidden_states: Text hidden states
Returns:
Merged hidden states (default: concatenate [text, image])
"""
...
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial noise latents for generation.
Args:
seed: Random seed
runtime_config: Runtime configuration with dimensions
Returns:
Initial latent tensor
"""
...
def encode_prompt(self, prompt: str) -> PromptData:
"""Encode prompt into model-specific prompt data.
Args:
prompt: Text prompt
Returns:
PromptData containing embeddings (and model-specific extras)
"""
...
@property
def needs_cfg(self) -> bool:
"""Whether this model uses classifier-free guidance.
Returns:
True if model requires two forward passes with guidance (e.g., Qwen)
False if model uses a single forward pass (e.g., Flux)
"""
...
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
"""Apply classifier-free guidance to combine positive/negative predictions.
Only called when needs_cfg is True.
Args:
noise_positive: Noise prediction from positive prompt
noise_negative: Noise prediction from negative prompt
guidance_scale: Guidance strength
Returns:
Guided noise prediction
"""
...
def decode_latents(
self,
latents: mx.array,
runtime_config: Config,
seed: int,
prompt: str,
) -> Any:
"""Decode latents to final image.
Args:
latents: Final denoised latents
runtime_config: Runtime configuration
seed: Random seed (for metadata)
prompt: Text prompt (for metadata)
Returns:
GeneratedImage result
"""
...
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
"""Compute and store dimensions from input image for edit mode.
For edit adapters: computes dimensions from input image aspect ratio,
stores image paths internally for encode_prompt(), returns (width, height).
For standard adapters: returns None (use user-specified dimensions).
Args:
image_path: Path to the input image
Returns:
Tuple of (width, height) if dimensions were computed, None otherwise.
"""
...

View File

@@ -0,0 +1,146 @@
from typing import Any
import mlx.core as mx
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
ModelAdapter,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class JointBlockWrapper:
"""Unified wrapper for joint transformer blocks.
Handles both CACHING (sync) and PATCHED (async) modes by delegating
to the model adapter for model-specific attention computation.
The wrapper is created once at initialization and reused across calls.
Mode and KV cache are passed at call time to support switching between
sync and async pipelines.
"""
def __init__(
self,
block: JointBlockInterface,
adapter: ModelAdapter,
):
"""Initialize the joint block wrapper.
Args:
block: The joint transformer block to wrap
adapter: Model adapter for model-specific operations
"""
self.block = block
self.adapter = adapter
def __call__(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
text_seq_len: int,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply the joint block.
Args:
hidden_states: Image hidden states (full or patch depending on mode)
encoder_hidden_states: Text hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
text_seq_len: Text sequence length
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
mode: CACHING (populate cache) or PATCHED (use cached K/V)
patch_start: Start index for patched mode (required if mode=PATCHED)
patch_end: End index for patched mode (required if mode=PATCHED)
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
block_idx for Qwen)
Returns:
Tuple of (encoder_hidden_states, hidden_states)
"""
return self.adapter.apply_joint_block(
block=self.block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
mode=mode,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
**kwargs,
)
class SingleBlockWrapper:
"""Unified wrapper for single transformer blocks.
Handles both CACHING (sync) and PATCHED (async) modes by delegating
to the model adapter for model-specific attention computation.
The wrapper is created once at initialization and reused across calls.
Mode and KV cache are passed at call time to support switching between
sync and async pipelines.
"""
def __init__(
self,
block: SingleBlockInterface,
adapter: ModelAdapter,
):
"""Initialize the single block wrapper.
Args:
block: The single transformer block to wrap
adapter: Model adapter for model-specific operations
"""
self.block = block
self.adapter = adapter
def __call__(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
text_seq_len: int,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Apply the single block.
Args:
hidden_states: [text + image] hidden states (full or patch depending on mode)
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
text_seq_len: Text sequence length
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
mode: CACHING (populate cache) or PATCHED (use cached K/V)
patch_start: Start index for patched mode (required if mode=PATCHED)
patch_end: End index for patched mode (required if mode=PATCHED)
Returns:
Output hidden states
"""
return self.adapter.apply_single_block(
block=self.block,
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
mode=mode,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)

View File

@@ -0,0 +1,72 @@
import mlx.core as mx
class ImagePatchKVCache:
"""KV cache that stores only IMAGE K/V with patch-level updates.
Only caches image K/V since:
- Text K/V is always computed fresh (same for all patches)
- Only image portion needs stale/fresh cache management across patches
"""
def __init__(
self,
batch_size: int,
num_heads: int,
image_seq_len: int,
head_dim: int,
dtype: mx.Dtype = mx.float32,
):
self.batch_size = batch_size
self.num_heads = num_heads
self.image_seq_len = image_seq_len
self.head_dim = head_dim
self._dtype = dtype
self.key_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
self.value_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
def update_image_patch(
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
) -> None:
"""Update cache with fresh K/V for an image patch slice.
Args:
patch_start: Start token index within image portion (0-indexed)
patch_end: End token index within image portion
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
"""
self.key_cache[:, :, patch_start:patch_end, :] = key
self.value_cache[:, :, patch_start:patch_end, :] = value
def get_full_kv(
self, text_key: mx.array, text_value: mx.array
) -> tuple[mx.array, mx.array]:
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
Args:
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
Returns:
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
"""
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
return full_key, full_value
def reset(self) -> None:
"""Reset cache to zeros."""
self.key_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)
self.value_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)

View File

@@ -0,0 +1,956 @@
from math import ceil
from typing import Any, Optional
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.utils.exceptions import StopImageGenerationException
from tqdm import tqdm
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
ModelAdapter,
PromptData,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
def calculate_patch_heights(latent_height: int, num_patches: int):
patch_height = ceil(latent_height / num_patches)
actual_num_patches = ceil(latent_height / patch_height)
patch_heights = [patch_height] * (actual_num_patches - 1)
last_height = latent_height - patch_height * (actual_num_patches - 1)
patch_heights.append(last_height)
return patch_heights, actual_num_patches
def calculate_token_indices(patch_heights: list[int], latent_width: int):
tokens_per_row = latent_width
token_ranges = []
cumulative_height = 0
for h in patch_heights:
start_token = tokens_per_row * cumulative_height
end_token = tokens_per_row * (cumulative_height + h)
token_ranges.append((start_token, end_token))
cumulative_height += h
return token_ranges
class DiffusionRunner:
"""Orchestrates the diffusion loop for image generation.
This class owns the entire diffusion process, handling both single-node
and distributed (PipeFusion) modes.
In distributed mode, it implements PipeFusion with:
- Sync pipeline for initial timesteps (full image, all devices in lockstep)
- Async pipeline for later timesteps (patches processed independently)
"""
def __init__(
self,
config: ImageModelConfig,
adapter: ModelAdapter,
group: Optional[mx.distributed.Group],
shard_metadata: PipelineShardMetadata,
num_sync_steps: int = 1,
num_patches: Optional[int] = None,
):
"""Initialize the diffusion runner.
Args:
config: Model configuration (architecture, block counts, etc.)
adapter: Model adapter for model-specific operations
group: MLX distributed group (None for single-node mode)
shard_metadata: Pipeline shard metadata with layer assignments
num_sync_steps: Number of synchronous timesteps before async mode
num_patches: Number of patches for async mode (defaults to world_size)
"""
self.config = config
self.adapter = adapter
self.group = group
# Handle single-node vs distributed mode
if group is None:
self.rank = 0
self.world_size = 1
self.next_rank = 0
self.prev_rank = 0
self.start_layer = 0
self.end_layer = config.total_blocks
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.next_rank = (self.rank + 1) % self.world_size
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.num_sync_steps = num_sync_steps
self.num_patches = num_patches if num_patches else max(1, self.world_size)
# Persistent KV caches (initialized on first async timestep, reused across timesteps)
self.joint_kv_caches: list[ImagePatchKVCache] | None = None
self.single_kv_caches: list[ImagePatchKVCache] | None = None
# Get block counts from config (model-agnostic)
self.total_joint = config.joint_block_count
self.total_single = config.single_block_count
self.total_layers = config.total_blocks
self._compute_assigned_blocks()
def _compute_assigned_blocks(self) -> None:
"""Determine which joint/single blocks this stage owns."""
start = self.start_layer
end = self.end_layer
if end <= self.total_joint:
# All assigned blocks are joint blocks
self.joint_start = start
self.joint_end = end
self.single_start = 0
self.single_end = 0
elif start >= self.total_joint:
# All assigned blocks are single blocks
self.joint_start = 0
self.joint_end = 0
self.single_start = start - self.total_joint
self.single_end = end - self.total_joint
else:
# Stage spans joint→single transition
self.joint_start = start
self.joint_end = self.total_joint
self.single_start = 0
self.single_end = end - self.total_joint
self.has_joint_blocks = self.joint_end > self.joint_start
self.has_single_blocks = self.single_end > self.single_start
self.owns_concat_stage = self.has_joint_blocks and (
self.has_single_blocks or self.end_layer == self.total_joint
)
joint_blocks = self.adapter.get_joint_blocks()
single_blocks = self.adapter.get_single_blocks()
# Wrap blocks at initialization (reused across all calls)
self.joint_block_wrappers = [
JointBlockWrapper(block=block, adapter=self.adapter)
for block in joint_blocks
]
self.single_block_wrappers = [
SingleBlockWrapper(block=block, adapter=self.adapter)
for block in single_blocks
]
@property
def is_first_stage(self) -> bool:
return self.rank == 0
@property
def is_last_stage(self) -> bool:
return self.rank == self.world_size - 1
@property
def is_distributed(self) -> bool:
return self.group is not None
def _calculate_capture_steps(
self,
partial_images: int,
init_time_step: int,
num_inference_steps: int,
) -> set[int]:
"""Calculate which timesteps should produce partial images.
Evenly spaces `partial_images` captures across the diffusion loop.
Does NOT include the final timestep (that's the complete image).
Args:
partial_images: Number of partial images to capture
init_time_step: Starting timestep (for img2img this may not be 0)
num_inference_steps: Total inference steps
Returns:
Set of timestep indices to capture
"""
if partial_images <= 0:
return set()
total_steps = num_inference_steps - init_time_step
if total_steps <= 1:
return set()
if partial_images >= total_steps - 1:
# Capture every step except final
return set(range(init_time_step, num_inference_steps - 1))
# Evenly space partial captures
step_interval = total_steps / (partial_images + 1)
capture_steps: set[int] = set()
for i in range(1, partial_images + 1):
step_idx = int(init_time_step + i * step_interval)
# Ensure we don't capture the final step
if step_idx < num_inference_steps - 1:
capture_steps.add(step_idx)
return capture_steps
def generate_image(
self,
runtime_config: Config,
prompt: str,
seed: int,
partial_images: int = 0,
):
"""Primary entry point for image generation.
Orchestrates the full generation flow:
1. Create runtime config
2. Create initial latents
3. Encode prompt
4. Run diffusion loop (yielding partials if requested)
5. Decode to image
When partial_images > 0, yields (GeneratedImage, partial_index, total_partials)
tuples for intermediate images, then yields the final GeneratedImage.
Args:
settings: Generation config (steps, height, width)
prompt: Text prompt
seed: Random seed
partial_images: Number of intermediate images to yield (0 for none)
Yields:
Partial images as (GeneratedImage, partial_index, total_partials) tuples
Final GeneratedImage
"""
latents = self.adapter.create_latents(seed, runtime_config)
prompt_data = self.adapter.encode_prompt(prompt)
# Calculate which steps to capture
capture_steps = self._calculate_capture_steps(
partial_images=partial_images,
init_time_step=runtime_config.init_time_step,
num_inference_steps=runtime_config.num_inference_steps,
)
# Run diffusion loop - may yield partial latents
diffusion_gen = self._run_diffusion_loop(
latents=latents,
prompt_data=prompt_data,
runtime_config=runtime_config,
seed=seed,
prompt=prompt,
capture_steps=capture_steps,
)
# Process partial yields and get final latents
partial_index = 0
total_partials = len(capture_steps)
if capture_steps:
# Generator mode - iterate to get partials and final latents
try:
while True:
partial_latents, _step = next(diffusion_gen)
if self.is_last_stage:
partial_image = self.adapter.decode_latents(
partial_latents, runtime_config, seed, prompt
)
yield (partial_image, partial_index, total_partials)
partial_index += 1
except StopIteration as e:
latents = e.value
else:
# No partials - just consume generator to get final latents
try:
while True:
next(diffusion_gen)
except StopIteration as e:
latents = e.value
# Yield final image (only on last stage)
if self.is_last_stage:
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt)
def _run_diffusion_loop(
self,
latents: mx.array,
prompt_data: PromptData,
runtime_config: Config,
seed: int,
prompt: str,
capture_steps: set[int] | None = None,
):
"""Execute the diffusion loop, optionally yielding at capture steps.
When capture_steps is provided and non-empty, this becomes a generator
that yields (latents, step_index) tuples at the specified timesteps.
Only the last stage yields (others have incomplete latents).
Args:
latents: Initial noise latents
prompt_data: Encoded prompt data
runtime_config: RuntimeConfig with scheduler, steps, dimensions
seed: Random seed (for callbacks)
prompt: Text prompt (for callbacks)
capture_steps: Set of timestep indices to capture (None = no captures)
Yields:
(latents, step_index) tuples at capture steps (last stage only)
Returns:
Final denoised latents ready for VAE decoding
"""
if capture_steps is None:
capture_steps = set()
time_steps = tqdm(range(runtime_config.num_inference_steps))
ctx = self.adapter.model.callbacks.start(
seed=seed, prompt=prompt, config=runtime_config
)
ctx.before_loop(
latents=latents,
)
for t in time_steps:
try:
latents = self._diffusion_step(
t=t,
config=runtime_config,
latents=latents,
prompt_data=prompt_data,
)
# Call subscribers in-loop
ctx.in_loop(
t=t,
latents=latents,
)
mx.eval(latents)
# Yield partial latents at capture steps (only on last stage)
if t in capture_steps and self.is_last_stage:
yield (latents, t)
except KeyboardInterrupt: # noqa: PERF203
ctx.interruption(t=t, latents=latents)
raise StopImageGenerationException(
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
) from None
# Call subscribers after loop
ctx.after_loop(latents=latents)
return latents
def _forward_pass(
self,
latents: mx.array,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
kwargs: dict[str, Any],
) -> mx.array:
"""Run a single forward pass through the transformer.
This is the internal method called by adapters via compute_step_noise.
Returns noise prediction without applying scheduler step.
For edit mode, concatenates conditioning latents with generated latents
before the transformer, and extracts only the generated portion after.
Args:
latents: Input latents (already scaled by caller)
prompt_embeds: Text embeddings
pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)
kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask, t)
Returns:
Noise prediction tensor
"""
t = kwargs.get("t", 0)
config = kwargs.get("config")
if config is None:
raise ValueError("config must be provided in kwargs")
scaled_latents = config.scheduler.scale_model_input(latents, t)
# For edit mode: concatenate with conditioning latents
conditioning_latents = kwargs.get("conditioning_latents")
original_latent_tokens = scaled_latents.shape[1]
if conditioning_latents is not None:
scaled_latents = mx.concatenate(
[scaled_latents, conditioning_latents], axis=1
)
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
scaled_latents, prompt_embeds
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds, hidden_states=hidden_states
)
rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds, config, **kwargs
)
text_seq_len = prompt_embeds.shape[1]
# Run through all joint blocks
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=None,
mode=BlockWrapperMode.CACHING,
block_idx=block_idx,
**kwargs,
)
# Merge streams
if self.joint_block_wrappers:
hidden_states = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
# Run through single blocks
for wrapper in self.single_block_wrappers:
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=None,
mode=BlockWrapperMode.CACHING,
)
# Extract image portion and project
hidden_states = hidden_states[:, text_seq_len:, ...]
# For edit mode: extract only the generated portion (exclude conditioning latents)
if conditioning_latents is not None:
hidden_states = hidden_states[:, :original_latent_tokens, ...]
return self.adapter.final_projection(hidden_states, text_embeddings)
def _diffusion_step(
self,
t: int,
config: Config,
latents: mx.array,
prompt_data: PromptData,
) -> mx.array:
"""Execute a single diffusion step.
Routes to single-node, sync pipeline, or async pipeline based on
configuration and current timestep.
"""
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif t < config.init_time_step + self.num_sync_steps:
return self._sync_pipeline(
t,
config,
latents,
prompt_data,
)
else:
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
)
def _single_node_step(
self,
t: int,
config: Config,
latents: mx.array,
prompt_data: PromptData,
) -> mx.array:
"""Execute a single diffusion step on a single node (no distribution)."""
base_kwargs = {"t": t, "config": config}
# For edit mode: include conditioning latents
if prompt_data.conditioning_latents is not None:
base_kwargs["conditioning_latents"] = prompt_data.conditioning_latents
if self.adapter.needs_cfg:
# Two forward passes + guidance for CFG models (e.g., Qwen)
pos_kwargs = {
**base_kwargs,
**prompt_data.get_extra_forward_kwargs(positive=True),
}
noise_pos = self._forward_pass(
latents,
prompt_data.prompt_embeds,
prompt_data.pooled_prompt_embeds,
pos_kwargs,
)
neg_kwargs = {
**base_kwargs,
**prompt_data.get_extra_forward_kwargs(positive=False),
}
noise_neg = self._forward_pass(
latents,
prompt_data.negative_prompt_embeds,
prompt_data.negative_pooled_prompt_embeds,
neg_kwargs,
)
assert self.config.guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale=self.config.guidance_scale
)
else:
# Single forward pass for non-CFG models (e.g., Flux)
kwargs = {**base_kwargs, **prompt_data.get_extra_forward_kwargs()}
noise = self._forward_pass(
latents,
prompt_data.prompt_embeds,
prompt_data.pooled_prompt_embeds,
kwargs,
)
return config.scheduler.step(noise=noise, timestep=t, latents=latents)
def _initialize_kv_caches(
self,
batch_size: int,
num_img_tokens: int,
dtype: mx.Dtype,
) -> None:
"""Initialize KV caches for both sync and async pipelines.
Note: Caches only store IMAGE K/V, not text K/V. Text K/V is always
computed fresh and doesn't need caching (it's the same for all patches).
"""
self.joint_kv_caches = [
ImagePatchKVCache(
batch_size=batch_size,
num_heads=self.config.num_heads,
image_seq_len=num_img_tokens,
head_dim=self.config.head_dim,
dtype=dtype,
)
for _ in range(len(self.joint_block_wrappers))
]
self.single_kv_caches = [
ImagePatchKVCache(
batch_size=batch_size,
num_heads=self.config.num_heads,
image_seq_len=num_img_tokens,
head_dim=self.config.head_dim,
dtype=dtype,
)
for _ in range(len(self.single_block_wrappers))
]
def _create_patches(
self,
latents: mx.array,
config: Config,
) -> tuple[list[mx.array], list[tuple[int, int]]]:
"""Split latents into patches for async pipeline."""
# Use 16 to match FluxLatentCreator.create_noise formula
latent_height = config.height // 16
latent_width = config.width // 16
patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)
token_indices = calculate_token_indices(patch_heights, latent_width)
# Split latents into patches
patch_latents = [latents[:, start:end, :] for start, end in token_indices]
return patch_latents, token_indices
def _sync_pipeline(
self,
t: int,
config: Config,
hidden_states: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
prompt_embeds = prompt_data.prompt_embeds
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
extra_kwargs = prompt_data.get_extra_forward_kwargs()
hidden_states = config.scheduler.scale_model_input(hidden_states, t)
# For edit mode: handle conditioning latents
# All stages need to know the total token count for correct recv templates
conditioning_latents = prompt_data.conditioning_latents
original_latent_tokens = hidden_states.shape[1]
if conditioning_latents is not None:
num_img_tokens = original_latent_tokens + conditioning_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
# First stage: concatenate conditioning latents before embedding
if self.is_first_stage and conditioning_latents is not None:
hidden_states = mx.concatenate(
[hidden_states, conditioning_latents], axis=1
)
# === PHASE 1: Embeddings ===
if self.is_first_stage:
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
hidden_states, prompt_embeds
)
# All stages need these for their blocks
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
kontext_image_ids=kontext_image_ids,
**extra_kwargs,
)
# === Initialize KV caches to populate during sync for async warmstart ===
batch_size = prev_latents.shape[0]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
if t == config.init_time_step:
self._initialize_kv_caches(
batch_size=batch_size,
num_img_tokens=num_img_tokens,
dtype=prev_latents.dtype,
)
# === PHASE 2: Joint Blocks with Communication and Caching ===
if self.has_joint_blocks:
# Receive from previous stage (if not first stage)
if not self.is_first_stage:
recv_template = mx.zeros(
(batch_size, num_img_tokens, hidden_dim), dtype=prev_latents.dtype
)
hidden_states = mx.distributed.recv_like(
recv_template, self.prev_rank, group=self.group
)
enc_template = mx.zeros(
(batch_size, text_seq_len, hidden_dim), dtype=prev_latents.dtype
)
encoder_hidden_states = mx.distributed.recv_like(
enc_template, self.prev_rank, group=self.group
)
mx.eval(hidden_states, encoder_hidden_states)
# Run assigned joint blocks with caching mode
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.joint_kv_caches[block_idx],
mode=BlockWrapperMode.CACHING,
**extra_kwargs,
)
# === PHASE 3: Joint→Single Transition ===
if self.owns_concat_stage:
# Merge encoder and hidden states using adapter hook
concatenated = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
if self.has_single_blocks or self.is_last_stage:
# Keep locally: either for single blocks or final projection
hidden_states = concatenated
else:
# Send concatenated state to next stage (which has single blocks)
mx.eval(
mx.distributed.send(concatenated, self.next_rank, group=self.group)
)
elif self.has_joint_blocks and not self.is_last_stage:
# Send joint block outputs to next stage (which has more joint blocks)
mx.eval(
mx.distributed.send(hidden_states, self.next_rank, group=self.group),
mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
),
)
# === PHASE 4: Single Blocks with Communication and Caching ===
if self.has_single_blocks:
# Receive from previous stage if we didn't do concatenation
if not self.owns_concat_stage and not self.is_first_stage:
recv_template = mx.zeros(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype=prev_latents.dtype,
)
hidden_states = mx.distributed.recv_like(
recv_template, self.prev_rank, group=self.group
)
mx.eval(hidden_states)
# Run assigned single blocks with caching mode
for block_idx, wrapper in enumerate(self.single_block_wrappers):
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.single_kv_caches[block_idx],
mode=BlockWrapperMode.CACHING,
)
# Send to next stage if not last
if not self.is_last_stage:
mx.eval(
mx.distributed.send(hidden_states, self.next_rank, group=self.group)
)
# === PHASE 5: Last Stage - Final Projection + Scheduler ===
# Extract image portion (remove text embeddings prefix)
hidden_states = hidden_states[:, text_seq_len:, ...]
# For edit mode: extract only the generated portion (exclude conditioning latents)
if conditioning_latents is not None:
hidden_states = hidden_states[:, :original_latent_tokens, ...]
if self.is_last_stage:
hidden_states = self.adapter.final_projection(
hidden_states, text_embeddings
)
hidden_states = config.scheduler.step(
noise=hidden_states,
timestep=t,
latents=prev_latents,
)
if not self.is_first_stage:
mx.eval(mx.distributed.send(hidden_states, 0, group=self.group))
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
# For shape correctness
hidden_states = prev_latents
return hidden_states
def _async_pipeline_step(
self,
t: int,
config: Config,
latents: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
patch_latents = self._async_pipeline(
t,
config,
patch_latents,
token_indices,
prompt_data,
kontext_image_ids,
)
return mx.concatenate(patch_latents, axis=1)
def _async_pipeline(
self,
t: int,
config: Config,
patch_latents: list[mx.array],
token_indices: list[tuple[int, int]],
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> list[mx.array]:
"""Execute async pipeline for all patches."""
assert self.joint_kv_caches is not None
assert self.single_kv_caches is not None
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
prompt_embeds = prompt_data.prompt_embeds
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
extra_kwargs = prompt_data.get_extra_forward_kwargs()
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
kontext_image_ids=kontext_image_ids,
**extra_kwargs,
)
batch_size = patch_latents[0].shape[0]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
for patch_idx, patch in enumerate(patch_latents):
patch_prev = patch
start_token, end_token = token_indices[patch_idx]
if self.has_joint_blocks:
if (
not self.is_first_stage
or t != config.init_time_step + self.num_sync_steps
):
if self.is_first_stage:
# First stage receives latent-space from last stage (scheduler output)
recv_template = patch
else:
# Other stages receive hidden-space from previous stage
patch_len = patch.shape[1]
recv_template = mx.zeros(
(batch_size, patch_len, hidden_dim),
dtype=patch.dtype,
)
patch = mx.distributed.recv_like(
recv_template, src=self.prev_rank, group=self.group
)
mx.eval(patch)
patch_latents[patch_idx] = patch
if not self.is_first_stage and patch_idx == 0:
enc_template = mx.zeros(
(batch_size, text_seq_len, hidden_dim),
dtype=patch_latents[0].dtype,
)
encoder_hidden_states = mx.distributed.recv_like(
enc_template, src=self.prev_rank, group=self.group
)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
patch, prompt_embeds
)
# Run assigned joint blocks with patched mode
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.joint_kv_caches[block_idx],
mode=BlockWrapperMode.PATCHED,
patch_start=start_token,
patch_end=end_token,
**extra_kwargs,
)
if self.owns_concat_stage:
patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)
if self.has_single_blocks or self.is_last_stage:
# Keep locally: either for single blocks or final projection
patch = patch_concat
else:
mx.eval(
mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
)
elif self.has_joint_blocks and not self.is_last_stage:
mx.eval(mx.distributed.send(patch, self.next_rank, group=self.group))
if patch_idx == 0:
mx.eval(
mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
recv_template = mx.zeros(
[
batch_size,
text_seq_len + patch_latents[patch_idx].shape[1],
hidden_dim,
],
dtype=patch_latents[0].dtype,
)
patch = mx.distributed.recv_like(
recv_template, src=self.prev_rank, group=self.group
)
mx.eval(patch)
patch_latents[patch_idx] = patch
# Run assigned single blocks with patched mode
for block_idx, wrapper in enumerate(self.single_block_wrappers):
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.single_kv_caches[block_idx],
mode=BlockWrapperMode.PATCHED,
patch_start=start_token,
patch_end=end_token,
)
if not self.is_last_stage:
mx.eval(
mx.distributed.send(patch, self.next_rank, group=self.group)
)
if self.is_last_stage:
patch_img_only = patch[:, text_seq_len:, :]
patch_img_only = self.adapter.final_projection(
patch_img_only, text_embeddings
)
patch = config.scheduler.step(
noise=patch_img_only,
timestep=t,
latents=patch_prev,
)
if not self.is_first_stage and t != config.num_inference_steps - 1:
mx.eval(
mx.distributed.send(patch, self.next_rank, group=self.group)
)
patch_latents[patch_idx] = patch
return patch_latents

View File

@@ -106,6 +106,7 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
# TODO(ciaran): This is overkill
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
return output

View File

@@ -1,7 +1,7 @@
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper

View File

@@ -8,13 +8,15 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
EventId,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
@@ -30,6 +32,7 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
DownloadModel,
ImageEdits,
Shutdown,
Task,
TaskStatus,
@@ -95,6 +98,10 @@ class Worker:
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
async def run(self):
logger.info("Starting Worker")
@@ -173,6 +180,17 @@ class Worker:
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
)
async def plan_step(self):
while True:
await anyio.sleep(0.1)
@@ -185,6 +203,8 @@ class Worker:
self.state.instances,
self.state.runners,
self.state.tasks,
self.input_chunk_buffer,
self.input_chunk_counts,
)
if task is None:
continue
@@ -248,6 +268,42 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
chunks = self.input_chunk_buffer.get(cmd_id, {})
assembled = "".join(chunks[i] for i in range(len(chunks)))
logger.info(
f"Assembled input image from {len(chunks)} chunks, "
f"total size: {len(assembled)} bytes"
)
# Create modified task with assembled image data
modified_task = ImageEdits(
task_id=task.task_id,
command_id=task.command_id,
instance_id=task.instance_id,
task_status=task.task_status,
task_params=ImageEditsInternalParams(
image_data=assembled,
total_input_chunks=task.task_params.total_input_chunks,
prompt=task.task_params.prompt,
model=task.task_params.model,
n=task.task_params.n,
quality=task.task_params.quality,
output_format=task.task_params.output_format,
response_format=task.task_params.response_format,
size=task.task_params.size,
image_strength=task.task_params.image_strength,
),
)
# Cleanup buffers
if cmd_id in self.input_chunk_buffer:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)

View File

@@ -2,13 +2,15 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadModel,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -49,6 +51,8 @@ def plan(
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_counts: Mapping[CommandId, int] | None = None,
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
@@ -58,7 +62,7 @@ def plan(
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
)
@@ -262,14 +266,24 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
if not isinstance(task, ChatCompletion):
# TODO(ciaran): do this better!
if not isinstance(task, (ChatCompletion, ImageGeneration, ImageEdits)):
continue
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
continue
# For ImageEdits tasks, verify all input chunks have been received
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue

View File

@@ -1,9 +1,12 @@
import base64
import time
import mlx.core as mx
from exo.master.api import get_model_card
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -11,9 +14,12 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelTask
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -23,6 +29,8 @@ from exo.shared.types.tasks import (
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ImageGenerationResponse,
PartialImageResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -39,6 +47,13 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.image import (
ImageGenerator,
generate_image,
initialize_image_model,
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -47,6 +62,10 @@ from exo.worker.engines.mlx.utils_mlx import (
)
from exo.worker.runner.bootstrap import logger
from exo.shared.types.common import CommandId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.shards import ShardMetadata
def main(
bound_instance: BoundInstance,
@@ -70,6 +89,10 @@ def main(
tokenizer = None
group = None
model_card = get_model_card(shard_metadata.model_meta.model_id)
assert model_card
model_tasks = model_card.tasks
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
event_sender.send(
@@ -109,13 +132,22 @@ def main(
)
)
model, tokenizer = load_mlx_items(bound_instance, group)
# TODO(ciaran): switch
if ModelTask.TextGeneration in model_tasks:
model, tokenizer = load_mlx_items(bound_instance, group)
elif (
ModelTask.TextToImage in model_tasks
or ModelTask.ImageToImage in model_tasks
):
model = initialize_image_model(bound_instance)
else:
raise ValueError(f"Unknown model task(s): {model_card.tasks}")
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -125,21 +157,36 @@ def main(
)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
if ModelTask.TextGeneration in model_tasks:
assert model and isinstance(model, Model)
assert tokenizer
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
elif (
ModelTask.TextToImage in model_tasks
or ModelTask.ImageToImage in model_tasks
):
assert isinstance(model, ImageGenerator)
image = warmup_image_generator(model=model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
logger.info("warmup completed (non-primary node)")
current_status = RunnerReady()
logger.info("runner ready")
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert model
assert model and isinstance(model, Model)
assert tokenizer
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
@@ -177,6 +224,81 @@ def main(
# case TokenizedResponse():
# TODO: something here ig
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, ImageGenerator)
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(model=model, task=task_params):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, ImageGenerator)
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
image_index = 0
for response in generate_image(model=model, task=task_params):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
match response:
case ImageGenerationResponse():
logger.info("sending ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
case PartialImageResponse():
pass # Image edits don't support partial images
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
@@ -207,6 +329,63 @@ def main(
break
def _send_image_chunk(
encoded_data: str,
command_id: CommandId,
model_id: ModelId,
event_sender: MpSender[Event],
image_index: int,
is_partial: bool,
partial_index: int | None = None,
total_partials: int | None = None,
) -> None:
"""Send base64-encoded image data as chunks via events."""
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
for chunk_index, chunk_data in enumerate(data_chunks):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
is_partial=is_partial,
partial_index=partial_index,
total_partials=total_partials,
),
)
)
def _process_image_response(
response: ImageGenerationResponse | PartialImageResponse,
command_id: CommandId,
shard_metadata: ShardMetadata,
event_sender: MpSender[Event],
image_index: int,
) -> None:
"""Process a single image response and send chunks."""
encoded_data = base64.b64encode(response.image_data).decode("utf-8")
is_partial = isinstance(response, PartialImageResponse)
_send_image_chunk(
encoded_data=encoded_data,
command_id=command_id,
model_id=shard_metadata.model_meta.model_id,
event_sender=event_sender,
image_index=response.partial_index if is_partial else image_index,
is_partial=is_partial,
partial_index=response.partial_index if is_partial else None,
total_partials=response.total_partials if is_partial else None,
)
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

2371
uv.lock generated
View File

File diff suppressed because it is too large Load Diff