mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-20 11:58:57 -05:00
Compare commits
252 Commits
upstream-s
...
ciaran/ima
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09a4454427 | ||
|
|
12d2da7e47 | ||
|
|
e98affdb4d | ||
|
|
c33f79a133 | ||
|
|
4d826b465e | ||
|
|
3fbe0b09bb | ||
|
|
2c9bfc0d8a | ||
|
|
f4d7606ae3 | ||
|
|
fb81b5c0db | ||
|
|
39a26c6745 | ||
|
|
35e62c8087 | ||
|
|
f107ae0b43 | ||
|
|
ba2faf3a71 | ||
|
|
305f4b79d3 | ||
|
|
40624bf494 | ||
|
|
1277265ae2 | ||
|
|
b5d1e78aba | ||
|
|
c712726778 | ||
|
|
c4a1f2cde8 | ||
|
|
4e287ab471 | ||
|
|
beb39442ab | ||
|
|
1684907879 | ||
|
|
842569655c | ||
|
|
c437154755 | ||
|
|
841f89cb0a | ||
|
|
8bc12db8d3 | ||
|
|
0fd4ff1151 | ||
|
|
9cba51aee6 | ||
|
|
874ad0e4c0 | ||
|
|
102681334d | ||
|
|
bda4c903a4 | ||
|
|
b13c75671c | ||
|
|
b3b3ae6ae6 | ||
|
|
0fcd26c9d6 | ||
|
|
d8c0d3e1c5 | ||
|
|
acfacbf34c | ||
|
|
8b39485250 | ||
|
|
5d5676103e | ||
|
|
3af221bf75 | ||
|
|
e6993ad537 | ||
|
|
0fc68617fb | ||
|
|
983420619e | ||
|
|
9381edfaa6 | ||
|
|
f3a840011d | ||
|
|
971b8367ab | ||
|
|
2d0335b7cc | ||
|
|
89b2471183 | ||
|
|
e39aec8189 | ||
|
|
00a828494b | ||
|
|
0c2b760179 | ||
|
|
430ea689b1 | ||
|
|
9e5c64eb07 | ||
|
|
d6b44f4ba1 | ||
|
|
ebcefc1fb3 | ||
|
|
e7d9a66c15 | ||
|
|
5280b4d4ae | ||
|
|
847ca9c21f | ||
|
|
69f6d37c66 | ||
|
|
c74d59c810 | ||
|
|
99e8223c0b | ||
|
|
fd321398f2 | ||
|
|
c1e8bcb54c | ||
|
|
82b0edf277 | ||
|
|
37a353448f | ||
|
|
699b63d1da | ||
|
|
aaa7dcc0b3 | ||
|
|
9d1963a891 | ||
|
|
35f1425490 | ||
|
|
0cb008328f | ||
|
|
796e9ccd9a | ||
|
|
7a1c6b2a18 | ||
|
|
f992389cb3 | ||
|
|
12cf753614 | ||
|
|
8673168b21 | ||
|
|
91bf4b5b5d | ||
|
|
bf4e114ec9 | ||
|
|
f39792b5b9 | ||
|
|
2bc2cd2108 | ||
|
|
88289b617b | ||
|
|
5295873bb2 | ||
|
|
1de5fb6716 | ||
|
|
4a47448fbf | ||
|
|
b381d4b518 | ||
|
|
8be71107fd | ||
|
|
c67356f4dd | ||
|
|
65e2a24d05 | ||
|
|
4093078730 | ||
|
|
76cbcabcab | ||
|
|
36944d40ae | ||
|
|
85f9bb4b27 | ||
|
|
2e590c18dd | ||
|
|
4e881a6c21 | ||
|
|
af8d373411 | ||
|
|
55eb7620bf | ||
|
|
0fb299db17 | ||
|
|
a7870e4904 | ||
|
|
95c7b26178 | ||
|
|
8747a74b32 | ||
|
|
6c9e582ffa | ||
|
|
ae07a76d99 | ||
|
|
c41f7adfdc | ||
|
|
5d38221731 | ||
|
|
a1d7eb61b6 | ||
|
|
dc51132bb7 | ||
|
|
5c9227ce42 | ||
|
|
e0cd04c5f3 | ||
|
|
c2aab343c4 | ||
|
|
3bcdd46bb1 | ||
|
|
46181a35ae | ||
|
|
e29d0b4a0e | ||
|
|
633147cb02 | ||
|
|
3f5b4a43db | ||
|
|
9ee7a3e92b | ||
|
|
3e40b2beb5 | ||
|
|
4d6f339e6f | ||
|
|
282b63effb | ||
|
|
cb537b0110 | ||
|
|
afc1643eb2 | ||
|
|
02ca7d5a4b | ||
|
|
73abda4f17 | ||
|
|
f492fd4be8 | ||
|
|
e59b51788b | ||
|
|
f1b08bdf68 | ||
|
|
1e74c5ec4f | ||
|
|
839e845876 | ||
|
|
8ab312de44 | ||
|
|
9e0d646505 | ||
|
|
6a21dca2e0 | ||
|
|
b491607a8f | ||
|
|
258754a5e8 | ||
|
|
09ec079be8 | ||
|
|
3da22204db | ||
|
|
60d7ea6265 | ||
|
|
f37751b31f | ||
|
|
0f357e1f9b | ||
|
|
9b0a621987 | ||
|
|
5574eb57e5 | ||
|
|
ddc67f09cc | ||
|
|
27b343316b | ||
|
|
99705175ee | ||
|
|
8aad72b4d7 | ||
|
|
2ce3833b17 | ||
|
|
ac535d5725 | ||
|
|
9dae2eafd4 | ||
|
|
d082db113d | ||
|
|
48298104a8 | ||
|
|
49823aa6d5 | ||
|
|
76eb5171f4 | ||
|
|
7fa8ebacad | ||
|
|
cebd3de003 | ||
|
|
e6cd1291b9 | ||
|
|
52d9e77bed | ||
|
|
3d41745b51 | ||
|
|
c5c3f43c6c | ||
|
|
8ac0686125 | ||
|
|
4d8a759eb2 | ||
|
|
5d7c0847c1 | ||
|
|
c4e916fc02 | ||
|
|
06ce520f4d | ||
|
|
1083808072 | ||
|
|
686fa6e04c | ||
|
|
fb0e64f4c0 | ||
|
|
11f2a5bda5 | ||
|
|
2b547a18f8 | ||
|
|
c4d558d550 | ||
|
|
5d5a1aa561 | ||
|
|
33d34e376b | ||
|
|
fa1d6de18c | ||
|
|
bdd3863d2d | ||
|
|
48b687af45 | ||
|
|
a91e0f55ca | ||
|
|
35ef128f04 | ||
|
|
34d3b84405 | ||
|
|
7a4b9e7884 | ||
|
|
da251c0c47 | ||
|
|
87b6940580 | ||
|
|
630af5b0f1 | ||
|
|
2ac27234fb | ||
|
|
1ad090b4f4 | ||
|
|
4f746c1575 | ||
|
|
e1921b24a0 | ||
|
|
c76da7c220 | ||
|
|
af27feedc9 | ||
|
|
f088bbad92 | ||
|
|
3f1bfc6ee1 | ||
|
|
eaa450e2db | ||
|
|
8964137f2b | ||
|
|
984084fe8e | ||
|
|
dd2d25951d | ||
|
|
9f5f763993 | ||
|
|
244f3a1bb4 | ||
|
|
efa419b36b | ||
|
|
ec035fab4e | ||
|
|
48cef402f1 | ||
|
|
ec68ad19a5 | ||
|
|
a2f0da26b6 | ||
|
|
49ac2e1aeb | ||
|
|
78d9c5264d | ||
|
|
f5f19415da | ||
|
|
e8c4293f1e | ||
|
|
f278464ff9 | ||
|
|
ce06bbb95e | ||
|
|
86c4d24e23 | ||
|
|
551daa2c06 | ||
|
|
6ca3042f8e | ||
|
|
62b653299f | ||
|
|
2b4e931873 | ||
|
|
c67b684552 | ||
|
|
926a157476 | ||
|
|
d9586028d1 | ||
|
|
e971744960 | ||
|
|
b2abece9aa | ||
|
|
1d5bd3e447 | ||
|
|
24668c54fc | ||
|
|
aff75d9992 | ||
|
|
49cb8422a1 | ||
|
|
19566deffd | ||
|
|
d046475ec4 | ||
|
|
780e0ea425 | ||
|
|
d301694f9c | ||
|
|
407c0b2418 | ||
|
|
3559e7f43b | ||
|
|
938cf08a5f | ||
|
|
3bc016cec4 | ||
|
|
21a1dc8c55 | ||
|
|
1ea271788f | ||
|
|
3a71354b1c | ||
|
|
d7b423db0e | ||
|
|
b940578e57 | ||
|
|
10dda7936d | ||
|
|
addc3ffe15 | ||
|
|
27c4f60a91 | ||
|
|
b9c34cdcc4 | ||
|
|
fd5835271f | ||
|
|
77d5c9b38f | ||
|
|
051577d122 | ||
|
|
013f66956a | ||
|
|
eead745cd0 | ||
|
|
6895ea935b | ||
|
|
8451cb7d16 | ||
|
|
3c8f5f3464 | ||
|
|
1bf4b42830 | ||
|
|
5ba044d0f8 | ||
|
|
c33fd5c9ae | ||
|
|
0e789435d4 | ||
|
|
0ca2e449b9 | ||
|
|
c99cd5fc87 | ||
|
|
e86caf8f48 | ||
|
|
2d483a5a77 | ||
|
|
01db2adc80 | ||
|
|
c10f87b006 | ||
|
|
0c6d04e085 |
9
dashboard/package-lock.json
generated
9
dashboard/package-lock.json
generated
@@ -863,6 +863,7 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -902,6 +903,7 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1518,6 +1520,7 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1527,6 +1530,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1939,6 +1943,7 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2646,6 +2651,7 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2833,6 +2839,7 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2977,6 +2984,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -2998,6 +3006,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -1,8 +1,22 @@
|
||||
<script lang="ts">
|
||||
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
|
||||
import ChatAttachments from './ChatAttachments.svelte';
|
||||
import type { ChatUploadedFile } from '$lib/types/files';
|
||||
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
|
||||
import {
|
||||
isLoading,
|
||||
sendMessage,
|
||||
generateImage,
|
||||
editImage,
|
||||
editingImage,
|
||||
clearEditingImage,
|
||||
selectedChatModel,
|
||||
setSelectedChatModel,
|
||||
instances,
|
||||
ttftMs,
|
||||
tps,
|
||||
totalTokens,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import ChatAttachments from "./ChatAttachments.svelte";
|
||||
import ImageParamsPanel from "./ImageParamsPanel.svelte";
|
||||
import type { ChatUploadedFile } from "$lib/types/files";
|
||||
import { processUploadedFiles, getAcceptString } from "$lib/types/files";
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -10,17 +24,19 @@
|
||||
showHelperText?: boolean;
|
||||
autofocus?: boolean;
|
||||
showModelSelector?: boolean;
|
||||
modelTasks?: Record<string, string[]>;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
placeholder = 'Ask anything',
|
||||
let {
|
||||
class: className = "",
|
||||
placeholder = "Ask anything",
|
||||
showHelperText = false,
|
||||
autofocus = true,
|
||||
showModelSelector = false
|
||||
showModelSelector = false,
|
||||
modelTasks = {},
|
||||
}: Props = $props();
|
||||
|
||||
let message = $state('');
|
||||
let message = $state("");
|
||||
let textareaRef: HTMLTextAreaElement | undefined = $state();
|
||||
let fileInputRef: HTMLInputElement | undefined = $state();
|
||||
let uploadedFiles = $state<ChatUploadedFile[]>([]);
|
||||
@@ -31,30 +47,82 @@
|
||||
const currentTtft = $derived(ttftMs());
|
||||
const currentTps = $derived(tps());
|
||||
const currentTokens = $derived(totalTokens());
|
||||
|
||||
const currentEditingImage = $derived(editingImage());
|
||||
const isEditMode = $derived(currentEditingImage !== null);
|
||||
|
||||
// Custom dropdown state
|
||||
let isModelDropdownOpen = $state(false);
|
||||
let dropdownButtonRef: HTMLButtonElement | undefined = $state();
|
||||
let dropdownPosition = $derived(() => {
|
||||
if (!dropdownButtonRef || !isModelDropdownOpen) return { top: 0, left: 0, width: 0 };
|
||||
if (!dropdownButtonRef || !isModelDropdownOpen)
|
||||
return { top: 0, left: 0, width: 0 };
|
||||
const rect = dropdownButtonRef.getBoundingClientRect();
|
||||
return {
|
||||
top: rect.top,
|
||||
left: rect.left,
|
||||
width: rect.width
|
||||
width: rect.width,
|
||||
};
|
||||
});
|
||||
|
||||
// Accept all supported file types
|
||||
const acceptString = getAcceptString(['image', 'text', 'pdf']);
|
||||
const acceptString = getAcceptString(["image", "text", "pdf"]);
|
||||
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const tasks = modelTasks[modelId] || [];
|
||||
return tasks.includes("TextToImage") || tasks.includes("ImageToImage");
|
||||
}
|
||||
|
||||
function modelSupportsTextToImage(modelId: string): boolean {
|
||||
const tasks = modelTasks[modelId] || [];
|
||||
return tasks.includes("TextToImage");
|
||||
}
|
||||
|
||||
function modelSupportsOnlyImageEditing(modelId: string): boolean {
|
||||
const tasks = modelTasks[modelId] || [];
|
||||
return tasks.includes("ImageToImage") && !tasks.includes("TextToImage");
|
||||
}
|
||||
|
||||
function modelSupportsImageEditing(modelId: string): boolean {
|
||||
const tasks = modelTasks[modelId] || [];
|
||||
return tasks.includes("ImageToImage");
|
||||
}
|
||||
|
||||
const isImageModel = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
return modelSupportsTextToImage(currentModel);
|
||||
});
|
||||
|
||||
const isEditOnlyWithoutImage = $derived(
|
||||
currentModel !== null &&
|
||||
modelSupportsOnlyImageEditing(currentModel) &&
|
||||
!isEditMode &&
|
||||
uploadedFiles.length === 0,
|
||||
);
|
||||
|
||||
// Show edit mode when: explicit edit mode OR (model supports ImageToImage AND files attached)
|
||||
const shouldShowEditMode = $derived(
|
||||
isEditMode ||
|
||||
(currentModel &&
|
||||
modelSupportsImageEditing(currentModel) &&
|
||||
uploadedFiles.length > 0),
|
||||
);
|
||||
|
||||
// Extract available models from running instances
|
||||
const availableModels = $derived(() => {
|
||||
const models: Array<{id: string, label: string}> = [];
|
||||
const models: Array<{ id: string; label: string; isImageModel: boolean }> =
|
||||
[];
|
||||
for (const [, instance] of Object.entries(instanceData)) {
|
||||
const modelId = getInstanceModelId(instance);
|
||||
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
|
||||
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
|
||||
if (
|
||||
modelId &&
|
||||
modelId !== "Unknown" &&
|
||||
!models.some((m) => m.id === modelId)
|
||||
) {
|
||||
models.push({
|
||||
id: modelId,
|
||||
label: modelId.split("/").pop() || modelId,
|
||||
isImageModel: modelSupportsImageGeneration(modelId),
|
||||
});
|
||||
}
|
||||
}
|
||||
return models;
|
||||
@@ -66,18 +134,18 @@
|
||||
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
|
||||
$effect(() => {
|
||||
const models = availableModels();
|
||||
const currentModelIds = new Set(models.map(m => m.id));
|
||||
const currentModelIds = new Set(models.map((m) => m.id));
|
||||
|
||||
if (models.length > 0) {
|
||||
// Find newly added models (in current but not in previous)
|
||||
const newModels = models.filter(m => !previousModelIds.has(m.id));
|
||||
const newModels = models.filter((m) => !previousModelIds.has(m.id));
|
||||
|
||||
// If no model selected, select the first available
|
||||
if (!currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If current model is stale (no longer has a running instance), reset to first available
|
||||
else if (!models.some(m => m.id === currentModel)) {
|
||||
else if (!models.some((m) => m.id === currentModel)) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If a new model was just added, select it
|
||||
@@ -87,7 +155,7 @@
|
||||
} else {
|
||||
// No instances running - clear the selected model
|
||||
if (currentModel) {
|
||||
setSelectedChatModel('');
|
||||
setSelectedChatModel("");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,13 +164,15 @@
|
||||
});
|
||||
|
||||
function getInstanceModelId(instanceWrapped: unknown): string {
|
||||
if (!instanceWrapped || typeof instanceWrapped !== 'object') return '';
|
||||
if (!instanceWrapped || typeof instanceWrapped !== "object") return "";
|
||||
const keys = Object.keys(instanceWrapped as Record<string, unknown>);
|
||||
if (keys.length === 1) {
|
||||
const instance = (instanceWrapped as Record<string, unknown>)[keys[0]] as { shardAssignments?: { modelId?: string } };
|
||||
return instance?.shardAssignments?.modelId || '';
|
||||
const instance = (instanceWrapped as Record<string, unknown>)[
|
||||
keys[0]
|
||||
] as { shardAssignments?: { modelId?: string } };
|
||||
return instance?.shardAssignments?.modelId || "";
|
||||
}
|
||||
return '';
|
||||
return "";
|
||||
}
|
||||
|
||||
async function handleFiles(files: File[]) {
|
||||
@@ -115,33 +185,35 @@
|
||||
const input = event.target as HTMLInputElement;
|
||||
if (input.files && input.files.length > 0) {
|
||||
handleFiles(Array.from(input.files));
|
||||
input.value = ''; // Reset for next selection
|
||||
input.value = ""; // Reset for next selection
|
||||
}
|
||||
}
|
||||
|
||||
function handleFileRemove(fileId: string) {
|
||||
uploadedFiles = uploadedFiles.filter(f => f.id !== fileId);
|
||||
uploadedFiles = uploadedFiles.filter((f) => f.id !== fileId);
|
||||
}
|
||||
|
||||
function handlePaste(event: ClipboardEvent) {
|
||||
if (!event.clipboardData) return;
|
||||
|
||||
|
||||
const files = Array.from(event.clipboardData.items)
|
||||
.filter(item => item.kind === 'file')
|
||||
.map(item => item.getAsFile())
|
||||
.filter((item) => item.kind === "file")
|
||||
.map((item) => item.getAsFile())
|
||||
.filter((file): file is File => file !== null);
|
||||
|
||||
|
||||
if (files.length > 0) {
|
||||
event.preventDefault();
|
||||
handleFiles(files);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Handle long text paste as file
|
||||
const text = event.clipboardData.getData('text/plain');
|
||||
const text = event.clipboardData.getData("text/plain");
|
||||
if (text.length > 2500) {
|
||||
event.preventDefault();
|
||||
const textFile = new File([text], 'pasted-text.txt', { type: 'text/plain' });
|
||||
const textFile = new File([text], "pasted-text.txt", {
|
||||
type: "text/plain",
|
||||
});
|
||||
handleFiles([textFile]);
|
||||
}
|
||||
}
|
||||
@@ -159,7 +231,7 @@
|
||||
function handleDrop(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
isDragOver = false;
|
||||
|
||||
|
||||
if (event.dataTransfer?.files) {
|
||||
handleFiles(Array.from(event.dataTransfer.files));
|
||||
}
|
||||
@@ -170,8 +242,8 @@
|
||||
if (event.isComposing || event.keyCode === 229) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.key === 'Enter' && !event.shiftKey) {
|
||||
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
handleSubmit();
|
||||
}
|
||||
@@ -179,29 +251,50 @@
|
||||
|
||||
function handleSubmit() {
|
||||
if ((!message.trim() && uploadedFiles.length === 0) || loading) return;
|
||||
|
||||
|
||||
const content = message.trim();
|
||||
const files = [...uploadedFiles];
|
||||
|
||||
message = '';
|
||||
|
||||
message = "";
|
||||
uploadedFiles = [];
|
||||
resetTextareaHeight();
|
||||
|
||||
sendMessage(content, files);
|
||||
|
||||
|
||||
// Use image editing if in edit mode
|
||||
if (isEditMode && currentEditingImage && content) {
|
||||
editImage(content, currentEditingImage.imageDataUrl);
|
||||
}
|
||||
// If user attached an image with an ImageToImage model, use edit endpoint
|
||||
else if (
|
||||
currentModel &&
|
||||
modelSupportsImageEditing(currentModel) &&
|
||||
files.length > 0 &&
|
||||
content
|
||||
) {
|
||||
// Use the first attached image for editing
|
||||
const imageFile = files[0];
|
||||
if (imageFile.preview) {
|
||||
editImage(content, imageFile.preview);
|
||||
}
|
||||
} else if (isImageModel() && content) {
|
||||
// Use image generation for text-to-image models
|
||||
generateImage(content);
|
||||
} else {
|
||||
sendMessage(content, files);
|
||||
}
|
||||
|
||||
// Refocus the textarea after sending
|
||||
setTimeout(() => textareaRef?.focus(), 10);
|
||||
}
|
||||
|
||||
function handleInput() {
|
||||
if (!textareaRef) return;
|
||||
textareaRef.style.height = 'auto';
|
||||
textareaRef.style.height = Math.min(textareaRef.scrollHeight, 150) + 'px';
|
||||
textareaRef.style.height = "auto";
|
||||
textareaRef.style.height = Math.min(textareaRef.scrollHeight, 150) + "px";
|
||||
}
|
||||
|
||||
function resetTextareaHeight() {
|
||||
if (textareaRef) {
|
||||
textareaRef.style.height = 'auto';
|
||||
textareaRef.style.height = "auto";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,13 +304,13 @@
|
||||
|
||||
// Track previous loading state to detect when loading completes
|
||||
let wasLoading = $state(false);
|
||||
|
||||
|
||||
$effect(() => {
|
||||
if (autofocus && textareaRef) {
|
||||
setTimeout(() => textareaRef?.focus(), 10);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
// Refocus after loading completes (AI response finished)
|
||||
$effect(() => {
|
||||
if (wasLoading && !loading && textareaRef) {
|
||||
@@ -226,7 +319,9 @@
|
||||
wasLoading = loading;
|
||||
});
|
||||
|
||||
const canSend = $derived(message.trim().length > 0 || uploadedFiles.length > 0);
|
||||
const canSend = $derived(
|
||||
message.trim().length > 0 || uploadedFiles.length > 0,
|
||||
);
|
||||
</script>
|
||||
|
||||
<!-- Hidden file input -->
|
||||
@@ -239,69 +334,132 @@
|
||||
onchange={handleFileInputChange}
|
||||
/>
|
||||
|
||||
<form
|
||||
onsubmit={(e) => { e.preventDefault(); handleSubmit(); }}
|
||||
<form
|
||||
onsubmit={(e) => {
|
||||
e.preventDefault();
|
||||
handleSubmit();
|
||||
}}
|
||||
class="w-full {className}"
|
||||
ondragover={handleDragOver}
|
||||
ondragleave={handleDragLeave}
|
||||
ondrop={handleDrop}
|
||||
>
|
||||
<div
|
||||
class="relative command-panel rounded overflow-hidden transition-all duration-200 {isDragOver ? 'ring-2 ring-exo-yellow ring-opacity-50' : ''}"
|
||||
<div
|
||||
class="relative command-panel rounded overflow-hidden transition-all duration-200 {isDragOver
|
||||
? 'ring-2 ring-exo-yellow ring-opacity-50'
|
||||
: ''}"
|
||||
>
|
||||
<!-- Top accent line -->
|
||||
<div class="absolute top-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/50 to-transparent"></div>
|
||||
|
||||
<div
|
||||
class="absolute top-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/50 to-transparent"
|
||||
></div>
|
||||
|
||||
<!-- Drag overlay -->
|
||||
{#if isDragOver}
|
||||
<div class="absolute inset-0 bg-exo-dark-gray/80 z-10 flex items-center justify-center">
|
||||
<div
|
||||
class="absolute inset-0 bg-exo-dark-gray/80 z-10 flex items-center justify-center"
|
||||
>
|
||||
<div class="text-exo-yellow text-sm font-mono tracking-wider uppercase">
|
||||
DROP FILES HERE
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
|
||||
<!-- Edit mode banner -->
|
||||
{#if isEditMode && currentEditingImage}
|
||||
<div
|
||||
class="flex items-center gap-3 px-3 py-2 bg-exo-yellow/10 border-b border-exo-yellow/30"
|
||||
>
|
||||
<img
|
||||
src={currentEditingImage.imageDataUrl}
|
||||
alt="Source for editing"
|
||||
class="w-10 h-10 object-cover rounded border border-exo-yellow/30"
|
||||
/>
|
||||
<div class="flex-1">
|
||||
<span
|
||||
class="text-xs font-mono tracking-wider uppercase text-exo-yellow"
|
||||
>EDITING IMAGE</span
|
||||
>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => clearEditingImage()}
|
||||
class="px-2 py-1 text-xs font-mono tracking-wider uppercase bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50 rounded hover:bg-exo-medium-gray/50 hover:text-exo-yellow transition-colors cursor-pointer"
|
||||
>
|
||||
CANCEL
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Model selector (when enabled) -->
|
||||
{#if showModelSelector && availableModels().length > 0}
|
||||
<div class="flex items-center justify-between gap-2 px-3 py-2 border-b border-exo-medium-gray/30">
|
||||
<div
|
||||
class="flex items-center justify-between gap-2 px-3 py-2 border-b border-exo-medium-gray/30"
|
||||
>
|
||||
<div class="flex items-center gap-2 flex-1">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider flex-shrink-0">MODEL:</span>
|
||||
<span
|
||||
class="text-xs text-exo-light-gray uppercase tracking-wider flex-shrink-0"
|
||||
>MODEL:</span
|
||||
>
|
||||
<!-- Custom dropdown -->
|
||||
<div class="relative flex-1 max-w-xs">
|
||||
<button
|
||||
bind:this={dropdownButtonRef}
|
||||
type="button"
|
||||
onclick={() => isModelDropdownOpen = !isModelDropdownOpen}
|
||||
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-1.5 text-xs font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen ? 'border-exo-yellow/70' : ''}"
|
||||
onclick={() => (isModelDropdownOpen = !isModelDropdownOpen)}
|
||||
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-3 pr-8 py-1.5 text-xs font-mono text-left tracking-wide cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isModelDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
>
|
||||
{#if availableModels().find(m => m.id === currentModel)}
|
||||
<span class="text-exo-yellow truncate">{availableModels().find(m => m.id === currentModel)?.label}</span>
|
||||
{#if availableModels().find((m) => m.id === currentModel)}
|
||||
<span class="text-exo-yellow truncate"
|
||||
>{availableModels().find((m) => m.id === currentModel)
|
||||
?.label}</span
|
||||
>
|
||||
{:else if availableModels().length > 0}
|
||||
<span class="text-exo-yellow truncate">{availableModels()[0].label}</span>
|
||||
<span class="text-exo-yellow truncate"
|
||||
>{availableModels()[0].label}</span
|
||||
>
|
||||
{:else}
|
||||
<span class="text-exo-light-gray/50">— SELECT MODEL —</span>
|
||||
{/if}
|
||||
</button>
|
||||
<div class="absolute right-2 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isModelDropdownOpen ? 'rotate-180' : ''}">
|
||||
<svg class="w-3 h-3 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
|
||||
<div
|
||||
class="absolute right-2 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isModelDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
{#if isModelDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => isModelDropdownOpen = false}
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isModelDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto"
|
||||
style="bottom: calc(100vh - {dropdownPosition().top}px + 4px); left: {dropdownPosition().left}px; width: {dropdownPosition().width}px;"
|
||||
style="bottom: calc(100vh - {dropdownPosition()
|
||||
.top}px + 4px); left: {dropdownPosition()
|
||||
.left}px; width: {dropdownPosition().width}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each availableModels() as model}
|
||||
@@ -311,20 +469,48 @@
|
||||
setSelectedChatModel(model.id);
|
||||
isModelDropdownOpen = false;
|
||||
}}
|
||||
class="w-full px-3 py-2 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {
|
||||
currentModel === model.id
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'
|
||||
}"
|
||||
class="w-full px-3 py-2 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {currentModel ===
|
||||
model.id
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if currentModel === model.id}
|
||||
<svg class="w-3 h-3 flex-shrink-0" fill="currentColor" viewBox="0 0 20 20">
|
||||
<path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd" />
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span class="truncate">{model.label}</span>
|
||||
{#if model.isImageModel}
|
||||
<svg
|
||||
class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
aria-label="Image generation model"
|
||||
>
|
||||
<rect
|
||||
x="3"
|
||||
y="3"
|
||||
width="18"
|
||||
height="18"
|
||||
rx="2"
|
||||
ry="2"
|
||||
/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate flex-1">{model.label}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
@@ -336,30 +522,37 @@
|
||||
<div class="flex items-center gap-4 text-xs font-mono flex-shrink-0">
|
||||
{#if currentTtft !== null}
|
||||
<span class="text-exo-light-gray">
|
||||
<span class="text-white/70">TTFT</span> <span class="text-exo-yellow">{currentTtft.toFixed(1)}ms</span>
|
||||
<span class="text-white/70">TTFT</span>
|
||||
<span class="text-exo-yellow">{currentTtft.toFixed(1)}ms</span>
|
||||
</span>
|
||||
{/if}
|
||||
{#if currentTps !== null}
|
||||
<span class="text-exo-light-gray">
|
||||
<span class="text-white/70">TPS</span> <span class="text-exo-yellow">{currentTps.toFixed(1)}</span> <span class="text-white/60">tok/s</span>
|
||||
<span class="text-white/50">({(1000 / currentTps).toFixed(1)} ms/tok)</span>
|
||||
<span class="text-white/70">TPS</span>
|
||||
<span class="text-exo-yellow">{currentTps.toFixed(1)}</span>
|
||||
<span class="text-white/60">tok/s</span>
|
||||
<span class="text-white/50"
|
||||
>({(1000 / currentTps).toFixed(1)} ms/tok)</span
|
||||
>
|
||||
</span>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
|
||||
<!-- Image params panel (shown for image models or edit mode) -->
|
||||
{#if showModelSelector && (isImageModel() || isEditMode)}
|
||||
<ImageParamsPanel {isEditMode} />
|
||||
{/if}
|
||||
|
||||
<!-- Attached files preview -->
|
||||
{#if uploadedFiles.length > 0}
|
||||
<div class="px-3 pt-3">
|
||||
<ChatAttachments
|
||||
files={uploadedFiles}
|
||||
onRemove={handleFileRemove}
|
||||
/>
|
||||
<ChatAttachments files={uploadedFiles} onRemove={handleFileRemove} />
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
|
||||
<!-- Input area -->
|
||||
<div class="flex items-start gap-2 sm:gap-3 py-3 px-3 sm:px-4">
|
||||
<!-- Attach file button -->
|
||||
@@ -370,58 +563,130 @@
|
||||
class="flex items-center justify-center w-7 h-7 rounded text-exo-light-gray hover:text-exo-yellow transition-all disabled:opacity-50 disabled:cursor-not-allowed flex-shrink-0 cursor-pointer"
|
||||
title="Attach file"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M15.172 7l-6.586 6.586a2 2 0 102.828 2.828l6.414-6.586a4 4 0 00-5.656-5.656l-6.415 6.585a6 6 0 108.486 8.486L20.5 13" />
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M15.172 7l-6.586 6.586a2 2 0 102.828 2.828l6.414-6.586a4 4 0 00-5.656-5.656l-6.415 6.585a6 6 0 108.486 8.486L20.5 13"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
|
||||
|
||||
<!-- Terminal prompt -->
|
||||
<span class="text-exo-yellow text-sm font-bold flex-shrink-0 leading-7">▶</span>
|
||||
|
||||
<span class="text-exo-yellow text-sm font-bold flex-shrink-0 leading-7"
|
||||
>▶</span
|
||||
>
|
||||
|
||||
<textarea
|
||||
bind:this={textareaRef}
|
||||
bind:value={message}
|
||||
onkeydown={handleKeydown}
|
||||
oninput={handleInput}
|
||||
onpaste={handlePaste}
|
||||
{placeholder}
|
||||
placeholder={isEditOnlyWithoutImage
|
||||
? "Attach an image to edit..."
|
||||
: isEditMode
|
||||
? "Describe how to edit this image..."
|
||||
: isImageModel()
|
||||
? "Describe the image you want to generate..."
|
||||
: placeholder}
|
||||
disabled={loading}
|
||||
rows={1}
|
||||
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading}
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label="Send message"
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
|
||||
<span class="hidden sm:inline">PROCESSING</span>
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
<div class="absolute bottom-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/30 to-transparent"></div>
|
||||
<div
|
||||
class="absolute bottom-0 left-0 right-0 h-px bg-gradient-to-r from-transparent via-exo-yellow/30 to-transparent"
|
||||
></div>
|
||||
</div>
|
||||
|
||||
|
||||
{#if showHelperText}
|
||||
<p class="mt-2 sm:mt-3 text-center text-xs sm:text-xs text-exo-light-gray tracking-[0.1em] sm:tracking-[0.15em] uppercase">
|
||||
<kbd class="px-1 sm:px-1.5 py-0.5 rounded bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50">ENTER</kbd>
|
||||
<p
|
||||
class="mt-2 sm:mt-3 text-center text-xs sm:text-xs text-exo-light-gray tracking-[0.1em] sm:tracking-[0.15em] uppercase"
|
||||
>
|
||||
<kbd
|
||||
class="px-1 sm:px-1.5 py-0.5 rounded bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50"
|
||||
>ENTER</kbd
|
||||
>
|
||||
<span class="mx-0.5 sm:mx-1">TO SEND</span>
|
||||
<span class="text-exo-medium-gray mx-1 sm:mx-2">|</span>
|
||||
<kbd class="px-1 sm:px-1.5 py-0.5 rounded bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50">SHIFT+ENTER</kbd>
|
||||
<kbd
|
||||
class="px-1 sm:px-1.5 py-0.5 rounded bg-exo-medium-gray/30 text-exo-light-gray border border-exo-medium-gray/50"
|
||||
>SHIFT+ENTER</kbd
|
||||
>
|
||||
<span class="mx-0.5 sm:mx-1">NEW LINE</span>
|
||||
<span class="text-exo-medium-gray mx-1 sm:mx-2">|</span>
|
||||
<span class="text-exo-light-gray">DRAG & DROP OR PASTE FILES</span>
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
messages,
|
||||
currentResponse,
|
||||
import {
|
||||
messages,
|
||||
currentResponse,
|
||||
isLoading,
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
regenerateLastResponse
|
||||
regenerateLastResponse,
|
||||
setEditingImage
|
||||
} from '$lib/stores/app.svelte';
|
||||
import type { Message } from '$lib/stores/app.svelte';
|
||||
import type { MessageAttachment } from '$lib/stores/app.svelte';
|
||||
import MarkdownContent from './MarkdownContent.svelte';
|
||||
|
||||
@@ -365,10 +367,76 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Generated Images -->
|
||||
{#if message.attachments?.some(a => a.type === 'generated-image')}
|
||||
<div class="mb-3">
|
||||
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
|
||||
<div class="relative group/img inline-block">
|
||||
<img
|
||||
src={attachment.preview}
|
||||
alt=""
|
||||
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
|
||||
/>
|
||||
<!-- Button overlay -->
|
||||
<div class="absolute top-2 right-2 flex gap-1 opacity-0 group-hover/img:opacity-100 transition-opacity">
|
||||
<!-- Edit button -->
|
||||
<button
|
||||
type="button"
|
||||
class="p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
|
||||
onclick={() => {
|
||||
if (attachment.preview) {
|
||||
setEditingImage(attachment.preview, message);
|
||||
}
|
||||
}}
|
||||
title="Edit image"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z" />
|
||||
</svg>
|
||||
</button>
|
||||
<!-- Download button -->
|
||||
<button
|
||||
type="button"
|
||||
class="p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
|
||||
onclick={() => {
|
||||
if (attachment.preview) {
|
||||
const link = document.createElement('a');
|
||||
link.href = attachment.preview;
|
||||
link.download = `generated-image-${Date.now()}.png`;
|
||||
link.click();
|
||||
}
|
||||
}}
|
||||
title="Download image"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="text-xs text-foreground">
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{#if message.content === 'Generating image...' || message.content === 'Editing image...' || message.content?.startsWith('Generating...') || message.content?.startsWith('Editing...')}
|
||||
<div class="flex items-center gap-3 text-exo-yellow">
|
||||
<div class="relative">
|
||||
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
|
||||
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
</div>
|
||||
<span class="font-mono tracking-wider uppercase text-sm">{message.content}</span>
|
||||
</div>
|
||||
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
471
dashboard/src/lib/components/ImageParamsPanel.svelte
Normal file
471
dashboard/src/lib/components/ImageParamsPanel.svelte
Normal file
@@ -0,0 +1,471 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
imageGenerationParams,
|
||||
setImageGenerationParams,
|
||||
resetImageGenerationParams,
|
||||
type ImageGenerationParams,
|
||||
} from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
isEditMode?: boolean;
|
||||
}
|
||||
|
||||
let { isEditMode = false }: Props = $props();
|
||||
|
||||
let showAdvanced = $state(false);
|
||||
|
||||
// Custom dropdown state
|
||||
let isSizeDropdownOpen = $state(false);
|
||||
let isQualityDropdownOpen = $state(false);
|
||||
let sizeButtonRef: HTMLButtonElement | undefined = $state();
|
||||
let qualityButtonRef: HTMLButtonElement | undefined = $state();
|
||||
|
||||
const sizeDropdownPosition = $derived(() => {
|
||||
if (!sizeButtonRef || !isSizeDropdownOpen) return { top: 0, left: 0, width: 0 };
|
||||
const rect = sizeButtonRef.getBoundingClientRect();
|
||||
return { top: rect.top, left: rect.left, width: rect.width };
|
||||
});
|
||||
|
||||
const qualityDropdownPosition = $derived(() => {
|
||||
if (!qualityButtonRef || !isQualityDropdownOpen) return { top: 0, left: 0, width: 0 };
|
||||
const rect = qualityButtonRef.getBoundingClientRect();
|
||||
return { top: rect.top, left: rect.left, width: rect.width };
|
||||
});
|
||||
|
||||
const params = $derived(imageGenerationParams());
|
||||
|
||||
const inputFidelityOptions: ImageGenerationParams["inputFidelity"][] = [
|
||||
"low",
|
||||
"high",
|
||||
];
|
||||
|
||||
function handleInputFidelityChange(value: ImageGenerationParams["inputFidelity"]) {
|
||||
setImageGenerationParams({ inputFidelity: value });
|
||||
}
|
||||
|
||||
const sizeOptions: ImageGenerationParams["size"][] = [
|
||||
"512x512",
|
||||
"768x768",
|
||||
"1024x1024",
|
||||
"1024x768",
|
||||
"768x1024",
|
||||
];
|
||||
|
||||
const qualityOptions: ImageGenerationParams["quality"][] = [
|
||||
"low",
|
||||
"medium",
|
||||
"high",
|
||||
];
|
||||
|
||||
function selectSize(value: ImageGenerationParams["size"]) {
|
||||
setImageGenerationParams({ size: value });
|
||||
isSizeDropdownOpen = false;
|
||||
}
|
||||
|
||||
function selectQuality(value: ImageGenerationParams["quality"]) {
|
||||
setImageGenerationParams({ quality: value });
|
||||
isQualityDropdownOpen = false;
|
||||
}
|
||||
|
||||
function handleSeedChange(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
const value = input.value.trim();
|
||||
if (value === "") {
|
||||
setImageGenerationParams({ seed: null });
|
||||
} else {
|
||||
const num = parseInt(value, 10);
|
||||
if (!isNaN(num) && num >= 0) {
|
||||
setImageGenerationParams({ seed: num });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleStepsChange(event: Event) {
|
||||
const value = parseInt((event.target as HTMLInputElement).value, 10);
|
||||
setImageGenerationParams({ numInferenceSteps: value });
|
||||
}
|
||||
|
||||
function handleGuidanceChange(event: Event) {
|
||||
const value = parseFloat((event.target as HTMLInputElement).value);
|
||||
setImageGenerationParams({ guidance: value });
|
||||
}
|
||||
|
||||
function handleNegativePromptChange(event: Event) {
|
||||
const value = (event.target as HTMLTextAreaElement).value;
|
||||
setImageGenerationParams({ negativePrompt: value || null });
|
||||
}
|
||||
|
||||
function clearSteps() {
|
||||
setImageGenerationParams({ numInferenceSteps: null });
|
||||
}
|
||||
|
||||
function clearGuidance() {
|
||||
setImageGenerationParams({ guidance: null });
|
||||
}
|
||||
|
||||
function handleReset() {
|
||||
resetImageGenerationParams();
|
||||
showAdvanced = false;
|
||||
}
|
||||
|
||||
const hasAdvancedParams = $derived(
|
||||
params.seed !== null ||
|
||||
params.numInferenceSteps !== null ||
|
||||
params.guidance !== null ||
|
||||
(params.negativePrompt !== null && params.negativePrompt.trim() !== ""),
|
||||
);
|
||||
</script>
|
||||
|
||||
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
|
||||
<!-- Basic params row -->
|
||||
<div class="flex items-center gap-3 flex-wrap">
|
||||
<!-- Size -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => isSizeDropdownOpen = !isSizeDropdownOpen}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen ? 'border-exo-yellow/70' : ''}"
|
||||
>
|
||||
{params.size}
|
||||
</button>
|
||||
<div class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen ? 'rotate-180' : ''}">
|
||||
<svg class="w-3 h-3 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => isSizeDropdownOpen = false}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition().top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {
|
||||
params.size === size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'
|
||||
}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg class="w-3 h-3 flex-shrink-0" fill="currentColor" viewBox="0 0 20 20">
|
||||
<path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Quality -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>QUALITY:</span
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={qualityButtonRef}
|
||||
type="button"
|
||||
onclick={() => isQualityDropdownOpen = !isQualityDropdownOpen}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isQualityDropdownOpen ? 'border-exo-yellow/70' : ''}"
|
||||
>
|
||||
{params.quality.toUpperCase()}
|
||||
</button>
|
||||
<div class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isQualityDropdownOpen ? 'rotate-180' : ''}">
|
||||
<svg class="w-3 h-3 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isQualityDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => isQualityDropdownOpen = false}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {qualityDropdownPosition().top}px + 4px); left: {qualityDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each qualityOptions as quality}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectQuality(quality)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {
|
||||
params.quality === quality
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'
|
||||
}"
|
||||
>
|
||||
{#if params.quality === quality}
|
||||
<svg class="w-3 h-3 flex-shrink-0" fill="currentColor" viewBox="0 0 20 20">
|
||||
<path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd" />
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{quality.toUpperCase()}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Input Fidelity (edit mode only) -->
|
||||
{#if isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>FIDELITY:</span
|
||||
>
|
||||
<div class="flex rounded overflow-hidden border border-exo-yellow/30">
|
||||
{#each inputFidelityOptions as fidelity}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => handleInputFidelityChange(fidelity)}
|
||||
class="px-2 py-1 text-xs font-mono uppercase transition-all duration-200 cursor-pointer {
|
||||
params.inputFidelity === fidelity
|
||||
? 'bg-exo-yellow text-exo-black'
|
||||
: 'bg-exo-medium-gray/50 text-exo-light-gray hover:text-exo-yellow'
|
||||
}"
|
||||
title={fidelity === 'low' ? 'More creative variation' : 'Closer to original'}
|
||||
>
|
||||
{fidelity}
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Spacer -->
|
||||
<div class="flex-1"></div>
|
||||
|
||||
<!-- Advanced toggle -->
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => (showAdvanced = !showAdvanced)}
|
||||
class="flex items-center gap-1 text-xs font-mono tracking-wider uppercase transition-colors duration-200 {showAdvanced ||
|
||||
hasAdvancedParams
|
||||
? 'text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
<span>ADVANCED</span>
|
||||
<svg
|
||||
class="w-3 h-3 transition-transform duration-200 {showAdvanced
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
{#if hasAdvancedParams && !showAdvanced}
|
||||
<span class="w-1.5 h-1.5 rounded-full bg-exo-yellow"></span>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Advanced params section -->
|
||||
{#if showAdvanced}
|
||||
<div class="mt-3 pt-3 border-t border-exo-medium-gray/20 space-y-3">
|
||||
<!-- Row 1: Seed and Steps -->
|
||||
<div class="flex items-center gap-4 flex-wrap">
|
||||
<!-- Seed -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SEED:</span
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="0"
|
||||
value={params.seed ?? ""}
|
||||
oninput={handleSeedChange}
|
||||
placeholder="Random"
|
||||
class="w-24 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow placeholder:text-exo-light-gray/50 transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Steps Slider -->
|
||||
<div class="flex items-center gap-1.5 flex-1 min-w-[200px]">
|
||||
<span
|
||||
class="text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap"
|
||||
>STEPS:</span
|
||||
>
|
||||
<div class="flex items-center gap-2 flex-1">
|
||||
<input
|
||||
type="range"
|
||||
min="1"
|
||||
max="100"
|
||||
value={params.numInferenceSteps ?? 50}
|
||||
oninput={handleStepsChange}
|
||||
class="flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow"
|
||||
/>
|
||||
<span class="text-xs font-mono text-exo-yellow w-8 text-right">
|
||||
{params.numInferenceSteps ?? "--"}
|
||||
</span>
|
||||
{#if params.numInferenceSteps !== null}
|
||||
<button
|
||||
type="button"
|
||||
onclick={clearSteps}
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
title="Clear"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Row 2: Guidance -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span
|
||||
class="text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap"
|
||||
>GUIDANCE:</span
|
||||
>
|
||||
<div class="flex items-center gap-2 flex-1 max-w-xs">
|
||||
<input
|
||||
type="range"
|
||||
min="1"
|
||||
max="20"
|
||||
step="0.5"
|
||||
value={params.guidance ?? 7.5}
|
||||
oninput={handleGuidanceChange}
|
||||
class="flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow"
|
||||
/>
|
||||
<span class="text-xs font-mono text-exo-yellow w-8 text-right">
|
||||
{params.guidance !== null ? params.guidance.toFixed(1) : "--"}
|
||||
</span>
|
||||
{#if params.guidance !== null}
|
||||
<button
|
||||
type="button"
|
||||
onclick={clearGuidance}
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
title="Clear"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Row 3: Negative Prompt -->
|
||||
<div class="flex flex-col gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>NEGATIVE PROMPT:</span
|
||||
>
|
||||
<textarea
|
||||
value={params.negativePrompt ?? ""}
|
||||
oninput={handleNegativePromptChange}
|
||||
placeholder="Things to avoid in the image..."
|
||||
rows={2}
|
||||
class="w-full bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1.5 text-xs font-mono text-exo-yellow placeholder:text-exo-light-gray/50 resize-none transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
></textarea>
|
||||
</div>
|
||||
|
||||
<!-- Reset Button -->
|
||||
<div class="flex justify-end pt-1">
|
||||
<button
|
||||
type="button"
|
||||
onclick={handleReset}
|
||||
class="text-xs font-mono tracking-wider uppercase text-exo-light-gray hover:text-exo-yellow transition-colors duration-200"
|
||||
>
|
||||
RESET TO DEFAULTS
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
/* Custom range slider styling */
|
||||
input[type="range"]::-webkit-slider-thumb {
|
||||
-webkit-appearance: none;
|
||||
appearance: none;
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background: #ffd700;
|
||||
cursor: pointer;
|
||||
border: none;
|
||||
}
|
||||
|
||||
input[type="range"]::-moz-range-thumb {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 50%;
|
||||
background: #ffd700;
|
||||
cursor: pointer;
|
||||
border: none;
|
||||
}
|
||||
|
||||
/* Hide number input spinners */
|
||||
input[type="number"]::-webkit-inner-spin-button,
|
||||
input[type="number"]::-webkit-outer-spin-button {
|
||||
-webkit-appearance: none;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
input[type="number"] {
|
||||
-moz-appearance: textfield;
|
||||
}
|
||||
</style>
|
||||
@@ -5,3 +5,4 @@ export { default as ChatAttachments } from "./ChatAttachments.svelte";
|
||||
export { default as ChatSidebar } from "./ChatSidebar.svelte";
|
||||
export { default as ModelCard } from "./ModelCard.svelte";
|
||||
export { default as MarkdownContent } from "./MarkdownContent.svelte";
|
||||
export { default as ImageParamsPanel } from "./ImageParamsPanel.svelte";
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -47,7 +47,30 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let mounted = $state(false);
|
||||
|
||||
// Instance launch state
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
|
||||
|
||||
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
|
||||
const modelTasks = $derived(() => {
|
||||
const tasks: Record<string, string[]> = {};
|
||||
for (const model of models) {
|
||||
if (model.tasks && model.tasks.length > 0) {
|
||||
// Map by short ID
|
||||
tasks[model.id] = model.tasks;
|
||||
// Also map by hugging_face_id from the API response
|
||||
if (model.hugging_face_id) {
|
||||
tasks[model.hugging_face_id] = model.tasks;
|
||||
}
|
||||
}
|
||||
}
|
||||
return tasks;
|
||||
});
|
||||
|
||||
// Helper to check if a model supports image generation
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
|
||||
if (!model?.tasks) return false;
|
||||
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
|
||||
}
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
@@ -1270,6 +1293,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
placeholder="Ask anything"
|
||||
showHelperText={false}
|
||||
showModelSelector={true}
|
||||
modelTasks={modelTasks()}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -1491,8 +1515,18 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{@const foundModel = models.find(m => m.id === selectedModelId)}
|
||||
{#if foundModel}
|
||||
{@const sizeGB = getModelSizeGB(foundModel)}
|
||||
<span class="flex items-center justify-between gap-2 w-full pr-4">
|
||||
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
|
||||
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
|
||||
<span class="flex items-center justify-between gap-2 w-full pr-4">
|
||||
<span class="flex items-center gap-2 text-exo-light-gray truncate">
|
||||
{#if isImageModel}
|
||||
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{foundModel.name || foundModel.id}</span>
|
||||
</span>
|
||||
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
|
||||
</span>
|
||||
{:else}
|
||||
@@ -1537,6 +1571,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
) as model}
|
||||
{@const sizeGB = getModelSizeGB(model)}
|
||||
{@const modelCanFit = hasEnoughMemory(model)}
|
||||
{@const isImageModel = modelSupportsImageGeneration(model.id)}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
@@ -1556,7 +1591,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
: 'text-white/30 cursor-default'
|
||||
}"
|
||||
>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
<span class="flex items-center gap-2 truncate flex-1">
|
||||
{#if isImageModel}
|
||||
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
</span>
|
||||
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
|
||||
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
|
||||
</span>
|
||||
@@ -1753,7 +1797,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} />
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
77
docs/api.md
77
docs/api.md
@@ -1,6 +1,6 @@
|
||||
# EXO API – Technical Reference
|
||||
|
||||
This document describes the REST API exposed by the **EXO ** service, as implemented in:
|
||||
This document describes the REST API exposed by the **EXO** service, as implemented in:
|
||||
|
||||
`src/exo/master/api.py`
|
||||
|
||||
@@ -183,7 +183,70 @@ Same schema as `/v1/chat/completions`.
|
||||
**Response:**
|
||||
Chat completion plus benchmarking metrics.
|
||||
|
||||
## 5. Complete Endpoint Summary
|
||||
## 5. Image Generation & Editing
|
||||
|
||||
### Image Generation
|
||||
|
||||
**POST** `/v1/images/generations`
|
||||
|
||||
Executes an image generation request using an OpenAI-compatible schema with additional advanced_params.
|
||||
|
||||
**Request body (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "a robot playing chess",
|
||||
"model": "flux-dev",
|
||||
"stream": false,
|
||||
}
|
||||
```
|
||||
|
||||
**Advanced Parameters (`advanced_params`):**
|
||||
|
||||
| Parameter | Type | Constraints | Description |
|
||||
|-----------|------|-------------|-------------|
|
||||
| `seed` | int | >= 0 | Random seed for reproducible generation |
|
||||
| `num_inference_steps` | int | 1-100 | Number of denoising steps |
|
||||
| `guidance` | float | 1.0-20.0 | Classifier-free guidance scale |
|
||||
| `negative_prompt` | string | - | Text describing what to avoid in the image |
|
||||
|
||||
**Response:**
|
||||
OpenAI-compatible image generation response.
|
||||
|
||||
### Benchmarked Image Generation
|
||||
|
||||
**POST** `/bench/images/generations`
|
||||
|
||||
Same as `/v1/images/generations`, but also returns generation statistics.
|
||||
|
||||
**Request body:**
|
||||
Same schema as `/v1/images/generations`.
|
||||
|
||||
**Response:**
|
||||
Image generation plus benchmarking metrics.
|
||||
|
||||
### Image Editing
|
||||
|
||||
**POST** `/v1/images/edits`
|
||||
|
||||
Executes an image editing request using an OpenAI-compatible schema with additional advanced_params (same as `/v1/images/generations`).
|
||||
|
||||
**Response:**
|
||||
Same format as `/v1/images/generations`.
|
||||
|
||||
### Benchmarked Image Editing
|
||||
|
||||
**POST** `/bench/images/edits`
|
||||
|
||||
Same as `/v1/images/edits`, but also returns generation statistics.
|
||||
|
||||
**Request:**
|
||||
Same schema as `/v1/images/edits`.
|
||||
|
||||
**Response:**
|
||||
Same format as `/bench/images/generations`, including `generation_stats`.
|
||||
|
||||
## 6. Complete Endpoint Summary
|
||||
|
||||
```
|
||||
GET /node_id
|
||||
@@ -203,10 +266,16 @@ GET /v1/models
|
||||
|
||||
POST /v1/chat/completions
|
||||
POST /bench/chat/completions
|
||||
|
||||
POST /v1/images/generations
|
||||
POST /bench/images/generations
|
||||
POST /v1/images/edits
|
||||
POST /bench/images/edits
|
||||
```
|
||||
|
||||
## 6. Notes
|
||||
## 7. Notes
|
||||
|
||||
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
|
||||
* The `/v1/chat/completions` endpoint is compatible with the OpenAI Chat API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
|
||||
* The `/v1/images/generations` and `/v1/images/edits` endpoints are compatible with the OpenAI Images API format.
|
||||
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
|
||||
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.
|
||||
|
||||
@@ -24,6 +24,9 @@ dependencies = [
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux>=0.14.2",
|
||||
"python-multipart>=0.0.21",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
from typing import Literal, cast
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -17,6 +19,7 @@ from loguru import logger
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
@@ -24,6 +27,8 @@ from exo.shared.models.model_meta import get_model_card
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
BenchImageGenerationResponse,
|
||||
BenchImageGenerationTaskParams,
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
@@ -34,6 +39,11 @@ from exo.shared.types.api import (
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ImageData,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationResponse,
|
||||
ImageGenerationStats,
|
||||
ImageGenerationTaskParams,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -41,14 +51,17 @@ from exo.shared.types.api import (
|
||||
PlacementPreviewResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
@@ -88,10 +101,13 @@ def chunk_to_response(
|
||||
|
||||
async def resolve_model_card(model_id: str) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
else:
|
||||
return await get_model_card(model_id)
|
||||
return MODEL_CARDS[model_id]
|
||||
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == ModelId(model_id):
|
||||
return card
|
||||
|
||||
return await get_model_card(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
@@ -136,6 +152,7 @@ class API:
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
|
||||
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
@@ -144,6 +161,7 @@ class API:
|
||||
self.session_id = new_session_id
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self._chat_completion_queues = {}
|
||||
self._image_generation_queues = {}
|
||||
self.unpause(result_clock)
|
||||
|
||||
def unpause(self, result_clock: int):
|
||||
@@ -191,6 +209,12 @@ class API:
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.post("/v1/images/generations", response_model=None)(
|
||||
self.image_generations
|
||||
)
|
||||
self.app.post("/bench/images/generations")(self.bench_image_generations)
|
||||
self.app.post("/v1/images/edits", response_model=None)(self.image_edits)
|
||||
self.app.post("/bench/images/edits")(self.bench_image_edits)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -598,6 +622,379 @@ class API:
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _validate_image_model(self, model: str) -> ModelId:
|
||||
"""Validate model exists and return resolved model ID.
|
||||
|
||||
Raises HTTPException 404 if no instance is found for the model.
|
||||
"""
|
||||
model_card = await resolve_model_card(model)
|
||||
resolved_model = model_card.model_id
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(resolved_model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {resolved_model}"
|
||||
)
|
||||
return resolved_model
|
||||
|
||||
async def image_generations(
|
||||
self, payload: ImageGenerationTaskParams
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image generation requests.
|
||||
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
# Check if streaming is requested
|
||||
if payload.stream and payload.partial_images and payload.partial_images > 0:
|
||||
return StreamingResponse(
|
||||
self._generate_image_stream(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming: collect all image chunks
|
||||
return await self._collect_image_generation(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
)
|
||||
|
||||
async def _generate_image_stream(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate SSE stream of partial and final images."""
|
||||
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
|
||||
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
|
||||
image_total_chunks: dict[tuple[int, bool], int] = {}
|
||||
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
key = (chunk.image_index, chunk.is_partial)
|
||||
|
||||
if key not in image_chunks:
|
||||
image_chunks[key] = {}
|
||||
image_total_chunks[key] = chunk.total_chunks
|
||||
image_metadata[key] = (
|
||||
chunk.partial_index,
|
||||
chunk.total_partials,
|
||||
)
|
||||
|
||||
image_chunks[key][chunk.chunk_index] = chunk.data
|
||||
|
||||
# Check if this image is complete
|
||||
if len(image_chunks[key]) == image_total_chunks[key]:
|
||||
full_data = "".join(
|
||||
image_chunks[key][i] for i in range(len(image_chunks[key]))
|
||||
)
|
||||
|
||||
partial_idx, total_partials = image_metadata[key]
|
||||
|
||||
if chunk.is_partial:
|
||||
# Yield partial image event
|
||||
event_data = {
|
||||
"type": "partial",
|
||||
"partial_index": partial_idx,
|
||||
"total_partials": total_partials,
|
||||
"data": {
|
||||
"b64_json": full_data
|
||||
if response_format == "b64_json"
|
||||
else None,
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
else:
|
||||
# Final image
|
||||
event_data = {
|
||||
"type": "final",
|
||||
"image_index": chunk.image_index,
|
||||
"data": {
|
||||
"b64_json": full_data
|
||||
if response_format == "b64_json"
|
||||
else None,
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
# Clean up completed image chunks
|
||||
del image_chunks[key]
|
||||
del image_total_chunks[key]
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._image_generation_queues:
|
||||
del self._image_generation_queues[command_id]
|
||||
|
||||
async def _collect_image_chunks(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
capture_stats: bool = False,
|
||||
) -> tuple[list[ImageData], ImageGenerationStats | None]:
|
||||
"""Collect image chunks and optionally capture stats."""
|
||||
# Track chunks per image: {image_index: {chunk_index: data}}
|
||||
# Only track non-partial (final) images
|
||||
image_chunks: dict[int, dict[int, str]] = {}
|
||||
image_total_chunks: dict[int, int] = {}
|
||||
images_complete = 0
|
||||
stats: ImageGenerationStats | None = None
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
|
||||
while images_complete < num_images:
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
if chunk.is_partial:
|
||||
continue
|
||||
|
||||
if chunk.image_index not in image_chunks:
|
||||
image_chunks[chunk.image_index] = {}
|
||||
image_total_chunks[chunk.image_index] = chunk.total_chunks
|
||||
|
||||
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
|
||||
|
||||
if capture_stats and chunk.stats is not None:
|
||||
stats = chunk.stats
|
||||
|
||||
if (
|
||||
len(image_chunks[chunk.image_index])
|
||||
== image_total_chunks[chunk.image_index]
|
||||
):
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
break
|
||||
|
||||
images: list[ImageData] = []
|
||||
for image_idx in range(num_images):
|
||||
chunks_dict = image_chunks[image_idx]
|
||||
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
|
||||
images.append(
|
||||
ImageData(
|
||||
b64_json=full_data if response_format == "b64_json" else None,
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._image_generation_queues:
|
||||
del self._image_generation_queues[command_id]
|
||||
|
||||
async def _collect_image_generation(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> ImageGenerationResponse:
|
||||
"""Collect all image chunks (non-streaming) and return a single response."""
|
||||
images, _ = await self._collect_image_chunks(
|
||||
command_id, num_images, response_format, capture_stats=False
|
||||
)
|
||||
return ImageGenerationResponse(data=images)
|
||||
|
||||
async def _collect_image_generation_with_stats(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> BenchImageGenerationResponse:
|
||||
images, stats = await self._collect_image_chunks(
|
||||
command_id, num_images, response_format, capture_stats=True
|
||||
)
|
||||
return BenchImageGenerationResponse(data=images, generation_stats=stats)
|
||||
|
||||
async def bench_image_generations(
|
||||
self, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return await self._collect_image_generation_with_stats(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
)
|
||||
|
||||
async def _send_image_edits_command(
|
||||
self,
|
||||
image: UploadFile,
|
||||
prompt: str,
|
||||
model: str,
|
||||
n: int,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
input_fidelity: Literal["low", "high"],
|
||||
stream: bool,
|
||||
partial_images: int,
|
||||
bench: bool,
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
image_strength = 0.7 if input_fidelity == "high" else 0.3
|
||||
|
||||
data_chunks = [
|
||||
image_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
|
||||
command = ImageEdits(
|
||||
request_params=ImageEditsInternalParams(
|
||||
image_data="",
|
||||
total_input_chunks=total_chunks,
|
||||
prompt=prompt,
|
||||
model=resolved_model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
image_strength=image_strength,
|
||||
stream=stream,
|
||||
partial_images=partial_images,
|
||||
bench=bench,
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(data_chunks):
|
||||
await self._send(
|
||||
SendInputChunk(
|
||||
chunk=InputImageChunk(
|
||||
idx=chunk_index,
|
||||
model=resolved_model,
|
||||
command_id=command.command_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await self._send(command)
|
||||
return command
|
||||
|
||||
async def image_edits(
|
||||
self,
|
||||
image: UploadFile = File(...),
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: str = Form("false"),
|
||||
partial_images: str = Form("0"),
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image editing requests (img2img)."""
|
||||
# Parse string form values to proper types
|
||||
stream_bool = stream.lower() in ("true", "1", "yes")
|
||||
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
|
||||
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=stream_bool,
|
||||
partial_images=partial_images_int,
|
||||
bench=False,
|
||||
)
|
||||
|
||||
if stream_bool and partial_images_int > 0:
|
||||
return StreamingResponse(
|
||||
self._generate_image_stream(
|
||||
command_id=command.command_id,
|
||||
num_images=n,
|
||||
response_format=response_format,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_image_generation(
|
||||
command_id=command.command_id,
|
||||
num_images=n,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
async def bench_image_edits(
|
||||
self,
|
||||
image: UploadFile = File(...),
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
) -> BenchImageGenerationResponse:
|
||||
"""Handle benchmark image editing requests with generation stats."""
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=False,
|
||||
partial_images=0,
|
||||
bench=True,
|
||||
)
|
||||
|
||||
return await self._collect_image_generation_with_stats(
|
||||
command_id=command.command_id,
|
||||
num_images=n,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
@@ -619,6 +1016,7 @@ class API:
|
||||
tags=[],
|
||||
storage_size_megabytes=int(card.storage_size.in_mb),
|
||||
supports_tensor=card.supports_tensor,
|
||||
tasks=[task.value for task in card.tasks],
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
]
|
||||
@@ -657,13 +1055,26 @@ class API:
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if isinstance(event, ChunkGenerated):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(event.command_id, None)
|
||||
if event.command_id in self._chat_completion_queues:
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
elif event.command_id in self._image_generation_queues:
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
queue = self._image_generation_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._image_generation_queues.pop(
|
||||
event.command_id, None
|
||||
)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -16,8 +16,11 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -26,6 +29,7 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
@@ -36,6 +40,12 @@ from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion as ChatCompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageEdits as ImageEditsTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageGeneration as ImageGenerationTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
@@ -100,13 +110,14 @@ class Master:
|
||||
async for forwarder_command in commands:
|
||||
try:
|
||||
logger.info(f"Executing command: {forwarder_command.command}")
|
||||
|
||||
generated_events: list[Event] = []
|
||||
command = forwarder_command.command
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
match command:
|
||||
case TestCommand():
|
||||
pass
|
||||
case ChatCompletion():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
@@ -147,6 +158,90 @@ class Master:
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageGeneration():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageGenerationTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageEdits():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageEditsTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
@@ -176,6 +271,13 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
generated_events.append(
|
||||
InputChunkReceived(
|
||||
command_id=chunk.command_id,
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
|
||||
@@ -7,7 +7,7 @@ from loguru import logger
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
@@ -115,6 +115,7 @@ async def test_master():
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
@@ -172,6 +173,7 @@ async def test_master():
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
|
||||
@@ -10,7 +10,7 @@ from exo.master.tests.conftest import (
|
||||
create_rdma_connection,
|
||||
create_socket_connection,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
@@ -50,6 +50,7 @@ def model_card() -> ModelCard:
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
@@ -169,6 +170,7 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {}, node_memory, node_network)
|
||||
@@ -195,6 +197,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {}, node_memory, node_network)
|
||||
@@ -221,6 +224,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from exo.master.tests.conftest import (
|
||||
create_node_memory,
|
||||
create_socket_connection,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -238,6 +238,7 @@ def test_get_shard_assignments(
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
@@ -517,6 +518,7 @@ def test_get_shard_assignments_insufficient_memory_raises():
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
@@ -9,6 +9,7 @@ from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
NodeDownloadProgress,
|
||||
@@ -52,8 +53,8 @@ def event_apply(event: Event, state: State) -> State:
|
||||
"""Apply an event to state."""
|
||||
match event:
|
||||
case (
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged()
|
||||
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
return apply_instance_created(event, state)
|
||||
|
||||
@@ -45,3 +45,5 @@ LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
||||
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
|
||||
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
|
||||
LIBP2P_COMMANDS_TOPIC = "commands"
|
||||
|
||||
EXO_MAX_CHUNK_SIZE = 512 * 1024
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from pydantic import PositiveInt
|
||||
from enum import Enum
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BeforeValidator, PositiveInt
|
||||
|
||||
from exo.shared.types.common import Id
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -13,12 +16,38 @@ class ModelId(Id):
|
||||
return self.split("/")[-1]
|
||||
|
||||
|
||||
class ModelTask(str, Enum):
|
||||
TextGeneration = "TextGeneration"
|
||||
TextToImage = "TextToImage"
|
||||
ImageToImage = "ImageToImage"
|
||||
|
||||
|
||||
def _coerce_model_task(v: object) -> object:
|
||||
if isinstance(v, str):
|
||||
return ModelTask(v)
|
||||
return v
|
||||
|
||||
|
||||
CoercedModelTask = Annotated[ModelTask, BeforeValidator(_coerce_model_task)]
|
||||
|
||||
|
||||
class ComponentInfo(CamelCaseModel):
|
||||
component_name: str
|
||||
component_path: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt | None
|
||||
can_shard: bool
|
||||
safetensors_index_filename: str | None
|
||||
|
||||
|
||||
class ModelCard(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
tasks: list[CoercedModelTask]
|
||||
components: list[ComponentInfo] | None = None
|
||||
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
@@ -29,6 +58,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
@@ -36,6 +66,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
@@ -44,6 +75,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
@@ -51,6 +83,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
@@ -59,6 +92,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
@@ -66,6 +100,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
@@ -73,6 +108,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
@@ -80,6 +116,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.2
|
||||
"llama-3.2-1b": ModelCard(
|
||||
@@ -88,6 +125,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
@@ -95,6 +133,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
@@ -102,6 +141,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.3
|
||||
"llama-3.3-70b": ModelCard(
|
||||
@@ -110,6 +150,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
@@ -117,6 +158,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
@@ -124,6 +166,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# qwen3
|
||||
"qwen3-0.6b": ModelCard(
|
||||
@@ -132,6 +175,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
@@ -139,6 +183,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
@@ -146,6 +191,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
@@ -153,6 +199,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
@@ -160,6 +207,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
@@ -167,6 +215,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
@@ -174,6 +223,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
@@ -181,6 +231,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
@@ -188,6 +239,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
@@ -195,6 +247,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
@@ -202,6 +255,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
@@ -209,6 +263,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
@@ -217,6 +272,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"gpt-oss-20b-MXFP4-Q8": ModelCard(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
@@ -224,6 +280,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# glm 4.5
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
@@ -233,6 +290,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
@@ -240,6 +298,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
@@ -248,6 +307,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
@@ -255,6 +315,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
@@ -262,6 +323,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# glm 4.7 flash
|
||||
"glm-4.7-flash-4bit": ModelCard(
|
||||
@@ -270,6 +332,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-flash-5bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
|
||||
@@ -277,6 +340,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-flash-6bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
|
||||
@@ -284,6 +348,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"glm-4.7-flash-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
|
||||
@@ -291,6 +356,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
@@ -299,6 +365,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
@@ -306,5 +373,158 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.worker.download.download_utils import (
|
||||
ModelSafetensorsIndex,
|
||||
@@ -119,4 +119,7 @@ async def _get_model_card(model_id: str) -> ModelCard:
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=model_card.supports_tensor if model_card is not None else False,
|
||||
tasks=model_card.tasks
|
||||
if model_card is not None
|
||||
else [ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
|
||||
@@ -37,6 +37,7 @@ def get_pipeline_shard_metadata(
|
||||
n_layers=32,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
@@ -39,6 +41,7 @@ class ModelListModel(BaseModel):
|
||||
tags: list[str] = Field(default=[])
|
||||
storage_size_megabytes: int = Field(default=0)
|
||||
supports_tensor: bool = Field(default=False)
|
||||
tasks: list[str] = Field(default=[])
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
@@ -137,6 +140,19 @@ class GenerationStats(BaseModel):
|
||||
peak_memory_usage: Memory
|
||||
|
||||
|
||||
class ImageGenerationStats(BaseModel):
|
||||
seconds_per_step: float
|
||||
total_generation_time: float
|
||||
|
||||
num_inference_steps: int
|
||||
num_images: int
|
||||
|
||||
image_width: int
|
||||
image_height: int
|
||||
|
||||
peak_memory_usage: Memory
|
||||
|
||||
|
||||
class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
@@ -213,3 +229,103 @@ class DeleteInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class AdvancedImageParams(BaseModel):
|
||||
seed: Annotated[int, Field(ge=0)] | None = None
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None
|
||||
negative_prompt: str | None = None
|
||||
|
||||
|
||||
class ImageGenerationTaskParams(BaseModel):
|
||||
prompt: str
|
||||
background: str | None = None
|
||||
model: str
|
||||
moderation: str | None = None
|
||||
n: int | None = 1
|
||||
output_compression: int | None = None
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_images: int | None = 0
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
stream: bool | None = False
|
||||
style: str | None = "vivid"
|
||||
user: str | None = None
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
|
||||
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
|
||||
bench: bool = True
|
||||
|
||||
|
||||
class ImageEditsTaskParams(BaseModel):
|
||||
image: UploadFile
|
||||
prompt: str
|
||||
background: str | None = None
|
||||
input_fidelity: float | None = None
|
||||
mask: UploadFile | None = None
|
||||
model: str
|
||||
n: int | None = 1
|
||||
output_compression: int | None = None
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_images: int | None = 0
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
stream: bool | None = False
|
||||
user: str | None = None
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
|
||||
class ImageEditsInternalParams(BaseModel):
|
||||
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
|
||||
|
||||
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
|
||||
total_input_chunks: int = 0
|
||||
prompt: str
|
||||
model: str
|
||||
n: int | None = 1
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
image_strength: float | None = 0.7
|
||||
stream: bool = False
|
||||
partial_images: int | None = 0
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
bench: bool = False
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class ImageData(BaseModel):
|
||||
b64_json: str | None = None
|
||||
url: str | None = None
|
||||
revised_prompt: str | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "b64_json" and self.b64_json is not None:
|
||||
yield name, f"<{len(self.b64_json)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
data: list[ImageData]
|
||||
|
||||
|
||||
class BenchImageGenerationResponse(ImageGenerationResponse):
|
||||
generation_stats: ImageGenerationStats | None = None
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .common import CommandId
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
@@ -26,7 +29,35 @@ class TokenChunk(BaseChunk):
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
data: bytes
|
||||
data: str
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
image_index: int
|
||||
is_partial: bool = False
|
||||
partial_index: int | None = None
|
||||
total_partials: int | None = None
|
||||
stats: ImageGenerationStats | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "data" and hasattr(value, "__len__"):
|
||||
yield name, f"<{len(self.data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class InputImageChunk(BaseChunk):
|
||||
command_id: CommandId
|
||||
data: str
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "data" and hasattr(value, "__len__"):
|
||||
yield name, f"<{len(self.data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
@@ -20,6 +25,14 @@ class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
request_params: ImageGenerationTaskParams
|
||||
|
||||
|
||||
class ImageEdits(BaseCommand):
|
||||
request_params: ImageEditsInternalParams
|
||||
|
||||
|
||||
class PlaceInstance(BaseCommand):
|
||||
model_card: ModelCard
|
||||
sharding: Sharding
|
||||
@@ -39,6 +52,12 @@ class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
|
||||
class SendInputChunk(BaseCommand):
|
||||
"""Command to send an input image chunk (converted to event by master)."""
|
||||
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class RequestEventLog(BaseCommand):
|
||||
since_idx: int
|
||||
|
||||
@@ -47,10 +66,13 @@ Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
| ChatCompletion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from datetime import datetime
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
@@ -96,6 +96,11 @@ class ChunkGenerated(BaseEvent):
|
||||
chunk: GenerationChunk
|
||||
|
||||
|
||||
class InputChunkReceived(BaseEvent):
|
||||
command_id: CommandId
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
@@ -119,6 +124,7 @@ Event = (
|
||||
| NodeGatheredInfo
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
)
|
||||
|
||||
@@ -2,7 +2,11 @@ from enum import Enum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
@@ -56,6 +60,22 @@ class ChatCompletion(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageEdits(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageEditsInternalParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
@@ -67,5 +87,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
)
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -18,5 +21,32 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
stats: ImageGenerationStats | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} bytes>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class PartialImageResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_index: int
|
||||
total_partials: int
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} bytes>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
@@ -15,6 +15,7 @@ import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import aiohttp
|
||||
import certifi
|
||||
from huggingface_hub._snapshot_download import snapshot_download
|
||||
from loguru import logger
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -445,12 +446,31 @@ def calculate_repo_progress(
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
|
||||
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_file = await download_file_with_retry(
|
||||
repo_id, revision, "model.safetensors.index.json", target_dir
|
||||
|
||||
index_files_dir = snapshot_download(
|
||||
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
|
||||
)
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
return index_data.weight_map
|
||||
|
||||
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
|
||||
|
||||
weight_map: dict[str, str] = {}
|
||||
|
||||
for index_file in index_files:
|
||||
relative_dir = index_file.parent.relative_to(index_files_dir)
|
||||
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
if relative_dir != Path("."):
|
||||
prefixed_weight_map = {
|
||||
f"{relative_dir}/{key}": str(relative_dir / value)
|
||||
for key, value in index_data.weight_map.items()
|
||||
}
|
||||
weight_map = weight_map | prefixed_weight_map
|
||||
else:
|
||||
weight_map = weight_map | index_data.weight_map
|
||||
|
||||
return weight_map
|
||||
|
||||
|
||||
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
|
||||
@@ -555,8 +575,6 @@ async def download_shard(
|
||||
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
# Update: <- This does not seem to be the case. Yay?
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
str(shard.model_card.model_id), revision, recursive=True
|
||||
)
|
||||
|
||||
@@ -100,26 +100,68 @@ def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"tiktoken.model",
|
||||
"*/spiece.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jinja",
|
||||
]
|
||||
)
|
||||
shard_specific_patterns: set[str] = set()
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num <= shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
layer_independent_files = set(
|
||||
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
|
||||
|
||||
if shard.model_card.components is not None:
|
||||
shardable_component = next(
|
||||
(c for c in shard.model_card.components if c.can_shard), None
|
||||
)
|
||||
shard_specific_patterns.update(layer_independent_files)
|
||||
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
|
||||
|
||||
if weight_map and shardable_component:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
# Strip component prefix from tensor name (added by weight map namespacing)
|
||||
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
|
||||
if "/" in tensor_name:
|
||||
_, tensor_name_no_prefix = tensor_name.split("/", 1)
|
||||
else:
|
||||
tensor_name_no_prefix = tensor_name
|
||||
|
||||
# Determine which component this file belongs to from filename
|
||||
component_path = Path(filename).parts[0] if "/" in filename else None
|
||||
|
||||
if component_path == shardable_component.component_path.rstrip("/"):
|
||||
layer_num = extract_layer_num(tensor_name_no_prefix)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num < shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
|
||||
if shard.is_first_layer or shard.is_last_layer:
|
||||
shard_specific_patterns.add(filename)
|
||||
else:
|
||||
shard_specific_patterns.add(filename)
|
||||
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
|
||||
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
|
||||
for component in shard.model_card.components:
|
||||
if not component.can_shard and component.safetensors_index_filename is None:
|
||||
component_pattern = f"{component.component_path.rstrip('/')}/*"
|
||||
shard_specific_patterns.add(component_pattern)
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num < shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
layer_independent_files = set(
|
||||
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
|
||||
)
|
||||
shard_specific_patterns.update(layer_independent_files)
|
||||
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
|
||||
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
|
||||
return list(default_patterns | shard_specific_patterns)
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
@@ -92,6 +92,7 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
|
||||
12
src/exo/worker/engines/image/__init__.py
Normal file
12
src/exo/worker/engines/image/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from exo.worker.engines.image.distributed_model import (
|
||||
DistributedImageModel,
|
||||
initialize_image_model,
|
||||
)
|
||||
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
|
||||
|
||||
__all__ = [
|
||||
"DistributedImageModel",
|
||||
"generate_image",
|
||||
"initialize_image_model",
|
||||
"warmup_image_generator",
|
||||
]
|
||||
50
src/exo/worker/engines/image/config.py
Normal file
50
src/exo/worker/engines/image/config.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BlockType(Enum):
|
||||
JOINT = "joint" # Separate image/text streams
|
||||
SINGLE = "single" # Concatenated streams
|
||||
|
||||
|
||||
class TransformerBlockConfig(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
block_type: BlockType
|
||||
count: int
|
||||
has_separate_text_output: bool # True for joint blocks that output text separately
|
||||
|
||||
|
||||
class ImageModelConfig(BaseModel):
|
||||
model_family: str
|
||||
|
||||
block_configs: tuple[TransformerBlockConfig, ...]
|
||||
|
||||
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
|
||||
num_sync_steps_factor: float # Fraction of steps for sync phase
|
||||
|
||||
guidance_scale: float | None = None # None or <= 1.0 disables CFG
|
||||
|
||||
@property
|
||||
def total_blocks(self) -> int:
|
||||
return sum(bc.count for bc in self.block_configs)
|
||||
|
||||
@property
|
||||
def joint_block_count(self) -> int:
|
||||
return sum(
|
||||
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
|
||||
)
|
||||
|
||||
@property
|
||||
def single_block_count(self) -> int:
|
||||
return sum(
|
||||
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
|
||||
)
|
||||
|
||||
def get_steps_for_quality(self, quality: str) -> int:
|
||||
return self.default_steps[quality]
|
||||
|
||||
def get_num_sync_steps(self, steps: int) -> int:
|
||||
return ceil(steps * self.num_sync_steps_factor)
|
||||
166
src/exo/worker/engines/image/distributed_model.py
Normal file
166
src/exo/worker/engines/image/distributed_model.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from PIL import Image
|
||||
|
||||
from exo.shared.types.api import AdvancedImageParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models import (
|
||||
create_adapter_for_model,
|
||||
get_config_for_model,
|
||||
)
|
||||
from exo.worker.engines.image.models.base import ModelAdapter
|
||||
from exo.worker.engines.image.pipeline import DiffusionRunner
|
||||
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
class DistributedImageModel:
|
||||
_config: ImageModelConfig
|
||||
_adapter: ModelAdapter[Any, Any]
|
||||
_runner: DiffusionRunner
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
config = get_config_for_model(model_id)
|
||||
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
|
||||
|
||||
if group is not None:
|
||||
adapter.slice_transformer_blocks(
|
||||
start_layer=shard_metadata.start_layer,
|
||||
end_layer=shard_metadata.end_layer,
|
||||
)
|
||||
|
||||
runner = DiffusionRunner(
|
||||
config=config,
|
||||
adapter=adapter,
|
||||
group=group,
|
||||
shard_metadata=shard_metadata,
|
||||
)
|
||||
|
||||
if group is not None:
|
||||
logger.info("Initialized distributed diffusion runner")
|
||||
|
||||
mx.eval(adapter.model.parameters())
|
||||
|
||||
# TODO(ciaran): Do we need this?
|
||||
mx.eval(adapter.model)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info(f"Transformer sharded for rank {group.rank()}")
|
||||
else:
|
||||
logger.info("Single-node initialization")
|
||||
|
||||
self._config = config
|
||||
self._adapter = adapter
|
||||
self._runner = runner
|
||||
|
||||
@classmethod
|
||||
def from_bound_instance(
|
||||
cls, bound_instance: BoundInstance
|
||||
) -> "DistributedImageModel":
|
||||
model_id = bound_instance.bound_shard.model_card.model_id
|
||||
model_path = build_model_path(model_id)
|
||||
|
||||
shard_metadata = bound_instance.bound_shard
|
||||
if not isinstance(shard_metadata, PipelineShardMetadata):
|
||||
raise ValueError("Expected PipelineShardMetadata for image generation")
|
||||
|
||||
is_distributed = (
|
||||
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
|
||||
)
|
||||
|
||||
if is_distributed:
|
||||
logger.info("Starting distributed init for image model")
|
||||
group = mlx_distributed_init(bound_instance)
|
||||
else:
|
||||
group = None
|
||||
|
||||
return cls(
|
||||
model_id=model_id,
|
||||
local_path=model_path,
|
||||
shard_metadata=shard_metadata,
|
||||
group=group,
|
||||
)
|
||||
|
||||
def get_steps_for_quality(self, quality: Literal["low", "medium", "high"]) -> int:
|
||||
"""Get the number of inference steps for a quality level."""
|
||||
return self._config.get_steps_for_quality(quality)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
quality: Literal["low", "medium", "high"] = "medium",
|
||||
seed: int = 2,
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
advanced_params: AdvancedImageParams | None = None,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
if (
|
||||
advanced_params is not None
|
||||
and advanced_params.num_inference_steps is not None
|
||||
):
|
||||
steps = advanced_params.num_inference_steps
|
||||
else:
|
||||
steps = self._config.get_steps_for_quality(quality)
|
||||
|
||||
guidance_override: float | None = None
|
||||
if advanced_params is not None and advanced_params.guidance is not None:
|
||||
guidance_override = advanced_params.guidance
|
||||
|
||||
negative_prompt: str | None = None
|
||||
if advanced_params is not None and advanced_params.negative_prompt is not None:
|
||||
negative_prompt = advanced_params.negative_prompt
|
||||
|
||||
# For edit mode: compute dimensions from input image
|
||||
# This also stores image_paths in the adapter for encode_prompt()
|
||||
if image_path is not None:
|
||||
computed_dims = self._adapter.set_image_dimensions(image_path)
|
||||
if computed_dims is not None:
|
||||
# Override user-provided dimensions with computed ones
|
||||
width, height = computed_dims
|
||||
|
||||
config = Config(
|
||||
num_inference_steps=steps,
|
||||
height=height,
|
||||
width=width,
|
||||
image_path=image_path,
|
||||
model_config=self._adapter.model.model_config,
|
||||
)
|
||||
|
||||
num_sync_steps = self._config.get_num_sync_steps(steps)
|
||||
|
||||
for result in self._runner.generate_image(
|
||||
runtime_config=config,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
partial_images=partial_images,
|
||||
guidance_override=guidance_override,
|
||||
negative_prompt=negative_prompt,
|
||||
num_sync_steps=num_sync_steps,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (GeneratedImage, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
yield (image, partial_idx, total_partials)
|
||||
else:
|
||||
logger.info("generated image")
|
||||
yield result
|
||||
|
||||
|
||||
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
|
||||
return DistributedImageModel.from_bound_instance(bound_instance)
|
||||
170
src/exo/worker/engines/image/generate.py
Normal file
170
src/exo/worker/engines/image/generate.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import base64
|
||||
import io
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from PIL import Image
|
||||
|
||||
from exo.shared.types.api import (
|
||||
AdvancedImageParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationStats,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
)
|
||||
from exo.worker.engines.image.distributed_model import DistributedImageModel
|
||||
|
||||
|
||||
def parse_size(size_str: str | None) -> tuple[int, int]:
|
||||
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
|
||||
if not size_str or size_str == "auto":
|
||||
size_str = "1024x1024"
|
||||
|
||||
try:
|
||||
parts = size_str.split("x")
|
||||
if len(parts) == 2:
|
||||
width, height = int(parts[0]), int(parts[1])
|
||||
return (width, height)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
# Default fallback
|
||||
return (1024, 1024)
|
||||
|
||||
|
||||
def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
|
||||
"""Warmup the image generator with a small image."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a small dummy image for warmup (needed for edit models)
|
||||
dummy_image = Image.new("RGB", (256, 256), color=(128, 128, 128))
|
||||
dummy_path = Path(tmpdir) / "warmup.png"
|
||||
dummy_image.save(dummy_path)
|
||||
|
||||
warmup_params = AdvancedImageParams(num_inference_steps=2)
|
||||
|
||||
for result in model.generate(
|
||||
prompt="Warmup",
|
||||
height=256,
|
||||
width=256,
|
||||
quality="low",
|
||||
image_path=dummy_path,
|
||||
advanced_params=warmup_params,
|
||||
):
|
||||
if not isinstance(result, tuple):
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def generate_image(
|
||||
model: DistributedImageModel,
|
||||
task: ImageGenerationTaskParams | ImageEditsInternalParams,
|
||||
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
|
||||
"""Generate image(s), optionally yielding partial results.
|
||||
|
||||
When partial_images > 0 or stream=True, yields PartialImageResponse for
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0)
|
||||
ImageGenerationResponse for the final complete image
|
||||
"""
|
||||
width, height = parse_size(task.size)
|
||||
quality: Literal["low", "medium", "high"] = task.quality or "medium"
|
||||
|
||||
advanced_params = task.advanced_params
|
||||
if advanced_params is not None and advanced_params.seed is not None:
|
||||
seed = advanced_params.seed
|
||||
else:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
is_bench = getattr(task, "bench", False)
|
||||
|
||||
generation_start_time: float = 0.0
|
||||
|
||||
if is_bench:
|
||||
mx.reset_peak_memory()
|
||||
generation_start_time = time.perf_counter()
|
||||
|
||||
partial_images = task.partial_images or (3 if task.stream else 0)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
if isinstance(task, ImageEditsInternalParams):
|
||||
# Decode base64 image data and save to temp file
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
|
||||
# Iterate over generator results
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = generation_end_time - generation_start_time
|
||||
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / num_inference_steps
|
||||
if num_inference_steps > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=task.n or 1,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
)
|
||||
84
src/exo/worker/engines/image/models/__init__.py
Normal file
84
src/exo/worker/engines/image/models/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import ModelAdapter
|
||||
from exo.worker.engines.image.models.flux import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
FluxModelAdapter,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen import (
|
||||
QWEN_IMAGE_CONFIG,
|
||||
QWEN_IMAGE_EDIT_CONFIG,
|
||||
QwenEditModelAdapter,
|
||||
QwenModelAdapter,
|
||||
)
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
# Type alias for adapter factory functions
|
||||
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
|
||||
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter[Any, Any]]
|
||||
|
||||
# Registry maps model_family string to adapter factory
|
||||
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
|
||||
"flux": FluxModelAdapter,
|
||||
"qwen-edit": QwenEditModelAdapter,
|
||||
"qwen": QwenModelAdapter,
|
||||
}
|
||||
|
||||
# Config registry: maps model ID patterns to configs
|
||||
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
|
||||
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
|
||||
"flux.1-dev": FLUX_DEV_CONFIG,
|
||||
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
|
||||
"qwen-image": QWEN_IMAGE_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_config_for_model(model_id: str) -> ImageModelConfig:
|
||||
"""Get configuration for a model ID.
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
|
||||
|
||||
Returns:
|
||||
The model configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If no configuration found for model ID
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
|
||||
for pattern, config in _CONFIG_REGISTRY.items():
|
||||
if pattern in model_id_lower:
|
||||
return config
|
||||
|
||||
raise ValueError(f"No configuration found for model: {model_id}")
|
||||
|
||||
|
||||
def create_adapter_for_model(
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
) -> ModelAdapter[Any, Any]:
|
||||
"""Create a model adapter for the given configuration.
|
||||
|
||||
Args:
|
||||
config: The model configuration
|
||||
model_id: The model identifier
|
||||
local_path: Path to the model weights
|
||||
quantize: Optional quantization bits
|
||||
|
||||
Returns:
|
||||
A ModelAdapter instance
|
||||
|
||||
Raises:
|
||||
ValueError: If no adapter found for model family
|
||||
"""
|
||||
factory = _ADAPTER_REGISTRY.get(config.model_family)
|
||||
if factory is None:
|
||||
raise ValueError(f"No adapter found for model family: {config.model_family}")
|
||||
return factory(config, model_id, local_path, quantize)
|
||||
295
src/exo/worker/engines/image/models/base.py
Normal file
295
src/exo/worker/engines/image/models/base.py
Normal file
@@ -0,0 +1,295 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
from PIL import Image
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
ModelT = TypeVar("ModelT")
|
||||
TransformerT = TypeVar("TransformerT")
|
||||
|
||||
RotaryEmbeddings = mx.array | tuple[mx.array, mx.array]
|
||||
|
||||
|
||||
class PromptData(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def prompt_embeds(self) -> mx.array: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def pooled_prompt_embeds(self) -> mx.array: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def negative_prompt_embeds(self) -> mx.array | None: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_encoder_hidden_states_mask(
|
||||
self, positive: bool = True
|
||||
) -> mx.array | None: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cond_image_grid(
|
||||
self,
|
||||
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
|
||||
"""Conditioning image grid dimensions for edit mode.
|
||||
|
||||
Returns:
|
||||
Grid dimensions (edit) or None (standard generation).
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Conditioning latents for edit mode.
|
||||
|
||||
Returns:
|
||||
Conditioning latents array for image editing, None for standard generation.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
"""Get embeddings for CFG with batch_size=2.
|
||||
|
||||
Combines positive and negative embeddings into batched tensors for
|
||||
a single forward pass. Pads shorter sequences to max length. Attention
|
||||
mask is used to mask padding.
|
||||
|
||||
Returns:
|
||||
None if model doesn't support CFG, otherwise tuple of:
|
||||
- batched_embeds: [2, max_seq, hidden] (positive then negative)
|
||||
- batched_mask: [2, max_seq] attention mask
|
||||
- batched_pooled: [2, hidden] pooled embeddings or None
|
||||
- conditioning_latents: [2, latent_seq, latent_dim] or None
|
||||
TODO(ciaran): type this
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModelAdapter(ABC, Generic[ModelT, TransformerT]):
|
||||
_config: ImageModelConfig
|
||||
_model: ModelT
|
||||
_transformer: TransformerT
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> ModelT:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> TransformerT:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def hidden_dim(self) -> int: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def needs_cfg(self) -> bool:
|
||||
"""Whether this model uses classifier-free guidance."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_latent_creator(self) -> type: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_joint_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> list["JointBlockWrapper[Any]"]:
|
||||
"""Create wrapped joint transformer blocks with pipefusion support.
|
||||
|
||||
Args:
|
||||
text_seq_len: Number of text tokens (constant for generation)
|
||||
encoder_hidden_states_mask: Attention mask for text (Qwen only)
|
||||
|
||||
Returns:
|
||||
List of wrapped joint blocks ready for pipefusion
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_single_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
) -> list["SingleBlockWrapper[Any]"]:
|
||||
"""Create wrapped single transformer blocks with pipefusion support.
|
||||
|
||||
Args:
|
||||
text_seq_len: Number of text tokens (constant for generation)
|
||||
|
||||
Returns:
|
||||
List of wrapped single blocks ready for pipefusion
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
):
|
||||
"""Remove transformer blocks outside the assigned range.
|
||||
|
||||
This should be called BEFORE mx.eval() to avoid loading unused weights
|
||||
in distributed mode.
|
||||
|
||||
Args:
|
||||
start_layer: First layer index (inclusive) assigned to this node
|
||||
end_layer: Last layer index (exclusive) assigned to this node
|
||||
"""
|
||||
...
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
|
||||
"""Default implementation: no dimension computation needed.
|
||||
|
||||
Override in edit adapters to compute dimensions from input image.
|
||||
TODO(ciaran): this is a hack
|
||||
|
||||
Returns:
|
||||
None (use user-specified dimensions)
|
||||
"""
|
||||
return None
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
|
||||
"""Create initial latents. Uses model-specific latent creator."""
|
||||
model: Any = self.model
|
||||
return LatentCreator.create_for_txt2img_or_img2img(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
img2img=Img2Img(
|
||||
vae=model.vae,
|
||||
latent_creator=self._get_latent_creator(),
|
||||
sigmas=runtime_config.scheduler.sigmas,
|
||||
init_time_step=runtime_config.init_time_step,
|
||||
image_path=runtime_config.image_path,
|
||||
),
|
||||
)
|
||||
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: mx.array,
|
||||
runtime_config: Config,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
) -> Image.Image:
|
||||
model: Any = self.model # Allow attribute access on model
|
||||
latents = self._get_latent_creator().unpack_latents(
|
||||
latents=latents,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
decoded = model.vae.decode(latents)
|
||||
# TODO(ciaran):
|
||||
# from mflux.models.common.vae.vae_util import VAEUtil
|
||||
# VAEUtil.decode(vae=model.vae, latents=latents, tiling_config=self.tiling_config)
|
||||
generated_image = ImageUtil.to_image(
|
||||
decoded_latents=decoded,
|
||||
config=runtime_config,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
quantization=model.bits,
|
||||
lora_paths=model.lora_paths,
|
||||
lora_scales=model.lora_scales,
|
||||
image_path=runtime_config.image_path,
|
||||
image_strength=runtime_config.image_strength,
|
||||
generation_time=0,
|
||||
)
|
||||
return generated_image.image
|
||||
|
||||
@abstractmethod
|
||||
def encode_prompt(
|
||||
self, prompt: str, negative_prompt: str | None = None
|
||||
) -> "PromptData": ...
|
||||
|
||||
@abstractmethod
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
@abstractmethod
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: Config,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
@abstractmethod
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> RotaryEmbeddings: ...
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
@abstractmethod
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
"""Apply classifier-free guidance to combine positive/negative predictions.
|
||||
|
||||
Only called when needs_cfg is True.
|
||||
|
||||
Args:
|
||||
noise_positive: Noise prediction from positive prompt
|
||||
noise_negative: Noise prediction from negative prompt
|
||||
guidance_scale: Guidance strength
|
||||
|
||||
Returns:
|
||||
Guided noise prediction
|
||||
"""
|
||||
...
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
transformer: Any = self.transformer
|
||||
hidden_states = transformer.norm_out(hidden_states, text_embeddings)
|
||||
return transformer.proj_out(hidden_states)
|
||||
11
src/exo/worker/engines/image/models/flux/__init__.py
Normal file
11
src/exo/worker/engines/image/models/flux/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
|
||||
from exo.worker.engines.image.models.flux.config import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FluxModelAdapter",
|
||||
"FLUX_DEV_CONFIG",
|
||||
"FLUX_SCHNELL_CONFIG",
|
||||
]
|
||||
215
src/exo/worker/engines/image/models/flux/adapter.py
Normal file
215
src/exo/worker/engines/image/models/flux/adapter.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.models.common.config.model_config import ModelConfig
|
||||
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
|
||||
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
from mflux.models.flux.variants.txt2img.flux import Flux1
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
ModelAdapter,
|
||||
PromptData,
|
||||
RotaryEmbeddings,
|
||||
)
|
||||
from exo.worker.engines.image.models.flux.wrappers import (
|
||||
FluxJointBlockWrapper,
|
||||
FluxSingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
class FluxPromptData(PromptData):
|
||||
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._pooled_prompt_embeds = pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def cond_image_grid(
|
||||
self,
|
||||
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
return None
|
||||
|
||||
|
||||
class FluxModelAdapter(ModelAdapter[Flux1, Transformer]):
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = Flux1(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
model_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.x_embedder.weight.shape[0]
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
return False
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return FluxLatentCreator
|
||||
|
||||
def get_joint_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> list[JointBlockWrapper[Any]]:
|
||||
"""Create wrapped joint blocks for Flux."""
|
||||
return [
|
||||
FluxJointBlockWrapper(block, text_seq_len)
|
||||
for block in self._transformer.transformer_blocks
|
||||
]
|
||||
|
||||
def get_single_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
) -> list[SingleBlockWrapper[Any]]:
|
||||
"""Create wrapped single blocks for Flux."""
|
||||
return [
|
||||
FluxSingleBlockWrapper(block, text_seq_len)
|
||||
for block in self._transformer.single_transformer_blocks
|
||||
]
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
):
|
||||
all_joint = list(self._transformer.transformer_blocks)
|
||||
all_single = list(self._transformer.single_transformer_blocks)
|
||||
total_joint_blocks = len(all_joint)
|
||||
if end_layer <= total_joint_blocks:
|
||||
# All assigned are joint blocks
|
||||
joint_start, joint_end = start_layer, end_layer
|
||||
single_start, single_end = 0, 0
|
||||
elif start_layer >= total_joint_blocks:
|
||||
# All assigned are single blocks
|
||||
joint_start, joint_end = 0, 0
|
||||
single_start = start_layer - total_joint_blocks
|
||||
single_end = end_layer - total_joint_blocks
|
||||
else:
|
||||
# Spans both joint and single
|
||||
joint_start, joint_end = start_layer, total_joint_blocks
|
||||
single_start = 0
|
||||
single_end = end_layer - total_joint_blocks
|
||||
|
||||
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
|
||||
|
||||
self._transformer.single_transformer_blocks = all_single[
|
||||
single_start:single_end
|
||||
]
|
||||
|
||||
def encode_prompt(
|
||||
self, prompt: str, negative_prompt: str | None = None
|
||||
) -> FluxPromptData:
|
||||
del negative_prompt
|
||||
|
||||
assert isinstance(self.model.prompt_cache, dict)
|
||||
assert isinstance(self.model.tokenizers, dict)
|
||||
|
||||
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_cache=self.model.prompt_cache,
|
||||
t5_tokenizer=self.model.tokenizers["t5"],
|
||||
clip_tokenizer=self.model.tokenizers["clip"],
|
||||
t5_text_encoder=self.model.t5_text_encoder,
|
||||
clip_text_encoder=self.model.clip_text_encoder,
|
||||
)
|
||||
return FluxPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
embedded_hidden = self._transformer.x_embedder(hidden_states)
|
||||
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: Config,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None, # Ignored by Flux
|
||||
) -> mx.array:
|
||||
if pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"pooled_prompt_embeds is required for Flux text embeddings"
|
||||
)
|
||||
|
||||
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
|
||||
return Transformer.compute_text_embeddings(
|
||||
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
|
||||
)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> RotaryEmbeddings:
|
||||
return Transformer.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
self._transformer.pos_embed,
|
||||
runtime_config,
|
||||
kontext_image_ids,
|
||||
)
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
raise NotImplementedError("Flux does not use classifier-free guidance")
|
||||
34
src/exo/worker/engines/image/models/flux/config.py
Normal file
34
src/exo/worker/engines/image/models/flux/config.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from exo.worker.engines.image.config import (
|
||||
BlockType,
|
||||
ImageModelConfig,
|
||||
TransformerBlockConfig,
|
||||
)
|
||||
|
||||
FLUX_SCHNELL_CONFIG = ImageModelConfig(
|
||||
model_family="flux",
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
default_steps={"low": 1, "medium": 2, "high": 4},
|
||||
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
|
||||
)
|
||||
|
||||
|
||||
FLUX_DEV_CONFIG = ImageModelConfig(
|
||||
model_family="flux",
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
|
||||
)
|
||||
279
src/exo/worker/engines/image/models/flux/wrappers.py
Normal file
279
src/exo/worker/engines/image/models/flux/wrappers.py
Normal file
@@ -0,0 +1,279 @@
|
||||
from typing import final
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
|
||||
AttentionUtils,
|
||||
)
|
||||
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
|
||||
JointTransformerBlock,
|
||||
)
|
||||
from mflux.models.flux.model.flux_transformer.single_transformer_block import (
|
||||
SingleTransformerBlock,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.worker.engines.image.models.base import RotaryEmbeddings
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FluxModulationParams(BaseModel):
|
||||
model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True)
|
||||
|
||||
gate_msa: mx.array
|
||||
shift_mlp: mx.array
|
||||
scale_mlp: mx.array
|
||||
gate_mlp: mx.array
|
||||
|
||||
|
||||
@final
|
||||
class FluxNormGateState(BaseModel):
|
||||
model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True)
|
||||
|
||||
norm_hidden: mx.array
|
||||
gate: mx.array
|
||||
|
||||
|
||||
class FluxJointBlockWrapper(JointBlockWrapper[JointTransformerBlock]):
|
||||
def __init__(self, block: JointTransformerBlock, text_seq_len: int):
|
||||
super().__init__(block, text_seq_len)
|
||||
self._num_heads = block.attn.num_heads
|
||||
self._head_dim = block.attn.head_dimension
|
||||
|
||||
# Intermediate state stored between _compute_qkv and _apply_output
|
||||
self._hidden_mod: FluxModulationParams | None = None
|
||||
self._context_mod: FluxModulationParams | None = None
|
||||
|
||||
def _compute_qkv(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: RotaryEmbeddings,
|
||||
patch_mode: bool = False,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
assert isinstance(rotary_embeddings, mx.array)
|
||||
|
||||
attn = self.block.attn
|
||||
|
||||
(
|
||||
norm_hidden,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
) = self.block.norm1(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
self._hidden_mod = FluxModulationParams(
|
||||
gate_msa=gate_msa,
|
||||
shift_mlp=shift_mlp,
|
||||
scale_mlp=scale_mlp,
|
||||
gate_mlp=gate_mlp,
|
||||
)
|
||||
|
||||
(
|
||||
norm_encoder,
|
||||
c_gate_msa,
|
||||
c_shift_mlp,
|
||||
c_scale_mlp,
|
||||
c_gate_mlp,
|
||||
) = self.block.norm1_context(
|
||||
hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
self._context_mod = FluxModulationParams(
|
||||
gate_msa=c_gate_msa,
|
||||
shift_mlp=c_shift_mlp,
|
||||
scale_mlp=c_scale_mlp,
|
||||
gate_mlp=c_gate_mlp,
|
||||
)
|
||||
|
||||
img_query, img_key, img_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=self._num_heads,
|
||||
head_dim=self._head_dim,
|
||||
)
|
||||
|
||||
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_encoder,
|
||||
to_q=attn.add_q_proj,
|
||||
to_k=attn.add_k_proj,
|
||||
to_v=attn.add_v_proj,
|
||||
norm_q=attn.norm_added_q,
|
||||
norm_k=attn.norm_added_k,
|
||||
num_heads=self._num_heads,
|
||||
head_dim=self._head_dim,
|
||||
)
|
||||
|
||||
query = mx.concatenate([txt_query, img_query], axis=2)
|
||||
key = mx.concatenate([txt_key, img_key], axis=2)
|
||||
value = mx.concatenate([txt_value, img_value], axis=2)
|
||||
|
||||
if patch_mode:
|
||||
text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]
|
||||
patch_img_rope = rotary_embeddings[
|
||||
:,
|
||||
:,
|
||||
self._text_seq_len + self._patch_start : self._text_seq_len
|
||||
+ self._patch_end,
|
||||
...,
|
||||
]
|
||||
rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
|
||||
else:
|
||||
rope = rotary_embeddings
|
||||
|
||||
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=rope)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _compute_attention(
|
||||
self, query: mx.array, key: mx.array, value: mx.array
|
||||
) -> mx.array:
|
||||
batch_size = query.shape[0]
|
||||
return AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
num_heads=self._num_heads,
|
||||
head_dim=self._head_dim,
|
||||
)
|
||||
|
||||
def _apply_output(
|
||||
self,
|
||||
attn_out: mx.array,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
attn = self.block.attn
|
||||
|
||||
context_attn_output = attn_out[:, : self._text_seq_len, :]
|
||||
hidden_attn_output = attn_out[:, self._text_seq_len :, :]
|
||||
|
||||
hidden_attn_output = attn.to_out[0](hidden_attn_output)
|
||||
context_attn_output = attn.to_add_out(context_attn_output)
|
||||
|
||||
assert self._hidden_mod is not None
|
||||
assert self._context_mod is not None
|
||||
|
||||
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=hidden_states,
|
||||
attn_output=hidden_attn_output,
|
||||
gate_mlp=self._hidden_mod.gate_mlp,
|
||||
gate_msa=self._hidden_mod.gate_msa,
|
||||
scale_mlp=self._hidden_mod.scale_mlp,
|
||||
shift_mlp=self._hidden_mod.shift_mlp,
|
||||
norm_layer=self.block.norm2,
|
||||
ff_layer=self.block.ff,
|
||||
)
|
||||
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=encoder_hidden_states,
|
||||
attn_output=context_attn_output,
|
||||
gate_mlp=self._context_mod.gate_mlp,
|
||||
gate_msa=self._context_mod.gate_msa,
|
||||
scale_mlp=self._context_mod.scale_mlp,
|
||||
shift_mlp=self._context_mod.shift_mlp,
|
||||
norm_layer=self.block.norm2_context,
|
||||
ff_layer=self.block.ff_context,
|
||||
)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class FluxSingleBlockWrapper(SingleBlockWrapper[SingleTransformerBlock]):
|
||||
"""Flux-specific single block wrapper with pipefusion support."""
|
||||
|
||||
def __init__(self, block: SingleTransformerBlock, text_seq_len: int):
|
||||
super().__init__(block, text_seq_len)
|
||||
self._num_heads = block.attn.num_heads
|
||||
self._head_dim = block.attn.head_dimension
|
||||
|
||||
# Intermediate state stored between _compute_qkv and _apply_output
|
||||
self._norm_state: FluxNormGateState | None = None
|
||||
|
||||
def _compute_qkv(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: RotaryEmbeddings,
|
||||
patch_mode: bool = False,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
assert isinstance(rotary_embeddings, mx.array)
|
||||
|
||||
attn = self.block.attn
|
||||
|
||||
norm_hidden, gate = self.block.norm(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
self._norm_state = FluxNormGateState(norm_hidden=norm_hidden, gate=gate)
|
||||
|
||||
query, key, value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=self._num_heads,
|
||||
head_dim=self._head_dim,
|
||||
)
|
||||
|
||||
if patch_mode:
|
||||
text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]
|
||||
patch_img_rope = rotary_embeddings[
|
||||
:,
|
||||
:,
|
||||
self._text_seq_len + self._patch_start : self._text_seq_len
|
||||
+ self._patch_end,
|
||||
...,
|
||||
]
|
||||
rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
|
||||
else:
|
||||
rope = rotary_embeddings
|
||||
|
||||
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=rope)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _compute_attention(
|
||||
self, query: mx.array, key: mx.array, value: mx.array
|
||||
) -> mx.array:
|
||||
batch_size = query.shape[0]
|
||||
return AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
num_heads=self._num_heads,
|
||||
head_dim=self._head_dim,
|
||||
)
|
||||
|
||||
def _apply_output(
|
||||
self,
|
||||
attn_out: mx.array,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
residual = hidden_states
|
||||
|
||||
assert self._norm_state is not None
|
||||
|
||||
output = self.block._apply_feed_forward_and_projection( # pyright: ignore[reportPrivateUsage]
|
||||
norm_hidden_states=self._norm_state.norm_hidden,
|
||||
attn_output=attn_out,
|
||||
gate=self._norm_state.gate,
|
||||
)
|
||||
|
||||
return residual + output
|
||||
13
src/exo/worker/engines/image/models/qwen/__init__.py
Normal file
13
src/exo/worker/engines/image/models/qwen/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
|
||||
from exo.worker.engines.image.models.qwen.config import (
|
||||
QWEN_IMAGE_CONFIG,
|
||||
QWEN_IMAGE_EDIT_CONFIG,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter
|
||||
|
||||
__all__ = [
|
||||
"QwenModelAdapter",
|
||||
"QwenEditModelAdapter",
|
||||
"QWEN_IMAGE_CONFIG",
|
||||
"QWEN_IMAGE_EDIT_CONFIG",
|
||||
]
|
||||
292
src/exo/worker/engines/image/models/qwen/adapter.py
Normal file
292
src/exo/worker/engines/image/models/qwen/adapter.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config import ModelConfig
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
||||
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
|
||||
QwenPromptEncoder,
|
||||
)
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
ModelAdapter,
|
||||
PromptData,
|
||||
RotaryEmbeddings,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
class QwenPromptData(PromptData):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
prompt_mask: mx.array,
|
||||
negative_prompt_embeds: mx.array,
|
||||
negative_prompt_mask: mx.array,
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._prompt_mask = prompt_mask
|
||||
self._negative_prompt_embeds = negative_prompt_embeds
|
||||
self._negative_prompt_mask = negative_prompt_mask
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array:
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:
|
||||
if positive:
|
||||
return self._prompt_mask
|
||||
else:
|
||||
return self._negative_prompt_mask
|
||||
|
||||
@property
|
||||
def cond_image_grid(
|
||||
self,
|
||||
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
return None
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
"""Batch positive and negative embeddings for CFG with batch_size=2.
|
||||
|
||||
Pads shorter sequence to max length using zeros for embeddings
|
||||
and zeros (masked) for attention mask.
|
||||
|
||||
Returns:
|
||||
Tuple of (batched_embeds, batched_mask, None, conditioning_latents)
|
||||
- batched_embeds: [2, max_seq, hidden]
|
||||
- batched_mask: [2, max_seq]
|
||||
- None for pooled (Qwen doesn't use it)
|
||||
- conditioning_latents: [2, latent_seq, latent_dim] or None
|
||||
"""
|
||||
pos_embeds = self._prompt_embeds
|
||||
neg_embeds = self._negative_prompt_embeds
|
||||
pos_mask = self._prompt_mask
|
||||
neg_mask = self._negative_prompt_mask
|
||||
|
||||
pos_seq_len = pos_embeds.shape[1]
|
||||
neg_seq_len = neg_embeds.shape[1]
|
||||
max_seq_len = max(pos_seq_len, neg_seq_len)
|
||||
hidden_dim = pos_embeds.shape[2]
|
||||
|
||||
if pos_seq_len < max_seq_len:
|
||||
pad_len = max_seq_len - pos_seq_len
|
||||
pos_embeds = mx.concatenate(
|
||||
[
|
||||
pos_embeds,
|
||||
mx.zeros((1, pad_len, hidden_dim), dtype=pos_embeds.dtype),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
pos_mask = mx.concatenate(
|
||||
[pos_mask, mx.zeros((1, pad_len), dtype=pos_mask.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
elif neg_seq_len < max_seq_len:
|
||||
pad_len = max_seq_len - neg_seq_len
|
||||
neg_embeds = mx.concatenate(
|
||||
[
|
||||
neg_embeds,
|
||||
mx.zeros((1, pad_len, hidden_dim), dtype=neg_embeds.dtype),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
neg_mask = mx.concatenate(
|
||||
[neg_mask, mx.zeros((1, pad_len), dtype=neg_mask.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
batched_embeds = mx.concatenate([pos_embeds, neg_embeds], axis=0)
|
||||
batched_mask = mx.concatenate([pos_mask, neg_mask], axis=0)
|
||||
|
||||
# TODO(ciaran): currently None but maybe we will deduplicate with edit
|
||||
# adapter
|
||||
cond_latents = self.conditioning_latents
|
||||
if cond_latents is not None:
|
||||
cond_latents = mx.concatenate([cond_latents, cond_latents], axis=0)
|
||||
|
||||
return batched_embeds, batched_mask, None, cond_latents
|
||||
|
||||
|
||||
class QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image model.
|
||||
|
||||
Key differences from Flux:
|
||||
- Single text encoder (vs dual T5+CLIP)
|
||||
- 60 joint-style blocks, no single blocks
|
||||
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
|
||||
- Norm-preserving CFG with negative prompts
|
||||
- Uses attention mask for variable-length text
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = QwenImage(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
model_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.inner_dim
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
gs = self._config.guidance_scale
|
||||
return gs is not None and gs > 1.0
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return QwenLatentCreator
|
||||
|
||||
def get_joint_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> list[JointBlockWrapper[Any]]:
|
||||
"""Create wrapped joint blocks for Qwen."""
|
||||
return [
|
||||
QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)
|
||||
for block in self._transformer.transformer_blocks
|
||||
]
|
||||
|
||||
def get_single_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
) -> list[SingleBlockWrapper[Any]]:
|
||||
return []
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
):
|
||||
self._transformer.transformer_blocks = self._transformer.transformer_blocks[
|
||||
start_layer:end_layer
|
||||
]
|
||||
|
||||
def encode_prompt(
|
||||
self, prompt: str, negative_prompt: str | None = None
|
||||
) -> QwenPromptData:
|
||||
assert isinstance(self.model.prompt_cache, dict)
|
||||
assert isinstance(self.model.tokenizers, dict)
|
||||
|
||||
if negative_prompt is None or negative_prompt == "":
|
||||
negative_prompt = " "
|
||||
|
||||
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
|
||||
QwenPromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_cache=self.model.prompt_cache,
|
||||
qwen_tokenizer=self.model.tokenizers["qwen"],
|
||||
qwen_text_encoder=self.model.text_encoder,
|
||||
)
|
||||
)
|
||||
|
||||
return QwenPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_mask=prompt_mask,
|
||||
negative_prompt_embeds=neg_embeds,
|
||||
negative_prompt_mask=neg_mask,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
embedded_hidden = self._transformer.img_in(hidden_states)
|
||||
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
|
||||
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: Config,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
|
||||
# (which for Qwen is the same as prompt_embeds)
|
||||
ref_tensor = (
|
||||
hidden_states if hidden_states is not None else pooled_prompt_embeds
|
||||
)
|
||||
if ref_tensor is None:
|
||||
raise ValueError(
|
||||
"Either hidden_states or pooled_prompt_embeds is required "
|
||||
"for Qwen text embeddings"
|
||||
)
|
||||
|
||||
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
|
||||
batch_size = ref_tensor.shape[0]
|
||||
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
|
||||
return self._transformer.time_text_embed(timestep, ref_tensor)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> RotaryEmbeddings:
|
||||
if encoder_hidden_states_mask is None:
|
||||
raise ValueError(
|
||||
"encoder_hidden_states_mask is required for Qwen RoPE computation"
|
||||
)
|
||||
|
||||
return QwenTransformer._compute_rotary_embeddings( # pyright: ignore[reportPrivateUsage]
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
pos_embed=self._transformer.pos_embed,
|
||||
config=runtime_config,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
return self._model.compute_guided_noise(
|
||||
noise=noise_positive,
|
||||
noise_negative=noise_negative,
|
||||
guidance=guidance_scale,
|
||||
)
|
||||
29
src/exo/worker/engines/image/models/qwen/config.py
Normal file
29
src/exo/worker/engines/image/models/qwen/config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from exo.worker.engines.image.config import (
|
||||
BlockType,
|
||||
ImageModelConfig,
|
||||
TransformerBlockConfig,
|
||||
)
|
||||
|
||||
QWEN_IMAGE_CONFIG = ImageModelConfig(
|
||||
model_family="qwen",
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
|
||||
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
|
||||
)
|
||||
|
||||
QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
|
||||
model_family="qwen-edit",
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125,
|
||||
guidance_scale=3.5,
|
||||
)
|
||||
434
src/exo/worker/engines/image/models/qwen/edit_adapter.py
Normal file
434
src/exo/worker/engines/image/models/qwen/edit_adapter.py
Normal file
@@ -0,0 +1,434 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
|
||||
from mflux.models.qwen.variants.edit.qwen_edit_util import QwenEditUtil
|
||||
from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
ModelAdapter,
|
||||
PromptData,
|
||||
RotaryEmbeddings,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EditImageDimensions:
|
||||
vl_width: int
|
||||
vl_height: int
|
||||
vae_width: int
|
||||
vae_height: int
|
||||
image_paths: list[str]
|
||||
|
||||
|
||||
class QwenEditPromptData(PromptData):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
prompt_mask: mx.array,
|
||||
negative_prompt_embeds: mx.array,
|
||||
negative_prompt_mask: mx.array,
|
||||
conditioning_latents: mx.array,
|
||||
qwen_image_ids: mx.array,
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._prompt_mask = prompt_mask
|
||||
self._negative_prompt_embeds = negative_prompt_embeds
|
||||
self._negative_prompt_mask = negative_prompt_mask
|
||||
self._conditioning_latents = conditioning_latents
|
||||
self._qwen_image_ids = qwen_image_ids
|
||||
self._cond_image_grid = cond_image_grid
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array:
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:
|
||||
if positive:
|
||||
return self._prompt_mask
|
||||
else:
|
||||
return self._negative_prompt_mask
|
||||
|
||||
@property
|
||||
def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:
|
||||
return self._cond_image_grid
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array:
|
||||
return self._conditioning_latents
|
||||
|
||||
@property
|
||||
def qwen_image_ids(self) -> mx.array:
|
||||
return self._qwen_image_ids
|
||||
|
||||
@property
|
||||
def is_edit_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
"""Batch positive and negative embeddings for CFG with batch_size=2.
|
||||
|
||||
Pads shorter sequence to max length using zeros for embeddings
|
||||
and zeros (masked) for attention mask. Duplicates conditioning
|
||||
latents for both positive and negative passes.
|
||||
|
||||
Returns:
|
||||
Tuple of (batched_embeds, batched_mask, None, batched_cond_latents)
|
||||
- batched_embeds: [2, max_seq, hidden]
|
||||
- batched_mask: [2, max_seq]
|
||||
- None for pooled (Qwen doesn't use it)
|
||||
- batched_cond_latents: [2, latent_seq, latent_dim]
|
||||
TODO(ciaran): type this
|
||||
"""
|
||||
pos_embeds = self._prompt_embeds
|
||||
neg_embeds = self._negative_prompt_embeds
|
||||
pos_mask = self._prompt_mask
|
||||
neg_mask = self._negative_prompt_mask
|
||||
|
||||
pos_seq_len = pos_embeds.shape[1]
|
||||
neg_seq_len = neg_embeds.shape[1]
|
||||
max_seq_len = max(pos_seq_len, neg_seq_len)
|
||||
hidden_dim = pos_embeds.shape[2]
|
||||
|
||||
if pos_seq_len < max_seq_len:
|
||||
pad_len = max_seq_len - pos_seq_len
|
||||
pos_embeds = mx.concatenate(
|
||||
[
|
||||
pos_embeds,
|
||||
mx.zeros((1, pad_len, hidden_dim), dtype=pos_embeds.dtype),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
pos_mask = mx.concatenate(
|
||||
[pos_mask, mx.zeros((1, pad_len), dtype=pos_mask.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
if neg_seq_len < max_seq_len:
|
||||
pad_len = max_seq_len - neg_seq_len
|
||||
neg_embeds = mx.concatenate(
|
||||
[
|
||||
neg_embeds,
|
||||
mx.zeros((1, pad_len, hidden_dim), dtype=neg_embeds.dtype),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
neg_mask = mx.concatenate(
|
||||
[neg_mask, mx.zeros((1, pad_len), dtype=neg_mask.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
batched_embeds = mx.concatenate([pos_embeds, neg_embeds], axis=0)
|
||||
batched_mask = mx.concatenate([pos_mask, neg_mask], axis=0)
|
||||
|
||||
batched_cond_latents = mx.concatenate(
|
||||
[self._conditioning_latents, self._conditioning_latents], axis=0
|
||||
)
|
||||
|
||||
return batched_embeds, batched_mask, None, batched_cond_latents
|
||||
|
||||
|
||||
class QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image-Edit model.
|
||||
|
||||
Key differences from standard QwenModelAdapter:
|
||||
- Uses QwenImageEdit model with vision-language components
|
||||
- Encodes prompts WITH input images via VL tokenizer/encoder
|
||||
- Creates conditioning latents from input images
|
||||
- Supports image editing with concatenated latents during diffusion
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = QwenImageEdit(
|
||||
quantize=quantize,
|
||||
model_path=str(local_path),
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
self._edit_dimensions: EditImageDimensions | None = None
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> QwenImageEdit:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> QwenTransformer:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.inner_dim
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
gs = self._config.guidance_scale
|
||||
return gs is not None and gs > 1.0
|
||||
|
||||
def _get_latent_creator(self) -> type[QwenLatentCreator]:
|
||||
return QwenLatentCreator
|
||||
|
||||
def get_joint_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> list[JointBlockWrapper[Any]]:
|
||||
"""Create wrapped joint blocks for Qwen Edit."""
|
||||
return [
|
||||
QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)
|
||||
for block in self._transformer.transformer_blocks
|
||||
]
|
||||
|
||||
def get_single_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
) -> list[SingleBlockWrapper[Any]]:
|
||||
"""Qwen has no single blocks."""
|
||||
return []
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
):
|
||||
self._transformer.transformer_blocks = self._transformer.transformer_blocks[
|
||||
start_layer:end_layer
|
||||
]
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
|
||||
"""Compute and store dimensions from input image.
|
||||
|
||||
Also stores image_paths for use in encode_prompt().
|
||||
|
||||
Returns:
|
||||
(output_width, output_height) for runtime config
|
||||
"""
|
||||
vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(
|
||||
image_path
|
||||
)
|
||||
self._edit_dimensions = EditImageDimensions(
|
||||
vl_width=vl_w,
|
||||
vl_height=vl_h,
|
||||
vae_width=vae_w,
|
||||
vae_height=vae_h,
|
||||
image_paths=[str(image_path)],
|
||||
)
|
||||
return out_w, out_h
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
|
||||
"""Create initial noise latents (pure noise for edit mode)."""
|
||||
return QwenLatentCreator.create_noise(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
|
||||
def encode_prompt(
|
||||
self, prompt: str, negative_prompt: str | None = None
|
||||
) -> QwenEditPromptData:
|
||||
dims = self._edit_dimensions
|
||||
if dims is None:
|
||||
raise RuntimeError(
|
||||
"set_image_dimensions() must be called before encode_prompt() "
|
||||
"for QwenEditModelAdapter"
|
||||
)
|
||||
|
||||
if negative_prompt is None or negative_prompt == "":
|
||||
negative_prompt = " "
|
||||
|
||||
# TODO(ciaran): config is untyped and unused, unsure if Config or RuntimeConfig is intended
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_mask,
|
||||
) = self._model._encode_prompts_with_images( # pyright: ignore[reportPrivateUsage]
|
||||
prompt,
|
||||
negative_prompt,
|
||||
dims.image_paths,
|
||||
self._config,
|
||||
dims.vl_width,
|
||||
dims.vl_height,
|
||||
)
|
||||
|
||||
(
|
||||
conditioning_latents,
|
||||
qwen_image_ids,
|
||||
cond_h_patches,
|
||||
cond_w_patches,
|
||||
num_images,
|
||||
) = QwenEditUtil.create_image_conditioning_latents(
|
||||
vae=self._model.vae,
|
||||
height=dims.vae_height,
|
||||
width=dims.vae_width,
|
||||
image_paths=dims.image_paths,
|
||||
vl_width=dims.vl_width,
|
||||
vl_height=dims.vl_height,
|
||||
)
|
||||
|
||||
if num_images > 1:
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [
|
||||
(1, cond_h_patches, cond_w_patches) for _ in range(num_images)
|
||||
]
|
||||
else:
|
||||
cond_image_grid = (1, cond_h_patches, cond_w_patches)
|
||||
|
||||
return QwenEditPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_mask=prompt_mask,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_mask=negative_prompt_mask,
|
||||
conditioning_latents=conditioning_latents,
|
||||
qwen_image_ids=qwen_image_ids,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
embedded_hidden = self._transformer.img_in(hidden_states)
|
||||
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
|
||||
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: Config,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
ref_tensor = (
|
||||
hidden_states if hidden_states is not None else pooled_prompt_embeds
|
||||
)
|
||||
if ref_tensor is None:
|
||||
raise ValueError(
|
||||
"Either hidden_states or pooled_prompt_embeds is required "
|
||||
"for Qwen text embeddings"
|
||||
)
|
||||
|
||||
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001 # pyright: ignore[reportPrivateUsage]
|
||||
batch_size = ref_tensor.shape[0]
|
||||
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
|
||||
return self._transformer.time_text_embed(timestep, ref_tensor)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> RotaryEmbeddings:
|
||||
if encoder_hidden_states_mask is None:
|
||||
raise ValueError(
|
||||
"encoder_hidden_states_mask is required for Qwen RoPE computation"
|
||||
)
|
||||
|
||||
return QwenTransformer._compute_rotary_embeddings( # pyright: ignore[reportPrivateUsage]
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
pos_embed=self._transformer.pos_embed,
|
||||
config=runtime_config,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
|
||||
return QwenImage.compute_guided_noise(
|
||||
noise=noise_positive,
|
||||
noise_negative=noise_negative,
|
||||
guidance=guidance_scale,
|
||||
)
|
||||
|
||||
def _compute_dimensions_from_image(
|
||||
self, image_path: Path
|
||||
) -> tuple[int, int, int, int, int, int]:
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
|
||||
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
|
||||
image_size = pil_image.size
|
||||
|
||||
# Vision-language dimensions (384x384 target area)
|
||||
condition_image_size = 384 * 384
|
||||
condition_ratio = image_size[0] / image_size[1]
|
||||
vl_width = math.sqrt(condition_image_size * condition_ratio)
|
||||
vl_height = vl_width / condition_ratio
|
||||
vl_width = round(vl_width / 32) * 32
|
||||
vl_height = round(vl_height / 32) * 32
|
||||
|
||||
# VAE dimensions (1024x1024 target area)
|
||||
vae_image_size = 1024 * 1024
|
||||
vae_ratio = image_size[0] / image_size[1]
|
||||
vae_width = math.sqrt(vae_image_size * vae_ratio)
|
||||
vae_height = vae_width / vae_ratio
|
||||
vae_width = round(vae_width / 32) * 32
|
||||
vae_height = round(vae_height / 32) * 32
|
||||
|
||||
# Output dimensions from input image aspect ratio
|
||||
target_area = 1024 * 1024
|
||||
ratio = image_size[0] / image_size[1]
|
||||
output_width = math.sqrt(target_area * ratio)
|
||||
output_height = output_width / ratio
|
||||
output_width = round(output_width / 32) * 32
|
||||
output_height = round(output_height / 32) * 32
|
||||
|
||||
# Ensure multiple of 16 for VAE
|
||||
vae_scale_factor = 8
|
||||
multiple_of = vae_scale_factor * 2
|
||||
output_width = output_width // multiple_of * multiple_of
|
||||
output_height = output_height // multiple_of * multiple_of
|
||||
|
||||
return (
|
||||
int(vl_width),
|
||||
int(vl_height),
|
||||
int(vae_width),
|
||||
int(vae_height),
|
||||
int(output_width),
|
||||
int(output_height),
|
||||
)
|
||||
204
src/exo/worker/engines/image/models/qwen/wrappers.py
Normal file
204
src/exo/worker/engines/image/models/qwen/wrappers.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from typing import final
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
|
||||
QwenTransformerBlock,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.worker.engines.image.models.base import RotaryEmbeddings
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import JointBlockWrapper
|
||||
|
||||
|
||||
@final
|
||||
class QwenStreamModulation(BaseModel):
|
||||
model_config = ConfigDict(frozen=True, strict=True, arbitrary_types_allowed=True)
|
||||
|
||||
mod1: mx.array
|
||||
mod2: mx.array
|
||||
gate1: mx.array
|
||||
|
||||
|
||||
class QwenJointBlockWrapper(JointBlockWrapper[QwenTransformerBlock]):
|
||||
def __init__(
|
||||
self,
|
||||
block: QwenTransformerBlock,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
):
|
||||
super().__init__(block, text_seq_len)
|
||||
self._encoder_hidden_states_mask = encoder_hidden_states_mask
|
||||
|
||||
self._num_heads = block.attn.num_heads
|
||||
self._head_dim = block.attn.head_dim
|
||||
|
||||
# Intermediate state stored between _compute_qkv and _apply_output
|
||||
self._img_mod: QwenStreamModulation | None = None
|
||||
self._txt_mod: QwenStreamModulation | None = None
|
||||
|
||||
def set_encoder_mask(self, mask: mx.array | None) -> None:
|
||||
"""Set the encoder hidden states mask for attention."""
|
||||
self._encoder_hidden_states_mask = mask
|
||||
|
||||
def _compute_qkv(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: RotaryEmbeddings,
|
||||
patch_mode: bool = False,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
assert isinstance(rotary_embeddings, tuple)
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
img_seq_len = hidden_states.shape[1]
|
||||
attn = self.block.attn
|
||||
|
||||
img_mod_params = self.block.img_mod_linear(
|
||||
self.block.img_mod_silu(text_embeddings)
|
||||
)
|
||||
txt_mod_params = self.block.txt_mod_linear(
|
||||
self.block.txt_mod_silu(text_embeddings)
|
||||
)
|
||||
|
||||
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
|
||||
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
|
||||
|
||||
img_normed = self.block.img_norm1(hidden_states)
|
||||
img_modulated, img_gate1 = QwenTransformerBlock._modulate( # pyright: ignore[reportPrivateUsage]
|
||||
img_normed, img_mod1
|
||||
)
|
||||
self._img_mod = QwenStreamModulation(
|
||||
mod1=img_mod1, mod2=img_mod2, gate1=img_gate1
|
||||
)
|
||||
|
||||
txt_normed = self.block.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate( # pyright: ignore[reportPrivateUsage]
|
||||
txt_normed, txt_mod1
|
||||
)
|
||||
self._txt_mod = QwenStreamModulation(
|
||||
mod1=txt_mod1, mod2=txt_mod2, gate1=txt_gate1
|
||||
)
|
||||
|
||||
img_query = attn.to_q(img_modulated)
|
||||
img_key = attn.to_k(img_modulated)
|
||||
img_value = attn.to_v(img_modulated)
|
||||
|
||||
txt_query = attn.add_q_proj(txt_modulated)
|
||||
txt_key = attn.add_k_proj(txt_modulated)
|
||||
txt_value = attn.add_v_proj(txt_modulated)
|
||||
|
||||
img_query = mx.reshape(
|
||||
img_query, (batch_size, img_seq_len, self._num_heads, self._head_dim)
|
||||
)
|
||||
img_key = mx.reshape(
|
||||
img_key, (batch_size, img_seq_len, self._num_heads, self._head_dim)
|
||||
)
|
||||
img_value = mx.reshape(
|
||||
img_value, (batch_size, img_seq_len, self._num_heads, self._head_dim)
|
||||
)
|
||||
|
||||
txt_query = mx.reshape(
|
||||
txt_query,
|
||||
(batch_size, self._text_seq_len, self._num_heads, self._head_dim),
|
||||
)
|
||||
txt_key = mx.reshape(
|
||||
txt_key, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
|
||||
)
|
||||
txt_value = mx.reshape(
|
||||
txt_value, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
|
||||
)
|
||||
|
||||
img_query = attn.norm_q(img_query)
|
||||
img_key = attn.norm_k(img_key)
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
|
||||
|
||||
if patch_mode:
|
||||
# Slice image RoPE for patch, keep full text RoPE
|
||||
img_cos = img_cos[self._patch_start : self._patch_end]
|
||||
img_sin = img_sin[self._patch_start : self._patch_end]
|
||||
|
||||
img_query = QwenAttention._apply_rope_qwen(img_query, img_cos, img_sin) # pyright: ignore[reportPrivateUsage]
|
||||
img_key = QwenAttention._apply_rope_qwen(img_key, img_cos, img_sin) # pyright: ignore[reportPrivateUsage]
|
||||
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin) # pyright: ignore[reportPrivateUsage]
|
||||
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
img_query = mx.transpose(img_query, (0, 2, 1, 3))
|
||||
img_key = mx.transpose(img_key, (0, 2, 1, 3))
|
||||
img_value = mx.transpose(img_value, (0, 2, 1, 3))
|
||||
|
||||
txt_query = mx.transpose(txt_query, (0, 2, 1, 3))
|
||||
txt_key = mx.transpose(txt_key, (0, 2, 1, 3))
|
||||
txt_value = mx.transpose(txt_value, (0, 2, 1, 3))
|
||||
|
||||
query = mx.concatenate([txt_query, img_query], axis=2)
|
||||
key = mx.concatenate([txt_key, img_key], axis=2)
|
||||
value = mx.concatenate([txt_value, img_value], axis=2)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _compute_attention(
|
||||
self, query: mx.array, key: mx.array, value: mx.array
|
||||
) -> mx.array:
|
||||
attn = self.block.attn
|
||||
|
||||
mask = QwenAttention._convert_mask_for_qwen( # pyright: ignore[reportPrivateUsage]
|
||||
mask=self._encoder_hidden_states_mask,
|
||||
joint_seq_len=key.shape[2],
|
||||
txt_seq_len=self._text_seq_len,
|
||||
)
|
||||
|
||||
query_bshd = mx.transpose(query, (0, 2, 1, 3))
|
||||
key_bshd = mx.transpose(key, (0, 2, 1, 3))
|
||||
value_bshd = mx.transpose(value, (0, 2, 1, 3))
|
||||
|
||||
return attn._compute_attention_qwen( # pyright: ignore[reportPrivateUsage]
|
||||
query=query_bshd,
|
||||
key=key_bshd,
|
||||
value=value_bshd,
|
||||
mask=mask,
|
||||
block_idx=None,
|
||||
)
|
||||
|
||||
def _apply_output(
|
||||
self,
|
||||
attn_out: mx.array,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
attn = self.block.attn
|
||||
|
||||
assert self._img_mod is not None
|
||||
assert self._txt_mod is not None
|
||||
|
||||
txt_attn_output = attn_out[:, : self._text_seq_len, :]
|
||||
img_attn_output = attn_out[:, self._text_seq_len :, :]
|
||||
|
||||
img_attn_output = attn.attn_to_out[0](img_attn_output)
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
|
||||
hidden_states = hidden_states + self._img_mod.gate1 * img_attn_output
|
||||
encoder_hidden_states = (
|
||||
encoder_hidden_states + self._txt_mod.gate1 * txt_attn_output
|
||||
)
|
||||
|
||||
img_normed2 = self.block.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = QwenTransformerBlock._modulate( # pyright: ignore[reportPrivateUsage]
|
||||
img_normed2, self._img_mod.mod2
|
||||
)
|
||||
img_mlp_output = self.block.img_ff(img_modulated2)
|
||||
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
||||
|
||||
txt_normed2 = self.block.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate( # pyright: ignore[reportPrivateUsage]
|
||||
txt_normed2, self._txt_mod.mod2
|
||||
)
|
||||
txt_mlp_output = self.block.txt_ff(txt_modulated2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
15
src/exo/worker/engines/image/pipeline/__init__.py
Normal file
15
src/exo/worker/engines/image/pipeline/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
BlockWrapperMode,
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
|
||||
|
||||
__all__ = [
|
||||
"BlockWrapperMode",
|
||||
"DiffusionRunner",
|
||||
"ImagePatchKVCache",
|
||||
"JointBlockWrapper",
|
||||
"SingleBlockWrapper",
|
||||
]
|
||||
303
src/exo/worker/engines/image/pipeline/block_wrapper.py
Normal file
303
src/exo/worker/engines/image/pipeline/block_wrapper.py
Normal file
@@ -0,0 +1,303 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Generic, Self, TypeVar
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.engines.image.models.base import RotaryEmbeddings
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
BlockT = TypeVar("BlockT")
|
||||
|
||||
|
||||
class BlockWrapperMode(Enum):
|
||||
CACHING = "caching" # Sync mode: compute full attention, populate cache
|
||||
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
|
||||
|
||||
|
||||
class BlockWrapperMixin:
|
||||
"""Common cache management logic for block wrappers.
|
||||
|
||||
Including:
|
||||
- KV cache creation and management
|
||||
- Mode
|
||||
- Patch range setting
|
||||
"""
|
||||
|
||||
_text_seq_len: int
|
||||
_kv_cache: ImagePatchKVCache | None
|
||||
_mode: BlockWrapperMode
|
||||
_patch_start: int
|
||||
_patch_end: int
|
||||
|
||||
def _init_cache_state(self, text_seq_len: int) -> None:
|
||||
self._text_seq_len = text_seq_len
|
||||
self._kv_cache = None
|
||||
self._mode = BlockWrapperMode.CACHING
|
||||
self._patch_start = 0
|
||||
self._patch_end = 0
|
||||
|
||||
def set_patch(
|
||||
self,
|
||||
mode: BlockWrapperMode,
|
||||
patch_start: int = 0,
|
||||
patch_end: int = 0,
|
||||
) -> Self:
|
||||
"""Set mode and patch range.
|
||||
|
||||
Args:
|
||||
mode: CACHING (full attention) or PATCHED (use cached KV)
|
||||
patch_start: Start token index within image (for PATCHED mode)
|
||||
patch_end: End token index within image (for PATCHED mode)
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
self._mode = mode
|
||||
self._patch_start = patch_start
|
||||
self._patch_end = patch_end
|
||||
return self
|
||||
|
||||
def set_text_seq_len(self, text_seq_len: int) -> None:
|
||||
self._text_seq_len = text_seq_len
|
||||
|
||||
def _get_active_cache(self) -> ImagePatchKVCache | None:
|
||||
return self._kv_cache
|
||||
|
||||
def _ensure_cache(self, img_key: mx.array) -> None:
|
||||
if self._kv_cache is None:
|
||||
batch, num_heads, img_seq_len, head_dim = img_key.shape
|
||||
self._kv_cache = ImagePatchKVCache(
|
||||
batch_size=batch,
|
||||
num_heads=num_heads,
|
||||
image_seq_len=img_seq_len,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
def _cache_full_image_kv(self, img_key: mx.array, img_value: mx.array) -> None:
|
||||
self._ensure_cache(img_key)
|
||||
cache = self._get_active_cache()
|
||||
assert cache is not None
|
||||
cache.update_image_patch(0, img_key.shape[2], img_key, img_value)
|
||||
|
||||
def _cache_patch_kv(self, img_key: mx.array, img_value: mx.array) -> None:
|
||||
cache = self._get_active_cache()
|
||||
assert cache is not None
|
||||
cache.update_image_patch(self._patch_start, self._patch_end, img_key, img_value)
|
||||
|
||||
def _get_full_kv(
|
||||
self, text_key: mx.array, text_value: mx.array
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
cache = self._get_active_cache()
|
||||
assert cache is not None
|
||||
return cache.get_full_kv(text_key, text_value)
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
self._kv_cache = None
|
||||
|
||||
|
||||
class JointBlockWrapper(BlockWrapperMixin, ABC, Generic[BlockT]):
|
||||
"""Base class for joint transformer block wrappers with pipefusion support.
|
||||
|
||||
The wrapper:
|
||||
- Owns its KV cache (created lazily on first CACHING forward)
|
||||
- Controls the forward pass flow (CACHING vs PATCHED mode)
|
||||
- Handles patch slicing and cache operations
|
||||
"""
|
||||
|
||||
block: BlockT
|
||||
|
||||
def __init__(self, block: BlockT, text_seq_len: int):
|
||||
self.block = block
|
||||
self._init_cache_state(text_seq_len)
|
||||
|
||||
def set_encoder_mask(self, mask: mx.array | None) -> None: # noqa: B027
|
||||
"""Set the encoder hidden states mask for attention.
|
||||
|
||||
Override in subclasses that use attention masks
|
||||
Default is a no-op for models that don't use masks
|
||||
"""
|
||||
del mask # Unused in base class
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: RotaryEmbeddings,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
if self._mode == BlockWrapperMode.CACHING:
|
||||
return self._forward_caching(
|
||||
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
|
||||
)
|
||||
return self._forward_patched(
|
||||
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
|
||||
)
|
||||
|
||||
def _forward_caching(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: RotaryEmbeddings,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""CACHING mode: Full attention, store image K/V in cache."""
|
||||
query, key, value = self._compute_qkv(
|
||||
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
|
||||
)
|
||||
|
||||
img_key = key[:, :, self._text_seq_len :, :]
|
||||
img_value = value[:, :, self._text_seq_len :, :]
|
||||
self._cache_full_image_kv(img_key, img_value)
|
||||
|
||||
attn_out = self._compute_attention(query, key, value)
|
||||
|
||||
return self._apply_output(
|
||||
attn_out, hidden_states, encoder_hidden_states, text_embeddings
|
||||
)
|
||||
|
||||
def _forward_patched(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: RotaryEmbeddings,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
# hidden_states is already the patch (provided by runner)
|
||||
patch_hidden = hidden_states
|
||||
|
||||
query, key, value = self._compute_qkv(
|
||||
patch_hidden,
|
||||
encoder_hidden_states,
|
||||
text_embeddings,
|
||||
rotary_embeddings,
|
||||
patch_mode=True,
|
||||
)
|
||||
|
||||
text_key = key[:, :, : self._text_seq_len, :]
|
||||
text_value = value[:, :, : self._text_seq_len, :]
|
||||
img_key = key[:, :, self._text_seq_len :, :]
|
||||
img_value = value[:, :, self._text_seq_len :, :]
|
||||
|
||||
self._cache_patch_kv(img_key, img_value)
|
||||
full_key, full_value = self._get_full_kv(text_key, text_value)
|
||||
|
||||
attn_out = self._compute_attention(query, full_key, full_value)
|
||||
|
||||
return self._apply_output(
|
||||
attn_out, patch_hidden, encoder_hidden_states, text_embeddings
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _compute_qkv(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: RotaryEmbeddings,
|
||||
patch_mode: bool = False,
|
||||
) -> tuple[mx.array, mx.array, mx.array]: ...
|
||||
|
||||
@abstractmethod
|
||||
def _compute_attention(
|
||||
self, query: mx.array, key: mx.array, value: mx.array
|
||||
) -> mx.array: ...
|
||||
|
||||
@abstractmethod
|
||||
def _apply_output(
|
||||
self,
|
||||
attn_out: mx.array,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
|
||||
class SingleBlockWrapper(BlockWrapperMixin, ABC, Generic[BlockT]):
|
||||
"""Base class for single-stream transformer block wrappers.
|
||||
|
||||
Similar to JointBlockWrapper but for blocks that operate on a single
|
||||
concatenated [text, image] stream rather than separate streams.
|
||||
"""
|
||||
|
||||
block: BlockT
|
||||
|
||||
def __init__(self, block: BlockT, text_seq_len: int):
|
||||
self.block = block
|
||||
self._init_cache_state(text_seq_len)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: RotaryEmbeddings,
|
||||
) -> mx.array:
|
||||
if self._mode == BlockWrapperMode.CACHING:
|
||||
return self._forward_caching(
|
||||
hidden_states, text_embeddings, rotary_embeddings
|
||||
)
|
||||
return self._forward_patched(hidden_states, text_embeddings, rotary_embeddings)
|
||||
|
||||
def _forward_caching(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: RotaryEmbeddings,
|
||||
) -> mx.array:
|
||||
"""CACHING mode: Full attention, store image K/V in cache."""
|
||||
query, key, value = self._compute_qkv(
|
||||
hidden_states, text_embeddings, rotary_embeddings
|
||||
)
|
||||
|
||||
img_key = key[:, :, self._text_seq_len :, :]
|
||||
img_value = value[:, :, self._text_seq_len :, :]
|
||||
self._cache_full_image_kv(img_key, img_value)
|
||||
|
||||
attn_out = self._compute_attention(query, key, value)
|
||||
|
||||
return self._apply_output(attn_out, hidden_states, text_embeddings)
|
||||
|
||||
def _forward_patched(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: RotaryEmbeddings,
|
||||
) -> mx.array:
|
||||
"""PATCHED mode: Compute patch Q/K/V, use cached image K/V for attention."""
|
||||
query, key, value = self._compute_qkv(
|
||||
hidden_states, text_embeddings, rotary_embeddings, patch_mode=True
|
||||
)
|
||||
|
||||
text_key = key[:, :, : self._text_seq_len, :]
|
||||
text_value = value[:, :, : self._text_seq_len, :]
|
||||
img_key = key[:, :, self._text_seq_len :, :]
|
||||
img_value = value[:, :, self._text_seq_len :, :]
|
||||
|
||||
self._cache_patch_kv(img_key, img_value)
|
||||
full_key, full_value = self._get_full_kv(text_key, text_value)
|
||||
|
||||
attn_out = self._compute_attention(query, full_key, full_value)
|
||||
|
||||
return self._apply_output(attn_out, hidden_states, text_embeddings)
|
||||
|
||||
@abstractmethod
|
||||
def _compute_qkv(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: RotaryEmbeddings,
|
||||
patch_mode: bool = False,
|
||||
) -> tuple[mx.array, mx.array, mx.array]: ...
|
||||
|
||||
@abstractmethod
|
||||
def _compute_attention(
|
||||
self, query: mx.array, key: mx.array, value: mx.array
|
||||
) -> mx.array: ...
|
||||
|
||||
@abstractmethod
|
||||
def _apply_output(
|
||||
self,
|
||||
attn_out: mx.array,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array: ...
|
||||
72
src/exo/worker/engines/image/pipeline/kv_cache.py
Normal file
72
src/exo/worker/engines/image/pipeline/kv_cache.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class ImagePatchKVCache:
|
||||
"""KV cache that stores only IMAGE K/V with patch-level updates.
|
||||
|
||||
Only caches image K/V since:
|
||||
- Text K/V is always computed fresh (same for all patches)
|
||||
- Only image portion needs stale/fresh cache management across patches
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
image_seq_len: int,
|
||||
head_dim: int,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.num_heads = num_heads
|
||||
self.image_seq_len = image_seq_len
|
||||
self.head_dim = head_dim
|
||||
self._dtype = dtype
|
||||
|
||||
self.key_cache = mx.zeros(
|
||||
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
|
||||
)
|
||||
self.value_cache = mx.zeros(
|
||||
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
|
||||
)
|
||||
|
||||
def update_image_patch(
|
||||
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
|
||||
) -> None:
|
||||
"""Update cache with fresh K/V for an image patch slice.
|
||||
|
||||
Args:
|
||||
patch_start: Start token index within image portion (0-indexed)
|
||||
patch_end: End token index within image portion
|
||||
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
|
||||
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
|
||||
"""
|
||||
self.key_cache[:, :, patch_start:patch_end, :] = key
|
||||
self.value_cache[:, :, patch_start:patch_end, :] = value
|
||||
|
||||
def get_full_kv(
|
||||
self, text_key: mx.array, text_value: mx.array
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
|
||||
|
||||
Args:
|
||||
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
|
||||
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
|
||||
|
||||
Returns:
|
||||
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
|
||||
"""
|
||||
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
|
||||
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
|
||||
return full_key, full_value
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset cache to zeros."""
|
||||
self.key_cache = mx.zeros(
|
||||
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
|
||||
dtype=self._dtype,
|
||||
)
|
||||
self.value_cache = mx.zeros(
|
||||
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
|
||||
dtype=self._dtype,
|
||||
)
|
||||
968
src/exo/worker/engines/image/pipeline/runner.py
Normal file
968
src/exo/worker/engines/image/pipeline/runner.py
Normal file
@@ -0,0 +1,968 @@
|
||||
from math import ceil
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.utils.exceptions import StopImageGenerationException
|
||||
from tqdm import tqdm
|
||||
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
ModelAdapter,
|
||||
PromptData,
|
||||
RotaryEmbeddings,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
BlockWrapperMode,
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
def calculate_patch_heights(latent_height: int, num_patches: int):
|
||||
patch_height = ceil(latent_height / num_patches)
|
||||
|
||||
actual_num_patches = ceil(latent_height / patch_height)
|
||||
patch_heights = [patch_height] * (actual_num_patches - 1)
|
||||
|
||||
last_height = latent_height - patch_height * (actual_num_patches - 1)
|
||||
patch_heights.append(last_height)
|
||||
|
||||
return patch_heights, actual_num_patches
|
||||
|
||||
|
||||
def calculate_token_indices(patch_heights: list[int], latent_width: int):
|
||||
tokens_per_row = latent_width
|
||||
|
||||
token_ranges = []
|
||||
cumulative_height = 0
|
||||
|
||||
for h in patch_heights:
|
||||
start_token = tokens_per_row * cumulative_height
|
||||
end_token = tokens_per_row * (cumulative_height + h)
|
||||
|
||||
token_ranges.append((start_token, end_token))
|
||||
cumulative_height += h
|
||||
|
||||
return token_ranges
|
||||
|
||||
|
||||
class DiffusionRunner:
|
||||
"""Orchestrates the diffusion loop for image generation.
|
||||
|
||||
In distributed mode, it implements PipeFusion with:
|
||||
- Sync pipeline for initial timesteps (full image, all devices in lockstep)
|
||||
- Async pipeline for later timesteps (patches processed independently)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
adapter: ModelAdapter[Any, Any],
|
||||
group: Optional[mx.distributed.Group],
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
num_patches: Optional[int] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.adapter = adapter
|
||||
self.group = group
|
||||
|
||||
if group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.next_rank = 0
|
||||
self.prev_rank = 0
|
||||
self.start_layer = 0
|
||||
self.end_layer = config.total_blocks
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.next_rank = (self.rank + 1) % self.world_size
|
||||
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
self.num_patches = num_patches if num_patches else max(1, self.world_size)
|
||||
|
||||
self.total_joint = config.joint_block_count
|
||||
self.total_single = config.single_block_count
|
||||
self.total_layers = config.total_blocks
|
||||
|
||||
self._guidance_override: float | None = None
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
def _compute_assigned_blocks(self) -> None:
|
||||
"""Determine which joint/single blocks this stage owns."""
|
||||
start = self.start_layer
|
||||
end = self.end_layer
|
||||
|
||||
if end <= self.total_joint:
|
||||
self.joint_start = start
|
||||
self.joint_end = end
|
||||
self.single_start = 0
|
||||
self.single_end = 0
|
||||
elif start >= self.total_joint:
|
||||
self.joint_start = 0
|
||||
self.joint_end = 0
|
||||
self.single_start = start - self.total_joint
|
||||
self.single_end = end - self.total_joint
|
||||
else:
|
||||
self.joint_start = start
|
||||
self.joint_end = self.total_joint
|
||||
self.single_start = 0
|
||||
self.single_end = end - self.total_joint
|
||||
|
||||
self.has_joint_blocks = self.joint_end > self.joint_start
|
||||
self.has_single_blocks = self.single_end > self.single_start
|
||||
|
||||
self.owns_concat_stage = self.has_joint_blocks and (
|
||||
self.has_single_blocks or self.end_layer == self.total_joint
|
||||
)
|
||||
|
||||
# Wrappers created lazily on first forward (need text_seq_len)
|
||||
self.joint_block_wrappers: list[JointBlockWrapper[Any]] | None = None
|
||||
self.single_block_wrappers: list[SingleBlockWrapper[Any]] | None = None
|
||||
self._wrappers_initialized = False
|
||||
self._current_text_seq_len: int | None = None
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self.rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self.rank == self.world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return self.group is not None
|
||||
|
||||
def _get_effective_guidance_scale(self) -> float | None:
|
||||
if self._guidance_override is not None:
|
||||
return self._guidance_override
|
||||
return self.config.guidance_scale
|
||||
|
||||
def _ensure_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> None:
|
||||
"""Lazily create block wrappers on first forward pass.
|
||||
|
||||
Wrappers need text_seq_len which is only known after prompt encoding.
|
||||
Re-initializes if text_seq_len changes (e.g., warmup vs real generation).
|
||||
"""
|
||||
if self._wrappers_initialized and self._current_text_seq_len == text_seq_len:
|
||||
return
|
||||
|
||||
self.joint_block_wrappers = self.adapter.get_joint_block_wrappers(
|
||||
text_seq_len=text_seq_len,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
)
|
||||
self.single_block_wrappers = self.adapter.get_single_block_wrappers(
|
||||
text_seq_len=text_seq_len,
|
||||
)
|
||||
self._wrappers_initialized = True
|
||||
self._current_text_seq_len = text_seq_len
|
||||
|
||||
def _reset_all_caches(self) -> None:
|
||||
"""Reset KV caches on all wrappers for a new generation."""
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.reset_cache()
|
||||
if self.single_block_wrappers:
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.reset_cache()
|
||||
|
||||
def _set_text_seq_len(self, text_seq_len: int) -> None:
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_text_seq_len(text_seq_len)
|
||||
if self.single_block_wrappers:
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_text_seq_len(text_seq_len)
|
||||
|
||||
def _calculate_capture_steps(
|
||||
self,
|
||||
partial_images: int,
|
||||
init_time_step: int,
|
||||
num_inference_steps: int,
|
||||
) -> set[int]:
|
||||
"""Calculate which timesteps should produce partial images.
|
||||
|
||||
Places the first partial after step 1 for fast initial feedback,
|
||||
then evenly spaces remaining partials with equal gaps between them
|
||||
and from the last partial to the final image.
|
||||
|
||||
Args:
|
||||
partial_images: Number of partial images to capture
|
||||
init_time_step: Starting timestep (for img2img this may not be 0)
|
||||
num_inference_steps: Total inference steps
|
||||
|
||||
Returns:
|
||||
Set of timestep indices to capture
|
||||
"""
|
||||
if partial_images <= 0:
|
||||
return set()
|
||||
|
||||
total_steps = num_inference_steps - init_time_step
|
||||
if total_steps <= 1:
|
||||
return set()
|
||||
|
||||
if partial_images >= total_steps - 1:
|
||||
return set(range(init_time_step, num_inference_steps - 1))
|
||||
|
||||
capture_steps: set[int] = set()
|
||||
|
||||
first_capture = init_time_step + 1
|
||||
capture_steps.add(first_capture)
|
||||
|
||||
if partial_images == 1:
|
||||
return capture_steps
|
||||
|
||||
final_step = num_inference_steps - 1
|
||||
remaining_range = final_step - first_capture
|
||||
|
||||
for i in range(1, partial_images):
|
||||
step_idx = first_capture + int(i * remaining_range / partial_images)
|
||||
capture_steps.add(step_idx)
|
||||
|
||||
return capture_steps
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
runtime_config: Config,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
partial_images: int = 0,
|
||||
guidance_override: float | None = None,
|
||||
negative_prompt: str | None = None,
|
||||
num_sync_steps: int = 1,
|
||||
):
|
||||
"""Primary entry point for image generation.
|
||||
|
||||
Orchestrates the full generation flow:
|
||||
1. Create runtime config
|
||||
2. Create initial latents
|
||||
3. Encode prompt
|
||||
4. Run diffusion loop (yielding partials if requested)
|
||||
5. Decode to image
|
||||
|
||||
Args:
|
||||
settings: Generation config (steps, height, width)
|
||||
prompt: Text prompt
|
||||
seed: Random seed
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
guidance_override: Optional override for guidance scale (CFG)
|
||||
|
||||
Yields:
|
||||
Partial images as (GeneratedImage, partial_index, total_partials) tuples
|
||||
Final GeneratedImage
|
||||
"""
|
||||
self._guidance_override = guidance_override
|
||||
latents = self.adapter.create_latents(seed, runtime_config)
|
||||
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
|
||||
|
||||
capture_steps = self._calculate_capture_steps(
|
||||
partial_images=partial_images,
|
||||
init_time_step=runtime_config.init_time_step,
|
||||
num_inference_steps=runtime_config.num_inference_steps,
|
||||
)
|
||||
|
||||
diffusion_gen = self._run_diffusion_loop(
|
||||
latents=latents,
|
||||
prompt_data=prompt_data,
|
||||
runtime_config=runtime_config,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
capture_steps=capture_steps,
|
||||
num_sync_steps=num_sync_steps,
|
||||
)
|
||||
|
||||
partial_index = 0
|
||||
total_partials = len(capture_steps)
|
||||
|
||||
if capture_steps:
|
||||
try:
|
||||
while True:
|
||||
partial_latents, _step = next(diffusion_gen)
|
||||
if self.is_last_stage:
|
||||
partial_image = self.adapter.decode_latents(
|
||||
partial_latents, runtime_config, seed, prompt
|
||||
)
|
||||
yield (partial_image, partial_index, total_partials)
|
||||
partial_index += 1
|
||||
except StopIteration as e:
|
||||
latents = e.value
|
||||
else:
|
||||
try:
|
||||
while True:
|
||||
next(diffusion_gen)
|
||||
except StopIteration as e:
|
||||
latents = e.value
|
||||
|
||||
if self.is_last_stage:
|
||||
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt)
|
||||
|
||||
def _run_diffusion_loop(
|
||||
self,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
runtime_config: Config,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
num_sync_steps: int,
|
||||
capture_steps: set[int] | None = None,
|
||||
):
|
||||
if capture_steps is None:
|
||||
capture_steps = set()
|
||||
|
||||
self._reset_all_caches()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
|
||||
ctx = self.adapter.model.callbacks.start(
|
||||
seed=seed, prompt=prompt, config=runtime_config
|
||||
)
|
||||
|
||||
ctx.before_loop(
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
for t in time_steps:
|
||||
try:
|
||||
latents = self._diffusion_step(
|
||||
t=t,
|
||||
config=runtime_config,
|
||||
latents=latents,
|
||||
prompt_data=prompt_data,
|
||||
num_sync_steps=num_sync_steps,
|
||||
)
|
||||
|
||||
ctx.in_loop(
|
||||
t=t,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
|
||||
if t in capture_steps and self.is_last_stage:
|
||||
yield (latents, t)
|
||||
|
||||
except KeyboardInterrupt: # noqa: PERF203
|
||||
ctx.interruption(t=t, latents=latents)
|
||||
raise StopImageGenerationException(
|
||||
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
|
||||
) from None
|
||||
|
||||
ctx.after_loop(latents=latents)
|
||||
|
||||
return latents
|
||||
|
||||
def _forward_pass(
|
||||
self,
|
||||
latents: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
t: int,
|
||||
config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
conditioning_latents: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Run a single forward pass through the transformer.
|
||||
Args:
|
||||
latents: Input latents (already scaled by caller)
|
||||
prompt_embeds: Text embeddings
|
||||
pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)
|
||||
t: Current timestep
|
||||
config: Runtime configuration
|
||||
encoder_hidden_states_mask: Attention mask for text (Qwen)
|
||||
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
|
||||
conditioning_latents: Conditioning latents for edit mode
|
||||
|
||||
Returns:
|
||||
Noise prediction tensor
|
||||
"""
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
|
||||
self._ensure_wrappers(text_seq_len, encoder_hidden_states_mask)
|
||||
|
||||
if self.joint_block_wrappers and encoder_hidden_states_mask is not None:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_hidden_states_mask)
|
||||
|
||||
scaled_latents = config.scheduler.scale_model_input(latents, t)
|
||||
|
||||
# For edit mode: concatenate with conditioning latents
|
||||
original_latent_tokens = scaled_latents.shape[1]
|
||||
if conditioning_latents is not None:
|
||||
scaled_latents = mx.concatenate(
|
||||
[scaled_latents, conditioning_latents], axis=1
|
||||
)
|
||||
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
scaled_latents, prompt_embeds
|
||||
)
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds, hidden_states=hidden_states
|
||||
)
|
||||
rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
hidden_states = self.adapter.merge_streams(
|
||||
hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
)
|
||||
|
||||
# Extract image portion
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
# For edit mode: extract only the generated portion (exclude conditioning latents)
|
||||
if conditioning_latents is not None:
|
||||
hidden_states = hidden_states[:, :original_latent_tokens, ...]
|
||||
|
||||
return self.adapter.final_projection(hidden_states, text_embeddings)
|
||||
|
||||
def _diffusion_step(
|
||||
self,
|
||||
t: int,
|
||||
config: Config,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
num_sync_steps: int,
|
||||
) -> mx.array:
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif t < config.init_time_step + num_sync_steps:
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
else:
|
||||
return self._async_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
is_first_async_step=t == config.init_time_step + num_sync_steps,
|
||||
)
|
||||
|
||||
def _single_node_step(
|
||||
self,
|
||||
t: int,
|
||||
config: Config,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
step_latents = mx.concatenate([latents, latents], axis=0)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = latents
|
||||
|
||||
noise = self._forward_pass(
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=cond_latents,
|
||||
)
|
||||
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale=guidance_scale
|
||||
)
|
||||
|
||||
return config.scheduler.step(noise=noise, timestep=t, latents=latents)
|
||||
|
||||
def _create_patches(
|
||||
self,
|
||||
latents: mx.array,
|
||||
config: Config,
|
||||
) -> tuple[list[mx.array], list[tuple[int, int]]]:
|
||||
latent_height = config.height // 16
|
||||
latent_width = config.width // 16
|
||||
|
||||
patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)
|
||||
token_indices = calculate_token_indices(patch_heights, latent_width)
|
||||
|
||||
patch_latents = [latents[:, start:end, :] for start, end in token_indices]
|
||||
|
||||
return patch_latents, token_indices
|
||||
|
||||
def _run_sync_pass(
|
||||
self,
|
||||
t: int,
|
||||
config: Config,
|
||||
scaled_hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
encoder_hidden_states_mask: mx.array | None,
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] | None,
|
||||
kontext_image_ids: mx.array | None,
|
||||
num_img_tokens: int,
|
||||
original_latent_tokens: int,
|
||||
conditioning_latents: mx.array | None,
|
||||
) -> mx.array | None:
|
||||
hidden_states = scaled_hidden_states
|
||||
batch_size = hidden_states.shape[0]
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
hidden_dim = self.adapter.hidden_dim
|
||||
dtype = scaled_hidden_states.dtype
|
||||
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_hidden_states_mask)
|
||||
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
if self.is_first_stage:
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
hidden_states, prompt_embeds
|
||||
)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
concatenated = self.adapter.merge_streams(
|
||||
hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
if conditioning_latents is not None:
|
||||
hidden_states = hidden_states[:, :original_latent_tokens, ...]
|
||||
|
||||
if self.is_last_stage:
|
||||
return self.adapter.final_projection(hidden_states, text_embeddings)
|
||||
|
||||
return None
|
||||
|
||||
def _sync_pipeline_step(
|
||||
self,
|
||||
t: int,
|
||||
config: Config,
|
||||
hidden_states: mx.array,
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t)
|
||||
original_latent_tokens = scaled_hidden_states.shape[1]
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
step_latents = mx.concatenate(
|
||||
[scaled_hidden_states, scaled_hidden_states], axis=0
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = scaled_hidden_states
|
||||
|
||||
if cond_latents is not None:
|
||||
num_img_tokens = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
encoder_mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
|
||||
hidden_states = config.scheduler.step(
|
||||
noise=noise, timestep=t, latents=prev_latents
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
else:
|
||||
hidden_states = prev_latents
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _async_pipeline_step(
|
||||
self,
|
||||
t: int,
|
||||
config: Config,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
is_first_async_step: bool,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
|
||||
for patch_idx in range(len(patch_latents)):
|
||||
patch = patch_latents[patch_idx]
|
||||
|
||||
if (
|
||||
self.is_first_stage
|
||||
and not self.is_last_stage
|
||||
and not is_first_async_step
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=step_patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=prompt_embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
|
||||
patch_latents[patch_idx] = config.scheduler.step(
|
||||
noise=noise,
|
||||
timestep=t,
|
||||
latents=prev_patch_latents[patch_idx],
|
||||
)
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
def _run_single_patch_pass(
|
||||
self,
|
||||
patch: mx.array,
|
||||
patch_idx: int,
|
||||
token_indices: tuple[int, int],
|
||||
prompt_embeds: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
image_rotary_embeddings: RotaryEmbeddings,
|
||||
encoder_hidden_states: mx.array | None,
|
||||
) -> tuple[mx.array | None, mx.array | None]:
|
||||
"""Process a single patch through the forward pipeline.
|
||||
|
||||
Handles stage-to-stage communication (stage i -> stage i+1).
|
||||
Ring communication (last stage -> first stage) is handled by the caller.
|
||||
|
||||
Args:
|
||||
patch: The patch latents to process
|
||||
patch_idx: Index of this patch (0-indexed)
|
||||
token_indices: (start_token, end_token) for this patch
|
||||
prompt_embeds: Text embeddings (for compute_embeddings on first stage)
|
||||
text_embeddings: Precomputed text embeddings
|
||||
image_rotary_embeddings: Precomputed rotary embeddings
|
||||
encoder_hidden_states: Encoder hidden states (passed between patches)
|
||||
|
||||
Returns:
|
||||
(noise_prediction, encoder_hidden_states) - noise is None for non-last stages
|
||||
"""
|
||||
start_token, end_token = token_indices
|
||||
batch_size = patch.shape[0]
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
hidden_dim = self.adapter.hidden_dim
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
patch, prompt_embeds
|
||||
)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
encoder_hidden_states, patch = wrapper(
|
||||
hidden_states=patch,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
patch = wrapper(
|
||||
hidden_states=patch,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
if self.is_last_stage:
|
||||
patch_img_only = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
|
||||
|
||||
return noise, encoder_hidden_states
|
||||
@@ -13,11 +13,17 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
@@ -335,7 +341,33 @@ def tensor_auto_parallel(
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
if isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, MiniMaxModel):
|
||||
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -343,6 +375,15 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GptOssModel):
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -377,6 +418,34 @@ class TensorParallelShardingStrategy(ABC):
|
||||
) -> nn.Module: ...
|
||||
|
||||
|
||||
class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(LlamaModel, model)
|
||||
for layer in model.layers:
|
||||
# Force load weights before sharding to avoid FAST_SYNCH deadlock
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
if layer.self_attn.n_kv_heads is not None:
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
@@ -403,6 +472,105 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
else:
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
layer.self_attn.kv_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
else:
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.up_proj
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -455,3 +623,58 @@ class ShardedQwenMoE(CustomMlxLayer):
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
layer.self_attn.num_key_value_groups = (
|
||||
layer.self_attn.num_attention_heads
|
||||
// layer.self_attn.num_key_value_heads
|
||||
)
|
||||
|
||||
layer.self_attn.sinks = layer.self_attn.sinks[
|
||||
layer.self_attn.num_attention_heads
|
||||
* self.group.rank() : layer.self_attn.num_attention_heads
|
||||
* (self.group.rank() + 1)
|
||||
]
|
||||
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
@@ -9,13 +9,15 @@ from loguru import logger
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import ImageEditsInternalParams
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
@@ -28,6 +30,7 @@ from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
Shutdown,
|
||||
Task,
|
||||
TaskStatus,
|
||||
@@ -93,6 +96,10 @@ class Worker:
|
||||
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
|
||||
# Buffer for input image chunks (for image editing)
|
||||
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
|
||||
self.input_chunk_counts: dict[CommandId, int] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
@@ -157,6 +164,17 @@ class Worker:
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
if cmd_id not in self.input_chunk_buffer:
|
||||
self.input_chunk_buffer[cmd_id] = {}
|
||||
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
|
||||
|
||||
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
|
||||
event.chunk.data
|
||||
)
|
||||
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
await anyio.sleep(0.1)
|
||||
@@ -169,6 +187,8 @@ class Worker:
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
self.state.tasks,
|
||||
self.input_chunk_buffer,
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
continue
|
||||
@@ -232,6 +252,46 @@ class Worker:
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
chunks = self.input_chunk_buffer.get(cmd_id, {})
|
||||
assembled = "".join(chunks[i] for i in range(len(chunks)))
|
||||
logger.info(
|
||||
f"Assembled input image from {len(chunks)} chunks, "
|
||||
f"total size: {len(assembled)} bytes"
|
||||
)
|
||||
# Create modified task with assembled image data
|
||||
modified_task = ImageEdits(
|
||||
task_id=task.task_id,
|
||||
command_id=task.command_id,
|
||||
instance_id=task.instance_id,
|
||||
task_status=task.task_status,
|
||||
task_params=ImageEditsInternalParams(
|
||||
image_data=assembled,
|
||||
total_input_chunks=task.task_params.total_input_chunks,
|
||||
prompt=task.task_params.prompt,
|
||||
model=task.task_params.model,
|
||||
n=task.task_params.n,
|
||||
quality=task.task_params.quality,
|
||||
output_format=task.task_params.output_format,
|
||||
response_format=task.task_params.response_format,
|
||||
size=task.task_params.size,
|
||||
image_strength=task.task_params.image_strength,
|
||||
bench=task.task_params.bench,
|
||||
stream=task.task_params.stream,
|
||||
partial_images=task.task_params.partial_images,
|
||||
advanced_params=task.task_params.advanced_params,
|
||||
),
|
||||
)
|
||||
# Cleanup buffers
|
||||
if cmd_id in self.input_chunk_buffer:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -49,6 +51,8 @@ def plan(
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
tasks: Mapping[TaskId, Task],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_counts: Mapping[CommandId, int] | None = None,
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
@@ -58,7 +62,7 @@ def plan(
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
)
|
||||
|
||||
|
||||
@@ -262,14 +266,24 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
if not isinstance(task, ChatCompletion):
|
||||
# TODO(ciaran): do this better!
|
||||
if not isinstance(task, (ChatCompletion, ImageGeneration, ImageEdits)):
|
||||
continue
|
||||
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
|
||||
continue
|
||||
|
||||
# For ImageEdits tasks, verify all input chunks have been received
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
for runner in runners.values():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
@@ -12,8 +13,11 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
|
||||
from exo.shared.types.chunks import ImageChunk, TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -24,6 +28,8 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -33,6 +39,8 @@ from exo.shared.types.tasks import (
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
@@ -48,7 +56,15 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.image import (
|
||||
DistributedImageModel,
|
||||
generate_image,
|
||||
initialize_image_model,
|
||||
warmup_image_generator,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -79,7 +95,7 @@ def main(
|
||||
|
||||
setup_start_time = time.time()
|
||||
|
||||
model = None
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
|
||||
@@ -133,15 +149,25 @@ def main(
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -151,15 +177,30 @@ def main(
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
logger.info("warmup completed (non-primary node)")
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
@@ -173,7 +214,7 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert model
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
|
||||
@@ -240,6 +281,90 @@ def main(
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
@@ -329,6 +454,72 @@ def parse_thinking_models(
|
||||
yield response
|
||||
|
||||
|
||||
def _send_image_chunk(
|
||||
encoded_data: str,
|
||||
command_id: CommandId,
|
||||
model_id: ModelId,
|
||||
event_sender: MpSender[Event],
|
||||
image_index: int,
|
||||
is_partial: bool,
|
||||
partial_index: int | None = None,
|
||||
total_partials: int | None = None,
|
||||
stats: ImageGenerationStats | None = None,
|
||||
) -> None:
|
||||
"""Send base64-encoded image data as chunks via events."""
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
for chunk_index, chunk_data in enumerate(data_chunks):
|
||||
# Only include stats on the last chunk of the final image
|
||||
chunk_stats = (
|
||||
stats if chunk_index == total_chunks - 1 and not is_partial else None
|
||||
)
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=image_index,
|
||||
is_partial=is_partial,
|
||||
partial_index=partial_index,
|
||||
total_partials=total_partials,
|
||||
stats=chunk_stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _process_image_response(
|
||||
response: ImageGenerationResponse | PartialImageResponse,
|
||||
command_id: CommandId,
|
||||
shard_metadata: ShardMetadata,
|
||||
event_sender: MpSender[Event],
|
||||
image_index: int,
|
||||
) -> None:
|
||||
"""Process a single image response and send chunks."""
|
||||
encoded_data = base64.b64encode(response.image_data).decode("utf-8")
|
||||
is_partial = isinstance(response, PartialImageResponse)
|
||||
# Extract stats from final ImageGenerationResponse if available
|
||||
stats = response.stats if isinstance(response, ImageGenerationResponse) else None
|
||||
_send_image_chunk(
|
||||
encoded_data=encoded_data,
|
||||
command_id=command_id,
|
||||
model_id=shard_metadata.model_card.model_id,
|
||||
event_sender=event_sender,
|
||||
image_index=response.partial_index if is_partial else image_index,
|
||||
is_partial=is_partial,
|
||||
partial_index=response.partial_index if is_partial else None,
|
||||
total_partials=response.total_partials if is_partial else None,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import BaseTask, TaskId
|
||||
@@ -38,6 +38,7 @@ def get_pipeline_shard_metadata(
|
||||
n_layers=32,
|
||||
hidden_size=2048,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
|
||||
@@ -11,7 +11,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
@@ -87,6 +87,7 @@ def run_gpt_oss_pipeline_device(
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
@@ -156,6 +157,7 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
|
||||
Reference in New Issue
Block a user