Compare commits

...

1 Commits

Author SHA1 Message Date
Alex Cheema
df240f834d Fix GLM and Kimi tool calling crashes (#1255)
## Motivation

Fixes tool calling crashes with GLM-4.7-Flash and Kimi-K2 models.

Related: #1254

Two distinct issues were causing crashes:
1. **Tool parser crashes** - The upstream GLM47 and Kimi tool parsers
call `.group()` on regex matches without checking for `None`, causing
`AttributeError` when the model outputs malformed tool calls
2. **Chat template crashes** - GLM's chat template expects
`tool_calls[].function.arguments` to be a dict, but OpenAI format
provides it as a JSON string, causing `'str object' has no attribute
'items'`

## Changes

**`src/exo/worker/runner/runner.py`:**
- Add `patch_glm_tokenizer()` - fixed version of mlx_lm's glm47 parser
with None checks
- Fix `patch_kimi_tokenizer()` - add None checks before calling
`.group()` on regex matches
- Add `ValueError` and `AttributeError` to exception handling in
`parse_tool_calls()`

**`src/exo/worker/engines/mlx/utils_mlx.py`:**
- Add `_normalize_tool_calls()` - parses
`tool_calls[].function.arguments` from JSON string to dict for templates
that expect dicts (like GLM-4.7-Flash)

## Why It Works

1. **Parser fixes**: By checking if regex matches are `None` before
calling `.group()`, we can raise a proper `ValueError` instead of
crashing with `AttributeError`

2. **Template fix**: The GLM-4.7-Flash chat template iterates over
arguments with `.items()`:
   ```jinja2
   {% set _args = tc.arguments %}{% for k, v in _args.items() %}
   ```
OpenAI format has `arguments` as a JSON string.
`_normalize_tool_calls()` parses this to a dict before passing to the
template.

## Test Plan

### Manual Testing
- Hardware: Mac with GLM-4.7-Flash-4bit model
- Tested tool calling with GLM model - no longer crashes

### Automated Testing
- Existing tests pass (`uv run pytest`)
- Type checking passes (`uv run basedpyright`)
- Linting passes (`uv run ruff check`)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-01-23 01:39:59 +00:00
2 changed files with 122 additions and 8 deletions

View File

@@ -365,12 +365,35 @@ def load_tokenizer_for_model_id(
return tokenizer
def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
"""
Normalize tool_calls in a message dict.
OpenAI format has tool_calls[].function.arguments as a JSON string,
but some chat templates (e.g., GLM) expect it as a dict.
"""
tool_calls = msg_dict.get("tool_calls")
if not tool_calls or not isinstance(tool_calls, list):
return
for tc in tool_calls: # pyright: ignore[reportUnknownVariableType]
if not isinstance(tc, dict):
continue
func = tc.get("function") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if not isinstance(func, dict):
continue
args = func.get("arguments") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if isinstance(args, str):
with contextlib.suppress(json.JSONDecodeError):
func["arguments"] = json.loads(args)
def apply_chat_template(
tokenizer: TokenizerWrapper,
chat_task_data: ChatCompletionTaskParams,
) -> str:
# Now we can properly access the messages
messages = chat_task_data.messages
tools = chat_task_data.tools
formatted_messages: list[dict[str, Any]] = []
for message in messages:
@@ -386,15 +409,19 @@ def apply_chat_template(
continue
# Null values are not valid when applying templates in tokenizer
formatted_messages.append(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
)
dumped: dict[str, Any] = message.model_dump()
msg_dict: dict[str, Any] = {k: v for k, v in dumped.items() if v is not None} # pyright: ignore[reportAny]
# Parse tool_calls arguments from JSON string to dict for templates that expect dicts
_normalize_tool_calls(msg_dict)
formatted_messages.append(msg_dict)
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
tools=tools,
)
logger.info(prompt)

View File

@@ -256,6 +256,10 @@ def main(
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
if "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
@@ -645,7 +649,14 @@ def parse_tool_calls(
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools)
except (json.JSONDecodeError, ValidationError) as e:
except (
json.JSONDecodeError,
ValidationError,
ValueError,
AttributeError,
) as e:
# ValueError: our parsers raise this for malformed tool calls
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
@@ -698,11 +709,17 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name = _func_name_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
# strip off the `functions.` prefix, if it exists.
func_name = func_name[func_name.find(".") + 1 :]
func_args = _func_arg_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError(f"Could not parse function args from tool call: {text!r}")
func_args = func_args_match.group(1)
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
@@ -713,6 +730,76 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""
import ast
import json
from typing import Any
import regex as re
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
_func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
re.DOTALL,
)
tool_call_start = "<tool_call>"
tool_call_end = "</tool_call>"
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[Any] | None,
) -> bool:
if tools is None:
return False
for tool in tools: # pyright: ignore[reportAny]
func = tool["function"] # pyright: ignore[reportAny]
if func["name"] == tool_name:
params = func["parameters"] # pyright: ignore[reportAny]
if params is None:
return False
props = params.get("properties", {}) # pyright: ignore[reportAny]
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
return arg_type == "string" # pyright: ignore[reportAny]
return False
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: list[Any] | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
pairs = _func_arg_regex.findall(text)
arg_dct: dict[str, Any] = {}
for key, value in pairs: # pyright: ignore[reportAny]
arg_key = key.strip() # pyright: ignore[reportAny]
arg_val = value.strip() # pyright: ignore[reportAny]
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
arg_dct[arg_key] = arg_val
return dict(name=func_name, arguments=arg_dct)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)