Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
df240f834d Fix GLM and Kimi tool calling crashes (#1255)
## Motivation

Fixes tool calling crashes with GLM-4.7-Flash and Kimi-K2 models.

Related: #1254

Two distinct issues were causing crashes:
1. **Tool parser crashes** - The upstream GLM47 and Kimi tool parsers
call `.group()` on regex matches without checking for `None`, causing
`AttributeError` when the model outputs malformed tool calls
2. **Chat template crashes** - GLM's chat template expects
`tool_calls[].function.arguments` to be a dict, but OpenAI format
provides it as a JSON string, causing `'str object' has no attribute
'items'`

## Changes

**`src/exo/worker/runner/runner.py`:**
- Add `patch_glm_tokenizer()` - fixed version of mlx_lm's glm47 parser
with None checks
- Fix `patch_kimi_tokenizer()` - add None checks before calling
`.group()` on regex matches
- Add `ValueError` and `AttributeError` to exception handling in
`parse_tool_calls()`

**`src/exo/worker/engines/mlx/utils_mlx.py`:**
- Add `_normalize_tool_calls()` - parses
`tool_calls[].function.arguments` from JSON string to dict for templates
that expect dicts (like GLM-4.7-Flash)

## Why It Works

1. **Parser fixes**: By checking if regex matches are `None` before
calling `.group()`, we can raise a proper `ValueError` instead of
crashing with `AttributeError`

2. **Template fix**: The GLM-4.7-Flash chat template iterates over
arguments with `.items()`:
   ```jinja2
   {% set _args = tc.arguments %}{% for k, v in _args.items() %}
   ```
OpenAI format has `arguments` as a JSON string.
`_normalize_tool_calls()` parses this to a dict before passing to the
template.

## Test Plan

### Manual Testing
- Hardware: Mac with GLM-4.7-Flash-4bit model
- Tested tool calling with GLM model - no longer crashes

### Automated Testing
- Existing tests pass (`uv run pytest`)
- Type checking passes (`uv run basedpyright`)
- Linting passes (`uv run ruff check`)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-01-23 01:39:59 +00:00
7 changed files with 215 additions and 252 deletions

View File

@@ -110,36 +110,6 @@
setImageGenerationParams({ negativePrompt: value || null });
}
function handleNumImagesChange(event: Event) {
const input = event.target as HTMLInputElement;
const value = input.value.trim();
if (value === "") {
setImageGenerationParams({ numImages: 1 });
} else {
const num = parseInt(value, 10);
if (!isNaN(num) && num >= 1) {
setImageGenerationParams({ numImages: num });
}
}
}
function handleStreamChange(enabled: boolean) {
setImageGenerationParams({ stream: enabled });
}
function handlePartialImagesChange(event: Event) {
const input = event.target as HTMLInputElement;
const value = input.value.trim();
if (value === "") {
setImageGenerationParams({ partialImages: 0 });
} else {
const num = parseInt(value, 10);
if (!isNaN(num) && num >= 0) {
setImageGenerationParams({ partialImages: num });
}
}
}
function clearSteps() {
setImageGenerationParams({ numInferenceSteps: null });
}
@@ -355,59 +325,6 @@
</div>
</div>
<!-- Number of Images (not in edit mode) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>IMAGES:</span
>
<input
type="number"
min="1"
value={params.numImages}
oninput={handleNumImagesChange}
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
/>
</div>
{/if}
<!-- Stream toggle -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>STREAM:</span
>
<button
type="button"
onclick={() => handleStreamChange(!params.stream)}
class="w-8 h-4 rounded-full transition-all duration-200 cursor-pointer relative {params.stream
? 'bg-exo-yellow'
: 'bg-exo-medium-gray/50 border border-exo-yellow/30'}"
title={params.stream ? "Streaming enabled" : "Streaming disabled"}
>
<div
class="absolute top-0.5 w-3 h-3 rounded-full transition-all duration-200 {params.stream
? 'right-0.5 bg-exo-black'
: 'left-0.5 bg-exo-light-gray'}"
></div>
</button>
</div>
<!-- Partial Images (only when streaming) -->
{#if params.stream}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>PARTIALS:</span
>
<input
type="number"
min="0"
value={params.partialImages}
oninput={handlePartialImagesChange}
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
/>
</div>
{/if}
<!-- Input Fidelity (edit mode only) -->
{#if isEditMode}
<div class="flex items-center gap-1.5">

View File

@@ -238,10 +238,6 @@ export interface ImageGenerationParams {
size: "512x512" | "768x768" | "1024x1024" | "1024x768" | "768x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
// Streaming params
stream: boolean;
partialImages: number;
// Advanced params
seed: number | null;
numInferenceSteps: number | null;
@@ -261,9 +257,6 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
size: "1024x1024",
quality: "medium",
outputFormat: "png",
numImages: 1,
stream: true,
partialImages: 3,
seed: null,
numInferenceSteps: null,
guidance: null,
@@ -1816,13 +1809,12 @@ class AppStore {
const requestBody: Record<string, unknown> = {
model,
prompt,
n: params.numImages,
quality: params.quality,
size: params.size,
output_format: params.outputFormat,
response_format: "b64_json",
stream: params.stream,
partial_images: params.partialImages,
stream: true,
partial_images: 3,
};
if (hasAdvancedParams) {
@@ -1886,74 +1878,31 @@ class AppStore {
if (imageData && idx !== -1) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
const numImages = params.numImages;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
this.messages[idx].content = progressText;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
this.messages[idx].attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments =
this.messages[idx].attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
this.messages[idx].attachments = [
...finals,
partialAttachment,
];
}
this.messages[idx].content =
`Generating... ${partialNum}/${totalPartials}`;
this.messages[idx].attachments = [
{
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
if (imageIndex === 0) {
// First final image - replace any partial preview
this.messages[idx].attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments =
this.messages[idx].attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
this.messages[idx].attachments = [
...previousFinals,
newAttachment,
];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
this.messages[idx].content =
`Generating image ${imageIndex + 2}/${numImages}...`;
} else {
this.messages[idx].content = "";
}
// Final image
this.messages[idx].content = "";
this.messages[idx].attachments = [
{
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
}
}
} catch {
@@ -2034,8 +1983,8 @@ class AppStore {
formData.append("size", params.size);
formData.append("output_format", params.outputFormat);
formData.append("response_format", "b64_json");
formData.append("stream", params.stream ? "1" : "0");
formData.append("partial_images", params.partialImages.toString());
formData.append("stream", "1"); // Use "1" instead of "true" for reliable FastAPI boolean parsing
formData.append("partial_images", "3");
formData.append("input_fidelity", params.inputFidelity);
// Advanced params

View File

@@ -835,7 +835,6 @@ class API:
# Yield partial image event (always use b64_json for partials)
event_data = {
"type": "partial",
"image_index": chunk.image_index,
"partial_index": partial_idx,
"total_partials": total_partials,
"format": str(chunk.format),

View File

@@ -30,7 +30,6 @@ class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
stats: ImageGenerationStats | None = None
image_index: int = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
@@ -45,7 +44,6 @@ class PartialImageResponse(BaseRunnerResponse):
format: Literal["png", "jpeg", "webp"] = "png"
partial_index: int
total_partials: int
image_index: int = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]

View File

@@ -75,20 +75,19 @@ def generate_image(
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
ImageGenerationResponse for final complete images
PartialImageResponse for intermediate images (if partial_images > 0)
ImageGenerationResponse for the final complete image
"""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
advanced_params = task.advanced_params
if advanced_params is not None and advanced_params.seed is not None:
base_seed = advanced_params.seed
seed = advanced_params.seed
else:
base_seed = random.randint(0, 2**32 - 1)
seed = random.randint(0, 2**32 - 1)
is_bench = getattr(task, "bench", False)
num_images = task.n or 1
generation_start_time: float = 0.0
@@ -96,11 +95,7 @@ def generate_image(
mx.reset_peak_memory()
generation_start_time = time.perf_counter()
partial_images = (
task.partial_images
if task.partial_images is not None
else (3 if task.stream else 0)
)
partial_images = task.partial_images or (3 if task.stream else 0)
image_path: Path | None = None
@@ -110,81 +105,72 @@ def generate_image(
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
for image_num in range(num_images):
# Increment seed for each image to ensure unique results
current_seed = base_seed + image_num
# Iterate over generator results
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=current_seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
)
else:
image = result
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
image_index=image_num,
stats: ImageGenerationStats | None = None
if is_bench:
generation_end_time = time.perf_counter()
total_generation_time = generation_end_time - generation_start_time
num_inference_steps = model.get_steps_for_quality(quality)
seconds_per_step = (
total_generation_time / num_inference_steps
if num_inference_steps > 0
else 0.0
)
else:
image = result
# Only include stats on the final image
stats: ImageGenerationStats | None = None
if is_bench and image_num == num_images - 1:
generation_end_time = time.perf_counter()
total_generation_time = (
generation_end_time - generation_start_time
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
num_inference_steps = model.get_steps_for_quality(quality)
total_steps = num_inference_steps * num_images
seconds_per_step = (
total_generation_time / total_steps
if total_steps > 0
else 0.0
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=num_images,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
image_index=image_num,
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=task.n or 1,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
)

View File

@@ -365,12 +365,35 @@ def load_tokenizer_for_model_id(
return tokenizer
def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
"""
Normalize tool_calls in a message dict.
OpenAI format has tool_calls[].function.arguments as a JSON string,
but some chat templates (e.g., GLM) expect it as a dict.
"""
tool_calls = msg_dict.get("tool_calls")
if not tool_calls or not isinstance(tool_calls, list):
return
for tc in tool_calls: # pyright: ignore[reportUnknownVariableType]
if not isinstance(tc, dict):
continue
func = tc.get("function") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if not isinstance(func, dict):
continue
args = func.get("arguments") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if isinstance(args, str):
with contextlib.suppress(json.JSONDecodeError):
func["arguments"] = json.loads(args)
def apply_chat_template(
tokenizer: TokenizerWrapper,
chat_task_data: ChatCompletionTaskParams,
) -> str:
# Now we can properly access the messages
messages = chat_task_data.messages
tools = chat_task_data.tools
formatted_messages: list[dict[str, Any]] = []
for message in messages:
@@ -386,15 +409,19 @@ def apply_chat_template(
continue
# Null values are not valid when applying templates in tokenizer
formatted_messages.append(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
)
dumped: dict[str, Any] = message.model_dump()
msg_dict: dict[str, Any] = {k: v for k, v in dumped.items() if v is not None} # pyright: ignore[reportAny]
# Parse tool_calls arguments from JSON string to dict for templates that expect dicts
_normalize_tool_calls(msg_dict)
formatted_messages.append(msg_dict)
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
tools=tools,
)
logger.info(prompt)

View File

@@ -256,6 +256,10 @@ def main(
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
if "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
@@ -608,7 +612,7 @@ def _process_image_response(
command_id=command_id,
model_id=shard_metadata.model_card.model_id,
event_sender=event_sender,
image_index=response.image_index,
image_index=response.partial_index if is_partial else image_index,
is_partial=is_partial,
partial_index=response.partial_index if is_partial else None,
total_partials=response.total_partials if is_partial else None,
@@ -645,7 +649,14 @@ def parse_tool_calls(
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools)
except (json.JSONDecodeError, ValidationError) as e:
except (
json.JSONDecodeError,
ValidationError,
ValueError,
AttributeError,
) as e:
# ValueError: our parsers raise this for malformed tool calls
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
@@ -698,11 +709,17 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name = _func_name_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
# strip off the `functions.` prefix, if it exists.
func_name = func_name[func_name.find(".") + 1 :]
func_args = _func_arg_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError(f"Could not parse function args from tool call: {text!r}")
func_args = func_args_match.group(1)
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
@@ -713,6 +730,76 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""
import ast
import json
from typing import Any
import regex as re
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
_func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
re.DOTALL,
)
tool_call_start = "<tool_call>"
tool_call_end = "</tool_call>"
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[Any] | None,
) -> bool:
if tools is None:
return False
for tool in tools: # pyright: ignore[reportAny]
func = tool["function"] # pyright: ignore[reportAny]
if func["name"] == tool_name:
params = func["parameters"] # pyright: ignore[reportAny]
if params is None:
return False
props = params.get("properties", {}) # pyright: ignore[reportAny]
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
return arg_type == "string" # pyright: ignore[reportAny]
return False
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: list[Any] | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
pairs = _func_arg_regex.findall(text)
arg_dct: dict[str, Any] = {}
for key, value in pairs: # pyright: ignore[reportAny]
arg_key = key.strip() # pyright: ignore[reportAny]
arg_val = value.strip() # pyright: ignore[reportAny]
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
arg_dct[arg_key] = arg_val
return dict(name=func_name, arguments=arg_dct)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)