Compare commits

..

262 Commits

Author SHA1 Message Date
ciaranbor
7a6a847577 Review feedback 2026-01-21 17:10:14 +00:00
ciaranbor
71f844c565 Add endpoint to list images 2026-01-21 15:05:55 +00:00
ciaranbor
be4e159499 Support image url response 2026-01-21 15:05:55 +00:00
ciaranbor
5a1b5f82d5 Use public snapshot_download 2026-01-21 15:05:55 +00:00
ciaranbor
cd2bdb9f42 Support jpeg outputs 2026-01-21 15:05:55 +00:00
ciaranbor
45a45c4d3f Error on malformed size parameter 2026-01-21 15:05:55 +00:00
ciaranbor
a79b49fc68 Use ModelId 2026-01-21 15:05:55 +00:00
ciaranbor
3119fe27c7 Resolve mflux type errors 2026-01-21 15:05:55 +00:00
ciaranbor
406cd2ac20 Generate mflux type stubs 2026-01-21 15:05:55 +00:00
ciaranbor
833b765af4 Format FE 2026-01-21 15:04:40 +00:00
ciaranbor
09f7a34086 Prevent running image editing model without an input image 2026-01-21 15:04:40 +00:00
ciaranbor
137700d142 Type coercion for ModelTask 2026-01-21 15:04:40 +00:00
ciaranbor
bf7b1e9e9f Fallback for resolving model card 2026-01-21 15:04:40 +00:00
ciaranbor
2e5ae6cd5e Reflect ModelCard simplification 2026-01-21 15:04:40 +00:00
ciaranbor
328063fd9f Fix text model runtime check 2026-01-21 15:04:40 +00:00
ciaranbor
1cbc40e65d Fix image streaming for editing 2026-01-21 15:04:40 +00:00
ciaranbor
cd4c49ff18 Propagate additional image edit api params 2026-01-21 15:04:40 +00:00
ciaranbor
def96ba0b6 Support image editing in UI 2026-01-21 15:04:40 +00:00
ciaranbor
dfedf2b8bb Allow dropdowns to fit values 2026-01-21 15:04:40 +00:00
ciaranbor
625ef2c35c Better param dropdowns 2026-01-21 15:04:40 +00:00
ciaranbor
2961713943 Better typing 2026-01-21 15:04:40 +00:00
ciaranbor
6e00aedcf6 Restore recv evals 2026-01-21 15:04:40 +00:00
ciaranbor
8ae9fed692 Remove outputFormat param 2026-01-21 15:04:40 +00:00
ciaranbor
e5e15c22a5 Expose image generation settings to UI 2026-01-21 15:04:40 +00:00
ciaranbor
e7a92aea4d DiffusionRunner type errors 2026-01-21 15:04:40 +00:00
ciaranbor
b96419d48d Correctly handle num_sync_steps for diffusion steps override 2026-01-21 15:04:40 +00:00
ciaranbor
3342565474 Simplify DistributedImageModel 2026-01-21 15:04:40 +00:00
ciaranbor
af2fe2c662 Hide internal mflux image type 2026-01-21 15:04:40 +00:00
ciaranbor
8fa15f4fc9 Only run two steps during warmup 2026-01-21 15:04:39 +00:00
ciaranbor
8e848ee71c Document image apis 2026-01-21 15:04:39 +00:00
ciaranbor
69fd194715 Clean up ImageModelConfig 2026-01-21 15:04:39 +00:00
ciaranbor
fa70ade332 Add AdvancedImageParams to api 2026-01-21 15:04:39 +00:00
ciaranbor
153d9ee017 Capture partial images earlier 2026-01-21 15:04:39 +00:00
ciaranbor
c6cc4ed79b Remove redundant tensor allocations for recv templates 2026-01-21 15:04:39 +00:00
ciaranbor
09605a71ba Remove recv evals, use async_eval for sends 2026-01-21 15:04:39 +00:00
ciaranbor
34003549a9 Remove ImageGenerator protocol 2026-01-21 15:04:39 +00:00
ciaranbor
30bab3d51c Batch CFG 2026-01-21 15:04:39 +00:00
ciaranbor
35ad63ab4c Add image generation benchmarking endpoints 2026-01-21 15:04:39 +00:00
ciaranbor
7e1dc3740f Consolidate patched and unpatched qkv computation logic 2026-01-21 15:04:39 +00:00
ciaranbor
cc368c005e Use mixin for common block wrapper functionality 2026-01-21 15:04:39 +00:00
ciaranbor
9e398dbca5 Move last rank to first rank comms outside the CFG step 2026-01-21 15:04:39 +00:00
ciaranbor
aca99f9086 Revert 2026-01-21 15:04:39 +00:00
ciaranbor
06d7bdea74 Run all positive then all negative 2026-01-21 15:04:39 +00:00
ciaranbor
766c690f78 Rank 0 shouldn't receive on negative pass 2026-01-21 15:04:39 +00:00
ciaranbor
26eda8a537 Fix negative pass text_seq_len 2026-01-21 15:04:39 +00:00
ciaranbor
51fad29944 Add distributed CFG support 2026-01-21 15:04:39 +00:00
ciaranbor
9f939e606c Enable CFG for Qwen-Image 2026-01-21 15:04:39 +00:00
ciaranbor
fd0b6f7b47 Use transformer block wrapper classes 2026-01-21 15:04:39 +00:00
ciaranbor
af1d8b2b8a Refactor 2026-01-21 15:04:39 +00:00
ciaranbor
e3b5d5fd89 Fix flux tokenizer 2026-01-21 15:04:39 +00:00
ciaranbor
ca26d865cd Reduce image generation and image edits code duplication 2026-01-21 15:04:39 +00:00
ciaranbor
87e0cff54f Update mflux to 0.14.2 2026-01-21 15:04:39 +00:00
ciaranbor
ad0b49cd98 Linting 2026-01-21 15:04:32 +00:00
ciaranbor
c38e6cac64 Start image editing time steps at 0 2026-01-21 15:04:32 +00:00
ciaranbor
b38ac2efed Ignore image_strength 2026-01-21 15:04:32 +00:00
ciaranbor
f5a5a1b536 Handle conditioning latents in sync pipeline 2026-01-21 15:04:32 +00:00
ciaranbor
e8d9376734 Use dummy image for editing warmup 2026-01-21 15:04:32 +00:00
ciaranbor
311455d13e Support streaming for image editing 2026-01-21 15:04:32 +00:00
ciaranbor
17ad2d5ec1 Support image editing in runner 2026-01-21 15:04:32 +00:00
ciaranbor
c2441122a3 Add editing features to adapter 2026-01-21 15:04:32 +00:00
ciaranbor
deb298823b Default partial images to 3 if streaming 2026-01-21 15:04:32 +00:00
ciaranbor
5b72c5d5cb Add Qwen-Image model adapter 2026-01-21 15:04:32 +00:00
ciaranbor
8dc13a38e9 Add Qwen-Image-Edit model config 2026-01-21 15:04:32 +00:00
ciaranbor
bea3c1ae2f Use image generation in streaming mode in UI 2026-01-21 15:04:32 +00:00
ciaranbor
5afd965c92 Handle partial image streaming 2026-01-21 15:04:32 +00:00
ciaranbor
fe66d36453 Add streaming params to ImageGenerationTaskParams 2026-01-21 15:04:32 +00:00
ciaranbor
bebaf2ed9c Add Qwen-Image-Edit-2509 2026-01-21 15:04:32 +00:00
ciaranbor
3743d24a14 Handle image editing time steps 2026-01-21 15:04:32 +00:00
ciaranbor
f9eaaac75f Fix time steps 2026-01-21 15:04:32 +00:00
ciaranbor
5aaf07d892 Fix image_strength meaning 2026-01-21 15:04:32 +00:00
ciaranbor
71e63d0417 Truncate image data logs 2026-01-21 15:04:32 +00:00
ciaranbor
7f2a5f4394 Chunk image input 2026-01-21 15:04:32 +00:00
ciaranbor
cdcfcd8d4d Avoid logging image data 2026-01-21 15:04:32 +00:00
ciaranbor
5249a7260c Support image editing 2026-01-21 15:04:32 +00:00
Sami Khan
818d0718b0 small UI change 2026-01-21 15:04:32 +00:00
Sami Khan
ce99022832 image gen in dashboard 2026-01-21 15:04:32 +00:00
ciaranbor
b3fa3e5b70 Better llm model type check 2026-01-21 15:04:32 +00:00
ciaranbor
84cd65b626 Prune blocks before model load 2026-01-21 15:04:32 +00:00
ciaranbor
7c81d4c53a Own TODOs 2026-01-21 15:04:32 +00:00
ciaranbor
2fd4bdd34b Remove double RunnerReady event 2026-01-21 15:04:32 +00:00
ciaranbor
6ffb7d158c Fix hidden_size for image models 2026-01-21 15:04:32 +00:00
ciaranbor
72575193c0 Fix image model cards 2026-01-21 15:04:32 +00:00
ciaranbor
c439c3a903 Skip decode on non-final ranks 2026-01-21 15:04:32 +00:00
ciaranbor
c399d90b11 Final rank produces image 2026-01-21 15:04:32 +00:00
ciaranbor
6a397a27ae Increase number of sync steps 2026-01-21 15:04:32 +00:00
ciaranbor
a1db52ea28 Change Qwen-Image steps 2026-01-21 15:04:32 +00:00
ciaranbor
61ef4ed277 Fix Qwen-Image latent shapes 2026-01-21 15:04:32 +00:00
ciaranbor
7501c6605a Fix joint block patch recv shape for non-zero ranks 2026-01-21 15:04:32 +00:00
ciaranbor
7039630323 Fix comms issue for models without single blocks 2026-01-21 15:04:32 +00:00
ciaranbor
d555dd7443 Support Qwen in DiffusionRunner pipefusion 2026-01-21 15:04:32 +00:00
ciaranbor
6dbe8af9c1 Implement Qwen pipefusion 2026-01-21 15:04:32 +00:00
ciaranbor
75d81f3339 Add guidance_scale parameter to image model config 2026-01-21 15:04:32 +00:00
ciaranbor
3ced17a63e Move orchestration to DiffusionRunner 2026-01-21 15:04:32 +00:00
ciaranbor
c3cda7ad17 Add initial QwenModelAdapter 2026-01-21 15:04:32 +00:00
ciaranbor
74bd62d63a Tweak embeddings interface 2026-01-21 15:04:32 +00:00
ciaranbor
ad2c4b7c7b Add Qwen ImageModelConfig 2026-01-21 15:04:32 +00:00
ciaranbor
60cba5cb6a Use 10% sync steps 2026-01-21 15:04:32 +00:00
ciaranbor
3b2c560b2f Update FluxModelAdaper for new interface 2026-01-21 15:04:32 +00:00
ciaranbor
b93dd6679c Register QwenModelAdapter 2026-01-21 15:04:32 +00:00
ciaranbor
8623b338d0 Support multiple forward passes in runner 2026-01-21 15:04:32 +00:00
ciaranbor
39fd443d22 Extend block wrapper parameters 2026-01-21 15:04:32 +00:00
ciaranbor
bf6aae0077 Relax adaptor typing 2026-01-21 15:04:32 +00:00
ciaranbor
d349705410 Add Qwen-Image model card 2026-01-21 15:04:32 +00:00
ciaranbor
76dee083c8 Clean up dead code 2026-01-21 15:04:32 +00:00
ciaranbor
5bd31d90eb Add BaseModelAdaptor 2026-01-21 15:04:32 +00:00
ciaranbor
8ce8be17f9 Refactor filestructure 2026-01-21 15:04:32 +00:00
ciaranbor
ff192fc9c0 Treat unified blocks as single blocks (equivalent) 2026-01-21 15:04:32 +00:00
ciaranbor
0990eace17 Refactor to handle entire denoising process in Diffusion runner 2026-01-21 15:04:32 +00:00
ciaranbor
404c370612 Move transformer to adapter 2026-01-21 15:04:32 +00:00
ciaranbor
0c5945256e Move some more logic to adaptor 2026-01-21 15:04:32 +00:00
ciaranbor
f96a0042bf Add generic block wrapper 2026-01-21 15:04:32 +00:00
ciaranbor
9166d041d8 Access transformer blocks from adaptor 2026-01-21 15:04:32 +00:00
ciaranbor
518fdcf366 Better typing 2026-01-21 15:04:32 +00:00
ciaranbor
07fc666a3d Create wrappers at init time 2026-01-21 15:04:32 +00:00
ciaranbor
8b59517a2c Combine model factory and adaptor 2026-01-21 15:04:32 +00:00
ciaranbor
4ca07e47c2 Implement model factory 2026-01-21 15:04:32 +00:00
ciaranbor
a492ae8a6b Add adaptor registry 2026-01-21 15:04:32 +00:00
ciaranbor
3847bf41c8 Remove mflux/generator/generate.py 2026-01-21 15:04:32 +00:00
ciaranbor
b4ba0f1183 Switch to using DistributedImageModel 2026-01-21 15:04:32 +00:00
ciaranbor
7da2feae5d Add DistributedImageModel 2026-01-21 15:04:32 +00:00
ciaranbor
02e96f88f9 Use new generic wrappers, etc in denoising 2026-01-21 15:04:32 +00:00
ciaranbor
387f2534ea Add generic transformer block wrappers 2026-01-21 15:04:32 +00:00
ciaranbor
49eee41276 Add FluxAdaptor 2026-01-21 15:04:32 +00:00
ciaranbor
8deb7c65cf Add ModelAdaptor, derivations implement model specific logic 2026-01-21 15:04:32 +00:00
ciaranbor
7f7f3efff8 Introduce image model config concept 2026-01-21 15:04:32 +00:00
ciaranbor
7dad96c2bd Consolidate kv cache patching 2026-01-21 15:04:32 +00:00
ciaranbor
7934510db0 Support different configuration comms 2026-01-21 15:04:32 +00:00
ciaranbor
7d7f4a6bca Add ImageGenerator protocol 2026-01-21 15:04:32 +00:00
ciaranbor
14ccf73de0 Force final patch receive order 2026-01-21 15:04:32 +00:00
ciaranbor
67ef5cdaf5 Remove logs 2026-01-21 15:04:32 +00:00
ciaranbor
79e6f40a6a Update patch list 2026-01-21 15:04:32 +00:00
ciaranbor
8a26d52b40 Slight refactor 2026-01-21 15:04:32 +00:00
ciaranbor
7488551ac4 Don't need array for prev patches 2026-01-21 15:04:32 +00:00
ciaranbor
2e253bbf01 Fix send/recv order 2026-01-21 15:04:32 +00:00
ciaranbor
8f7b95928e Fix async single transformer block 2026-01-21 15:04:32 +00:00
ciaranbor
40d09b22e5 Use relative rank variables 2026-01-21 15:04:32 +00:00
ciaranbor
ab45113285 Fix writing patches 2026-01-21 15:04:32 +00:00
ciaranbor
38dfec705f Collect final image 2026-01-21 15:04:32 +00:00
ciaranbor
fe06dfd922 Fix recv_template shape 2026-01-21 15:04:32 +00:00
ciaranbor
bfbb02e3b2 Add logs 2026-01-21 15:04:32 +00:00
ciaranbor
c57175ba8a Optimise async pipeline 2026-01-21 15:04:32 +00:00
ciaranbor
bfde99c35e Add next_rank and prev_rank members 2026-01-21 15:04:32 +00:00
ciaranbor
ae054c3adf Add _create_patches method 2026-01-21 15:04:32 +00:00
ciaranbor
435ab44d68 Fix shapes 2026-01-21 15:04:32 +00:00
ciaranbor
6372cb9e37 Reorder comms 2026-01-21 15:04:32 +00:00
ciaranbor
9e6e9306ed Remove all_gather from sync pipeline, send from final rank to first rank 2026-01-21 15:04:32 +00:00
ciaranbor
90856552ee Simplify kv_cache initialization 2026-01-21 15:04:32 +00:00
ciaranbor
cf45cd8887 Fix kv cache 2026-01-21 15:04:32 +00:00
ciaranbor
81d42f0851 Clean up kv caches 2026-01-21 15:04:32 +00:00
ciaranbor
29f8e7b6b6 Fix return 2026-01-21 15:04:32 +00:00
ciaranbor
d37228d891 Fix hidden_states shapes 2026-01-21 15:04:32 +00:00
ciaranbor
c735f86852 Only perform projection and scheduler step on last rank 2026-01-21 15:04:32 +00:00
ciaranbor
bbce54af6d Only compute embeddings on rank 0 2026-01-21 15:04:32 +00:00
ciaranbor
9d036edbef Remove eval 2026-01-21 15:04:32 +00:00
ciaranbor
42cd9ffb3d Remove eval 2026-01-21 15:04:32 +00:00
ciaranbor
6344c7bd58 Only send encoder_hidden_states with the first patch (once per timestep) 2026-01-21 15:04:32 +00:00
ciaranbor
3b2af661fa Remove redundant text kv cache computation 2026-01-21 15:04:32 +00:00
ciaranbor
e4da1c2dd7 Concatenate before all gather 2026-01-21 15:04:32 +00:00
ciaranbor
0616cdd6a8 Increase number of sync steps 2026-01-21 15:04:32 +00:00
ciaranbor
43248fe59a Reinitialise kv_caches between generations 2026-01-21 15:04:32 +00:00
ciaranbor
8601fa1df3 Eliminate double kv cache computation 2026-01-21 15:04:32 +00:00
ciaranbor
4b069037f9 Add kv cache caching wrappers for sync pipeline transformer blocks 2026-01-21 15:04:32 +00:00
ciaranbor
97056e4e02 Persist kv caches 2026-01-21 15:04:32 +00:00
ciaranbor
f65b4a3c8d Implement naive async pipeline implementation 2026-01-21 15:04:32 +00:00
ciaranbor
7c384a86a3 Use wrapper classes for patched transformer logic 2026-01-21 15:04:32 +00:00
ciaranbor
f6ebd8cd7a Add patch-aware joint and single attention wrappers 2026-01-21 15:04:32 +00:00
ciaranbor
6d2e183448 Fix group.size() 2026-01-21 15:04:32 +00:00
ciaranbor
0aaabe2e7d Add classes to manage kv caches with patch support 2026-01-21 15:04:32 +00:00
ciaranbor
ccdb52f4cb Use heuristic for number of sync steps 2026-01-21 15:04:32 +00:00
ciaranbor
32868a04e9 Generalise number of denoising steps 2026-01-21 15:04:32 +00:00
ciaranbor
9271562d37 Add flux1-dev 2026-01-21 15:04:32 +00:00
ciaranbor
88be42543e Move scheduler step to inner pipeline 2026-01-21 15:04:32 +00:00
ciaranbor
d945abf241 Add barrier before all_gather 2026-01-21 15:04:32 +00:00
ciaranbor
516e6b06cd Fix transformer blocks pruning 2026-01-21 15:04:32 +00:00
ciaranbor
4d175c63c7 Fix image generation api 2026-01-21 15:04:32 +00:00
ciaranbor
2f98459833 Create queue in try block 2026-01-21 15:04:32 +00:00
ciaranbor
2627ca0933 Conform to rebase 2026-01-21 15:04:32 +00:00
ciaranbor
f91453f333 Refactor denoising 2026-01-21 15:04:32 +00:00
ciaranbor
e78c18147c Move more logic to DistributedFlux 2026-01-21 15:04:32 +00:00
ciaranbor
1a5571e939 Move surrounding logic back to _sync_pipeline 2026-01-21 15:04:32 +00:00
ciaranbor
53b8af6366 Add patching aware member variables 2026-01-21 15:04:32 +00:00
ciaranbor
64bc62eecc Implement sync/async switching logic 2026-01-21 15:04:32 +00:00
ciaranbor
cb5ce5a130 Move current transformer implementation to _sync_pipeline method 2026-01-21 15:04:32 +00:00
ciaranbor
91974a04d8 Remove some logs 2026-01-21 15:04:32 +00:00
ciaranbor
d9167aed15 Remove old Flux1 implementation 2026-01-21 15:04:32 +00:00
ciaranbor
fa61565b48 Prune unused transformer blocks 2026-01-21 15:04:32 +00:00
ciaranbor
ba7756db1b Add mx.eval 2026-01-21 15:04:32 +00:00
ciaranbor
9062a358a4 Test evals 2026-01-21 15:04:32 +00:00
ciaranbor
1ca8f950c7 Test only barriers 2026-01-21 15:04:32 +00:00
ciaranbor
3c8e2117ad All perform final projection 2026-01-21 15:04:32 +00:00
ciaranbor
c38675b7cb Another barrier 2026-01-21 15:04:32 +00:00
ciaranbor
6bacab895f More debug 2026-01-21 15:04:32 +00:00
ciaranbor
2a338d15f1 Add barriers 2026-01-21 15:04:32 +00:00
ciaranbor
9120f15e04 Add log 2026-01-21 15:04:32 +00:00
ciaranbor
e4e3c838e2 Restore distributed logging 2026-01-21 15:04:32 +00:00
ciaranbor
dad7d453e4 Use bootstrap logger 2026-01-21 15:04:32 +00:00
ciaranbor
bc9e74ef76 Remove logs 2026-01-21 15:04:31 +00:00
ciaranbor
849ad22f27 fix single block receive shape 2026-01-21 15:04:31 +00:00
ciaranbor
b38cc3d08f Add debug logs 2026-01-21 15:04:31 +00:00
ciaranbor
3988789cc5 Move communication logic to DistributedTransformer wrapper 2026-01-21 15:04:31 +00:00
ciaranbor
d79fd26f57 Move inference logic to DistribuedFlux1 2026-01-21 15:04:31 +00:00
ciaranbor
cba3a968d5 Add DistributedFlux1 class 2026-01-21 15:04:31 +00:00
ciaranbor
f12cd35ab9 Rename pipeline to pipefusion 2026-01-21 15:04:31 +00:00
ciaranbor
fdc3d487c8 Further refactor 2026-01-21 15:04:31 +00:00
ciaranbor
2c2f09cb65 Refactor warmup 2026-01-21 15:04:31 +00:00
ciaranbor
046c73247b Manually handle flux1 inference 2026-01-21 15:04:31 +00:00
ciaranbor
3df368f16a Refactor flux1 image generation 2026-01-21 15:04:31 +00:00
ciaranbor
0e5a9287ff Use quality parameter to set number of inference steps 2026-01-21 15:04:31 +00:00
ciaranbor
b1ec12dc00 Chunk image data transfer 2026-01-21 15:04:31 +00:00
ciaranbor
ceb6612170 Define EXO_MAX_CHUNK_SIZE 2026-01-21 15:04:31 +00:00
ciaranbor
7177f2c36d Add indexing info to ImageChunk 2026-01-21 15:04:31 +00:00
ciaranbor
60c488cba3 Remove sharding logs 2026-01-21 15:04:31 +00:00
ciaranbor
1fad55a392 Temp: reduce flux1.schnell storage size 2026-01-21 15:04:31 +00:00
ciaranbor
f1b03cf0b3 Fix mflux transformer all_gather 2026-01-21 15:04:31 +00:00
ciaranbor
a3d6e7719c Fix world size 2026-01-21 15:04:31 +00:00
ciaranbor
62952b31ab Fix transition block? 2026-01-21 15:04:31 +00:00
ciaranbor
56dc9b8a19 Implement image generation warmup 2026-01-21 15:04:31 +00:00
ciaranbor
ccf6922b30 Add logs 2026-01-21 15:04:31 +00:00
ciaranbor
c77a87b331 Add spiece.model to default patterns 2026-01-21 15:04:31 +00:00
ciaranbor
1c9dbb0b6e Just download all files for now 2026-01-21 15:04:31 +00:00
ciaranbor
8e650074ea Fix get_allow_patterns to include non-indexed safetensors files 2026-01-21 15:04:31 +00:00
ciaranbor
19cacbed37 Use half-open layer indexing in get_allow_patterns 2026-01-21 15:04:31 +00:00
ciaranbor
86a48c9789 Enable distributed mflux 2026-01-21 15:04:31 +00:00
ciaranbor
18010489cf Implement mflux transformer sharding and communication pattern 2026-01-21 15:04:31 +00:00
ciaranbor
7c94b61a2b Update get_allow_patterns to handle sharding components 2026-01-21 15:04:31 +00:00
ciaranbor
d200e8a550 Namespace both keys and values for component weight maps 2026-01-21 15:04:31 +00:00
ciaranbor
b1267e233b Add components to Flux.1-schnell MODEL_CARD 2026-01-21 15:04:31 +00:00
ciaranbor
022498210f Add component concept for ModelMetadata 2026-01-21 15:04:31 +00:00
ciaranbor
51228d1516 Fix multiple components weight map key conflicts 2026-01-21 15:04:31 +00:00
ciaranbor
92d1850015 get_weight_map: handle repos with multiple safetensors.index.json files 2026-01-21 15:04:31 +00:00
ciaranbor
070e4ceb00 Add initial image edits spec 2026-01-21 15:04:31 +00:00
ciaranbor
f2ea2b40d3 Add image edits endpoint 2026-01-21 15:04:31 +00:00
ciaranbor
97eef764b1 Add ImageToImage task 2026-01-21 15:04:31 +00:00
ciaranbor
38d67a3e59 Allow ModelCards to have multiple tasks 2026-01-21 15:04:31 +00:00
ciaranbor
d4a1c3eb2a Fix text generation 2026-01-21 15:04:31 +00:00
ciaranbor
1c8bcd6c98 Rename mlx_generate_image to mflux_generate 2026-01-21 15:04:31 +00:00
ciaranbor
53e7a30082 Initialize mlx or mflux engine based on model task 2026-01-21 15:04:31 +00:00
ciaranbor
ebccb3ae71 Restore warmup for text generation 2026-01-21 15:04:31 +00:00
ciaranbor
e432a68bff Add initialize_mflux function 2026-01-21 15:04:31 +00:00
ciaranbor
03ae74f5c4 Move image generation to mflux engine 2026-01-21 15:04:31 +00:00
ciaranbor
012b8d89b0 Just use str for image generation size 2026-01-21 15:04:31 +00:00
ciaranbor
5c7e2ef9e1 Use MFlux for image generation 2026-01-21 15:04:31 +00:00
ciaranbor
0bf640bfe7 Add get_model_card function 2026-01-21 15:04:31 +00:00
ciaranbor
a1e8b09031 Add ModelTask enum 2026-01-21 15:04:31 +00:00
ciaranbor
6b7bd8ba2d ADd flux1-schnell model 2026-01-21 15:04:31 +00:00
ciaranbor
02ddea2458 Add task field to ModelCard 2026-01-21 15:04:31 +00:00
ciaranbor
618ee71492 Enable recursive repo downloads 2026-01-21 15:04:18 +00:00
ciaranbor
0ee12d1e1c Add dummy generate_image implementation 2026-01-21 15:04:18 +00:00
ciaranbor
ee99c74f17 Use base64 encoded str for image data 2026-01-21 15:04:18 +00:00
ciaranbor
4677fdbe08 Handle ImageGeneration tasks in _pending_tasks 2026-01-21 15:04:18 +00:00
ciaranbor
93575b92af Handle ImageGeneration task in runner task processing 2026-01-21 15:04:07 +00:00
ciaranbor
31c597732f Handle ImageGeneration command in master command processing 2026-01-21 15:04:07 +00:00
ciaranbor
2f4412641c Add image generation to API 2026-01-21 15:04:07 +00:00
ciaranbor
37c4b0bcf9 Add ImageGenerationResponse 2026-01-21 15:04:07 +00:00
ciaranbor
2ce3166aa3 Add ImageGeneration task 2026-01-21 15:04:07 +00:00
ciaranbor
c9de76c69c Add ImageGeneration command 2026-01-21 15:04:07 +00:00
ciaranbor
0a8d015a02 Add image generation params and response types 2026-01-21 15:04:07 +00:00
ciaranbor
df1e3e7d5a Fix mlx stream_generate import 2026-01-21 15:03:50 +00:00
ciaranbor
61afad4439 Add mflux type stubs 2026-01-21 14:52:00 +00:00
rltakashige
758464703d Fix GPT OSS tensor sharding with upstream MLX LM (#1223)
## Motivation
MLX LM has given GPT OSS a shard method, but MLX does not have an update
to match.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 18:24:54 +00:00
rltakashige
9e2179c848 Register original layer in CustomMlxLayer (#1229)
## Motivation
Kimi K2 Thinking Pipeline RDMA was broken before.

## Why It Works
No clue tbh

## Test Plan

### Manual Testing
Kimi K2 Thinking and GPT OSS work at the same time on Pipeline RDMA.
Needs exo bench to check more thoroughly

### Automated Testing
Layer composition tests still pass.
2026-01-20 18:20:01 +00:00
Evan Quiney
22b5d836ef swap all instances of model_id: str for model_id: ModelId (#1221)
This change uses the stronger typed ModelId, and introduces some
convenience methods. It also cleans up some code left over from #1204.

## Changes

`model_id: str -> model_id: ModelId`
`repo_id: str -> model_id: ModelId`

Introduces methods on ModelId, in particular ModelId.normalize() to
replace `/` with `--`.

This PR did introduce some circular imports, so has moved some code
around to try and limit them.

## Test Plan

Tests still pass, types still check. As this is about metadata, I
haven't tested inference.
2026-01-20 17:38:06 +00:00
281 changed files with 12809 additions and 1290 deletions

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
import os
if "TOKENIZERS_PARALLELISM" not in os.environ: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,47 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import PIL.Image
import tqdm
from typing import Protocol
from mflux.models.common.config.config import Config
class BeforeLoopCallback(Protocol):
def call_before_loop(
self,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
canny_image: PIL.Image.Image | None = ...,
depth_image: PIL.Image.Image | None = ...,
) -> None: ...
class InLoopCallback(Protocol):
def call_in_loop(
self,
t: int,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
time_steps: tqdm,
) -> None: ...
class AfterLoopCallback(Protocol):
def call_after_loop(
self, seed: int, prompt: str, latents: mx.array, config: Config
) -> None: ...
class InterruptCallback(Protocol):
def call_interrupt(
self,
t: int,
seed: int,
prompt: str,
latents: mx.array,
config: Config,
time_steps: tqdm,
) -> None: ...

View File

@@ -0,0 +1,24 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.callbacks.callback import (
AfterLoopCallback,
BeforeLoopCallback,
InLoopCallback,
InterruptCallback,
)
from mflux.callbacks.generation_context import GenerationContext
from mflux.models.common.config.config import Config
if TYPE_CHECKING: ...
class CallbackRegistry:
def __init__(self) -> None: ...
def register(self, callback) -> None: ...
def start(self, seed: int, prompt: str, config: Config) -> GenerationContext: ...
def before_loop_callbacks(self) -> list[BeforeLoopCallback]: ...
def in_loop_callbacks(self) -> list[InLoopCallback]: ...
def after_loop_callbacks(self) -> list[AfterLoopCallback]: ...
def interrupt_callbacks(self) -> list[InterruptCallback]: ...

View File

@@ -0,0 +1,29 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import PIL.Image
import tqdm
from typing import TYPE_CHECKING
from mflux.callbacks.callback_registry import CallbackRegistry
from mflux.models.common.config.config import Config
if TYPE_CHECKING: ...
class GenerationContext:
def __init__(
self, registry: CallbackRegistry, seed: int, prompt: str, config: Config
) -> None: ...
def before_loop(
self,
latents: mx.array,
*,
canny_image: PIL.Image.Image | None = ...,
depth_image: PIL.Image.Image | None = ...,
) -> None: ...
def in_loop(self, t: int, latents: mx.array, time_steps: tqdm = ...) -> None: ...
def after_loop(self, latents: mx.array) -> None: ...
def interruption(
self, t: int, latents: mx.array, time_steps: tqdm = ...
) -> None: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,22 @@
"""
This type stub file was generated by pyright.
"""
import os
BATTERY_PERCENTAGE_STOP_LIMIT = ...
CONTROLNET_STRENGTH = ...
DEFAULT_DEV_FILL_GUIDANCE = ...
DEFAULT_DEPTH_GUIDANCE = ...
DIMENSION_STEP_PIXELS = ...
GUIDANCE_SCALE = ...
GUIDANCE_SCALE_KONTEXT = ...
IMAGE_STRENGTH = ...
MODEL_CHOICES = ...
MODEL_INFERENCE_STEPS = ...
QUANTIZE_CHOICES = ...
if os.environ.get("MFLUX_CACHE_DIR"):
MFLUX_CACHE_DIR = ...
else:
MFLUX_CACHE_DIR = ...
MFLUX_LORA_CACHE_DIR = ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,8 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.config.config import Config
from mflux.models.common.config.model_config import ModelConfig
__all__ = ["Config", "ModelConfig"]

View File

@@ -0,0 +1,66 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from pathlib import Path
from typing import Any
from tqdm import tqdm
from mflux.models.common.config.model_config import ModelConfig
logger = ...
class Config:
def __init__(
self,
model_config: ModelConfig,
num_inference_steps: int = ...,
height: int = ...,
width: int = ...,
guidance: float = ...,
image_path: Path | str | None = ...,
image_strength: float | None = ...,
depth_image_path: Path | str | None = ...,
redux_image_paths: list[Path | str] | None = ...,
redux_image_strengths: list[float] | None = ...,
masked_image_path: Path | str | None = ...,
controlnet_strength: float | None = ...,
scheduler: str = ...,
) -> None: ...
@property
def height(self) -> int: ...
@property
def width(self) -> int: ...
@width.setter
def width(self, value): # -> None:
...
@property
def image_seq_len(self) -> int: ...
@property
def guidance(self) -> float: ...
@property
def num_inference_steps(self) -> int: ...
@property
def precision(self) -> mx.Dtype: ...
@property
def num_train_steps(self) -> int: ...
@property
def image_path(self) -> Path | None: ...
@property
def image_strength(self) -> float | None: ...
@property
def depth_image_path(self) -> Path | None: ...
@property
def redux_image_paths(self) -> list[Path] | None: ...
@property
def redux_image_strengths(self) -> list[float] | None: ...
@property
def masked_image_path(self) -> Path | None: ...
@property
def init_time_step(self) -> int: ...
@property
def time_steps(self) -> tqdm: ...
@property
def controlnet_strength(self) -> float | None: ...
@property
def scheduler(self) -> Any: ...

View File

@@ -0,0 +1,86 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from functools import lru_cache
from typing import Literal
class ModelConfig:
precision: mx.Dtype = ...
def __init__(
self,
priority: int,
aliases: list[str],
model_name: str,
base_model: str | None,
controlnet_model: str | None,
custom_transformer_model: str | None,
num_train_steps: int | None,
max_sequence_length: int | None,
supports_guidance: bool | None,
requires_sigma_shift: bool | None,
transformer_overrides: dict | None = ...,
) -> None: ...
@staticmethod
@lru_cache
def dev() -> ModelConfig: ...
@staticmethod
@lru_cache
def schnell() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_kontext() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_fill() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_redux() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_depth() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_controlnet_canny() -> ModelConfig: ...
@staticmethod
@lru_cache
def schnell_controlnet_canny() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_controlnet_upscaler() -> ModelConfig: ...
@staticmethod
@lru_cache
def dev_fill_catvton() -> ModelConfig: ...
@staticmethod
@lru_cache
def krea_dev() -> ModelConfig: ...
@staticmethod
@lru_cache
def flux2_klein_4b() -> ModelConfig: ...
@staticmethod
@lru_cache
def flux2_klein_9b() -> ModelConfig: ...
@staticmethod
@lru_cache
def qwen_image() -> ModelConfig: ...
@staticmethod
@lru_cache
def qwen_image_edit() -> ModelConfig: ...
@staticmethod
@lru_cache
def fibo() -> ModelConfig: ...
@staticmethod
@lru_cache
def z_image_turbo() -> ModelConfig: ...
@staticmethod
@lru_cache
def seedvr2_3b() -> ModelConfig: ...
def x_embedder_input_dim(self) -> int: ...
def is_canny(self) -> bool: ...
@staticmethod
def from_name(
model_name: str, base_model: Literal["dev", "schnell", "krea-dev"] | None = ...
) -> ModelConfig: ...
AVAILABLE_MODELS = ...

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,49 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from pathlib import Path
from typing import TYPE_CHECKING, TypeAlias
from mlx import nn
from mflux.models.common.vae.tiling_config import TilingConfig
from mflux.models.fibo.latent_creator.fibo_latent_creator import FiboLatentCreator
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.z_image.latent_creator.z_image_latent_creator import (
ZImageLatentCreator,
)
if TYPE_CHECKING:
LatentCreatorType: TypeAlias = type[
FiboLatentCreator | FluxLatentCreator | QwenLatentCreator | ZImageLatentCreator
]
class Img2Img:
def __init__(
self,
vae: nn.Module,
latent_creator: LatentCreatorType,
sigmas: mx.array,
init_time_step: int,
image_path: str | Path | None,
tiling_config: TilingConfig | None = ...,
) -> None: ...
class LatentCreator:
@staticmethod
def create_for_txt2img_or_img2img(
seed: int, height: int, width: int, img2img: Img2Img
) -> mx.array: ...
@staticmethod
def encode_image(
vae: nn.Module,
image_path: str | Path,
height: int,
width: int,
tiling_config: TilingConfig | None = ...,
) -> mx.array: ...
@staticmethod
def add_noise_by_interpolation(
clean: mx.array, noise: mx.array, sigma: float
) -> mx.array: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
from mlx import nn
from mflux.models.common.lora.layer.linear_lora_layer import LoRALinear
class FusedLoRALinear(nn.Module):
def __init__(
self, base_linear: nn.Linear | nn.QuantizedLinear, loras: list[LoRALinear]
) -> None: ...
def __call__(self, x): # -> array:
...

View File

@@ -0,0 +1,22 @@
"""
This type stub file was generated by pyright.
"""
from mlx import nn
class LoRALinear(nn.Module):
@staticmethod
def from_linear(
linear: nn.Linear | nn.QuantizedLinear, r: int = ..., scale: float = ...
): # -> LoRALinear:
...
def __init__(
self,
input_dims: int,
output_dims: int,
r: int = ...,
scale: float = ...,
bias: bool = ...,
) -> None: ...
def __call__(self, x): # -> array:
...

View File

@@ -0,0 +1,26 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
from collections.abc import Callable
from dataclasses import dataclass
from mflux.models.common.lora.mapping.lora_mapping import LoRATarget
@dataclass
class PatternMatch:
source_pattern: str
target_path: str
matrix_name: str
transpose: bool
transform: Callable[[mx.array], mx.array] | None = ...
class LoRALoader:
@staticmethod
def load_and_apply_lora(
lora_mapping: list[LoRATarget],
transformer: nn.Module,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> tuple[list[str], list[float]]: ...

View File

@@ -0,0 +1,21 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from collections.abc import Callable
from dataclasses import dataclass
from typing import List, Protocol
@dataclass
class LoRATarget:
model_path: str
possible_up_patterns: List[str]
possible_down_patterns: List[str]
possible_alpha_patterns: List[str] = ...
up_transform: Callable[[mx.array], mx.array] | None = ...
down_transform: Callable[[mx.array], mx.array] | None = ...
class LoRAMapping(Protocol):
@staticmethod
def get_mapping() -> List[LoRATarget]: ...

View File

@@ -0,0 +1,9 @@
"""
This type stub file was generated by pyright.
"""
import mlx.nn as nn
class LoRASaver:
@staticmethod
def bake_and_strip_lora(module: nn.Module) -> nn.Module: ...

View File

@@ -0,0 +1,35 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class LoraTransforms:
@staticmethod
def split_q_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_k_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_v_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_q_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_k_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_v_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_q_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_k_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_v_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_mlp_up(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_q_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_k_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_v_down(tensor: mx.array) -> mx.array: ...
@staticmethod
def split_single_mlp_down(tensor: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.resolution.config_resolution import ConfigResolution
from mflux.models.common.resolution.lora_resolution import LoraResolution
from mflux.models.common.resolution.path_resolution import PathResolution
from mflux.models.common.resolution.quantization_resolution import (
QuantizationResolution,
)
__all__ = [
"ConfigResolution",
"LoraResolution",
"PathResolution",
"QuantizationResolution",
]

View File

@@ -0,0 +1,39 @@
"""
This type stub file was generated by pyright.
"""
from enum import Enum
from typing import NamedTuple
class QuantizationAction(Enum):
NONE = ...
STORED = ...
REQUESTED = ...
class PathAction(Enum):
LOCAL = ...
HUGGINGFACE_CACHED = ...
HUGGINGFACE = ...
ERROR = ...
class LoraAction(Enum):
LOCAL = ...
REGISTRY = ...
HUGGINGFACE_COLLECTION_CACHED = ...
HUGGINGFACE_COLLECTION = ...
HUGGINGFACE_REPO_CACHED = ...
HUGGINGFACE_REPO = ...
ERROR = ...
class ConfigAction(Enum):
EXACT_MATCH = ...
EXPLICIT_BASE = ...
INFER_SUBSTRING = ...
ERROR = ...
class Rule(NamedTuple):
priority: int
name: str
check: str
action: QuantizationAction | PathAction | LoraAction | ConfigAction
...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.config.model_config import ModelConfig
if TYPE_CHECKING: ...
logger = ...
class ConfigResolution:
RULES = ...
@staticmethod
def resolve(model_name: str, base_model: str | None = ...) -> ModelConfig: ...

View File

@@ -0,0 +1,21 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
logger = ...
class LoraResolution:
RULES = ...
_registry: dict[str, Path] = ...
@staticmethod
def resolve(path: str) -> str: ...
@staticmethod
def resolve_paths(paths: list[str] | None) -> list[str]: ...
@staticmethod
def resolve_scales(scales: list[float] | None, num_paths: int) -> list[float]: ...
@staticmethod
def get_registry() -> dict[str, Path]: ...
@staticmethod
def discover_files(library_paths: list[Path]) -> dict[str, Path]: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
from pathlib import Path
logger = ...
class PathResolution:
RULES = ...
@staticmethod
def resolve(path: str | None, patterns: list[str] | None = ...) -> Path | None: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
logger = ...
class QuantizationResolution:
RULES = ...
@staticmethod
def resolve(
stored: int | None, requested: int | None
) -> tuple[int | None, str | None]: ...

View File

@@ -0,0 +1,26 @@
"""
This type stub file was generated by pyright.
"""
from .flow_match_euler_discrete_scheduler import FlowMatchEulerDiscreteScheduler
from .linear_scheduler import LinearScheduler
from .seedvr2_euler_scheduler import SeedVR2EulerScheduler
__all__ = [
"LinearScheduler",
"FlowMatchEulerDiscreteScheduler",
"SeedVR2EulerScheduler",
]
class SchedulerModuleNotFound(ValueError): ...
class SchedulerClassNotFound(ValueError): ...
class InvalidSchedulerType(TypeError): ...
SCHEDULER_REGISTRY = ...
def register_contrib(scheduler_object, scheduler_name=...): # -> None:
...
def try_import_external_scheduler(
scheduler_object_path: str,
): # -> type[BaseScheduler]:
...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from abc import ABC, abstractmethod
class BaseScheduler(ABC):
@property
@abstractmethod
def sigmas(self) -> mx.array: ...
@abstractmethod
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...

View File

@@ -0,0 +1,26 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class FlowMatchEulerDiscreteScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def sigmas(self) -> mx.array: ...
@property
def timesteps(self) -> mx.array: ...
def set_image_seq_len(self, image_seq_len: int) -> None: ...
@staticmethod
def get_timesteps_and_sigmas(
image_seq_len: int, num_inference_steps: int, num_train_timesteps: int = ...
) -> tuple[mx.array, mx.array]: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...
def scale_model_input(self, latents: mx.array, t: int) -> mx.array: ...

View File

@@ -0,0 +1,20 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class LinearScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def sigmas(self) -> mx.array: ...
@property
def timesteps(self) -> mx.array: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...

View File

@@ -0,0 +1,20 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import TYPE_CHECKING
from mflux.models.common.config.config import Config
from mflux.models.common.schedulers.base_scheduler import BaseScheduler
if TYPE_CHECKING: ...
class SeedVR2EulerScheduler(BaseScheduler):
def __init__(self, config: Config) -> None: ...
@property
def timesteps(self) -> mx.array: ...
@property
def sigmas(self) -> mx.array: ...
def step(
self, noise: mx.array, timestep: int, latents: mx.array, **kwargs
) -> mx.array: ...

View File

@@ -0,0 +1,24 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.tokenizer.tokenizer import (
BaseTokenizer,
LanguageTokenizer,
Tokenizer,
VisionLanguageTokenizer,
)
from mflux.models.common.tokenizer.tokenizer_loader import TokenizerLoader
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
"""
This type stub file was generated by pyright.
"""
__all__ = [
"Tokenizer",
"BaseTokenizer",
"LanguageTokenizer",
"VisionLanguageTokenizer",
"TokenizerLoader",
"TokenizerOutput",
]

View File

@@ -0,0 +1,74 @@
"""
This type stub file was generated by pyright.
"""
from abc import ABC, abstractmethod
from typing import Protocol, runtime_checkable
from PIL import Image
from transformers import PreTrainedTokenizer
from mflux.models.common.tokenizer.tokenizer_output import TokenizerOutput
"""
This type stub file was generated by pyright.
"""
@runtime_checkable
class Tokenizer(Protocol):
tokenizer: PreTrainedTokenizer
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class BaseTokenizer(ABC):
def __init__(
self, tokenizer: PreTrainedTokenizer, max_length: int = ...
) -> None: ...
@abstractmethod
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class LanguageTokenizer(BaseTokenizer):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
max_length: int = ...,
padding: str = ...,
return_attention_mask: bool = ...,
template: str | None = ...,
use_chat_template: bool = ...,
chat_template_kwargs: dict | None = ...,
add_special_tokens: bool = ...,
) -> None: ...
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...
class VisionLanguageTokenizer(BaseTokenizer):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
processor,
max_length: int = ...,
template: str | None = ...,
image_token: str = ...,
) -> None: ...
def tokenize(
self,
prompt: str | list[str],
images: list[Image.Image] | None = ...,
max_length: int | None = ...,
**kwargs,
) -> TokenizerOutput: ...

View File

@@ -0,0 +1,22 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
from mflux.models.common.weights.loading.weight_definition import TokenizerDefinition
"""
This type stub file was generated by pyright.
"""
if TYPE_CHECKING: ...
class TokenizerLoader:
@staticmethod
def load(definition: TokenizerDefinition, model_path: str) -> BaseTokenizer: ...
@staticmethod
def load_all(
definitions: list[TokenizerDefinition],
model_path: str,
max_length_overrides: dict[str, int] | None = ...,
) -> dict[str, BaseTokenizer]: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
"""
This type stub file was generated by pyright.
"""
@dataclass
class TokenizerOutput:
input_ids: mx.array
attention_mask: mx.array
pixel_values: mx.array | None = ...
image_grid_thw: mx.array | None = ...

View File

@@ -0,0 +1,8 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.vae.tiling_config import TilingConfig
from mflux.models.common.vae.vae_tiler import VAETiler
__all__ = ["TilingConfig", "VAETiler"]

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class TilingConfig:
vae_decode_tiles_per_dim: int | None = ...
vae_decode_overlap: int = ...
vae_encode_tiled: bool = ...
vae_encode_tile_size: int = ...
vae_encode_tile_overlap: int = ...

View File

@@ -0,0 +1,27 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import Callable
class VAETiler:
@staticmethod
def encode_image_tiled(
*,
image: mx.array,
encode_fn: Callable[[mx.array], mx.array],
latent_channels: int,
tile_size: tuple[int, int] = ...,
tile_overlap: tuple[int, int] = ...,
spatial_scale: int = ...,
) -> mx.array: ...
@staticmethod
def decode_image_tiled(
*,
latent: mx.array,
decode_fn: Callable[[mx.array], mx.array],
tile_size: tuple[int, int] = ...,
tile_overlap: tuple[int, int] = ...,
spatial_scale: int = ...,
) -> mx.array: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
from mflux.models.common.vae.tiling_config import TilingConfig
class VAEUtil:
@staticmethod
def encode(
vae: nn.Module, image: mx.array, tiling_config: TilingConfig | None = ...
) -> mx.array: ...
@staticmethod
def decode(
vae: nn.Module, latent: mx.array, tiling_config: TilingConfig | None = ...
) -> mx.array: ...

View File

@@ -0,0 +1,18 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights, MetaData
from mflux.models.common.weights.loading.weight_applier import WeightApplier
from mflux.models.common.weights.loading.weight_definition import ComponentDefinition
from mflux.models.common.weights.loading.weight_loader import WeightLoader
from mflux.models.common.weights.saving.model_saver import ModelSaver
__all__ = [
"ComponentDefinition",
"LoadedWeights",
"MetaData",
"ModelSaver",
"WeightApplier",
"WeightLoader",
]

View File

@@ -0,0 +1,18 @@
"""
This type stub file was generated by pyright.
"""
from dataclasses import dataclass
@dataclass
class MetaData:
quantization_level: int | None = ...
mflux_version: str | None = ...
@dataclass
class LoadedWeights:
components: dict[str, dict]
meta_data: MetaData
def __getattr__(self, name: str) -> dict | None: ...
def num_transformer_blocks(self, component_name: str = ...) -> int: ...
def num_single_transformer_blocks(self, component_name: str = ...) -> int: ...

View File

@@ -0,0 +1,30 @@
"""
This type stub file was generated by pyright.
"""
import mlx.nn as nn
from typing import TYPE_CHECKING
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
WeightDefinitionType,
)
if TYPE_CHECKING: ...
class WeightApplier:
@staticmethod
def apply_and_quantize_single(
weights: LoadedWeights,
model: nn.Module,
component: ComponentDefinition,
quantize_arg: int | None,
quantization_predicate=...,
) -> int | None: ...
@staticmethod
def apply_and_quantize(
weights: LoadedWeights,
models: dict[str, nn.Module],
quantize_arg: int | None,
weight_definition: WeightDefinitionType,
) -> int | None: ...

View File

@@ -0,0 +1,73 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from typing import Callable, List, TYPE_CHECKING, TypeAlias
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
from mflux.models.common.tokenizer.tokenizer import BaseTokenizer
from mflux.models.depth_pro.weights.depth_pro_weight_definition import (
DepthProWeightDefinition,
)
from mflux.models.fibo.weights.fibo_weight_definition import FIBOWeightDefinition
from mflux.models.fibo_vlm.weights.fibo_vlm_weight_definition import (
FIBOVLMWeightDefinition,
)
from mflux.models.flux.weights.flux_weight_definition import FluxWeightDefinition
from mflux.models.qwen.weights.qwen_weight_definition import QwenWeightDefinition
from mflux.models.seedvr2.weights.seedvr2_weight_definition import (
SeedVR2WeightDefinition,
)
from mflux.models.z_image.weights.z_image_weight_definition import (
ZImageWeightDefinition,
)
"""
This type stub file was generated by pyright.
"""
if TYPE_CHECKING:
WeightDefinitionType: TypeAlias = type[
FluxWeightDefinition
| FIBOWeightDefinition
| FIBOVLMWeightDefinition
| QwenWeightDefinition
| ZImageWeightDefinition
| SeedVR2WeightDefinition
| DepthProWeightDefinition
]
@dataclass
class ComponentDefinition:
name: str
hf_subdir: str
mapping_getter: Callable[[], List[WeightTarget]] | None = ...
model_attr: str | None = ...
num_blocks: int | None = ...
num_layers: int | None = ...
loading_mode: str = ...
precision: mx.Dtype | None = ...
skip_quantization: bool = ...
bulk_transform: Callable[[mx.array], mx.array] | None = ...
weight_subkey: str | None = ...
download_url: str | None = ...
weight_prefix_filters: List[str] | None = ...
weight_files: List[str] | None = ...
@dataclass
class TokenizerDefinition:
name: str
hf_subdir: str
tokenizer_class: str = ...
fallback_subdirs: List[str] | None = ...
download_patterns: List[str] | None = ...
encoder_class: type[BaseTokenizer] | None = ...
max_length: int = ...
padding: str = ...
template: str | None = ...
use_chat_template: bool = ...
chat_template_kwargs: dict | None = ...
add_special_tokens: bool = ...
processor_class: type | None = ...
image_token: str = ...
chat_template: str | None = ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
from typing import TYPE_CHECKING
from mflux.models.common.weights.loading.loaded_weights import LoadedWeights
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
WeightDefinitionType,
)
if TYPE_CHECKING: ...
logger = ...
class WeightLoader:
@staticmethod
def load_single(
component: ComponentDefinition, repo_id: str, file_pattern: str = ...
) -> LoadedWeights: ...
@staticmethod
def load(
weight_definition: WeightDefinitionType, model_path: str | None = ...
) -> LoadedWeights: ...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from typing import Dict, List, Optional
from mflux.models.common.weights.mapping.weight_mapping import WeightTarget
class WeightMapper:
@staticmethod
def apply_mapping(
hf_weights: Dict[str, mx.array],
mapping: List[WeightTarget],
num_blocks: Optional[int] = ...,
num_layers: Optional[int] = ...,
) -> Dict: ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from typing import Callable, List, Optional, Protocol
"""
This type stub file was generated by pyright.
"""
@dataclass
class WeightTarget:
to_pattern: str
from_pattern: List[str]
transform: Optional[Callable[[mx.array], mx.array]] = ...
required: bool = ...
max_blocks: Optional[int] = ...
class WeightMapping(Protocol):
@staticmethod
def get_mapping() -> List[WeightTarget]: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class WeightTransforms:
@staticmethod
def reshape_gamma_to_1d(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_patch_embed(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv3d_weight(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv2d_weight(tensor: mx.array) -> mx.array: ...
@staticmethod
def transpose_conv_transpose2d_weight(tensor: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
from typing import Any, TYPE_CHECKING
from mflux.models.common.weights.loading.weight_definition import WeightDefinitionType
if TYPE_CHECKING: ...
class ModelSaver:
@staticmethod
def save_model(
model: Any, bits: int, base_path: str, weight_definition: WeightDefinitionType
) -> None: ...

View File

@@ -0,0 +1,9 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.depth_pro.model.depth_pro_model import DepthProModel
class DepthProInitializer:
@staticmethod
def init(model: DepthProModel, quantize: int | None = ...) -> None: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class FeatureFusionBlock2d(nn.Module):
def __init__(self, num_features: int, deconv: bool = ...) -> None: ...
def __call__(self, x0: mx.array, x1: mx.array | None = ...) -> mx.array: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class MultiresConvDecoder(nn.Module):
def __init__(self) -> None: ...
def __call__(
self,
x0_latent: mx.array,
x1_latent: mx.array,
x0_features: mx.array,
x1_features: mx.array,
x_global_features: mx.array,
) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, num_features: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,20 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from dataclasses import dataclass
from pathlib import Path
from PIL import Image
@dataclass
class DepthResult:
depth_image: Image.Image
depth_array: mx.array
min_depth: float
max_depth: float
...
class DepthPro:
def __init__(self, quantize: int | None = ...) -> None: ...
def create_depth_map(self, image_path: str | Path) -> DepthResult: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProModel(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, x0: mx.array, x1: mx.array, x2: mx.array
) -> tuple[mx.array, mx.array]: ...

View File

@@ -0,0 +1,15 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProUtil:
@staticmethod
def split(x: mx.array, overlap_ratio: float = ...) -> mx.array: ...
@staticmethod
def interpolate(x: mx.array, size=..., scale_factor=...): # -> array:
...
@staticmethod
def apply_conv(x: mx.array, conv_module: nn.Module) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class Attention(nn.Module):
def __init__(
self, dim: int = ..., head_dim: int = ..., num_heads: int = ...
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DinoVisionTransformer(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array, mx.array]: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class LayerScale(nn.Module):
def __init__(self, dims: int, init_values: float = ...) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class MLP(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class PatchEmbed(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class TransformerBlock(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class DepthProEncoder(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, x0: mx.array, x1: mx.array, x2: mx.array
) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class UpSampleBlock(nn.Module):
def __init__(
self,
dim_in: int = ...,
dim_int: int = ...,
dim_out: int = ...,
upsample_layers: int = ...,
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
import mlx.nn as nn
class FOVHead(nn.Module):
def __init__(self) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
class DepthProWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class DepthProWeightMapping(WeightMapping):
@staticmethod
def get_mapping() -> List[WeightTarget]: ...

View File

@@ -0,0 +1,13 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
class FiboLatentCreator:
@staticmethod
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
@staticmethod
def pack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...
@staticmethod
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...

View File

@@ -0,0 +1,23 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
class FIBOWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -0,0 +1,17 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class FIBOWeightMapping(WeightMapping):
@staticmethod
def get_transformer_mapping() -> List[WeightTarget]: ...
@staticmethod
def get_text_encoder_mapping() -> List[WeightTarget]: ...
@staticmethod
def get_vae_mapping() -> List[WeightTarget]: ...

View File

@@ -0,0 +1,8 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.qwen.tokenizer.qwen_image_processor import QwenImageProcessor
class Qwen2VLImageProcessor(QwenImageProcessor):
def __init__(self) -> None: ...

View File

@@ -0,0 +1,28 @@
"""
This type stub file was generated by pyright.
"""
from typing import Optional, Union
from PIL import Image
class Qwen2VLProcessor:
def __init__(self, tokenizer) -> None: ...
def apply_chat_template(
self,
messages,
tokenize: bool = ...,
add_generation_prompt: bool = ...,
return_tensors: Optional[str] = ...,
return_dict: bool = ...,
**kwargs,
): # -> dict[Any, Any]:
...
def __call__(
self,
text: Optional[Union[str, list[str]]] = ...,
images: Optional[Union[Image.Image, list[Image.Image]]] = ...,
padding: bool = ...,
return_tensors: Optional[str] = ...,
**kwargs,
): # -> dict[Any, Any]:
...

View File

@@ -0,0 +1,24 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.loading.weight_definition import (
ComponentDefinition,
TokenizerDefinition,
)
"""
This type stub file was generated by pyright.
"""
QWEN2VL_CHAT_TEMPLATE = ...
class FIBOVLMWeightDefinition:
@staticmethod
def get_components() -> List[ComponentDefinition]: ...
@staticmethod
def get_tokenizers() -> List[TokenizerDefinition]: ...
@staticmethod
def get_download_patterns() -> List[str]: ...
@staticmethod
def quantization_predicate(path: str, module) -> bool: ...

View File

@@ -0,0 +1,15 @@
"""
This type stub file was generated by pyright.
"""
from typing import List
from mflux.models.common.weights.mapping.weight_mapping import (
WeightMapping,
WeightTarget,
)
class FIBOVLMWeightMapping(WeightMapping):
@staticmethod
def get_vlm_decoder_mapping(num_layers: int = ...) -> List[WeightTarget]: ...
@staticmethod
def get_vlm_visual_mapping(depth: int = ...) -> List[WeightTarget]: ...

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,3 @@
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,53 @@
"""
This type stub file was generated by pyright.
"""
from mflux.models.common.config import ModelConfig
class FluxInitializer:
@staticmethod
def init(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
custom_transformer=...,
) -> None: ...
@staticmethod
def init_depth(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_redux(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_controlnet(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...
@staticmethod
def init_concept(
model,
model_config: ModelConfig,
quantize: int | None,
model_path: str | None = ...,
lora_paths: list[str] | None = ...,
lora_scales: list[float] | None = ...,
) -> None: ...

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,19 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
"""
This type stub file was generated by pyright.
"""
class FluxLatentCreator:
@staticmethod
def create_noise(seed: int, height: int, width: int) -> mx.array: ...
@staticmethod
def pack_latents(
latents: mx.array, height: int, width: int, num_channels_latents: int = ...
) -> mx.array: ...
@staticmethod
def unpack_latents(latents: mx.array, height: int, width: int) -> mx.array: ...

View File

@@ -0,0 +1,7 @@
"""
This type stub file was generated by pyright.
"""
"""
This type stub file was generated by pyright.
"""

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPEmbeddings(nn.Module):
def __init__(self, dims: int) -> None: ...
def __call__(self, tokens: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
"""
This type stub file was generated by pyright.
"""
class CLIPEncoder(nn.Module):
def __init__(self) -> None: ...
def __call__(self, tokens: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPEncoderLayer(nn.Module):
def __init__(self, layer: int) -> None: ...
def __call__(
self, hidden_states: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPMLP(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def quick_gelu(input_array: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,18 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPSdpaAttention(nn.Module):
head_dimension = ...
batch_size = ...
num_heads = ...
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...
@staticmethod
def reshape_and_transpose(x, batch_size, num_heads, head_dim): # -> array:
...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class CLIPTextModel(nn.Module):
def __init__(self, dims: int, num_encoder_layers: int) -> None: ...
def __call__(self, tokens: mx.array) -> tuple[mx.array, mx.array]: ...
@staticmethod
def create_causal_attention_mask(input_shape: tuple) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class EncoderCLIP(nn.Module):
def __init__(self, num_encoder_layers: int) -> None: ...
def __call__(
self, tokens: mx.array, causal_attention_mask: mx.array
) -> mx.array: ...

View File

@@ -0,0 +1,25 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mflux.models.common.tokenizer import Tokenizer
from mflux.models.flux.model.flux_text_encoder.clip_encoder.clip_encoder import (
CLIPEncoder,
)
from mflux.models.flux.model.flux_text_encoder.t5_encoder.t5_encoder import T5Encoder
"""
This type stub file was generated by pyright.
"""
class PromptEncoder:
@staticmethod
def encode_prompt(
prompt: str,
prompt_cache: dict[str, tuple[mx.array, mx.array]],
t5_tokenizer: Tokenizer,
clip_tokenizer: Tokenizer,
t5_text_encoder: T5Encoder,
clip_text_encoder: CLIPEncoder,
) -> tuple[mx.array, mx.array]: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5Attention(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5Block(nn.Module):
def __init__(self, layer: int) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5DenseReluDense(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def new_gelu(input_array: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,14 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
"""
This type stub file was generated by pyright.
"""
class T5Encoder(nn.Module):
def __init__(self) -> None: ...
def __call__(self, tokens: mx.array): ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5FeedForward(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5LayerNorm(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,16 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class T5SelfAttention(nn.Module):
def __init__(self) -> None: ...
def __call__(self, hidden_states: mx.array) -> mx.array: ...
@staticmethod
def shape(states): # -> array:
...
@staticmethod
def un_shape(states): # -> array:
...

View File

@@ -0,0 +1,10 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormContinuous(nn.Module):
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int) -> None: ...
def __call__(self, x: mx.array, text_embeddings: mx.array) -> mx.array: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormZero(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, text_embeddings: mx.array
) -> tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: ...

View File

@@ -0,0 +1,12 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AdaLayerNormZeroSingle(nn.Module):
def __init__(self) -> None: ...
def __call__(
self, hidden_states: mx.array, text_embeddings: mx.array
) -> tuple[mx.array, mx.array]: ...

View File

@@ -0,0 +1,41 @@
"""
This type stub file was generated by pyright.
"""
import mlx.core as mx
from mlx import nn
class AttentionUtils:
@staticmethod
def process_qkv(
hidden_states: mx.array,
to_q: nn.Linear,
to_k: nn.Linear,
to_v: nn.Linear,
norm_q: nn.RMSNorm,
norm_k: nn.RMSNorm,
num_heads: int,
head_dim: int,
) -> tuple[mx.array, mx.array, mx.array]: ...
@staticmethod
def compute_attention(
query: mx.array,
key: mx.array,
value: mx.array,
batch_size: int,
num_heads: int,
head_dim: int,
mask: mx.array | None = ...,
) -> mx.array: ...
@staticmethod
def convert_key_padding_mask_to_additive_mask(
mask: mx.array | None, joint_seq_len: int, txt_seq_len: int
) -> mx.array | None: ...
@staticmethod
def apply_rope(
xq: mx.array, xk: mx.array, freqs_cis: mx.array
) -> tuple[mx.array, mx.array]: ...
@staticmethod
def apply_rope_bshd(
xq: mx.array, xk: mx.array, cos: mx.array, sin: mx.array
) -> tuple[mx.array, mx.array]: ...

Some files were not shown because too many files have changed in this diff Show More