Compare commits

...

3 Commits

Author SHA1 Message Date
ciaranbor
31c021aad8 Prevent deleting messages during streamed generation 2026-02-06 20:28:21 +00:00
ciaranbor
9394462e5f Clarify message deletion behaviour 2026-02-06 20:28:21 +00:00
rltakashige
3b2f553a25 Fix kimi tool calling id (#1413)
## Motivation

Kimi produces its own tool id. It gets confused when we generate our own
id.

## Changes

Add id to tool call item and parse Kimi id properly.

## Test Plan

### Manual Testing
<img width="3198" height="522" alt="image"
src="https://github.com/user-attachments/assets/d71ec2be-7f57-49dc-a569-d304cc430f4d"
/>

Long running Kimi K2.5 cluster querying itself through OpenCode running
on the same Kimi K2.5 instance.
2026-02-06 11:33:51 -08:00
7 changed files with 44 additions and 22 deletions

View File

@@ -221,6 +221,7 @@
}
function handleDeleteClick(messageId: string) {
if (loading) return;
deleteConfirmId = messageId;
}
@@ -251,7 +252,7 @@
</script>
<div class="flex flex-col gap-4 sm:gap-6 {className}">
{#each messageList as message (message.id)}
{#each messageList as message, i (message.id)}
<div
class="group flex {message.role === 'user'
? 'justify-end'
@@ -313,9 +314,11 @@
<!-- Delete confirmation -->
<div class="bg-red-500/10 border border-red-500/30 rounded-lg p-3">
<p class="text-xs text-red-400 mb-3">
Delete this message{message.role === "user"
? " and all responses after it"
: ""}?
{#if i === messageList.length - 1}
Delete this message?
{:else}
Delete this message and all messages after it?
{/if}
</p>
<div class="flex gap-2 justify-end">
<button
@@ -713,8 +716,13 @@
<!-- Delete button -->
<button
onclick={() => handleDeleteClick(message.id)}
class="p-1.5 text-exo-light-gray hover:text-red-400 transition-colors rounded hover:bg-red-500/10 cursor-pointer"
title="Delete message"
disabled={loading}
class="p-1.5 transition-colors rounded {loading
? 'text-exo-light-gray/30 cursor-not-allowed'
: 'text-exo-light-gray hover:text-red-400 hover:bg-red-500/10 cursor-pointer'}"
title={loading
? "Cannot delete while generating"
: "Delete message"}
>
<svg
class="w-3.5 h-3.5"

View File

@@ -3,7 +3,6 @@
import time
from collections.abc import AsyncGenerator
from typing import Any
from uuid import uuid4
from exo.shared.types.api import (
ChatCompletionChoice,
@@ -141,7 +140,7 @@ async def generate_chat_stream(
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
id=str(uuid4()),
id=tool.id,
index=i,
function=tool,
)
@@ -207,7 +206,7 @@ async def collect_chat_response(
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
id=tool.id,
index=i,
function=tool,
)

View File

@@ -3,7 +3,6 @@
import json
from collections.abc import AsyncGenerator
from typing import Any
from uuid import uuid4
from exo.shared.types.api import FinishReason
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
@@ -179,7 +178,7 @@ async def collect_claude_response(
for tool in chunk.tool_calls:
tool_use_blocks.append(
ClaudeToolUseBlock(
id=f"toolu_{uuid4().hex[:24]}",
id=f"toolu_{tool.id}",
name=tool.name,
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
)
@@ -264,7 +263,7 @@ async def generate_claude_stream(
# Emit tool_use content blocks
for tool in chunk.tool_calls:
tool_id = f"toolu_{uuid4().hex[:24]}"
tool_id = f"toolu_{tool.id}"
tool_input_json = tool.arguments
# content_block_start for tool_use

View File

@@ -3,7 +3,6 @@
from collections.abc import AsyncGenerator
from itertools import count
from typing import Any
from uuid import uuid4
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
@@ -140,8 +139,8 @@ async def collect_responses_response(
for tool in chunk.tool_calls:
function_call_items.append(
ResponseFunctionCallItem(
id=f"fc_{uuid4().hex[:24]}",
call_id=f"call_{uuid4().hex[:24]}",
id=f"fc_{tool.id}",
call_id=f"call_{tool.id}",
name=tool.name,
arguments=tool.arguments,
)
@@ -246,8 +245,8 @@ async def generate_responses_stream(
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
for tool in chunk.tool_calls:
fc_id = f"fc_{uuid4().hex[:24]}"
call_id = f"call_{uuid4().hex[:24]}"
fc_id = f"fc_{tool.id}"
call_id = f"call_{tool.id}"
# response.output_item.added for function_call
fc_item = ResponseFunctionCallItem(

View File

@@ -1,6 +1,7 @@
import time
from collections.abc import Generator
from typing import Annotated, Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
@@ -60,6 +61,7 @@ class ChatCompletionMessageText(BaseModel):
class ToolCallItem(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
arguments: str

View File

@@ -298,6 +298,9 @@ def mlx_generate(
)
max_stop_len = max((len(s) for s in stop_sequences), default=0)
mx_barrier(group)
logger.info("Ready to prefill")
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
model,

View File

@@ -810,8 +810,9 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
_func_name_regex = re.compile(
r"^\s*(.+):\d+\s*<\|tool_call_argument_begin\|>", re.DOTALL
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
@@ -835,9 +836,10 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
original_func_name = func_name_match.group(1)
tool_id = func_name_match.group(2)
# strip off the `functions.` prefix, if it exists.
func_name = func_name[func_name.find(".") + 1 :]
func_name = original_func_name[original_func_name.find(".") + 1 :]
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
@@ -846,7 +848,11 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
return dict(
id=f"{original_func_name}:{tool_id}",
name=func_name,
arguments=arg_dct, # pyright: ignore[reportAny]
)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
@@ -929,7 +935,13 @@ def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
and ((args := obj.get("arguments")) is not None)
and isinstance(name, str)
):
return ToolCallItem(name=name, arguments=json.dumps(args))
raw_id: object = obj.get("id")
extra = {"id": str(raw_id)} if raw_id is not None else {}
return ToolCallItem(
**extra,
name=name,
arguments=json.dumps(args),
)
else:
raise ValidationError