mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 05:23:11 -05:00
Compare commits
1 Commits
ciaran/mul
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df240f834d |
@@ -110,36 +110,6 @@
|
||||
setImageGenerationParams({ negativePrompt: value || null });
|
||||
}
|
||||
|
||||
function handleNumImagesChange(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
const value = input.value.trim();
|
||||
if (value === "") {
|
||||
setImageGenerationParams({ numImages: 1 });
|
||||
} else {
|
||||
const num = parseInt(value, 10);
|
||||
if (!isNaN(num) && num >= 1) {
|
||||
setImageGenerationParams({ numImages: num });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleStreamChange(enabled: boolean) {
|
||||
setImageGenerationParams({ stream: enabled });
|
||||
}
|
||||
|
||||
function handlePartialImagesChange(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
const value = input.value.trim();
|
||||
if (value === "") {
|
||||
setImageGenerationParams({ partialImages: 0 });
|
||||
} else {
|
||||
const num = parseInt(value, 10);
|
||||
if (!isNaN(num) && num >= 0) {
|
||||
setImageGenerationParams({ partialImages: num });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function clearSteps() {
|
||||
setImageGenerationParams({ numInferenceSteps: null });
|
||||
}
|
||||
@@ -355,59 +325,6 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Number of Images (not in edit mode) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>IMAGES:</span
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="1"
|
||||
value={params.numImages}
|
||||
oninput={handleNumImagesChange}
|
||||
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Stream toggle -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>STREAM:</span
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => handleStreamChange(!params.stream)}
|
||||
class="w-8 h-4 rounded-full transition-all duration-200 cursor-pointer relative {params.stream
|
||||
? 'bg-exo-yellow'
|
||||
: 'bg-exo-medium-gray/50 border border-exo-yellow/30'}"
|
||||
title={params.stream ? "Streaming enabled" : "Streaming disabled"}
|
||||
>
|
||||
<div
|
||||
class="absolute top-0.5 w-3 h-3 rounded-full transition-all duration-200 {params.stream
|
||||
? 'right-0.5 bg-exo-black'
|
||||
: 'left-0.5 bg-exo-light-gray'}"
|
||||
></div>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Partial Images (only when streaming) -->
|
||||
{#if params.stream}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>PARTIALS:</span
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="0"
|
||||
value={params.partialImages}
|
||||
oninput={handlePartialImagesChange}
|
||||
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Input Fidelity (edit mode only) -->
|
||||
{#if isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
|
||||
@@ -238,10 +238,6 @@ export interface ImageGenerationParams {
|
||||
size: "512x512" | "768x768" | "1024x1024" | "1024x768" | "768x1024";
|
||||
quality: "low" | "medium" | "high";
|
||||
outputFormat: "png" | "jpeg";
|
||||
numImages: number;
|
||||
// Streaming params
|
||||
stream: boolean;
|
||||
partialImages: number;
|
||||
// Advanced params
|
||||
seed: number | null;
|
||||
numInferenceSteps: number | null;
|
||||
@@ -261,9 +257,6 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
size: "1024x1024",
|
||||
quality: "medium",
|
||||
outputFormat: "png",
|
||||
numImages: 1,
|
||||
stream: true,
|
||||
partialImages: 3,
|
||||
seed: null,
|
||||
numInferenceSteps: null,
|
||||
guidance: null,
|
||||
@@ -1816,13 +1809,12 @@ class AppStore {
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model,
|
||||
prompt,
|
||||
n: params.numImages,
|
||||
quality: params.quality,
|
||||
size: params.size,
|
||||
output_format: params.outputFormat,
|
||||
response_format: "b64_json",
|
||||
stream: params.stream,
|
||||
partial_images: params.partialImages,
|
||||
stream: true,
|
||||
partial_images: 3,
|
||||
};
|
||||
|
||||
if (hasAdvancedParams) {
|
||||
@@ -1886,74 +1878,31 @@ class AppStore {
|
||||
if (imageData && idx !== -1) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
const numImages = params.numImages;
|
||||
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
this.messages[idx].content = progressText;
|
||||
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
this.messages[idx].attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments =
|
||||
this.messages[idx].attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
this.messages[idx].attachments = [
|
||||
...finals,
|
||||
partialAttachment,
|
||||
];
|
||||
}
|
||||
this.messages[idx].content =
|
||||
`Generating... ${partialNum}/${totalPartials}`;
|
||||
this.messages[idx].attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
this.messages[idx].attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments =
|
||||
this.messages[idx].attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
this.messages[idx].attachments = [
|
||||
...previousFinals,
|
||||
newAttachment,
|
||||
];
|
||||
}
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
this.messages[idx].content =
|
||||
`Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
this.messages[idx].content = "";
|
||||
}
|
||||
// Final image
|
||||
this.messages[idx].content = "";
|
||||
this.messages[idx].attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
@@ -2034,8 +1983,8 @@ class AppStore {
|
||||
formData.append("size", params.size);
|
||||
formData.append("output_format", params.outputFormat);
|
||||
formData.append("response_format", "b64_json");
|
||||
formData.append("stream", params.stream ? "1" : "0");
|
||||
formData.append("partial_images", params.partialImages.toString());
|
||||
formData.append("stream", "1"); // Use "1" instead of "true" for reliable FastAPI boolean parsing
|
||||
formData.append("partial_images", "3");
|
||||
formData.append("input_fidelity", params.inputFidelity);
|
||||
|
||||
// Advanced params
|
||||
|
||||
@@ -835,7 +835,6 @@ class API:
|
||||
# Yield partial image event (always use b64_json for partials)
|
||||
event_data = {
|
||||
"type": "partial",
|
||||
"image_index": chunk.image_index,
|
||||
"partial_index": partial_idx,
|
||||
"total_partials": total_partials,
|
||||
"format": str(chunk.format),
|
||||
|
||||
@@ -30,7 +30,6 @@ class ImageGenerationResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
stats: ImageGenerationStats | None = None
|
||||
image_index: int = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
@@ -45,7 +44,6 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_index: int
|
||||
total_partials: int
|
||||
image_index: int = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
|
||||
@@ -75,20 +75,19 @@ def generate_image(
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
|
||||
ImageGenerationResponse for final complete images
|
||||
PartialImageResponse for intermediate images (if partial_images > 0)
|
||||
ImageGenerationResponse for the final complete image
|
||||
"""
|
||||
width, height = parse_size(task.size)
|
||||
quality: Literal["low", "medium", "high"] = task.quality or "medium"
|
||||
|
||||
advanced_params = task.advanced_params
|
||||
if advanced_params is not None and advanced_params.seed is not None:
|
||||
base_seed = advanced_params.seed
|
||||
seed = advanced_params.seed
|
||||
else:
|
||||
base_seed = random.randint(0, 2**32 - 1)
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
is_bench = getattr(task, "bench", False)
|
||||
num_images = task.n or 1
|
||||
|
||||
generation_start_time: float = 0.0
|
||||
|
||||
@@ -96,11 +95,7 @@ def generate_image(
|
||||
mx.reset_peak_memory()
|
||||
generation_start_time = time.perf_counter()
|
||||
|
||||
partial_images = (
|
||||
task.partial_images
|
||||
if task.partial_images is not None
|
||||
else (3 if task.stream else 0)
|
||||
)
|
||||
partial_images = task.partial_images or (3 if task.stream else 0)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
@@ -110,81 +105,72 @@ def generate_image(
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
|
||||
for image_num in range(num_images):
|
||||
# Increment seed for each image to ensure unique results
|
||||
current_seed = base_seed + image_num
|
||||
# Iterate over generator results
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=current_seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
image_index=image_num,
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = generation_end_time - generation_start_time
|
||||
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / num_inference_steps
|
||||
if num_inference_steps > 0
|
||||
else 0.0
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
|
||||
# Only include stats on the final image
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench and image_num == num_images - 1:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = (
|
||||
generation_end_time - generation_start_time
|
||||
)
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
total_steps = num_inference_steps * num_images
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / total_steps
|
||||
if total_steps > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=num_images,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
image_index=image_num,
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=task.n or 1,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
@@ -365,12 +365,35 @@ def load_tokenizer_for_model_id(
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
|
||||
"""
|
||||
Normalize tool_calls in a message dict.
|
||||
|
||||
OpenAI format has tool_calls[].function.arguments as a JSON string,
|
||||
but some chat templates (e.g., GLM) expect it as a dict.
|
||||
"""
|
||||
tool_calls = msg_dict.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
return
|
||||
|
||||
for tc in tool_calls: # pyright: ignore[reportUnknownVariableType]
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
func = tc.get("function") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if not isinstance(func, dict):
|
||||
continue
|
||||
args = func.get("arguments") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if isinstance(args, str):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
func["arguments"] = json.loads(args)
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
chat_task_data: ChatCompletionTaskParams,
|
||||
) -> str:
|
||||
# Now we can properly access the messages
|
||||
messages = chat_task_data.messages
|
||||
tools = chat_task_data.tools
|
||||
|
||||
formatted_messages: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
@@ -386,15 +409,19 @@ def apply_chat_template(
|
||||
continue
|
||||
|
||||
# Null values are not valid when applying templates in tokenizer
|
||||
formatted_messages.append(
|
||||
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
|
||||
)
|
||||
dumped: dict[str, Any] = message.model_dump()
|
||||
msg_dict: dict[str, Any] = {k: v for k, v in dumped.items() if v is not None} # pyright: ignore[reportAny]
|
||||
|
||||
# Parse tool_calls arguments from JSON string to dict for templates that expect dicts
|
||||
_normalize_tool_calls(msg_dict)
|
||||
|
||||
formatted_messages.append(msg_dict)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
@@ -256,6 +256,10 @@ def main(
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
if "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
@@ -608,7 +612,7 @@ def _process_image_response(
|
||||
command_id=command_id,
|
||||
model_id=shard_metadata.model_card.model_id,
|
||||
event_sender=event_sender,
|
||||
image_index=response.image_index,
|
||||
image_index=response.partial_index if is_partial else image_index,
|
||||
is_partial=is_partial,
|
||||
partial_index=response.partial_index if is_partial else None,
|
||||
total_partials=response.total_partials if is_partial else None,
|
||||
@@ -645,7 +649,14 @@ def parse_tool_calls(
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
|
||||
except (json.JSONDecodeError, ValidationError) as e:
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
ValidationError,
|
||||
ValueError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
# ValueError: our parsers raise this for malformed tool calls
|
||||
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
|
||||
logger.opt(exception=e).warning("tool call parsing failed")
|
||||
# assumption: talking about tool calls, not making a tool call
|
||||
response.text = (
|
||||
@@ -698,11 +709,17 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: Any | None = None):
|
||||
func_name = _func_name_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
# strip off the `functions.` prefix, if it exists.
|
||||
func_name = func_name[func_name.find(".") + 1 :]
|
||||
|
||||
func_args = _func_arg_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError(f"Could not parse function args from tool call: {text!r}")
|
||||
func_args = func_args_match.group(1)
|
||||
# the args should be valid json - no need to check against our tools to deserialize
|
||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||
|
||||
@@ -713,6 +730,76 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
|
||||
_func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
tool_call_start = "<tool_call>"
|
||||
tool_call_end = "</tool_call>"
|
||||
|
||||
def _is_string_type(
|
||||
tool_name: str,
|
||||
arg_name: str,
|
||||
tools: list[Any] | None,
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools: # pyright: ignore[reportAny]
|
||||
func = tool["function"] # pyright: ignore[reportAny]
|
||||
if func["name"] == tool_name:
|
||||
params = func["parameters"] # pyright: ignore[reportAny]
|
||||
if params is None:
|
||||
return False
|
||||
props = params.get("properties", {}) # pyright: ignore[reportAny]
|
||||
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
|
||||
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
|
||||
return arg_type == "string" # pyright: ignore[reportAny]
|
||||
return False
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: list[Any] | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
|
||||
pairs = _func_arg_regex.findall(text)
|
||||
arg_dct: dict[str, Any] = {}
|
||||
for key, value in pairs: # pyright: ignore[reportAny]
|
||||
arg_key = key.strip() # pyright: ignore[reportAny]
|
||||
arg_val = value.strip() # pyright: ignore[reportAny]
|
||||
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
|
||||
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
|
||||
arg_dct[arg_key] = arg_val
|
||||
return dict(name=func_name, arguments=arg_dct)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
||||
if (
|
||||
((name := obj.get("name")) is not None)
|
||||
|
||||
Reference in New Issue
Block a user