Compare commits

...

5 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
11f9e2e1d8 fix chat completions call_ 2026-02-06 18:44:34 +00:00
Ryuichi Leo Takashige
870d53329b fix tool calling again 2026-02-06 18:26:11 +00:00
Ryuichi Leo Takashige
4dc8654ee7 move 2026-02-06 17:36:53 +00:00
Ryuichi Leo Takashige
76d7994fa1 Fix kimi tool calling id 2026-02-06 17:31:55 +00:00
Ryuichi Leo Takashige
2c1b8f79b2 Fix kimi tool calling id 2026-02-06 17:17:36 +00:00
6 changed files with 24 additions and 14 deletions

View File

@@ -3,7 +3,6 @@
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from uuid import uuid4
from exo.shared.types.api import ( from exo.shared.types.api import (
ChatCompletionChoice, ChatCompletionChoice,
@@ -141,7 +140,7 @@ async def generate_chat_stream(
if isinstance(chunk, ToolCallChunk): if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [ tool_call_deltas = [
ToolCall( ToolCall(
id=str(uuid4()), id=tool.id,
index=i, index=i,
function=tool, function=tool,
) )
@@ -207,7 +206,7 @@ async def collect_chat_response(
if isinstance(chunk, ToolCallChunk): if isinstance(chunk, ToolCallChunk):
tool_calls.extend( tool_calls.extend(
ToolCall( ToolCall(
id=str(uuid4()), id=tool.id,
index=i, index=i,
function=tool, function=tool,
) )

View File

@@ -3,7 +3,6 @@
import json import json
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from uuid import uuid4
from exo.shared.types.api import FinishReason from exo.shared.types.api import FinishReason
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
@@ -179,7 +178,7 @@ async def collect_claude_response(
for tool in chunk.tool_calls: for tool in chunk.tool_calls:
tool_use_blocks.append( tool_use_blocks.append(
ClaudeToolUseBlock( ClaudeToolUseBlock(
id=f"toolu_{uuid4().hex[:24]}", id=f"toolu_{tool.id}",
name=tool.name, name=tool.name,
input=json.loads(tool.arguments), # pyright: ignore[reportAny] input=json.loads(tool.arguments), # pyright: ignore[reportAny]
) )
@@ -264,7 +263,7 @@ async def generate_claude_stream(
# Emit tool_use content blocks # Emit tool_use content blocks
for tool in chunk.tool_calls: for tool in chunk.tool_calls:
tool_id = f"toolu_{uuid4().hex[:24]}" tool_id = f"toolu_{tool.id}"
tool_input_json = tool.arguments tool_input_json = tool.arguments
# content_block_start for tool_use # content_block_start for tool_use

View File

@@ -3,7 +3,6 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from itertools import count from itertools import count
from typing import Any from typing import Any
from uuid import uuid4
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId from exo.shared.types.common import CommandId
@@ -140,8 +139,8 @@ async def collect_responses_response(
for tool in chunk.tool_calls: for tool in chunk.tool_calls:
function_call_items.append( function_call_items.append(
ResponseFunctionCallItem( ResponseFunctionCallItem(
id=f"fc_{uuid4().hex[:24]}", id=f"fc_{tool.id}",
call_id=f"call_{uuid4().hex[:24]}", call_id=f"call_{tool.id}",
name=tool.name, name=tool.name,
arguments=tool.arguments, arguments=tool.arguments,
) )
@@ -246,8 +245,8 @@ async def generate_responses_stream(
if isinstance(chunk, ToolCallChunk): if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats last_stats = chunk.stats or last_stats
for tool in chunk.tool_calls: for tool in chunk.tool_calls:
fc_id = f"fc_{uuid4().hex[:24]}" fc_id = f"fc_{tool.id}"
call_id = f"call_{uuid4().hex[:24]}" call_id = f"call_{tool.id}"
# response.output_item.added for function_call # response.output_item.added for function_call
fc_item = ResponseFunctionCallItem( fc_item = ResponseFunctionCallItem(

View File

@@ -1,6 +1,7 @@
import time import time
from collections.abc import Generator from collections.abc import Generator
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault from pydantic_core import PydanticUseDefault
@@ -60,6 +61,7 @@ class ChatCompletionMessageText(BaseModel):
class ToolCallItem(BaseModel): class ToolCallItem(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str name: str
arguments: str arguments: str

View File

@@ -298,6 +298,9 @@ def mlx_generate(
) )
max_stop_len = max((len(s) for s in stop_sequences), default=0) max_stop_len = max((len(s) for s in stop_sequences), default=0)
mx_barrier(group)
logger.info("Ready to prefill")
# Prefill cache with all tokens except the last one # Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill( prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
model, model,

View File

@@ -810,8 +810,9 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
# kimi has a fixed function naming scheme, with a json formatted arg # kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3} # functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
_func_name_regex = re.compile( _func_name_regex = re.compile(
r"^\s*(.+):\d+\s*<\|tool_call_argument_begin\|>", re.DOTALL r"^\s*(.+)[:_](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
) )
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL) _func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
@@ -836,6 +837,7 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
if func_name_match is None: if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}") raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1) func_name = func_name_match.group(1)
tool_id = func_name_match.group(2)
# strip off the `functions.` prefix, if it exists. # strip off the `functions.` prefix, if it exists.
func_name = func_name[func_name.find(".") + 1 :] func_name = func_name[func_name.find(".") + 1 :]
@@ -846,7 +848,7 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
# the args should be valid json - no need to check against our tools to deserialize # the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny] arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(name=func_name, arguments=arg_dct) # pyright: ignore[reportAny] return dict(id=tool_id, name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
tokenizer._tool_call_start = tool_call_start tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end tokenizer._tool_call_end = tool_call_end
@@ -929,7 +931,13 @@ def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
and ((args := obj.get("arguments")) is not None) and ((args := obj.get("arguments")) is not None)
and isinstance(name, str) and isinstance(name, str)
): ):
return ToolCallItem(name=name, arguments=json.dumps(args)) raw_id: object = obj.get("id")
extra = {"id": str(raw_id)} if raw_id is not None else {}
return ToolCallItem(
**extra,
name=name,
arguments=json.dumps(args),
)
else: else:
raise ValidationError raise ValidationError