mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-23 05:23:11 -05:00
Compare commits
3 Commits
ciaran/mul
...
v1.0.64
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55b67e2be2 | ||
|
|
df240f834d | ||
|
|
30cfad9b68 |
@@ -110,36 +110,6 @@
|
||||
setImageGenerationParams({ negativePrompt: value || null });
|
||||
}
|
||||
|
||||
function handleNumImagesChange(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
const value = input.value.trim();
|
||||
if (value === "") {
|
||||
setImageGenerationParams({ numImages: 1 });
|
||||
} else {
|
||||
const num = parseInt(value, 10);
|
||||
if (!isNaN(num) && num >= 1) {
|
||||
setImageGenerationParams({ numImages: num });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleStreamChange(enabled: boolean) {
|
||||
setImageGenerationParams({ stream: enabled });
|
||||
}
|
||||
|
||||
function handlePartialImagesChange(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
const value = input.value.trim();
|
||||
if (value === "") {
|
||||
setImageGenerationParams({ partialImages: 0 });
|
||||
} else {
|
||||
const num = parseInt(value, 10);
|
||||
if (!isNaN(num) && num >= 0) {
|
||||
setImageGenerationParams({ partialImages: num });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function clearSteps() {
|
||||
setImageGenerationParams({ numInferenceSteps: null });
|
||||
}
|
||||
@@ -355,59 +325,6 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Number of Images (not in edit mode) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>IMAGES:</span
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="1"
|
||||
value={params.numImages}
|
||||
oninput={handleNumImagesChange}
|
||||
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Stream toggle -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>STREAM:</span
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => handleStreamChange(!params.stream)}
|
||||
class="w-8 h-4 rounded-full transition-all duration-200 cursor-pointer relative {params.stream
|
||||
? 'bg-exo-yellow'
|
||||
: 'bg-exo-medium-gray/50 border border-exo-yellow/30'}"
|
||||
title={params.stream ? "Streaming enabled" : "Streaming disabled"}
|
||||
>
|
||||
<div
|
||||
class="absolute top-0.5 w-3 h-3 rounded-full transition-all duration-200 {params.stream
|
||||
? 'right-0.5 bg-exo-black'
|
||||
: 'left-0.5 bg-exo-light-gray'}"
|
||||
></div>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Partial Images (only when streaming) -->
|
||||
{#if params.stream}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>PARTIALS:</span
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="0"
|
||||
value={params.partialImages}
|
||||
oninput={handlePartialImagesChange}
|
||||
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Input Fidelity (edit mode only) -->
|
||||
{#if isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
|
||||
@@ -238,10 +238,6 @@ export interface ImageGenerationParams {
|
||||
size: "512x512" | "768x768" | "1024x1024" | "1024x768" | "768x1024";
|
||||
quality: "low" | "medium" | "high";
|
||||
outputFormat: "png" | "jpeg";
|
||||
numImages: number;
|
||||
// Streaming params
|
||||
stream: boolean;
|
||||
partialImages: number;
|
||||
// Advanced params
|
||||
seed: number | null;
|
||||
numInferenceSteps: number | null;
|
||||
@@ -261,9 +257,6 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
|
||||
size: "1024x1024",
|
||||
quality: "medium",
|
||||
outputFormat: "png",
|
||||
numImages: 1,
|
||||
stream: true,
|
||||
partialImages: 3,
|
||||
seed: null,
|
||||
numInferenceSteps: null,
|
||||
guidance: null,
|
||||
@@ -1816,13 +1809,12 @@ class AppStore {
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model,
|
||||
prompt,
|
||||
n: params.numImages,
|
||||
quality: params.quality,
|
||||
size: params.size,
|
||||
output_format: params.outputFormat,
|
||||
response_format: "b64_json",
|
||||
stream: params.stream,
|
||||
partial_images: params.partialImages,
|
||||
stream: true,
|
||||
partial_images: 3,
|
||||
};
|
||||
|
||||
if (hasAdvancedParams) {
|
||||
@@ -1886,74 +1878,31 @@ class AppStore {
|
||||
if (imageData && idx !== -1) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
const numImages = params.numImages;
|
||||
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
this.messages[idx].content = progressText;
|
||||
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
this.messages[idx].attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments =
|
||||
this.messages[idx].attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
this.messages[idx].attachments = [
|
||||
...finals,
|
||||
partialAttachment,
|
||||
];
|
||||
}
|
||||
this.messages[idx].content =
|
||||
`Generating... ${partialNum}/${totalPartials}`;
|
||||
this.messages[idx].attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
this.messages[idx].attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments =
|
||||
this.messages[idx].attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
this.messages[idx].attachments = [
|
||||
...previousFinals,
|
||||
newAttachment,
|
||||
];
|
||||
}
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
this.messages[idx].content =
|
||||
`Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
this.messages[idx].content = "";
|
||||
}
|
||||
// Final image
|
||||
this.messages[idx].content = "";
|
||||
this.messages[idx].attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
@@ -2034,8 +1983,8 @@ class AppStore {
|
||||
formData.append("size", params.size);
|
||||
formData.append("output_format", params.outputFormat);
|
||||
formData.append("response_format", "b64_json");
|
||||
formData.append("stream", params.stream ? "1" : "0");
|
||||
formData.append("partial_images", params.partialImages.toString());
|
||||
formData.append("stream", "1"); // Use "1" instead of "true" for reliable FastAPI boolean parsing
|
||||
formData.append("partial_images", "3");
|
||||
formData.append("input_fidelity", params.inputFidelity);
|
||||
|
||||
// Advanced params
|
||||
|
||||
@@ -17,7 +17,7 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
|
||||
@@ -835,7 +835,6 @@ class API:
|
||||
# Yield partial image event (always use b64_json for partials)
|
||||
event_data = {
|
||||
"type": "partial",
|
||||
"image_index": chunk.image_index,
|
||||
"partial_index": partial_idx,
|
||||
"total_partials": total_partials,
|
||||
"format": str(chunk.format),
|
||||
|
||||
@@ -30,7 +30,6 @@ class ImageGenerationResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
stats: ImageGenerationStats | None = None
|
||||
image_index: int = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
@@ -45,7 +44,6 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_index: int
|
||||
total_partials: int
|
||||
image_index: int = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
|
||||
@@ -75,20 +75,19 @@ def generate_image(
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
|
||||
ImageGenerationResponse for final complete images
|
||||
PartialImageResponse for intermediate images (if partial_images > 0)
|
||||
ImageGenerationResponse for the final complete image
|
||||
"""
|
||||
width, height = parse_size(task.size)
|
||||
quality: Literal["low", "medium", "high"] = task.quality or "medium"
|
||||
|
||||
advanced_params = task.advanced_params
|
||||
if advanced_params is not None and advanced_params.seed is not None:
|
||||
base_seed = advanced_params.seed
|
||||
seed = advanced_params.seed
|
||||
else:
|
||||
base_seed = random.randint(0, 2**32 - 1)
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
is_bench = getattr(task, "bench", False)
|
||||
num_images = task.n or 1
|
||||
|
||||
generation_start_time: float = 0.0
|
||||
|
||||
@@ -96,11 +95,7 @@ def generate_image(
|
||||
mx.reset_peak_memory()
|
||||
generation_start_time = time.perf_counter()
|
||||
|
||||
partial_images = (
|
||||
task.partial_images
|
||||
if task.partial_images is not None
|
||||
else (3 if task.stream else 0)
|
||||
)
|
||||
partial_images = task.partial_images or (3 if task.stream else 0)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
@@ -110,81 +105,72 @@ def generate_image(
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
|
||||
for image_num in range(num_images):
|
||||
# Increment seed for each image to ensure unique results
|
||||
current_seed = base_seed + image_num
|
||||
# Iterate over generator results
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=current_seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
image_index=image_num,
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = generation_end_time - generation_start_time
|
||||
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / num_inference_steps
|
||||
if num_inference_steps > 0
|
||||
else 0.0
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
|
||||
# Only include stats on the final image
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench and image_num == num_images - 1:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = (
|
||||
generation_end_time - generation_start_time
|
||||
)
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
total_steps = num_inference_steps * num_images
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / total_steps
|
||||
if total_steps > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=num_images,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
image_index=image_num,
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=task.n or 1,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
@@ -145,6 +145,10 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
-output.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -252,10 +256,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
logits = mx.distributed.all_gather(logits, group=group)[
|
||||
-logits.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
|
||||
@@ -170,10 +170,10 @@ def mlx_distributed_init(
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_JACCL_DEVICES"] = coordination_file
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
@@ -365,12 +365,35 @@ def load_tokenizer_for_model_id(
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
|
||||
"""
|
||||
Normalize tool_calls in a message dict.
|
||||
|
||||
OpenAI format has tool_calls[].function.arguments as a JSON string,
|
||||
but some chat templates (e.g., GLM) expect it as a dict.
|
||||
"""
|
||||
tool_calls = msg_dict.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
return
|
||||
|
||||
for tc in tool_calls: # pyright: ignore[reportUnknownVariableType]
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
func = tc.get("function") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if not isinstance(func, dict):
|
||||
continue
|
||||
args = func.get("arguments") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if isinstance(args, str):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
func["arguments"] = json.loads(args)
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
chat_task_data: ChatCompletionTaskParams,
|
||||
) -> str:
|
||||
# Now we can properly access the messages
|
||||
messages = chat_task_data.messages
|
||||
tools = chat_task_data.tools
|
||||
|
||||
formatted_messages: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
@@ -386,15 +409,19 @@ def apply_chat_template(
|
||||
continue
|
||||
|
||||
# Null values are not valid when applying templates in tokenizer
|
||||
formatted_messages.append(
|
||||
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
|
||||
)
|
||||
dumped: dict[str, Any] = message.model_dump()
|
||||
msg_dict: dict[str, Any] = {k: v for k, v in dumped.items() if v is not None} # pyright: ignore[reportAny]
|
||||
|
||||
# Parse tool_calls arguments from JSON string to dict for templates that expect dicts
|
||||
_normalize_tool_calls(msg_dict)
|
||||
|
||||
formatted_messages.append(msg_dict)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
@@ -256,6 +256,10 @@ def main(
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
if "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
@@ -608,7 +612,7 @@ def _process_image_response(
|
||||
command_id=command_id,
|
||||
model_id=shard_metadata.model_card.model_id,
|
||||
event_sender=event_sender,
|
||||
image_index=response.image_index,
|
||||
image_index=response.partial_index if is_partial else image_index,
|
||||
is_partial=is_partial,
|
||||
partial_index=response.partial_index if is_partial else None,
|
||||
total_partials=response.total_partials if is_partial else None,
|
||||
@@ -645,7 +649,14 @@ def parse_tool_calls(
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
|
||||
except (json.JSONDecodeError, ValidationError) as e:
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
ValidationError,
|
||||
ValueError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
# ValueError: our parsers raise this for malformed tool calls
|
||||
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
|
||||
logger.opt(exception=e).warning("tool call parsing failed")
|
||||
# assumption: talking about tool calls, not making a tool call
|
||||
response.text = (
|
||||
@@ -698,11 +709,17 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: Any | None = None):
|
||||
func_name = _func_name_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
# strip off the `functions.` prefix, if it exists.
|
||||
func_name = func_name[func_name.find(".") + 1 :]
|
||||
|
||||
func_args = _func_arg_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError(f"Could not parse function args from tool call: {text!r}")
|
||||
func_args = func_args_match.group(1)
|
||||
# the args should be valid json - no need to check against our tools to deserialize
|
||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||
|
||||
@@ -713,6 +730,76 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
|
||||
_func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
tool_call_start = "<tool_call>"
|
||||
tool_call_end = "</tool_call>"
|
||||
|
||||
def _is_string_type(
|
||||
tool_name: str,
|
||||
arg_name: str,
|
||||
tools: list[Any] | None,
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools: # pyright: ignore[reportAny]
|
||||
func = tool["function"] # pyright: ignore[reportAny]
|
||||
if func["name"] == tool_name:
|
||||
params = func["parameters"] # pyright: ignore[reportAny]
|
||||
if params is None:
|
||||
return False
|
||||
props = params.get("properties", {}) # pyright: ignore[reportAny]
|
||||
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
|
||||
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
|
||||
return arg_type == "string" # pyright: ignore[reportAny]
|
||||
return False
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: list[Any] | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
|
||||
pairs = _func_arg_regex.findall(text)
|
||||
arg_dct: dict[str, Any] = {}
|
||||
for key, value in pairs: # pyright: ignore[reportAny]
|
||||
arg_key = key.strip() # pyright: ignore[reportAny]
|
||||
arg_val = value.strip() # pyright: ignore[reportAny]
|
||||
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
|
||||
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
|
||||
arg_dct[arg_key] = arg_val
|
||||
return dict(name=func_name, arguments=arg_dct)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
||||
if (
|
||||
((name := obj.get("name")) is not None)
|
||||
|
||||
59
uv.lock
generated
59
uv.lock
generated
@@ -376,8 +376,8 @@ dependencies = [
|
||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -413,7 +413,7 @@ requires-dist = [
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mflux", specifier = ">=0.14.2" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
|
||||
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
@@ -458,16 +458,6 @@ dev = [
|
||||
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomlkit"
|
||||
version = "0.14.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
|
||||
]
|
||||
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.128.0"
|
||||
@@ -1004,8 +994,8 @@ dependencies = [
|
||||
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1032,18 +1022,12 @@ wheels = [
|
||||
name = "mlx"
|
||||
version = "0.30.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
|
||||
resolution-markers = [
|
||||
"sys_platform == 'linux'",
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/22/42935d593fe82d3b98eb9d60e4620ed99703886635106f89d407c68f33bc/mlx-0.30.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743fac1e4f9e8e46c8262943c643a31139c255cdb256c99ad496958215ccac1e", size = 569344, upload-time = "2026-01-14T01:16:54.847Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/27/f2e7a5236289d45315d0215e8553b4dd7e2faaba3bcb5025b34b25d5ab66/mlx-0.30.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:3b04ae81655aa0e63a6e8f2c749de3bbce64cf5b168ae10f39ed086dfa99e7f8", size = 569345, upload-time = "2026-01-14T01:16:56.564Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/41/06b042457f51952456e9bb46b2c6e205ab3a28fc52d6751b5787fdb762b2/mlx-0.30.3-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:ba9b5bdb1e929cc130af72efd7f73508c0f4e526d224489af7ec1c6419564659", size = 569213, upload-time = "2026-01-14T05:52:10.86Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/1e/f62c98fc0d2d878ee4235671f9d406b13cc9240493ba6fcfde2f72c2ff83/mlx-0.30.3-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:dfe5c5b64e55398a22100804abbf9681996b03129e720e36b1727ed704db12b5", size = 617309, upload-time = "2026-01-14T01:16:57.58Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/62/811f064693449de740350d27793ce39343a460305ec8d878c318b80921d0/mlx-0.30.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:a3364924610929936e6aaf13c71106161258e5a5d3f7813a64c07cc2435f9f55", size = 659521, upload-time = "2026-01-14T01:16:58.719Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/e2/6e551bd48fb350fbf0ee4cc5cd09485437d260b8f4937f22d8623e14687a/mlx-0.30.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2c27fd8daaae14ca6cf407fcd236006a6e968f7708c8f61a2709116f2e754852", size = 571920, upload-time = "2026-01-14T01:16:59.683Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/c0/561d1c9d3d12830b0e7fdcbd807585ef20909e398d4bcdbf25e4367543eb/mlx-0.30.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:b755fd4ed4b6a2ae4dee3766b5a2ea52fcbe83ebd1cf018458e18b74139409f3", size = 571921, upload-time = "2026-01-14T01:17:00.868Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/1a/fb573fc2edc22a777fa254ff5c0c886ffd2c88aeb1f21c45778ef170f990/mlx-0.30.3-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:7e352c0369a2f7e54d4f317b434eab3333918ea9edde1c43c61d36386b6f76bf", size = 571732, upload-time = "2026-01-14T05:52:11.893Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/db/d0083e8f2205b3b2dcd9670eb6f0d6c1b7cbfea6b01a1f8bff39142edf44/mlx-0.30.3-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:00ac867f3d003c1477a66a579442c2040ba7ea43ce3c174490d1f8bf379606bd", size = 619635, upload-time = "2026-01-14T01:17:01.812Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/90/ab0b93ff0e76da4fe0e878722c76a308cfb950b044a4676e9617276d8ccd/mlx-0.30.3-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:5be7d0329036f09c6ed003ea3e307e97e3144f20a3e4711b01810d7d5013cf2c", size = 659652, upload-time = "2026-01-14T01:17:02.915Z" },
|
||||
]
|
||||
@@ -1056,6 +1040,14 @@ cuda13 = [
|
||||
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.4.dev20260121+fbe306f9"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.3"
|
||||
@@ -1086,7 +1078,7 @@ version = "0.30.4"
|
||||
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1094,16 +1086,6 @@ dependencies = [
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
version = "0.30.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/63/4d8f6fefb507c028df4454dabfe8d8e0ad2961bb06510b6aca23d2d5b2be/mlx_metal-0.30.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:6276312b02353714c7c6515169569fe1c4bebe3229c8ecf1fdb375a13e78c966", size = 37716245, upload-time = "2026-01-14T01:16:34.838Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/91/1d452e48a4bb4958844fd3bb28ae31b8de110549c009ebec5024ce27ebf3/mlx_metal-0.30.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:c096c0a3428f3f96a06220f97a36f9528b18bc05173f821eb05bc8458e723fa8", size = 37712125, upload-time = "2026-01-14T01:16:38.619Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/36/7a3cbca85542b5ca4faf871e35927f43aa0e3fc830ae5b699780fe723677/mlx_metal-0.30.3-py3-none-macosx_26_0_arm64.whl", hash = "sha256:69068533bd1ee8b0379ce5de57ed5fd313577a10ecab58e1332fd1ff7248a75e", size = 46488962, upload-time = "2026-01-14T05:52:04.523Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "more-itertools"
|
||||
version = "10.8.0"
|
||||
@@ -2227,6 +2209,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomlkit"
|
||||
version = "0.14.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167, upload-time = "2026-01-13T01:14:53.304Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310, upload-time = "2026-01-13T01:14:51.965Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.1"
|
||||
|
||||
Reference in New Issue
Block a user