Compare commits

...

3 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
1f2f3571cf clarify todos 2026-02-19 03:26:37 +00:00
Ryuichi Leo Takashige
efcb6b9393 Add TODOs 2026-02-19 03:24:11 +00:00
Ryuichi Leo Takashige
72fd7f0e1f Add cancellation button and the ability to cancel during prefill 2026-02-19 03:21:48 +00:00
4 changed files with 153 additions and 98 deletions

View File

@@ -14,6 +14,7 @@
totalTokens,
thinkingEnabled as thinkingEnabledStore,
setConversationThinking,
stopGeneration,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
@@ -653,86 +654,92 @@
style="min-height: 28px; max-height: 150px;"
></textarea>
<button
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
{#if loading}
<button
type="button"
onclick={() => stopGeneration()}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-exo-medium-gray hover:text-white"
aria-label="Stop generation"
>
<span class="inline-flex items-center gap-1 sm:gap-2">
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<span class="sm:hidden">...</span>
</span>
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
class="w-3 h-3 sm:w-3.5 sm:h-3.5"
fill="currentColor"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
<rect x="6" y="6" width="12" height="12" rx="1" />
</svg>
<span>EDIT</span>
<span class="hidden sm:inline">Cancel</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</button>
{:else}
<button
type="submit"
disabled={!canSend || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
{/if}
</div>
<!-- Bottom accent line -->

View File

@@ -527,6 +527,9 @@ class AppStore {
totalTokens = $state<number>(0); // Total tokens in current response
prefillProgress = $state<PrefillProgress | null>(null);
// Abort controller for stopping generation
private currentAbortController: AbortController | null = null;
// Topology state
topologyData = $state<TopologyData | null>(null);
instances = $state<Record<string, unknown>>({});
@@ -2281,6 +2284,9 @@ class AppStore {
let firstTokenTime: number | null = null;
let tokenCount = 0;
const abortController = new AbortController();
this.currentAbortController = abortController;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: {
@@ -2297,6 +2303,7 @@ class AppStore {
enable_thinking: enableThinking,
}),
}),
signal: abortController.signal,
});
if (!response.ok) {
@@ -2451,20 +2458,31 @@ class AppStore {
this.persistConversation(targetConversationId);
}
} catch (error) {
console.error("Error sending message:", error);
this.handleStreamingError(
error,
targetConversationId,
assistantMessage.id,
"Failed to get response",
);
if (error instanceof DOMException && error.name === "AbortError") {
// User stopped generation — not an error
} else {
console.error("Error sending message:", error);
this.handleStreamingError(
error,
targetConversationId,
assistantMessage.id,
"Failed to get response",
);
}
} finally {
this.currentAbortController = null;
this.prefillProgress = null;
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
}
}
stopGeneration(): void {
this.currentAbortController?.abort();
this.currentAbortController = null;
}
/**
* Generate an image using the image generation API
*/
@@ -3109,6 +3127,7 @@ export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
// Actions
export const stopGeneration = () => appStore.stopGeneration();
export const startChat = () => appStore.startChat();
export const sendMessage = (
content: string,

View File

@@ -51,6 +51,10 @@ generation_stream = mx.new_stream(mx.default_device())
_MIN_PREFIX_HIT_RATIO_TO_UPDATE = 0.5
class PrefillCancelled(BaseException):
"""Raised when prefill is cancelled via the progress callback."""
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
@@ -66,7 +70,7 @@ def prefill(
then trims off the extra generated token.
Returns:
tokens_per_sec
(tokens_per_sec, num_tokens, snapshots)
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
@@ -77,6 +81,7 @@ def prefill(
has_ssm = has_non_kv_caches(cache)
snapshots: list[CacheSnapshot] = []
# TODO(evan): kill the callbacks/runner refactor
def progress_callback(processed: int, total: int) -> None:
elapsed = time.perf_counter() - start_time
tok_per_sec = processed / elapsed if elapsed > 0 else 0
@@ -96,19 +101,23 @@ def prefill(
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=4096,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
try:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=4096,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
except PrefillCancelled:
set_pipeline_prefill(model, is_prefill=False)
raise
set_pipeline_prefill(model, is_prefill=False)

View File

@@ -82,7 +82,11 @@ from exo.worker.engines.image import (
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.generator.generate import (
PrefillCancelled,
mlx_generate,
warmup_inference,
)
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
detect_thinking_prompt_suffix,
@@ -299,8 +303,16 @@ def main(
assert tokenizer
assert check_for_cancel_every
# Define callback to send prefill progress events directly
def on_prefill_progress(processed: int, total: int) -> None:
# Define callback to send prefill progress events
# and check for cancellation between prefill chunks.
# TODO(evan): kill the callbacks/runner refactor
# Specifically the part that this is literally duplicated code.
def on_prefill_progress(
processed: int,
total: int,
_task_id: TaskId = task.task_id,
_group: mx.distributed.Group | None = group,
) -> None:
if device_rank == 0:
event_sender.send(
PrefillProgress(
@@ -310,6 +322,12 @@ def main(
total_tokens=total,
)
)
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (_task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, _group):
raise PrefillCancelled()
try:
_check_for_debug_prompts(task_params)
@@ -406,6 +424,8 @@ def main(
)
)
except PrefillCancelled:
logger.info(f"Prefill cancelled for task {task.task_id}")
# can we make this more explicit?
except Exception as e:
if device_rank == 0: