Compare commits

...

261 Commits

Author SHA1 Message Date
ciaranbor
3ae9bc3d19 Add endpoint to list images 2026-01-20 22:23:01 +00:00
ciaranbor
4b21c3c321 Support image url response 2026-01-20 22:02:08 +00:00
ciaranbor
fa15de1e82 Use public snapshot_download 2026-01-20 22:00:37 +00:00
ciaranbor
7c21e096e8 Support jpeg outputs 2026-01-20 20:51:48 +00:00
ciaranbor
2a1b9928fd Error on malformed size parameter 2026-01-20 20:04:31 +00:00
ciaranbor
299048e269 Use ModelId 2026-01-20 18:47:51 +00:00
ciaranbor
418e057a32 Resolve mflux type errors 2026-01-20 18:47:51 +00:00
ciaranbor
87045dda79 Generate mflux type stubs 2026-01-20 18:47:51 +00:00
ciaranbor
4143b98f36 Format FE 2026-01-20 18:47:51 +00:00
ciaranbor
3c598492f2 Prevent running image editing model without an input image 2026-01-20 18:47:51 +00:00
ciaranbor
9f421f6790 Type coercion for ModelTask 2026-01-20 18:47:51 +00:00
ciaranbor
15c63def3d Fallback for resolving model card 2026-01-20 18:47:51 +00:00
ciaranbor
d10220d9d3 Reflect ModelCard simplification 2026-01-20 18:47:51 +00:00
ciaranbor
03415b57d2 Fix text model runtime check 2026-01-20 18:47:51 +00:00
ciaranbor
de14348693 Fix image streaming for editing 2026-01-20 18:47:51 +00:00
ciaranbor
88566521d0 Propagate additional image edit api params 2026-01-20 18:47:50 +00:00
ciaranbor
5c21959908 Support image editing in UI 2026-01-20 18:47:50 +00:00
ciaranbor
e1d4ded4c7 Allow dropdowns to fit values 2026-01-20 18:47:50 +00:00
ciaranbor
835d5428a2 Better param dropdowns 2026-01-20 18:47:50 +00:00
ciaranbor
e6a6776769 Better typing 2026-01-20 18:47:50 +00:00
ciaranbor
d50dd6e35a Restore recv evals 2026-01-20 18:47:50 +00:00
ciaranbor
a65ed8cf41 Remove outputFormat param 2026-01-20 18:47:50 +00:00
ciaranbor
8f448a8af9 Expose image generation settings to UI 2026-01-20 18:47:50 +00:00
ciaranbor
3286f5ff4a DiffusionRunner type errors 2026-01-20 18:47:50 +00:00
ciaranbor
3c609ab667 Correctly handle num_sync_steps for diffusion steps override 2026-01-20 18:47:50 +00:00
ciaranbor
68e60bbcbc Simplify DistributedImageModel 2026-01-20 18:47:50 +00:00
ciaranbor
d2c0f1a6cb Hide internal mflux image type 2026-01-20 18:47:50 +00:00
ciaranbor
3a1571f33a Only run two steps during warmup 2026-01-20 18:47:50 +00:00
ciaranbor
f5a28918c3 Document image apis 2026-01-20 18:47:50 +00:00
ciaranbor
1ab0aa5e1e Clean up ImageModelConfig 2026-01-20 18:47:50 +00:00
ciaranbor
77186fda80 Add AdvancedImageParams to api 2026-01-20 18:47:50 +00:00
ciaranbor
b52ed4803f Capture partial images earlier 2026-01-20 18:47:50 +00:00
ciaranbor
8607f85708 Remove redundant tensor allocations for recv templates 2026-01-20 18:47:50 +00:00
ciaranbor
4c91c3ab81 Remove recv evals, use async_eval for sends 2026-01-20 18:47:50 +00:00
ciaranbor
c322174bff Remove ImageGenerator protocol 2026-01-20 18:47:50 +00:00
ciaranbor
289f058d52 Batch CFG 2026-01-20 18:47:50 +00:00
ciaranbor
4a998b9d05 Add image generation benchmarking endpoints 2026-01-20 18:47:50 +00:00
ciaranbor
94b88da349 Consolidate patched and unpatched qkv computation logic 2026-01-20 18:47:50 +00:00
ciaranbor
924542846a Use mixin for common block wrapper functionality 2026-01-20 18:47:50 +00:00
ciaranbor
db9f038beb Move last rank to first rank comms outside the CFG step 2026-01-20 18:47:50 +00:00
ciaranbor
e0a5e9178b Revert 2026-01-20 18:47:50 +00:00
ciaranbor
21989572c9 Run all positive then all negative 2026-01-20 18:47:50 +00:00
ciaranbor
30e21e29ef Rank 0 shouldn't receive on negative pass 2026-01-20 18:47:50 +00:00
ciaranbor
e1f7500d3b Fix negative pass text_seq_len 2026-01-20 18:47:50 +00:00
ciaranbor
6896887a74 Add distributed CFG support 2026-01-20 18:47:50 +00:00
ciaranbor
3de6cbf2b9 Enable CFG for Qwen-Image 2026-01-20 18:47:50 +00:00
ciaranbor
05faaebf2f Use transformer block wrapper classes 2026-01-20 18:47:50 +00:00
ciaranbor
204a323951 Refactor 2026-01-20 18:47:50 +00:00
ciaranbor
d19cddea2a Fix flux tokenizer 2026-01-20 18:47:50 +00:00
ciaranbor
a8ca6163b2 Reduce image generation and image edits code duplication 2026-01-20 18:47:50 +00:00
ciaranbor
5592412d43 Update mflux to 0.14.2 2026-01-20 18:47:50 +00:00
ciaranbor
125972bc50 Add python-multipart dependency 2026-01-20 18:47:50 +00:00
ciaranbor
3560e7a138 Linting 2026-01-20 18:47:50 +00:00
ciaranbor
1445435512 Start image editing time steps at 0 2026-01-20 18:47:50 +00:00
ciaranbor
29060750b5 Ignore image_strength 2026-01-20 18:47:50 +00:00
ciaranbor
78f22c6547 Handle conditioning latents in sync pipeline 2026-01-20 18:47:50 +00:00
ciaranbor
721318e6cc Use dummy image for editing warmup 2026-01-20 18:47:50 +00:00
ciaranbor
0557b08775 Support streaming for image editing 2026-01-20 18:47:50 +00:00
ciaranbor
df440eeb15 Support image editing in runner 2026-01-20 18:47:50 +00:00
ciaranbor
8efc02822e Add editing features to adapter 2026-01-20 18:47:50 +00:00
ciaranbor
6bfd6d401f Default partial images to 3 if streaming 2026-01-20 18:47:50 +00:00
ciaranbor
cf48c2dd73 Add Qwen-Image model adapter 2026-01-20 18:47:50 +00:00
ciaranbor
99ad01a636 Add Qwen-Image-Edit model config 2026-01-20 18:47:50 +00:00
ciaranbor
edfcddbfee Use image generation in streaming mode in UI 2026-01-20 18:47:50 +00:00
ciaranbor
475aa66fe2 Handle partial image streaming 2026-01-20 18:47:50 +00:00
ciaranbor
7e11aca950 Add streaming params to ImageGenerationTaskParams 2026-01-20 18:47:50 +00:00
ciaranbor
69f7a36a7e Add Qwen-Image-Edit-2509 2026-01-20 18:47:50 +00:00
ciaranbor
4dd4ef696a Handle image editing time steps 2026-01-20 18:47:50 +00:00
ciaranbor
039383b687 Fix time steps 2026-01-20 18:47:50 +00:00
ciaranbor
e75b37f833 Fix image_strength meaning 2026-01-20 18:47:50 +00:00
ciaranbor
08f77df2e9 Truncate image data logs 2026-01-20 18:47:50 +00:00
ciaranbor
ec41250577 Chunk image input 2026-01-20 18:47:50 +00:00
ciaranbor
9e264b24e7 Avoid logging image data 2026-01-20 18:47:50 +00:00
ciaranbor
bb2d4765b9 Support image editing 2026-01-20 18:47:50 +00:00
Sami Khan
b8208d18ad small UI change 2026-01-20 18:47:50 +00:00
Sami Khan
06490c0bcd image gen in dashboard 2026-01-20 18:47:50 +00:00
ciaranbor
73a6a86e19 Better llm model type check 2026-01-20 18:47:50 +00:00
ciaranbor
2d2db66ccd Prune blocks before model load 2026-01-20 18:47:50 +00:00
ciaranbor
56b190af5a Own TODOs 2026-01-20 18:47:50 +00:00
ciaranbor
50e9254e9f Remove double RunnerReady event 2026-01-20 18:47:50 +00:00
ciaranbor
77f25a62a2 Fix hidden_size for image models 2026-01-20 18:47:50 +00:00
ciaranbor
72fe14bbe7 Fix image model cards 2026-01-20 18:47:50 +00:00
ciaranbor
251d71044c Skip decode on non-final ranks 2026-01-20 18:47:50 +00:00
ciaranbor
4271beecf4 Final rank produces image 2026-01-20 18:47:50 +00:00
ciaranbor
0a4beb25bd Increase number of sync steps 2026-01-20 18:47:50 +00:00
ciaranbor
bb2c517348 Change Qwen-Image steps 2026-01-20 18:47:50 +00:00
ciaranbor
9c8e3c048c Fix Qwen-Image latent shapes 2026-01-20 18:47:50 +00:00
ciaranbor
30808799bc Fix joint block patch recv shape for non-zero ranks 2026-01-20 18:47:50 +00:00
ciaranbor
333364d4ac Fix comms issue for models without single blocks 2026-01-20 18:47:50 +00:00
ciaranbor
9dc5e8c72e Support Qwen in DiffusionRunner pipefusion 2026-01-20 18:47:50 +00:00
ciaranbor
b4ecdc82b6 Implement Qwen pipefusion 2026-01-20 18:47:50 +00:00
ciaranbor
a34fb92e31 Add guidance_scale parameter to image model config 2026-01-20 18:47:50 +00:00
ciaranbor
ea427d9451 Move orchestration to DiffusionRunner 2026-01-20 18:47:50 +00:00
ciaranbor
607525587d Add initial QwenModelAdapter 2026-01-20 18:47:50 +00:00
ciaranbor
5db6e85694 Tweak embeddings interface 2026-01-20 18:47:50 +00:00
ciaranbor
9d1b966345 Add Qwen ImageModelConfig 2026-01-20 18:47:50 +00:00
ciaranbor
3adac330c2 Use 10% sync steps 2026-01-20 18:47:50 +00:00
ciaranbor
33507c950b Update FluxModelAdaper for new interface 2026-01-20 18:47:50 +00:00
ciaranbor
b7b946cb4b Register QwenModelAdapter 2026-01-20 18:47:50 +00:00
ciaranbor
d8adbb293b Support multiple forward passes in runner 2026-01-20 18:47:50 +00:00
ciaranbor
2ca20a6ae7 Extend block wrapper parameters 2026-01-20 18:47:50 +00:00
ciaranbor
e3ca51d9c9 Relax adaptor typing 2026-01-20 18:47:50 +00:00
ciaranbor
89760f8a85 Add Qwen-Image model card 2026-01-20 18:47:50 +00:00
ciaranbor
383c8cc330 Clean up dead code 2026-01-20 18:47:50 +00:00
ciaranbor
f6b604c0d9 Add BaseModelAdaptor 2026-01-20 18:47:50 +00:00
ciaranbor
b0727d65d8 Refactor filestructure 2026-01-20 18:47:50 +00:00
ciaranbor
cbf477b2af Treat unified blocks as single blocks (equivalent) 2026-01-20 18:47:50 +00:00
ciaranbor
538c7ae48d Refactor to handle entire denoising process in Diffusion runner 2026-01-20 18:47:50 +00:00
ciaranbor
61f4135c50 Move transformer to adapter 2026-01-20 18:47:50 +00:00
ciaranbor
6b58c4eac4 Move some more logic to adaptor 2026-01-20 18:47:50 +00:00
ciaranbor
c20c9d27f6 Add generic block wrapper 2026-01-20 18:47:50 +00:00
ciaranbor
056e1f6550 Access transformer blocks from adaptor 2026-01-20 18:47:50 +00:00
ciaranbor
38945c9139 Better typing 2026-01-20 18:47:50 +00:00
ciaranbor
c98f5ca7d2 Create wrappers at init time 2026-01-20 18:47:50 +00:00
ciaranbor
6a3b846cc1 Combine model factory and adaptor 2026-01-20 18:47:50 +00:00
ciaranbor
e1aed70321 Implement model factory 2026-01-20 18:47:50 +00:00
ciaranbor
fba8a99be8 Add adaptor registry 2026-01-20 18:47:50 +00:00
ciaranbor
d0de338d2b Remove mflux/generator/generate.py 2026-01-20 18:47:50 +00:00
ciaranbor
9b39c30498 Switch to using DistributedImageModel 2026-01-20 18:47:50 +00:00
ciaranbor
f939cb0f38 Add DistributedImageModel 2026-01-20 18:47:50 +00:00
ciaranbor
271ccbe50b Use new generic wrappers, etc in denoising 2026-01-20 18:47:50 +00:00
ciaranbor
cc8cac7440 Add generic transformer block wrappers 2026-01-20 18:47:50 +00:00
ciaranbor
1160fbf002 Add FluxAdaptor 2026-01-20 18:47:50 +00:00
ciaranbor
7610df2d85 Add ModelAdaptor, derivations implement model specific logic 2026-01-20 18:47:50 +00:00
ciaranbor
fc64a82cd7 Introduce image model config concept 2026-01-20 18:47:50 +00:00
ciaranbor
485c1a5188 Consolidate kv cache patching 2026-01-20 18:47:50 +00:00
ciaranbor
fbb99c3743 Support different configuration comms 2026-01-20 18:47:50 +00:00
ciaranbor
5c9bc7b730 Add ImageGenerator protocol 2026-01-20 18:47:50 +00:00
ciaranbor
4c5be18a81 Force final patch receive order 2026-01-20 18:47:50 +00:00
ciaranbor
55a6667ca3 Remove logs 2026-01-20 18:47:50 +00:00
ciaranbor
e6c5d440b0 Update patch list 2026-01-20 18:47:50 +00:00
ciaranbor
a764ce063d Slight refactor 2026-01-20 18:47:50 +00:00
ciaranbor
32eabb1379 Don't need array for prev patches 2026-01-20 18:47:50 +00:00
ciaranbor
ab45d13a2e Fix send/recv order 2026-01-20 18:47:50 +00:00
ciaranbor
eac12d6bf8 Fix async single transformer block 2026-01-20 18:47:50 +00:00
ciaranbor
c232818847 Use relative rank variables 2026-01-20 18:47:50 +00:00
ciaranbor
66822d8a4e Fix writing patches 2026-01-20 18:47:50 +00:00
ciaranbor
d5f7b67692 Collect final image 2026-01-20 18:47:50 +00:00
ciaranbor
0415f9efab Fix recv_template shape 2026-01-20 18:47:50 +00:00
ciaranbor
1d7bc6436f Add logs 2026-01-20 18:47:50 +00:00
ciaranbor
472f03e8a6 Optimise async pipeline 2026-01-20 18:47:49 +00:00
ciaranbor
64a7b83ccb Add next_rank and prev_rank members 2026-01-20 18:47:49 +00:00
ciaranbor
7e6e361da5 Add _create_patches method 2026-01-20 18:47:49 +00:00
ciaranbor
4b5b7fdab9 Fix shapes 2026-01-20 18:47:49 +00:00
ciaranbor
b5e7898dd7 Reorder comms 2026-01-20 18:47:49 +00:00
ciaranbor
ea6d2cf306 Remove all_gather from sync pipeline, send from final rank to first rank 2026-01-20 18:47:49 +00:00
ciaranbor
0e98e14479 Simplify kv_cache initialization 2026-01-20 18:47:49 +00:00
ciaranbor
1e52ca82cf Fix kv cache 2026-01-20 18:47:49 +00:00
ciaranbor
142adb05a3 Clean up kv caches 2026-01-20 18:47:49 +00:00
ciaranbor
d12fef300e Fix return 2026-01-20 18:47:49 +00:00
ciaranbor
08555b3603 Fix hidden_states shapes 2026-01-20 18:47:49 +00:00
ciaranbor
c9fd53c99c Only perform projection and scheduler step on last rank 2026-01-20 18:47:49 +00:00
ciaranbor
35f93e22d2 Only compute embeddings on rank 0 2026-01-20 18:47:49 +00:00
ciaranbor
805550335f Remove eval 2026-01-20 18:47:49 +00:00
ciaranbor
1d18cb86a0 Remove eval 2026-01-20 18:47:49 +00:00
ciaranbor
7957b0593e Only send encoder_hidden_states with the first patch (once per timestep) 2026-01-20 18:47:49 +00:00
ciaranbor
90c7441a21 Remove redundant text kv cache computation 2026-01-20 18:47:49 +00:00
ciaranbor
15538ccd4f Concatenate before all gather 2026-01-20 18:47:49 +00:00
ciaranbor
7f32e96992 Increase number of sync steps 2026-01-20 18:47:49 +00:00
ciaranbor
b5aba378fa Reinitialise kv_caches between generations 2026-01-20 18:47:49 +00:00
ciaranbor
203b89087b Eliminate double kv cache computation 2026-01-20 18:47:49 +00:00
ciaranbor
1f05d5815a Add kv cache caching wrappers for sync pipeline transformer blocks 2026-01-20 18:47:49 +00:00
ciaranbor
ce0c96f95c Persist kv caches 2026-01-20 18:47:49 +00:00
ciaranbor
cf3ba67b37 Implement naive async pipeline implementation 2026-01-20 18:47:49 +00:00
ciaranbor
6d23adac0f Use wrapper classes for patched transformer logic 2026-01-20 18:47:49 +00:00
ciaranbor
f3bb7ade50 Add patch-aware joint and single attention wrappers 2026-01-20 18:47:49 +00:00
ciaranbor
0da1f9c334 Fix group.size() 2026-01-20 18:47:49 +00:00
ciaranbor
a30c099c92 Add classes to manage kv caches with patch support 2026-01-20 18:47:49 +00:00
ciaranbor
6824e90704 Use heuristic for number of sync steps 2026-01-20 18:47:49 +00:00
ciaranbor
cdadd7f2b5 Generalise number of denoising steps 2026-01-20 18:47:49 +00:00
ciaranbor
d9afd0f431 Add flux1-dev 2026-01-20 18:47:49 +00:00
ciaranbor
5e11eaa582 Move scheduler step to inner pipeline 2026-01-20 18:47:49 +00:00
ciaranbor
60422ea038 Add barrier before all_gather 2026-01-20 18:47:49 +00:00
ciaranbor
c9941cf4aa Fix transformer blocks pruning 2026-01-20 18:47:49 +00:00
ciaranbor
cb2e32db77 Fix image generation api 2026-01-20 18:47:49 +00:00
ciaranbor
a154a48b21 Create queue in try block 2026-01-20 18:47:49 +00:00
ciaranbor
547dbdc5f0 Conform to rebase 2026-01-20 18:47:49 +00:00
ciaranbor
38bb5d05f1 Refactor denoising 2026-01-20 18:47:49 +00:00
ciaranbor
170261f03e Move more logic to DistributedFlux 2026-01-20 18:47:49 +00:00
ciaranbor
59319551d8 Move surrounding logic back to _sync_pipeline 2026-01-20 18:47:49 +00:00
ciaranbor
e053425afa Add patching aware member variables 2026-01-20 18:47:49 +00:00
ciaranbor
b6742db278 Implement sync/async switching logic 2026-01-20 18:47:49 +00:00
ciaranbor
9b4218c2e1 Move current transformer implementation to _sync_pipeline method 2026-01-20 18:47:49 +00:00
ciaranbor
ea11918fde Remove some logs 2026-01-20 18:47:49 +00:00
ciaranbor
6aa02825bb Remove old Flux1 implementation 2026-01-20 18:47:49 +00:00
ciaranbor
ae900878a5 Prune unused transformer blocks 2026-01-20 18:47:49 +00:00
ciaranbor
94e10dca92 Add mx.eval 2026-01-20 18:47:49 +00:00
ciaranbor
68a4c7e7f5 Test evals 2026-01-20 18:47:49 +00:00
ciaranbor
785432d330 Test only barriers 2026-01-20 18:47:49 +00:00
ciaranbor
0ed51090ca All perform final projection 2026-01-20 18:47:49 +00:00
ciaranbor
edeb974a8a Another barrier 2026-01-20 18:47:49 +00:00
ciaranbor
f57910dd5e More debug 2026-01-20 18:47:49 +00:00
ciaranbor
4f97de1cc0 Add barriers 2026-01-20 18:47:49 +00:00
ciaranbor
1245223b11 Add log 2026-01-20 18:47:49 +00:00
ciaranbor
68b5d8964c Restore distributed logging 2026-01-20 18:47:49 +00:00
ciaranbor
3fa1f99393 Use bootstrap logger 2026-01-20 18:47:49 +00:00
ciaranbor
5ae7d01a64 Remove logs 2026-01-20 18:47:49 +00:00
ciaranbor
db4d3bf9aa fix single block receive shape 2026-01-20 18:47:49 +00:00
ciaranbor
ac001b0d81 Add debug logs 2026-01-20 18:47:49 +00:00
ciaranbor
bdd96b9053 Move communication logic to DistributedTransformer wrapper 2026-01-20 18:47:49 +00:00
ciaranbor
44310d680a Move inference logic to DistribuedFlux1 2026-01-20 18:47:49 +00:00
ciaranbor
5ec028a536 Add DistributedFlux1 class 2026-01-20 18:47:49 +00:00
ciaranbor
e3c0c045c4 Rename pipeline to pipefusion 2026-01-20 18:47:49 +00:00
ciaranbor
dbd7a8346a Further refactor 2026-01-20 18:47:49 +00:00
ciaranbor
fd72be4886 Refactor warmup 2026-01-20 18:47:49 +00:00
ciaranbor
51d3432f87 Manually handle flux1 inference 2026-01-20 18:47:49 +00:00
ciaranbor
d75087bd3a Refactor flux1 image generation 2026-01-20 18:47:49 +00:00
ciaranbor
65739cddee Use quality parameter to set number of inference steps 2026-01-20 18:47:49 +00:00
ciaranbor
64c6121f2c Chunk image data transfer 2026-01-20 18:47:49 +00:00
ciaranbor
2eecbca71f Define EXO_MAX_CHUNK_SIZE 2026-01-20 18:47:49 +00:00
ciaranbor
0de3b7ef71 Add indexing info to ImageChunk 2026-01-20 18:47:49 +00:00
ciaranbor
f311d6525f Remove sharding logs 2026-01-20 18:47:49 +00:00
ciaranbor
270cbab8b6 Temp: reduce flux1.schnell storage size 2026-01-20 18:47:49 +00:00
ciaranbor
afc2d8234a Fix mflux transformer all_gather 2026-01-20 18:47:49 +00:00
ciaranbor
2ca928a525 Fix world size 2026-01-20 18:47:49 +00:00
ciaranbor
c3f761df2f Fix transition block? 2026-01-20 18:47:49 +00:00
ciaranbor
59cc3f6138 Implement image generation warmup 2026-01-20 18:47:49 +00:00
ciaranbor
af76cbf35b Add logs 2026-01-20 18:47:49 +00:00
ciaranbor
cb743c383c Add spiece.model to default patterns 2026-01-20 18:47:49 +00:00
ciaranbor
44b087ec39 Just download all files for now 2026-01-20 18:47:49 +00:00
ciaranbor
899bfb3a65 Fix get_allow_patterns to include non-indexed safetensors files 2026-01-20 18:47:49 +00:00
ciaranbor
69ca95e572 Use half-open layer indexing in get_allow_patterns 2026-01-20 18:47:49 +00:00
ciaranbor
d7aca65488 Enable distributed mflux 2026-01-20 18:47:49 +00:00
ciaranbor
95992b87ec Implement mflux transformer sharding and communication pattern 2026-01-20 18:47:49 +00:00
ciaranbor
e3e69b6504 Update get_allow_patterns to handle sharding components 2026-01-20 18:47:49 +00:00
ciaranbor
da123b0c4b Namespace both keys and values for component weight maps 2026-01-20 18:47:49 +00:00
ciaranbor
23e1014972 Add components to Flux.1-schnell MODEL_CARD 2026-01-20 18:47:49 +00:00
ciaranbor
4419271362 Add component concept for ModelMetadata 2026-01-20 18:47:49 +00:00
ciaranbor
3ddd53b3fd Fix multiple components weight map key conflicts 2026-01-20 18:47:49 +00:00
ciaranbor
7dc5449003 get_weight_map: handle repos with multiple safetensors.index.json files 2026-01-20 18:47:49 +00:00
ciaranbor
2302240a45 Add initial image edits spec 2026-01-20 18:47:49 +00:00
ciaranbor
ba4dbf3cb7 Add image edits endpoint 2026-01-20 18:47:49 +00:00
ciaranbor
58eb74be89 Add ImageToImage task 2026-01-20 18:47:49 +00:00
ciaranbor
d05e9643aa Allow ModelCards to have multiple tasks 2026-01-20 18:47:49 +00:00
ciaranbor
7d0bfaf0f9 Fix text generation 2026-01-20 18:47:49 +00:00
ciaranbor
787f85fa71 Rename mlx_generate_image to mflux_generate 2026-01-20 18:47:49 +00:00
ciaranbor
39a167ca25 Initialize mlx or mflux engine based on model task 2026-01-20 18:47:49 +00:00
ciaranbor
47c2721254 Restore warmup for text generation 2026-01-20 18:47:49 +00:00
ciaranbor
9044b3ea99 Add initialize_mflux function 2026-01-20 18:47:49 +00:00
ciaranbor
d6087d5f5f Move image generation to mflux engine 2026-01-20 18:47:49 +00:00
ciaranbor
2dee324175 Just use str for image generation size 2026-01-20 18:47:49 +00:00
ciaranbor
8b93bc5151 Use MFlux for image generation 2026-01-20 18:47:49 +00:00
ciaranbor
3214c014bf Add get_model_card function 2026-01-20 18:47:49 +00:00
ciaranbor
8ed8ff7afb Add ModelTask enum 2026-01-20 18:47:49 +00:00
ciaranbor
6498756540 ADd flux1-schnell model 2026-01-20 18:47:49 +00:00
ciaranbor
9c7ea6a7af Add task field to ModelCard 2026-01-20 18:47:49 +00:00
ciaranbor
a4a965c80c Update mflux version 2026-01-20 18:47:49 +00:00
ciaranbor
99c891ca6d Enable recursive repo downloads 2026-01-20 18:47:49 +00:00
ciaranbor
ff2b67966a Add dummy generate_image implementation 2026-01-20 18:47:49 +00:00
ciaranbor
680d1706a7 Use base64 encoded str for image data 2026-01-20 18:47:49 +00:00
ciaranbor
3a669eb5d9 Handle ImageGeneration tasks in _pending_tasks 2026-01-20 18:47:49 +00:00
ciaranbor
970f62e645 Add mflux dependency 2026-01-20 18:47:49 +00:00
ciaranbor
5c61396260 Handle ImageGeneration task in runner task processing 2026-01-20 18:47:49 +00:00
ciaranbor
db8c8706b4 Handle ImageGeneration command in master command processing 2026-01-20 18:47:49 +00:00
ciaranbor
bdd2cb97f4 Add image generation to API 2026-01-20 18:47:49 +00:00
ciaranbor
cdb9557cc7 Add ImageGenerationResponse 2026-01-20 18:47:49 +00:00
ciaranbor
a618adaf3d Add ImageGeneration task 2026-01-20 18:47:49 +00:00
ciaranbor
1ef9184468 Add ImageGeneration command 2026-01-20 18:47:49 +00:00
ciaranbor
14bb82dc5b Add image generation params and response types 2026-01-20 18:47:49 +00:00
ciaranbor
b824f7a60d Add pillow dependency 2026-01-20 18:47:49 +00:00
ciaranbor
fe9ff17e8b Fix mlx stream_generate import 2026-01-20 18:47:49 +00:00
272 changed files with 12418 additions and 941 deletions

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
import os
if "TOKENIZERS_PARALLELISM" not in os.environ: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,47 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import PIL.Image
import tqdm
from typing import Protocol
from mflux.models.common.config.config import Config
class BeforeLoopCallback(Protocol):
def call_before_loop(
self,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
canny_image: PIL.Image.Image | None = ...,
depth_image: PIL.Image.Image | None = ...,
) -> None: ...
class InLoopCallback(Protocol):
def call_in_loop(
self,
t: int,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
time_steps: tqdm,
) -> None: ...
class AfterLoopCallback(Protocol):
def call_after_loop(
self, seed: int, prompt: str, latents: mx.array, config: Config
) -> None: ...
class InterruptCallback(Protocol):
def call_interrupt(
self,
t: int,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
time_steps: tqdm,
) -> None: ...

View File

@@ -0,0 +1,24 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.callbacks.callback import (
AfterLoopCallback,
BeforeLoopCallback,
InLoopCallback,
InterruptCallback,
)
from mflux.callbacks.generation_context import GenerationContext
from mflux.models.common.config.config import Config
if TYPE_CHECKING: ...
class CallbackRegistry:
def __init__(self) -> None: ...
def register(self, callback) -> None: ...
def start(self, seed: int, prompt: str, config: Config) -> GenerationContext: ...
def before_loop_callbacks(self) -> list[BeforeLoopCallback]: ...
def in_loop_callbacks(self) -> list[InLoopCallback]: ...
def after_loop_callbacks(self) -> list[AfterLoopCallback]: ...
def interrupt_callbacks(self) -> list[InterruptCallback]: ...

View File

@@ -0,0 +1,29 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import PIL.Image
import tqdm
from typing import TYPE_CHECKING
from mflux.callbacks.callback_registry import CallbackRegistry
from mflux.models.common.config.config import Config
if TYPE_CHECKING: ...
class GenerationContext:
def __init__(
self, registry: CallbackRegistry, seed: int, prompt: str, config: Config
) -> None: ...
def before_loop(
self,
latents: mx.array,
*,
canny_image: PIL.Image.Image | None = ...,
depth_image: PIL.Image.Image | None = ...,
) -> None: ...
def in_loop(self, t: int, latents: mx.array, time_steps: tqdm = ...) -> None: ...
def after_loop(self, latents: mx.array) -> None: ...
def interruption(
self, t: int, latents: mx.array, time_steps: tqdm = ...
) -> None: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,22 @@
"""
This type stub file was generated by pyright.
"""
import os
BATTERY_PERCENTAGE_STOP_LIMIT = ...
CONTROLNET_STRENGTH = ...
DEFAULT_DEV_FILL_GUIDANCE = ...
DEFAULT_DEPTH_GUIDANCE = ...
DIMENSION_STEP_PIXELS = ...
GUIDANCE_SCALE = ...
GUIDANCE_SCALE_KONTEXT = ...
IMAGE_STRENGTH = ...
MODEL_CHOICES = ...
MODEL_INFERENCE_STEPS = ...
QUANTIZE_CHOICES = ...
if os.environ.get("MFLUX_CACHE_DIR"):
MFLUX_CACHE_DIR = ...
else:
MFLUX_CACHE_DIR = ...
MFLUX_LORA_CACHE_DIR = ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,8 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.config.config import Config
from mflux.models.common.config.model_config import ModelConfig
__all__ = ["Config", "ModelConfig"]

View File

@@ -0,0 +1,66 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from pathlib import Path
from typing import Any
from tqdm import tqdm
from mflux.models.common.config.model_config import ModelConfig
logger = ...
class Config:
def __init__(
self,
model_config: ModelConfig,
num_inference_steps: int = ...,
height: int = ...,
width: int = ...,
guidance: float = ...,
image_path: Path | str | None = ...,
image_strength: float | None = ...,
depth_image_path: Path | str | None = ...,
redux_image_paths: list[Path | str] | None = ...,
redux_image_strengths: list[float] | None = ...,
masked_image_path: Path | str | None = ...,
controlnet_strength: float | None = ...,
scheduler: str = ...,
) -> None: ...
@property
def height(self) -> int: ...
@property
def width(self) -> int: ...
@width.setter
def width(self, value): # -> None:
...
@property
def image_seq_len(self) -> int: ...
@property
def guidance(self) -> float: ...
@property
def num_inference_steps(self) -> int: ...
@property
def precision(self) -> mx.Dtype: ...
@property
def num_train_steps(self) -> int: ...
@property
def image_path(self) -> Path | None: ...
@property
def image_strength(self) -> float | None: ...
@property
def depth_image_path(self) -> Path | None: ...
@property
def redux_image_paths(self) -> list[Path] | None: ...
@property
def redux_image_strengths(self) -> list[float] | None: ...
@property
def masked_image_path(self) -> Path | None: ...
@property
def init_time_step(self) -> int: ...
@property
def time_steps(self) -> tqdm: ...
@property
def controlnet_strength(self) -> float | None: ...
@property
def scheduler(self) -> Any: ...

View File

@@ -0,0 +1,86 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from functools import lru_cache
from typing import Literal
class ModelConfig:
precision: mx.Dtype = ...
def __init__(
self,
priority: int,
aliases: list[str],
model_name: str,
base_model: str | None,
controlnet_model: str | None,
custom_transformer_model: str | None,
num_train_steps: int | None,
max_sequence_length: int | None,
supports_guidance: bool | None,
requires_sigma_shift: bool | None,
transformer_overrides: dict | None = ...,
) -> None: ...
@staticmethod
@lru_cache
def dev() -> ModelConfig: ...
@staticmethod
@lru_cache
def schnell() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_kontext() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_fill() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_redux() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_depth() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_controlnet_canny() -> ModelConfig: ...
@staticmethod
@lru_cache
def schnell_controlnet_canny() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_controlnet_upscaler() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_fill_catvton() -> ModelConfig: ...
@staticmethod
@lru_cache
def krea_dev() -> ModelConfig: ...
@staticmethod
@lru_cache
def flux2_klein_4b() -> ModelConfig: ...
@staticmethod
@lru_cache
def flux2_klein_9b() -> ModelConfig: ...
@staticmethod
@lru_cache
def qwen_image() -> ModelConfig: ...
@staticmethod
@lru_cache
def qwen_image_edit() -> ModelConfig: ...
@staticmethod
@lru_cache
def fibo() -> ModelConfig: ...
@staticmethod
@lru_cache
def z_image_turbo() -> ModelConfig: ...
@staticmethod
@lru_cache
def seedvr2_3b() -> ModelConfig: ...
def x_embedder_input_dim(self) -> int: ...
def is_canny(self) -> bool: ...
@staticmethod
def from_name(
model_name: str, base_model: Literal["dev", "schnell", "krea-dev"] | None = ...
) -> ModelConfig: ...
AVAILABLE_MODELS = ...

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,49 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from pathlib import Path
from typing import TYPE_CHECKING, TypeAlias
from mlx import nn
from mflux.models.common.vae.tiling_config import TilingConfig
from mflux.models.fibo.latent_creator.fibo_latent_creator import FiboLatentCreator
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.z_image.latent_creator.z_image_latent_creator import (
ZImageLatentCreator,
)
if TYPE_CHECKING:
LatentCreatorType: TypeAlias = type[
FiboLatentCreator | FluxLatentCreator | QwenLatentCreator | ZImageLatentCreator
]
class Img2Img:
def __init__(
self,
vae: nn.Module,
latent_creator: LatentCreatorType,
sigmas: mx.array,
init_time_step: int,
image_path: str | Path | None,
tiling_config: TilingConfig | None = ...,
) -> None: ...
class LatentCreator:
@staticmethod
def create_for_txt2img_or_img2img(
seed: int, height: int, width: int, img2img: Img2Img
) -> mx.array: ...
@staticmethod
def encode_image(
vae: nn.Module,
image_path: str | Path,
height: int,
width: int,
tiling_config: TilingConfig | None = ...,
) -> mx.array: ...
@staticmethod
def add_noise_by_interpolation(
clean: mx.array, noise: mx.array, sigma: float
) -> mx.array: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
from mlx import nn
from mflux.models.common.lora.layer.linear_lora_layer import LoRALinear
class FusedLoRALinear(nn.Module):
def __init__(
self, base_linear: nn.Linear | nn.QuantizedLinear, loras: list[LoRALinear]
) -> None: ...
def __call__(self, x): # -> array:
...

View File

@@ -0,0 +1,22 @@
"""
This type stub file was generated by pyright.
"""
from mlx import nn
class LoRALinear(nn.Module):
@staticmethod
def from_linear(
linear: nn.Linear | nn.QuantizedLinear, r: int = ..., scale: float = ...
): # -> LoRALinear:
...
def __init__(
self,
input_dims: int,
output_dims: int,
r: int = ...,
scale: float = ...,
bias: bool = ...,
) -> None: ...
def __call__(self, x): # -> array:
...

View File

@@ -0,0 +1,26 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
from collections.abc import Callable
from dataclasses import dataclass
from mflux.models.common.lora.mapping.lora_mapping import LoRATarget
@dataclass
class PatternMatch:
source_pattern: str
target_path: str
matrix_name: str
transpose: bool
transform: Callable[[mx.array], mx.array] | None = ...
class LoRALoader:
@staticmethod
def load_and_apply_lora(
lora_mapping: list[LoRATarget],
transformer: nn.Module,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> tuple[list[str], list[float]]: ...

View File

@@ -0,0 +1,21 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from collections.abc import Callable
from dataclasses import dataclass
from typing import List, Protocol
@dataclass
class LoRATarget:
model_path: str
possible_up_patterns: List[str]
possible_down_patterns: List[str]
possible_alpha_patterns: List[str] = ...
up_transform: Callable[[mx.array], mx.array] | None = ...
down_transform: Callable[[mx.array], mx.array] | None = ...
class LoRAMapping(Protocol):
@staticmethod
def get_mapping() -> List[LoRATarget]: ...

View File

@@ -0,0 +1,9 @@
"""
This type stub file was generated by pyright.
"""
import mlx.nn as nn
class LoRASaver:
@staticmethod
def bake_and_strip_lora(module: nn.Module) -> nn.Module: ...

View File

@@ -0,0 +1,35 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class LoraTransforms:
@staticmethod
def split_q_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_k_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_v_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_q_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_k_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_v_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_q_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_k_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_v_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_mlp_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_q_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_k_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_v_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_mlp_down(tensor: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.resolution.config_resolution import ConfigResolution
from mflux.models.common.resolution.lora_resolution import LoraResolution
from mflux.models.common.resolution.path_resolution import PathResolution
from mflux.models.common.resolution.quantization_resolution import (
QuantizationResolution,
)
__all__ = [
"ConfigResolution",
"LoraResolution",
"PathResolution",
"QuantizationResolution",
]

View File

@@ -0,0 +1,39 @@
"""
This type stub file was generated by pyright.
"""
from enum import Enum
from typing import NamedTuple
class QuantizationAction(Enum):
NONE = ...
STORED = ...
REQUESTED = ...
class PathAction(Enum):
LOCAL = ...
HUGGINGFACE_CACHED = ...
HUGGINGFACE = ...
ERROR = ...
class LoraAction(Enum):
LOCAL = ...
REGISTRY = ...
HUGGINGFACE_COLLECTION_CACHED = ...
HUGGINGFACE_COLLECTION = ...
HUGGINGFACE_REPO_CACHED = ...
HUGGINGFACE_REPO = ...
ERROR = ...
class ConfigAction(Enum):
EXACT_MATCH = ...
EXPLICIT_BASE = ...
INFER_SUBSTRING = ...
ERROR = ...
class Rule(NamedTuple):
priority: int
name: str
check: str
action: QuantizationAction | PathAction | LoraAction | ConfigAction
...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.config.model_config import ModelConfig
if TYPE_CHECKING: ...
logger = ...
class ConfigResolution:
RULES = ...
@staticmethod
def resolve(model_name: str, base_model: str | None = ...) -> ModelConfig: ...

View File

@@ -0,0 +1,21 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
logger = ...
class LoraResolution:
RULES = ...
_registry: dict[str, Path] = ...
@staticmethod
def resolve(path: str) -> str: ...
@staticmethod
def resolve_paths(paths: list[str] | None) -> list[str]: ...
@staticmethod
def resolve_scales(scales: list[float] | None, num_paths: int) -> list[float]: ...
@staticmethod
def get_registry() -> dict[str, Path]: ...
@staticmethod
def discover_files(library_paths: list[Path]) -> dict[str, Path]: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
logger = ...
class PathResolution:
RULES = ...
@staticmethod
def resolve(path: str | None, patterns: list[str] | None = ...) -> Path | None: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
logger = ...
class QuantizationResolution:
RULES = ...
@staticmethod
def resolve(
stored: int | None, requested: int | None
) -> tuple[int | None, str | None]: ...

View File

@@ -0,0 +1,26 @@
"""
This type stub file was generated by pyright.
"""
from .flow_match_euler_discrete_scheduler import FlowMatchEulerDiscreteScheduler
from .linear_scheduler import LinearScheduler
from .seedvr2_euler_scheduler import SeedVR2EulerScheduler
__all__ = [
"LinearScheduler",
"FlowMatchEulerDiscreteScheduler",
"SeedVR2EulerScheduler",
]
class SchedulerModuleNotFound(ValueError): ...
class SchedulerClassNotFound(ValueError): ...
class InvalidSchedulerType(TypeError): ...
SCHEDULER_REGISTRY = ...
def register_contrib(scheduler_object, scheduler_name=...): # -> None:
...
def try_import_external_scheduler(
scheduler_object_path: str,
): # -> type[BaseScheduler]:
...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from abc import ABC, abstractmethod
class BaseScheduler(ABC):
@property
@abstractmethod
def sigmas(self) -> mx.array: ...
@abstractmethod
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...

View File

@@ -0,0 +1,26 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class FlowMatchEulerDiscreteScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def sigmas(self) -> mx.array: ...
@property
def timesteps(self) -> mx.array: ...
def set_image_seq_len(self, image_seq_len: int) -> None: ...
@staticmethod
def get_timesteps_and_sigmas(
image_seq_len: int, num_inference_steps: int, num_train_timesteps: int = ...
) -> tuple[mx.array, mx.array]: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...

View File

@@ -0,0 +1,20 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class LinearScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def sigmas(self) -> mx.array: ...
@property
def timesteps(self) -> mx.array: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...

View File

@@ -0,0 +1,20 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class SeedVR2EulerScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def timesteps(self) -> mx.array: ...
@property
def sigmas(self) -> mx.array: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...

View File

@@ -0,0 +1,24 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.tokenizer.tokenizer import (
BaseTokenizer,
LanguageTokenizer,
Tokenizer,
VisionLanguageTokenizer,
)
from mflux.models.common.tokenizer.tokenizer_loader import TokenizerLoader
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
"""
This type stub file was generated by pyright.
"""
__all__ = [
"Tokenizer",
"BaseTokenizer",
"LanguageTokenizer",
"VisionLanguageTokenizer",
"TokenizerLoader",
"TokenizerOutput",
]

View File

@@ -0,0 +1,74 @@
"""
This type stub file was generated by pyright.
"""
from abc import ABC, abstractmethod
from typing import Protocol, runtime_checkable
from PIL import Image
from transformers import PreTrainedTokenizer
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
"""
This type stub file was generated by pyright.
"""
@runtime_checkable
class Tokenizer(Protocol):
tokenizer: PreTrainedTokenizer
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class BaseTokenizer(ABC):
def __init__(
self, tokenizer: PreTrainedTokenizer, max_length: int = ...
) -> None: ...
@abstractmethod
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class LanguageTokenizer(BaseTokenizer):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
max_length: int = ...,
padding: str = ...,
return_attention_mask: bool = ...,
template: str | None = ...,
use_chat_template: bool = ...,
chat_template_kwargs: dict | None = ...,
add_special_tokens: bool = ...,
) -> None: ...
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class VisionLanguageTokenizer(BaseTokenizer):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
processor,
max_length: int = ...,
template: str | None = ...,
image_token: str = ...,
) -> None: ...
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...

View File

@@ -0,0 +1,22 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
from mflux.models.common.weights.loading.weight_definition import TokenizerDefinition
"""
This type stub file was generated by pyright.
"""
if TYPE_CHECKING: ...
class TokenizerLoader:
@staticmethod
def load(definition: TokenizerDefinition, model_path: str) -> BaseTokenizer: ...
@staticmethod
def load_all(
definitions: list[TokenizerDefinition],
model_path: str,
max_length_overrides: dict[str, int] | None = ...,
) -> dict[str, BaseTokenizer]: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
"""
This type stub file was generated by pyright.
"""
@dataclass
class TokenizerOutput:
input_ids: mx.array
attention_mask: mx.array
pixel_values: mx.array | None = ...
image_grid_thw: mx.array | None = ...

View File

@@ -0,0 +1,8 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.vae.tiling_config import TilingConfig
from mflux.models.common.vae.vae_tiler import VAETiler
__all__ = ["TilingConfig", "VAETiler"]

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class TilingConfig:
vae_decode_tiles_per_dim: int | None = ...
vae_decode_overlap: int = ...
vae_encode_tiled: bool = ...
vae_encode_tile_size: int = ...
vae_encode_tile_overlap: int = ...

View File

@@ -0,0 +1,27 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import Callable
class VAETiler:
@staticmethod
def encode_image_tiled(
*,
image: mx.array,
encode_fn: Callable[[mx.array], mx.array],
latent_channels: int,
tile_size: tuple[int, int] = ...,
tile_overlap: tuple[int, int] = ...,
spatial_scale: int = ...,
) -> mx.array: ...
@staticmethod
def decode_image_tiled(
*,
latent: mx.array,
decode_fn: Callable[[mx.array], mx.array],
tile_size: tuple[int, int] = ...,
tile_overlap: tuple[int, int] = ...,
spatial_scale: int = ...,
) -> mx.array: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
from mflux.models.common.vae.tiling_config import TilingConfig
class VAEUtil:
@staticmethod
def encode(
vae: nn.Module, image: mx.array, tiling_config: TilingConfig | None = ...
) -> mx.array: ...
@staticmethod
def decode(
vae: nn.Module, latent: mx.array, tiling_config: TilingConfig | None = ...
) -> mx.array: ...

View File

@@ -0,0 +1,18 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights, MetaData
from mflux.models.common.weights.loading.weight_applier import WeightApplier
from mflux.models.common.weights.loading.weight_definition import ComponentDefinition
from mflux.models.common.weights.loading.weight_loader import WeightLoader
from mflux.models.common.weights.saving.model_saver import ModelSaver
__all__ = [
"ComponentDefinition",
"LoadedWeights",
"MetaData",
"ModelSaver",
"WeightApplier",
"WeightLoader",
]

View File

@@ -0,0 +1,18 @@
"""
This type stub file was generated by pyright.
"""
from dataclasses import dataclass
@dataclass
class MetaData:
quantization_level: int | None = ...
mflux_version: str | None = ...
@dataclass
class LoadedWeights:
components: dict[str, dict]
meta_data: MetaData
def __getattr__(self, name: str) -> dict | None: ...
def num_transformer_blocks(self, component_name: str = ...) -> int: ...
def num_single_transformer_blocks(self, component_name: str = ...) -> int: ...

View File

@@ -0,0 +1,30 @@
"""
This type stub file was generated by pyright.
"""
import mlx.nn as nn
from typing import TYPE_CHECKING
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
WeightDefinitionType,
)
if TYPE_CHECKING: ...
class WeightApplier:
@staticmethod
def apply_and_quantize_single(
weights: LoadedWeights,
model: nn.Module,
component: ComponentDefinition,
quantize_arg: int | None,
quantization_predicate=...,
) -> int | None: ...
@staticmethod
def apply_and_quantize(
weights: LoadedWeights,
models: dict[str, nn.Module],
quantize_arg: int | None,
weight_definition: WeightDefinitionType,
) -> int | None: ...

View File

@@ -0,0 +1,73 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from typing import Callable, List, TYPE_CHECKING, TypeAlias
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
from mflux.models.depth_pro.weights.depth_pro_weight_definition import (
DepthProWeightDefinition,
)
from mflux.models.fibo.weights.fibo_weight_definition import FIBOWeightDefinition
from mflux.models.fibo_vlm.weights.fibo_vlm_weight_definition import (
FIBOVLMWeightDefinition,
)
from mflux.models.flux.weights.flux_weight_definition import FluxWeightDefinition
from mflux.models.qwen.weights.qwen_weight_definition import QwenWeightDefinition
from mflux.models.seedvr2.weights.seedvr2_weight_definition import (
SeedVR2WeightDefinition,
)
from mflux.models.z_image.weights.z_image_weight_definition import (
ZImageWeightDefinition,
)
"""
This type stub file was generated by pyright.
"""
if TYPE_CHECKING:
WeightDefinitionType: TypeAlias = type[
FluxWeightDefinition
| FIBOWeightDefinition
| FIBOVLMWeightDefinition
| QwenWeightDefinition
| ZImageWeightDefinition
| SeedVR2WeightDefinition
| DepthProWeightDefinition
]
@dataclass
class ComponentDefinition:
name: str
hf_subdir: str
mapping_getter: Callable[[], List[WeightTarget]] | None = ...
model_attr: str | None = ...
num_blocks: int | None = ...
num_layers: int | None = ...
loading_mode: str = ...
precision: mx.Dtype | None = ...
skip_quantization: bool = ...
bulk_transform: Callable[[mx.array], mx.array] | None = ...
weight_subkey: str | None = ...
download_url: str | None = ...
weight_prefix_filters: List[str] | None = ...
weight_files: List[str] | None = ...
@dataclass
class TokenizerDefinition:
name: str
hf_subdir: str
tokenizer_class: str = ...
fallback_subdirs: List[str] | None = ...
download_patterns: List[str] | None = ...
encoder_class: type[BaseTokenizer] | None = ...
max_length: int = ...
padding: str = ...
template: str | None = ...
use_chat_template: bool = ...
chat_template_kwargs: dict | None = ...
add_special_tokens: bool = ...
processor_class: type | None = ...
image_token: str = ...
chat_template: str | None = ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
WeightDefinitionType,
)
if TYPE_CHECKING: ...
logger = ...
class WeightLoader:
@staticmethod
def load_single(
component: ComponentDefinition, repo_id: str, file_pattern: str = ...
) -> LoadedWeights: ...
@staticmethod
def load(
weight_definition: WeightDefinitionType, model_path: str | None = ...
) -> LoadedWeights: ...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import Dict, List, Optional
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
class WeightMapper:
@staticmethod
def apply_mapping(
hf_weights: Dict[str, mx.array],
mapping: List[WeightTarget],
num_blocks: Optional[int] = ...,
num_layers: Optional[int] = ...,
) -> Dict: ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from typing import Callable, List, Optional, Protocol
"""
This type stub file was generated by pyright.
"""
@dataclass
class WeightTarget:
to_pattern: str
from_pattern: List[str]
transform: Optional[Callable[[mx.array], mx.array]] = ...
required: bool = ...
max_blocks: Optional[int] = ...
class WeightMapping(Protocol):
@staticmethod
def get_mapping() -> List[WeightTarget]: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class WeightTransforms:
@staticmethod
def reshape_gamma_to_1d(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_patch_embed(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv3d_weight(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv2d_weight(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv_transpose2d_weight(tensor: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
from typing import Any, TYPE_CHECKING
from mflux.models.common.weights.loading.weight_definition import WeightDefinitionType
if TYPE_CHECKING: ...
class ModelSaver:
@staticmethod
def save_model(
model: Any, bits: int, base_path: str, weight_definition: WeightDefinitionType
) -> None: ...

View File

@@ -0,0 +1,9 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.depth_pro.model.depth_pro_model import DepthProModel
class DepthProInitializer:
@staticmethod
def init(model: DepthProModel, quantize: int | None = ...) -> None: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class FeatureFusionBlock2d(nn.Module):
def __init__(self, num_features: int, deconv: bool = ...) -> None: ...
def __call__(self, x0: mx.array, x1: mx.array | None = ...) -> mx.array: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class MultiresConvDecoder(nn.Module):
def __init__(self) -> None: ...
def __call__(
self,
x0_latent: mx.array,
x1_latent: mx.array,
x0_features: mx.array,
x1_features: mx.array,
x_global_features: mx.array,
) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, num_features: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,20 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from pathlib import Path
from PIL import Image
@dataclass
class DepthResult:
depth_image: Image.Image
depth_array: mx.array
min_depth: float
max_depth: float
...
class DepthPro:
def __init__(self, quantize: int | None = ...) -> None: ...
def create_depth_map(self, image_path: str | Path) -> DepthResult: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProModel(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, x0: mx.array, x1: mx.array, x2: mx.array
) -> tuple[mx.array, mx.array]: ...

View File

@@ -0,0 +1,15 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProUtil:
@staticmethod
def split(x: mx.array, overlap_ratio: float = ...) -> mx.array: ...
@staticmethod
def interpolate(x: mx.array, size=..., scale_factor=...): # -> array:
...
@staticmethod
def apply_conv(x: mx.array, conv_module: nn.Module) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class Attention(nn.Module):
def __init__(
self, dim: int = ..., head_dim: int = ..., num_heads: int = ...
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DinoVisionTransformer(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array, mx.array]: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class LayerScale(nn.Module):
def __init__(self, dims: int, init_values: float = ...) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class PatchEmbed(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class TransformerBlock(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProEncoder(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, x0: mx.array, x1: mx.array, x2: mx.array
) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class UpSampleBlock(nn.Module):
def __init__(
self,
dim_in: int = ...,
dim_int: int = ...,
dim_out: int = ...,
upsample_layers: int = ...,
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class FOVHead(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
class DepthProWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class DepthProWeightMapping(WeightMapping):
@staticmethod
def get_mapping() -> List[WeightTarget]: ...

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class FiboLatentCreator:
@staticmethod
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
@staticmethod
def pack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...
@staticmethod
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
class FIBOWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class FIBOWeightMapping(WeightMapping):
@staticmethod
def get_transformer_mapping() -> List[WeightTarget]: ...
@staticmethod
def get_text_encoder_mapping() -> List[WeightTarget]: ...
@staticmethod
def get_vae_mapping() -> List[WeightTarget]: ...

View File

@@ -0,0 +1,8 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.qwen.tokenizer.qwen_image_processor import QwenImageProcessor
class Qwen2VLImageProcessor(QwenImageProcessor):
def __init__(self) -> None: ...

View File

@@ -0,0 +1,28 @@
"""
This type stub file was generated by pyright.
"""
from typing import Optional, Union
from PIL import Image
class Qwen2VLProcessor:
def __init__(self, tokenizer) -> None: ...
def apply_chat_template(
self,
messages,
tokenize: bool = ...,
add_generation_prompt: bool = ...,
return_tensors: Optional[str] = ...,
return_dict: bool = ...,
**kwargs,
): # -> dict[Any, Any]:
...
def __call__(
self,
text: Optional[Union[str, list[str]]] = ...,
images: Optional[Union[Image.Image, list[Image.Image]]] = ...,
padding: bool = ...,
return_tensors: Optional[str] = ...,
**kwargs,
): # -> dict[Any, Any]:
...

View File

@@ -0,0 +1,24 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
QWEN2VL_CHAT_TEMPLATE = ...
class FIBOVLMWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -0,0 +1,15 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class FIBOVLMWeightMapping(WeightMapping):
@staticmethod
def get_vlm_decoder_mapping(num_layers: int = ...) -> List[WeightTarget]: ...
@staticmethod
def get_vlm_visual_mapping(depth: int = ...) -> List[WeightTarget]: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,53 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.config import ModelConfig
class FluxInitializer:
@staticmethod
def init(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
custom_transformer=...,
) -> None: ...
@staticmethod
def init_depth(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_redux(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_controlnet(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_concept(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,19 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
"""
This type stub file was generated by pyright.
"""
class FluxLatentCreator:
@staticmethod
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
@staticmethod
def pack_latents(
latents: mx.array, height: int, width: int, num_channels_latents: int = ...
) -> mx.array: ...
@staticmethod
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPEmbeddings(nn.Module):
def __init__(self, dims: int) -> None: ...
def __call__(self, tokens: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
"""
This type stub file was generated by pyright.
"""
class CLIPEncoder(nn.Module):
def __init__(self) -> None: ...
def __call__(self, tokens: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPEncoderLayer(nn.Module):
def __init__(self, layer: int) -> None: ...
def __call__(
self, hidden_states: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPMLP(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def quick_gelu(input_array: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,18 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPSdpaAttention(nn.Module):
head_dimension = ...
batch_size = ...
num_heads = ...
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...
@staticmethod
def reshape_and_transpose(x, batch_size, num_heads, head_dim): # -> array:
...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPTextModel(nn.Module):
def __init__(self, dims: int, num_encoder_layers: int) -> None: ...
def __call__(self, tokens: mx.array) -> tuple[mx.array, mx.array]: ...
@staticmethod
def create_causal_attention_mask(input_shape: tuple) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class EncoderCLIP(nn.Module):
def __init__(self, num_encoder_layers: int) -> None: ...
def __call__(
self, tokens: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...

View File

@@ -0,0 +1,25 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mflux.models.common.tokenizer import Tokenizer
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
CLIPEncoder,
)
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
"""
This type stub file was generated by pyright.
"""
class PromptEncoder:
@staticmethod
def encode_prompt(
prompt: str,
prompt_cache: dict[str, tuple[mx.array, mx.array]],
t5_tokenizer: Tokenizer,
clip_tokenizer: Tokenizer,
t5_text_encoder: T5Encoder,
clip_text_encoder: CLIPEncoder,
) -> tuple[mx.array, mx.array]: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5Attention(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5Block(nn.Module):
def __init__(self, layer: int) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5DenseReluDense(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def new_gelu(input_array: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
"""
This type stub file was generated by pyright.
"""
class T5Encoder(nn.Module):
def __init__(self) -> None: ...
def __call__(self, tokens: mx.array): ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5FeedForward(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5LayerNorm(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5SelfAttention(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def shape(states): # -> array:
...
@staticmethod
def un_shape(states): # -> array:
...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormContinuous(nn.Module):
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int) -> None: ...
def __call__(self, x: mx.array, text_embeddings: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormZero(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, text_embeddings: mx.array
) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormZeroSingle(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, text_embeddings: mx.array
) -> tuple[mx.array, mx.array]: ...

View File

@@ -0,0 +1,41 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AttentionUtils:
@staticmethod
def process_qkv(
hidden_states: mx.array,
to_q: nn.Linear,
to_k: nn.Linear,
to_v: nn.Linear,
norm_q: nn.RMSNorm,
norm_k: nn.RMSNorm,
num_heads: int,
head_dim: int,
) -> tuple[mx.array, mx.array, mx.array]: ...
@staticmethod
def compute_attention(
query: mx.array,
key: mx.array,
value: mx.array,
batch_size: int,
num_heads: int,
head_dim: int,
mask: mx.array | None = ...,
) -> mx.array: ...
@staticmethod
def convert_key_padding_mask_to_additive_mask(
mask: mx.array | None, joint_seq_len: int, txt_seq_len: int
) -> mx.array | None: ...
@staticmethod
def apply_rope(
xq: mx.array, xk: mx.array, freqs_cis: mx.array
) -> tuple[mx.array, mx.array]: ...
@staticmethod
def apply_rope_bshd(
xq: mx.array, xk: mx.array, cos: mx.array, sin: mx.array
) -> tuple[mx.array, mx.array]: ...

Some files were not shown because too many files have changed in this diff Show More