Compare commits

...

1 Commits

Author SHA1 Message Date
Alex Cheema
f3207a786d Fix Kimi-K2-Thinking tool calls appearing in thinking output
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 02:15:23 +00:00

View File

@@ -240,6 +240,16 @@ def main(
prompt=prompt,
)
# Kimi-K2 has tool call sections and tool calls that need to be
# handled BEFORE parse_gpt_oss, because parse_gpt_oss transforms
# the token stream into text deltas which prevents tool call detection
is_kimi = "kimi" in shard_metadata.model_card.model_id.lower()
if is_kimi:
patch_kimi_tokenizer(tokenizer)
mlx_generator = filter_kimi_tokens_and_parse_tool_calls(
mlx_generator, tokenizer
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
@@ -251,16 +261,13 @@ def main(
mlx_generator, tokenizer
)
# Kimi-K2 has tool call sections - we don't care about them
if "kimi" in shard_metadata.model_card.model_id.lower():
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
if "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
if tokenizer.has_tool_calling:
# Generic tool call parsing for non-Kimi models
# (Kimi tool calls are parsed earlier, before parse_gpt_oss)
if tokenizer.has_tool_calling and not is_kimi:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
@@ -488,26 +495,88 @@ def get_gpt_oss_encoding():
return encoding
def filter_kimi_tokens(
def filter_kimi_tokens_and_parse_tool_calls(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse | ToolCallResponse]:
"""
Filter Kimi-specific tokens and parse tool calls at the raw token level.
Must be called BEFORE parse_gpt_oss because parse_gpt_oss transforms
the token stream into text deltas, which prevents tool call token detection.
"""
in_tool_call = False
tool_call_text_parts: list[str] = []
for resp in responses:
# Filter out section markers
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
):
continue
# Handle tool calls if tokenizer supports them
if tokenizer.has_tool_calling:
# Check for tool call start
if resp.text == tokenizer.tool_call_start:
in_tool_call = True
continue
# Check for tool call end
if in_tool_call and resp.text == tokenizer.tool_call_end:
try:
parsed = tokenizer.tool_parser( # pyright: ignore[reportAny]
"".join(tool_call_text_parts).strip()
)
logger.info(f"parsed kimi {tool_call_text_parts=} into {parsed=}")
if isinstance(parsed, list):
tools = [
_validate_single_tool(tool) # pyright: ignore[reportUnknownArgumentType]
for tool in parsed # pyright: ignore[reportUnknownVariableType]
]
else:
tools = [_validate_single_tool(parsed)] # pyright: ignore[reportAny]
yield ToolCallResponse(tool_calls=tools)
except (
json.JSONDecodeError,
ValidationError,
ValueError,
AttributeError,
) as e:
# Kimi-K2-Thinking emits internal tool calls during thinking
# (format: UUID<|tool_call_argument_begin|>args) that don't match
# the external tool call format (functions.name:0 <|...|> args).
# Silently filter these out instead of showing raw markers.
logger.opt(exception=e).debug(
"kimi internal tool call filtered (not an external tool call)"
)
in_tool_call = False
tool_call_text_parts = []
continue
# Accumulate tool call text
if in_tool_call:
tool_call_text_parts.append(resp.text)
continue
yield resp
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
for response in responses:
# Pass through ToolCallResponses unchanged
if isinstance(response, ToolCallResponse):
yield response
continue
stream.process(response.token)
delta = stream.last_content_delta
@@ -532,9 +601,9 @@ def parse_gpt_oss(
def parse_thinking_models(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse]:
) -> Generator[GenerationResponse | ToolCallResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
@@ -542,6 +611,11 @@ def parse_thinking_models(
"""
first = True
for response in responses:
# Pass through ToolCallResponses unchanged
if isinstance(response, ToolCallResponse):
yield response
continue
if first:
first = False
yield response.model_copy(
@@ -622,7 +696,7 @@ def _process_image_response(
def parse_tool_calls(
responses: Generator[GenerationResponse],
responses: Generator[GenerationResponse | ToolCallResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
@@ -630,6 +704,11 @@ def parse_tool_calls(
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
# Pass through ToolCallResponses unchanged
if isinstance(response, ToolCallResponse):
yield response
continue
# assumption: the tool call start is one token
if response.text == tool_call_start:
in_tool_call = True