Compare commits

..

2 Commits

Author SHA1 Message Date
Evan
879d178057 add kimi tool parseing 2026-01-21 16:53:38 +00:00
Evan
daa31b4472 implement mlx-lm tool calling 2026-01-21 16:53:38 +00:00
6 changed files with 248 additions and 16 deletions

View File

@@ -2,6 +2,7 @@ import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
from uuid import uuid4
import anyio
from anyio import BrokenResourceError, create_task_group
@@ -43,8 +44,9 @@ from exo.shared.types.api import (
PlacementPreview,
PlacementPreviewResponse,
StreamingChoiceResponse,
ToolCall,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -73,7 +75,7 @@ from exo.utils.event_buffer import OrderedBuffer
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -85,6 +87,22 @@ def chunk_to_response(
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
finish_reason=chunk.finish_reason,
)
if isinstance(chunk, TokenChunk)
else StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=[
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
],
),
finish_reason="tool_calls",
)
],
)
@@ -138,7 +156,9 @@ class API:
name="dashboard",
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._chat_completion_queues: dict[
CommandId, Sender[TokenChunk | ToolCallChunk]
] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -404,16 +424,21 @@ class API:
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
) -> AsyncGenerator[TokenChunk | ToolCallChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
self._chat_completion_queues[command_id], recv = channel[
TokenChunk | ToolCallChunk
]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
if (
isinstance(chunk, TokenChunk)
and chunk.finish_reason is not None
):
break
except anyio.get_cancelled_exc_class():
@@ -435,7 +460,7 @@ class API:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
if isinstance(chunk, TokenChunk) and chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
@@ -454,7 +479,7 @@ class API:
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
if isinstance(chunk, ToolCallChunk) or chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
@@ -467,6 +492,29 @@ class API:
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ToolCallChunk):
finish_reason = "tool_calls"
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model or chunk.model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
tool_calls=[
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
],
),
)
],
)
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -510,6 +558,11 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ToolCallChunk):
raise HTTPException(
status_code=500,
detail="Tool call in bench",
)
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -660,7 +713,7 @@ class API:
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
assert isinstance(event.chunk, (TokenChunk, ToolCallChunk))
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:

View File

@@ -51,6 +51,18 @@ class ChatCompletionMessageText(BaseModel):
text: str
class ToolCallItem(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
id: str
index: int | None = None
type: Literal["function"] = "function"
function: ToolCallItem
class ChatCompletionMessage(BaseModel):
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
content: (
@@ -58,7 +70,7 @@ class ChatCompletionMessage(BaseModel):
) = None
thinking: str | None = None # Added for GPT-OSS harmony format support
name: str | None = None
tool_calls: list[dict[str, Any]] | None = None
tool_calls: list[ToolCall] | None = None
tool_call_id: str | None = None
function_call: dict[str, Any] | None = None

View File

@@ -5,6 +5,7 @@ from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .worker.runner_response import ToolCallItem
class ChunkType(str, Enum):
@@ -25,8 +26,12 @@ class TokenChunk(BaseChunk):
error_message: str | None = None
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
class ImageChunk(BaseChunk):
data: bytes
GenerationChunk = TokenChunk | ImageChunk
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk

View File

@@ -1,4 +1,4 @@
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.api import FinishReason, GenerationStats, ToolCallItem
from exo.utils.pydantic_ext import TaggedModel
@@ -18,5 +18,9 @@ class GenerationResponse(BaseRunnerResponse):
stats: GenerationStats | None = None
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -1,6 +1,8 @@
import json
import time
from collections.abc import Generator
from functools import cache
from typing import Any, Callable, cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -11,9 +13,10 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -33,6 +36,8 @@ from exo.shared.types.tasks import (
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallItem,
ToolCallResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -136,6 +141,7 @@ def main(
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(f"model has_tool_calling={tokenizer.has_tool_calling}")
current_status = RunnerLoaded()
logger.info("runner loaded")
@@ -202,9 +208,17 @@ def main(
mlx_generator, tokenizer
)
# TODO: Add tool call parser here
# Kimi-K2 has tool call sections - we don't care about them
if "kimi" in shard_metadata.model_card.model_id.lower():
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
if tokenizer.has_tool_calling:
mlx_generator = parse_tool_calls(mlx_generator, tokenizer)
chunk_idx = -1
for response in mlx_generator:
chunk_idx += 1
match response:
case GenerationResponse():
if device_rank == 0:
@@ -212,7 +226,7 @@ def main(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
idx=chunk_idx,
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
@@ -221,6 +235,18 @@ def main(
),
)
)
case ToolCallResponse():
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ToolCallChunk(
idx=chunk_idx,
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
),
)
)
# can we make this more explicit?
except Exception as e:
@@ -276,6 +302,18 @@ def get_gpt_oss_encoding():
return encoding
def filter_kimi_tokens(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
):
continue
yield resp
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
@@ -329,6 +367,120 @@ def parse_thinking_models(
yield response
def parse_tool_calls(
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse | ToolCallResponse]:
assert tokenizer.tool_call_start is not None
assert tokenizer.tool_call_end is not None
tool_parser = cast(
Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None,
tokenizer.tool_parser,
) # first arg has followup args, but we don't care about them rn
assert tool_parser is not None
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
# assumption: the tool call start is one token
if response.text == tokenizer.tool_call_start:
in_tool_call = True
continue
# assumption: the tool call end is one token
if in_tool_call and response.text == tokenizer.tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if isinstance(parsed, list):
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools)
except (json.JSONDecodeError, ValidationError) as e:
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
tokenizer.tool_call_start
+ "".join(tool_call_text_parts)
+ tokenizer.tool_call_end
)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
continue
# fallthrough
yield response
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
"""
Version of to-be-upstreamed kimi-k2 tool parser
"""
import ast
import json
from typing import Any
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
_func_name_regex = re.compile(
r"^\s*(.+):\d+\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
tool_call_start = "<|tool_call_begin|>"
tool_call_end = "<|tool_call_end|>"
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name = _func_name_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
# strip off the `functions.` prefix, if it exists.
func_name = func_name[func_name.find(".") + 1 :]
func_args = _func_arg_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)
and ((args := obj.get("arguments")) is not None)
and isinstance(name, str)
):
return ToolCallItem(name=name, arguments=json.dumps(args))
else:
raise ValidationError
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -111,7 +111,7 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
@@ -140,6 +140,12 @@ class EventCollector:
pass
class MockTokenizer:
tool_parser = None
tool_call_start = None
tool_call_end = None
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,