mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-11 07:29:36 -05:00
Compare commits
222 Commits
alexcheema
...
ciaran/ima
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b6a6059b9 | ||
|
|
fff08568ba | ||
|
|
edde03d74c | ||
|
|
d67eb9f2af | ||
|
|
bcf1fbd561 | ||
|
|
9067033f20 | ||
|
|
17ef7d8838 | ||
|
|
23962ab4e2 | ||
|
|
59c5de8256 | ||
|
|
a43a14da1f | ||
|
|
bb46c12878 | ||
|
|
79b5316efe | ||
|
|
40d49c2720 | ||
|
|
f87d06b1f1 | ||
|
|
617c6ffdcb | ||
|
|
acd246f49e | ||
|
|
d5536c1a2b | ||
|
|
2aba3fc9a9 | ||
|
|
c7cb22d546 | ||
|
|
5a429f4ab6 | ||
|
|
42c427a5bb | ||
|
|
4b8976be51 | ||
|
|
85bf4aba1c | ||
|
|
de12873b1a | ||
|
|
a86bb97d65 | ||
|
|
a5c6db7145 | ||
|
|
e74345bb09 | ||
|
|
58d1f159b7 | ||
|
|
8183225714 | ||
|
|
46f957ee5b | ||
|
|
a2f52e04e3 | ||
|
|
859960608b | ||
|
|
80ad016004 | ||
|
|
d8938e6e72 | ||
|
|
bd6a6cc6d3 | ||
|
|
a88d588de4 | ||
|
|
08e8a30fb7 | ||
|
|
d926df8f95 | ||
|
|
4ff550106d | ||
|
|
1d168dfe61 | ||
|
|
f813b9f5e1 | ||
|
|
0f96083d48 | ||
|
|
1a1b394f6d | ||
|
|
6ab5a9d3d4 | ||
|
|
90bf4608df | ||
|
|
f2a0fdf25c | ||
|
|
f574b3f57e | ||
|
|
5cca9d8493 | ||
|
|
3cd421079b | ||
|
|
d9eb4637ee | ||
|
|
19f52e80fd | ||
|
|
3f4162b732 | ||
|
|
cad86ee76e | ||
|
|
d7be6a09b0 | ||
|
|
79603e73ed | ||
|
|
78901cfe23 | ||
|
|
c0ac199ab8 | ||
|
|
b70d6abfa2 | ||
|
|
16bfab9bab | ||
|
|
28ee6f6370 | ||
|
|
6b299bab8f | ||
|
|
a3754a60b6 | ||
|
|
06039f93f5 | ||
|
|
fcfecc9cd8 | ||
|
|
ba798ae4f9 | ||
|
|
9a0e1e93a9 | ||
|
|
196f504c82 | ||
|
|
e3d89b8d63 | ||
|
|
cb8079525c | ||
|
|
cb03c62c4a | ||
|
|
0653668048 | ||
|
|
0054bc4c14 | ||
|
|
b7b682b7bb | ||
|
|
f7a651c1c1 | ||
|
|
98e8d74cea | ||
|
|
27567f8a4e | ||
|
|
28227bb45a | ||
|
|
7683d4a21f | ||
|
|
0a3cb77a29 | ||
|
|
3f5810c1fe | ||
|
|
fc62ae1b9b | ||
|
|
ec5bad4254 | ||
|
|
f9f54be32b | ||
|
|
36daf9183f | ||
|
|
5d38ffc77e | ||
|
|
1b4851765a | ||
|
|
8787eaf3df | ||
|
|
e1e3aa7a5e | ||
|
|
0fe5239273 | ||
|
|
7eddf7404b | ||
|
|
5f3bc30f17 | ||
|
|
90a7e6601d | ||
|
|
ce2691c8d3 | ||
|
|
076d2901e8 | ||
|
|
7a733b584c | ||
|
|
94fee6f2d2 | ||
|
|
ef4fe09424 | ||
|
|
2919bcf21d | ||
|
|
dd84cc9ca2 | ||
|
|
5a74d76d41 | ||
|
|
e115814c74 | ||
|
|
d85432d4f0 | ||
|
|
da823a2b02 | ||
|
|
8576f4252b | ||
|
|
7ca0bc5b55 | ||
|
|
db24f052d7 | ||
|
|
7b8382be10 | ||
|
|
d3685b0eb5 | ||
|
|
93f4bdc5f9 | ||
|
|
8eea0327b8 | ||
|
|
085358e5e0 | ||
|
|
546efe4dd2 | ||
|
|
4ddfb6e254 | ||
|
|
12f20fd94e | ||
|
|
f7ba70d5ae | ||
|
|
4ecad10a66 | ||
|
|
552ae776fe | ||
|
|
6e0a6e8956 | ||
|
|
e8b0a2124c | ||
|
|
129df1ec89 | ||
|
|
a87fe26973 | ||
|
|
a9ea223dc7 | ||
|
|
0af3349f2f | ||
|
|
20e3319a3e | ||
|
|
4c88fac266 | ||
|
|
e1d916f743 | ||
|
|
09c9b2e29f | ||
|
|
b6359a7199 | ||
|
|
b5a043f676 | ||
|
|
55e690fd49 | ||
|
|
9e4ffb11ec | ||
|
|
d665a8d05a | ||
|
|
cac77816be | ||
|
|
25b9c3369e | ||
|
|
c19c5b4080 | ||
|
|
9592f8b6b0 | ||
|
|
7d7c16ebc1 | ||
|
|
450d0ba923 | ||
|
|
ea64062362 | ||
|
|
206b12e912 | ||
|
|
eecc1da596 | ||
|
|
44e68e4498 | ||
|
|
f1548452fa | ||
|
|
97769c82a9 | ||
|
|
26e5b03285 | ||
|
|
8f93a1ff78 | ||
|
|
e07dcc43b9 | ||
|
|
f91d0797fb | ||
|
|
aaeebaf79e | ||
|
|
c3075a003e | ||
|
|
be796e55ac | ||
|
|
6e0c611f37 | ||
|
|
88996eddcb | ||
|
|
fb4fae51fa | ||
|
|
dbefc209f5 | ||
|
|
e6dd95524c | ||
|
|
c2a9e5e53b | ||
|
|
21587898bc | ||
|
|
b6f23d0b01 | ||
|
|
f00ba03f4b | ||
|
|
73e3713296 | ||
|
|
ecca6b4d20 | ||
|
|
8bac08a236 | ||
|
|
e7cca752fd | ||
|
|
540fe8b278 | ||
|
|
2972f4620c | ||
|
|
0ed81d8afa | ||
|
|
66a24d59b9 | ||
|
|
5dcc359dba | ||
|
|
c2a4d61865 | ||
|
|
ba12ee4897 | ||
|
|
bcd69a3b01 | ||
|
|
f5eb5d0338 | ||
|
|
058aff5145 | ||
|
|
5cb0bc6a63 | ||
|
|
c3aab450c6 | ||
|
|
cf27673e20 | ||
|
|
96c165e297 | ||
|
|
2a589177cd | ||
|
|
f782b619b6 | ||
|
|
dc661e4b5e | ||
|
|
8b7d8ef394 | ||
|
|
7dd2b328c8 | ||
|
|
73a165702d | ||
|
|
0c76978b35 | ||
|
|
25188c845e | ||
|
|
df94169aba | ||
|
|
a2d4c0de2a | ||
|
|
2edbc7e026 | ||
|
|
8f6e360d21 | ||
|
|
085b966a5f | ||
|
|
c64a55bfed | ||
|
|
fee716faab | ||
|
|
b88c89ee9c | ||
|
|
9ef7b913e2 | ||
|
|
0daa4b36db | ||
|
|
3c2da43792 | ||
|
|
8c4c53b50a | ||
|
|
b2beb4c9cd | ||
|
|
098a11b262 | ||
|
|
bedb9045a0 | ||
|
|
8e23841b4e | ||
|
|
4420eac10d | ||
|
|
d0772e9e0f | ||
|
|
8d861168f1 | ||
|
|
242648dff4 | ||
|
|
9b06b754cb | ||
|
|
1603984f45 | ||
|
|
f9418843f8 | ||
|
|
877e7196c3 | ||
|
|
db7c4670b9 | ||
|
|
4f6fcd9e93 | ||
|
|
839b67f318 | ||
|
|
47b8e0ce12 | ||
|
|
17f9b583a4 | ||
|
|
844bcc7ce6 | ||
|
|
c1be5184b2 | ||
|
|
1ec550dff1 | ||
|
|
283c0e39e4 | ||
|
|
35be4c55c3 | ||
|
|
31d4cd8409 | ||
|
|
8a6da58404 |
@@ -20,6 +20,8 @@ struct ContentView: View {
|
||||
@State private var showDebugInfo = false
|
||||
@State private var bugReportInFlight = false
|
||||
@State private var bugReportMessage: String?
|
||||
@State private var showAdvancedOptions = false
|
||||
@State private var pendingNamespace: String = ""
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
@@ -197,6 +199,8 @@ struct ContentView: View {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
.padding(.bottom, 8)
|
||||
advancedOptionsSection
|
||||
.padding(.bottom, 8)
|
||||
debugSection
|
||||
.padding(.bottom, 8)
|
||||
controlButton(title: "Quit", tint: .secondary) {
|
||||
@@ -327,6 +331,47 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private var advancedOptionsSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
Text("Advanced Options")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
Spacer()
|
||||
collapseButton(isExpanded: $showAdvancedOptions)
|
||||
}
|
||||
.animation(nil, value: showAdvancedOptions)
|
||||
if showAdvancedOptions {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Cluster Namespace")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
HStack {
|
||||
TextField("optional", text: $pendingNamespace)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingNamespace = controller.customNamespace
|
||||
}
|
||||
Button("Save & Restart") {
|
||||
controller.customNamespace = pendingNamespace
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingNamespace == controller.customNamespace)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.25), value: showAdvancedOptions)
|
||||
}
|
||||
|
||||
private var debugSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
|
||||
@@ -2,6 +2,8 @@ import AppKit
|
||||
import Combine
|
||||
import Foundation
|
||||
|
||||
private let customNamespaceKey = "EXOCustomNamespace"
|
||||
|
||||
@MainActor
|
||||
final class ExoProcessController: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -27,6 +29,13 @@ final class ExoProcessController: ObservableObject {
|
||||
@Published private(set) var status: Status = .stopped
|
||||
@Published private(set) var lastError: String?
|
||||
@Published private(set) var launchCountdownSeconds: Int?
|
||||
@Published var customNamespace: String = {
|
||||
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
|
||||
}() {
|
||||
didSet {
|
||||
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
|
||||
}
|
||||
}
|
||||
|
||||
private var process: Process?
|
||||
private var runtimeDirectoryURL: URL?
|
||||
@@ -180,7 +189,7 @@ final class ExoProcessController: ObservableObject {
|
||||
private func makeEnvironment(for runtimeURL: URL) -> [String: String] {
|
||||
var environment = ProcessInfo.processInfo.environment
|
||||
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = buildTag()
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
|
||||
|
||||
var paths: [String] = []
|
||||
if let existing = environment["PATH"], !existing.isEmpty {
|
||||
@@ -217,6 +226,12 @@ final class ExoProcessController: ObservableObject {
|
||||
}
|
||||
return "dev"
|
||||
}
|
||||
|
||||
private func computeNamespace() -> String {
|
||||
let base = buildTag()
|
||||
let custom = customNamespace.trimmingCharacters(in: .whitespaces)
|
||||
return custom.isEmpty ? base : custom
|
||||
}
|
||||
}
|
||||
|
||||
struct RuntimeError: LocalizedError {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
|
||||
import { isLoading, sendMessage, generateImage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
|
||||
import ChatAttachments from './ChatAttachments.svelte';
|
||||
import type { ChatUploadedFile } from '$lib/types/files';
|
||||
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
|
||||
@@ -10,6 +10,7 @@
|
||||
showHelperText?: boolean;
|
||||
autofocus?: boolean;
|
||||
showModelSelector?: boolean;
|
||||
modelTasks?: Record<string, string[]>;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -17,7 +18,8 @@
|
||||
placeholder = 'Ask anything',
|
||||
showHelperText = false,
|
||||
autofocus = true,
|
||||
showModelSelector = false
|
||||
showModelSelector = false,
|
||||
modelTasks = {}
|
||||
}: Props = $props();
|
||||
|
||||
let message = $state('');
|
||||
@@ -48,13 +50,29 @@
|
||||
// Accept all supported file types
|
||||
const acceptString = getAcceptString(['image', 'text', 'pdf']);
|
||||
|
||||
// Check if a model supports image generation
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const tasks = modelTasks[modelId] || [];
|
||||
return tasks.includes('TextToImage') || tasks.includes('ImageToImage');
|
||||
}
|
||||
|
||||
// Check if the currently selected model supports image generation
|
||||
const isImageModel = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
return modelSupportsImageGeneration(currentModel);
|
||||
});
|
||||
|
||||
// Extract available models from running instances
|
||||
const availableModels = $derived(() => {
|
||||
const models: Array<{id: string, label: string}> = [];
|
||||
const models: Array<{id: string, label: string, isImageModel: boolean}> = [];
|
||||
for (const [, instance] of Object.entries(instanceData)) {
|
||||
const modelId = getInstanceModelId(instance);
|
||||
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
|
||||
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
|
||||
models.push({
|
||||
id: modelId,
|
||||
label: modelId.split('/').pop() || modelId,
|
||||
isImageModel: modelSupportsImageGeneration(modelId)
|
||||
});
|
||||
}
|
||||
}
|
||||
return models;
|
||||
@@ -139,6 +157,11 @@
|
||||
}
|
||||
|
||||
function handleKeydown(event: KeyboardEvent) {
|
||||
// Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)
|
||||
if (event.isComposing || event.keyCode === 229) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.key === 'Enter' && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
handleSubmit();
|
||||
@@ -155,7 +178,12 @@
|
||||
uploadedFiles = [];
|
||||
resetTextareaHeight();
|
||||
|
||||
sendMessage(content, files);
|
||||
// Use image generation for image models
|
||||
if (isImageModel() && content) {
|
||||
generateImage(content);
|
||||
} else {
|
||||
sendMessage(content, files);
|
||||
}
|
||||
|
||||
// Refocus the textarea after sending
|
||||
setTimeout(() => textareaRef?.focus(), 10);
|
||||
@@ -292,7 +320,14 @@
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span class="truncate">{model.label}</span>
|
||||
{#if model.isImageModel}
|
||||
<svg class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate flex-1">{model.label}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
@@ -352,7 +387,7 @@
|
||||
onkeydown={handleKeydown}
|
||||
oninput={handleInput}
|
||||
onpaste={handlePaste}
|
||||
{placeholder}
|
||||
placeholder={isImageModel() ? 'Describe the image you want to generate...' : placeholder}
|
||||
disabled={loading}
|
||||
rows={1}
|
||||
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
|
||||
@@ -366,14 +401,23 @@
|
||||
{!canSend || loading
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label="Send message"
|
||||
aria-label={isImageModel() ? "Generate image" : "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
|
||||
<span class="hidden sm:inline">PROCESSING</span>
|
||||
<span class="hidden sm:inline">{isImageModel() ? 'GENERATING' : 'PROCESSING'}</span>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
|
||||
@@ -365,10 +365,58 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Generated Images -->
|
||||
{#if message.attachments?.some(a => a.type === 'generated-image')}
|
||||
<div class="mb-3">
|
||||
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
|
||||
<div class="relative group/img inline-block">
|
||||
<img
|
||||
src={attachment.preview}
|
||||
alt=""
|
||||
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
|
||||
/>
|
||||
<!-- Download button overlay -->
|
||||
<button
|
||||
type="button"
|
||||
class="absolute top-2 right-2 p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow opacity-0 group-hover/img:opacity-100 transition-opacity hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
|
||||
onclick={() => {
|
||||
if (attachment.preview) {
|
||||
const link = document.createElement('a');
|
||||
link.href = attachment.preview;
|
||||
link.download = `generated-image-${Date.now()}.png`;
|
||||
link.click();
|
||||
}
|
||||
}}
|
||||
title="Download image"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="text-xs text-foreground">
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{#if message.content === 'Generating image...'}
|
||||
<div class="flex items-center gap-3 text-exo-yellow">
|
||||
<div class="relative">
|
||||
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
|
||||
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
</div>
|
||||
<span class="font-mono tracking-wider uppercase text-sm">Generating image...</span>
|
||||
</div>
|
||||
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -47,10 +47,86 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let mounted = $state(false);
|
||||
|
||||
// Instance launch state
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
|
||||
|
||||
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
|
||||
const modelTasks = $derived(() => {
|
||||
const tasks: Record<string, string[]> = {};
|
||||
for (const model of models) {
|
||||
if (model.tasks && model.tasks.length > 0) {
|
||||
// Map by short ID
|
||||
tasks[model.id] = model.tasks;
|
||||
// Also map by hugging_face_id from the API response
|
||||
if (model.hugging_face_id) {
|
||||
tasks[model.hugging_face_id] = model.tasks;
|
||||
}
|
||||
}
|
||||
}
|
||||
return tasks;
|
||||
});
|
||||
|
||||
// Helper to check if a model supports image generation
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
|
||||
if (!model?.tasks) return false;
|
||||
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
|
||||
}
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
// Launch defaults persistence
|
||||
const LAUNCH_DEFAULTS_KEY = 'exo-launch-defaults';
|
||||
interface LaunchDefaults {
|
||||
modelId: string | null;
|
||||
sharding: 'Pipeline' | 'Tensor';
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
}
|
||||
|
||||
function saveLaunchDefaults(): void {
|
||||
const defaults: LaunchDefaults = {
|
||||
modelId: selectedPreviewModelId(),
|
||||
sharding: selectedSharding,
|
||||
instanceType: selectedInstanceType,
|
||||
minNodes: selectedMinNodes,
|
||||
};
|
||||
try {
|
||||
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
|
||||
} catch (e) {
|
||||
console.warn('Failed to save launch defaults:', e);
|
||||
}
|
||||
}
|
||||
|
||||
function loadLaunchDefaults(): LaunchDefaults | null {
|
||||
try {
|
||||
const stored = localStorage.getItem(LAUNCH_DEFAULTS_KEY);
|
||||
if (!stored) return null;
|
||||
return JSON.parse(stored) as LaunchDefaults;
|
||||
} catch (e) {
|
||||
console.warn('Failed to load launch defaults:', e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
if (!defaults) return;
|
||||
|
||||
// Apply sharding and instance type unconditionally
|
||||
selectedSharding = defaults.sharding;
|
||||
selectedInstanceType = defaults.instanceType;
|
||||
|
||||
// Apply minNodes if valid (between 1 and maxNodes)
|
||||
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
// Only apply model if it exists in the available models
|
||||
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
}
|
||||
}
|
||||
|
||||
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
|
||||
let selectedMinNodes = $state<number>(1);
|
||||
let minNodesInitialized = $state(false);
|
||||
@@ -298,6 +374,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const data = await response.json();
|
||||
// API returns { data: [{ id, name }] } format
|
||||
models = data.data || [];
|
||||
// Restore last launch defaults if available
|
||||
const currentNodeCount = topologyData() ? Object.keys(topologyData()!.nodes).length : 1;
|
||||
applyLaunchDefaults(models, currentNodeCount);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch models:', error);
|
||||
@@ -988,6 +1067,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
function handleSliderMouseUp() {
|
||||
isDraggingSlider = false;
|
||||
saveLaunchDefaults();
|
||||
}
|
||||
|
||||
// Handle touch events for mobile
|
||||
@@ -1007,6 +1087,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
function handleSliderTouchEnd() {
|
||||
isDraggingSlider = false;
|
||||
saveLaunchDefaults();
|
||||
}
|
||||
|
||||
const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0);
|
||||
@@ -1192,6 +1273,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
placeholder="Ask anything"
|
||||
showHelperText={false}
|
||||
showModelSelector={true}
|
||||
modelTasks={modelTasks()}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -1413,8 +1495,18 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{@const foundModel = models.find(m => m.id === selectedModelId)}
|
||||
{#if foundModel}
|
||||
{@const sizeGB = getModelSizeGB(foundModel)}
|
||||
<span class="flex items-center justify-between gap-2 w-full pr-4">
|
||||
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
|
||||
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
|
||||
<span class="flex items-center justify-between gap-2 w-full pr-4">
|
||||
<span class="flex items-center gap-2 text-exo-light-gray truncate">
|
||||
{#if isImageModel}
|
||||
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{foundModel.name || foundModel.id}</span>
|
||||
</span>
|
||||
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
|
||||
</span>
|
||||
{:else}
|
||||
@@ -1459,11 +1551,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
) as model}
|
||||
{@const sizeGB = getModelSizeGB(model)}
|
||||
{@const modelCanFit = hasEnoughMemory(model)}
|
||||
{@const isImageModel = modelSupportsImageGeneration(model.id)}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
if (modelCanFit) {
|
||||
selectPreviewModel(model.id);
|
||||
saveLaunchDefaults();
|
||||
isModelDropdownOpen = false;
|
||||
modelDropdownSearch = '';
|
||||
}
|
||||
@@ -1477,7 +1571,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
: 'text-white/30 cursor-default'
|
||||
}"
|
||||
>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
<span class="flex items-center gap-2 truncate flex-1">
|
||||
{#if isImageModel}
|
||||
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
</span>
|
||||
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
|
||||
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
|
||||
</span>
|
||||
@@ -1497,7 +1600,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Sharding:</div>
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
onclick={() => selectedSharding = 'Pipeline'}
|
||||
onclick={() => { selectedSharding = 'Pipeline'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Pipeline' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Pipeline' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1508,7 +1611,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
Pipeline
|
||||
</button>
|
||||
<button
|
||||
onclick={() => selectedSharding = 'Tensor'}
|
||||
onclick={() => { selectedSharding = 'Tensor'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Tensor' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Tensor' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1526,7 +1629,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Instance Type:</div>
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
onclick={() => selectedInstanceType = 'MlxRing'}
|
||||
onclick={() => { selectedInstanceType = 'MlxRing'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxRing' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxRing' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1537,7 +1640,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
MLX Ring
|
||||
</button>
|
||||
<button
|
||||
onclick={() => selectedInstanceType = 'MlxIbv'}
|
||||
onclick={() => { selectedInstanceType = 'MlxIbv'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxIbv' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxIbv' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1674,7 +1777,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} />
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -35,6 +35,8 @@ dependencies = [
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux>=0.12.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -28,7 +28,7 @@ from exo.worker.main import Worker
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
worker: Worker
|
||||
worker: Worker | None
|
||||
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
|
||||
election_result_receiver: Receiver[ElectionResult]
|
||||
master: Master | None
|
||||
@@ -62,15 +62,19 @@ class Node:
|
||||
else:
|
||||
api = None
|
||||
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
)
|
||||
if not args.no_worker:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
|
||||
# We start every node with a master
|
||||
master = Master(
|
||||
node_id,
|
||||
@@ -100,8 +104,9 @@ class Node:
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.worker.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.worker:
|
||||
tg.start_soon(self.worker.run)
|
||||
if self.master:
|
||||
tg.start_soon(self.master.run)
|
||||
if self.api:
|
||||
@@ -209,6 +214,7 @@ class Args(CamelCaseModel):
|
||||
spawn_api: bool = False
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -246,6 +252,10 @@ class Args(CamelCaseModel):
|
||||
dest="api_port",
|
||||
default=52415,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Literal, cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -22,9 +24,10 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
@@ -34,6 +37,10 @@ from exo.shared.types.api import (
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
FinishReason,
|
||||
ImageData,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationResponse,
|
||||
ImageGenerationTaskParams,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -41,14 +48,17 @@ from exo.shared.types.api import (
|
||||
PlacementPreviewResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
@@ -84,12 +94,23 @@ def chunk_to_response(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_meta(model_id: str) -> ModelMetadata:
|
||||
def get_model_card(model_id: str) -> ModelCard | None:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
|
||||
for _, model_card in MODEL_CARDS.items():
|
||||
if model_id == model_card.model_id:
|
||||
return model_card
|
||||
|
||||
|
||||
async def resolve_model_meta(model_id: str) -> ModelMetadata:
|
||||
model_card = get_model_card(model_id)
|
||||
|
||||
if model_card is not None:
|
||||
return model_card.metadata
|
||||
else:
|
||||
return await get_model_meta(model_id)
|
||||
|
||||
return await get_model_meta(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
@@ -133,6 +154,7 @@ class API:
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
|
||||
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
@@ -141,6 +163,7 @@ class API:
|
||||
self.session_id = new_session_id
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self._chat_completion_queues = {}
|
||||
self._image_generation_queues = {}
|
||||
self.unpause(result_clock)
|
||||
|
||||
def unpause(self, result_clock: int):
|
||||
@@ -172,6 +195,10 @@ class API:
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/v1/images/generations", response_model=None)(
|
||||
self.image_generations
|
||||
)
|
||||
self.app.post("/v1/images/edits", response_model=None)(self.image_edits)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -525,6 +552,325 @@ class API:
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
async def image_generations(
|
||||
self, payload: ImageGenerationTaskParams
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image generation requests.
|
||||
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(payload.model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {payload.model}"
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
# Check if streaming is requested
|
||||
if payload.stream and payload.partial_images and payload.partial_images > 0:
|
||||
return StreamingResponse(
|
||||
self._generate_image_stream(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming: collect all image chunks
|
||||
return await self._collect_image_generation(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
)
|
||||
|
||||
async def _generate_image_stream(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate SSE stream of partial and final images."""
|
||||
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
|
||||
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
|
||||
image_total_chunks: dict[tuple[int, bool], int] = {}
|
||||
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
key = (chunk.image_index, chunk.is_partial)
|
||||
|
||||
if key not in image_chunks:
|
||||
image_chunks[key] = {}
|
||||
image_total_chunks[key] = chunk.total_chunks
|
||||
image_metadata[key] = (
|
||||
chunk.partial_index,
|
||||
chunk.total_partials,
|
||||
)
|
||||
|
||||
image_chunks[key][chunk.chunk_index] = chunk.data
|
||||
|
||||
# Check if this image is complete
|
||||
if len(image_chunks[key]) == image_total_chunks[key]:
|
||||
full_data = "".join(
|
||||
image_chunks[key][i] for i in range(len(image_chunks[key]))
|
||||
)
|
||||
|
||||
partial_idx, total_partials = image_metadata[key]
|
||||
|
||||
if chunk.is_partial:
|
||||
# Yield partial image event
|
||||
event_data = {
|
||||
"type": "partial",
|
||||
"partial_index": partial_idx,
|
||||
"total_partials": total_partials,
|
||||
"data": {
|
||||
"b64_json": full_data
|
||||
if response_format == "b64_json"
|
||||
else None,
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
else:
|
||||
# Final image
|
||||
event_data = {
|
||||
"type": "final",
|
||||
"image_index": chunk.image_index,
|
||||
"data": {
|
||||
"b64_json": full_data
|
||||
if response_format == "b64_json"
|
||||
else None,
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
# Clean up completed image chunks
|
||||
del image_chunks[key]
|
||||
del image_total_chunks[key]
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._image_generation_queues:
|
||||
del self._image_generation_queues[command_id]
|
||||
|
||||
async def _collect_image_generation(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> ImageGenerationResponse:
|
||||
"""Collect all image chunks (non-streaming) and return a single response."""
|
||||
# Track chunks per image: {image_index: {chunk_index: data}}
|
||||
# Only track non-partial (final) images
|
||||
image_chunks: dict[int, dict[int, str]] = {}
|
||||
image_total_chunks: dict[int, int] = {}
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
|
||||
while images_complete < num_images:
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
# Skip partial images in non-streaming mode
|
||||
if chunk.is_partial:
|
||||
continue
|
||||
|
||||
if chunk.image_index not in image_chunks:
|
||||
image_chunks[chunk.image_index] = {}
|
||||
image_total_chunks[chunk.image_index] = chunk.total_chunks
|
||||
|
||||
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
|
||||
|
||||
# Check if this image is complete
|
||||
if (
|
||||
len(image_chunks[chunk.image_index])
|
||||
== image_total_chunks[chunk.image_index]
|
||||
):
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
break
|
||||
|
||||
# Reassemble images in order
|
||||
images: list[ImageData] = []
|
||||
for image_idx in range(num_images):
|
||||
chunks_dict = image_chunks[image_idx]
|
||||
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
|
||||
images.append(
|
||||
ImageData(
|
||||
b64_json=full_data if response_format == "b64_json" else None,
|
||||
url=None, # URL format not implemented yet
|
||||
)
|
||||
)
|
||||
|
||||
return ImageGenerationResponse(data=images)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._image_generation_queues:
|
||||
del self._image_generation_queues[command_id]
|
||||
|
||||
async def image_edits(
|
||||
self,
|
||||
image: UploadFile = File(...),
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: bool = Form(False),
|
||||
partial_images: int = Form(0),
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image editing requests (img2img)."""
|
||||
model_meta = await resolve_model_meta(model)
|
||||
resolved_model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(resolved_model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {resolved_model}"
|
||||
)
|
||||
|
||||
# Read and base64 encode the uploaded image
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
# Map input_fidelity to image_strength
|
||||
image_strength = 0.7 if input_fidelity == "high" else 0.3
|
||||
|
||||
# Split image into chunks to stay under gossipsub message size limit
|
||||
data_chunks = [
|
||||
image_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
|
||||
# Create command first to get command_id
|
||||
command = ImageEdits(
|
||||
request_params=ImageEditsInternalParams(
|
||||
image_data="", # Empty - will be assembled at worker from chunks
|
||||
total_input_chunks=total_chunks,
|
||||
prompt=prompt,
|
||||
model=resolved_model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
image_strength=image_strength,
|
||||
stream=stream,
|
||||
partial_images=partial_images,
|
||||
),
|
||||
)
|
||||
|
||||
# Send input chunks BEFORE the command
|
||||
logger.info(
|
||||
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(data_chunks):
|
||||
await self._send(
|
||||
SendInputChunk(
|
||||
chunk=InputImageChunk(
|
||||
idx=chunk_index,
|
||||
model=resolved_model,
|
||||
command_id=command.command_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Now send the main command
|
||||
await self._send(command)
|
||||
|
||||
num_images = n
|
||||
|
||||
# Check if streaming is requested
|
||||
if stream and partial_images and partial_images > 0:
|
||||
return StreamingResponse(
|
||||
self._generate_image_stream(
|
||||
command_id=command.command_id,
|
||||
num_images=num_images,
|
||||
response_format=response_format,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Track chunks per image: {image_index: {chunk_index: data}}
|
||||
image_chunks: dict[int, dict[int, str]] = {}
|
||||
image_total_chunks: dict[int, int] = {}
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command.command_id], recv = channel[
|
||||
ImageChunk
|
||||
]()
|
||||
|
||||
while images_complete < num_images:
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
if chunk.image_index not in image_chunks:
|
||||
image_chunks[chunk.image_index] = {}
|
||||
image_total_chunks[chunk.image_index] = chunk.total_chunks
|
||||
|
||||
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
|
||||
|
||||
if (
|
||||
len(image_chunks[chunk.image_index])
|
||||
== image_total_chunks[chunk.image_index]
|
||||
):
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
break
|
||||
|
||||
images: list[ImageData] = []
|
||||
for image_idx in range(num_images):
|
||||
chunks_dict = image_chunks[image_idx]
|
||||
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
|
||||
images.append(
|
||||
ImageData(
|
||||
b64_json=full_data if response_format == "b64_json" else None,
|
||||
url=None, # URL format not implemented yet
|
||||
)
|
||||
)
|
||||
|
||||
return ImageGenerationResponse(data=images)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
# Send TaskFinished command
|
||||
await self._send(TaskFinished(finished_command_id=command.command_id))
|
||||
del self._image_generation_queues[command.command_id]
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
@@ -547,6 +893,7 @@ class API:
|
||||
tags=card.tags,
|
||||
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
|
||||
supports_tensor=card.metadata.supports_tensor,
|
||||
tasks=[task.value for task in card.tasks],
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
]
|
||||
@@ -584,14 +931,17 @@ class API:
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if (
|
||||
isinstance(event, ChunkGenerated)
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if event.command_id in self._chat_completion_queues:
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
elif event.command_id in self._image_generation_queues:
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
await self._image_generation_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime, timedelta, timezone
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi.routing import request_response
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
@@ -11,13 +12,17 @@ from exo.master.placement import (
|
||||
place_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -26,6 +31,7 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
@@ -35,6 +41,12 @@ from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion as ChatCompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageEdits as ImageEditsTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageGeneration as ImageGenerationTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
@@ -99,6 +111,7 @@ class Master:
|
||||
async for forwarder_command in commands:
|
||||
try:
|
||||
logger.info(f"Executing command: {forwarder_command.command}")
|
||||
|
||||
generated_events: list[Event] = []
|
||||
command = forwarder_command.command
|
||||
match command:
|
||||
@@ -146,6 +159,92 @@ class Master:
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageGeneration():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageGenerationTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageEdits():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageEditsTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
@@ -173,6 +272,13 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
generated_events.append(
|
||||
InputChunkReceived(
|
||||
command_id=chunk.command_id,
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
|
||||
@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
|
||||
)
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
@@ -29,6 +30,7 @@ from exo.shared.types.worker.instances import (
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def random_ephemeral_port() -> int:
|
||||
@@ -65,6 +67,28 @@ def place_instance(
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
if command.sharding == Sharding.Tensor:
|
||||
if not command.model_meta.supports_tensor:
|
||||
raise ValueError(
|
||||
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
|
||||
)
|
||||
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
|
||||
cycles_with_sufficient_memory = [
|
||||
cycle
|
||||
for cycle in cycles_with_sufficient_memory
|
||||
if command.model_meta.hidden_size % len(cycle) == 0
|
||||
]
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError(
|
||||
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
|
||||
)
|
||||
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
|
||||
"mlx-community/DeepSeek-V3.1-8bit"
|
||||
):
|
||||
raise ValueError(
|
||||
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
|
||||
)
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
smallest_tb_cycles = [
|
||||
|
||||
@@ -385,13 +385,14 @@ def get_mlx_jaccl_coordinators(
|
||||
address in format "X.X.X.X:PORT" per node.
|
||||
"""
|
||||
rank_0_node = selected_cycle[0]
|
||||
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
|
||||
def get_ip_for_node(n: NodeInfo) -> str:
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
return "0.0.0.0"
|
||||
|
||||
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
|
||||
if ip:
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
|
||||
@@ -50,7 +50,7 @@ def model_meta() -> ModelMetadata:
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
NodeCreated,
|
||||
@@ -40,8 +41,8 @@ def event_apply(event: Event, state: State) -> State:
|
||||
"""Apply an event to state."""
|
||||
match event:
|
||||
case (
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged()
|
||||
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
return apply_instance_created(event, state)
|
||||
|
||||
@@ -44,3 +44,5 @@ LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
||||
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
|
||||
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
|
||||
LIBP2P_COMMANDS_TOPIC = "commands"
|
||||
|
||||
EXO_MAX_CHUNK_SIZE = 512 * 1024
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.models import ComponentInfo, ModelId, ModelMetadata, ModelTask
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ class ModelCard(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
name: str
|
||||
description: str
|
||||
tasks: list[ModelTask]
|
||||
tags: list[str]
|
||||
metadata: ModelMetadata
|
||||
|
||||
@@ -45,6 +46,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
name="DeepSeek V3.1 (4-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
@@ -60,6 +62,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
name="DeepSeek V3.1 (8-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
@@ -133,6 +136,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
name="Kimi K2 Instruct (4-bit)",
|
||||
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
@@ -148,6 +152,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
name="Kimi K2 Thinking (4-bit)",
|
||||
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
@@ -164,6 +169,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
name="Llama 3.1 8B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
@@ -179,6 +185,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
name="Llama 3.1 8B (8-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
@@ -194,6 +201,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
name="Llama 3.1 8B (BF16)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
@@ -209,6 +217,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
name="Llama 3.1 70B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
@@ -225,6 +234,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
name="Llama 3.2 1B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
@@ -240,6 +250,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
name="Llama 3.2 3B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
@@ -255,6 +266,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
name="Llama 3.2 3B (8-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
@@ -271,6 +283,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
name="Llama 3.3 70B (4-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
@@ -286,6 +299,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
name="Llama 3.3 70B (8-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
@@ -301,6 +315,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
name="Llama 3.3 70B (FP16)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
@@ -317,6 +332,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
name="Qwen3 0.6B (4-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
@@ -332,6 +348,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
name="Qwen3 0.6B (8-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
@@ -347,6 +364,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
name="Qwen3 30B A3B (4-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
@@ -362,6 +380,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
name="Qwen3 30B A3B (8-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
@@ -377,6 +396,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
name="Qwen3 80B A3B (4-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
@@ -392,6 +412,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
name="Qwen3 80B A3B (8-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
@@ -407,6 +428,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
name="Qwen3 80B A3B Thinking (4-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
@@ -422,6 +444,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
name="Qwen3 80B A3B Thinking (8-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
@@ -437,6 +460,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
name="Qwen3 235B A22B (4-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
@@ -452,6 +476,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
name="Qwen3 235B A22B (8-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
@@ -467,6 +492,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
@@ -482,6 +508,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
@@ -498,6 +525,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
@@ -513,6 +541,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
@@ -529,6 +558,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
description="""GLM 4.5 Air 8bit""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
@@ -544,6 +574,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
name="GLM 4.5 Air bf16",
|
||||
description="""GLM 4.5 Air bf16""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
@@ -569,4 +600,188 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
"flux1-schnell": ModelCard(
|
||||
short_id="flux1-schnell",
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
name="FLUX.1 [schnell]",
|
||||
description="""FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
pretty_name="FLUX.1 [schnell]",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(23782357120), # + 9524621312),
|
||||
n_layers=57, # sharded layers
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
short_id="flux1-dev",
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
name="FLUX.1 [dev]",
|
||||
description="""FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
pretty_name="FLUX.1 [dev]",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57, # sharded layers
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
short_id="qwen-image",
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
name="Qwen Image",
|
||||
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
pretty_name="Qwen Image",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
short_id="qwen-image-edit-2509",
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
name="Qwen Image Edit 2509",
|
||||
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
pretty_name="Qwen Image Edit 2509",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
@@ -27,6 +29,7 @@ class ModelListModel(BaseModel):
|
||||
tags: list[str] = Field(default=[])
|
||||
storage_size_megabytes: int = Field(default=0)
|
||||
supports_tensor: bool = Field(default=False)
|
||||
tasks: list[str] = Field(default=[])
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
@@ -181,3 +184,75 @@ class DeleteInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class ImageGenerationTaskParams(BaseModel):
|
||||
prompt: str
|
||||
# background: str | None = None
|
||||
model: str
|
||||
# moderation: str | None = None
|
||||
n: int | None = 1
|
||||
# output_compression: int | None = None
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_images: int | None = 0
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
stream: bool | None = False
|
||||
# style: str | None = "vivid"
|
||||
# user: str | None = None
|
||||
|
||||
|
||||
class ImageEditsTaskParams(BaseModel):
|
||||
image: UploadFile
|
||||
prompt: str
|
||||
input_fidelity: float = 0.7
|
||||
model: str
|
||||
n: int | None = 1
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
# user: str | None = None
|
||||
|
||||
|
||||
class ImageEditsInternalParams(BaseModel):
|
||||
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
|
||||
|
||||
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
|
||||
total_input_chunks: int = 0
|
||||
prompt: str
|
||||
model: str
|
||||
n: int | None = 1
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
image_strength: float = 0.7
|
||||
stream: bool = False
|
||||
partial_images: int | None = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class ImageData(BaseModel):
|
||||
b64_json: str | None = None
|
||||
url: str | None = None
|
||||
revised_prompt: str | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "b64_json" and value is not None:
|
||||
yield name, f"<{len(value)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
data: list[ImageData]
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .common import CommandId
|
||||
from .models import ModelId
|
||||
|
||||
|
||||
@@ -23,7 +26,34 @@ class TokenChunk(BaseChunk):
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
data: bytes
|
||||
data: str
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
image_index: int
|
||||
is_partial: bool = False
|
||||
partial_index: int | None = None
|
||||
total_partials: int | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "data":
|
||||
yield name, f"<{len(self.data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class InputImageChunk(BaseChunk):
|
||||
command_id: CommandId
|
||||
data: str
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "data":
|
||||
yield name, f"<{len(self.data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
@@ -20,6 +25,14 @@ class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
request_params: ImageGenerationTaskParams
|
||||
|
||||
|
||||
class ImageEdits(BaseCommand):
|
||||
request_params: ImageEditsInternalParams
|
||||
|
||||
|
||||
class PlaceInstance(BaseCommand):
|
||||
model_meta: ModelMetadata
|
||||
sharding: Sharding
|
||||
@@ -39,6 +52,12 @@ class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
|
||||
class SendInputChunk(BaseCommand):
|
||||
"""Command to send an input image chunk (converted to event by master)."""
|
||||
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class RequestEventLog(BaseCommand):
|
||||
since_idx: int
|
||||
|
||||
@@ -47,10 +66,13 @@ Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
| ChatCompletion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from datetime import datetime
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
@@ -106,6 +106,11 @@ class ChunkGenerated(BaseEvent):
|
||||
chunk: GenerationChunk
|
||||
|
||||
|
||||
class InputChunkReceived(BaseEvent):
|
||||
command_id: CommandId
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
edge: Connection
|
||||
|
||||
@@ -131,6 +136,7 @@ Event = (
|
||||
| NodeMemoryMeasured
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.types.common import Id
|
||||
@@ -9,6 +11,21 @@ class ModelId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class ModelTask(str, Enum):
|
||||
TextGeneration = "TextGeneration"
|
||||
TextToImage = "TextToImage"
|
||||
ImageToImage = "ImageToImage"
|
||||
|
||||
|
||||
class ComponentInfo(CamelCaseModel):
|
||||
component_name: str
|
||||
component_path: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt | None
|
||||
can_shard: bool
|
||||
safetensors_index_filename: str | None
|
||||
|
||||
|
||||
class ModelMetadata(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
pretty_name: str
|
||||
@@ -16,3 +33,4 @@ class ModelMetadata(CamelCaseModel):
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
components: list[ComponentInfo] | None = None
|
||||
|
||||
@@ -2,7 +2,11 @@ from enum import Enum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
@@ -56,6 +60,22 @@ class ChatCompletion(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageEdits(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageEditsInternalParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
@@ -67,5 +87,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
@@ -17,5 +20,31 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} bytes>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class PartialImageResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_index: int
|
||||
total_partials: int
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} bytes>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
@@ -53,6 +53,10 @@ class RunnerRunning(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerShutdown(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
@@ -70,6 +74,7 @@ RunnerStatus = (
|
||||
| RunnerWarmingUp
|
||||
| RunnerReady
|
||||
| RunnerRunning
|
||||
| RunnerShuttingDown
|
||||
| RunnerShutdown
|
||||
| RunnerFailed
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal
|
||||
from urllib.parse import urljoin
|
||||
from huggingface_hub._snapshot_download import snapshot_download
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
@@ -441,15 +442,39 @@ def calculate_repo_progress(
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
|
||||
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_file = await download_file_with_retry(
|
||||
repo_id, revision, "model.safetensors.index.json", target_dir
|
||||
|
||||
index_files_dir = snapshot_download(
|
||||
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
|
||||
)
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
return index_data.weight_map
|
||||
|
||||
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
|
||||
|
||||
weight_map: dict[str, str] = {}
|
||||
|
||||
for index_file in index_files:
|
||||
relative_dir = index_file.parent.relative_to(index_files_dir)
|
||||
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
if relative_dir != Path("."):
|
||||
prefixed_weight_map = {
|
||||
f"{relative_dir}/{key}": str(relative_dir / value)
|
||||
for key, value in index_data.weight_map.items()
|
||||
}
|
||||
weight_map = weight_map | prefixed_weight_map
|
||||
else:
|
||||
weight_map = weight_map | index_data.weight_map
|
||||
|
||||
return weight_map
|
||||
|
||||
|
||||
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
|
||||
# TODO: 'Smart' downloads are disabled because:
|
||||
# (i) We don't handle all kinds of files;
|
||||
# (ii) We don't have sticky sessions.
|
||||
# (iii) Tensor parallel requires all files.
|
||||
return ["*"]
|
||||
try:
|
||||
weight_map = await get_weight_map(str(shard.model_meta.model_id))
|
||||
return get_allow_patterns(weight_map, shard)
|
||||
@@ -546,8 +571,6 @@ async def download_shard(
|
||||
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
# Update: <- This does not seem to be the case. Yay?
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
str(shard.model_meta.model_id), revision, recursive=True
|
||||
)
|
||||
|
||||
@@ -100,26 +100,68 @@ def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"tiktoken.model",
|
||||
"*/spiece.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jinja",
|
||||
]
|
||||
)
|
||||
shard_specific_patterns: set[str] = set()
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num <= shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
layer_independent_files = set(
|
||||
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
|
||||
|
||||
if shard.model_meta.components is not None:
|
||||
shardable_component = next(
|
||||
(c for c in shard.model_meta.components if c.can_shard), None
|
||||
)
|
||||
shard_specific_patterns.update(layer_independent_files)
|
||||
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
|
||||
|
||||
if weight_map and shardable_component:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
# Strip component prefix from tensor name (added by weight map namespacing)
|
||||
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
|
||||
if "/" in tensor_name:
|
||||
_, tensor_name_no_prefix = tensor_name.split("/", 1)
|
||||
else:
|
||||
tensor_name_no_prefix = tensor_name
|
||||
|
||||
# Determine which component this file belongs to from filename
|
||||
component_path = Path(filename).parts[0] if "/" in filename else None
|
||||
|
||||
if component_path == shardable_component.component_path.rstrip("/"):
|
||||
layer_num = extract_layer_num(tensor_name_no_prefix)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num < shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
|
||||
if shard.is_first_layer or shard.is_last_layer:
|
||||
shard_specific_patterns.add(filename)
|
||||
else:
|
||||
shard_specific_patterns.add(filename)
|
||||
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
|
||||
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
|
||||
for component in shard.model_meta.components:
|
||||
if not component.can_shard and component.safetensors_index_filename is None:
|
||||
component_pattern = f"{component.component_path.rstrip('/')}/*"
|
||||
shard_specific_patterns.add(component_pattern)
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num < shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
layer_independent_files = set(
|
||||
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
|
||||
)
|
||||
shard_specific_patterns.update(layer_independent_files)
|
||||
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
|
||||
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
|
||||
return list(default_patterns | shard_specific_patterns)
|
||||
|
||||
10
src/exo/worker/engines/image/__init__.py
Normal file
10
src/exo/worker/engines/image/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from exo.worker.engines.image.base import ImageGenerator
|
||||
from exo.worker.engines.image.distributed_model import initialize_image_model
|
||||
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
|
||||
|
||||
__all__ = [
|
||||
"ImageGenerator",
|
||||
"generate_image",
|
||||
"initialize_image_model",
|
||||
"warmup_image_generator",
|
||||
]
|
||||
50
src/exo/worker/engines/image/base.py
Normal file
50
src/exo/worker/engines/image/base.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ImageGenerator(Protocol):
|
||||
@property
|
||||
def rank(self) -> int: ...
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
quality: Literal["low", "medium", "high"],
|
||||
seed: int,
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
"""Generate an image from a text prompt, or edit an existing image.
|
||||
|
||||
For distributed inference, only the last stage returns images.
|
||||
Other stages yield nothing after participating in the pipeline.
|
||||
|
||||
When partial_images > 0, yields intermediate images during diffusion
|
||||
as tuples of (image, partial_index, total_partials), then yields
|
||||
the final image.
|
||||
|
||||
When partial_images = 0 (default), only yields the final image.
|
||||
|
||||
Args:
|
||||
prompt: Text description of the image to generate
|
||||
height: Image height in pixels
|
||||
width: Image width in pixels
|
||||
quality: Generation quality level
|
||||
seed: Random seed for reproducibility
|
||||
image_path: Optional path to input image for image editing
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
|
||||
Yields:
|
||||
Intermediate images as (Image, partial_index, total_partials) tuples
|
||||
Final PIL Image (last stage) or nothing (other stages)
|
||||
"""
|
||||
...
|
||||
74
src/exo/worker/engines/image/config.py
Normal file
74
src/exo/worker/engines/image/config.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BlockType(Enum):
|
||||
JOINT = "joint" # Separate image/text streams
|
||||
SINGLE = "single" # Concatenated streams
|
||||
|
||||
|
||||
class TransformerBlockConfig(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
block_type: BlockType
|
||||
count: int
|
||||
has_separate_text_output: bool # True for joint blocks that output text separately
|
||||
|
||||
|
||||
class ImageModelConfig(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
# Model identification
|
||||
model_family: str # "flux", "fibo", "qwen"
|
||||
model_variant: str # "schnell", "dev", etc.
|
||||
|
||||
# Architecture parameters
|
||||
hidden_dim: int
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
|
||||
# Block configuration - ordered sequence of block types
|
||||
block_configs: tuple[TransformerBlockConfig, ...]
|
||||
|
||||
# Tokenization parameters
|
||||
patch_size: int # 2 for Flux/Qwen
|
||||
vae_scale_factor: int # 8 for Flux, 16 for others
|
||||
|
||||
# Inference parameters
|
||||
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
|
||||
num_sync_steps_factor: float # Fraction of steps for sync phase
|
||||
|
||||
# Feature flags
|
||||
uses_attention_mask: bool # True for Fibo
|
||||
|
||||
# CFG (Classifier-Free Guidance) parameters
|
||||
guidance_scale: float | None = None # None or <= 1.0 disables CFG
|
||||
|
||||
@property
|
||||
def total_blocks(self) -> int:
|
||||
"""Total number of transformer blocks."""
|
||||
return sum(bc.count for bc in self.block_configs)
|
||||
|
||||
@property
|
||||
def joint_block_count(self) -> int:
|
||||
"""Number of joint transformer blocks."""
|
||||
return sum(
|
||||
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
|
||||
)
|
||||
|
||||
@property
|
||||
def single_block_count(self) -> int:
|
||||
"""Number of single transformer blocks."""
|
||||
return sum(
|
||||
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
|
||||
)
|
||||
|
||||
def get_steps_for_quality(self, quality: str) -> int:
|
||||
"""Get inference steps for a quality level."""
|
||||
return self.default_steps[quality]
|
||||
|
||||
def get_num_sync_steps(self, quality: str) -> int:
|
||||
"""Get number of synchronous steps based on quality."""
|
||||
return ceil(self.default_steps[quality] * self.num_sync_steps_factor)
|
||||
228
src/exo/worker/engines/image/distributed_model.py
Normal file
228
src/exo/worker/engines/image/distributed_model.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.config import Config
|
||||
from PIL import Image
|
||||
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models import (
|
||||
create_adapter_for_model,
|
||||
get_config_for_model,
|
||||
)
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline import DiffusionRunner
|
||||
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
class DistributedImageModel:
|
||||
__slots__ = (
|
||||
"_config",
|
||||
"_adapter",
|
||||
"_group",
|
||||
"_shard_metadata",
|
||||
"_runner",
|
||||
)
|
||||
|
||||
_config: ImageModelConfig
|
||||
_adapter: BaseModelAdapter
|
||||
_group: Optional[mx.distributed.Group]
|
||||
_shard_metadata: PipelineShardMetadata
|
||||
_runner: DiffusionRunner
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
# Get model config and create adapter (adapter owns the model)
|
||||
config = get_config_for_model(model_id)
|
||||
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
|
||||
|
||||
if group is not None:
|
||||
adapter.slice_transformer_blocks(
|
||||
start_layer=shard_metadata.start_layer,
|
||||
end_layer=shard_metadata.end_layer,
|
||||
total_joint_blocks=config.joint_block_count,
|
||||
total_single_blocks=config.single_block_count,
|
||||
)
|
||||
|
||||
# Create diffusion runner (handles both single-node and distributed modes)
|
||||
num_sync_steps = config.get_num_sync_steps("medium") if group else 0
|
||||
runner = DiffusionRunner(
|
||||
config=config,
|
||||
adapter=adapter,
|
||||
group=group,
|
||||
shard_metadata=shard_metadata,
|
||||
num_sync_steps=num_sync_steps,
|
||||
)
|
||||
|
||||
if group is not None:
|
||||
logger.info("Initialized distributed diffusion runner")
|
||||
|
||||
mx.eval(adapter.model.parameters())
|
||||
|
||||
# TODO(ciaran): Do we need this?
|
||||
mx.eval(adapter.model)
|
||||
|
||||
# Synchronize processes before generation to avoid timeout
|
||||
mx_barrier(group)
|
||||
logger.info(f"Transformer sharded for rank {group.rank()}")
|
||||
else:
|
||||
logger.info("Single-node initialization")
|
||||
|
||||
object.__setattr__(self, "_config", config)
|
||||
object.__setattr__(self, "_adapter", adapter)
|
||||
object.__setattr__(self, "_group", group)
|
||||
object.__setattr__(self, "_shard_metadata", shard_metadata)
|
||||
object.__setattr__(self, "_runner", runner)
|
||||
|
||||
@classmethod
|
||||
def from_bound_instance(
|
||||
cls, bound_instance: BoundInstance
|
||||
) -> "DistributedImageModel":
|
||||
model_id = bound_instance.bound_shard.model_meta.model_id
|
||||
model_path = build_model_path(model_id)
|
||||
|
||||
shard_metadata = bound_instance.bound_shard
|
||||
if not isinstance(shard_metadata, PipelineShardMetadata):
|
||||
raise ValueError("Expected PipelineShardMetadata for image generation")
|
||||
|
||||
is_distributed = (
|
||||
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
|
||||
)
|
||||
|
||||
if is_distributed:
|
||||
logger.info("Starting distributed init for image model")
|
||||
group = mlx_distributed_init(bound_instance)
|
||||
else:
|
||||
group = None
|
||||
|
||||
return cls(
|
||||
model_id=model_id,
|
||||
local_path=model_path,
|
||||
shard_metadata=shard_metadata,
|
||||
group=group,
|
||||
)
|
||||
|
||||
@property
|
||||
def model(self) -> Any:
|
||||
"""Return the underlying mflux model via the adapter."""
|
||||
return self._adapter.model
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def adapter(self) -> BaseModelAdapter:
|
||||
return self._adapter
|
||||
|
||||
@property
|
||||
def group(self) -> Optional[mx.distributed.Group]:
|
||||
return self._group
|
||||
|
||||
@property
|
||||
def shard_metadata(self) -> PipelineShardMetadata:
|
||||
return self._shard_metadata
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return self._shard_metadata.device_rank
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return self._shard_metadata.world_size
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self._shard_metadata.device_rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self._shard_metadata.device_rank == self._shard_metadata.world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return self._shard_metadata.world_size > 1
|
||||
|
||||
@property
|
||||
def runner(self) -> DiffusionRunner:
|
||||
return self._runner
|
||||
|
||||
# Delegate attribute access to the underlying model via the adapter.
|
||||
# Guarded with TYPE_CHECKING to prevent type checker complaints
|
||||
# while still providing full delegation at runtime.
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._adapter.model, name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in (
|
||||
"_config",
|
||||
"_adapter",
|
||||
"_group",
|
||||
"_shard_metadata",
|
||||
"_runner",
|
||||
):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
setattr(self._adapter.model, name, value)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
quality: Literal["low", "medium", "high"] = "medium",
|
||||
seed: int = 2,
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
# Determine number of inference steps based on quality
|
||||
steps = self._config.get_steps_for_quality(quality)
|
||||
|
||||
# For edit mode: compute dimensions from input image
|
||||
# This also stores image_paths in the adapter for encode_prompt()
|
||||
if image_path is not None:
|
||||
computed_dims = self._adapter.set_image_dimensions(image_path)
|
||||
if computed_dims is not None:
|
||||
# Override user-provided dimensions with computed ones
|
||||
width, height = computed_dims
|
||||
|
||||
config = Config(
|
||||
num_inference_steps=steps,
|
||||
height=height,
|
||||
width=width,
|
||||
image_path=image_path,
|
||||
)
|
||||
|
||||
# Generate images via the runner
|
||||
for result in self._runner.generate_image(
|
||||
settings=config,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
partial_images=partial_images,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (GeneratedImage, partial_index, total_partials)
|
||||
generated_image, partial_idx, total_partials = result
|
||||
yield (generated_image.image, partial_idx, total_partials)
|
||||
else:
|
||||
# Final image: GeneratedImage
|
||||
logger.info("generated image")
|
||||
yield result.image
|
||||
|
||||
|
||||
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
|
||||
"""Initialize DistributedImageModel from a BoundInstance."""
|
||||
return DistributedImageModel.from_bound_instance(bound_instance)
|
||||
120
src/exo/worker/engines/image/generate.py
Normal file
120
src/exo/worker/engines/image/generate.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import base64
|
||||
import io
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from exo.shared.types.api import ImageEditsInternalParams, ImageGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
)
|
||||
from exo.worker.engines.image.base import ImageGenerator
|
||||
|
||||
|
||||
def parse_size(size_str: str | None) -> tuple[int, int]:
|
||||
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
|
||||
if not size_str or size_str == "auto":
|
||||
size_str = "1024x1024"
|
||||
|
||||
try:
|
||||
parts = size_str.split("x")
|
||||
if len(parts) == 2:
|
||||
width, height = int(parts[0]), int(parts[1])
|
||||
return (width, height)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
# Default fallback
|
||||
return (1024, 1024)
|
||||
|
||||
|
||||
def warmup_image_generator(model: ImageGenerator) -> Image.Image | None:
|
||||
"""Warmup the image generator with a small image."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a small dummy image for warmup (needed for edit models)
|
||||
dummy_image = Image.new("RGB", (256, 256), color=(128, 128, 128))
|
||||
dummy_path = Path(tmpdir) / "warmup.png"
|
||||
dummy_image.save(dummy_path)
|
||||
|
||||
for result in model.generate(
|
||||
prompt="Warmup",
|
||||
height=256,
|
||||
width=256,
|
||||
quality="low",
|
||||
seed=2,
|
||||
image_path=dummy_path,
|
||||
):
|
||||
if not isinstance(result, tuple):
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def generate_image(
|
||||
model: ImageGenerator,
|
||||
task: ImageGenerationTaskParams | ImageEditsInternalParams,
|
||||
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
|
||||
"""Generate image(s), optionally yielding partial results.
|
||||
|
||||
When partial_images > 0 or stream=True, yields PartialImageResponse for
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0)
|
||||
ImageGenerationResponse for the final complete image
|
||||
"""
|
||||
width, height = parse_size(task.size)
|
||||
quality: Literal["low", "medium", "high"] = task.quality or "medium"
|
||||
seed = 2 # TODO(ciaran): Randomise when not testing anymore
|
||||
|
||||
# Handle streaming params for both generation and edit tasks
|
||||
partial_images = task.partial_images or (3 if task.stream else 0)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
if isinstance(task, ImageEditsInternalParams):
|
||||
# Decode base64 image data and save to temp file
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
|
||||
# Iterate over generator results
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
)
|
||||
else:
|
||||
# Final image
|
||||
image = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
)
|
||||
84
src/exo/worker/engines/image/models/__init__.py
Normal file
84
src/exo/worker/engines/image/models/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.flux import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
FluxModelAdapter,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen import (
|
||||
QWEN_IMAGE_CONFIG,
|
||||
QWEN_IMAGE_EDIT_CONFIG,
|
||||
QwenEditModelAdapter,
|
||||
QwenModelAdapter,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.adapter import ModelAdapter
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
# Type alias for adapter factory functions
|
||||
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
|
||||
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter]
|
||||
|
||||
# Registry maps model_family string to adapter factory
|
||||
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
|
||||
"flux": FluxModelAdapter,
|
||||
"qwen-edit": QwenEditModelAdapter,
|
||||
"qwen": QwenModelAdapter,
|
||||
}
|
||||
|
||||
# Config registry: maps model ID patterns to configs
|
||||
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
|
||||
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
|
||||
"flux.1-dev": FLUX_DEV_CONFIG,
|
||||
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
|
||||
"qwen-image": QWEN_IMAGE_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_config_for_model(model_id: str) -> ImageModelConfig:
|
||||
"""Get configuration for a model ID.
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
|
||||
|
||||
Returns:
|
||||
The model configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If no configuration found for model ID
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
|
||||
for pattern, config in _CONFIG_REGISTRY.items():
|
||||
if pattern in model_id_lower:
|
||||
return config
|
||||
|
||||
raise ValueError(f"No configuration found for model: {model_id}")
|
||||
|
||||
|
||||
def create_adapter_for_model(
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
) -> ModelAdapter:
|
||||
"""Create a model adapter for the given configuration.
|
||||
|
||||
Args:
|
||||
config: The model configuration
|
||||
model_id: The model identifier
|
||||
local_path: Path to the model weights
|
||||
quantize: Optional quantization bits
|
||||
|
||||
Returns:
|
||||
A ModelAdapter instance
|
||||
|
||||
Raises:
|
||||
ValueError: If no adapter found for model family
|
||||
"""
|
||||
factory = _ADAPTER_REGISTRY.get(config.model_family)
|
||||
if factory is None:
|
||||
raise ValueError(f"No adapter found for model family: {config.model_family}")
|
||||
return factory(config, model_id, local_path, quantize)
|
||||
103
src/exo/worker/engines/image/models/base.py
Normal file
103
src/exo/worker/engines/image/models/base.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
|
||||
from mflux.utils.array_util import ArrayUtil
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
|
||||
|
||||
class BaseModelAdapter(ABC):
|
||||
"""Base class for model adapters with shared utilities.
|
||||
|
||||
Provides common implementations for latent creation and decoding.
|
||||
Subclasses implement model-specific prompt encoding and noise computation.
|
||||
"""
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
|
||||
"""Create initial latents. Uses model-specific latent creator."""
|
||||
return LatentCreator.create_for_txt2img_or_img2img(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
img2img=Img2Img(
|
||||
vae=self.model.vae,
|
||||
latent_creator=self._get_latent_creator(),
|
||||
sigmas=runtime_config.scheduler.sigmas,
|
||||
init_time_step=runtime_config.init_time_step,
|
||||
image_path=runtime_config.image_path,
|
||||
),
|
||||
)
|
||||
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
) -> Any:
|
||||
"""Decode latents to image. Shared implementation."""
|
||||
latents = ArrayUtil.unpack_latents(
|
||||
latents=latents,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
decoded = self.model.vae.decode(latents)
|
||||
return ImageUtil.to_image(
|
||||
decoded_latents=decoded,
|
||||
config=runtime_config,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
quantization=self.model.bits,
|
||||
lora_paths=self.model.lora_paths,
|
||||
lora_scales=self.model.lora_scales,
|
||||
image_path=runtime_config.image_path,
|
||||
image_strength=runtime_config.image_strength,
|
||||
generation_time=0,
|
||||
)
|
||||
|
||||
# Abstract methods - subclasses must implement
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> Any:
|
||||
"""Return the underlying mflux model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_latent_creator(self) -> type:
|
||||
"""Return the latent creator class for this model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
):
|
||||
"""Remove transformer blocks outside the assigned range.
|
||||
|
||||
This should be called BEFORE mx.eval() to avoid loading unused weights
|
||||
in distributed mode.
|
||||
|
||||
Args:
|
||||
start_layer: First layer index (inclusive) assigned to this node
|
||||
end_layer: Last layer index (exclusive) assigned to this node
|
||||
total_joint_blocks: Total number of joint blocks in the model
|
||||
total_single_blocks: Total number of single blocks in the model
|
||||
"""
|
||||
...
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
|
||||
"""Default implementation: no dimension computation needed.
|
||||
|
||||
Override in edit adapters to compute dimensions from input image.
|
||||
|
||||
Returns:
|
||||
None (use user-specified dimensions)
|
||||
"""
|
||||
return None
|
||||
11
src/exo/worker/engines/image/models/flux/__init__.py
Normal file
11
src/exo/worker/engines/image/models/flux/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
|
||||
from exo.worker.engines.image.models.flux.config import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FluxModelAdapter",
|
||||
"FLUX_DEV_CONFIG",
|
||||
"FLUX_SCHNELL_CONFIG",
|
||||
]
|
||||
680
src/exo/worker/engines/image/models/flux/adapter.py
Normal file
680
src/exo/worker/engines/image/models/flux/adapter.py
Normal file
@@ -0,0 +1,680 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.model_config import ModelConfig
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
|
||||
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
|
||||
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
|
||||
AttentionUtils,
|
||||
)
|
||||
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
|
||||
JointTransformerBlock,
|
||||
)
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
from mflux.models.flux.variants.txt2img.flux import Flux1
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class FluxPromptData:
|
||||
"""Container for Flux prompt encoding results."""
|
||||
|
||||
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._pooled_prompt_embeds = pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
"""Flux does not use CFG."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
"""Flux does not use CFG."""
|
||||
return None
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Flux has no extra forward kwargs."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Flux does not use conditioning latents."""
|
||||
return None
|
||||
|
||||
|
||||
class FluxModelAdapter(BaseModelAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = Flux1(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
local_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> Flux1:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> Transformer:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.x_embedder.weight.shape[0]
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return FluxLatentCreator
|
||||
|
||||
def encode_prompt(self, prompt: str) -> FluxPromptData:
|
||||
"""Encode prompt into FluxPromptData."""
|
||||
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_cache=self._model.prompt_cache,
|
||||
t5_tokenizer=self._model.t5_tokenizer,
|
||||
clip_tokenizer=self._model.clip_tokenizer,
|
||||
t5_text_encoder=self._model.t5_text_encoder,
|
||||
clip_text_encoder=self._model.clip_text_encoder,
|
||||
)
|
||||
return FluxPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
return False
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
raise NotImplementedError("Flux does not use classifier-free guidance")
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
embedded_hidden = self._transformer.x_embedder(hidden_states)
|
||||
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None, # Ignored by Flux
|
||||
) -> mx.array:
|
||||
if pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"pooled_prompt_embeds is required for Flux text embeddings"
|
||||
)
|
||||
|
||||
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
|
||||
return Transformer.compute_text_embeddings(
|
||||
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
|
||||
)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> mx.array:
|
||||
kontext_image_ids = kwargs.get("kontext_image_ids")
|
||||
return Transformer.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
self._transformer.pos_embed,
|
||||
runtime_config,
|
||||
kontext_image_ids,
|
||||
)
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any, # mx.array for Flux
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_joint_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
)
|
||||
else:
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_joint_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
)
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_single_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
)
|
||||
else:
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_single_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
)
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
|
||||
return self._transformer.proj_out(hidden_states)
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
return cast(
|
||||
list[JointBlockInterface], list(self._transformer.transformer_blocks)
|
||||
)
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
return cast(
|
||||
list[SingleBlockInterface],
|
||||
list(self._transformer.single_transformer_blocks),
|
||||
)
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
) -> None:
|
||||
if end_layer <= total_joint_blocks:
|
||||
# All assigned are joint blocks
|
||||
joint_start, joint_end = start_layer, end_layer
|
||||
single_start, single_end = 0, 0
|
||||
elif start_layer >= total_joint_blocks:
|
||||
# All assigned are single blocks
|
||||
joint_start, joint_end = 0, 0
|
||||
single_start = start_layer - total_joint_blocks
|
||||
single_end = end_layer - total_joint_blocks
|
||||
else:
|
||||
# Spans both joint and single
|
||||
joint_start, joint_end = start_layer, total_joint_blocks
|
||||
single_start = 0
|
||||
single_end = end_layer - total_joint_blocks
|
||||
|
||||
all_joint = list(self._transformer.transformer_blocks)
|
||||
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
|
||||
|
||||
all_single = list(self._transformer.single_transformer_blocks)
|
||||
self._transformer.single_transformer_blocks = all_single[
|
||||
single_start:single_end
|
||||
]
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
def _apply_joint_block_caching(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
num_img_tokens = hidden_states.shape[1]
|
||||
batch_size = hidden_states.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# 1. Compute norms
|
||||
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
|
||||
block.norm1_context(
|
||||
hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V for full image
|
||||
img_query, img_key, img_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Compute Q, K, V for text
|
||||
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_encoder,
|
||||
to_q=attn.add_q_proj,
|
||||
to_k=attn.add_k_proj,
|
||||
to_v=attn.add_v_proj,
|
||||
norm_q=attn.norm_added_q,
|
||||
norm_k=attn.norm_added_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 4. Concatenate Q, K, V: [text, image]
|
||||
query = mx.concatenate([txt_query, img_query], axis=2)
|
||||
key = mx.concatenate([txt_key, img_key], axis=2)
|
||||
value = mx.concatenate([txt_value, img_value], axis=2)
|
||||
|
||||
# 5. Apply RoPE
|
||||
query, key = AttentionUtils.apply_rope(
|
||||
xq=query, xk=key, freqs_cis=rotary_embeddings
|
||||
)
|
||||
|
||||
# 6. Store IMAGE K/V in cache for async pipeline
|
||||
if kv_cache is not None:
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=0,
|
||||
patch_end=num_img_tokens,
|
||||
key=key[:, :, text_seq_len:, :],
|
||||
value=value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 7. Compute full attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 8. Extract and project outputs
|
||||
context_attn_output = attn_output[:, :text_seq_len, :]
|
||||
attn_output = attn_output[:, text_seq_len:, :]
|
||||
|
||||
attn_output = attn.to_out[0](attn_output)
|
||||
context_attn_output = attn.to_add_out(context_attn_output)
|
||||
|
||||
# 9. Apply norm and feed forward
|
||||
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=hidden_states,
|
||||
attn_output=attn_output,
|
||||
gate_mlp=gate_mlp,
|
||||
gate_msa=gate_msa,
|
||||
scale_mlp=scale_mlp,
|
||||
shift_mlp=shift_mlp,
|
||||
norm_layer=block.norm2,
|
||||
ff_layer=block.ff,
|
||||
)
|
||||
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=encoder_hidden_states,
|
||||
attn_output=context_attn_output,
|
||||
gate_mlp=c_gate_mlp,
|
||||
gate_msa=c_gate_msa,
|
||||
scale_mlp=c_scale_mlp,
|
||||
shift_mlp=c_shift_mlp,
|
||||
norm_layer=block.norm2_context,
|
||||
ff_layer=block.ff_context,
|
||||
)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
def _apply_joint_block_patched(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
patch_hidden: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# 1. Compute norms
|
||||
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
|
||||
hidden_states=patch_hidden,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
|
||||
block.norm1_context(
|
||||
hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V for image patch
|
||||
img_query, img_key, img_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Compute Q, K, V for text
|
||||
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_encoder,
|
||||
to_q=attn.add_q_proj,
|
||||
to_k=attn.add_k_proj,
|
||||
to_v=attn.add_v_proj,
|
||||
norm_q=attn.norm_added_q,
|
||||
norm_k=attn.norm_added_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 4. Concatenate Q, K, V for patch: [text, patch]
|
||||
query = mx.concatenate([txt_query, img_query], axis=2)
|
||||
patch_key = mx.concatenate([txt_key, img_key], axis=2)
|
||||
patch_value = mx.concatenate([txt_value, img_value], axis=2)
|
||||
|
||||
# 5. Extract RoPE for [text + current_patch]
|
||||
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
|
||||
patch_img_rope = rotary_embeddings[
|
||||
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
|
||||
]
|
||||
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
|
||||
|
||||
# 6. Apply RoPE
|
||||
query, patch_key = AttentionUtils.apply_rope(
|
||||
xq=query, xk=patch_key, freqs_cis=patch_rope
|
||||
)
|
||||
|
||||
# 7. Update cache with this patch's IMAGE K/V
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=patch_key[:, :, text_seq_len:, :],
|
||||
value=patch_value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 8. Get full K, V from cache
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=patch_key[:, :, :text_seq_len, :],
|
||||
text_value=patch_value[:, :, :text_seq_len, :],
|
||||
)
|
||||
|
||||
# 9. Compute attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=full_key,
|
||||
value=full_value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 10. Extract and project outputs
|
||||
context_attn_output = attn_output[:, :text_seq_len, :]
|
||||
hidden_attn_output = attn_output[:, text_seq_len:, :]
|
||||
|
||||
hidden_attn_output = attn.to_out[0](hidden_attn_output)
|
||||
context_attn_output = attn.to_add_out(context_attn_output)
|
||||
|
||||
# 11. Apply norm and feed forward
|
||||
patch_hidden = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=patch_hidden,
|
||||
attn_output=hidden_attn_output,
|
||||
gate_mlp=gate_mlp,
|
||||
gate_msa=gate_msa,
|
||||
scale_mlp=scale_mlp,
|
||||
shift_mlp=shift_mlp,
|
||||
norm_layer=block.norm2,
|
||||
ff_layer=block.ff,
|
||||
)
|
||||
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=encoder_hidden_states,
|
||||
attn_output=context_attn_output,
|
||||
gate_mlp=c_gate_mlp,
|
||||
gate_msa=c_gate_msa,
|
||||
scale_mlp=c_scale_mlp,
|
||||
shift_mlp=c_shift_mlp,
|
||||
norm_layer=block.norm2_context,
|
||||
ff_layer=block.ff_context,
|
||||
)
|
||||
|
||||
return encoder_hidden_states, patch_hidden
|
||||
|
||||
def _apply_single_block_caching(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
) -> mx.array:
|
||||
total_seq_len = hidden_states.shape[1]
|
||||
num_img_tokens = total_seq_len - text_seq_len
|
||||
batch_size = hidden_states.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# Residual connection
|
||||
residual = hidden_states
|
||||
|
||||
# 1. Compute norm
|
||||
norm_hidden, gate = block.norm(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V
|
||||
query, key, value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Apply RoPE
|
||||
query, key = AttentionUtils.apply_rope(
|
||||
xq=query, xk=key, freqs_cis=rotary_embeddings
|
||||
)
|
||||
|
||||
# 4. Store IMAGE K/V in cache
|
||||
if kv_cache is not None:
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=0,
|
||||
patch_end=num_img_tokens,
|
||||
key=key[:, :, text_seq_len:, :],
|
||||
value=value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 5. Compute attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 6. Apply feed forward and projection
|
||||
hidden_states = block._apply_feed_forward_and_projection(
|
||||
norm_hidden_states=norm_hidden,
|
||||
attn_output=attn_output,
|
||||
gate=gate,
|
||||
)
|
||||
|
||||
return residual + hidden_states
|
||||
|
||||
def _apply_single_block_patched(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
patch_hidden: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
) -> mx.array:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# Residual connection
|
||||
residual = patch_hidden
|
||||
|
||||
# 1. Compute norm
|
||||
norm_hidden, gate = block.norm(
|
||||
hidden_states=patch_hidden,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V
|
||||
query, key, value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Extract RoPE for [text + current_patch]
|
||||
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
|
||||
patch_img_rope = rotary_embeddings[
|
||||
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
|
||||
]
|
||||
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
|
||||
|
||||
# 4. Apply RoPE
|
||||
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=patch_rope)
|
||||
|
||||
# 5. Update cache with this patch's IMAGE K/V
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=key[:, :, text_seq_len:, :],
|
||||
value=value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 6. Get full K, V from cache
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=key[:, :, :text_seq_len, :],
|
||||
text_value=value[:, :, :text_seq_len, :],
|
||||
)
|
||||
|
||||
# 7. Compute attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=full_key,
|
||||
value=full_value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 8. Apply feed forward and projection
|
||||
hidden_states = block._apply_feed_forward_and_projection(
|
||||
norm_hidden_states=norm_hidden,
|
||||
attn_output=attn_output,
|
||||
gate=gate,
|
||||
)
|
||||
|
||||
return residual + hidden_states
|
||||
48
src/exo/worker/engines/image/models/flux/config.py
Normal file
48
src/exo/worker/engines/image/models/flux/config.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from exo.worker.engines.image.config import (
|
||||
BlockType,
|
||||
ImageModelConfig,
|
||||
TransformerBlockConfig,
|
||||
)
|
||||
|
||||
FLUX_SCHNELL_CONFIG = ImageModelConfig(
|
||||
model_family="flux",
|
||||
model_variant="schnell",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=8,
|
||||
default_steps={"low": 1, "medium": 2, "high": 4},
|
||||
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
|
||||
uses_attention_mask=False,
|
||||
)
|
||||
|
||||
|
||||
FLUX_DEV_CONFIG = ImageModelConfig(
|
||||
model_family="flux",
|
||||
model_variant="dev",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=8,
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
|
||||
uses_attention_mask=False,
|
||||
)
|
||||
13
src/exo/worker/engines/image/models/qwen/__init__.py
Normal file
13
src/exo/worker/engines/image/models/qwen/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
|
||||
from exo.worker.engines.image.models.qwen.config import (
|
||||
QWEN_IMAGE_CONFIG,
|
||||
QWEN_IMAGE_EDIT_CONFIG,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter
|
||||
|
||||
__all__ = [
|
||||
"QwenModelAdapter",
|
||||
"QwenEditModelAdapter",
|
||||
"QWEN_IMAGE_CONFIG",
|
||||
"QWEN_IMAGE_EDIT_CONFIG",
|
||||
]
|
||||
519
src/exo/worker/engines/image/models/qwen/adapter.py
Normal file
519
src/exo/worker/engines/image/models/qwen/adapter.py
Normal file
@@ -0,0 +1,519 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.model_config import ModelConfig
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
||||
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
|
||||
QwenPromptEncoder,
|
||||
)
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
|
||||
QwenTransformerBlock,
|
||||
)
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class QwenPromptData:
|
||||
"""Container for Qwen prompt encoding results.
|
||||
|
||||
Implements PromptData protocol with additional Qwen-specific attributes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
prompt_mask: mx.array,
|
||||
negative_prompt_embeds: mx.array,
|
||||
negative_prompt_mask: mx.array,
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self.prompt_mask = prompt_mask
|
||||
self._negative_prompt_embeds = negative_prompt_embeds
|
||||
self.negative_prompt_mask = negative_prompt_mask
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from encoder."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
|
||||
return self._prompt_embeds # Use prompt_embeds as placeholder
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array:
|
||||
"""Negative prompt embeddings for CFG."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder - Qwen doesn't use pooled embeds."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Return encoder_hidden_states_mask for the appropriate prompt."""
|
||||
if positive:
|
||||
return {"encoder_hidden_states_mask": self.prompt_mask}
|
||||
else:
|
||||
return {"encoder_hidden_states_mask": self.negative_prompt_mask}
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Standard Qwen does not use conditioning latents."""
|
||||
return None
|
||||
|
||||
|
||||
class QwenModelAdapter(BaseModelAdapter):
|
||||
"""Adapter for Qwen-Image model.
|
||||
|
||||
Key differences from Flux:
|
||||
- Single text encoder (vs dual T5+CLIP)
|
||||
- 60 joint-style blocks, no single blocks
|
||||
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
|
||||
- Norm-preserving CFG with negative prompts
|
||||
- Uses attention mask for variable-length text
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = QwenImage(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
local_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> QwenImage:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> QwenTransformer:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.inner_dim
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return QwenLatentCreator
|
||||
|
||||
def encode_prompt(self, prompt: str) -> QwenPromptData:
|
||||
"""Encode prompt into QwenPromptData.
|
||||
|
||||
Qwen uses classifier-free guidance with explicit negative prompts.
|
||||
Returns a QwenPromptData container with all 4 tensors.
|
||||
"""
|
||||
# TODO(ciaran): empty string as default negative prompt
|
||||
negative_prompt = ""
|
||||
|
||||
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
|
||||
QwenPromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_cache=self._model.prompt_cache,
|
||||
qwen_tokenizer=self._model.qwen_tokenizer,
|
||||
qwen_text_encoder=self._model.text_encoder,
|
||||
)
|
||||
)
|
||||
|
||||
return QwenPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_mask=prompt_mask,
|
||||
negative_prompt_embeds=neg_embeds,
|
||||
negative_prompt_mask=neg_mask,
|
||||
)
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
gs = self._config.guidance_scale
|
||||
return gs is not None and gs > 1.0
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
return self._model.compute_guided_noise(
|
||||
noise=noise_positive,
|
||||
noise_negative=noise_negative,
|
||||
guidance=guidance_scale,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute image and text embeddings."""
|
||||
# Image embedding
|
||||
embedded_hidden = self._transformer.img_in(hidden_states)
|
||||
# Text embedding: first normalize, then project
|
||||
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
|
||||
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings.
|
||||
|
||||
For Qwen, the time_text_embed only uses hidden_states for:
|
||||
- batch_size (shape[0])
|
||||
- dtype
|
||||
|
||||
This allows us to pass any tensor (latents, prompt_embeds) as a fallback
|
||||
when embedded hidden_states are not yet available.
|
||||
"""
|
||||
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
|
||||
# (which for Qwen is the same as prompt_embeds)
|
||||
ref_tensor = (
|
||||
hidden_states if hidden_states is not None else pooled_prompt_embeds
|
||||
)
|
||||
if ref_tensor is None:
|
||||
raise ValueError(
|
||||
"Either hidden_states or pooled_prompt_embeds is required "
|
||||
"for Qwen text embeddings"
|
||||
)
|
||||
|
||||
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
|
||||
batch_size = ref_tensor.shape[0]
|
||||
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
|
||||
return self._transformer.time_text_embed(timestep, ref_tensor)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Compute 3D rotary embeddings for Qwen.
|
||||
|
||||
Qwen uses video-aware 3D RoPE with separate embeddings for image and text.
|
||||
|
||||
Returns:
|
||||
tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]:
|
||||
((img_cos, img_sin), (txt_cos, txt_sin))
|
||||
"""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
cond_image_grid = kwargs.get("cond_image_grid")
|
||||
|
||||
if encoder_hidden_states_mask is None:
|
||||
raise ValueError(
|
||||
"encoder_hidden_states_mask is required for Qwen RoPE computation"
|
||||
)
|
||||
|
||||
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
pos_embed=self._transformer.pos_embed,
|
||||
config=runtime_config,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any, # tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]] for Qwen
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply Qwen joint block.
|
||||
|
||||
For caching mode, we run the full block and optionally populate the KV cache.
|
||||
For patched mode, we use the cached KV values (not yet implemented).
|
||||
"""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
block_idx = kwargs.get("block_idx")
|
||||
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_joint_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
else:
|
||||
# mode == BlockWrapperMode.PATCHED
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_joint_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Qwen has no single blocks."""
|
||||
raise NotImplementedError("Qwen does not have single blocks")
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply final normalization and projection."""
|
||||
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
|
||||
return self._transformer.proj_out(hidden_states)
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
"""Return all 60 transformer blocks."""
|
||||
return cast(
|
||||
list[JointBlockInterface], list(self._transformer.transformer_blocks)
|
||||
)
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
"""Qwen has no single blocks."""
|
||||
return []
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
) -> None:
|
||||
all_blocks = list(self._transformer.transformer_blocks)
|
||||
assigned_blocks = all_blocks[start_layer:end_layer]
|
||||
self._transformer.transformer_blocks = assigned_blocks
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
"""Merge image and text streams.
|
||||
|
||||
For Qwen, this is called before final projection.
|
||||
The streams remain separate through all blocks.
|
||||
"""
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
def _apply_joint_block_caching(
|
||||
self,
|
||||
block: Any, # QwenTransformerBlock
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply joint block in caching mode (full attention, optionally populate cache).
|
||||
|
||||
Delegates to the QwenTransformerBlock's forward pass.
|
||||
"""
|
||||
# Call the block directly - it handles all the modulation and attention internally
|
||||
return block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_emb=rotary_embeddings,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def _apply_joint_block_patched(
|
||||
self,
|
||||
block: Any, # QwenTransformerBlock
|
||||
patch_hidden: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dim
|
||||
|
||||
# 1. Compute modulation parameters
|
||||
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
|
||||
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
|
||||
|
||||
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
|
||||
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
|
||||
|
||||
# 2. Apply normalization and modulation
|
||||
img_normed = block.img_norm1(patch_hidden)
|
||||
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
|
||||
|
||||
txt_normed = block.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
|
||||
|
||||
# 3. Compute Q, K, V for image patch
|
||||
img_query = attn.to_q(img_modulated)
|
||||
img_key = attn.to_k(img_modulated)
|
||||
img_value = attn.to_v(img_modulated)
|
||||
|
||||
# 4. Compute Q, K, V for text
|
||||
txt_query = attn.add_q_proj(txt_modulated)
|
||||
txt_key = attn.add_k_proj(txt_modulated)
|
||||
txt_value = attn.add_v_proj(txt_modulated)
|
||||
|
||||
# 5. Reshape to [B, S, H, D]
|
||||
patch_len = patch_hidden.shape[1]
|
||||
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
|
||||
|
||||
txt_query = mx.reshape(
|
||||
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
|
||||
txt_value = mx.reshape(
|
||||
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
|
||||
# 6. Apply RMSNorm to Q, K
|
||||
img_query = attn.norm_q(img_query)
|
||||
img_key = attn.norm_k(img_key)
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
# 7. Extract RoPE for patch: slice image RoPE, keep full text RoPE
|
||||
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
|
||||
patch_img_cos = img_cos[patch_start:patch_end]
|
||||
patch_img_sin = img_sin[patch_start:patch_end]
|
||||
|
||||
# 8. Apply RoPE to Q, K
|
||||
img_query = QwenAttention._apply_rope_qwen(
|
||||
img_query, patch_img_cos, patch_img_sin
|
||||
)
|
||||
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
|
||||
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
|
||||
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
|
||||
|
||||
# 9. Transpose to [B, H, S, D] for cache operations
|
||||
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
|
||||
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
|
||||
|
||||
# 10. Update cache with this patch's IMAGE K/V
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=img_key_bhsd,
|
||||
value=img_value_bhsd,
|
||||
)
|
||||
|
||||
# 11. Get full K, V from cache (text + full image)
|
||||
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
|
||||
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=txt_key_bhsd,
|
||||
text_value=txt_value_bhsd,
|
||||
)
|
||||
|
||||
# 12. Build query: [text, patch]
|
||||
joint_query = mx.concatenate([txt_query, img_query], axis=1)
|
||||
|
||||
# 13. Build attention mask for [text + patch] query attending to [text + full_image] KV
|
||||
mask = QwenAttention._convert_mask_for_qwen(
|
||||
mask=encoder_hidden_states_mask,
|
||||
joint_seq_len=full_key.shape[2], # text + full_image
|
||||
txt_seq_len=text_seq_len,
|
||||
)
|
||||
|
||||
# 14. Compute attention
|
||||
hidden_states = attn._compute_attention_qwen(
|
||||
query=joint_query,
|
||||
key=mx.transpose(full_key, (0, 2, 1, 3)), # Back to [B, S, H, D]
|
||||
value=mx.transpose(full_value, (0, 2, 1, 3)),
|
||||
mask=mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
# 15. Extract text and image attention outputs
|
||||
txt_attn_output = hidden_states[:, :text_seq_len, :]
|
||||
img_attn_output = hidden_states[:, text_seq_len:, :]
|
||||
|
||||
# 16. Project outputs
|
||||
img_attn_output = attn.attn_to_out[0](img_attn_output)
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
|
||||
# 17. Apply residual + gate for attention
|
||||
patch_hidden = patch_hidden + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
|
||||
# 18. Apply feed-forward for image
|
||||
img_normed2 = block.img_norm2(patch_hidden)
|
||||
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
|
||||
img_normed2, img_mod2
|
||||
)
|
||||
img_mlp_output = block.img_ff(img_modulated2)
|
||||
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
|
||||
|
||||
# 19. Apply feed-forward for text
|
||||
txt_normed2 = block.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
|
||||
txt_normed2, txt_mod2
|
||||
)
|
||||
txt_mlp_output = block.txt_ff(txt_modulated2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
||||
|
||||
return encoder_hidden_states, patch_hidden
|
||||
49
src/exo/worker/engines/image/models/qwen/config.py
Normal file
49
src/exo/worker/engines/image/models/qwen/config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from exo.worker.engines.image.config import (
|
||||
BlockType,
|
||||
ImageModelConfig,
|
||||
TransformerBlockConfig,
|
||||
)
|
||||
|
||||
# Qwen-Image has 60 joint-style blocks (no single blocks)
|
||||
# Architecture: 24 heads * 128 dim = 3072 hidden dim
|
||||
# VAE uses scale factor of 16 (vs Flux's 8)
|
||||
QWEN_IMAGE_CONFIG = ImageModelConfig(
|
||||
model_family="qwen",
|
||||
model_variant="image",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
|
||||
),
|
||||
# Qwen has no single blocks - all blocks process image and text separately
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=16,
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
|
||||
uses_attention_mask=True, # Qwen uses encoder_hidden_states_mask
|
||||
guidance_scale=None, # Set to None or < 1.0 to disable CFG
|
||||
)
|
||||
|
||||
# Qwen-Image-Edit uses the same architecture but different processing pipeline
|
||||
# Uses vision-language encoding and conditioning latents
|
||||
QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
|
||||
model_family="qwen-edit",
|
||||
model_variant="image-edit",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
|
||||
),
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=16,
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125,
|
||||
uses_attention_mask=True,
|
||||
guidance_scale=None,
|
||||
)
|
||||
671
src/exo/worker/engines/image/models/qwen/edit_adapter.py
Normal file
671
src/exo/worker/engines/image/models/qwen/edit_adapter.py
Normal file
@@ -0,0 +1,671 @@
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
|
||||
QwenTransformerBlock,
|
||||
)
|
||||
from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
|
||||
from mflux.models.qwen.variants.edit.utils.qwen_edit_util import QwenEditUtil
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class QwenEditPromptData:
|
||||
"""Container for Qwen edit prompt encoding results.
|
||||
|
||||
Includes vision-language encoded embeddings and edit-specific conditioning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
prompt_mask: mx.array,
|
||||
negative_prompt_embeds: mx.array,
|
||||
negative_prompt_mask: mx.array,
|
||||
conditioning_latents: mx.array,
|
||||
qwen_image_ids: mx.array,
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self.prompt_mask = prompt_mask
|
||||
self._negative_prompt_embeds = negative_prompt_embeds
|
||||
self.negative_prompt_mask = negative_prompt_mask
|
||||
self._conditioning_latents = conditioning_latents
|
||||
self._qwen_image_ids = qwen_image_ids
|
||||
self._cond_image_grid = cond_image_grid
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from vision-language encoder."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array:
|
||||
"""Negative prompt embeddings for CFG."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder - Qwen doesn't use pooled embeds."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array:
|
||||
"""Static image conditioning latents to concatenate with generated latents."""
|
||||
return self._conditioning_latents
|
||||
|
||||
@property
|
||||
def qwen_image_ids(self) -> mx.array:
|
||||
"""Spatial position IDs for conditioning images."""
|
||||
return self._qwen_image_ids
|
||||
|
||||
@property
|
||||
def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:
|
||||
"""Conditioning image grid dimensions."""
|
||||
return self._cond_image_grid
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Return encoder_hidden_states_mask and edit-specific params."""
|
||||
if positive:
|
||||
return {
|
||||
"encoder_hidden_states_mask": self.prompt_mask,
|
||||
"qwen_image_ids": self._qwen_image_ids,
|
||||
"cond_image_grid": self._cond_image_grid,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"encoder_hidden_states_mask": self.negative_prompt_mask,
|
||||
"qwen_image_ids": self._qwen_image_ids,
|
||||
"cond_image_grid": self._cond_image_grid,
|
||||
}
|
||||
|
||||
@property
|
||||
def is_edit_mode(self) -> bool:
|
||||
"""Indicates this is edit mode with conditioning latents."""
|
||||
return True
|
||||
|
||||
|
||||
class QwenEditModelAdapter(BaseModelAdapter):
|
||||
"""Adapter for Qwen-Image-Edit model.
|
||||
|
||||
Key differences from standard QwenModelAdapter:
|
||||
- Uses QwenImageEdit model with vision-language components
|
||||
- Encodes prompts WITH input images via VL tokenizer/encoder
|
||||
- Creates conditioning latents from input images
|
||||
- Supports image editing with concatenated latents during diffusion
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = QwenImageEdit(
|
||||
quantize=quantize,
|
||||
local_path=str(local_path),
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
# Store dimensions and image paths (set via set_image_dimensions)
|
||||
self._vl_width: int | None = None
|
||||
self._vl_height: int | None = None
|
||||
self._vae_width: int | None = None
|
||||
self._vae_height: int | None = None
|
||||
self._image_paths: list[str] | None = None
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> QwenImageEdit:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> QwenTransformer:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.inner_dim
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return QwenLatentCreator
|
||||
|
||||
def _compute_dimensions_from_image(
|
||||
self, image_path: Path
|
||||
) -> tuple[int, int, int, int, int, int]:
|
||||
"""Compute VL and VAE dimensions from input image.
|
||||
|
||||
Returns:
|
||||
(vl_width, vl_height, vae_width, vae_height, output_width, output_height)
|
||||
"""
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
|
||||
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
|
||||
image_size = pil_image.size
|
||||
|
||||
# Vision-language dimensions (384x384 target area)
|
||||
condition_image_size = 384 * 384
|
||||
condition_ratio = image_size[0] / image_size[1]
|
||||
vl_width = math.sqrt(condition_image_size * condition_ratio)
|
||||
vl_height = vl_width / condition_ratio
|
||||
vl_width = round(vl_width / 32) * 32
|
||||
vl_height = round(vl_height / 32) * 32
|
||||
|
||||
# VAE dimensions (1024x1024 target area)
|
||||
vae_image_size = 1024 * 1024
|
||||
vae_ratio = image_size[0] / image_size[1]
|
||||
vae_width = math.sqrt(vae_image_size * vae_ratio)
|
||||
vae_height = vae_width / vae_ratio
|
||||
vae_width = round(vae_width / 32) * 32
|
||||
vae_height = round(vae_height / 32) * 32
|
||||
|
||||
# Output dimensions from input image aspect ratio
|
||||
target_area = 1024 * 1024
|
||||
ratio = image_size[0] / image_size[1]
|
||||
output_width = math.sqrt(target_area * ratio)
|
||||
output_height = output_width / ratio
|
||||
output_width = round(output_width / 32) * 32
|
||||
output_height = round(output_height / 32) * 32
|
||||
|
||||
# Ensure multiple of 16 for VAE
|
||||
vae_scale_factor = 8
|
||||
multiple_of = vae_scale_factor * 2
|
||||
output_width = output_width // multiple_of * multiple_of
|
||||
output_height = output_height // multiple_of * multiple_of
|
||||
|
||||
return (
|
||||
int(vl_width),
|
||||
int(vl_height),
|
||||
int(vae_width),
|
||||
int(vae_height),
|
||||
int(output_width),
|
||||
int(output_height),
|
||||
)
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
|
||||
"""Create initial noise latents (pure noise for edit mode)."""
|
||||
return QwenLatentCreator.create_noise(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
|
||||
def encode_prompt(self, prompt: str) -> QwenEditPromptData:
|
||||
"""Encode prompt with input images using vision-language encoder.
|
||||
|
||||
Uses stored image_paths from set_image_dimensions() for VL encoding.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt for editing
|
||||
|
||||
Returns:
|
||||
QwenEditPromptData with VL embeddings and conditioning latents
|
||||
"""
|
||||
# Ensure image_paths and dimensions were set via set_image_dimensions()
|
||||
if self._image_paths is None:
|
||||
raise RuntimeError(
|
||||
"set_image_dimensions() must be called before encode_prompt() "
|
||||
"for QwenEditModelAdapter"
|
||||
)
|
||||
|
||||
negative_prompt = ""
|
||||
image_paths = self._image_paths
|
||||
|
||||
# Use stored dimensions (computed from input image)
|
||||
vl_width = self._vl_width
|
||||
vl_height = self._vl_height
|
||||
vae_width = self._vae_width
|
||||
vae_height = self._vae_height
|
||||
|
||||
# Encode prompts with images via vision-language components
|
||||
tokenizer = self._model.qwen_vl_tokenizer
|
||||
pos_input_ids, pos_attention_mask, pos_pixel_values, pos_image_grid_thw = (
|
||||
tokenizer.tokenize_with_image(
|
||||
prompt, image_paths, vl_width=vl_width, vl_height=vl_height
|
||||
)
|
||||
)
|
||||
|
||||
pos_hidden_states = self._model.qwen_vl_encoder(
|
||||
input_ids=pos_input_ids,
|
||||
attention_mask=pos_attention_mask,
|
||||
pixel_values=pos_pixel_values,
|
||||
image_grid_thw=pos_image_grid_thw,
|
||||
)
|
||||
mx.eval(pos_hidden_states[0])
|
||||
mx.eval(pos_hidden_states[1])
|
||||
|
||||
# Encode negative prompt with images
|
||||
neg_input_ids, neg_attention_mask, neg_pixel_values, neg_image_grid_thw = (
|
||||
tokenizer.tokenize_with_image(
|
||||
negative_prompt, image_paths, vl_width=vl_width, vl_height=vl_height
|
||||
)
|
||||
)
|
||||
|
||||
neg_hidden_states = self._model.qwen_vl_encoder(
|
||||
input_ids=neg_input_ids,
|
||||
attention_mask=neg_attention_mask,
|
||||
pixel_values=neg_pixel_values,
|
||||
image_grid_thw=neg_image_grid_thw,
|
||||
)
|
||||
mx.eval(neg_hidden_states[0])
|
||||
mx.eval(neg_hidden_states[1])
|
||||
|
||||
# Create conditioning latents from input images
|
||||
# Ensure dimensions are set (should have been set via set_image_dimensions)
|
||||
assert vl_width is not None and vl_height is not None
|
||||
assert vae_width is not None and vae_height is not None
|
||||
|
||||
(
|
||||
conditioning_latents,
|
||||
qwen_image_ids,
|
||||
cond_h_patches,
|
||||
cond_w_patches,
|
||||
num_images,
|
||||
) = QwenEditUtil.create_image_conditioning_latents(
|
||||
vae=self._model.vae,
|
||||
height=vae_height,
|
||||
width=vae_width,
|
||||
image_paths=image_paths,
|
||||
vl_width=vl_width,
|
||||
vl_height=vl_height,
|
||||
)
|
||||
|
||||
# Build cond_image_grid
|
||||
if num_images > 1:
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [
|
||||
(1, cond_h_patches, cond_w_patches) for _ in range(num_images)
|
||||
]
|
||||
else:
|
||||
cond_image_grid = (1, cond_h_patches, cond_w_patches)
|
||||
|
||||
return QwenEditPromptData(
|
||||
prompt_embeds=pos_hidden_states[0].astype(mx.float16),
|
||||
prompt_mask=pos_hidden_states[1].astype(mx.float16),
|
||||
negative_prompt_embeds=neg_hidden_states[0].astype(mx.float16),
|
||||
negative_prompt_mask=neg_hidden_states[1].astype(mx.float16),
|
||||
conditioning_latents=conditioning_latents,
|
||||
qwen_image_ids=qwen_image_ids,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
|
||||
"""Compute and store dimensions from input image.
|
||||
|
||||
Also stores image_paths for use in encode_prompt().
|
||||
|
||||
Returns:
|
||||
(output_width, output_height) for runtime config
|
||||
"""
|
||||
vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(
|
||||
image_path
|
||||
)
|
||||
self._vl_width = vl_w
|
||||
self._vl_height = vl_h
|
||||
self._vae_width = vae_w
|
||||
self._vae_height = vae_h
|
||||
self._image_paths = [str(image_path)]
|
||||
return out_w, out_h
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
gs = self._config.guidance_scale
|
||||
return gs is not None and gs > 1.0
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
|
||||
return QwenImage.compute_guided_noise(
|
||||
noise=noise_positive,
|
||||
noise_negative=noise_negative,
|
||||
guidance=guidance_scale,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute image and text embeddings."""
|
||||
embedded_hidden = self._transformer.img_in(hidden_states)
|
||||
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
|
||||
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings."""
|
||||
ref_tensor = (
|
||||
hidden_states if hidden_states is not None else pooled_prompt_embeds
|
||||
)
|
||||
if ref_tensor is None:
|
||||
raise ValueError(
|
||||
"Either hidden_states or pooled_prompt_embeds is required "
|
||||
"for Qwen text embeddings"
|
||||
)
|
||||
|
||||
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
|
||||
batch_size = ref_tensor.shape[0]
|
||||
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
|
||||
return self._transformer.time_text_embed(timestep, ref_tensor)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Compute 3D rotary embeddings for Qwen edit."""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
cond_image_grid = kwargs.get("cond_image_grid")
|
||||
|
||||
if encoder_hidden_states_mask is None:
|
||||
raise ValueError(
|
||||
"encoder_hidden_states_mask is required for Qwen RoPE computation"
|
||||
)
|
||||
|
||||
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
pos_embed=self._transformer.pos_embed,
|
||||
config=runtime_config,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply Qwen joint block."""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
block_idx = kwargs.get("block_idx")
|
||||
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_joint_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
else:
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_joint_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Qwen has no single blocks."""
|
||||
raise NotImplementedError("Qwen does not have single blocks")
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply final normalization and projection."""
|
||||
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
|
||||
return self._transformer.proj_out(hidden_states)
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
"""Return all 60 transformer blocks."""
|
||||
return cast(
|
||||
list[JointBlockInterface], list(self._transformer.transformer_blocks)
|
||||
)
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
"""Qwen has no single blocks."""
|
||||
return []
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
) -> None:
|
||||
all_blocks = list(self._transformer.transformer_blocks)
|
||||
assigned_blocks = all_blocks[start_layer:end_layer]
|
||||
self._transformer.transformer_blocks = assigned_blocks
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
"""Merge image and text streams."""
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
def _apply_joint_block_caching(
|
||||
self,
|
||||
block: Any,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply joint block in caching mode."""
|
||||
return block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_emb=rotary_embeddings,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def _apply_joint_block_patched(
|
||||
self,
|
||||
block: Any,
|
||||
patch_hidden: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dim
|
||||
|
||||
# Modulation parameters
|
||||
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
|
||||
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
|
||||
|
||||
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
|
||||
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
|
||||
|
||||
# Normalization and modulation
|
||||
img_normed = block.img_norm1(patch_hidden)
|
||||
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
|
||||
|
||||
txt_normed = block.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
|
||||
|
||||
# Q, K, V for image patch
|
||||
img_query = attn.to_q(img_modulated)
|
||||
img_key = attn.to_k(img_modulated)
|
||||
img_value = attn.to_v(img_modulated)
|
||||
|
||||
# Q, K, V for text
|
||||
txt_query = attn.add_q_proj(txt_modulated)
|
||||
txt_key = attn.add_k_proj(txt_modulated)
|
||||
txt_value = attn.add_v_proj(txt_modulated)
|
||||
|
||||
# Reshape to [B, S, H, D]
|
||||
patch_len = patch_hidden.shape[1]
|
||||
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
|
||||
|
||||
txt_query = mx.reshape(
|
||||
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
|
||||
txt_value = mx.reshape(
|
||||
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
|
||||
# RMSNorm to Q, K
|
||||
img_query = attn.norm_q(img_query)
|
||||
img_key = attn.norm_k(img_key)
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
# Extract RoPE for patch
|
||||
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
|
||||
patch_img_cos = img_cos[patch_start:patch_end]
|
||||
patch_img_sin = img_sin[patch_start:patch_end]
|
||||
|
||||
# Apply RoPE
|
||||
img_query = QwenAttention._apply_rope_qwen(
|
||||
img_query, patch_img_cos, patch_img_sin
|
||||
)
|
||||
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
|
||||
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
|
||||
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
|
||||
|
||||
# Transpose to [B, H, S, D]
|
||||
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
|
||||
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
|
||||
|
||||
# Update cache
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=img_key_bhsd,
|
||||
value=img_value_bhsd,
|
||||
)
|
||||
|
||||
# Get full K, V from cache
|
||||
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
|
||||
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=txt_key_bhsd,
|
||||
text_value=txt_value_bhsd,
|
||||
)
|
||||
|
||||
# Build query
|
||||
joint_query = mx.concatenate([txt_query, img_query], axis=1)
|
||||
|
||||
# Build attention mask
|
||||
mask = QwenAttention._convert_mask_for_qwen(
|
||||
mask=encoder_hidden_states_mask,
|
||||
joint_seq_len=full_key.shape[2],
|
||||
txt_seq_len=text_seq_len,
|
||||
)
|
||||
|
||||
# Compute attention
|
||||
hidden_states = attn._compute_attention_qwen(
|
||||
query=joint_query,
|
||||
key=mx.transpose(full_key, (0, 2, 1, 3)),
|
||||
value=mx.transpose(full_value, (0, 2, 1, 3)),
|
||||
mask=mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
# Extract outputs
|
||||
txt_attn_output = hidden_states[:, :text_seq_len, :]
|
||||
img_attn_output = hidden_states[:, text_seq_len:, :]
|
||||
|
||||
# Project
|
||||
img_attn_output = attn.attn_to_out[0](img_attn_output)
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
|
||||
# Residual + gate
|
||||
patch_hidden = patch_hidden + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
|
||||
# Feed-forward for image
|
||||
img_normed2 = block.img_norm2(patch_hidden)
|
||||
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
|
||||
img_normed2, img_mod2
|
||||
)
|
||||
img_mlp_output = block.img_ff(img_modulated2)
|
||||
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
|
||||
|
||||
# Feed-forward for text
|
||||
txt_normed2 = block.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
|
||||
txt_normed2, txt_mod2
|
||||
)
|
||||
txt_mlp_output = block.txt_ff(txt_modulated2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
||||
|
||||
return encoder_hidden_states, patch_hidden
|
||||
23
src/exo/worker/engines/image/pipeline/__init__.py
Normal file
23
src/exo/worker/engines/image/pipeline/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
ModelAdapter,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
|
||||
|
||||
__all__ = [
|
||||
"BlockWrapperMode",
|
||||
"DiffusionRunner",
|
||||
"ImagePatchKVCache",
|
||||
"JointBlockInterface",
|
||||
"JointBlockWrapper",
|
||||
"ModelAdapter",
|
||||
"SingleBlockInterface",
|
||||
"SingleBlockWrapper",
|
||||
]
|
||||
402
src/exo/worker/engines/image/pipeline/adapter.py
Normal file
402
src/exo/worker/engines/image/pipeline/adapter.py
Normal file
@@ -0,0 +1,402 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class AttentionInterface(Protocol):
|
||||
num_heads: int
|
||||
head_dimension: int
|
||||
to_q: Any
|
||||
to_k: Any
|
||||
to_v: Any
|
||||
norm_q: Any
|
||||
norm_k: Any
|
||||
to_out: list[Any]
|
||||
|
||||
|
||||
class JointAttentionInterface(AttentionInterface, Protocol):
|
||||
add_q_proj: Any
|
||||
add_k_proj: Any
|
||||
add_v_proj: Any
|
||||
norm_added_q: Any
|
||||
norm_added_k: Any
|
||||
to_add_out: Any
|
||||
|
||||
|
||||
class JointBlockInterface(Protocol):
|
||||
attn: JointAttentionInterface
|
||||
norm1: Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
|
||||
norm1_context: (
|
||||
Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
|
||||
)
|
||||
norm2: Any
|
||||
norm2_context: Any
|
||||
ff: Any
|
||||
ff_context: Any
|
||||
|
||||
|
||||
class SingleBlockInterface(Protocol):
|
||||
attn: AttentionInterface
|
||||
norm: Any # Callable module: (hidden_states, text_embeddings) -> tuple[2 arrays]
|
||||
|
||||
def _apply_feed_forward_and_projection(
|
||||
self, norm_hidden_states: mx.array, attn_output: mx.array, gate: mx.array
|
||||
) -> mx.array:
|
||||
"""Apply feed forward network and projection."""
|
||||
...
|
||||
|
||||
|
||||
class BlockWrapperMode(Enum):
|
||||
CACHING = "caching" # Sync mode: compute full attention, populate cache
|
||||
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
|
||||
|
||||
|
||||
class PromptData(Protocol):
|
||||
"""Protocol for encoded prompt data.
|
||||
|
||||
All adapters must return prompt data that conforms to this protocol.
|
||||
Model-specific prompt data classes can add additional attributes
|
||||
(e.g., attention masks for Qwen).
|
||||
"""
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from encoder."""
|
||||
...
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Pooled text embeddings (for Flux) or placeholder (for Qwen)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
"""Negative prompt embeddings for CFG (None if not using CFG)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
"""Negative pooled embeddings for CFG (None if not using CFG)."""
|
||||
...
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Return model-specific kwargs for forward pass.
|
||||
|
||||
Args:
|
||||
positive: If True, return kwargs for positive prompt pass.
|
||||
If False, return kwargs for negative prompt pass.
|
||||
|
||||
Returns:
|
||||
Dict of extra kwargs (e.g., {"encoder_hidden_states_mask": ...} for Qwen)
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Conditioning latents for edit mode.
|
||||
|
||||
Returns:
|
||||
Conditioning latents array for image editing, None for standard generation.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModelAdapter(Protocol):
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
"""Return the model configuration."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model(self) -> Any:
|
||||
"""Return the underlying mflux model instance (e.g., Flux1, Fibo, Qwen)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def transformer(self) -> Any:
|
||||
"""Return the transformer component of the model."""
|
||||
...
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
"""Return the hidden dimension of the transformer."""
|
||||
...
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute x_embedder and context_embedder outputs.
|
||||
|
||||
Args:
|
||||
hidden_states: Input latent states
|
||||
prompt_embeds: Text embeddings from encoder
|
||||
|
||||
Returns:
|
||||
Tuple of (embedded_hidden_states, embedded_encoder_states)
|
||||
"""
|
||||
...
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings for conditioning.
|
||||
|
||||
Args:
|
||||
t: Current timestep
|
||||
runtime_config: Runtime configuration
|
||||
pooled_prompt_embeds: Pooled text embeddings (used by Flux)
|
||||
hidden_states: Image hidden states
|
||||
|
||||
Returns:
|
||||
Text embeddings tensor
|
||||
"""
|
||||
...
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Compute rotary position embeddings.
|
||||
|
||||
Args:
|
||||
prompt_embeds: Text embeddings
|
||||
runtime_config: Runtime configuration
|
||||
**kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask for Qwen)
|
||||
|
||||
Returns:
|
||||
Flux: mx.array
|
||||
Qwen: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any, # Format varies: mx.array (Flux) or nested tuple (Qwen)
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: "BlockWrapperMode",
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply a joint transformer block.
|
||||
|
||||
Args:
|
||||
block: The joint transformer block
|
||||
hidden_states: Image hidden states
|
||||
encoder_hidden_states: Text hidden states
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings (format varies by model)
|
||||
kv_cache: KV cache (None if not using cache)
|
||||
mode: CACHING or PATCHED mode
|
||||
text_seq_len: Text sequence length
|
||||
patch_start: Start index for patched mode
|
||||
patch_end: End index for patched mode
|
||||
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
|
||||
block_idx for Qwen)
|
||||
|
||||
Returns:
|
||||
Tuple of (encoder_hidden_states, hidden_states)
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: "BlockWrapperMode",
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Apply a single transformer block.
|
||||
|
||||
Args:
|
||||
block: The single transformer block
|
||||
hidden_states: Concatenated [text + image] hidden states
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
kv_cache: KV cache (None if not using cache)
|
||||
mode: CACHING or PATCHED mode
|
||||
text_seq_len: Text sequence length
|
||||
patch_start: Start index for patched mode
|
||||
patch_end: End index for patched mode
|
||||
|
||||
Returns:
|
||||
Output hidden states
|
||||
"""
|
||||
...
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply final norm and projection.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states (image only, text already removed)
|
||||
text_embeddings: Conditioning embeddings
|
||||
|
||||
Returns:
|
||||
Projected output
|
||||
"""
|
||||
...
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
"""Get the list of joint transformer blocks from the model."""
|
||||
...
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
"""Get the list of single transformer blocks from the model."""
|
||||
...
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
):
|
||||
"""Remove transformer blocks outside the assigned range.
|
||||
|
||||
This should be called BEFORE mx.eval() to avoid loading unused weights
|
||||
in distributed mode.
|
||||
|
||||
Args:
|
||||
start_layer: First layer index (inclusive) assigned to this node
|
||||
end_layer: Last layer index (exclusive) assigned to this node
|
||||
total_joint_blocks: Total number of joint blocks in the model
|
||||
total_single_blocks: Total number of single blocks in the model
|
||||
"""
|
||||
...
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
"""Merge image and text streams for transition to single blocks.
|
||||
|
||||
This is called at the transition point from joint blocks (which process
|
||||
image and text separately) to single blocks (which process them
|
||||
together). Override to customize the merge strategy.
|
||||
|
||||
Args:
|
||||
hidden_states: Image hidden states
|
||||
encoder_hidden_states: Text hidden states
|
||||
|
||||
Returns:
|
||||
Merged hidden states (default: concatenate [text, image])
|
||||
"""
|
||||
...
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
|
||||
"""Create initial noise latents for generation.
|
||||
|
||||
Args:
|
||||
seed: Random seed
|
||||
runtime_config: Runtime configuration with dimensions
|
||||
|
||||
Returns:
|
||||
Initial latent tensor
|
||||
"""
|
||||
...
|
||||
|
||||
def encode_prompt(self, prompt: str) -> PromptData:
|
||||
"""Encode prompt into model-specific prompt data.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
|
||||
Returns:
|
||||
PromptData containing embeddings (and model-specific extras)
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
"""Whether this model uses classifier-free guidance.
|
||||
|
||||
Returns:
|
||||
True if model requires two forward passes with guidance (e.g., Qwen)
|
||||
False if model uses a single forward pass (e.g., Flux)
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
"""Apply classifier-free guidance to combine positive/negative predictions.
|
||||
|
||||
Only called when needs_cfg is True.
|
||||
|
||||
Args:
|
||||
noise_positive: Noise prediction from positive prompt
|
||||
noise_negative: Noise prediction from negative prompt
|
||||
guidance_scale: Guidance strength
|
||||
|
||||
Returns:
|
||||
Guided noise prediction
|
||||
"""
|
||||
...
|
||||
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
) -> Any:
|
||||
"""Decode latents to final image.
|
||||
|
||||
Args:
|
||||
latents: Final denoised latents
|
||||
runtime_config: Runtime configuration
|
||||
seed: Random seed (for metadata)
|
||||
prompt: Text prompt (for metadata)
|
||||
|
||||
Returns:
|
||||
GeneratedImage result
|
||||
"""
|
||||
...
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
|
||||
"""Compute and store dimensions from input image for edit mode.
|
||||
|
||||
For edit adapters: computes dimensions from input image aspect ratio,
|
||||
stores image paths internally for encode_prompt(), returns (width, height).
|
||||
|
||||
For standard adapters: returns None (use user-specified dimensions).
|
||||
|
||||
Args:
|
||||
image_path: Path to the input image
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height) if dimensions were computed, None otherwise.
|
||||
"""
|
||||
...
|
||||
146
src/exo/worker/engines/image/pipeline/block_wrapper.py
Normal file
146
src/exo/worker/engines/image/pipeline/block_wrapper.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
ModelAdapter,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class JointBlockWrapper:
|
||||
"""Unified wrapper for joint transformer blocks.
|
||||
|
||||
Handles both CACHING (sync) and PATCHED (async) modes by delegating
|
||||
to the model adapter for model-specific attention computation.
|
||||
|
||||
The wrapper is created once at initialization and reused across calls.
|
||||
Mode and KV cache are passed at call time to support switching between
|
||||
sync and async pipelines.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
adapter: ModelAdapter,
|
||||
):
|
||||
"""Initialize the joint block wrapper.
|
||||
|
||||
Args:
|
||||
block: The joint transformer block to wrap
|
||||
adapter: Model adapter for model-specific operations
|
||||
"""
|
||||
self.block = block
|
||||
self.adapter = adapter
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
text_seq_len: int,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply the joint block.
|
||||
|
||||
Args:
|
||||
hidden_states: Image hidden states (full or patch depending on mode)
|
||||
encoder_hidden_states: Text hidden states
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
text_seq_len: Text sequence length
|
||||
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
|
||||
mode: CACHING (populate cache) or PATCHED (use cached K/V)
|
||||
patch_start: Start index for patched mode (required if mode=PATCHED)
|
||||
patch_end: End index for patched mode (required if mode=PATCHED)
|
||||
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
|
||||
block_idx for Qwen)
|
||||
|
||||
Returns:
|
||||
Tuple of (encoder_hidden_states, hidden_states)
|
||||
"""
|
||||
return self.adapter.apply_joint_block(
|
||||
block=self.block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
mode=mode,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class SingleBlockWrapper:
|
||||
"""Unified wrapper for single transformer blocks.
|
||||
|
||||
Handles both CACHING (sync) and PATCHED (async) modes by delegating
|
||||
to the model adapter for model-specific attention computation.
|
||||
|
||||
The wrapper is created once at initialization and reused across calls.
|
||||
Mode and KV cache are passed at call time to support switching between
|
||||
sync and async pipelines.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
adapter: ModelAdapter,
|
||||
):
|
||||
"""Initialize the single block wrapper.
|
||||
|
||||
Args:
|
||||
block: The single transformer block to wrap
|
||||
adapter: Model adapter for model-specific operations
|
||||
"""
|
||||
self.block = block
|
||||
self.adapter = adapter
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
text_seq_len: int,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Apply the single block.
|
||||
|
||||
Args:
|
||||
hidden_states: [text + image] hidden states (full or patch depending on mode)
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
text_seq_len: Text sequence length
|
||||
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
|
||||
mode: CACHING (populate cache) or PATCHED (use cached K/V)
|
||||
patch_start: Start index for patched mode (required if mode=PATCHED)
|
||||
patch_end: End index for patched mode (required if mode=PATCHED)
|
||||
|
||||
Returns:
|
||||
Output hidden states
|
||||
"""
|
||||
return self.adapter.apply_single_block(
|
||||
block=self.block,
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
mode=mode,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
)
|
||||
72
src/exo/worker/engines/image/pipeline/kv_cache.py
Normal file
72
src/exo/worker/engines/image/pipeline/kv_cache.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class ImagePatchKVCache:
|
||||
"""KV cache that stores only IMAGE K/V with patch-level updates.
|
||||
|
||||
Only caches image K/V since:
|
||||
- Text K/V is always computed fresh (same for all patches)
|
||||
- Only image portion needs stale/fresh cache management across patches
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
image_seq_len: int,
|
||||
head_dim: int,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.num_heads = num_heads
|
||||
self.image_seq_len = image_seq_len
|
||||
self.head_dim = head_dim
|
||||
self._dtype = dtype
|
||||
|
||||
self.key_cache = mx.zeros(
|
||||
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
|
||||
)
|
||||
self.value_cache = mx.zeros(
|
||||
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
|
||||
)
|
||||
|
||||
def update_image_patch(
|
||||
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
|
||||
) -> None:
|
||||
"""Update cache with fresh K/V for an image patch slice.
|
||||
|
||||
Args:
|
||||
patch_start: Start token index within image portion (0-indexed)
|
||||
patch_end: End token index within image portion
|
||||
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
|
||||
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
|
||||
"""
|
||||
self.key_cache[:, :, patch_start:patch_end, :] = key
|
||||
self.value_cache[:, :, patch_start:patch_end, :] = value
|
||||
|
||||
def get_full_kv(
|
||||
self, text_key: mx.array, text_value: mx.array
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
|
||||
|
||||
Args:
|
||||
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
|
||||
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
|
||||
|
||||
Returns:
|
||||
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
|
||||
"""
|
||||
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
|
||||
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
|
||||
return full_key, full_value
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset cache to zeros."""
|
||||
self.key_cache = mx.zeros(
|
||||
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
|
||||
dtype=self._dtype,
|
||||
)
|
||||
self.value_cache = mx.zeros(
|
||||
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
|
||||
dtype=self._dtype,
|
||||
)
|
||||
975
src/exo/worker/engines/image/pipeline/runner.py
Normal file
975
src/exo/worker/engines/image/pipeline/runner.py
Normal file
@@ -0,0 +1,975 @@
|
||||
from math import ceil
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.callbacks.callbacks import Callbacks
|
||||
from mflux.config.config import Config
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.utils.exceptions import StopImageGenerationException
|
||||
from tqdm import tqdm
|
||||
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
ModelAdapter,
|
||||
PromptData,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
def calculate_patch_heights(latent_height: int, num_patches: int):
|
||||
patch_height = ceil(latent_height / num_patches)
|
||||
|
||||
actual_num_patches = ceil(latent_height / patch_height)
|
||||
patch_heights = [patch_height] * (actual_num_patches - 1)
|
||||
|
||||
last_height = latent_height - patch_height * (actual_num_patches - 1)
|
||||
patch_heights.append(last_height)
|
||||
|
||||
return patch_heights, actual_num_patches
|
||||
|
||||
|
||||
def calculate_token_indices(patch_heights: list[int], latent_width: int):
|
||||
tokens_per_row = latent_width
|
||||
|
||||
token_ranges = []
|
||||
cumulative_height = 0
|
||||
|
||||
for h in patch_heights:
|
||||
start_token = tokens_per_row * cumulative_height
|
||||
end_token = tokens_per_row * (cumulative_height + h)
|
||||
|
||||
token_ranges.append((start_token, end_token))
|
||||
cumulative_height += h
|
||||
|
||||
return token_ranges
|
||||
|
||||
|
||||
class DiffusionRunner:
|
||||
"""Orchestrates the diffusion loop for image generation.
|
||||
|
||||
This class owns the entire diffusion process, handling both single-node
|
||||
and distributed (PipeFusion) modes.
|
||||
|
||||
In distributed mode, it implements PipeFusion with:
|
||||
- Sync pipeline for initial timesteps (full image, all devices in lockstep)
|
||||
- Async pipeline for later timesteps (patches processed independently)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
adapter: ModelAdapter,
|
||||
group: Optional[mx.distributed.Group],
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
num_sync_steps: int = 1,
|
||||
num_patches: Optional[int] = None,
|
||||
):
|
||||
"""Initialize the diffusion runner.
|
||||
|
||||
Args:
|
||||
config: Model configuration (architecture, block counts, etc.)
|
||||
adapter: Model adapter for model-specific operations
|
||||
group: MLX distributed group (None for single-node mode)
|
||||
shard_metadata: Pipeline shard metadata with layer assignments
|
||||
num_sync_steps: Number of synchronous timesteps before async mode
|
||||
num_patches: Number of patches for async mode (defaults to world_size)
|
||||
"""
|
||||
self.config = config
|
||||
self.adapter = adapter
|
||||
self.group = group
|
||||
|
||||
# Handle single-node vs distributed mode
|
||||
if group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.next_rank = 0
|
||||
self.prev_rank = 0
|
||||
self.start_layer = 0
|
||||
self.end_layer = config.total_blocks
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.next_rank = (self.rank + 1) % self.world_size
|
||||
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
self.num_sync_steps = num_sync_steps
|
||||
self.num_patches = num_patches if num_patches else max(1, self.world_size)
|
||||
|
||||
# Persistent KV caches (initialized on first async timestep, reused across timesteps)
|
||||
self.joint_kv_caches: list[ImagePatchKVCache] | None = None
|
||||
self.single_kv_caches: list[ImagePatchKVCache] | None = None
|
||||
|
||||
# Get block counts from config (model-agnostic)
|
||||
self.total_joint = config.joint_block_count
|
||||
self.total_single = config.single_block_count
|
||||
self.total_layers = config.total_blocks
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
def _compute_assigned_blocks(self) -> None:
|
||||
"""Determine which joint/single blocks this stage owns."""
|
||||
start = self.start_layer
|
||||
end = self.end_layer
|
||||
|
||||
if end <= self.total_joint:
|
||||
# All assigned blocks are joint blocks
|
||||
self.joint_start = start
|
||||
self.joint_end = end
|
||||
self.single_start = 0
|
||||
self.single_end = 0
|
||||
elif start >= self.total_joint:
|
||||
# All assigned blocks are single blocks
|
||||
self.joint_start = 0
|
||||
self.joint_end = 0
|
||||
self.single_start = start - self.total_joint
|
||||
self.single_end = end - self.total_joint
|
||||
else:
|
||||
# Stage spans joint→single transition
|
||||
self.joint_start = start
|
||||
self.joint_end = self.total_joint
|
||||
self.single_start = 0
|
||||
self.single_end = end - self.total_joint
|
||||
|
||||
self.has_joint_blocks = self.joint_end > self.joint_start
|
||||
self.has_single_blocks = self.single_end > self.single_start
|
||||
|
||||
self.owns_concat_stage = self.has_joint_blocks and (
|
||||
self.has_single_blocks or self.end_layer == self.total_joint
|
||||
)
|
||||
|
||||
joint_blocks = self.adapter.get_joint_blocks()
|
||||
single_blocks = self.adapter.get_single_blocks()
|
||||
|
||||
# Wrap blocks at initialization (reused across all calls)
|
||||
self.joint_block_wrappers = [
|
||||
JointBlockWrapper(block=block, adapter=self.adapter)
|
||||
for block in joint_blocks
|
||||
]
|
||||
self.single_block_wrappers = [
|
||||
SingleBlockWrapper(block=block, adapter=self.adapter)
|
||||
for block in single_blocks
|
||||
]
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self.rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self.rank == self.world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return self.group is not None
|
||||
|
||||
def _calculate_capture_steps(
|
||||
self,
|
||||
partial_images: int,
|
||||
init_time_step: int,
|
||||
num_inference_steps: int,
|
||||
) -> set[int]:
|
||||
"""Calculate which timesteps should produce partial images.
|
||||
|
||||
Evenly spaces `partial_images` captures across the diffusion loop.
|
||||
Does NOT include the final timestep (that's the complete image).
|
||||
|
||||
Args:
|
||||
partial_images: Number of partial images to capture
|
||||
init_time_step: Starting timestep (for img2img this may not be 0)
|
||||
num_inference_steps: Total inference steps
|
||||
|
||||
Returns:
|
||||
Set of timestep indices to capture
|
||||
"""
|
||||
if partial_images <= 0:
|
||||
return set()
|
||||
|
||||
total_steps = num_inference_steps - init_time_step
|
||||
if total_steps <= 1:
|
||||
return set()
|
||||
|
||||
if partial_images >= total_steps - 1:
|
||||
# Capture every step except final
|
||||
return set(range(init_time_step, num_inference_steps - 1))
|
||||
|
||||
# Evenly space partial captures
|
||||
step_interval = total_steps / (partial_images + 1)
|
||||
capture_steps: set[int] = set()
|
||||
for i in range(1, partial_images + 1):
|
||||
step_idx = int(init_time_step + i * step_interval)
|
||||
# Ensure we don't capture the final step
|
||||
if step_idx < num_inference_steps - 1:
|
||||
capture_steps.add(step_idx)
|
||||
|
||||
return capture_steps
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
settings: Config,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
partial_images: int = 0,
|
||||
):
|
||||
"""Primary entry point for image generation.
|
||||
|
||||
Orchestrates the full generation flow:
|
||||
1. Create runtime config
|
||||
2. Create initial latents
|
||||
3. Encode prompt
|
||||
4. Run diffusion loop (yielding partials if requested)
|
||||
5. Decode to image
|
||||
|
||||
When partial_images > 0, yields (GeneratedImage, partial_index, total_partials)
|
||||
tuples for intermediate images, then yields the final GeneratedImage.
|
||||
|
||||
Args:
|
||||
settings: Generation config (steps, height, width)
|
||||
prompt: Text prompt
|
||||
seed: Random seed
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
|
||||
Yields:
|
||||
Partial images as (GeneratedImage, partial_index, total_partials) tuples
|
||||
Final GeneratedImage
|
||||
"""
|
||||
runtime_config = RuntimeConfig(settings, self.adapter.model.model_config)
|
||||
latents = self.adapter.create_latents(seed, runtime_config)
|
||||
prompt_data = self.adapter.encode_prompt(prompt)
|
||||
|
||||
# Calculate which steps to capture
|
||||
capture_steps = self._calculate_capture_steps(
|
||||
partial_images=partial_images,
|
||||
init_time_step=runtime_config.init_time_step,
|
||||
num_inference_steps=runtime_config.num_inference_steps,
|
||||
)
|
||||
|
||||
# Run diffusion loop - may yield partial latents
|
||||
diffusion_gen = self._run_diffusion_loop(
|
||||
latents=latents,
|
||||
prompt_data=prompt_data,
|
||||
runtime_config=runtime_config,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
capture_steps=capture_steps,
|
||||
)
|
||||
|
||||
# Process partial yields and get final latents
|
||||
partial_index = 0
|
||||
total_partials = len(capture_steps)
|
||||
|
||||
if capture_steps:
|
||||
# Generator mode - iterate to get partials and final latents
|
||||
try:
|
||||
while True:
|
||||
partial_latents, _step = next(diffusion_gen)
|
||||
if self.is_last_stage:
|
||||
partial_image = self.adapter.decode_latents(
|
||||
partial_latents, runtime_config, seed, prompt
|
||||
)
|
||||
yield (partial_image, partial_index, total_partials)
|
||||
partial_index += 1
|
||||
except StopIteration as e:
|
||||
latents = e.value
|
||||
else:
|
||||
# No partials - just consume generator to get final latents
|
||||
try:
|
||||
while True:
|
||||
next(diffusion_gen)
|
||||
except StopIteration as e:
|
||||
latents = e.value
|
||||
|
||||
# Yield final image (only on last stage)
|
||||
if self.is_last_stage:
|
||||
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt)
|
||||
|
||||
def _run_diffusion_loop(
|
||||
self,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
runtime_config: RuntimeConfig,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
capture_steps: set[int] | None = None,
|
||||
):
|
||||
"""Execute the diffusion loop, optionally yielding at capture steps.
|
||||
|
||||
When capture_steps is provided and non-empty, this becomes a generator
|
||||
that yields (latents, step_index) tuples at the specified timesteps.
|
||||
Only the last stage yields (others have incomplete latents).
|
||||
|
||||
Args:
|
||||
latents: Initial noise latents
|
||||
prompt_data: Encoded prompt data
|
||||
runtime_config: RuntimeConfig with scheduler, steps, dimensions
|
||||
seed: Random seed (for callbacks)
|
||||
prompt: Text prompt (for callbacks)
|
||||
capture_steps: Set of timestep indices to capture (None = no captures)
|
||||
|
||||
Yields:
|
||||
(latents, step_index) tuples at capture steps (last stage only)
|
||||
|
||||
Returns:
|
||||
Final denoised latents ready for VAE decoding
|
||||
"""
|
||||
if capture_steps is None:
|
||||
capture_steps = set()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
|
||||
# Call subscribers for beginning of loop
|
||||
Callbacks.before_loop(
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
)
|
||||
|
||||
for t in time_steps:
|
||||
try:
|
||||
latents = self._diffusion_step(
|
||||
t=t,
|
||||
config=runtime_config,
|
||||
latents=latents,
|
||||
prompt_data=prompt_data,
|
||||
)
|
||||
|
||||
# Call subscribers in-loop
|
||||
Callbacks.in_loop(
|
||||
t=t,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
|
||||
# Yield partial latents at capture steps (only on last stage)
|
||||
if t in capture_steps and self.is_last_stage:
|
||||
yield (latents, t)
|
||||
|
||||
except KeyboardInterrupt: # noqa: PERF203
|
||||
Callbacks.interruption(
|
||||
t=t,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
raise StopImageGenerationException(
|
||||
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
|
||||
) from None
|
||||
|
||||
# Call subscribers after loop
|
||||
Callbacks.after_loop(
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def _forward_pass(
|
||||
self,
|
||||
latents: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
kwargs: dict[str, Any],
|
||||
) -> mx.array:
|
||||
"""Run a single forward pass through the transformer.
|
||||
|
||||
This is the internal method called by adapters via compute_step_noise.
|
||||
Returns noise prediction without applying scheduler step.
|
||||
|
||||
For edit mode, concatenates conditioning latents with generated latents
|
||||
before the transformer, and extracts only the generated portion after.
|
||||
|
||||
Args:
|
||||
latents: Input latents (already scaled by caller)
|
||||
prompt_embeds: Text embeddings
|
||||
pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)
|
||||
kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask, t)
|
||||
|
||||
Returns:
|
||||
Noise prediction tensor
|
||||
"""
|
||||
t = kwargs.get("t", 0)
|
||||
config = kwargs.get("config")
|
||||
if config is None:
|
||||
raise ValueError("config must be provided in kwargs")
|
||||
scaled_latents = config.scheduler.scale_model_input(latents, t)
|
||||
|
||||
# For edit mode: concatenate with conditioning latents
|
||||
conditioning_latents = kwargs.get("conditioning_latents")
|
||||
original_latent_tokens = scaled_latents.shape[1]
|
||||
if conditioning_latents is not None:
|
||||
scaled_latents = mx.concatenate(
|
||||
[scaled_latents, conditioning_latents], axis=1
|
||||
)
|
||||
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
scaled_latents, prompt_embeds
|
||||
)
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds, hidden_states=hidden_states
|
||||
)
|
||||
rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds, config, **kwargs
|
||||
)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
|
||||
# Run through all joint blocks
|
||||
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=None,
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
block_idx=block_idx,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Merge streams
|
||||
if self.joint_block_wrappers:
|
||||
hidden_states = self.adapter.merge_streams(
|
||||
hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
# Run through single blocks
|
||||
for wrapper in self.single_block_wrappers:
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=None,
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
)
|
||||
|
||||
# Extract image portion and project
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
# For edit mode: extract only the generated portion (exclude conditioning latents)
|
||||
if conditioning_latents is not None:
|
||||
hidden_states = hidden_states[:, :original_latent_tokens, ...]
|
||||
|
||||
return self.adapter.final_projection(hidden_states, text_embeddings)
|
||||
|
||||
def _diffusion_step(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
"""Execute a single diffusion step.
|
||||
|
||||
Routes to single-node, sync pipeline, or async pipeline based on
|
||||
configuration and current timestep.
|
||||
"""
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif t < config.init_time_step + self.num_sync_steps:
|
||||
return self._sync_pipeline(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
else:
|
||||
return self._async_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
|
||||
def _single_node_step(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
"""Execute a single diffusion step on a single node (no distribution)."""
|
||||
base_kwargs = {"t": t, "config": config}
|
||||
|
||||
# For edit mode: include conditioning latents
|
||||
if prompt_data.conditioning_latents is not None:
|
||||
base_kwargs["conditioning_latents"] = prompt_data.conditioning_latents
|
||||
|
||||
if self.adapter.needs_cfg:
|
||||
# Two forward passes + guidance for CFG models (e.g., Qwen)
|
||||
pos_kwargs = {
|
||||
**base_kwargs,
|
||||
**prompt_data.get_extra_forward_kwargs(positive=True),
|
||||
}
|
||||
noise_pos = self._forward_pass(
|
||||
latents,
|
||||
prompt_data.prompt_embeds,
|
||||
prompt_data.pooled_prompt_embeds,
|
||||
pos_kwargs,
|
||||
)
|
||||
|
||||
neg_kwargs = {
|
||||
**base_kwargs,
|
||||
**prompt_data.get_extra_forward_kwargs(positive=False),
|
||||
}
|
||||
noise_neg = self._forward_pass(
|
||||
latents,
|
||||
prompt_data.negative_prompt_embeds,
|
||||
prompt_data.negative_pooled_prompt_embeds,
|
||||
neg_kwargs,
|
||||
)
|
||||
|
||||
assert self.config.guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale=self.config.guidance_scale
|
||||
)
|
||||
else:
|
||||
# Single forward pass for non-CFG models (e.g., Flux)
|
||||
kwargs = {**base_kwargs, **prompt_data.get_extra_forward_kwargs()}
|
||||
noise = self._forward_pass(
|
||||
latents,
|
||||
prompt_data.prompt_embeds,
|
||||
prompt_data.pooled_prompt_embeds,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
return config.scheduler.step(model_output=noise, timestep=t, sample=latents)
|
||||
|
||||
def _initialize_kv_caches(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_img_tokens: int,
|
||||
dtype: mx.Dtype,
|
||||
) -> None:
|
||||
"""Initialize KV caches for both sync and async pipelines.
|
||||
|
||||
Note: Caches only store IMAGE K/V, not text K/V. Text K/V is always
|
||||
computed fresh and doesn't need caching (it's the same for all patches).
|
||||
"""
|
||||
self.joint_kv_caches = [
|
||||
ImagePatchKVCache(
|
||||
batch_size=batch_size,
|
||||
num_heads=self.config.num_heads,
|
||||
image_seq_len=num_img_tokens,
|
||||
head_dim=self.config.head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(len(self.joint_block_wrappers))
|
||||
]
|
||||
self.single_kv_caches = [
|
||||
ImagePatchKVCache(
|
||||
batch_size=batch_size,
|
||||
num_heads=self.config.num_heads,
|
||||
image_seq_len=num_img_tokens,
|
||||
head_dim=self.config.head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(len(self.single_block_wrappers))
|
||||
]
|
||||
|
||||
def _create_patches(
|
||||
self,
|
||||
latents: mx.array,
|
||||
config: RuntimeConfig,
|
||||
) -> tuple[list[mx.array], list[tuple[int, int]]]:
|
||||
"""Split latents into patches for async pipeline."""
|
||||
# Use 16 to match FluxLatentCreator.create_noise formula
|
||||
latent_height = config.height // 16
|
||||
latent_width = config.width // 16
|
||||
|
||||
patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)
|
||||
token_indices = calculate_token_indices(patch_heights, latent_width)
|
||||
|
||||
# Split latents into patches
|
||||
patch_latents = [latents[:, start:end, :] for start, end in token_indices]
|
||||
|
||||
return patch_latents, token_indices
|
||||
|
||||
def _sync_pipeline(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
hidden_states: mx.array,
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
|
||||
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
|
||||
extra_kwargs = prompt_data.get_extra_forward_kwargs()
|
||||
|
||||
hidden_states = config.scheduler.scale_model_input(hidden_states, t)
|
||||
|
||||
# For edit mode: handle conditioning latents
|
||||
# All stages need to know the total token count for correct recv templates
|
||||
conditioning_latents = prompt_data.conditioning_latents
|
||||
original_latent_tokens = hidden_states.shape[1]
|
||||
if conditioning_latents is not None:
|
||||
num_img_tokens = original_latent_tokens + conditioning_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
# First stage: concatenate conditioning latents before embedding
|
||||
if self.is_first_stage and conditioning_latents is not None:
|
||||
hidden_states = mx.concatenate(
|
||||
[hidden_states, conditioning_latents], axis=1
|
||||
)
|
||||
|
||||
# === PHASE 1: Embeddings ===
|
||||
if self.is_first_stage:
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
hidden_states, prompt_embeds
|
||||
)
|
||||
|
||||
# All stages need these for their blocks
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
# === Initialize KV caches to populate during sync for async warmstart ===
|
||||
batch_size = prev_latents.shape[0]
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
hidden_dim = self.adapter.hidden_dim
|
||||
|
||||
if t == config.init_time_step:
|
||||
self._initialize_kv_caches(
|
||||
batch_size=batch_size,
|
||||
num_img_tokens=num_img_tokens,
|
||||
dtype=prev_latents.dtype,
|
||||
)
|
||||
|
||||
# === PHASE 2: Joint Blocks with Communication and Caching ===
|
||||
if self.has_joint_blocks:
|
||||
# Receive from previous stage (if not first stage)
|
||||
if not self.is_first_stage:
|
||||
recv_template = mx.zeros(
|
||||
(batch_size, num_img_tokens, hidden_dim), dtype=prev_latents.dtype
|
||||
)
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
recv_template, self.prev_rank, group=self.group
|
||||
)
|
||||
enc_template = mx.zeros(
|
||||
(batch_size, text_seq_len, hidden_dim), dtype=prev_latents.dtype
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv_like(
|
||||
enc_template, self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
# Run assigned joint blocks with caching mode
|
||||
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.joint_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
# === PHASE 3: Joint→Single Transition ===
|
||||
if self.owns_concat_stage:
|
||||
# Merge encoder and hidden states using adapter hook
|
||||
concatenated = self.adapter.merge_streams(
|
||||
hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
# Keep locally: either for single blocks or final projection
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
# Send concatenated state to next stage (which has single blocks)
|
||||
mx.eval(
|
||||
mx.distributed.send(concatenated, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
# Send joint block outputs to next stage (which has more joint blocks)
|
||||
mx.eval(
|
||||
mx.distributed.send(hidden_states, self.next_rank, group=self.group),
|
||||
mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
),
|
||||
)
|
||||
|
||||
# === PHASE 4: Single Blocks with Communication and Caching ===
|
||||
if self.has_single_blocks:
|
||||
# Receive from previous stage if we didn't do concatenation
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
recv_template = mx.zeros(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype=prev_latents.dtype,
|
||||
)
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
recv_template, self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
# Run assigned single blocks with caching mode
|
||||
for block_idx, wrapper in enumerate(self.single_block_wrappers):
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.single_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
)
|
||||
|
||||
# Send to next stage if not last
|
||||
if not self.is_last_stage:
|
||||
mx.eval(
|
||||
mx.distributed.send(hidden_states, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
# === PHASE 5: Last Stage - Final Projection + Scheduler ===
|
||||
# Extract image portion (remove text embeddings prefix)
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
# For edit mode: extract only the generated portion (exclude conditioning latents)
|
||||
if conditioning_latents is not None:
|
||||
hidden_states = hidden_states[:, :original_latent_tokens, ...]
|
||||
|
||||
if self.is_last_stage:
|
||||
hidden_states = self.adapter.final_projection(
|
||||
hidden_states, text_embeddings
|
||||
)
|
||||
|
||||
hidden_states = config.scheduler.step(
|
||||
model_output=hidden_states,
|
||||
timestep=t,
|
||||
sample=prev_latents,
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
mx.eval(mx.distributed.send(hidden_states, 0, group=self.group))
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
|
||||
mx.eval(hidden_states)
|
||||
|
||||
else:
|
||||
# For shape correctness
|
||||
hidden_states = prev_latents
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _async_pipeline_step(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
|
||||
patch_latents = self._async_pipeline(
|
||||
t,
|
||||
config,
|
||||
patch_latents,
|
||||
token_indices,
|
||||
prompt_data,
|
||||
kontext_image_ids,
|
||||
)
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
def _async_pipeline(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
patch_latents: list[mx.array],
|
||||
token_indices: list[tuple[int, int]],
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> list[mx.array]:
|
||||
"""Execute async pipeline for all patches."""
|
||||
assert self.joint_kv_caches is not None
|
||||
assert self.single_kv_caches is not None
|
||||
|
||||
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
|
||||
extra_kwargs = prompt_data.get_extra_forward_kwargs()
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
batch_size = patch_latents[0].shape[0]
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
hidden_dim = self.adapter.hidden_dim
|
||||
|
||||
for patch_idx, patch in enumerate(patch_latents):
|
||||
patch_prev = patch
|
||||
|
||||
start_token, end_token = token_indices[patch_idx]
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if (
|
||||
not self.is_first_stage
|
||||
or t != config.init_time_step + self.num_sync_steps
|
||||
):
|
||||
if self.is_first_stage:
|
||||
# First stage receives latent-space from last stage (scheduler output)
|
||||
recv_template = patch
|
||||
else:
|
||||
# Other stages receive hidden-space from previous stage
|
||||
patch_len = patch.shape[1]
|
||||
recv_template = mx.zeros(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
dtype=patch.dtype,
|
||||
)
|
||||
patch = mx.distributed.recv_like(
|
||||
recv_template, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
patch_latents[patch_idx] = patch
|
||||
|
||||
if not self.is_first_stage and patch_idx == 0:
|
||||
enc_template = mx.zeros(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype=patch_latents[0].dtype,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv_like(
|
||||
enc_template, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
patch, prompt_embeds
|
||||
)
|
||||
|
||||
# Run assigned joint blocks with patched mode
|
||||
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
|
||||
encoder_hidden_states, patch = wrapper(
|
||||
hidden_states=patch,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.joint_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.PATCHED,
|
||||
patch_start=start_token,
|
||||
patch_end=end_token,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
# Keep locally: either for single blocks or final projection
|
||||
patch = patch_concat
|
||||
else:
|
||||
mx.eval(
|
||||
mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
mx.eval(mx.distributed.send(patch, self.next_rank, group=self.group))
|
||||
|
||||
if patch_idx == 0:
|
||||
mx.eval(
|
||||
mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
recv_template = mx.zeros(
|
||||
[
|
||||
batch_size,
|
||||
text_seq_len + patch_latents[patch_idx].shape[1],
|
||||
hidden_dim,
|
||||
],
|
||||
dtype=patch_latents[0].dtype,
|
||||
)
|
||||
|
||||
patch = mx.distributed.recv_like(
|
||||
recv_template, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
patch_latents[patch_idx] = patch
|
||||
|
||||
# Run assigned single blocks with patched mode
|
||||
for block_idx, wrapper in enumerate(self.single_block_wrappers):
|
||||
patch = wrapper(
|
||||
hidden_states=patch,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.single_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.PATCHED,
|
||||
patch_start=start_token,
|
||||
patch_end=end_token,
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
mx.eval(
|
||||
mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
patch_img_only = patch[:, text_seq_len:, :]
|
||||
|
||||
patch_img_only = self.adapter.final_projection(
|
||||
patch_img_only, text_embeddings
|
||||
)
|
||||
|
||||
patch = config.scheduler.step(
|
||||
model_output=patch_img_only,
|
||||
timestep=t,
|
||||
sample=patch_prev,
|
||||
)
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
mx.eval(
|
||||
mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
patch_latents[patch_idx] = patch
|
||||
|
||||
return patch_latents
|
||||
@@ -103,6 +103,7 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
# This change happened upstream - check out mlx github somewhere??
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
# TODO(ciaran): This is overkill
|
||||
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
|
||||
return output
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ MAX_KV_SIZE: int | None = 3200
|
||||
KEEP_KV_SIZE: int | None = 1600
|
||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
CACHE_GROUP_SIZE: int = 64
|
||||
KV_CACHE_BITS: int | None = 8
|
||||
KV_CACHE_BITS: int | None = None
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
|
||||
@@ -343,10 +343,6 @@ def make_kv_cache(
|
||||
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
if hasattr(model, "make_cache"):
|
||||
logger.info(f"Using make_cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
@@ -399,11 +395,5 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
"MB. This can be slow. See the documentation for possible work-arounds: "
|
||||
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
|
||||
)
|
||||
kv_bytes = int(0.02 * model_bytes)
|
||||
target_cache = int(1.10 * (model_bytes + kv_bytes))
|
||||
target_cache = min(target_cache, max_rec_size)
|
||||
mx.set_cache_limit(target_cache)
|
||||
mx.set_wired_limit(max_rec_size)
|
||||
logger.info(
|
||||
f"Wired limit set to {max_rec_size}. Cache limit set to {target_cache}."
|
||||
)
|
||||
logger.info(f"Wired limit set to {max_rec_size}.")
|
||||
|
||||
@@ -8,13 +8,15 @@ from loguru import logger
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.api import ImageEditsInternalParams
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
NodeDownloadProgress,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
@@ -23,12 +25,14 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
Shutdown,
|
||||
Task,
|
||||
TaskStatus,
|
||||
@@ -83,7 +87,7 @@ class Worker:
|
||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
||||
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||
self.download_status: dict[ModelId, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
@@ -94,6 +98,10 @@ class Worker:
|
||||
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
|
||||
# Buffer for input image chunks (for image editing)
|
||||
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
|
||||
self.input_chunk_counts: dict[CommandId, int] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
@@ -128,6 +136,7 @@ class Worker:
|
||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
||||
|
||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
@@ -171,6 +180,17 @@ class Worker:
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
if cmd_id not in self.input_chunk_buffer:
|
||||
self.input_chunk_buffer[cmd_id] = {}
|
||||
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
|
||||
|
||||
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
|
||||
event.chunk.data
|
||||
)
|
||||
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
await anyio.sleep(0.1)
|
||||
@@ -183,6 +203,8 @@ class Worker:
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
self.state.tasks,
|
||||
self.input_chunk_buffer,
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
continue
|
||||
@@ -200,11 +222,11 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
if shard not in self.download_status:
|
||||
if shard.model_meta.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard] = progress
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
@@ -217,7 +239,7 @@ class Worker:
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard] = progress
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
@@ -244,6 +266,42 @@ class Worker:
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
chunks = self.input_chunk_buffer.get(cmd_id, {})
|
||||
assembled = "".join(chunks[i] for i in range(len(chunks)))
|
||||
logger.info(
|
||||
f"Assembled input image from {len(chunks)} chunks, "
|
||||
f"total size: {len(assembled)} bytes"
|
||||
)
|
||||
# Create modified task with assembled image data
|
||||
modified_task = ImageEdits(
|
||||
task_id=task.task_id,
|
||||
command_id=task.command_id,
|
||||
instance_id=task.instance_id,
|
||||
task_status=task.task_status,
|
||||
task_params=ImageEditsInternalParams(
|
||||
image_data=assembled,
|
||||
total_input_chunks=task.task_params.total_input_chunks,
|
||||
prompt=task.task_params.prompt,
|
||||
model=task.task_params.model,
|
||||
n=task.task_params.n,
|
||||
quality=task.task_params.quality,
|
||||
output_format=task.task_params.output_format,
|
||||
response_format=task.task_params.response_format,
|
||||
size=task.task_params.size,
|
||||
image_strength=task.task_params.image_strength,
|
||||
),
|
||||
)
|
||||
# Cleanup buffers
|
||||
if cmd_id in self.input_chunk_buffer:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
@@ -349,7 +407,7 @@ class Worker:
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[task.shard_metadata] = status
|
||||
self.download_status[task.shard_metadata.model_meta.model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
@@ -363,7 +421,7 @@ class Worker:
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[shard] = status
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
# Footgun!
|
||||
self.event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
@@ -384,7 +442,7 @@ class Worker:
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[shard] = status
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
self.event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
@@ -444,3 +502,40 @@ class Worker:
|
||||
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.info("Fetching and emitting existing download progress...")
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_meta.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.info("Done emitting existing download progress.")
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting existing download progress: {e}")
|
||||
|
||||
@@ -2,12 +2,15 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -34,7 +37,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
@@ -43,12 +45,14 @@ def plan(
|
||||
# Runners is expected to be FRESH and so should not come from state
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
# DL_status is expected to be FRESH and so should not come from state
|
||||
download_status: Mapping[ShardMetadata, DownloadProgress],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
# gdls is not expected to be fresh
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
tasks: Mapping[TaskId, Task],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_counts: Mapping[CommandId, int] | None = None,
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
@@ -58,7 +62,7 @@ def plan(
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
)
|
||||
|
||||
|
||||
@@ -111,13 +115,14 @@ def _create_runner(
|
||||
|
||||
def _model_needs_download(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ShardMetadata, DownloadProgress],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
model_id = runner.bound_instance.bound_shard.model_meta.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
not isinstance(
|
||||
download_status.get(runner.bound_instance.bound_shard, None),
|
||||
(DownloadOngoing, DownloadCompleted),
|
||||
model_id not in download_status
|
||||
or not isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
)
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
@@ -261,18 +266,38 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
if not isinstance(task, ChatCompletion):
|
||||
# TODO(ciaran): do this better!
|
||||
if (
|
||||
not isinstance(task, ChatCompletion)
|
||||
and not isinstance(task, ImageGeneration)
|
||||
and not isinstance(task, ImageEdits)
|
||||
):
|
||||
continue
|
||||
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
|
||||
continue
|
||||
|
||||
# For ImageEdits tasks, verify all input chunks have been received
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
for runner in runners.values():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import base64
|
||||
import time
|
||||
|
||||
from exo.master.api import get_model_card
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import ImageChunk, TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -9,9 +12,12 @@ from exo.shared.types.events import (
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.models import ModelTask
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -21,6 +27,8 @@ from exo.shared.types.tasks import (
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
@@ -32,10 +40,18 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
from exo.worker.engines.image import (
|
||||
ImageGenerator,
|
||||
generate_image,
|
||||
initialize_image_model,
|
||||
warmup_image_generator,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
@@ -69,6 +85,10 @@ def main(
|
||||
sampler = None
|
||||
group = None
|
||||
|
||||
model_card = get_model_card(shard_metadata.model_meta.model_id)
|
||||
assert model_card
|
||||
model_tasks = model_card.tasks
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
@@ -111,16 +131,26 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, sampler = load_mlx_items(
|
||||
bound_instance, group
|
||||
)
|
||||
# TODO(ciaran): switch
|
||||
if ModelTask.TextGeneration in model_tasks:
|
||||
model, tokenizer, sampler = load_mlx_items(
|
||||
bound_instance, group
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in model_tasks
|
||||
or ModelTask.ImageToImage in model_tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {model_card.tasks}"
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
assert sampler
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -130,22 +160,40 @@ def main(
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
if ModelTask.TextGeneration in model_tasks:
|
||||
assert model and isinstance(model, Model)
|
||||
assert tokenizer
|
||||
assert sampler
|
||||
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in model_tasks
|
||||
or ModelTask.ImageToImage in model_tasks
|
||||
):
|
||||
assert isinstance(model, ImageGenerator)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(
|
||||
f"warmed up by generating {image.size} image"
|
||||
)
|
||||
else:
|
||||
logger.info("warmup completed (non-primary node)")
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert model
|
||||
assert model and isinstance(model, Model)
|
||||
assert tokenizer
|
||||
assert sampler
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
@@ -186,14 +234,171 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
logger.info("runner shutting down")
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, ImageGenerator)
|
||||
logger.info(
|
||||
f"received image generation request: {str(task)[:500]}"
|
||||
)
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=model,
|
||||
task=task_params,
|
||||
):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
encoded_data = base64.b64encode(
|
||||
response.image_data
|
||||
).decode("utf-8")
|
||||
# Split into chunks to stay under gossipsub 1MB limit
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(
|
||||
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
|
||||
)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}: {len(encoded_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(
|
||||
data_chunks
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=response.partial_index,
|
||||
is_partial=True,
|
||||
partial_index=response.partial_index,
|
||||
total_partials=response.total_partials,
|
||||
),
|
||||
)
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
encoded_data = base64.b64encode(
|
||||
response.image_data
|
||||
).decode("utf-8")
|
||||
# Split into chunks to stay under gossipsub 1MB limit
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(
|
||||
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
|
||||
)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
logger.info(
|
||||
f"sending final ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(
|
||||
data_chunks
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=image_index,
|
||||
is_partial=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, ImageGenerator)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=model,
|
||||
task=task_params,
|
||||
):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case ImageGenerationResponse():
|
||||
encoded_data = base64.b64encode(
|
||||
response.image_data
|
||||
).decode("utf-8")
|
||||
# Split into chunks to stay under gossipsub 1MB limit
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(
|
||||
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
|
||||
)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
logger.info(
|
||||
f"sending ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(
|
||||
data_chunks
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=image_index,
|
||||
is_partial=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
@@ -208,9 +413,8 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
break
|
||||
except ClosedResourceError:
|
||||
logger.warning("runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -14,13 +14,23 @@ from anyio import (
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated, TaskAcknowledged
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoading,
|
||||
RunnerRunning,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
@@ -39,10 +49,10 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
# err_path: str
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -77,7 +87,6 @@ class RunnerSupervisor:
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_event_sender=event_sender,
|
||||
# err_path=err_path,
|
||||
)
|
||||
|
||||
return self
|
||||
@@ -118,6 +127,10 @@ class RunnerSupervisor:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.completed:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
@@ -138,6 +151,22 @@ class RunnerSupervisor:
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
and event.task_status == TaskStatus.Complete
|
||||
):
|
||||
# If a task has just been completed, we should be working on it.
|
||||
assert isinstance(
|
||||
self.status,
|
||||
(
|
||||
RunnerRunning,
|
||||
RunnerWarmingUp,
|
||||
RunnerLoading,
|
||||
RunnerConnecting,
|
||||
RunnerShuttingDown,
|
||||
),
|
||||
)
|
||||
self.completed.add(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
|
||||
@@ -9,9 +9,11 @@ MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
|
||||
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
|
||||
NODE_C: Final[NodeId] = NodeId("cccccccc-cccc-4ccc-8ccc-cccccccccccc")
|
||||
|
||||
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
|
||||
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
|
||||
RUNNER_3_ID: Final[RunnerId] = RunnerId("Runner3")
|
||||
|
||||
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
|
||||
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import BaseTask
|
||||
from exo.shared.types.tasks import BaseTask, TaskId
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
Instance,
|
||||
@@ -21,6 +19,7 @@ from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -7,7 +8,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerIdle,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
@@ -46,7 +46,7 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# No entry for this shard -> should trigger DownloadModel
|
||||
download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||
download_status: dict[ModelId, DownloadProgress] = {}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
@@ -94,7 +94,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
local_download_status = {
|
||||
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
}
|
||||
|
||||
# Global view has completed downloads for both nodes
|
||||
@@ -140,7 +140,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
shard: DownloadCompleted(shard_metadata=shard, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
|
||||
}
|
||||
|
||||
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
|
||||
@@ -192,7 +192,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
local_download_status = {
|
||||
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
}
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
|
||||
@@ -12,8 +12,10 @@ from exo.worker.tests.constants import (
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
NODE_C,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
RUNNER_3_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import (
|
||||
FakeRunnerSupervisor,
|
||||
@@ -24,37 +26,39 @@ from exo.worker.tests.unittests.conftest import (
|
||||
|
||||
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
|
||||
"""
|
||||
For non-final device_rank shards, StartWarmup should be emitted when all
|
||||
For non-zero device_rank shards, StartWarmup should be emitted when all
|
||||
shards in the instance are Loaded/WarmingUp.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=3)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=3)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=2, world_size=3)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID, NODE_C: RUNNER_3_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1, RUNNER_3_ID: shard2},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
RUNNER_3_ID: RunnerWarmingUp(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_B: []},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
@@ -150,9 +154,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
"""
|
||||
Rank-zero shard should not start warmup until all non-zero ranks are
|
||||
already WarmingUp.
|
||||
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
|
||||
For accepting ranks (device_rank != 0), StartWarmup should be
|
||||
emitted when all shards in the instance are Loaded/WarmingUp.
|
||||
In a 2-node setup, rank 0 is the accepting rank.
|
||||
In a 2-node setup, rank 1 is the accepting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -163,7 +167,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 0 is the accepting rank
|
||||
# Rank 1 is the accepting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
@@ -188,6 +192,23 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerWarmingUp(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
@@ -280,9 +301,8 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
|
||||
|
||||
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
"""
|
||||
Connecting rank (device_rank == world_size - 1) should not start warmup
|
||||
Connecting rank (device_rank == 0) should not start warmup
|
||||
until all other ranks are already WarmingUp.
|
||||
In a 2-node setup, rank 1 is the connecting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -295,13 +315,13 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
|
||||
# Rank 1 is the connecting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
@@ -309,7 +329,7 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
|
||||
@@ -34,6 +34,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
@@ -199,6 +200,9 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
|
||||
),
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user