Compare commits

..

1 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
b0e6a6ac08 fix gpt oss tool calling 2026-02-13 21:23:26 +00:00
12 changed files with 81 additions and 222 deletions

View File

@@ -9,6 +9,7 @@
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";

View File

@@ -1,51 +0,0 @@
<script lang="ts">
import type { PrefillProgress } from "$lib/stores/app.svelte";
interface Props {
progress: PrefillProgress;
class?: string;
}
let { progress, class: className = "" }: Props = $props();
const percentage = $derived(
progress.total > 0
? Math.round((progress.processed / progress.total) * 100)
: 0,
);
function formatTokenCount(count: number): string {
if (count >= 1000) {
return `${(count / 1000).toFixed(1)}k`;
}
return count.toString();
}
</script>
<div class="prefill-progress {className}">
<div
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
>
<span>Processing prompt</span>
<span class="font-mono">
{formatTokenCount(progress.processed)} / {formatTokenCount(
progress.total,
)} tokens
</span>
</div>
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
<div
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
{percentage}%
</div>
</div>
<style>
.prefill-progress {
width: 100%;
}
</style>

View File

@@ -273,11 +273,6 @@ export interface TokenData {
topLogprobs: TopLogprob[];
}
export interface PrefillProgress {
processed: number;
total: number;
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";

View File

@@ -18,12 +18,7 @@ from exo.shared.types.api import (
StreamingChoiceResponse,
ToolCall,
)
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressData,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
@@ -127,65 +122,55 @@ def chunk_to_response(
async def generate_chat_stream(
command_id: CommandId,
event_stream: AsyncGenerator[
PrefillProgressData | ErrorChunk | ToolCallChunk | TokenChunk, None
],
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from StreamEvents.
Handles PrefillProgressData, ErrorChunk, ToolCallChunk, and TokenChunk.
"""
async for event in event_stream:
match event:
case PrefillProgressData():
yield f"event: prefill_progress\ndata: {event.model_dump_json()}\n\n"
case ErrorChunk():
error_response = ErrorResponse(
error=ErrorInfo(
message=event.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
"""Generate Chat Completions API streaming events from chunks."""
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
case ToolCallChunk():
tool_call_deltas = [
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(event.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=event.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
id=tool.id,
index=i,
function=tool,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
case TokenChunk():
chunk_response = chunk_to_response(event, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
chunk_response = chunk_to_response(chunk, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if event.finish_reason is not None:
yield "data: [DONE]\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def collect_chat_response(

View File

@@ -105,7 +105,6 @@ from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
InputImageChunk,
PrefillProgressData,
TokenChunk,
ToolCallChunk,
)
@@ -135,7 +134,6 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
PrefillProgress,
TracesMerged,
)
from exo.shared.types.memory import Memory
@@ -219,8 +217,7 @@ class API:
)
self._text_generation_queues: dict[
CommandId,
Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData],
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
@@ -524,27 +521,22 @@ class API:
instance_id=instance_id,
)
async def _stream_events(
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData, None
]:
"""Yield stream events for a command.
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
"""Yield chunks for a given command until completion.
This is the internal low-level stream used by all API adapters.
"""
try:
self._text_generation_queues[command_id], recv = channel[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData
ErrorChunk | ToolCallChunk | TokenChunk
]()
with recv as events:
async for event in events:
yield event
if (
isinstance(event, TokenChunk)
and event.finish_reason is not None
):
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
@@ -561,14 +553,6 @@ class API:
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
async def _chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
"""Yield chunks, filtering out prefill progress events."""
async for event in self._stream_events(command_id):
if not isinstance(event, PrefillProgressData):
yield event
async def _collect_text_generation_with_stats(
self, command_id: CommandId
) -> BenchChatCompletionResponse:
@@ -579,7 +563,7 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chunk_stream(command_id):
async for chunk in self._token_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -651,7 +635,7 @@ class API:
return StreamingResponse(
generate_chat_stream(
command.command_id,
self._stream_events(command.command_id),
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
@@ -661,13 +645,10 @@ class API:
},
)
try:
return await collect_chat_response(
command.command_id,
self._chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
return await collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
)
async def bench_chat_completions(
self, payload: BenchChatCompletionRequest
@@ -1231,7 +1212,7 @@ class API:
generate_claude_stream(
command.command_id,
payload.model,
self._chunk_stream(command.command_id),
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
@@ -1241,14 +1222,11 @@ class API:
},
)
try:
return await collect_claude_response(
command.command_id,
payload.model,
self._chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
return await collect_claude_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
async def openai_responses(
self, payload: ResponsesRequest
@@ -1266,7 +1244,7 @@ class API:
generate_responses_stream(
command.command_id,
payload.model,
self._chunk_stream(command.command_id),
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
@@ -1276,14 +1254,11 @@ class API:
},
)
try:
return await collect_responses_response(
command.command_id,
payload.model,
self._chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
return await collect_responses_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
@@ -1437,20 +1412,6 @@ class API:
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
elif isinstance(event, PrefillProgress):
if queue := self._text_generation_queues.get(
event.command_id, None
):
try:
await queue.send(
PrefillProgressData(
processed_tokens=event.processed_tokens,
total_tokens=event.total_tokens,
)
)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)

View File

@@ -15,7 +15,6 @@ from exo.shared.types.events import (
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
PrefillProgress,
RunnerDeleted,
RunnerStatusUpdated,
TaskAcknowledged,
@@ -65,7 +64,6 @@ def event_apply(event: Event, state: State) -> State:
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| PrefillProgress()
| TracesCollected()
| TracesMerged()
): # Pass-through events that don't modify state

View File

@@ -77,13 +77,3 @@ class InputImageChunk(BaseChunk):
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
class PrefillProgressData(TaggedModel):
"""Data class for prefill progress events during streaming."""
processed_tokens: int
total_tokens: int
StreamEvent = TokenChunk | PrefillProgressData

View File

@@ -102,12 +102,6 @@ class InputChunkReceived(BaseEvent):
chunk: InputImageChunk
class PrefillProgress(BaseEvent):
command_id: CommandId
processed_tokens: int
total_tokens: int
class TopologyEdgeCreated(BaseEvent):
conn: Connection
@@ -154,7 +148,6 @@ Event = (
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| PrefillProgress
| TopologyEdgeCreated
| TopologyEdgeDeleted
| TracesCollected

View File

@@ -66,8 +66,3 @@ class ToolCallResponse(BaseRunnerResponse):
class FinishedResponse(BaseRunnerResponse):
pass
class PrefillProgressResponse(BaseRunnerResponse):
processed_tokens: int
total_tokens: int

View File

@@ -160,7 +160,7 @@ def warmup_inference(
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=1024,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
@@ -252,7 +252,6 @@ def mlx_generate(
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
on_prefill_progress: Callable[[int, int], None] | None = None,
group: mx.distributed.Group | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
@@ -346,7 +345,6 @@ def mlx_generate(
prefill_step_size=1,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=on_prefill_progress,
),
start=1,
):

View File

@@ -316,6 +316,8 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
return [151336, 151329, 151338]
elif "gpt-oss" in model_id_lower:
return [200002, 200012]
return None

View File

@@ -11,6 +11,7 @@ from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
HarmonyError, # pyright: ignore[reportUnknownVariableType]
Role,
StreamableParser,
load_harmony_encoding,
@@ -26,7 +27,6 @@ from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
PrefillProgress,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
@@ -267,17 +267,6 @@ def main(
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
# Define callback to send prefill progress events directly
def on_prefill_progress(processed: int, total: int) -> None:
if device_rank == 0:
event_sender.send(
PrefillProgress(
command_id=command_id,
processed_tokens=processed,
total_tokens=total,
)
)
try:
_check_for_debug_prompts(task_params)
@@ -291,7 +280,6 @@ def main(
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
group=group,
)
@@ -581,7 +569,11 @@ def parse_gpt_oss(
for response in responses:
assert isinstance(response, GenerationResponse)
stream.process(response.token)
try:
stream.process(response.token)
except HarmonyError:
logger.error("Encountered critical Harmony Error, returning early")
return
delta = stream.last_content_delta
ch = stream.current_channel