Compare commits

..

3 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
905fd5e900 fix prompt tokens reporting mistake 2026-02-12 19:00:07 +00:00
rltakashige
1f19619e1e Merge branch 'main' into leo/fix-usage-metrics 2026-02-12 18:28:46 +00:00
Ryuichi Leo Takashige
c4da2bc211 Pass usage and generation stats through all adapters correctly 2026-02-12 18:27:33 +00:00
6 changed files with 52 additions and 33 deletions

View File

@@ -17,6 +17,7 @@ from exo.shared.types.api import (
LogprobsContentItem,
StreamingChoiceResponse,
ToolCall,
Usage,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
@@ -125,6 +126,8 @@ async def generate_chat_stream(
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
@@ -138,6 +141,8 @@ async def generate_chat_stream(
yield "data: [DONE]\n\n"
return
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
@@ -161,12 +166,15 @@ async def generate_chat_stream(
finish_reason="tool_calls",
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
@@ -184,6 +192,7 @@ async def collect_chat_response(
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
@@ -193,6 +202,8 @@ async def collect_chat_response(
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
@@ -241,4 +252,5 @@ async def collect_chat_response(
finish_reason=finish_reason,
)
],
usage=last_usage,
)

View File

@@ -4,7 +4,7 @@ import json
from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.api import FinishReason
from exo.shared.types.api import FinishReason, Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.claude_api import (
ClaudeContentBlock,
@@ -166,7 +166,7 @@ async def collect_claude_response(
text_parts: list[str] = []
tool_use_blocks: list[ClaudeToolUseBlock] = []
stop_reason: ClaudeStopReason | None = None
last_stats = None
last_usage: Usage | None = None
error_message: str | None = None
async for chunk in chunk_stream:
@@ -174,6 +174,8 @@ async def collect_claude_response(
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
tool_use_blocks.append(
@@ -183,12 +185,10 @@ async def collect_claude_response(
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
)
)
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
continue
text_parts.append(chunk.text)
last_stats = chunk.stats or last_stats
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
@@ -208,9 +208,9 @@ async def collect_claude_response(
if not content:
content.append(ClaudeTextBlock(text=""))
# Use actual usage data from stats if available
input_tokens = last_stats.prompt_tokens if last_stats else 0
output_tokens = last_stats.generation_tokens if last_stats else 0
# Use actual usage data if available
input_tokens = last_usage.prompt_tokens if last_usage else 0
output_tokens = last_usage.completion_tokens if last_usage else 0
return ClaudeMessagesResponse(
id=f"msg_{command_id}",
@@ -249,7 +249,7 @@ async def generate_claude_stream(
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
last_usage: Usage | None = None
next_block_index = 1 # text block is 0, tool blocks start at 1
async for chunk in chunk_stream:
@@ -257,8 +257,9 @@ async def generate_claude_stream(
# Close text block and bail
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
# Emit tool_use content blocks
@@ -290,7 +291,6 @@ async def generate_claude_stream(
continue
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
@@ -302,9 +302,9 @@ async def generate_claude_stream(
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# Use actual token count from usage if available
if last_usage is not None:
output_tokens = last_usage.completion_tokens
# content_block_stop for text block
block_stop = ClaudeContentBlockStopEvent(index=0)

View File

@@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
from itertools import count
from typing import Any
from exo.shared.types.api import Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
@@ -127,7 +128,7 @@ async def collect_responses_response(
item_id = f"item_{command_id}"
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
last_usage: Usage | None = None
error_message: str | None = None
async for chunk in chunk_stream:
@@ -135,6 +136,8 @@ async def collect_responses_response(
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
function_call_items.append(
@@ -145,22 +148,20 @@ async def collect_responses_response(
arguments=tool.arguments,
)
)
last_stats = chunk.stats or last_stats
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
if error_message is not None:
raise ValueError(error_message)
# Create usage from stats if available
# Create usage from usage data if available
usage = None
if last_stats is not None:
if last_usage is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
input_tokens=last_usage.prompt_tokens,
output_tokens=last_usage.completion_tokens,
total_tokens=last_usage.total_tokens,
)
output: list[ResponseItem] = [
@@ -235,15 +236,16 @@ async def generate_responses_stream(
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
last_usage: Usage | None = None
next_output_index = 1 # message item is at 0
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
for tool in chunk.tool_calls:
fc_id = f"fc_{tool.id}"
call_id = f"call_{tool.id}"
@@ -302,7 +304,6 @@ async def generate_responses_stream(
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
@@ -346,13 +347,13 @@ async def generate_responses_stream(
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
# Create usage from usage data if available
usage = None
if last_stats is not None:
if last_usage is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
input_tokens=last_usage.prompt_tokens,
output_tokens=last_usage.completion_tokens,
total_tokens=last_usage.total_tokens,
)
# response.completed

View File

@@ -62,6 +62,7 @@ class PartialImageResponse(BaseRunnerResponse):
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
usage: Usage | None
stats: GenerationStats | None = None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -393,10 +393,11 @@ def mlx_generate(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
total_prompt_tokens = len(all_prompt_tokens)
usage = Usage(
prompt_tokens=int(out.prompt_tokens),
prompt_tokens=total_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=int(out.prompt_tokens) + completion_tokens,
total_tokens=total_prompt_tokens + completion_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=prefix_hit_length
),

View File

@@ -364,6 +364,7 @@ def main(
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
usage=response.usage,
stats=response.stats,
),
)
)
@@ -764,7 +765,9 @@ def parse_tool_calls(
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
yield ToolCallResponse(
tool_calls=tools, usage=response.usage, stats=response.stats
)
except (
json.JSONDecodeError,
@@ -795,7 +798,8 @@ def parse_tool_calls(
text=tool_call_start + "".join(tool_call_text_parts),
token=0,
finish_reason=response.finish_reason,
usage=None,
usage=response.usage,
stats=response.stats,
)
continue
# fallthrough