Compare commits

...

31 Commits

Author SHA1 Message Date
Alex Cheema
f94870f2c0 Merge remote-tracking branch 'origin/main' into alexcheema/prefill-progress
# Conflicts:
#	dashboard/src/lib/stores/app.svelte.ts
2026-02-05 05:38:14 -08:00
Alex Cheema
18928d4d9e Merge remote-tracking branch 'origin/main' into alexcheema/prefill-progress 2026-02-05 05:30:22 -08:00
Alex Cheema
937da476b0 address PR review comments: use model_dump_json, match/case, prefill_step_size=1024
- Use model_dump_json() for prefill progress SSE instead of hand-crafted JSON
- Convert isinstance chain to match/case in generate_chat_stream
- Set all prefill_step_size to 1024 (was 256 for testing, 2048 default)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 09:04:25 -08:00
Alex Cheema
1c1c286127 Merge uncertainty-visualization into prefill-progress
Combines prefill progress bar (PrefillProgressData, SSE named events)
with uncertainty visualization features (logprobs, tool calls, error
chunks, traces, image generation).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 06:18:54 -08:00
Alex Cheema
258785be84 Merge remote-tracking branch 'origin/main' into alexcheema/uncertainty-visualization 2026-02-03 06:03:01 -08:00
Alex Cheema
13a6b9819a fix: assistant prefilling for regenerate-from-token and tooltip UX
Support assistant message continuation by popping the last assistant
message before template formatting and appending its content raw,
keeping the turn open without a closing token.

Improve tooltip hover UX: use getClientRects() for correct multi-line
token positioning, add padding to bridge the hover gap, and increase
the hide delay.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 06:00:28 -08:00
Alex Cheema
1733d07cb3 fix: enable uncertainty visualization for regular chat messages
The sendMessage method was missing logprobs request params and token
collection, so the heatmap toggle never appeared. Also rename the
top_k parameter to top_logprobs in extract_top_logprobs to avoid
confusion with the sampling top_k parameter.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 05:08:49 -08:00
Alex Cheema
b3e4c9b1e5 fix: populate logprobs in non-streaming chat completions responses
collect_chat_response() was dropping logprobs data from TokenChunks,
so non-streaming requests never returned logprobs even when requested.
Accumulate LogprobsContentItems and attach them to the ChatCompletionChoice.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 04:45:39 -08:00
Alex Cheema
4c74792373 Merge branch 'main' into alexcheema/uncertainty-visualization 2026-02-03 04:44:14 -08:00
Alex Cheema
eadb6de1f7 Merge main into uncertainty-visualization branch
Resolve conflicts by keeping main's structure (TextGenerationTaskParams,
tool calling, KV prefix cache, Claude/OpenAI APIs) and surgically adding
the uncertainty visualization features on top:

- Add logprob/top_logprobs fields to GenerationResponse and TokenChunk
- Add extract_top_logprobs() to MLX generator for per-token logprob extraction
- Build Logprobs in chat completions adapter for streaming responses
- Add SSE headers (Cache-Control, Connection, X-Accel-Buffering) to streaming endpoints
- Add TokenHeatmap component and uncertainty toggle in dashboard
- Add logprobs collection in streaming response handler
- Add regenerateFromToken method for re-generation from specific tokens
- Strip token data from localStorage to avoid storage bloat

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 11:33:24 -08:00
Alex Cheema
7ba2408eed fix: restore dashboard build by using main's app.svelte.ts
The prefill-progress branch's app.svelte.ts was missing image generation
features from main. To fix the dashboard build, restored main's
app.svelte.ts and removed the uncertainty visualization and prefill
progress bar features from ChatMessages.svelte that depended on the
missing exports.

Note: TokenHeatmap and PrefillProgressBar components still exist but
are not currently used. The prefill progress backend code is still in
place and can be re-enabled in the dashboard once app.svelte.ts is
properly updated to include both image generation and prefill/uncertainty
features.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 12:07:16 +00:00
Alex Cheema
ce4d7f4d43 style: simplify prefill progress bar and use exo color palette
- Remove spinner (progress bar is dynamic enough)
- Use exo-yellow for progress bar fill
- Use exo-black/60 for progress bar background
- Use exo-light-gray for text

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 12:00:55 +00:00
Alex Cheema
6727523eab fix: wire prefill progress events to chat completions stream
- Move PrefillProgressData to shared types (chunks.py) to avoid circular imports
- Update generate_chat_stream adapter to handle both TokenChunk and PrefillProgressData
- Use _stream_events instead of _chat_chunk_stream for streaming endpoint
- Prefill progress now properly sent as SSE 'event: prefill_progress' to frontend

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 12:00:55 +00:00
Alex Cheema
a8f81e0495 feat: add prefill progress bar for long prompts
Shows real-time progress during prompt processing (prefill phase).
Progress is sent via SSE named events that maintain OpenAI API compatibility.

- Add PrefillProgress event type and PrefillProgressData dataclass
- Wire prompt_progress_callback through MLX stream_generate
- Send progress events directly from callback for real-time updates
- Add PrefillProgressBar.svelte component
- Parse event: prefill_progress SSE events in dashboard

Note: prefill_step_size temporarily set to 256 for testing (normally 2048)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:59:52 +00:00
Alex Cheema
ba7148ccec style: format app.svelte.ts with nix fmt 2026-01-22 11:53:43 +00:00
Alex Cheema
a64b8addc6 Fix localStorage quota issues by stripping tokens and auto-pruning
- Strip tokens (logprobs data) from messages before saving to localStorage
  since they're large and not essential for persistence
- Add pruneOldConversations() to automatically remove oldest conversations
  when quota is exceeded
- This prevents QuotaExceededError from crashing the app

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
e6599a9408 Fix ReferenceError: controller undefined in sendMessage finally block
Move AbortController creation before the try block in both
sendMessageWithLogprobs and regenerateFromToken functions.
Previously, controller was defined inside the try block but
referenced in the finally block, causing a ReferenceError
if an exception was thrown before the controller was created.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
93f4753598 Add SSE headers to properly close streaming connections
Add Cache-Control, Connection: close, and X-Accel-Buffering headers
to all SSE streaming responses. These headers help ensure:
- No caching of streaming responses
- Connection closes when stream ends (instead of keep-alive)
- No proxy buffering that could delay stream closure

This should fix the issue where the frontend stays on "PROCESSING"
even after receiving the complete response.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
75fe505275 Add debug logging to generate_chat_stream
Add logging to help diagnose why streaming might not be ending properly.
This will show when [DONE] is yielded, when return is called, and when
the finally block runs.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
d7c044e349 Fix streaming not ending after [DONE] is yielded
Add missing return statement after yielding [DONE] in generate_chat_stream.
Without this, the async generator continues waiting for more chunks from
chunk_stream even though generation is complete, causing the stream to hang
indefinitely. The frontend waits for the stream to close (reader.done) which
never happens, resulting in the chat button staying on "PROCESSING" forever.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
53b6d56e9f fix: restore extract_top_logprobs function for uncertainty visualization
The extract_top_logprobs function was lost during rebases. This function
processes the out.logprobs array (full vocabulary logprobs from MLX) to
extract the selected token's logprob and top-k alternatives.

The previous code tried to use getattr(out, "logprob", None) which
doesn't exist - mlx_lm returns logprobs as an mx.array, not individual
values.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
7fe0a61230 fix: remove unsupported logprob params from stream_generate
The mlx_lm.stream_generate already returns logprobs in its output -
we don't need to pass return_logprob or return_top_logprobs kwargs.
The uncertainty visualization feature extracts logprobs from the
existing out.logprobs field.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
5a36542631 feat: add uncertainty visualization with token-level logprobs
- Add TokenHeatmap component for visualizing token confidence
- Collect and stream logprobs in generation pipeline
- Add regenerate-from-token feature with continue_from_prefix
- Add AbortController for request cancellation
- Support continue_final_message for seamless prefix continuation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
955e0105b3 fix: resolve import and type errors from rebase
- Use claude_request_to_internal instead of old function name
- Fix ModelId imports in runner.py and test files
- Update test_mlx/conftest.py to use ResponsesRequest format
- Remove unused imports

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:36:11 +00:00
Evan
4d1eb1d9bd fix: rebase fix 2026-01-22 11:32:46 +00:00
Alex Cheema
365416c65e style: move inline imports to top of file in api.py
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:32:26 +00:00
Alex Cheema
04af76e10f fix: restore try/except structure in runner.py
Replace non-existent context manager with proper try/except block
and remove unused ModelId import.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:32:04 +00:00
Alex Cheema
a84c3431cd style: fix formatting issues caught by treefmt
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:31:45 +00:00
Alex Cheema
52445b21f6 refactor: use ResponsesRequest as canonical internal type
- Extend ResponsesRequest with fields: top_k, seed, stop, tools
- Remove redundant InternalTaskParams and InputMessage types
- Update all adapters to convert to ResponsesRequest
- Simplify Responses API (no conversion needed - native passthrough)
- Update all imports across codebase and tests

This eliminates type duplication and makes the Responses API
relationship explicit throughout the codebase.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:31:44 +00:00
Alex Cheema
435bd7f6fa refactor: make Responses API the canonical internal format
Restructure the API layer so that OpenAI Responses API is the native
format, with Chat Completions and Claude Messages as adapters on top.

Changes:
- Add new chat_completions.py adapter with streaming/non-streaming support
- Update responses.py with collect_responses_response() for non-streaming
- Update claude.py with collect_claude_response() for non-streaming
- Refactor api.py so all endpoints use adapters uniformly
- Rename _chat_chunk_stream to _token_chunk_stream (generic internal format)
- Remove unused chat_response_to_* converter functions
- Update tests to remove tests for deleted functions

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:30:27 +00:00
Alex Cheema
dd25b5b90e feat: add Claude Messages API and OpenAI Responses API support
Adds two new API endpoints that wrap the existing chat completions:

- /v1/messages - Claude Messages API compatible endpoint
- /v1/responses - OpenAI Responses API compatible endpoint

Both support streaming (SSE) and non-streaming modes with proper
token usage reporting from actual inference stats.

Also adds top_k sampling parameter and stop sequence support to the
MLX inference engine.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:28:49 +00:00
11 changed files with 225 additions and 76 deletions

View File

@@ -9,7 +9,6 @@
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";

View File

@@ -0,0 +1,51 @@
<script lang="ts">
import type { PrefillProgress } from "$lib/stores/app.svelte";
interface Props {
progress: PrefillProgress;
class?: string;
}
let { progress, class: className = "" }: Props = $props();
const percentage = $derived(
progress.total > 0
? Math.round((progress.processed / progress.total) * 100)
: 0,
);
function formatTokenCount(count: number): string {
if (count >= 1000) {
return `${(count / 1000).toFixed(1)}k`;
}
return count.toString();
}
</script>
<div class="prefill-progress {className}">
<div
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
>
<span>Processing prompt</span>
<span class="font-mono">
{formatTokenCount(progress.processed)} / {formatTokenCount(
progress.total,
)} tokens
</span>
</div>
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
<div
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
{percentage}%
</div>
</div>
<style>
.prefill-progress {
width: 100%;
}
</style>

View File

@@ -255,6 +255,12 @@ export interface TokenData {
topLogprobs: TopLogprob[];
}
export interface PrefillProgress {
processed: number;
total: number;
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";

View File

@@ -19,7 +19,12 @@ from exo.shared.types.api import (
StreamingChoiceResponse,
ToolCall,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressData,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
@@ -122,55 +127,65 @@ def chunk_to_response(
async def generate_chat_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
event_stream: AsyncGenerator[
PrefillProgressData | ErrorChunk | ToolCallChunk | TokenChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
"""Generate Chat Completions API streaming events from StreamEvents.
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
Handles PrefillProgressData, ErrorChunk, ToolCallChunk, and TokenChunk.
"""
async for event in event_stream:
match event:
case PrefillProgressData():
yield f"event: prefill_progress\ndata: {event.model_dump_json()}\n\n"
case ErrorChunk():
error_response = ErrorResponse(
error=ErrorInfo(
message=event.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
],
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
case ToolCallChunk():
tool_call_deltas = [
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(event.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=event.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
case TokenChunk():
chunk_response = chunk_to_response(event, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if event.finish_reason is not None:
yield "data: [DONE]\n\n"
async def collect_chat_response(

View File

@@ -103,6 +103,7 @@ from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
InputImageChunk,
PrefillProgressData,
TokenChunk,
ToolCallChunk,
)
@@ -132,6 +133,7 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
PrefillProgress,
TracesMerged,
)
from exo.shared.types.memory import Memory
@@ -213,7 +215,8 @@ class API:
)
self._text_generation_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
CommandId,
Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData],
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
@@ -510,22 +513,27 @@ class API:
instance_id=instance_id,
)
async def _token_chunk_stream(
async def _stream_events(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
"""Yield chunks for a given command until completion.
) -> AsyncGenerator[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData, None
]:
"""Yield stream events for a command.
This is the internal low-level stream used by all API adapters.
"""
try:
self._text_generation_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData
]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
with recv as events:
async for event in events:
yield event
if (
isinstance(event, TokenChunk)
and event.finish_reason is not None
):
break
except anyio.get_cancelled_exc_class():
@@ -542,6 +550,14 @@ class API:
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
async def _chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
"""Yield chunks, filtering out prefill progress events."""
async for event in self._stream_events(command_id):
if not isinstance(event, PrefillProgressData):
yield event
async def _collect_text_generation_with_stats(
self, command_id: CommandId
) -> BenchChatCompletionResponse:
@@ -552,7 +568,7 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._token_chunk_stream(command_id):
async for chunk in self._chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -624,7 +640,7 @@ class API:
return StreamingResponse(
generate_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
self._stream_events(command.command_id),
),
media_type="text/event-stream",
headers={
@@ -634,10 +650,13 @@ class API:
},
)
return await collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
)
try:
return await collect_chat_response(
command.command_id,
self._chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def bench_chat_completions(
self, payload: BenchChatCompletionRequest
@@ -1185,7 +1204,7 @@ class API:
generate_claude_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
self._chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
@@ -1195,11 +1214,14 @@ class API:
},
)
return await collect_claude_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
try:
return await collect_claude_response(
command.command_id,
payload.model,
self._chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def openai_responses(
self, payload: ResponsesRequest
@@ -1217,7 +1239,7 @@ class API:
generate_responses_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
self._chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
@@ -1227,11 +1249,14 @@ class API:
},
)
return await collect_responses_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
try:
return await collect_responses_response(
command.command_id,
payload.model,
self._chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
@@ -1373,6 +1398,20 @@ class API:
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
elif isinstance(event, PrefillProgress):
if queue := self._text_generation_queues.get(
event.command_id, None
):
try:
await queue.send(
PrefillProgressData(
processed_tokens=event.processed_tokens,
total_tokens=event.total_tokens,
)
)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)

View File

@@ -15,6 +15,7 @@ from exo.shared.types.events import (
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
PrefillProgress,
RunnerDeleted,
RunnerStatusUpdated,
TaskAcknowledged,
@@ -61,6 +62,7 @@ def event_apply(event: Event, state: State) -> State:
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| PrefillProgress()
| TracesCollected()
| TracesMerged()
): # Pass-through events that don't modify state

View File

@@ -77,3 +77,13 @@ class InputImageChunk(BaseChunk):
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
class PrefillProgressData(TaggedModel):
"""Data class for prefill progress events during streaming."""
processed_tokens: int
total_tokens: int
StreamEvent = TokenChunk | PrefillProgressData

View File

@@ -102,6 +102,12 @@ class InputChunkReceived(BaseEvent):
chunk: InputImageChunk
class PrefillProgress(BaseEvent):
command_id: CommandId
processed_tokens: int
total_tokens: int
class TopologyEdgeCreated(BaseEvent):
conn: Connection
@@ -148,6 +154,7 @@ Event = (
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| PrefillProgress
| TopologyEdgeCreated
| TopologyEdgeDeleted
| TracesCollected

View File

@@ -66,3 +66,8 @@ class ToolCallResponse(BaseRunnerResponse):
class FinishedResponse(BaseRunnerResponse):
pass
class PrefillProgressResponse(BaseRunnerResponse):
processed_tokens: int
total_tokens: int

View File

@@ -79,7 +79,7 @@ def prefill(
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
prefill_step_size=1024,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
@@ -127,7 +127,7 @@ def warmup_inference(
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
prefill_step_size=1024,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
@@ -221,6 +221,7 @@ def mlx_generate(
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
on_prefill_progress: Callable[[int, int], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -292,9 +293,10 @@ def mlx_generate(
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
prefill_step_size=1024,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=on_prefill_progress,
),
start=1,
):

View File

@@ -25,6 +25,7 @@ from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
PrefillProgress,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
@@ -261,6 +262,17 @@ def main(
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
# Define callback to send prefill progress events directly
def on_prefill_progress(processed: int, total: int) -> None:
if device_rank == 0:
event_sender.send(
PrefillProgress(
command_id=command_id,
processed_tokens=processed,
total_tokens=total,
)
)
try:
_check_for_debug_prompts(task_params)
@@ -274,6 +286,7 @@ def main(
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
)
# For other thinking models (GLM, etc.), check if we need to