Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
ea593075d7 api cancellation
closing the http request to the api now
- sends a cancellation from the api
- writes that canellation in the master
- worker plans off the cancellation
- runner observes that cancellation after every generation step (+1
communication per token)
- cancellation happens synchronously to prevent gpu locks
2026-01-24 21:50:50 +00:00
10 changed files with 262 additions and 326 deletions

View File

@@ -16,7 +16,7 @@
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)

View File

@@ -12,7 +12,6 @@
ttftMs,
tps,
totalTokens,
cancelRequest,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
@@ -606,15 +605,37 @@
style="min-height: 28px; max-height: 150px;"
></textarea>
{#if loading}
<button
type="button"
onclick={() => cancelRequest()}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/50 text-exo-light-gray border border-exo-medium-gray/50 hover:border-red-500/50 hover:text-red-400 cursor-pointer"
>
<button
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<span class="sm:hidden">...</span>
</span>
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3 h-3"
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
@@ -623,81 +644,47 @@
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span class="hidden sm:inline">CANCEL</span>
<span class="sm:hidden">X</span>
<span>EDIT</span>
</span>
</button>
{:else}
<button
type="submit"
disabled={!canSend || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
{/if}
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</div>
<!-- Bottom accent line -->

View File

@@ -464,7 +464,6 @@ class AppStore {
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
private previousNodeIds: Set<string> = new Set();
private activeAbortController: AbortController | null = null;
constructor() {
if (browser) {
@@ -1747,9 +1746,6 @@ class AppStore {
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
this.activeAbortController = new AbortController();
const signal = this.activeAbortController.signal;
this.isLoading = true;
this.currentResponse = "";
this.ttftMs = null;
@@ -1884,7 +1880,6 @@ class AppStore {
temperature: 0.7,
stream: true,
}),
signal,
});
if (!response.ok) {
@@ -1980,9 +1975,6 @@ class AppStore {
this.persistConversation(targetConversationId);
}
} catch (error) {
if (signal.aborted) {
return;
}
console.error("Error sending message:", error);
this.handleStreamingError(
error,
@@ -1991,7 +1983,6 @@ class AppStore {
"Failed to get response",
);
} finally {
this.activeAbortController = null;
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
@@ -2012,9 +2003,6 @@ class AppStore {
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
this.activeAbortController = new AbortController();
const signal = this.activeAbortController.signal;
this.isLoading = true;
this.currentResponse = "";
@@ -2100,7 +2088,6 @@ class AppStore {
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
signal,
});
if (!response.ok) {
@@ -2210,19 +2197,6 @@ class AppStore {
},
);
} catch (error) {
if (signal.aborted) {
// Clean up the "Generating image..." message on cancellation
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "Cancelled";
msg.attachments = [];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
return;
}
console.error("Error generating image:", error);
this.handleStreamingError(
error,
@@ -2231,7 +2205,6 @@ class AppStore {
"Failed to generate image",
);
} finally {
this.activeAbortController = null;
this.isLoading = false;
this.saveConversationsToStorage();
}
@@ -2255,9 +2228,6 @@ class AppStore {
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
this.activeAbortController = new AbortController();
const signal = this.activeAbortController.signal;
this.isLoading = true;
this.currentResponse = "";
@@ -2366,7 +2336,6 @@ class AppStore {
const apiResponse = await fetch("/v1/images/edits", {
method: "POST",
body: formData,
signal,
});
if (!apiResponse.ok) {
@@ -2438,19 +2407,6 @@ class AppStore {
},
);
} catch (error) {
if (signal.aborted) {
// Clean up the "Editing image..." message on cancellation
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "cancelled";
msg.attachments = [];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
return;
}
console.error("Error editing image:", error);
this.handleStreamingError(
error,
@@ -2459,24 +2415,11 @@ class AppStore {
"Failed to edit image",
);
} finally {
this.activeAbortController = null;
this.isLoading = false;
this.saveConversationsToStorage();
}
}
/**
* Cancel an in-flight request by aborting the active fetch
*/
cancelRequest(): void {
if (this.activeAbortController) {
this.activeAbortController.abort();
this.activeAbortController = null;
}
this.isLoading = false;
this.currentResponse = "";
}
/**
* Clear current chat and go back to welcome state
*/
@@ -2613,7 +2556,6 @@ export const editMessage = (messageId: string, newContent: string) =>
export const editAndRegenerate = (messageId: string, newContent: string) =>
appStore.editAndRegenerate(messageId, newContent);
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
export const cancelRequest = () => appStore.cancelRequest();
// Conversation actions
export const conversations = () => appStore.conversations;

View File

@@ -248,7 +248,7 @@ class Master:
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case PlaceInstance():
@@ -260,7 +260,7 @@ class Master:
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case CreateInstance():
@@ -270,7 +270,7 @@ class Master:
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):

View File

@@ -20,9 +20,15 @@ from exo.shared.types.commands import (
PlaceInstance,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.events import (
Event,
InstanceCreated,
InstanceDeleted,
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -180,6 +186,7 @@ def delete_instance(
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],
tasks: Mapping[TaskId, Task],
) -> Sequence[Event]:
events: list[Event] = []
@@ -195,6 +202,18 @@ def get_transition_events(
# find instances to delete
for instance_id in current_instances:
if instance_id not in target_instances:
for task in tasks.values():
if task.instance_id == instance_id and task.task_status in [
TaskStatus.Pending,
TaskStatus.Running,
]:
events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
events.append(
InstanceDeleted(
instance_id=instance_id,

View File

@@ -1,4 +1,4 @@
from collections.abc import Callable, Generator
from collections.abc import Generator
from pathlib import Path
from typing import Any, Literal, Optional
@@ -109,7 +109,6 @@ class DistributedImageModel:
image_path: Path | None = None,
partial_images: int = 0,
advanced_params: AdvancedImageParams | None = None,
cancel_checker: Callable[[], bool] | None = None,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
if (
advanced_params is not None
@@ -154,7 +153,6 @@ class DistributedImageModel:
guidance_override=guidance_override,
negative_prompt=negative_prompt,
num_sync_steps=num_sync_steps,
cancel_checker=cancel_checker,
):
if isinstance(result, tuple):
# Partial image: (GeneratedImage, partial_index, total_partials)

View File

@@ -3,7 +3,6 @@ import io
import random
import tempfile
import time
from collections.abc import Callable
from pathlib import Path
from typing import Generator, Literal
@@ -69,18 +68,12 @@ def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
def generate_image(
model: DistributedImageModel,
task: ImageGenerationTaskParams | ImageEditsInternalParams,
cancel_checker: Callable[[], bool] | None = None,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
When partial_images > 0 or stream=True, yields PartialImageResponse for
intermediate images, then ImageGenerationResponse for the final image.
Args:
model: The distributed image model to use for generation.
task: The task parameters for image generation or editing.
cancel_checker: Optional callback to check if generation should be cancelled.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
ImageGenerationResponse for final complete images
@@ -130,7 +123,6 @@ def generate_image(
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
cancel_checker=cancel_checker,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)

View File

@@ -1,4 +1,3 @@
from collections.abc import Callable
from math import ceil
from typing import Any, Optional
@@ -95,8 +94,6 @@ class DiffusionRunner:
self.total_layers = config.total_blocks
self._guidance_override: float | None = None
self._cancel_checker: Callable[[], bool] | None = None
self._cancelling = False
self._compute_assigned_blocks()
@@ -151,54 +148,6 @@ class DiffusionRunner:
return self._guidance_override
return self.config.guidance_scale
def _check_cancellation(self) -> bool:
if self._cancelling:
return True
if (
self.is_first_stage
and self._cancel_checker is not None
and self._cancel_checker()
):
self._cancelling = True
return self._cancelling
def _is_sentinel(self, tensor: mx.array) -> bool:
return bool(mx.any(mx.isnan(tensor)).item())
def _make_sentinel_like(self, tensor: mx.array) -> mx.array:
return mx.full(tensor.shape, float("nan"), dtype=tensor.dtype)
def _recv(
self,
shape: tuple[int, ...],
dtype: mx.Dtype,
src: int,
) -> mx.array:
"""Receive data and check for cancellation sentinel."""
data = mx.distributed.recv(shape, dtype, src, group=self.group)
mx.eval(data)
if self._is_sentinel(data):
self._cancelling = True
return data
def _recv_like(self, template: mx.array, src: int) -> mx.array:
"""Receive data matching template and check for cancellation sentinel."""
data = mx.distributed.recv_like(template, src=src, group=self.group)
mx.eval(data)
if self._is_sentinel(data):
self._cancelling = True
return data
def _send(self, data: mx.array, dst: int) -> mx.array:
"""Send data, or sentinel if cancelling."""
if self._cancelling:
data = self._make_sentinel_like(data)
result = mx.distributed.send(data, dst, group=self.group)
mx.async_eval(result)
return result
def _ensure_wrappers(
self,
text_seq_len: int,
@@ -295,7 +244,6 @@ class DiffusionRunner:
guidance_override: float | None = None,
negative_prompt: str | None = None,
num_sync_steps: int = 1,
cancel_checker: Callable[[], bool] | None = None,
):
"""Primary entry point for image generation.
@@ -307,21 +255,17 @@ class DiffusionRunner:
5. Decode to image
Args:
runtime_config: Runtime configuration (steps, height, width)
settings: Generation config (steps, height, width)
prompt: Text prompt
seed: Random seed
partial_images: Number of intermediate images to yield (0 for none)
guidance_override: Optional override for guidance scale (CFG)
negative_prompt: Optional negative prompt for CFG
num_sync_steps: Number of synchronous pipeline steps
cancel_checker: Optional callback to check for cancellation
Yields:
Partial images as (GeneratedImage, partial_index, total_partials) tuples
Final GeneratedImage
"""
self._guidance_override = guidance_override
self._cancel_checker = cancel_checker
latents = self.adapter.create_latents(seed, runtime_config)
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
@@ -363,7 +307,7 @@ class DiffusionRunner:
except StopIteration as e:
latents = e.value # pyright: ignore[reportAny]
if self.is_last_stage and not self._cancelling:
if self.is_last_stage:
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt) # pyright: ignore[reportAny]
def _run_diffusion_loop(
@@ -379,7 +323,6 @@ class DiffusionRunner:
if capture_steps is None:
capture_steps = set()
self._cancelling = False
self._reset_all_caches()
time_steps = tqdm(range(runtime_config.num_inference_steps))
@@ -402,9 +345,6 @@ class DiffusionRunner:
num_sync_steps=num_sync_steps,
)
if self._cancelling:
break
ctx.in_loop( # pyright: ignore[reportAny]
t=t,
latents=latents,
@@ -626,8 +566,6 @@ class DiffusionRunner:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_hidden_states_mask)
self._check_cancellation()
encoder_hidden_states: mx.array | None = None
if self.is_first_stage:
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
@@ -647,12 +585,19 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
hidden_states = self._recv(
(batch_size, num_img_tokens, hidden_dim), dtype, self.prev_rank
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = self._recv(
(batch_size, text_seq_len, hidden_dim), dtype, self.prev_rank
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
@@ -674,20 +619,30 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
concatenated = self._send(concatenated, self.next_rank)
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
hidden_states = self._send(hidden_states, self.next_rank)
encoder_hidden_states = self._send(encoder_hidden_states, self.next_rank)
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
hidden_states = self._recv(
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
@@ -699,7 +654,10 @@ class DiffusionRunner:
)
if not self.is_last_stage:
hidden_states = self._send(hidden_states, self.next_rank)
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
hidden_states = hidden_states[:, text_seq_len:, ...]
@@ -783,13 +741,14 @@ class DiffusionRunner:
)
if not self.is_first_stage:
hidden_states = self._send(hidden_states, 0)
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
elif self.is_first_stage:
hidden_states = self._recv_like(prev_latents, src=self.world_size - 1)
if self._cancelling:
return prev_latents
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
hidden_states = prev_latents
@@ -849,9 +808,10 @@ class DiffusionRunner:
and not self.is_last_stage
and not is_first_async_step
):
patch = self._recv_like(patch, src=self.prev_rank)
self._check_cancellation()
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
@@ -881,11 +841,11 @@ class DiffusionRunner:
latents=prev_patch_latents[patch_idx],
)
# Ring send back to first stage (except on last timestep)
if not self.is_first_stage and t != config.num_inference_steps - 1:
patch_latents[patch_idx] = self._send(
patch_latents[patch_idx], self.next_rank
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
return mx.concatenate(patch_latents, axis=1)
@@ -924,16 +884,22 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
patch_len = patch.shape[1]
patch = self._recv(
(batch_size, patch_len, hidden_dim), patch.dtype, self.prev_rank
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
if patch_idx == 0:
encoder_hidden_states = self._recv(
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
@@ -958,25 +924,32 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
patch_concat = self._send(patch_concat, self.next_rank)
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
patch = self._send(patch, self.next_rank)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
encoder_hidden_states = self._send(
encoder_hidden_states, self.next_rank
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
patch_len = patch.shape[1]
patch = self._recv(
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
@@ -988,7 +961,8 @@ class DiffusionRunner:
)
if not self.is_last_stage:
patch = self._send(patch, self.next_rank)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None
if self.is_last_stage:

View File

@@ -522,16 +522,6 @@ def mx_any(bool_: bool, group: Group | None) -> bool:
return num_true.item() > 0
def mx_all(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() == group.size()
def mx_barrier(group: Group | None):
if group is None:
return

View File

@@ -278,10 +278,12 @@ def main(
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
last_checked = time.perf_counter()
cancel_every = 5
tokens_since_last_cancel_check = 0
for response in mlx_generator:
if (t := time.perf_counter()) - last_checked > 0.1:
last_checked = t
tokens_since_last_cancel_check += 1
if tokens_since_last_cancel_check >= cancel_every:
tokens_since_last_cancel_check = 0
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
@@ -352,16 +354,11 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration() | ImageEdits() if isinstance(
current_status, RunnerReady
):
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert image_model
task_name = (
"image generation"
if isinstance(task, ImageGeneration)
else "image edits"
)
logger.info(f"received {task_name} request: {str(task)[:500]}")
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
@@ -371,19 +368,104 @@ def main(
)
try:
_run_image_task(
task=task,
image_model=image_model,
shard_metadata=shard_metadata,
event_sender=event_sender,
cancel_receiver=cancel_receiver,
cancelled_tasks=cancelled_tasks,
)
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=task.command_id,
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert image_model
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
image_index = 0
for response in generate_image(
model=image_model, task=task_params
):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
except Exception as e:
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
@@ -524,54 +606,6 @@ def parse_thinking_models(
yield response
def _run_image_task(
task: ImageGeneration | ImageEdits,
image_model: DistributedImageModel,
shard_metadata: ShardMetadata,
event_sender: MpSender[Event],
cancel_receiver: MpReceiver[TaskId],
cancelled_tasks: set[TaskId],
) -> None:
task_id = task.task_id
command_id = task.command_id
def check_cancelled(task_id: TaskId = task_id) -> bool:
cancelled_tasks.update(cancel_receiver.collect())
return (task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
image_index = 0
for response in generate_image(
model=image_model,
task=task.task_params,
cancel_checker=check_cancelled,
):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
def _send_image_chunk(
encoded_data: str,
command_id: CommandId,