Compare commits

...

189 Commits

Author SHA1 Message Date
Sami Khan
fc8aaeaf48 small UI change 2026-01-07 04:19:43 +05:00
Sami Khan
24d99ada4a image gen in dashboard 2026-01-07 03:56:42 +05:00
ciaranbor
3bcce669d1 log 2026-01-06 17:03:57 +00:00
ciaranbor
a5c6db7145 Prune blocks before model load 2026-01-06 16:47:58 +00:00
ciaranbor
e74345bb09 Own TODOs 2026-01-06 13:30:17 +00:00
ciaranbor
58d1f159b7 Remove double RunnerReady event 2026-01-06 13:25:01 +00:00
ciaranbor
8183225714 Fix hidden_size for image models 2026-01-06 12:33:28 +00:00
ciaranbor
46f957ee5b Fix uv.lock 2026-01-06 12:25:23 +00:00
ciaranbor
a2f52e04e3 Fix image model cards 2026-01-06 11:00:01 +00:00
ciaranbor
859960608b Skip decode on non-final ranks 2026-01-06 10:51:21 +00:00
ciaranbor
80ad016004 Final rank produces image 2026-01-06 10:51:21 +00:00
ciaranbor
d8938e6e72 Increase number of sync steps 2026-01-06 10:51:21 +00:00
ciaranbor
bd6a6cc6d3 Change Qwen-Image steps 2026-01-06 10:51:21 +00:00
ciaranbor
a88d588de4 Fix Qwen-Image latent shapes 2026-01-06 10:51:21 +00:00
ciaranbor
08e8a30fb7 Fix joint block patch recv shape for non-zero ranks 2026-01-06 10:51:21 +00:00
ciaranbor
d926df8f95 Fix comms issue for models without single blocks 2026-01-06 10:51:21 +00:00
ciaranbor
4ff550106d Support Qwen in DiffusionRunner pipefusion 2026-01-06 10:51:21 +00:00
ciaranbor
1d168dfe61 Implement Qwen pipefusion 2026-01-06 10:51:21 +00:00
ciaranbor
f813b9f5e1 Add guidance_scale parameter to image model config 2026-01-06 10:51:21 +00:00
ciaranbor
0f96083d48 Move orchestration to DiffusionRunner 2026-01-06 10:51:21 +00:00
ciaranbor
1a1b394f6d Add initial QwenModelAdapter 2026-01-06 10:51:21 +00:00
ciaranbor
6ab5a9d3d4 Tweak embeddings interface 2026-01-06 10:51:21 +00:00
ciaranbor
90bf4608df Add Qwen ImageModelConfig 2026-01-06 10:51:21 +00:00
ciaranbor
f2a0fdf25c Use 10% sync steps 2026-01-06 10:51:21 +00:00
ciaranbor
f574b3f57e Update FluxModelAdaper for new interface 2026-01-06 10:51:21 +00:00
ciaranbor
5cca9d8493 Register QwenModelAdapter 2026-01-06 10:51:21 +00:00
ciaranbor
3cd421079b Support multiple forward passes in runner 2026-01-06 10:51:21 +00:00
ciaranbor
d9eb4637ee Extend block wrapper parameters 2026-01-06 10:51:21 +00:00
ciaranbor
19f52e80fd Relax adaptor typing 2026-01-06 10:51:21 +00:00
ciaranbor
3f4162b732 Add Qwen-Image model card 2026-01-06 10:51:21 +00:00
ciaranbor
cad86ee76e Clean up dead code 2026-01-06 10:51:21 +00:00
ciaranbor
d7be6a09b0 Add BaseModelAdaptor 2026-01-06 10:51:21 +00:00
ciaranbor
79603e73ed Refactor filestructure 2026-01-06 10:51:21 +00:00
ciaranbor
78901cfe23 Treat unified blocks as single blocks (equivalent) 2026-01-06 10:51:21 +00:00
ciaranbor
c0ac199ab8 Refactor to handle entire denoising process in Diffusion runner 2026-01-06 10:51:21 +00:00
ciaranbor
b70d6abfa2 Move transformer to adapter 2026-01-06 10:51:21 +00:00
ciaranbor
16bfab9bab Move some more logic to adaptor 2026-01-06 10:51:21 +00:00
ciaranbor
28ee6f6370 Add generic block wrapper 2026-01-06 10:51:21 +00:00
ciaranbor
6b299bab8f Access transformer blocks from adaptor 2026-01-06 10:51:21 +00:00
ciaranbor
a3754a60b6 Better typing 2026-01-06 10:51:21 +00:00
ciaranbor
06039f93f5 Create wrappers at init time 2026-01-06 10:51:21 +00:00
ciaranbor
fcfecc9cd8 Combine model factory and adaptor 2026-01-06 10:51:21 +00:00
ciaranbor
ba798ae4f9 Implement model factory 2026-01-06 10:51:21 +00:00
ciaranbor
9a0e1e93a9 Add adaptor registry 2026-01-06 10:51:21 +00:00
ciaranbor
196f504c82 Remove mflux/generator/generate.py 2026-01-06 10:51:21 +00:00
ciaranbor
e3d89b8d63 Switch to using DistributedImageModel 2026-01-06 10:51:21 +00:00
ciaranbor
cb8079525c Add DistributedImageModel 2026-01-06 10:51:21 +00:00
ciaranbor
cb03c62c4a Use new generic wrappers, etc in denoising 2026-01-06 10:51:21 +00:00
ciaranbor
0653668048 Add generic transformer block wrappers 2026-01-06 10:51:21 +00:00
ciaranbor
0054bc4c14 Add FluxAdaptor 2026-01-06 10:51:21 +00:00
ciaranbor
b7b682b7bb Add ModelAdaptor, derivations implement model specific logic 2026-01-06 10:51:21 +00:00
ciaranbor
f7a651c1c1 Introduce image model config concept 2026-01-06 10:51:21 +00:00
ciaranbor
98e8d74cea Consolidate kv cache patching 2026-01-06 10:51:21 +00:00
ciaranbor
27567f8a4e Support different configuration comms 2026-01-06 10:51:21 +00:00
ciaranbor
28227bb45a Add ImageGenerator protocol 2026-01-06 10:51:21 +00:00
ciaranbor
7683d4a21f Force final patch receive order 2026-01-06 10:51:21 +00:00
ciaranbor
0a3cb77a29 Remove logs 2026-01-06 10:51:21 +00:00
ciaranbor
3f5810c1fe Update patch list 2026-01-06 10:51:21 +00:00
ciaranbor
fc62ae1b9b Slight refactor 2026-01-06 10:51:21 +00:00
ciaranbor
ec5bad4254 Don't need array for prev patches 2026-01-06 10:51:21 +00:00
ciaranbor
f9f54be32b Fix send/recv order 2026-01-06 10:51:21 +00:00
ciaranbor
36daf9183f Fix async single transformer block 2026-01-06 10:51:21 +00:00
ciaranbor
5d38ffc77e Use relative rank variables 2026-01-06 10:51:21 +00:00
ciaranbor
1b4851765a Fix writing patches 2026-01-06 10:51:21 +00:00
ciaranbor
8787eaf3df Collect final image 2026-01-06 10:51:21 +00:00
ciaranbor
e1e3aa7a5e Fix recv_template shape 2026-01-06 10:51:21 +00:00
ciaranbor
0fe5239273 Add logs 2026-01-06 10:51:21 +00:00
ciaranbor
7eddf7404b Optimise async pipeline 2026-01-06 10:51:21 +00:00
ciaranbor
5f3bc30f17 Add next_rank and prev_rank members 2026-01-06 10:51:21 +00:00
ciaranbor
90a7e6601d Add _create_patches method 2026-01-06 10:51:21 +00:00
ciaranbor
ce2691c8d3 Fix shapes 2026-01-06 10:51:21 +00:00
ciaranbor
076d2901e8 Reorder comms 2026-01-06 10:51:20 +00:00
ciaranbor
7a733b584c Remove all_gather from sync pipeline, send from final rank to first rank 2026-01-06 10:51:20 +00:00
ciaranbor
94fee6f2d2 Simplify kv_cache initialization 2026-01-06 10:51:20 +00:00
ciaranbor
ef4fe09424 Fix kv cache 2026-01-06 10:51:20 +00:00
ciaranbor
2919bcf21d Clean up kv caches 2026-01-06 10:51:20 +00:00
ciaranbor
dd84cc9ca2 Fix return 2026-01-06 10:51:20 +00:00
ciaranbor
5a74d76d41 Fix hidden_states shapes 2026-01-06 10:51:20 +00:00
ciaranbor
e115814c74 Only perform projection and scheduler step on last rank 2026-01-06 10:51:20 +00:00
ciaranbor
d85432d4f0 Only compute embeddings on rank 0 2026-01-06 10:51:20 +00:00
ciaranbor
da823a2b02 Remove eval 2026-01-06 10:51:20 +00:00
ciaranbor
8576f4252b Remove eval 2026-01-06 10:51:20 +00:00
ciaranbor
7ca0bc5b55 Only send encoder_hidden_states with the first patch (once per timestep) 2026-01-06 10:51:20 +00:00
ciaranbor
db24f052d7 Remove redundant text kv cache computation 2026-01-06 10:51:20 +00:00
ciaranbor
7b8382be10 Concatenate before all gather 2026-01-06 10:51:20 +00:00
ciaranbor
d3685b0eb5 Increase number of sync steps 2026-01-06 10:51:20 +00:00
ciaranbor
93f4bdc5f9 Reinitialise kv_caches between generations 2026-01-06 10:51:20 +00:00
ciaranbor
8eea0327b8 Eliminate double kv cache computation 2026-01-06 10:51:20 +00:00
ciaranbor
085358e5e0 Add kv cache caching wrappers for sync pipeline transformer blocks 2026-01-06 10:51:20 +00:00
ciaranbor
546efe4dd2 Persist kv caches 2026-01-06 10:51:20 +00:00
ciaranbor
4ddfb6e254 Implement naive async pipeline implementation 2026-01-06 10:51:20 +00:00
ciaranbor
12f20fd94e Use wrapper classes for patched transformer logic 2026-01-06 10:51:20 +00:00
ciaranbor
f7ba70d5ae Add patch-aware joint and single attention wrappers 2026-01-06 10:51:20 +00:00
ciaranbor
4ecad10a66 Fix group.size() 2026-01-06 10:51:20 +00:00
ciaranbor
552ae776fe Add classes to manage kv caches with patch support 2026-01-06 10:51:20 +00:00
ciaranbor
6e0a6e8956 Use heuristic for number of sync steps 2026-01-06 10:51:20 +00:00
ciaranbor
e8b0a2124c Generalise number of denoising steps 2026-01-06 10:51:20 +00:00
ciaranbor
129df1ec89 Add flux1-dev 2026-01-06 10:51:20 +00:00
ciaranbor
a87fe26973 Move scheduler step to inner pipeline 2026-01-06 10:51:20 +00:00
ciaranbor
a9ea223dc7 Add barrier before all_gather 2026-01-06 10:51:20 +00:00
ciaranbor
0af3349f2f Fix transformer blocks pruning 2026-01-06 10:51:20 +00:00
ciaranbor
20e3319a3e Fix image generation api 2026-01-06 10:51:20 +00:00
ciaranbor
4c88fac266 Create queue in try block 2026-01-06 10:51:20 +00:00
ciaranbor
e1d916f743 Conform to rebase 2026-01-06 10:51:20 +00:00
ciaranbor
09c9b2e29f Refactor denoising 2026-01-06 10:51:20 +00:00
ciaranbor
b6359a7199 Move more logic to DistributedFlux 2026-01-06 10:51:20 +00:00
ciaranbor
b5a043f676 Move surrounding logic back to _sync_pipeline 2026-01-06 10:51:20 +00:00
ciaranbor
55e690fd49 Add patching aware member variables 2026-01-06 10:51:20 +00:00
ciaranbor
9e4ffb11ec Implement sync/async switching logic 2026-01-06 10:51:20 +00:00
ciaranbor
d665a8d05a Move current transformer implementation to _sync_pipeline method 2026-01-06 10:51:20 +00:00
ciaranbor
cac77816be Remove some logs 2026-01-06 10:51:20 +00:00
ciaranbor
25b9c3369e Remove old Flux1 implementation 2026-01-06 10:51:20 +00:00
ciaranbor
c19c5b4080 Prune unused transformer blocks 2026-01-06 10:51:20 +00:00
ciaranbor
9592f8b6b0 Add mx.eval 2026-01-06 10:51:20 +00:00
ciaranbor
7d7c16ebc1 Test evals 2026-01-06 10:51:20 +00:00
ciaranbor
450d0ba923 Test only barriers 2026-01-06 10:51:20 +00:00
ciaranbor
ea64062362 All perform final projection 2026-01-06 10:51:20 +00:00
ciaranbor
206b12e912 Another barrier 2026-01-06 10:51:20 +00:00
ciaranbor
eecc1da596 More debug 2026-01-06 10:51:20 +00:00
ciaranbor
44e68e4498 Add barriers 2026-01-06 10:51:20 +00:00
ciaranbor
f1548452fa Add log 2026-01-06 10:51:20 +00:00
ciaranbor
97769c82a9 Restore distributed logging 2026-01-06 10:51:20 +00:00
ciaranbor
26e5b03285 Use bootstrap logger 2026-01-06 10:51:20 +00:00
ciaranbor
8f93a1ff78 Remove logs 2026-01-06 10:51:20 +00:00
ciaranbor
e07dcc43b9 fix single block receive shape 2026-01-06 10:51:20 +00:00
ciaranbor
f91d0797fb Add debug logs 2026-01-06 10:51:20 +00:00
ciaranbor
aaeebaf79e Move communication logic to DistributedTransformer wrapper 2026-01-06 10:51:20 +00:00
ciaranbor
c3075a003e Move inference logic to DistribuedFlux1 2026-01-06 10:51:20 +00:00
ciaranbor
be796e55ac Add DistributedFlux1 class 2026-01-06 10:51:20 +00:00
ciaranbor
6e0c611f37 Rename pipeline to pipefusion 2026-01-06 10:51:20 +00:00
ciaranbor
88996eddcb Further refactor 2026-01-06 10:51:20 +00:00
ciaranbor
fb4fae51fa Refactor warmup 2026-01-06 10:51:20 +00:00
ciaranbor
dbefc209f5 Manually handle flux1 inference 2026-01-06 10:51:20 +00:00
ciaranbor
e6dd95524c Refactor flux1 image generation 2026-01-06 10:51:20 +00:00
ciaranbor
c2a9e5e53b Use quality parameter to set number of inference steps 2026-01-06 10:51:20 +00:00
ciaranbor
21587898bc Chunk image data transfer 2026-01-06 10:51:20 +00:00
ciaranbor
b6f23d0b01 Define EXO_MAX_CHUNK_SIZE 2026-01-06 10:51:20 +00:00
ciaranbor
f00ba03f4b Add indexing info to ImageChunk 2026-01-06 10:50:56 +00:00
ciaranbor
73e3713296 Remove sharding logs 2026-01-06 10:50:56 +00:00
ciaranbor
ecca6b4d20 Temp: reduce flux1.schnell storage size 2026-01-06 10:50:56 +00:00
ciaranbor
8bac08a236 Fix mflux transformer all_gather 2026-01-06 10:50:34 +00:00
ciaranbor
e7cca752fd Add all_gather -> broadcast todo 2026-01-06 10:50:34 +00:00
ciaranbor
540fe8b278 Fix world size 2026-01-06 10:50:34 +00:00
ciaranbor
2972f4620c Fix transition block? 2026-01-06 10:50:34 +00:00
ciaranbor
0ed81d8afa Implement image generation warmup 2026-01-06 10:50:34 +00:00
ciaranbor
66a24d59b9 Add logs 2026-01-06 10:50:11 +00:00
ciaranbor
5dcc359dba Add spiece.model to default patterns 2026-01-06 10:50:11 +00:00
ciaranbor
c2a4d61865 Just download all files for now 2026-01-06 10:49:43 +00:00
ciaranbor
ba12ee4897 Fix get_allow_patterns to include non-indexed safetensors files 2026-01-06 10:49:43 +00:00
ciaranbor
bcd69a3b01 Use half-open layer indexing in get_allow_patterns 2026-01-06 10:49:43 +00:00
ciaranbor
f5eb5d0338 Enable distributed mflux 2026-01-06 10:49:43 +00:00
ciaranbor
058aff5145 Implement mflux transformer sharding and communication pattern 2026-01-06 10:49:43 +00:00
ciaranbor
5cb0bc6a63 Update get_allow_patterns to handle sharding components 2026-01-06 10:49:43 +00:00
ciaranbor
c3aab450c6 Namespace both keys and values for component weight maps 2026-01-06 10:49:43 +00:00
ciaranbor
cf27673e20 Add components to Flux.1-schnell MODEL_CARD 2026-01-06 10:49:43 +00:00
ciaranbor
96c165e297 Add component concept for ModelMetadata 2026-01-06 10:48:42 +00:00
ciaranbor
2a589177cd Fix multiple components weight map key conflicts 2026-01-06 10:48:26 +00:00
ciaranbor
f782b619b6 get_weight_map: handle repos with multiple safetensors.index.json files 2026-01-06 10:48:26 +00:00
ciaranbor
dc661e4b5e Add initial image edits spec 2026-01-06 10:48:26 +00:00
ciaranbor
8b7d8ef394 Add image edits endpoint 2026-01-06 10:47:44 +00:00
ciaranbor
7dd2b328c8 Add ImageToImage task 2026-01-06 10:45:26 +00:00
ciaranbor
73a165702d Allow ModelCards to have multiple tasks 2026-01-06 10:44:53 +00:00
ciaranbor
0c76978b35 Fix text generation 2026-01-06 10:41:38 +00:00
ciaranbor
25188c845e Rename mlx_generate_image to mflux_generate 2026-01-06 10:41:38 +00:00
ciaranbor
df94169aba Initialize mlx or mflux engine based on model task 2026-01-06 10:39:27 +00:00
ciaranbor
a2d4c0de2a Restore warmup for text generation 2026-01-06 10:17:21 +00:00
ciaranbor
2edbc7e026 Add initialize_mflux function 2026-01-06 10:17:21 +00:00
ciaranbor
8f6e360d21 Move image generation to mflux engine 2026-01-06 10:17:21 +00:00
ciaranbor
085b966a5f Just use str for image generation size 2026-01-06 10:17:21 +00:00
ciaranbor
c64a55bfed Use MFlux for image generation 2026-01-06 10:17:21 +00:00
ciaranbor
fee716faab Add get_model_card function 2026-01-06 10:17:21 +00:00
ciaranbor
b88c89ee9c Add ModelTask enum 2026-01-06 10:17:21 +00:00
ciaranbor
9ef7b913e2 ADd flux1-schnell model 2026-01-06 10:08:11 +00:00
ciaranbor
0daa4b36db Add task field to ModelCard 2026-01-06 10:08:11 +00:00
ciaranbor
3c2da43792 Update mflux version 2026-01-06 10:04:51 +00:00
ciaranbor
8c4c53b50a Enable recursive repo downloads 2026-01-06 10:03:18 +00:00
ciaranbor
b2beb4c9cd Add dummy generate_image implementation 2026-01-06 10:03:18 +00:00
ciaranbor
098a11b262 Use base64 encoded str for image data 2026-01-06 10:03:18 +00:00
ciaranbor
bedb9045a0 Handle ImageGeneration tasks in _pending_tasks 2026-01-06 10:03:18 +00:00
ciaranbor
8e23841b4e Add mflux dependency 2026-01-06 10:03:18 +00:00
ciaranbor
4420eac10d Handle ImageGeneration task in runner task processing 2026-01-06 10:02:04 +00:00
ciaranbor
d0772e9e0f Handle ImageGeneration command in master command processing 2026-01-06 10:00:32 +00:00
ciaranbor
8d861168f1 Add image generation to API 2026-01-06 10:00:07 +00:00
ciaranbor
242648dff4 Add ImageGenerationResponse 2026-01-06 09:59:13 +00:00
ciaranbor
9b06b754cb Add ImageGeneration task 2026-01-06 09:59:13 +00:00
ciaranbor
1603984f45 Add ImageGeneration command 2026-01-06 09:58:46 +00:00
ciaranbor
f9418843f8 Add image generation params and response types 2026-01-05 21:51:10 +00:00
ciaranbor
877e7196c3 Add pillow dependency 2026-01-05 21:51:10 +00:00
ciaranbor
db7c4670b9 Fix mlx stream_generate import 2026-01-05 21:18:23 +00:00
40 changed files with 5361 additions and 339 deletions

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import { isLoading, sendMessage, generateImage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import ChatAttachments from './ChatAttachments.svelte';
import type { ChatUploadedFile } from '$lib/types/files';
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
@@ -10,6 +10,7 @@
showHelperText?: boolean;
autofocus?: boolean;
showModelSelector?: boolean;
modelTasks?: Record<string, string[]>;
}
let {
@@ -17,7 +18,8 @@
placeholder = 'Ask anything',
showHelperText = false,
autofocus = true,
showModelSelector = false
showModelSelector = false,
modelTasks = {}
}: Props = $props();
let message = $state('');
@@ -48,13 +50,29 @@
// Accept all supported file types
const acceptString = getAcceptString(['image', 'text', 'pdf']);
// Check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const tasks = modelTasks[modelId] || [];
return tasks.includes('TextToImage') || tasks.includes('ImageToImage');
}
// Check if the currently selected model supports image generation
const isImageModel = $derived(() => {
if (!currentModel) return false;
return modelSupportsImageGeneration(currentModel);
});
// Extract available models from running instances
const availableModels = $derived(() => {
const models: Array<{id: string, label: string}> = [];
const models: Array<{id: string, label: string, isImageModel: boolean}> = [];
for (const [, instance] of Object.entries(instanceData)) {
const modelId = getInstanceModelId(instance);
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
models.push({
id: modelId,
label: modelId.split('/').pop() || modelId,
isImageModel: modelSupportsImageGeneration(modelId)
});
}
}
return models;
@@ -160,7 +178,12 @@
uploadedFiles = [];
resetTextareaHeight();
sendMessage(content, files);
// Use image generation for image models
if (isImageModel() && content) {
generateImage(content);
} else {
sendMessage(content, files);
}
// Refocus the textarea after sending
setTimeout(() => textareaRef?.focus(), 10);
@@ -297,7 +320,14 @@
{:else}
<span class="w-3"></span>
{/if}
<span class="truncate">{model.label}</span>
{#if model.isImageModel}
<svg class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate flex-1">{model.label}</span>
</button>
{/each}
</div>
@@ -357,7 +387,7 @@
onkeydown={handleKeydown}
oninput={handleInput}
onpaste={handlePaste}
{placeholder}
placeholder={isImageModel() ? 'Describe the image you want to generate...' : placeholder}
disabled={loading}
rows={1}
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
@@ -371,14 +401,23 @@
{!canSend || loading
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label="Send message"
aria-label={isImageModel() ? "Generate image" : "Send message"}
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
<span class="hidden sm:inline">PROCESSING</span>
<span class="hidden sm:inline">{isImageModel() ? 'GENERATING' : 'PROCESSING'}</span>
<span class="sm:hidden">...</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}

View File

@@ -365,10 +365,58 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
</div>
{/if}
<!-- Generated Images -->
{#if message.attachments?.some(a => a.type === 'generated-image')}
<div class="mb-3">
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
<div class="relative group/img inline-block">
<img
src={attachment.preview}
alt=""
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
/>
<!-- Download button overlay -->
<button
type="button"
class="absolute top-2 right-2 p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow opacity-0 group-hover/img:opacity-100 transition-opacity hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
onclick={() => {
if (attachment.preview) {
const link = document.createElement('a');
link.href = attachment.preview;
link.download = `generated-image-${Date.now()}.png`;
link.click();
}
}}
title="Download image"
>
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
</svg>
</button>
</div>
{/each}
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{#if message.content === 'Generating image...'}
<div class="flex items-center gap-3 text-exo-yellow">
<div class="relative">
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
</div>
<span class="font-mono tracking-wider uppercase text-sm">Generating image...</span>
</div>
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
{/if}
</div>
</div>

View File

@@ -163,7 +163,7 @@ interface RawStateResponse {
}
export interface MessageAttachment {
type: 'image' | 'text' | 'file';
type: 'image' | 'text' | 'file' | 'generated-image';
name: string;
content?: string;
preview?: string;
@@ -1413,6 +1413,90 @@ class AppStore {
}
}
/**
* Generate an image using the image generation API
*/
async generateImage(prompt: string, modelId?: string): Promise<void> {
if (!prompt.trim() || this.isLoading) return;
if (!this.hasStartedChat) {
this.startChat();
}
this.isLoading = true;
this.currentResponse = '';
// Add user message
const userMessage: Message = {
id: generateUUID(),
role: 'user',
content: prompt,
timestamp: Date.now()
};
this.messages.push(userMessage);
// Create placeholder for assistant message with generating state
const assistantMessage = this.addMessage('assistant', '');
this.messages[this.messages.length - 1].content = 'Generating image...';
this.updateActiveConversation();
try {
// Determine the model to use
let model = modelId || this.selectedChatModel;
if (!model) {
throw new Error('No model selected. Please select an image generation model.');
}
const response = await fetch('/v1/images/generations', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
model,
prompt,
quality: 'medium',
size: '1024x1024',
response_format: 'b64_json'
})
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`API error: ${response.status} - ${errorText}`);
}
const data = await response.json();
const imageData = data.data?.[0]?.b64_json;
if (!imageData) {
throw new Error('No image data received from the API');
}
// Update the assistant message with the generated image
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = '';
this.messages[idx].attachments = [{
type: 'generated-image',
name: 'generated-image.png',
preview: `data:image/png;base64,${imageData}`,
mimeType: 'image/png'
}];
}
} catch (error) {
console.error('Error generating image:', error);
const idx = this.messages.findIndex(m => m.id === assistantMessage.id);
if (idx !== -1) {
this.messages[idx].content = `Error: ${error instanceof Error ? error.message : 'Failed to generate image'}`;
}
} finally {
this.isLoading = false;
this.updateActiveConversation();
}
}
/**
* Clear current chat and go back to welcome state
*/
@@ -1463,6 +1547,7 @@ export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
// Actions
export const startChat = () => appStore.startChat();
export const sendMessage = (content: string, files?: { id: string; name: string; type: string; textContent?: string; preview?: string }[]) => appStore.sendMessage(content, files);
export const generateImage = (prompt: string, modelId?: string) => appStore.generateImage(prompt, modelId);
export const clearChat = () => appStore.clearChat();
export const setSelectedChatModel = (modelId: string) => appStore.setSelectedModel(modelId);
export const selectPreviewModel = (modelId: string | null) => appStore.selectPreviewModel(modelId);

View File

@@ -47,7 +47,30 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
const modelTasks = $derived(() => {
const tasks: Record<string, string[]> = {};
for (const model of models) {
if (model.tasks && model.tasks.length > 0) {
// Map by short ID
tasks[model.id] = model.tasks;
// Also map by hugging_face_id from the API response
if (model.hugging_face_id) {
tasks[model.hugging_face_id] = model.tasks;
}
}
}
return tasks;
});
// Helper to check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
if (!model?.tasks) return false;
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
}
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -1250,6 +1273,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
placeholder="Ask anything"
showHelperText={false}
showModelSelector={true}
modelTasks={modelTasks()}
/>
</div>
</div>
@@ -1471,8 +1495,18 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{@const foundModel = models.find(m => m.id === selectedModelId)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="flex items-center gap-2 text-exo-light-gray truncate">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{foundModel.name || foundModel.id}</span>
</span>
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
</span>
{:else}
@@ -1517,6 +1551,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(model.id)}
<button
type="button"
onclick={() => {
@@ -1536,7 +1571,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
: 'text-white/30 cursor-default'
}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
@@ -1733,7 +1777,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
<div class="max-w-7xl mx-auto">
<ChatForm placeholder="Ask anything" showModelSelector={true} />
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
</div>
</div>
</div>

View File

@@ -35,6 +35,8 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux>=0.12.1",
]
[project.scripts]

View File

@@ -24,7 +24,7 @@ from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
ChatCompletionChoice,
@@ -34,6 +34,10 @@ from exo.shared.types.api import (
CreateInstanceResponse,
DeleteInstanceResponse,
FinishReason,
ImageData,
ImageEditsTaskParams,
ImageGenerationResponse,
ImageGenerationTaskParams,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -41,13 +45,15 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
TaskFinished,
)
@@ -84,12 +90,23 @@ def chunk_to_response(
)
async def resolve_model_meta(model_id: str) -> ModelMetadata:
def get_model_card(model_id: str) -> ModelCard | None:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for _, model_card in MODEL_CARDS.items():
if model_id == model_card.model_id:
return model_card
async def resolve_model_meta(model_id: str) -> ModelMetadata:
model_card = get_model_card(model_id)
if model_card is not None:
return model_card.metadata
else:
return await get_model_meta(model_id)
return await get_model_meta(model_id)
class API:
@@ -133,6 +150,7 @@ class API:
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -141,6 +159,7 @@ class API:
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._chat_completion_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
def unpause(self, result_clock: int):
@@ -172,6 +191,8 @@ class API:
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
self.app.post("/v1/images/generations")(self.image_generations)
# self.app.post("/v1/images/edits")(self.image_edits)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -525,6 +546,87 @@ class API:
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
async def image_generations(
self, payload: ImageGenerationTaskParams
) -> ImageGenerationResponse:
"""Handle image generation requests."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
)
command = ImageGeneration(
request_params=payload,
)
await self._send(command)
# Collect all image chunks (non-streaming)
num_images = payload.n or 1
# Track chunks per image: {image_index: {chunk_index: data}}
image_chunks: dict[int, dict[int, str]] = {}
image_total_chunks: dict[int, int] = {}
images_complete = 0
try:
self._image_generation_queues[command.command_id], recv = channel[
ImageChunk
]()
while images_complete < num_images:
with recv as chunks:
async for chunk in chunks:
if chunk.image_index not in image_chunks:
image_chunks[chunk.image_index] = {}
image_total_chunks[chunk.image_index] = chunk.total_chunks
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
# Check if this image is complete
if (
len(image_chunks[chunk.image_index])
== image_total_chunks[chunk.image_index]
):
images_complete += 1
if images_complete >= num_images:
break
# Reassemble images in order
images: list[ImageData] = []
for image_idx in range(num_images):
chunks_dict = image_chunks[image_idx]
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
images.append(
ImageData(
b64_json=full_data
if payload.response_format == "b64_json"
else None,
url=None, # URL format not implemented yet
)
)
return ImageGenerationResponse(data=images)
except anyio.get_cancelled_exc_class():
# TODO(ciaran): TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
raise
finally:
# Send TaskFinished command
await self._send(TaskFinished(finished_command_id=command.command_id))
del self._image_generation_queues[command.command_id]
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -547,6 +649,7 @@ class API:
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
tasks=[task.value for task in card.tasks],
)
for card in MODEL_CARDS.values()
]
@@ -584,14 +687,17 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if (
isinstance(event, ChunkGenerated)
and event.command_id in self._chat_completion_queues
):
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
if isinstance(event, ChunkGenerated):
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
elif event.command_id in self._image_generation_queues:
assert isinstance(event.chunk, ImageChunk)
await self._image_generation_queues[event.command_id].send(
event.chunk
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -16,6 +16,8 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
RequestEventLog,
TaskFinished,
@@ -35,6 +37,12 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
from exo.shared.types.tasks import (
ImageGeneration as ImageGenerationTask,
)
from exo.shared.types.tasks import (
TaskId,
TaskStatus,
@@ -146,6 +154,94 @@ class Master:
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageGeneration():
# TODO(ciaran): refactor with ChatCompletion
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageEdits():
# TODO(ciaran): refactor with ChatCompletion
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)

View File

@@ -44,3 +44,5 @@ LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024

View File

@@ -1,5 +1,5 @@
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.models import ComponentInfo, ModelId, ModelMetadata, ModelTask
from exo.utils.pydantic_ext import CamelCaseModel
@@ -8,6 +8,7 @@ class ModelCard(CamelCaseModel):
model_id: ModelId
name: str
description: str
tasks: list[ModelTask]
tags: list[str]
metadata: ModelMetadata
@@ -45,6 +46,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
@@ -60,6 +62,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
@@ -133,6 +136,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
@@ -148,6 +152,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
@@ -164,6 +169,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
@@ -179,6 +185,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
@@ -194,6 +201,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
@@ -209,6 +217,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
@@ -225,6 +234,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
@@ -240,6 +250,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
@@ -255,6 +266,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
@@ -271,6 +283,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
@@ -286,6 +299,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
@@ -301,6 +315,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
@@ -317,6 +332,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
@@ -332,6 +348,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
@@ -347,6 +364,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
@@ -362,6 +380,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
@@ -377,6 +396,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
@@ -392,6 +412,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
@@ -407,6 +428,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
@@ -422,6 +444,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
@@ -437,6 +460,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
@@ -452,6 +476,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
@@ -467,6 +492,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
@@ -482,6 +508,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
@@ -498,6 +525,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
@@ -513,6 +541,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
@@ -529,6 +558,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
@@ -544,6 +574,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
@@ -569,4 +600,146 @@ MODEL_CARDS: dict[str, ModelCard] = {
# supports_tensor=True,
# ),
# ),
"flux1-schnell": ModelCard(
short_id="flux1-schnell",
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
name="FLUX.1 [schnell]",
description="""FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
pretty_name="FLUX.1 [schnell]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120), # + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"flux1-dev": ModelCard(
short_id="flux1-dev",
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
name="FLUX.1 [dev]",
description="""FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
pretty_name="FLUX.1 [dev]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"qwen-image": ModelCard(
short_id="qwen-image",
model_id=ModelId("Qwen/Qwen-Image"),
name="Qwen Image",
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("Qwen/Qwen-Image"),
pretty_name="Qwen Image",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
}

View File

@@ -1,6 +1,7 @@
import time
from typing import Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
@@ -27,6 +28,7 @@ class ModelListModel(BaseModel):
tags: list[str] = Field(default=[])
storage_size_megabytes: int = Field(default=0)
supports_tensor: bool = Field(default=False)
tasks: list[str] = Field(default=[])
class ModelList(BaseModel):
@@ -181,3 +183,38 @@ class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class ImageGenerationTaskParams(BaseModel):
prompt: str
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
user: str | None = None
class ImageEditsTaskParams(BaseModel):
image: UploadFile
mask: UploadFile | None
prompt: str
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
user: str | None = None
class ImageData(BaseModel):
b64_json: str | None = None
url: str | None = None
revised_prompt: str | None = None
class ImageGenerationResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
data: list[ImageData]

View File

@@ -23,7 +23,10 @@ class TokenChunk(BaseChunk):
class ImageChunk(BaseChunk):
data: bytes
data: str
chunk_index: int
total_chunks: int
image_index: int
GenerationChunk = TokenChunk | ImageChunk

View File

@@ -1,6 +1,10 @@
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsTaskParams,
ImageGenerationTaskParams,
)
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -20,6 +24,14 @@ class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
class ImageGeneration(BaseCommand):
request_params: ImageGenerationTaskParams
class ImageEdits(BaseCommand):
request_params: ImageEditsTaskParams
class PlaceInstance(BaseCommand):
model_meta: ModelMetadata
sharding: Sharding
@@ -47,6 +59,8 @@ Command = (
TestCommand
| RequestEventLog
| ChatCompletion
| ImageGeneration
| ImageEdits
| PlaceInstance
| CreateInstance
| DeleteInstance

View File

@@ -1,3 +1,5 @@
from enum import Enum
from pydantic import PositiveInt
from exo.shared.types.common import Id
@@ -9,6 +11,21 @@ class ModelId(Id):
pass
class ModelTask(str, Enum):
TextGeneration = "TextGeneration"
TextToImage = "TextToImage"
ImageToImage = "ImageToImage"
class ComponentInfo(CamelCaseModel):
component_name: str
component_path: str
storage_size: Memory
n_layers: PositiveInt | None
can_shard: bool
safetensors_index_filename: str | None
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
@@ -16,3 +33,4 @@ class ModelMetadata(CamelCaseModel):
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
components: list[ComponentInfo] | None = None

View File

@@ -2,7 +2,11 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsTaskParams,
ImageGenerationTaskParams,
)
from exo.shared.types.common import CommandId, Id
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
@@ -56,6 +60,22 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class ImageEdits(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageEditsTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
@@ -67,5 +87,7 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| ImageGeneration
| ImageEdits
| Shutdown
)

View File

@@ -1,3 +1,4 @@
from typing import Literal
from exo.shared.types.api import FinishReason
from exo.utils.pydantic_ext import TaggedModel
@@ -17,5 +18,10 @@ class GenerationResponse(BaseRunnerResponse):
finish_reason: FinishReason | None = None
class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -9,6 +9,7 @@ from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal
from urllib.parse import urljoin
from huggingface_hub._snapshot_download import snapshot_download
import aiofiles
import aiofiles.os as aios
@@ -441,12 +442,31 @@ def calculate_repo_progress(
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
index_files_dir = snapshot_download(
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
return index_data.weight_map
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
weight_map: dict[str, str] = {}
for index_file in index_files:
relative_dir = index_file.parent.relative_to(index_files_dir)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
if relative_dir != Path("."):
prefixed_weight_map = {
f"{relative_dir}/{key}": str(relative_dir / value)
for key, value in index_data.weight_map.items()
}
weight_map = weight_map | prefixed_weight_map
else:
weight_map = weight_map | index_data.weight_map
return weight_map
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
@@ -551,8 +571,6 @@ async def download_shard(
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_meta.model_id), revision, recursive=True
)

View File

@@ -100,26 +100,68 @@ def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list
"*.py",
"tokenizer.model",
"tiktoken.model",
"*/spiece.model",
"*.tiktoken",
"*.txt",
"*.jinja",
]
)
shard_specific_patterns: set[str] = set()
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num <= shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
if shard.model_meta.components is not None:
shardable_component = next(
(c for c in shard.model_meta.components if c.can_shard), None
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
if weight_map and shardable_component:
for tensor_name, filename in weight_map.items():
# Strip component prefix from tensor name (added by weight map namespacing)
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
if "/" in tensor_name:
_, tensor_name_no_prefix = tensor_name.split("/", 1)
else:
tensor_name_no_prefix = tensor_name
# Determine which component this file belongs to from filename
component_path = Path(filename).parts[0] if "/" in filename else None
if component_path == shardable_component.component_path.rstrip("/"):
layer_num = extract_layer_num(tensor_name_no_prefix)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
if shard.is_first_layer or shard.is_last_layer:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns = set(["*.safetensors"])
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
for component in shard.model_meta.components:
if not component.can_shard and component.safetensors_index_filename is None:
component_pattern = f"{component.component_path.rstrip('/')}/*"
shard_specific_patterns.add(component_pattern)
else:
shard_specific_patterns = set(["*.safetensors"])
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
else:
shard_specific_patterns = set(["*.safetensors"])
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

View File

@@ -0,0 +1,10 @@
from exo.worker.engines.image.base import ImageGenerator
from exo.worker.engines.image.distributed_model import initialize_image_model
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
__all__ = [
"ImageGenerator",
"generate_image",
"initialize_image_model",
"warmup_image_generator",
]

View File

@@ -0,0 +1,37 @@
from typing import Literal, Optional, Protocol, runtime_checkable
from PIL import Image
@runtime_checkable
class ImageGenerator(Protocol):
@property
def rank(self) -> int: ...
@property
def is_first_stage(self) -> bool: ...
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"],
seed: int,
) -> Optional[Image.Image]:
"""Generate an image from a text prompt.
For distributed inference, only the first stage (rank 0) returns the image.
Other stages return None after participating in the pipeline.
Args:
prompt: Text description of the image to generate
height: Image height in pixels
width: Image width in pixels
quality: Generation quality level
seed: Random seed for reproducibility
Returns:
Generated PIL Image (rank 0) or None (other ranks)
"""
...

View File

@@ -0,0 +1,74 @@
from enum import Enum
from math import ceil
from pydantic import BaseModel
class BlockType(Enum):
JOINT = "joint" # Separate image/text streams
SINGLE = "single" # Concatenated streams
class TransformerBlockConfig(BaseModel):
model_config = {"frozen": True}
block_type: BlockType
count: int
has_separate_text_output: bool # True for joint blocks that output text separately
class ImageModelConfig(BaseModel):
model_config = {"frozen": True}
# Model identification
model_family: str # "flux", "fibo", "qwen"
model_variant: str # "schnell", "dev", etc.
# Architecture parameters
hidden_dim: int
num_heads: int
head_dim: int
# Block configuration - ordered sequence of block types
block_configs: tuple[TransformerBlockConfig, ...]
# Tokenization parameters
patch_size: int # 2 for Flux/Qwen
vae_scale_factor: int # 8 for Flux, 16 for others
# Inference parameters
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
num_sync_steps_factor: float # Fraction of steps for sync phase
# Feature flags
uses_attention_mask: bool # True for Fibo
# CFG (Classifier-Free Guidance) parameters
guidance_scale: float | None = None # None or <= 1.0 disables CFG
@property
def total_blocks(self) -> int:
"""Total number of transformer blocks."""
return sum(bc.count for bc in self.block_configs)
@property
def joint_block_count(self) -> int:
"""Number of joint transformer blocks."""
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
)
@property
def single_block_count(self) -> int:
"""Number of single transformer blocks."""
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
)
def get_steps_for_quality(self, quality: str) -> int:
"""Get inference steps for a quality level."""
return self.default_steps[quality]
def get_num_sync_steps(self, quality: str) -> int:
"""Get number of synchronous steps based on quality."""
return ceil(self.default_steps[quality] * self.num_sync_steps_factor)

View File

@@ -0,0 +1,222 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional
import mlx.core as mx
from mflux.config.config import Config
from PIL import Image
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,
get_config_for_model,
)
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline import DiffusionRunner
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
from exo.worker.runner.bootstrap import logger
class DistributedImageModel:
__slots__ = (
"_config",
"_adapter",
"_group",
"_shard_metadata",
"_runner",
)
_config: ImageModelConfig
_adapter: BaseModelAdapter
_group: Optional[mx.distributed.Group]
_shard_metadata: PipelineShardMetadata
_runner: DiffusionRunner
def __init__(
self,
model_id: str,
local_path: Path,
shard_metadata: PipelineShardMetadata,
group: Optional[mx.distributed.Group] = None,
quantize: int | None = None,
):
# Get model config and create adapter (adapter owns the model)
config = get_config_for_model(model_id)
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
print(
"=====================================================================\n"
+ f"num layers: {len(adapter.transformer.transformer_blocks)}"
+ "\n====================================================================="
)
if group is not None:
adapter.slice_transformer_blocks(
start_layer=shard_metadata.start_layer,
end_layer=shard_metadata.end_layer,
total_joint_blocks=config.joint_block_count,
total_single_blocks=config.single_block_count,
)
print(
"=====================================================================\n"
+ f"num layers: {len(adapter.transformer.transformer_blocks)}"
+ "\n====================================================================="
)
# Create diffusion runner (handles both single-node and distributed modes)
num_sync_steps = config.get_num_sync_steps("medium") if group else 0
runner = DiffusionRunner(
config=config,
adapter=adapter,
group=group,
shard_metadata=shard_metadata,
num_sync_steps=num_sync_steps,
)
if group is not None:
logger.info("Initialized distributed diffusion runner")
mx.eval(adapter.model.parameters())
# TODO(ciaran): Do we need this?
mx.eval(adapter.model)
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
logger.info(f"Transformer sharded for rank {group.rank()}")
else:
logger.info("Single-node initialization")
object.__setattr__(self, "_config", config)
object.__setattr__(self, "_adapter", adapter)
object.__setattr__(self, "_group", group)
object.__setattr__(self, "_shard_metadata", shard_metadata)
object.__setattr__(self, "_runner", runner)
@classmethod
def from_bound_instance(
cls, bound_instance: BoundInstance
) -> "DistributedImageModel":
model_id = bound_instance.bound_shard.model_meta.model_id
model_path = build_model_path(model_id)
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, PipelineShardMetadata):
raise ValueError("Expected PipelineShardMetadata for image generation")
is_distributed = (
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
)
if is_distributed:
logger.info("Starting distributed init for image model")
group = mlx_distributed_init(bound_instance)
else:
group = None
return cls(
model_id=model_id,
local_path=model_path,
shard_metadata=shard_metadata,
group=group,
)
@property
def model(self) -> Any:
"""Return the underlying mflux model via the adapter."""
return self._adapter.model
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def adapter(self) -> BaseModelAdapter:
return self._adapter
@property
def group(self) -> Optional[mx.distributed.Group]:
return self._group
@property
def shard_metadata(self) -> PipelineShardMetadata:
return self._shard_metadata
@property
def rank(self) -> int:
return self._shard_metadata.device_rank
@property
def world_size(self) -> int:
return self._shard_metadata.world_size
@property
def is_first_stage(self) -> bool:
return self._shard_metadata.device_rank == 0
@property
def is_last_stage(self) -> bool:
return self._shard_metadata.device_rank == self._shard_metadata.world_size - 1
@property
def is_distributed(self) -> bool:
return self._shard_metadata.world_size > 1
@property
def runner(self) -> DiffusionRunner:
return self._runner
# Delegate attribute access to the underlying model via the adapter.
# Guarded with TYPE_CHECKING to prevent type checker complaints
# while still providing full delegation at runtime.
if not TYPE_CHECKING:
def __getattr__(self, name: str) -> Any:
return getattr(self._adapter.model, name)
def __setattr__(self, name: str, value: Any) -> None:
if name in (
"_config",
"_adapter",
"_group",
"_shard_metadata",
"_runner",
):
object.__setattr__(self, name, value)
else:
setattr(self._adapter.model, name, value)
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"] = "medium",
seed: int = 2,
) -> Optional[Image.Image]:
# Determine number of inference steps based on quality
steps = self._config.get_steps_for_quality(quality)
config = Config(num_inference_steps=steps, height=height, width=width)
image = self._generate_image(settings=config, prompt=prompt, seed=seed)
logger.info("generated image")
# Only final rank returns the actual image
if self.is_last_stage:
return image.image
def _generate_image(self, settings: Config, prompt: str, seed: int) -> Any:
"""Generate image by delegating to the runner."""
return self._runner.generate_image(
settings=settings,
prompt=prompt,
seed=seed,
)
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
"""Initialize DistributedImageModel from a BoundInstance."""
return DistributedImageModel.from_bound_instance(bound_instance)

View File

@@ -0,0 +1,72 @@
import io
from typing import Generator, Literal
from PIL import Image
from exo.shared.types.api import ImageGenerationTaskParams
from exo.shared.types.worker.runner_response import ImageGenerationResponse
from exo.worker.engines.image.base import ImageGenerator
def parse_size(size_str: str | None) -> tuple[int, int]:
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
if not size_str or size_str == "auto":
size_str = "1024x1024"
try:
parts = size_str.split("x")
if len(parts) == 2:
width, height = int(parts[0]), int(parts[1])
return (width, height)
except (ValueError, AttributeError):
pass
# Default fallback
return (1024, 1024)
def warmup_image_generator(model: ImageGenerator) -> Image.Image | None:
return model.generate(
prompt="Warmup",
height=256,
width=256,
quality="low",
seed=2,
)
def generate_image(
model: ImageGenerator,
task: ImageGenerationTaskParams,
) -> Generator[ImageGenerationResponse, None, None]:
# Parse parameters
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
seed = 2 # TODO(ciaran): Consider adding seed to ImageGenerationTaskParams
# Generate using the model's generate method
image = model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
)
# Only rank 0 returns the image
if image is None:
return
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
image_bytes = buffer.getvalue()
# Send complete image as single response (no artificial chunking)
yield ImageGenerationResponse(
image_data=image_bytes,
format=task.output_format,
)

View File

@@ -0,0 +1,80 @@
from pathlib import Path
from typing import Callable
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.flux import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
FluxModelAdapter,
)
from exo.worker.engines.image.models.qwen import (
QWEN_IMAGE_CONFIG,
QwenModelAdapter,
)
from exo.worker.engines.image.pipeline.adapter import ModelAdapter
__all__: list[str] = []
# Type alias for adapter factory functions
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter]
# Registry maps model_family string to adapter factory
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
"flux": FluxModelAdapter,
"qwen": QwenModelAdapter,
}
# Config registry: maps model ID patterns to configs
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image": QWEN_IMAGE_CONFIG,
}
def get_config_for_model(model_id: str) -> ImageModelConfig:
"""Get configuration for a model ID.
Args:
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
Returns:
The model configuration
Raises:
ValueError: If no configuration found for model ID
"""
model_id_lower = model_id.lower()
for pattern, config in _CONFIG_REGISTRY.items():
if pattern in model_id_lower:
return config
raise ValueError(f"No configuration found for model: {model_id}")
def create_adapter_for_model(
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
) -> ModelAdapter:
"""Create a model adapter for the given configuration.
Args:
config: The model configuration
model_id: The model identifier
local_path: Path to the model weights
quantize: Optional quantization bits
Returns:
A ModelAdapter instance
Raises:
ValueError: If no adapter found for model family
"""
factory = _ADAPTER_REGISTRY.get(config.model_family)
if factory is None:
raise ValueError(f"No adapter found for model family: {config.model_family}")
return factory(config, model_id, local_path, quantize)

View File

@@ -0,0 +1,92 @@
from abc import ABC, abstractmethod
from typing import Any
import mlx.core as mx
from mflux.config.runtime_config import RuntimeConfig
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
from mflux.utils.array_util import ArrayUtil
from mflux.utils.image_util import ImageUtil
class BaseModelAdapter(ABC):
"""Base class for model adapters with shared utilities.
Provides common implementations for latent creation and decoding.
Subclasses implement model-specific prompt encoding and noise computation.
"""
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
"""Create initial latents. Uses model-specific latent creator."""
return LatentCreator.create_for_txt2img_or_img2img(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
img2img=Img2Img(
vae=self.model.vae,
latent_creator=self._get_latent_creator(),
sigmas=runtime_config.scheduler.sigmas,
init_time_step=runtime_config.init_time_step,
image_path=runtime_config.image_path,
),
)
def decode_latents(
self,
latents: mx.array,
runtime_config: RuntimeConfig,
seed: int,
prompt: str,
) -> Any:
"""Decode latents to image. Shared implementation."""
latents = ArrayUtil.unpack_latents(
latents=latents,
height=runtime_config.height,
width=runtime_config.width,
)
decoded = self.model.vae.decode(latents)
return ImageUtil.to_image(
decoded_latents=decoded,
config=runtime_config,
seed=seed,
prompt=prompt,
quantization=self.model.bits,
lora_paths=self.model.lora_paths,
lora_scales=self.model.lora_scales,
image_path=runtime_config.image_path,
image_strength=runtime_config.image_strength,
generation_time=0,
)
# Abstract methods - subclasses must implement
@property
@abstractmethod
def model(self) -> Any:
"""Return the underlying mflux model."""
...
@abstractmethod
def _get_latent_creator(self) -> type:
"""Return the latent creator class for this model."""
...
@abstractmethod
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
):
"""Remove transformer blocks outside the assigned range.
This should be called BEFORE mx.eval() to avoid loading unused weights
in distributed mode.
Args:
start_layer: First layer index (inclusive) assigned to this node
end_layer: Last layer index (exclusive) assigned to this node
total_joint_blocks: Total number of joint blocks in the model
total_single_blocks: Total number of single blocks in the model
"""
...

View File

@@ -0,0 +1,11 @@
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
from exo.worker.engines.image.models.flux.config import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
)
__all__ = [
"FluxModelAdapter",
"FLUX_DEV_CONFIG",
"FLUX_SCHNELL_CONFIG",
]

View File

@@ -0,0 +1,675 @@
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
from mflux.config.model_config import ModelConfig
from mflux.config.runtime_config import RuntimeConfig
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
AttentionUtils,
)
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
JointTransformerBlock,
)
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.variants.txt2img.flux import Flux1
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class FluxPromptData:
"""Container for Flux prompt encoding results."""
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
self._prompt_embeds = prompt_embeds
self._pooled_prompt_embeds = pooled_prompt_embeds
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._pooled_prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Flux has no extra forward kwargs."""
return {}
class FluxModelAdapter(BaseModelAdapter):
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = Flux1(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
local_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> Flux1:
return self._model
@property
def transformer(self) -> Transformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.x_embedder.weight.shape[0]
def _get_latent_creator(self) -> type:
return FluxLatentCreator
def encode_prompt(self, prompt: str) -> FluxPromptData:
"""Encode prompt into FluxPromptData."""
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
prompt_cache=self._model.prompt_cache,
t5_tokenizer=self._model.t5_tokenizer,
clip_tokenizer=self._model.clip_tokenizer,
t5_text_encoder=self._model.t5_text_encoder,
clip_text_encoder=self._model.clip_text_encoder,
)
return FluxPromptData(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)
@property
def needs_cfg(self) -> bool:
return False
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
raise NotImplementedError("Flux does not use classifier-free guidance")
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.x_embedder(hidden_states)
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: RuntimeConfig,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None, # Ignored by Flux
) -> mx.array:
if pooled_prompt_embeds is None:
raise ValueError(
"pooled_prompt_embeds is required for Flux text embeddings"
)
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
return Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: RuntimeConfig,
**kwargs: Any,
) -> mx.array:
kontext_image_ids = kwargs.get("kontext_image_ids")
return Transformer.compute_rotary_embeddings(
prompt_embeds,
self._transformer.pos_embed,
runtime_config,
kontext_image_ids,
)
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # mx.array for Flux
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
if mode == BlockWrapperMode.CACHING:
return self._apply_joint_block_caching(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
)
else:
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_joint_block_patched(
block=block,
patch_hidden=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
if mode == BlockWrapperMode.CACHING:
return self._apply_single_block_caching(
block=block,
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
)
else:
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_single_block_patched(
block=block,
patch_hidden=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)
def get_joint_blocks(self) -> list[JointBlockInterface]:
return cast(
list[JointBlockInterface], list(self._transformer.transformer_blocks)
)
def get_single_blocks(self) -> list[SingleBlockInterface]:
return cast(
list[SingleBlockInterface],
list(self._transformer.single_transformer_blocks),
)
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
) -> None:
if end_layer <= total_joint_blocks:
# All assigned are joint blocks
joint_start, joint_end = start_layer, end_layer
single_start, single_end = 0, 0
elif start_layer >= total_joint_blocks:
# All assigned are single blocks
joint_start, joint_end = 0, 0
single_start = start_layer - total_joint_blocks
single_end = end_layer - total_joint_blocks
else:
# Spans both joint and single
joint_start, joint_end = start_layer, total_joint_blocks
single_start = 0
single_end = end_layer - total_joint_blocks
all_joint = list(self._transformer.transformer_blocks)
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
all_single = list(self._transformer.single_transformer_blocks)
self._transformer.single_transformer_blocks = all_single[
single_start:single_end
]
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def _apply_joint_block_caching(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
) -> tuple[mx.array, mx.array]:
num_img_tokens = hidden_states.shape[1]
batch_size = hidden_states.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# 1. Compute norms
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
)
# 2. Compute Q, K, V for full image
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Compute Q, K, V for text
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 4. Concatenate Q, K, V: [text, image]
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
# 5. Apply RoPE
query, key = AttentionUtils.apply_rope(
xq=query, xk=key, freqs_cis=rotary_embeddings
)
# 6. Store IMAGE K/V in cache for async pipeline
if kv_cache is not None:
kv_cache.update_image_patch(
patch_start=0,
patch_end=num_img_tokens,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 7. Compute full attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 8. Extract and project outputs
context_attn_output = attn_output[:, :text_seq_len, :]
attn_output = attn_output[:, text_seq_len:, :]
attn_output = attn.to_out[0](attn_output)
context_attn_output = attn.to_add_out(context_attn_output)
# 9. Apply norm and feed forward
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=hidden_states,
attn_output=attn_output,
gate_mlp=gate_mlp,
gate_msa=gate_msa,
scale_mlp=scale_mlp,
shift_mlp=shift_mlp,
norm_layer=block.norm2,
ff_layer=block.ff,
)
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=encoder_hidden_states,
attn_output=context_attn_output,
gate_mlp=c_gate_mlp,
gate_msa=c_gate_msa,
scale_mlp=c_scale_mlp,
shift_mlp=c_shift_mlp,
norm_layer=block.norm2_context,
ff_layer=block.ff_context,
)
return encoder_hidden_states, hidden_states
def _apply_joint_block_patched(
self,
block: JointBlockInterface,
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
) -> tuple[mx.array, mx.array]:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# 1. Compute norms
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
hidden_states=patch_hidden,
text_embeddings=text_embeddings,
)
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
)
# 2. Compute Q, K, V for image patch
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Compute Q, K, V for text
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 4. Concatenate Q, K, V for patch: [text, patch]
query = mx.concatenate([txt_query, img_query], axis=2)
patch_key = mx.concatenate([txt_key, img_key], axis=2)
patch_value = mx.concatenate([txt_value, img_value], axis=2)
# 5. Extract RoPE for [text + current_patch]
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
]
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
# 6. Apply RoPE
query, patch_key = AttentionUtils.apply_rope(
xq=query, xk=patch_key, freqs_cis=patch_rope
)
# 7. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=patch_key[:, :, text_seq_len:, :],
value=patch_value[:, :, text_seq_len:, :],
)
# 8. Get full K, V from cache
full_key, full_value = kv_cache.get_full_kv(
text_key=patch_key[:, :, :text_seq_len, :],
text_value=patch_value[:, :, :text_seq_len, :],
)
# 9. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=full_key,
value=full_value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 10. Extract and project outputs
context_attn_output = attn_output[:, :text_seq_len, :]
hidden_attn_output = attn_output[:, text_seq_len:, :]
hidden_attn_output = attn.to_out[0](hidden_attn_output)
context_attn_output = attn.to_add_out(context_attn_output)
# 11. Apply norm and feed forward
patch_hidden = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=patch_hidden,
attn_output=hidden_attn_output,
gate_mlp=gate_mlp,
gate_msa=gate_msa,
scale_mlp=scale_mlp,
shift_mlp=shift_mlp,
norm_layer=block.norm2,
ff_layer=block.ff,
)
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=encoder_hidden_states,
attn_output=context_attn_output,
gate_mlp=c_gate_mlp,
gate_msa=c_gate_msa,
scale_mlp=c_scale_mlp,
shift_mlp=c_shift_mlp,
norm_layer=block.norm2_context,
ff_layer=block.ff_context,
)
return encoder_hidden_states, patch_hidden
def _apply_single_block_caching(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
) -> mx.array:
total_seq_len = hidden_states.shape[1]
num_img_tokens = total_seq_len - text_seq_len
batch_size = hidden_states.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# Residual connection
residual = hidden_states
# 1. Compute norm
norm_hidden, gate = block.norm(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V
query, key, value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Apply RoPE
query, key = AttentionUtils.apply_rope(
xq=query, xk=key, freqs_cis=rotary_embeddings
)
# 4. Store IMAGE K/V in cache
if kv_cache is not None:
kv_cache.update_image_patch(
patch_start=0,
patch_end=num_img_tokens,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 5. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 6. Apply feed forward and projection
hidden_states = block._apply_feed_forward_and_projection(
norm_hidden_states=norm_hidden,
attn_output=attn_output,
gate=gate,
)
return residual + hidden_states
def _apply_single_block_patched(
self,
block: SingleBlockInterface,
patch_hidden: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
) -> mx.array:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dimension
# Residual connection
residual = patch_hidden
# 1. Compute norm
norm_hidden, gate = block.norm(
hidden_states=patch_hidden,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V
query, key, value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=num_heads,
head_dim=head_dim,
)
# 3. Extract RoPE for [text + current_patch]
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
]
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
# 4. Apply RoPE
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=patch_rope)
# 5. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=key[:, :, text_seq_len:, :],
value=value[:, :, text_seq_len:, :],
)
# 6. Get full K, V from cache
full_key, full_value = kv_cache.get_full_kv(
text_key=key[:, :, :text_seq_len, :],
text_value=value[:, :, :text_seq_len, :],
)
# 7. Compute attention
attn_output = AttentionUtils.compute_attention(
query=query,
key=full_key,
value=full_value,
batch_size=batch_size,
num_heads=num_heads,
head_dim=head_dim,
)
# 8. Apply feed forward and projection
hidden_states = block._apply_feed_forward_and_projection(
norm_hidden_states=norm_hidden,
attn_output=attn_output,
gate=gate,
)
return residual + hidden_states

View File

@@ -0,0 +1,48 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
FLUX_SCHNELL_CONFIG = ImageModelConfig(
model_family="flux",
model_variant="schnell",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
patch_size=2,
vae_scale_factor=8,
default_steps={"low": 1, "medium": 2, "high": 4},
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
uses_attention_mask=False,
)
FLUX_DEV_CONFIG = ImageModelConfig(
model_family="flux",
model_variant="dev",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
patch_size=2,
vae_scale_factor=8,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
uses_attention_mask=False,
)

View File

@@ -0,0 +1,7 @@
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
from exo.worker.engines.image.models.qwen.config import QWEN_IMAGE_CONFIG
__all__ = [
"QwenModelAdapter",
"QWEN_IMAGE_CONFIG",
]

View File

@@ -0,0 +1,514 @@
from pathlib import Path
from typing import Any, cast
import mlx.core as mx
from mflux.config.model_config import ModelConfig
from mflux.config.runtime_config import RuntimeConfig
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
QwenPromptEncoder,
)
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
QwenTransformerBlock,
)
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import BaseModelAdapter
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class QwenPromptData:
"""Container for Qwen prompt encoding results.
Implements PromptData protocol with additional Qwen-specific attributes.
"""
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
):
self._prompt_embeds = prompt_embeds
self.prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self.negative_prompt_mask = negative_prompt_mask
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
return self._prompt_embeds # Use prompt_embeds as placeholder
@property
def negative_prompt_embeds(self) -> mx.array:
"""Negative prompt embeddings for CFG."""
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
"""Placeholder - Qwen doesn't use pooled embeds."""
return self._negative_prompt_embeds
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Return encoder_hidden_states_mask for the appropriate prompt."""
if positive:
return {"encoder_hidden_states_mask": self.prompt_mask}
else:
return {"encoder_hidden_states_mask": self.negative_prompt_mask}
class QwenModelAdapter(BaseModelAdapter):
"""Adapter for Qwen-Image model.
Key differences from Flux:
- Single text encoder (vs dual T5+CLIP)
- 60 joint-style blocks, no single blocks
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
- Norm-preserving CFG with negative prompts
- Uses attention mask for variable-length text
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImage(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
local_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> QwenImage:
return self._model
@property
def transformer(self) -> QwenTransformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def encode_prompt(self, prompt: str) -> QwenPromptData:
"""Encode prompt into QwenPromptData.
Qwen uses classifier-free guidance with explicit negative prompts.
Returns a QwenPromptData container with all 4 tensors.
"""
# TODO(ciaran): empty string as default negative prompt
negative_prompt = ""
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
QwenPromptEncoder.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_cache=self._model.prompt_cache,
qwen_tokenizer=self._model.qwen_tokenizer,
qwen_text_encoder=self._model.text_encoder,
)
)
return QwenPromptData(
prompt_embeds=prompt_embeds,
prompt_mask=prompt_mask,
negative_prompt_embeds=neg_embeds,
negative_prompt_mask=neg_mask,
)
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
return self._model.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute image and text embeddings."""
# Image embedding
embedded_hidden = self._transformer.img_in(hidden_states)
# Text embedding: first normalize, then project
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: RuntimeConfig,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings.
For Qwen, the time_text_embed only uses hidden_states for:
- batch_size (shape[0])
- dtype
This allows us to pass any tensor (latents, prompt_embeds) as a fallback
when embedded hidden_states are not yet available.
"""
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
# (which for Qwen is the same as prompt_embeds)
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: RuntimeConfig,
**kwargs: Any,
) -> Any:
"""Compute 3D rotary embeddings for Qwen.
Qwen uses video-aware 3D RoPE with separate embeddings for image and text.
Returns:
tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]:
((img_cos, img_sin), (txt_cos, txt_sin))
"""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
cond_image_grid = kwargs.get("cond_image_grid")
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]] for Qwen
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply Qwen joint block.
For caching mode, we run the full block and optionally populate the KV cache.
For patched mode, we use the cached KV values (not yet implemented).
"""
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
block_idx = kwargs.get("block_idx")
if mode == BlockWrapperMode.CACHING:
return self._apply_joint_block_caching(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
else:
# mode == BlockWrapperMode.PATCHED
assert patch_start is not None and patch_end is not None
assert kv_cache is not None
return self._apply_joint_block_patched(
block=block,
patch_hidden=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
encoder_hidden_states_mask=encoder_hidden_states_mask,
block_idx=block_idx,
)
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Qwen has no single blocks."""
raise NotImplementedError("Qwen does not have single blocks")
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final normalization and projection."""
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)
def get_joint_blocks(self) -> list[JointBlockInterface]:
"""Return all 60 transformer blocks."""
return cast(
list[JointBlockInterface], list(self._transformer.transformer_blocks)
)
def get_single_blocks(self) -> list[SingleBlockInterface]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
) -> None:
all_blocks = list(self._transformer.transformer_blocks)
assigned_blocks = all_blocks[start_layer:end_layer]
self._transformer.transformer_blocks = assigned_blocks
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams.
For Qwen, this is called before final projection.
The streams remain separate through all blocks.
"""
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def _apply_joint_block_caching(
self,
block: Any, # QwenTransformerBlock
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache | None,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
"""Apply joint block in caching mode (full attention, optionally populate cache).
Delegates to the QwenTransformerBlock's forward pass.
"""
# Call the block directly - it handles all the modulation and attention internally
return block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
text_embeddings=text_embeddings,
image_rotary_emb=rotary_embeddings,
block_idx=block_idx,
)
def _apply_joint_block_patched(
self,
block: Any, # QwenTransformerBlock
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
kv_cache: ImagePatchKVCache,
text_seq_len: int,
patch_start: int,
patch_end: int,
encoder_hidden_states_mask: mx.array | None = None,
block_idx: int | None = None,
) -> tuple[mx.array, mx.array]:
batch_size = patch_hidden.shape[0]
attn = block.attn
num_heads = attn.num_heads
head_dim = attn.head_dim
# 1. Compute modulation parameters
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
# 2. Apply normalization and modulation
img_normed = block.img_norm1(patch_hidden)
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
txt_normed = block.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
# 3. Compute Q, K, V for image patch
img_query = attn.to_q(img_modulated)
img_key = attn.to_k(img_modulated)
img_value = attn.to_v(img_modulated)
# 4. Compute Q, K, V for text
txt_query = attn.add_q_proj(txt_modulated)
txt_key = attn.add_k_proj(txt_modulated)
txt_value = attn.add_v_proj(txt_modulated)
# 5. Reshape to [B, S, H, D]
patch_len = patch_hidden.shape[1]
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
txt_query = mx.reshape(
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
)
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
txt_value = mx.reshape(
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
)
# 6. Apply RMSNorm to Q, K
img_query = attn.norm_q(img_query)
img_key = attn.norm_k(img_key)
txt_query = attn.norm_added_q(txt_query)
txt_key = attn.norm_added_k(txt_key)
# 7. Extract RoPE for patch: slice image RoPE, keep full text RoPE
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
patch_img_cos = img_cos[patch_start:patch_end]
patch_img_sin = img_sin[patch_start:patch_end]
# 8. Apply RoPE to Q, K
img_query = QwenAttention._apply_rope_qwen(
img_query, patch_img_cos, patch_img_sin
)
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
# 9. Transpose to [B, H, S, D] for cache operations
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
# 10. Update cache with this patch's IMAGE K/V
kv_cache.update_image_patch(
patch_start=patch_start,
patch_end=patch_end,
key=img_key_bhsd,
value=img_value_bhsd,
)
# 11. Get full K, V from cache (text + full image)
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
full_key, full_value = kv_cache.get_full_kv(
text_key=txt_key_bhsd,
text_value=txt_value_bhsd,
)
# 12. Build query: [text, patch]
joint_query = mx.concatenate([txt_query, img_query], axis=1)
# 13. Build attention mask for [text + patch] query attending to [text + full_image] KV
mask = QwenAttention._convert_mask_for_qwen(
mask=encoder_hidden_states_mask,
joint_seq_len=full_key.shape[2], # text + full_image
txt_seq_len=text_seq_len,
)
# 14. Compute attention
hidden_states = attn._compute_attention_qwen(
query=joint_query,
key=mx.transpose(full_key, (0, 2, 1, 3)), # Back to [B, S, H, D]
value=mx.transpose(full_value, (0, 2, 1, 3)),
mask=mask,
block_idx=block_idx,
)
# 15. Extract text and image attention outputs
txt_attn_output = hidden_states[:, :text_seq_len, :]
img_attn_output = hidden_states[:, text_seq_len:, :]
# 16. Project outputs
img_attn_output = attn.attn_to_out[0](img_attn_output)
txt_attn_output = attn.to_add_out(txt_attn_output)
# 17. Apply residual + gate for attention
patch_hidden = patch_hidden + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# 18. Apply feed-forward for image
img_normed2 = block.img_norm2(patch_hidden)
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
img_normed2, img_mod2
)
img_mlp_output = block.img_ff(img_modulated2)
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
# 19. Apply feed-forward for text
txt_normed2 = block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
txt_normed2, txt_mod2
)
txt_mlp_output = block.txt_ff(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
return encoder_hidden_states, patch_hidden

View File

@@ -0,0 +1,28 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
# Qwen-Image has 60 joint-style blocks (no single blocks)
# Architecture: 24 heads * 128 dim = 3072 hidden dim
# VAE uses scale factor of 16 (vs Flux's 8)
QWEN_IMAGE_CONFIG = ImageModelConfig(
model_family="qwen",
model_variant="image",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
# Qwen has no single blocks - all blocks process image and text separately
),
patch_size=2,
vae_scale_factor=16,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
uses_attention_mask=True, # Qwen uses encoder_hidden_states_mask
guidance_scale=None, # Set to None or < 1.0 to disable CFG
)

View File

@@ -0,0 +1,23 @@
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
ModelAdapter,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
__all__ = [
"BlockWrapperMode",
"DiffusionRunner",
"ImagePatchKVCache",
"JointBlockInterface",
"JointBlockWrapper",
"ModelAdapter",
"SingleBlockInterface",
"SingleBlockWrapper",
]

View File

@@ -0,0 +1,376 @@
from enum import Enum
from typing import Any, Protocol
import mlx.core as mx
from mflux.config.runtime_config import RuntimeConfig
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class AttentionInterface(Protocol):
num_heads: int
head_dimension: int
to_q: Any
to_k: Any
to_v: Any
norm_q: Any
norm_k: Any
to_out: list[Any]
class JointAttentionInterface(AttentionInterface, Protocol):
add_q_proj: Any
add_k_proj: Any
add_v_proj: Any
norm_added_q: Any
norm_added_k: Any
to_add_out: Any
class JointBlockInterface(Protocol):
attn: JointAttentionInterface
norm1: Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
norm1_context: (
Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
)
norm2: Any
norm2_context: Any
ff: Any
ff_context: Any
class SingleBlockInterface(Protocol):
attn: AttentionInterface
norm: Any # Callable module: (hidden_states, text_embeddings) -> tuple[2 arrays]
def _apply_feed_forward_and_projection(
self, norm_hidden_states: mx.array, attn_output: mx.array, gate: mx.array
) -> mx.array:
"""Apply feed forward network and projection."""
...
class BlockWrapperMode(Enum):
CACHING = "caching" # Sync mode: compute full attention, populate cache
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
class PromptData(Protocol):
"""Protocol for encoded prompt data.
All adapters must return prompt data that conforms to this protocol.
Model-specific prompt data classes can add additional attributes
(e.g., attention masks for Qwen).
"""
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
...
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Pooled text embeddings (for Flux) or placeholder (for Qwen)."""
...
@property
def negative_prompt_embeds(self) -> mx.array | None:
"""Negative prompt embeddings for CFG (None if not using CFG)."""
...
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Negative pooled embeddings for CFG (None if not using CFG)."""
...
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
"""Return model-specific kwargs for forward pass.
Args:
positive: If True, return kwargs for positive prompt pass.
If False, return kwargs for negative prompt pass.
Returns:
Dict of extra kwargs (e.g., {"encoder_hidden_states_mask": ...} for Qwen)
"""
...
class ModelAdapter(Protocol):
@property
def config(self) -> ImageModelConfig:
"""Return the model configuration."""
...
@property
def model(self) -> Any:
"""Return the underlying mflux model instance (e.g., Flux1, Fibo, Qwen)."""
...
@property
def transformer(self) -> Any:
"""Return the transformer component of the model."""
...
@property
def hidden_dim(self) -> int:
"""Return the hidden dimension of the transformer."""
...
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute x_embedder and context_embedder outputs.
Args:
hidden_states: Input latent states
prompt_embeds: Text embeddings from encoder
Returns:
Tuple of (embedded_hidden_states, embedded_encoder_states)
"""
...
def compute_text_embeddings(
self,
t: int,
runtime_config: RuntimeConfig,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings for conditioning.
Args:
t: Current timestep
runtime_config: Runtime configuration
pooled_prompt_embeds: Pooled text embeddings (used by Flux)
hidden_states: Image hidden states
Returns:
Text embeddings tensor
"""
...
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: RuntimeConfig,
**kwargs: Any,
) -> Any:
"""Compute rotary position embeddings.
Args:
prompt_embeds: Text embeddings
runtime_config: Runtime configuration
**kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask for Qwen)
Returns:
Flux: mx.array
Qwen: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]
"""
...
def apply_joint_block(
self,
block: JointBlockInterface,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any, # Format varies: mx.array (Flux) or nested tuple (Qwen)
kv_cache: ImagePatchKVCache | None,
mode: "BlockWrapperMode",
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply a joint transformer block.
Args:
block: The joint transformer block
hidden_states: Image hidden states
encoder_hidden_states: Text hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings (format varies by model)
kv_cache: KV cache (None if not using cache)
mode: CACHING or PATCHED mode
text_seq_len: Text sequence length
patch_start: Start index for patched mode
patch_end: End index for patched mode
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
block_idx for Qwen)
Returns:
Tuple of (encoder_hidden_states, hidden_states)
"""
...
def apply_single_block(
self,
block: SingleBlockInterface,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
kv_cache: ImagePatchKVCache | None,
mode: "BlockWrapperMode",
text_seq_len: int,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Apply a single transformer block.
Args:
block: The single transformer block
hidden_states: Concatenated [text + image] hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
kv_cache: KV cache (None if not using cache)
mode: CACHING or PATCHED mode
text_seq_len: Text sequence length
patch_start: Start index for patched mode
patch_end: End index for patched mode
Returns:
Output hidden states
"""
...
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final norm and projection.
Args:
hidden_states: Hidden states (image only, text already removed)
text_embeddings: Conditioning embeddings
Returns:
Projected output
"""
...
def get_joint_blocks(self) -> list[JointBlockInterface]:
"""Get the list of joint transformer blocks from the model."""
...
def get_single_blocks(self) -> list[SingleBlockInterface]:
"""Get the list of single transformer blocks from the model."""
...
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
total_joint_blocks: int,
total_single_blocks: int,
):
"""Remove transformer blocks outside the assigned range.
This should be called BEFORE mx.eval() to avoid loading unused weights
in distributed mode.
Args:
start_layer: First layer index (inclusive) assigned to this node
end_layer: Last layer index (exclusive) assigned to this node
total_joint_blocks: Total number of joint blocks in the model
total_single_blocks: Total number of single blocks in the model
"""
...
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams for transition to single blocks.
This is called at the transition point from joint blocks (which process
image and text separately) to single blocks (which process them
together). Override to customize the merge strategy.
Args:
hidden_states: Image hidden states
encoder_hidden_states: Text hidden states
Returns:
Merged hidden states (default: concatenate [text, image])
"""
...
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
"""Create initial noise latents for generation.
Args:
seed: Random seed
runtime_config: Runtime configuration with dimensions
Returns:
Initial latent tensor
"""
...
def encode_prompt(self, prompt: str) -> PromptData:
"""Encode prompt into model-specific prompt data.
Args:
prompt: Text prompt
Returns:
PromptData containing embeddings (and model-specific extras)
"""
...
@property
def needs_cfg(self) -> bool:
"""Whether this model uses classifier-free guidance.
Returns:
True if model requires two forward passes with guidance (e.g., Qwen)
False if model uses a single forward pass (e.g., Flux)
"""
...
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
"""Apply classifier-free guidance to combine positive/negative predictions.
Only called when needs_cfg is True.
Args:
noise_positive: Noise prediction from positive prompt
noise_negative: Noise prediction from negative prompt
guidance_scale: Guidance strength
Returns:
Guided noise prediction
"""
...
def decode_latents(
self,
latents: mx.array,
runtime_config: RuntimeConfig,
seed: int,
prompt: str,
) -> Any:
"""Decode latents to final image.
Args:
latents: Final denoised latents
runtime_config: Runtime configuration
seed: Random seed (for metadata)
prompt: Text prompt (for metadata)
Returns:
GeneratedImage result
"""
...

View File

@@ -0,0 +1,146 @@
from typing import Any
import mlx.core as mx
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
JointBlockInterface,
ModelAdapter,
SingleBlockInterface,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class JointBlockWrapper:
"""Unified wrapper for joint transformer blocks.
Handles both CACHING (sync) and PATCHED (async) modes by delegating
to the model adapter for model-specific attention computation.
The wrapper is created once at initialization and reused across calls.
Mode and KV cache are passed at call time to support switching between
sync and async pipelines.
"""
def __init__(
self,
block: JointBlockInterface,
adapter: ModelAdapter,
):
"""Initialize the joint block wrapper.
Args:
block: The joint transformer block to wrap
adapter: Model adapter for model-specific operations
"""
self.block = block
self.adapter = adapter
def __call__(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
text_seq_len: int,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
patch_start: int | None = None,
patch_end: int | None = None,
**kwargs: Any,
) -> tuple[mx.array, mx.array]:
"""Apply the joint block.
Args:
hidden_states: Image hidden states (full or patch depending on mode)
encoder_hidden_states: Text hidden states
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
text_seq_len: Text sequence length
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
mode: CACHING (populate cache) or PATCHED (use cached K/V)
patch_start: Start index for patched mode (required if mode=PATCHED)
patch_end: End index for patched mode (required if mode=PATCHED)
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
block_idx for Qwen)
Returns:
Tuple of (encoder_hidden_states, hidden_states)
"""
return self.adapter.apply_joint_block(
block=self.block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
mode=mode,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
**kwargs,
)
class SingleBlockWrapper:
"""Unified wrapper for single transformer blocks.
Handles both CACHING (sync) and PATCHED (async) modes by delegating
to the model adapter for model-specific attention computation.
The wrapper is created once at initialization and reused across calls.
Mode and KV cache are passed at call time to support switching between
sync and async pipelines.
"""
def __init__(
self,
block: SingleBlockInterface,
adapter: ModelAdapter,
):
"""Initialize the single block wrapper.
Args:
block: The single transformer block to wrap
adapter: Model adapter for model-specific operations
"""
self.block = block
self.adapter = adapter
def __call__(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
text_seq_len: int,
kv_cache: ImagePatchKVCache | None,
mode: BlockWrapperMode,
patch_start: int | None = None,
patch_end: int | None = None,
) -> mx.array:
"""Apply the single block.
Args:
hidden_states: [text + image] hidden states (full or patch depending on mode)
text_embeddings: Conditioning embeddings
rotary_embeddings: Rotary position embeddings
text_seq_len: Text sequence length
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
mode: CACHING (populate cache) or PATCHED (use cached K/V)
patch_start: Start index for patched mode (required if mode=PATCHED)
patch_end: End index for patched mode (required if mode=PATCHED)
Returns:
Output hidden states
"""
return self.adapter.apply_single_block(
block=self.block,
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
kv_cache=kv_cache,
mode=mode,
text_seq_len=text_seq_len,
patch_start=patch_start,
patch_end=patch_end,
)

View File

@@ -0,0 +1,72 @@
import mlx.core as mx
class ImagePatchKVCache:
"""KV cache that stores only IMAGE K/V with patch-level updates.
Only caches image K/V since:
- Text K/V is always computed fresh (same for all patches)
- Only image portion needs stale/fresh cache management across patches
"""
def __init__(
self,
batch_size: int,
num_heads: int,
image_seq_len: int,
head_dim: int,
dtype: mx.Dtype = mx.float32,
):
self.batch_size = batch_size
self.num_heads = num_heads
self.image_seq_len = image_seq_len
self.head_dim = head_dim
self._dtype = dtype
self.key_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
self.value_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
def update_image_patch(
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
) -> None:
"""Update cache with fresh K/V for an image patch slice.
Args:
patch_start: Start token index within image portion (0-indexed)
patch_end: End token index within image portion
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
"""
self.key_cache[:, :, patch_start:patch_end, :] = key
self.value_cache[:, :, patch_start:patch_end, :] = value
def get_full_kv(
self, text_key: mx.array, text_value: mx.array
) -> tuple[mx.array, mx.array]:
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
Args:
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
Returns:
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
"""
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
return full_key, full_value
def reset(self) -> None:
"""Reset cache to zeros."""
self.key_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)
self.value_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)

View File

@@ -0,0 +1,838 @@
from math import ceil
from typing import Any, Optional
import mlx.core as mx
from mflux.callbacks.callbacks import Callbacks
from mflux.config.config import Config
from mflux.config.runtime_config import RuntimeConfig
from mflux.utils.exceptions import StopImageGenerationException
from tqdm import tqdm
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.pipeline.adapter import (
BlockWrapperMode,
ModelAdapter,
PromptData,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
def calculate_patch_heights(latent_height: int, num_patches: int):
patch_height = ceil(latent_height / num_patches)
actual_num_patches = ceil(latent_height / patch_height)
patch_heights = [patch_height] * (actual_num_patches - 1)
last_height = latent_height - patch_height * (actual_num_patches - 1)
patch_heights.append(last_height)
return patch_heights, actual_num_patches
def calculate_token_indices(patch_heights: list[int], latent_width: int):
tokens_per_row = latent_width
token_ranges = []
cumulative_height = 0
for h in patch_heights:
start_token = tokens_per_row * cumulative_height
end_token = tokens_per_row * (cumulative_height + h)
token_ranges.append((start_token, end_token))
cumulative_height += h
return token_ranges
class DiffusionRunner:
"""Orchestrates the diffusion loop for image generation.
This class owns the entire diffusion process, handling both single-node
and distributed (PipeFusion) modes.
In distributed mode, it implements PipeFusion with:
- Sync pipeline for initial timesteps (full image, all devices in lockstep)
- Async pipeline for later timesteps (patches processed independently)
"""
def __init__(
self,
config: ImageModelConfig,
adapter: ModelAdapter,
group: Optional[mx.distributed.Group],
shard_metadata: PipelineShardMetadata,
num_sync_steps: int = 1,
num_patches: Optional[int] = None,
):
"""Initialize the diffusion runner.
Args:
config: Model configuration (architecture, block counts, etc.)
adapter: Model adapter for model-specific operations
group: MLX distributed group (None for single-node mode)
shard_metadata: Pipeline shard metadata with layer assignments
num_sync_steps: Number of synchronous timesteps before async mode
num_patches: Number of patches for async mode (defaults to world_size)
"""
self.config = config
self.adapter = adapter
self.group = group
# Handle single-node vs distributed mode
if group is None:
self.rank = 0
self.world_size = 1
self.next_rank = 0
self.prev_rank = 0
self.start_layer = 0
self.end_layer = config.total_blocks
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.next_rank = (self.rank + 1) % self.world_size
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.num_sync_steps = num_sync_steps
self.num_patches = num_patches if num_patches else max(1, self.world_size)
# Persistent KV caches (initialized on first async timestep, reused across timesteps)
self.joint_kv_caches: list[ImagePatchKVCache] | None = None
self.single_kv_caches: list[ImagePatchKVCache] | None = None
# Get block counts from config (model-agnostic)
self.total_joint = config.joint_block_count
self.total_single = config.single_block_count
self.total_layers = config.total_blocks
self._compute_assigned_blocks()
def _compute_assigned_blocks(self) -> None:
"""Determine which joint/single blocks this stage owns."""
start = self.start_layer
end = self.end_layer
if end <= self.total_joint:
# All assigned blocks are joint blocks
self.joint_start = start
self.joint_end = end
self.single_start = 0
self.single_end = 0
elif start >= self.total_joint:
# All assigned blocks are single blocks
self.joint_start = 0
self.joint_end = 0
self.single_start = start - self.total_joint
self.single_end = end - self.total_joint
else:
# Stage spans joint→single transition
self.joint_start = start
self.joint_end = self.total_joint
self.single_start = 0
self.single_end = end - self.total_joint
self.has_joint_blocks = self.joint_end > self.joint_start
self.has_single_blocks = self.single_end > self.single_start
self.owns_concat_stage = self.has_joint_blocks and (
self.has_single_blocks or self.end_layer == self.total_joint
)
joint_blocks = self.adapter.get_joint_blocks()
single_blocks = self.adapter.get_single_blocks()
# Wrap blocks at initialization (reused across all calls)
self.joint_block_wrappers = [
JointBlockWrapper(block=block, adapter=self.adapter)
for block in joint_blocks
]
self.single_block_wrappers = [
SingleBlockWrapper(block=block, adapter=self.adapter)
for block in single_blocks
]
@property
def is_first_stage(self) -> bool:
return self.rank == 0
@property
def is_last_stage(self) -> bool:
return self.rank == self.world_size - 1
@property
def is_distributed(self) -> bool:
return self.group is not None
def generate_image(
self,
settings: Config,
prompt: str,
seed: int,
) -> Any:
"""Primary entry point for image generation.
Orchestrates the full generation flow:
1. Create runtime config
2. Create initial latents
3. Encode prompt
4. Run diffusion loop
5. Decode to image
Args:
settings: Generation config (steps, height, width)
prompt: Text prompt
seed: Random seed
Returns:
GeneratedImage result
"""
runtime_config = RuntimeConfig(settings, self.adapter.model.model_config)
latents = self.adapter.create_latents(seed, runtime_config)
prompt_data = self.adapter.encode_prompt(prompt)
latents = self._run_diffusion_loop(
latents=latents,
prompt_data=prompt_data,
runtime_config=runtime_config,
seed=seed,
prompt=prompt,
)
if self.is_last_stage:
return self.adapter.decode_latents(latents, runtime_config, seed, prompt)
def _run_diffusion_loop(
self,
latents: mx.array,
prompt_data: PromptData,
runtime_config: RuntimeConfig,
seed: int,
prompt: str,
) -> mx.array:
"""Execute the diffusion loop.
Args:
latents: Initial noise latents
prompt_data: Encoded prompt data
runtime_config: RuntimeConfig with scheduler, steps, dimensions
seed: Random seed (for callbacks)
prompt: Text prompt (for callbacks)
Returns:
Final denoised latents ready for VAE decoding
"""
time_steps = tqdm(
range(runtime_config.init_time_step, runtime_config.num_inference_steps)
)
# Call subscribers for beginning of loop
Callbacks.before_loop(
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
)
for t in time_steps:
try:
latents = self._diffusion_step(
t=t,
config=runtime_config,
latents=latents,
prompt_data=prompt_data,
)
# Call subscribers in-loop
Callbacks.in_loop(
t=t,
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
time_steps=time_steps,
)
mx.eval(latents)
except KeyboardInterrupt: # noqa: PERF203
Callbacks.interruption(
t=t,
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
time_steps=time_steps,
)
raise StopImageGenerationException(
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
) from None
# Call subscribers after loop
Callbacks.after_loop(
seed=seed,
prompt=prompt,
latents=latents,
config=runtime_config,
)
return latents
def _forward_pass(
self,
latents: mx.array,
prompt_embeds: mx.array,
pooled_prompt_embeds: mx.array,
kwargs: dict[str, Any],
) -> mx.array:
"""Run a single forward pass through the transformer.
This is the internal method called by adapters via compute_step_noise.
Returns noise prediction without applying scheduler step.
Args:
latents: Input latents (already scaled by caller)
prompt_embeds: Text embeddings
pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)
kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask, t)
Returns:
Noise prediction tensor
"""
t = kwargs.get("t", 0)
config = kwargs.get("config")
if config is None:
raise ValueError("config must be provided in kwargs")
scaled_latents = config.scheduler.scale_model_input(latents, t)
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
scaled_latents, prompt_embeds
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds, hidden_states=hidden_states
)
rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds, config, **kwargs
)
text_seq_len = prompt_embeds.shape[1]
# Run through all joint blocks
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=None,
mode=BlockWrapperMode.CACHING,
block_idx=block_idx,
**kwargs,
)
# Merge streams
if self.joint_block_wrappers:
hidden_states = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
# Run through single blocks
for wrapper in self.single_block_wrappers:
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=None,
mode=BlockWrapperMode.CACHING,
)
# Extract image portion and project
hidden_states = hidden_states[:, text_seq_len:, ...]
return self.adapter.final_projection(hidden_states, text_embeddings)
def _diffusion_step(
self,
t: int,
config: RuntimeConfig,
latents: mx.array,
prompt_data: PromptData,
) -> mx.array:
"""Execute a single diffusion step.
Routes to single-node, sync pipeline, or async pipeline based on
configuration and current timestep.
"""
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif t < self.num_sync_steps:
return self._sync_pipeline(
t,
config,
latents,
prompt_data,
)
else:
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
)
def _single_node_step(
self,
t: int,
config: RuntimeConfig,
latents: mx.array,
prompt_data: PromptData,
) -> mx.array:
"""Execute a single diffusion step on a single node (no distribution)."""
base_kwargs = {"t": t, "config": config}
if self.adapter.needs_cfg:
# Two forward passes + guidance for CFG models (e.g., Qwen)
pos_kwargs = {
**base_kwargs,
**prompt_data.get_extra_forward_kwargs(positive=True),
}
noise_pos = self._forward_pass(
latents,
prompt_data.prompt_embeds,
prompt_data.pooled_prompt_embeds,
pos_kwargs,
)
neg_kwargs = {
**base_kwargs,
**prompt_data.get_extra_forward_kwargs(positive=False),
}
noise_neg = self._forward_pass(
latents,
prompt_data.negative_prompt_embeds,
prompt_data.negative_pooled_prompt_embeds,
neg_kwargs,
)
assert self.config.guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale=self.config.guidance_scale
)
else:
# Single forward pass for non-CFG models (e.g., Flux)
kwargs = {**base_kwargs, **prompt_data.get_extra_forward_kwargs()}
noise = self._forward_pass(
latents,
prompt_data.prompt_embeds,
prompt_data.pooled_prompt_embeds,
kwargs,
)
return config.scheduler.step(model_output=noise, timestep=t, sample=latents)
def _initialize_kv_caches(
self,
batch_size: int,
num_img_tokens: int,
dtype: mx.Dtype,
) -> None:
"""Initialize KV caches for both sync and async pipelines.
Note: Caches only store IMAGE K/V, not text K/V. Text K/V is always
computed fresh and doesn't need caching (it's the same for all patches).
"""
self.joint_kv_caches = [
ImagePatchKVCache(
batch_size=batch_size,
num_heads=self.config.num_heads,
image_seq_len=num_img_tokens,
head_dim=self.config.head_dim,
dtype=dtype,
)
for _ in range(len(self.joint_block_wrappers))
]
self.single_kv_caches = [
ImagePatchKVCache(
batch_size=batch_size,
num_heads=self.config.num_heads,
image_seq_len=num_img_tokens,
head_dim=self.config.head_dim,
dtype=dtype,
)
for _ in range(len(self.single_block_wrappers))
]
def _create_patches(
self,
latents: mx.array,
config: RuntimeConfig,
) -> tuple[list[mx.array], list[tuple[int, int]]]:
"""Split latents into patches for async pipeline."""
# Use 16 to match FluxLatentCreator.create_noise formula
latent_height = config.height // 16
latent_width = config.width // 16
patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)
token_indices = calculate_token_indices(patch_heights, latent_width)
# Split latents into patches
patch_latents = [latents[:, start:end, :] for start, end in token_indices]
return patch_latents, token_indices
def _sync_pipeline(
self,
t: int,
config: RuntimeConfig,
hidden_states: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
prompt_embeds = prompt_data.prompt_embeds
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
extra_kwargs = prompt_data.get_extra_forward_kwargs()
hidden_states = config.scheduler.scale_model_input(hidden_states, t)
# === PHASE 1: Embeddings ===
if self.is_first_stage:
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
hidden_states, prompt_embeds
)
# All stages need these for their blocks
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
kontext_image_ids=kontext_image_ids,
**extra_kwargs,
)
# === Initialize KV caches to populate during sync for async warmstart ===
batch_size = prev_latents.shape[0]
num_img_tokens = prev_latents.shape[1]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
if t == 0:
self._initialize_kv_caches(
batch_size=batch_size,
num_img_tokens=num_img_tokens,
dtype=prev_latents.dtype,
)
# === PHASE 2: Joint Blocks with Communication and Caching ===
if self.has_joint_blocks:
# Receive from previous stage (if not first stage)
if not self.is_first_stage:
recv_template = mx.zeros(
(batch_size, num_img_tokens, hidden_dim), dtype=prev_latents.dtype
)
hidden_states = mx.distributed.recv_like(
recv_template, self.prev_rank, group=self.group
)
enc_template = mx.zeros(
(batch_size, text_seq_len, hidden_dim), dtype=prev_latents.dtype
)
encoder_hidden_states = mx.distributed.recv_like(
enc_template, self.prev_rank, group=self.group
)
mx.eval(hidden_states, encoder_hidden_states)
# Run assigned joint blocks with caching mode
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.joint_kv_caches[block_idx],
mode=BlockWrapperMode.CACHING,
**extra_kwargs,
)
# === PHASE 3: Joint→Single Transition ===
if self.owns_concat_stage:
# Merge encoder and hidden states using adapter hook
concatenated = self.adapter.merge_streams(
hidden_states, encoder_hidden_states
)
if self.has_single_blocks or self.is_last_stage:
# Keep locally: either for single blocks or final projection
hidden_states = concatenated
else:
# Send concatenated state to next stage (which has single blocks)
mx.eval(
mx.distributed.send(concatenated, self.next_rank, group=self.group)
)
elif self.has_joint_blocks and not self.is_last_stage:
# Send joint block outputs to next stage (which has more joint blocks)
mx.eval(
mx.distributed.send(hidden_states, self.next_rank, group=self.group),
mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
),
)
# === PHASE 4: Single Blocks with Communication and Caching ===
if self.has_single_blocks:
# Receive from previous stage if we didn't do concatenation
if not self.owns_concat_stage and not self.is_first_stage:
recv_template = mx.zeros(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype=prev_latents.dtype,
)
hidden_states = mx.distributed.recv_like(
recv_template, self.prev_rank, group=self.group
)
mx.eval(hidden_states)
# Run assigned single blocks with caching mode
for block_idx, wrapper in enumerate(self.single_block_wrappers):
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.single_kv_caches[block_idx],
mode=BlockWrapperMode.CACHING,
)
# Send to next stage if not last
if not self.is_last_stage:
mx.eval(
mx.distributed.send(hidden_states, self.next_rank, group=self.group)
)
# === PHASE 5: Last Stage - Final Projection + Scheduler ===
# Extract image portion (remove text embeddings prefix)
hidden_states = hidden_states[:, text_seq_len:, ...]
if self.is_last_stage:
hidden_states = self.adapter.final_projection(
hidden_states, text_embeddings
)
hidden_states = config.scheduler.step(
model_output=hidden_states,
timestep=t,
sample=prev_latents,
)
if not self.is_first_stage:
mx.eval(mx.distributed.send(hidden_states, 0, group=self.group))
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
# For shape correctness
hidden_states = prev_latents
return hidden_states
def _async_pipeline_step(
self,
t: int,
config: RuntimeConfig,
latents: mx.array,
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
patch_latents = self._async_pipeline(
t,
config,
patch_latents,
token_indices,
prompt_data,
kontext_image_ids,
)
return mx.concatenate(patch_latents, axis=1)
def _async_pipeline(
self,
t: int,
config: RuntimeConfig,
patch_latents: list[mx.array],
token_indices: list[tuple[int, int]],
prompt_data: PromptData,
kontext_image_ids: mx.array | None = None,
) -> list[mx.array]:
"""Execute async pipeline for all patches."""
assert self.joint_kv_caches is not None
assert self.single_kv_caches is not None
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
prompt_embeds = prompt_data.prompt_embeds
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
extra_kwargs = prompt_data.get_extra_forward_kwargs()
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
kontext_image_ids=kontext_image_ids,
**extra_kwargs,
)
batch_size = patch_latents[0].shape[0]
text_seq_len = prompt_embeds.shape[1]
hidden_dim = self.adapter.hidden_dim
for patch_idx, patch in enumerate(patch_latents):
patch_prev = patch
start_token, end_token = token_indices[patch_idx]
if self.has_joint_blocks:
if not self.is_first_stage or t != self.num_sync_steps:
if self.is_first_stage:
# First stage receives latent-space from last stage (scheduler output)
recv_template = patch
else:
# Other stages receive hidden-space from previous stage
patch_len = patch.shape[1]
recv_template = mx.zeros(
(batch_size, patch_len, hidden_dim),
dtype=patch.dtype,
)
patch = mx.distributed.recv_like(
recv_template, src=self.prev_rank, group=self.group
)
mx.eval(patch)
patch_latents[patch_idx] = patch
if not self.is_first_stage and patch_idx == 0:
enc_template = mx.zeros(
(batch_size, text_seq_len, hidden_dim),
dtype=patch_latents[0].dtype,
)
encoder_hidden_states = mx.distributed.recv_like(
enc_template, src=self.prev_rank, group=self.group
)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
patch, prompt_embeds
)
# Run assigned joint blocks with patched mode
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.joint_kv_caches[block_idx],
mode=BlockWrapperMode.PATCHED,
patch_start=start_token,
patch_end=end_token,
**extra_kwargs,
)
if self.owns_concat_stage:
patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)
if self.has_single_blocks or self.is_last_stage:
# Keep locally: either for single blocks or final projection
patch = patch_concat
else:
mx.eval(
mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
)
elif self.has_joint_blocks and not self.is_last_stage:
mx.eval(mx.distributed.send(patch, self.next_rank, group=self.group))
if patch_idx == 0:
mx.eval(
mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
recv_template = mx.zeros(
[
batch_size,
text_seq_len + patch_latents[patch_idx].shape[1],
hidden_dim,
],
dtype=patch_latents[0].dtype,
)
patch = mx.distributed.recv_like(
recv_template, src=self.prev_rank, group=self.group
)
mx.eval(patch)
patch_latents[patch_idx] = patch
# Run assigned single blocks with patched mode
for block_idx, wrapper in enumerate(self.single_block_wrappers):
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
text_seq_len=text_seq_len,
kv_cache=self.single_kv_caches[block_idx],
mode=BlockWrapperMode.PATCHED,
patch_start=start_token,
patch_end=end_token,
)
if not self.is_last_stage:
mx.eval(
mx.distributed.send(patch, self.next_rank, group=self.group)
)
if self.is_last_stage:
patch_img_only = patch[:, text_seq_len:, :]
patch_img_only = self.adapter.final_projection(
patch_img_only, text_embeddings
)
patch = config.scheduler.step(
model_output=patch_img_only,
timestep=t,
sample=patch_prev,
)
if not self.is_first_stage and t != config.num_inference_steps - 1:
mx.eval(
mx.distributed.send(patch, self.next_rank, group=self.group)
)
patch_latents[patch_idx] = patch
return patch_latents

View File

@@ -103,6 +103,7 @@ class PipelineLastLayer(CustomMlxLayer):
# This change happened upstream - check out mlx github somewhere??
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
# TODO(ciaran): This is overkill
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
return output

View File

@@ -1,7 +1,7 @@
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.tokenizer_utils import TokenizerWrapper

View File

@@ -9,6 +9,8 @@ from exo.shared.types.tasks import (
ConnectToGroup,
CreateRunner,
DownloadModel,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -265,7 +267,12 @@ def _pending_tasks(
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
if not isinstance(task, ChatCompletion):
# TODO(ciaran): do this better!
if (
not isinstance(task, ChatCompletion)
and not isinstance(task, ImageGeneration)
and not isinstance(task, ImageEdits)
):
continue
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
continue

View File

@@ -1,7 +1,10 @@
import base64
import time
from exo.master.api import get_model_card
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -9,9 +12,12 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelTask
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -21,6 +27,7 @@ from exo.shared.types.tasks import (
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ImageGenerationResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -37,6 +44,12 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
from exo.worker.engines.image import (
ImageGenerator,
generate_image,
initialize_image_model,
warmup_image_generator,
)
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -70,6 +83,10 @@ def main(
sampler = None
group = None
model_card = get_model_card(shard_metadata.model_meta.model_id)
assert model_card
model_tasks = model_card.tasks
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
event_sender.send(
@@ -112,16 +129,26 @@ def main(
)
)
model, tokenizer, sampler = load_mlx_items(
bound_instance, group
)
# TODO(ciaran): switch
if ModelTask.TextGeneration in model_tasks:
model, tokenizer, sampler = load_mlx_items(
bound_instance, group
)
elif (
ModelTask.TextToImage in model_tasks
or ModelTask.ImageToImage in model_tasks
):
model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {model_card.tasks}"
)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
assert sampler
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -131,22 +158,42 @@ def main(
)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
sampler=sampler,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
if ModelTask.TextGeneration in model_tasks:
# assert isinstance(model, Model) TODO(ciaran): not actually Model
assert model and not isinstance(model, ImageGenerator)
assert tokenizer
assert sampler
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
sampler=sampler,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
elif (
ModelTask.TextToImage in model_tasks
or ModelTask.ImageToImage in model_tasks
):
assert isinstance(model, ImageGenerator)
image = warmup_image_generator(model=model)
if image is not None:
logger.info(
f"warmed up by generating {image.size} image"
)
else:
logger.info("warmup completed (non-primary node)")
current_status = RunnerReady()
logger.info("runner ready")
case ChatCompletion(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert model
# assert isinstance(model, Model) TODO(ciaran): not actually Model
assert model and not isinstance(model, ImageGenerator)
assert tokenizer
assert sampler
logger.info(f"received chat request: {str(task)[:500]}")
@@ -187,6 +234,130 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, ImageGenerator)
logger.info(
f"received image generation request: {str(task)[:500]}"
)
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
# Generate images using the image generation backend
for image_index, response in enumerate(
generate_image(
model=model,
task=task_params,
)
):
match response:
case ImageGenerationResponse():
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
encoded_data = base64.b64encode(
response.image_data
).decode("utf-8")
# Split into chunks to stay under gossipsub 1MB limit
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
)
]
total_chunks = len(data_chunks)
logger.info(
f"sending ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(
data_chunks
):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=shard_metadata.model_meta.model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
),
)
)
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, ImageGenerator)
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
for image_index, response in enumerate(
generate_image(
model=model,
task=task_params,
)
):
match response:
case ImageGenerationResponse():
if shard_metadata.device_rank == 0:
encoded_data = base64.b64encode(
response.image_data
).decode("utf-8")
# Split into chunks to stay under gossipsub 1MB limit
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
)
]
total_chunks = len(data_chunks)
logger.info(
f"sending ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(
data_chunks
):
is_last_chunk = (
chunk_index == total_chunks - 1
)
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=shard_metadata.model_meta.model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
),
)
)
current_status = RunnerReady()
logger.info("runner ready")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=RunnerReady()
)
)
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")

1293
uv.lock generated
View File

File diff suppressed because it is too large Load Diff