Compare commits

...

5 Commits

Author SHA1 Message Date
ciaranbor
fa80a51f70 Handle cancellation completion in dashboard 2026-01-24 20:09:11 +00:00
ciaranbor
278c02b200 Handle cancellation signal in diffusion runner 2026-01-24 20:06:20 +00:00
ciaranbor
ee31bd7f93 Refactor duplicate image generation and image editing runner logic. Add cancellation checker to inject into model inference 2026-01-24 19:51:40 +00:00
ciaranbor
95310bc3ae Add generation cancellation button to UI 2026-01-24 18:57:17 +00:00
Evan
ea61b59941 api cancellation
closing the http request to the api now
- sends a cancellation from the api
- writes that canellation in the master
- worker plans off the cancellation
- runner observes that cancellation after every generation step (+1
communication per token)
- cancellation happens synchronously to prevent gpu locks
2026-01-24 04:06:02 +00:00
17 changed files with 502 additions and 338 deletions

View File

@@ -5,16 +5,16 @@
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[X] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[X] GPTOSS support dropped in auto_parallel.py.
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.

View File

@@ -12,6 +12,7 @@
ttftMs,
tps,
totalTokens,
cancelRequest,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
@@ -605,37 +606,15 @@
style="min-height: 28px; max-height: 150px;"
></textarea>
<button
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
{#if loading}
<button
type="button"
onclick={() => cancelRequest()}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/50 text-exo-light-gray border border-exo-medium-gray/50 hover:border-red-500/50 hover:text-red-400 cursor-pointer"
>
<span class="inline-flex items-center gap-1 sm:gap-2">
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<span class="sm:hidden">...</span>
</span>
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
@@ -644,47 +623,81 @@
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
<span>EDIT</span>
<span class="hidden sm:inline">CANCEL</span>
<span class="sm:hidden">X</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</button>
{:else}
<button
type="submit"
disabled={!canSend || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
{/if}
</div>
<!-- Bottom accent line -->

View File

@@ -464,6 +464,7 @@ class AppStore {
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
private previousNodeIds: Set<string> = new Set();
private activeAbortController: AbortController | null = null;
constructor() {
if (browser) {
@@ -1746,6 +1747,9 @@ class AppStore {
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
this.activeAbortController = new AbortController();
const signal = this.activeAbortController.signal;
this.isLoading = true;
this.currentResponse = "";
this.ttftMs = null;
@@ -1880,6 +1884,7 @@ class AppStore {
temperature: 0.7,
stream: true,
}),
signal,
});
if (!response.ok) {
@@ -1975,6 +1980,9 @@ class AppStore {
this.persistConversation(targetConversationId);
}
} catch (error) {
if (signal.aborted) {
return;
}
console.error("Error sending message:", error);
this.handleStreamingError(
error,
@@ -1983,6 +1991,7 @@ class AppStore {
"Failed to get response",
);
} finally {
this.activeAbortController = null;
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
@@ -2003,6 +2012,9 @@ class AppStore {
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
this.activeAbortController = new AbortController();
const signal = this.activeAbortController.signal;
this.isLoading = true;
this.currentResponse = "";
@@ -2088,6 +2100,7 @@ class AppStore {
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
signal,
});
if (!response.ok) {
@@ -2197,6 +2210,19 @@ class AppStore {
},
);
} catch (error) {
if (signal.aborted) {
// Clean up the "Generating image..." message on cancellation
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "Cancelled";
msg.attachments = [];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
return;
}
console.error("Error generating image:", error);
this.handleStreamingError(
error,
@@ -2205,6 +2231,7 @@ class AppStore {
"Failed to generate image",
);
} finally {
this.activeAbortController = null;
this.isLoading = false;
this.saveConversationsToStorage();
}
@@ -2228,6 +2255,9 @@ class AppStore {
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
this.activeAbortController = new AbortController();
const signal = this.activeAbortController.signal;
this.isLoading = true;
this.currentResponse = "";
@@ -2336,6 +2366,7 @@ class AppStore {
const apiResponse = await fetch("/v1/images/edits", {
method: "POST",
body: formData,
signal,
});
if (!apiResponse.ok) {
@@ -2407,6 +2438,19 @@ class AppStore {
},
);
} catch (error) {
if (signal.aborted) {
// Clean up the "Editing image..." message on cancellation
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "cancelled";
msg.attachments = [];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
return;
}
console.error("Error editing image:", error);
this.handleStreamingError(
error,
@@ -2415,11 +2459,24 @@ class AppStore {
"Failed to edit image",
);
} finally {
this.activeAbortController = null;
this.isLoading = false;
this.saveConversationsToStorage();
}
}
/**
* Cancel an in-flight request by aborting the active fetch
*/
cancelRequest(): void {
if (this.activeAbortController) {
this.activeAbortController.abort();
this.activeAbortController = null;
}
this.isLoading = false;
this.currentResponse = "";
}
/**
* Clear current chat and go back to welcome state
*/
@@ -2556,6 +2613,7 @@ export const editMessage = (messageId: string, newContent: string) =>
export const editAndRegenerate = (messageId: string, newContent: string) =>
appStore.editAndRegenerate(messageId, newContent);
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
export const cancelRequest = () => appStore.cancelRequest();
// Conversation actions
export const conversations = () => appStore.conversations;

View File

@@ -88,6 +88,7 @@ from exo.shared.types.commands import (
PlaceInstance,
SendInputChunk,
StartDownload,
TaskCancelled,
TaskFinished,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
@@ -508,16 +509,14 @@ class API:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
@@ -901,6 +900,11 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -982,6 +986,11 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))

View File

@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskCancelled,
TaskFinished,
TestCommand,
)
@@ -35,6 +36,7 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -278,6 +280,18 @@ class Master:
chunk=chunk,
)
)
case TaskCancelled():
if (
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -286,10 +300,9 @@ class Master:
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
self.command_task_mapping.pop(
command.finished_command_id, None
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):

View File

@@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -84,6 +88,7 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskCancelled
| TaskFinished
| SendInputChunk
)

View File

@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
Complete = "Complete"
TimedOut = "TimedOut"
Failed = "Failed"
Cancelled = "Cancelled"
class BaseTask(TaggedModel):
@@ -60,6 +61,10 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
cancelled_task_id: TaskId
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -87,6 +92,7 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| CancelTask
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Callable, Generator
from pathlib import Path
from typing import Any, Literal, Optional
@@ -109,6 +109,7 @@ class DistributedImageModel:
image_path: Path | None = None,
partial_images: int = 0,
advanced_params: AdvancedImageParams | None = None,
cancel_checker: Callable[[], bool] | None = None,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
if (
advanced_params is not None
@@ -153,6 +154,7 @@ class DistributedImageModel:
guidance_override=guidance_override,
negative_prompt=negative_prompt,
num_sync_steps=num_sync_steps,
cancel_checker=cancel_checker,
):
if isinstance(result, tuple):
# Partial image: (GeneratedImage, partial_index, total_partials)

View File

@@ -3,6 +3,7 @@ import io
import random
import tempfile
import time
from collections.abc import Callable
from pathlib import Path
from typing import Generator, Literal
@@ -68,12 +69,18 @@ def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
def generate_image(
model: DistributedImageModel,
task: ImageGenerationTaskParams | ImageEditsInternalParams,
cancel_checker: Callable[[], bool] | None = None,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
When partial_images > 0 or stream=True, yields PartialImageResponse for
intermediate images, then ImageGenerationResponse for the final image.
Args:
model: The distributed image model to use for generation.
task: The task parameters for image generation or editing.
cancel_checker: Optional callback to check if generation should be cancelled.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
ImageGenerationResponse for final complete images
@@ -123,6 +130,7 @@ def generate_image(
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
cancel_checker=cancel_checker,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)

View File

@@ -1,3 +1,4 @@
from collections.abc import Callable
from math import ceil
from typing import Any, Optional
@@ -94,6 +95,8 @@ class DiffusionRunner:
self.total_layers = config.total_blocks
self._guidance_override: float | None = None
self._cancel_checker: Callable[[], bool] | None = None
self._cancelling = False
self._compute_assigned_blocks()
@@ -148,6 +151,54 @@ class DiffusionRunner:
return self._guidance_override
return self.config.guidance_scale
def _check_cancellation(self) -> bool:
if self._cancelling:
return True
if (
self.is_first_stage
and self._cancel_checker is not None
and self._cancel_checker()
):
self._cancelling = True
return self._cancelling
def _is_sentinel(self, tensor: mx.array) -> bool:
return bool(mx.any(mx.isnan(tensor)).item())
def _make_sentinel_like(self, tensor: mx.array) -> mx.array:
return mx.full(tensor.shape, float("nan"), dtype=tensor.dtype)
def _recv(
self,
shape: tuple[int, ...],
dtype: mx.Dtype,
src: int,
) -> mx.array:
"""Receive data and check for cancellation sentinel."""
data = mx.distributed.recv(shape, dtype, src, group=self.group)
mx.eval(data)
if self._is_sentinel(data):
self._cancelling = True
return data
def _recv_like(self, template: mx.array, src: int) -> mx.array:
"""Receive data matching template and check for cancellation sentinel."""
data = mx.distributed.recv_like(template, src=src, group=self.group)
mx.eval(data)
if self._is_sentinel(data):
self._cancelling = True
return data
def _send(self, data: mx.array, dst: int) -> mx.array:
"""Send data, or sentinel if cancelling."""
if self._cancelling:
data = self._make_sentinel_like(data)
result = mx.distributed.send(data, dst, group=self.group)
mx.async_eval(result)
return result
def _ensure_wrappers(
self,
text_seq_len: int,
@@ -244,6 +295,7 @@ class DiffusionRunner:
guidance_override: float | None = None,
negative_prompt: str | None = None,
num_sync_steps: int = 1,
cancel_checker: Callable[[], bool] | None = None,
):
"""Primary entry point for image generation.
@@ -255,17 +307,21 @@ class DiffusionRunner:
5. Decode to image
Args:
settings: Generation config (steps, height, width)
runtime_config: Runtime configuration (steps, height, width)
prompt: Text prompt
seed: Random seed
partial_images: Number of intermediate images to yield (0 for none)
guidance_override: Optional override for guidance scale (CFG)
negative_prompt: Optional negative prompt for CFG
num_sync_steps: Number of synchronous pipeline steps
cancel_checker: Optional callback to check for cancellation
Yields:
Partial images as (GeneratedImage, partial_index, total_partials) tuples
Final GeneratedImage
"""
self._guidance_override = guidance_override
self._cancel_checker = cancel_checker
latents = self.adapter.create_latents(seed, runtime_config)
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
@@ -307,7 +363,7 @@ class DiffusionRunner:
except StopIteration as e:
latents = e.value # pyright: ignore[reportAny]
if self.is_last_stage:
if self.is_last_stage and not self._cancelling:
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt) # pyright: ignore[reportAny]
def _run_diffusion_loop(
@@ -323,6 +379,7 @@ class DiffusionRunner:
if capture_steps is None:
capture_steps = set()
self._cancelling = False
self._reset_all_caches()
time_steps = tqdm(range(runtime_config.num_inference_steps))
@@ -345,6 +402,9 @@ class DiffusionRunner:
num_sync_steps=num_sync_steps,
)
if self._cancelling:
break
ctx.in_loop( # pyright: ignore[reportAny]
t=t,
latents=latents,
@@ -566,6 +626,8 @@ class DiffusionRunner:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_hidden_states_mask)
self._check_cancellation()
encoder_hidden_states: mx.array | None = None
if self.is_first_stage:
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
@@ -585,19 +647,12 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
hidden_states = self._recv(
(batch_size, num_img_tokens, hidden_dim), dtype, self.prev_rank
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
encoder_hidden_states = self._recv(
(batch_size, text_seq_len, hidden_dim), dtype, self.prev_rank
)
mx.eval(hidden_states, encoder_hidden_states)
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
@@ -619,30 +674,20 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
concatenated = self._send(concatenated, self.next_rank)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
hidden_states = self._send(hidden_states, self.next_rank)
encoder_hidden_states = self._send(encoder_hidden_states, self.next_rank)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
hidden_states = mx.distributed.recv(
hidden_states = self._recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
@@ -654,10 +699,7 @@ class DiffusionRunner:
)
if not self.is_last_stage:
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
hidden_states = self._send(hidden_states, self.next_rank)
hidden_states = hidden_states[:, text_seq_len:, ...]
@@ -741,14 +783,13 @@ class DiffusionRunner:
)
if not self.is_first_stage:
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
hidden_states = self._send(hidden_states, 0)
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
hidden_states = self._recv_like(prev_latents, src=self.world_size - 1)
if self._cancelling:
return prev_latents
else:
hidden_states = prev_latents
@@ -808,10 +849,9 @@ class DiffusionRunner:
and not self.is_last_stage
and not is_first_async_step
):
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
patch = self._recv_like(patch, src=self.prev_rank)
self._check_cancellation()
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
@@ -841,11 +881,11 @@ class DiffusionRunner:
latents=prev_patch_latents[patch_idx],
)
# Ring send back to first stage (except on last timestep)
if not self.is_first_stage and t != config.num_inference_steps - 1:
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx], self.next_rank, group=self.group
patch_latents[patch_idx] = self._send(
patch_latents[patch_idx], self.next_rank
)
mx.async_eval(patch_latents[patch_idx])
return mx.concatenate(patch_latents, axis=1)
@@ -884,22 +924,16 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
patch = self._recv(
(batch_size, patch_len, hidden_dim), patch.dtype, self.prev_rank
)
mx.eval(patch)
if patch_idx == 0:
encoder_hidden_states = mx.distributed.recv(
encoder_hidden_states = self._recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
@@ -924,32 +958,25 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
patch_concat = self._send(patch_concat, self.next_rank)
elif self.has_joint_blocks and not self.is_last_stage:
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
patch = self._send(patch, self.next_rank)
if patch_idx == 0:
assert encoder_hidden_states is not None
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
encoder_hidden_states = self._send(
encoder_hidden_states, self.next_rank
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
patch_len = patch.shape[1]
patch = mx.distributed.recv(
patch = self._recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
@@ -961,8 +988,7 @@ class DiffusionRunner:
)
if not self.is_last_stage:
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
patch = self._send(patch, self.next_rank)
noise: mx.array | None = None
if self.is_last_stage:

View File

@@ -23,7 +23,6 @@ from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
@@ -90,10 +89,6 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
# TODO: Do we want an mx_barrier?
# At least this version is actively incorrect, as it should use mx_barrier(group)
mx_barrier()
return tokens_generated
@@ -186,5 +181,3 @@ def mlx_generate(
if out.finish_reason is not None:
break
# TODO: Do we want an mx_barrier?

View File

@@ -70,8 +70,6 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -89,30 +87,6 @@ class ModelLoadingTimeoutError(Exception):
pass
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group,
)
)
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
if group.rank() == 0:
a = mx.array([value], dtype=mx.int32)
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -536,3 +510,33 @@ def mlx_cleanup(
import gc
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() > 0
def mx_all(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() == group.size()
def mx_barrier(group: Group | None):
if group is None:
return
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)

View File

@@ -33,6 +33,7 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -115,8 +116,9 @@ class Worker:
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
async with create_task_group() as tg:
for runner in self.runners.values():
tg.start_soon(runner.shutdown)
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
@@ -220,15 +222,22 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await self.runners.pop(runner_id).start_task(task)
await runner.start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
await runner.shutdown()
case CancelTask(cancelled_task_id=cancelled_task_id):
await self.runners[self._task_to_runner_id(task)].cancel_task(
cancelled_task_id
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -351,8 +360,6 @@ class Worker:
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
## Op Executors
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(

View File

@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
ChatCompletion,
ConnectToGroup,
CreateRunner,
@@ -59,7 +60,8 @@ def plan(
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
or _cancel_tasks(runners, tasks)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
)
@@ -270,7 +272,7 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
@@ -284,7 +286,7 @@ def _pending_tasks(
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
received = len(input_chunk_buffer.get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
@@ -292,16 +294,31 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):
return task
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
if task.task_id in runner.cancelled:
continue
return CancelTask(
instance_id=task.instance_id, cancelled_task_id=task.task_id
)

View File

@@ -3,7 +3,7 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -15,6 +15,7 @@ def entrypoint(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
@@ -38,7 +39,7 @@ def entrypoint(
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver, cancel_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -37,6 +37,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
@@ -77,6 +78,7 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
mx_any,
)
from exo.worker.runner.bootstrap import logger
@@ -85,6 +87,7 @@ def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
instance, runner_id, shard_metadata = (
bound_instance.instance,
@@ -99,8 +102,11 @@ def main(
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
model: Model | DistributedImageModel | None = None
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
group = None
@@ -111,6 +117,7 @@ def main(
)
with task_receiver as tasks:
for task in tasks:
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -155,7 +162,7 @@ def main(
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
model, tokenizer = load_mlx_items(
inference_model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
@@ -165,7 +172,7 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
model = initialize_image_model(bound_instance)
image_model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
@@ -174,8 +181,6 @@ def main(
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -186,11 +191,11 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
toks = warmup_inference(
model=model,
model=inference_model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
@@ -202,8 +207,8 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
assert image_model
image = warmup_image_generator(model=image_model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
@@ -222,7 +227,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert model and not isinstance(model, DistributedImageModel)
assert inference_model
assert tokenizer
assert task_params.messages[0].content is not None
@@ -234,7 +239,7 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
model=inference_model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
@@ -257,11 +262,11 @@ def main(
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(model, GptOssModel):
elif isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
model, GptOssModel
inference_model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
@@ -273,7 +278,17 @@ def main(
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
last_checked = time.perf_counter()
for response in mlx_generator:
if (t := time.perf_counter()) - last_checked > 0.1:
last_checked = t
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, group):
break
match response:
case GenerationResponse():
if (
@@ -337,72 +352,16 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, DistributedImageModel)
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(model=model, task=task_params):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
case ImageGeneration() | ImageEdits() if isinstance(
current_status, RunnerReady
):
assert isinstance(model, DistributedImageModel)
logger.info(f"received image edits request: {str(task)[:500]}")
assert image_model
task_name = (
"image generation"
if isinstance(task, ImageGeneration)
else "image edits"
)
logger.info(f"received {task_name} request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
@@ -412,39 +371,19 @@ def main(
)
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
_run_image_task(
task=task,
image_model=image_model,
shard_metadata=shard_metadata,
event_sender=event_sender,
cancel_receiver=cancel_receiver,
cancelled_tasks=cancelled_tasks,
)
except Exception as e:
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
command_id=task.command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
@@ -476,7 +415,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
@@ -585,6 +524,54 @@ def parse_thinking_models(
yield response
def _run_image_task(
task: ImageGeneration | ImageEdits,
image_model: DistributedImageModel,
shard_metadata: ShardMetadata,
event_sender: MpSender[Event],
cancel_receiver: MpReceiver[TaskId],
cancelled_tasks: set[TaskId],
) -> None:
task_id = task.task_id
command_id = task.command_id
def check_cancelled(task_id: TaskId = task_id) -> bool:
cancelled_tasks.update(cancel_receiver.collect())
return (task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
image_index = 0
for response in generate_image(
model=image_model,
task=task.task_params,
cancel_checker=check_cancelled,
):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
def _send_image_chunk(
encoded_data: str,
command_id: CommandId,

View File

@@ -49,10 +49,12 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_tg: TaskGroup | None = field(default=None, init=False)
_cancel_sender: MpSender[TaskId]
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -63,8 +65,8 @@ class RunnerSupervisor:
initialize_timeout: float = 400,
) -> Self:
ev_send, ev_recv = mp_channel[Event]()
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
target=entrypoint,
@@ -72,6 +74,7 @@ class RunnerSupervisor:
bound_instance,
ev_send,
task_recv,
cancel_recv,
logger,
),
daemon=True,
@@ -86,6 +89,7 @@ class RunnerSupervisor:
initialize_timeout=initialize_timeout,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
)
@@ -93,37 +97,41 @@ class RunnerSupervisor:
async def run(self):
self.runner_process.start()
async with create_task_group() as tg:
self._tg = tg
async with self._tg as tg:
tg.start_soon(self._forward_events)
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
await to_thread.run_sync(self.runner_process.join, 30)
if not self.runner_process.is_alive():
return
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.close()
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
await to_thread.run_sync(self.runner_process.join, 10)
if not self.runner_process.is_alive():
return
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
def shutdown(self):
assert self._tg
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
async def shutdown(self):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
@@ -131,6 +139,7 @@ class RunnerSupervisor:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -140,7 +149,13 @@ class RunnerSupervisor:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
logger.info(f"Finished task {task}")
async def cancel_task(self, task_id: TaskId):
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
self.cancelled.add(task_id)
await self._cancel_sender.send_async(task_id)
async def _forward_events(self):
with self._ev_recv as events:
@@ -206,4 +221,4 @@ class RunnerSupervisor:
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
)
)
self.shutdown()
await self.shutdown()