mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-24 22:12:39 -05:00
Compare commits
1 Commits
ciaran/ima
...
runner-can
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea593075d7 |
@@ -16,7 +16,7 @@
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
ttftMs,
|
||||
tps,
|
||||
totalTokens,
|
||||
cancelRequest,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import ChatAttachments from "./ChatAttachments.svelte";
|
||||
import ImageParamsPanel from "./ImageParamsPanel.svelte";
|
||||
@@ -606,15 +605,37 @@
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
{#if loading}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => cancelRequest()}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/50 text-exo-light-gray border border-exo-medium-gray/50 hover:border-red-500/50 hover:text-red-400 cursor-pointer"
|
||||
>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
@@ -623,81 +644,47 @@
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="hidden sm:inline">CANCEL</span>
|
||||
<span class="sm:hidden">X</span>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
|
||||
@@ -464,7 +464,6 @@ class AppStore {
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
private previousNodeIds: Set<string> = new Set();
|
||||
private activeAbortController: AbortController | null = null;
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
@@ -1747,9 +1746,6 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
this.ttftMs = null;
|
||||
@@ -1884,7 +1880,6 @@ class AppStore {
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
}),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -1980,9 +1975,6 @@ class AppStore {
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
return;
|
||||
}
|
||||
console.error("Error sending message:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -1991,7 +1983,6 @@ class AppStore {
|
||||
"Failed to get response",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -2012,9 +2003,6 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
|
||||
@@ -2100,7 +2088,6 @@ class AppStore {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -2210,19 +2197,6 @@ class AppStore {
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
// Clean up the "Generating image..." message on cancellation
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "Cancelled";
|
||||
msg.attachments = [];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
return;
|
||||
}
|
||||
console.error("Error generating image:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -2231,7 +2205,6 @@ class AppStore {
|
||||
"Failed to generate image",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2255,9 +2228,6 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
|
||||
@@ -2366,7 +2336,6 @@ class AppStore {
|
||||
const apiResponse = await fetch("/v1/images/edits", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!apiResponse.ok) {
|
||||
@@ -2438,19 +2407,6 @@ class AppStore {
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
// Clean up the "Editing image..." message on cancellation
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "cancelled";
|
||||
msg.attachments = [];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
return;
|
||||
}
|
||||
console.error("Error editing image:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -2459,24 +2415,11 @@ class AppStore {
|
||||
"Failed to edit image",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel an in-flight request by aborting the active fetch
|
||||
*/
|
||||
cancelRequest(): void {
|
||||
if (this.activeAbortController) {
|
||||
this.activeAbortController.abort();
|
||||
this.activeAbortController = null;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear current chat and go back to welcome state
|
||||
*/
|
||||
@@ -2613,7 +2556,6 @@ export const editMessage = (messageId: string, newContent: string) =>
|
||||
export const editAndRegenerate = (messageId: string, newContent: string) =>
|
||||
appStore.editAndRegenerate(messageId, newContent);
|
||||
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
|
||||
export const cancelRequest = () => appStore.cancelRequest();
|
||||
|
||||
// Conversation actions
|
||||
export const conversations = () => appStore.conversations;
|
||||
|
||||
@@ -248,7 +248,7 @@ class Master:
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
@@ -260,7 +260,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -270,7 +270,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
|
||||
@@ -20,9 +20,15 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -180,6 +186,7 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -195,6 +202,18 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Callable, Generator
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
@@ -109,7 +109,6 @@ class DistributedImageModel:
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
advanced_params: AdvancedImageParams | None = None,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
if (
|
||||
advanced_params is not None
|
||||
@@ -154,7 +153,6 @@ class DistributedImageModel:
|
||||
guidance_override=guidance_override,
|
||||
negative_prompt=negative_prompt,
|
||||
num_sync_steps=num_sync_steps,
|
||||
cancel_checker=cancel_checker,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (GeneratedImage, partial_index, total_partials)
|
||||
|
||||
@@ -3,7 +3,6 @@ import io
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
@@ -69,18 +68,12 @@ def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
|
||||
def generate_image(
|
||||
model: DistributedImageModel,
|
||||
task: ImageGenerationTaskParams | ImageEditsInternalParams,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
|
||||
"""Generate image(s), optionally yielding partial results.
|
||||
|
||||
When partial_images > 0 or stream=True, yields PartialImageResponse for
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Args:
|
||||
model: The distributed image model to use for generation.
|
||||
task: The task parameters for image generation or editing.
|
||||
cancel_checker: Optional callback to check if generation should be cancelled.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
|
||||
ImageGenerationResponse for final complete images
|
||||
@@ -130,7 +123,6 @@ def generate_image(
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
cancel_checker=cancel_checker,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from collections.abc import Callable
|
||||
from math import ceil
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -95,8 +94,6 @@ class DiffusionRunner:
|
||||
self.total_layers = config.total_blocks
|
||||
|
||||
self._guidance_override: float | None = None
|
||||
self._cancel_checker: Callable[[], bool] | None = None
|
||||
self._cancelling = False
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
@@ -151,54 +148,6 @@ class DiffusionRunner:
|
||||
return self._guidance_override
|
||||
return self.config.guidance_scale
|
||||
|
||||
def _check_cancellation(self) -> bool:
|
||||
if self._cancelling:
|
||||
return True
|
||||
if (
|
||||
self.is_first_stage
|
||||
and self._cancel_checker is not None
|
||||
and self._cancel_checker()
|
||||
):
|
||||
self._cancelling = True
|
||||
return self._cancelling
|
||||
|
||||
def _is_sentinel(self, tensor: mx.array) -> bool:
|
||||
return bool(mx.any(mx.isnan(tensor)).item())
|
||||
|
||||
def _make_sentinel_like(self, tensor: mx.array) -> mx.array:
|
||||
return mx.full(tensor.shape, float("nan"), dtype=tensor.dtype)
|
||||
|
||||
def _recv(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: mx.Dtype,
|
||||
src: int,
|
||||
) -> mx.array:
|
||||
"""Receive data and check for cancellation sentinel."""
|
||||
data = mx.distributed.recv(shape, dtype, src, group=self.group)
|
||||
mx.eval(data)
|
||||
if self._is_sentinel(data):
|
||||
self._cancelling = True
|
||||
return data
|
||||
|
||||
def _recv_like(self, template: mx.array, src: int) -> mx.array:
|
||||
"""Receive data matching template and check for cancellation sentinel."""
|
||||
data = mx.distributed.recv_like(template, src=src, group=self.group)
|
||||
mx.eval(data)
|
||||
if self._is_sentinel(data):
|
||||
self._cancelling = True
|
||||
return data
|
||||
|
||||
def _send(self, data: mx.array, dst: int) -> mx.array:
|
||||
"""Send data, or sentinel if cancelling."""
|
||||
|
||||
if self._cancelling:
|
||||
data = self._make_sentinel_like(data)
|
||||
|
||||
result = mx.distributed.send(data, dst, group=self.group)
|
||||
mx.async_eval(result)
|
||||
return result
|
||||
|
||||
def _ensure_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
@@ -295,7 +244,6 @@ class DiffusionRunner:
|
||||
guidance_override: float | None = None,
|
||||
negative_prompt: str | None = None,
|
||||
num_sync_steps: int = 1,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
):
|
||||
"""Primary entry point for image generation.
|
||||
|
||||
@@ -307,21 +255,17 @@ class DiffusionRunner:
|
||||
5. Decode to image
|
||||
|
||||
Args:
|
||||
runtime_config: Runtime configuration (steps, height, width)
|
||||
settings: Generation config (steps, height, width)
|
||||
prompt: Text prompt
|
||||
seed: Random seed
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
guidance_override: Optional override for guidance scale (CFG)
|
||||
negative_prompt: Optional negative prompt for CFG
|
||||
num_sync_steps: Number of synchronous pipeline steps
|
||||
cancel_checker: Optional callback to check for cancellation
|
||||
|
||||
Yields:
|
||||
Partial images as (GeneratedImage, partial_index, total_partials) tuples
|
||||
Final GeneratedImage
|
||||
"""
|
||||
self._guidance_override = guidance_override
|
||||
self._cancel_checker = cancel_checker
|
||||
latents = self.adapter.create_latents(seed, runtime_config)
|
||||
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
|
||||
|
||||
@@ -363,7 +307,7 @@ class DiffusionRunner:
|
||||
except StopIteration as e:
|
||||
latents = e.value # pyright: ignore[reportAny]
|
||||
|
||||
if self.is_last_stage and not self._cancelling:
|
||||
if self.is_last_stage:
|
||||
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt) # pyright: ignore[reportAny]
|
||||
|
||||
def _run_diffusion_loop(
|
||||
@@ -379,7 +323,6 @@ class DiffusionRunner:
|
||||
if capture_steps is None:
|
||||
capture_steps = set()
|
||||
|
||||
self._cancelling = False
|
||||
self._reset_all_caches()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
@@ -402,9 +345,6 @@ class DiffusionRunner:
|
||||
num_sync_steps=num_sync_steps,
|
||||
)
|
||||
|
||||
if self._cancelling:
|
||||
break
|
||||
|
||||
ctx.in_loop( # pyright: ignore[reportAny]
|
||||
t=t,
|
||||
latents=latents,
|
||||
@@ -626,8 +566,6 @@ class DiffusionRunner:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_hidden_states_mask)
|
||||
|
||||
self._check_cancellation()
|
||||
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
if self.is_first_stage:
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
@@ -647,12 +585,19 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
hidden_states = self._recv(
|
||||
(batch_size, num_img_tokens, hidden_dim), dtype, self.prev_rank
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = self._recv(
|
||||
(batch_size, text_seq_len, hidden_dim), dtype, self.prev_rank
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
@@ -674,20 +619,30 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
concatenated = self._send(concatenated, self.next_rank)
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
hidden_states = self._send(hidden_states, self.next_rank)
|
||||
encoder_hidden_states = self._send(encoder_hidden_states, self.next_rank)
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
hidden_states = self._recv(
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
@@ -699,7 +654,10 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
hidden_states = self._send(hidden_states, self.next_rank)
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
@@ -783,13 +741,14 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = self._send(hidden_states, 0)
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = self._recv_like(prev_latents, src=self.world_size - 1)
|
||||
|
||||
if self._cancelling:
|
||||
return prev_latents
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
else:
|
||||
hidden_states = prev_latents
|
||||
@@ -849,9 +808,10 @@ class DiffusionRunner:
|
||||
and not self.is_last_stage
|
||||
and not is_first_async_step
|
||||
):
|
||||
patch = self._recv_like(patch, src=self.prev_rank)
|
||||
|
||||
self._check_cancellation()
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
|
||||
@@ -881,11 +841,11 @@ class DiffusionRunner:
|
||||
latents=prev_patch_latents[patch_idx],
|
||||
)
|
||||
|
||||
# Ring send back to first stage (except on last timestep)
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
patch_latents[patch_idx] = self._send(
|
||||
patch_latents[patch_idx], self.next_rank
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
@@ -924,16 +884,22 @@ class DiffusionRunner:
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = self._recv(
|
||||
(batch_size, patch_len, hidden_dim), patch.dtype, self.prev_rank
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
encoder_hidden_states = self._recv(
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
@@ -958,25 +924,32 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
patch_concat = self._send(patch_concat, self.next_rank)
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
patch = self._send(patch, self.next_rank)
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
encoder_hidden_states = self._send(
|
||||
encoder_hidden_states, self.next_rank
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = self._recv(
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
@@ -988,7 +961,8 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
patch = self._send(patch, self.next_rank)
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
if self.is_last_stage:
|
||||
|
||||
@@ -522,16 +522,6 @@ def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_all(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() == group.size()
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
|
||||
@@ -278,10 +278,12 @@ def main(
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
last_checked = time.perf_counter()
|
||||
cancel_every = 5
|
||||
tokens_since_last_cancel_check = 0
|
||||
for response in mlx_generator:
|
||||
if (t := time.perf_counter()) - last_checked > 0.1:
|
||||
last_checked = t
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= cancel_every:
|
||||
tokens_since_last_cancel_check = 0
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
@@ -352,16 +354,11 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageGeneration() | ImageEdits() if isinstance(
|
||||
current_status, RunnerReady
|
||||
):
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert image_model
|
||||
task_name = (
|
||||
"image generation"
|
||||
if isinstance(task, ImageGeneration)
|
||||
else "image edits"
|
||||
)
|
||||
logger.info(f"received {task_name} request: {str(task)[:500]}")
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
@@ -371,19 +368,104 @@ def main(
|
||||
)
|
||||
|
||||
try:
|
||||
_run_image_task(
|
||||
task=task,
|
||||
image_model=image_model,
|
||||
shard_metadata=shard_metadata,
|
||||
event_sender=event_sender,
|
||||
cancel_receiver=cancel_receiver,
|
||||
cancelled_tasks=cancelled_tasks,
|
||||
)
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=task.command_id,
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert image_model
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
@@ -524,54 +606,6 @@ def parse_thinking_models(
|
||||
yield response
|
||||
|
||||
|
||||
def _run_image_task(
|
||||
task: ImageGeneration | ImageEdits,
|
||||
image_model: DistributedImageModel,
|
||||
shard_metadata: ShardMetadata,
|
||||
event_sender: MpSender[Event],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
cancelled_tasks: set[TaskId],
|
||||
) -> None:
|
||||
task_id = task.task_id
|
||||
command_id = task.command_id
|
||||
|
||||
def check_cancelled(task_id: TaskId = task_id) -> bool:
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
return (task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model,
|
||||
task=task.task_params,
|
||||
cancel_checker=check_cancelled,
|
||||
):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
|
||||
def _send_image_chunk(
|
||||
encoded_data: str,
|
||||
command_id: CommandId,
|
||||
|
||||
Reference in New Issue
Block a user