mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-27 15:33:26 -05:00
Compare commits
1 Commits
rust-explo
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3207a786d |
@@ -240,6 +240,16 @@ def main(
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
# Kimi-K2 has tool call sections and tool calls that need to be
|
||||
# handled BEFORE parse_gpt_oss, because parse_gpt_oss transforms
|
||||
# the token stream into text deltas which prevents tool call detection
|
||||
is_kimi = "kimi" in shard_metadata.model_card.model_id.lower()
|
||||
if is_kimi:
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
mlx_generator = filter_kimi_tokens_and_parse_tool_calls(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
@@ -251,16 +261,13 @@ def main(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
if "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
# Generic tool call parsing for non-Kimi models
|
||||
# (Kimi tool calls are parsed earlier, before parse_gpt_oss)
|
||||
if tokenizer.has_tool_calling and not is_kimi:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
@@ -488,26 +495,88 @@ def get_gpt_oss_encoding():
|
||||
return encoding
|
||||
|
||||
|
||||
def filter_kimi_tokens(
|
||||
def filter_kimi_tokens_and_parse_tool_calls(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
tokenizer: TokenizerWrapper,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
"""
|
||||
Filter Kimi-specific tokens and parse tool calls at the raw token level.
|
||||
|
||||
Must be called BEFORE parse_gpt_oss because parse_gpt_oss transforms
|
||||
the token stream into text deltas, which prevents tool call token detection.
|
||||
"""
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
|
||||
for resp in responses:
|
||||
# Filter out section markers
|
||||
if (
|
||||
resp.text == "<|tool_calls_section_begin|>"
|
||||
or resp.text == "<|tool_calls_section_end|>"
|
||||
):
|
||||
continue
|
||||
|
||||
# Handle tool calls if tokenizer supports them
|
||||
if tokenizer.has_tool_calling:
|
||||
# Check for tool call start
|
||||
if resp.text == tokenizer.tool_call_start:
|
||||
in_tool_call = True
|
||||
continue
|
||||
|
||||
# Check for tool call end
|
||||
if in_tool_call and resp.text == tokenizer.tool_call_end:
|
||||
try:
|
||||
parsed = tokenizer.tool_parser( # pyright: ignore[reportAny]
|
||||
"".join(tool_call_text_parts).strip()
|
||||
)
|
||||
logger.info(f"parsed kimi {tool_call_text_parts=} into {parsed=}")
|
||||
if isinstance(parsed, list):
|
||||
tools = [
|
||||
_validate_single_tool(tool) # pyright: ignore[reportUnknownArgumentType]
|
||||
for tool in parsed # pyright: ignore[reportUnknownVariableType]
|
||||
]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)] # pyright: ignore[reportAny]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
ValidationError,
|
||||
ValueError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
# Kimi-K2-Thinking emits internal tool calls during thinking
|
||||
# (format: UUID<|tool_call_argument_begin|>args) that don't match
|
||||
# the external tool call format (functions.name:0 <|...|> args).
|
||||
# Silently filter these out instead of showing raw markers.
|
||||
logger.opt(exception=e).debug(
|
||||
"kimi internal tool call filtered (not an external tool call)"
|
||||
)
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
# Accumulate tool call text
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(resp.text)
|
||||
continue
|
||||
|
||||
yield resp
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
for response in responses:
|
||||
# Pass through ToolCallResponses unchanged
|
||||
if isinstance(response, ToolCallResponse):
|
||||
yield response
|
||||
continue
|
||||
|
||||
stream.process(response.token)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
@@ -532,9 +601,9 @@ def parse_gpt_oss(
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse],
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
) -> Generator[GenerationResponse]:
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
"""
|
||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
||||
prepend the thinking tag to the output stream so the frontend
|
||||
@@ -542,6 +611,11 @@ def parse_thinking_models(
|
||||
"""
|
||||
first = True
|
||||
for response in responses:
|
||||
# Pass through ToolCallResponses unchanged
|
||||
if isinstance(response, ToolCallResponse):
|
||||
yield response
|
||||
continue
|
||||
|
||||
if first:
|
||||
first = False
|
||||
yield response.model_copy(
|
||||
@@ -622,7 +696,7 @@ def _process_image_response(
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse],
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
@@ -630,6 +704,11 @@ def parse_tool_calls(
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
# Pass through ToolCallResponses unchanged
|
||||
if isinstance(response, ToolCallResponse):
|
||||
yield response
|
||||
continue
|
||||
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tool_call_start:
|
||||
in_tool_call = True
|
||||
|
||||
Reference in New Issue
Block a user