mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-21 12:30:22 -05:00
Compare commits
2 Commits
main
...
tool-calli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
879d178057 | ||
|
|
daa31b4472 |
@@ -2,6 +2,7 @@ import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
@@ -43,8 +44,9 @@ from exo.shared.types.api import (
|
||||
PlacementPreview,
|
||||
PlacementPreviewResponse,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
@@ -73,7 +75,7 @@ from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
@@ -85,6 +87,22 @@ def chunk_to_response(
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
if isinstance(chunk, TokenChunk)
|
||||
else StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -138,7 +156,9 @@ class API:
|
||||
name="dashboard",
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
|
||||
self._chat_completion_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ToolCallChunk]
|
||||
] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
@@ -404,16 +424,21 @@ class API:
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
) -> AsyncGenerator[TokenChunk | ToolCallChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
self._chat_completion_queues[command_id], recv = channel[
|
||||
TokenChunk | ToolCallChunk
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
if (
|
||||
isinstance(chunk, TokenChunk)
|
||||
and chunk.finish_reason is not None
|
||||
):
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
@@ -435,7 +460,7 @@ class API:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
if isinstance(chunk, TokenChunk) and chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
@@ -454,7 +479,7 @@ class API:
|
||||
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if isinstance(chunk, ToolCallChunk) or chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
@@ -467,6 +492,29 @@ class API:
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
finish_reason = "tool_calls"
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model or chunk.model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -510,6 +558,11 @@ class API:
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Tool call in bench",
|
||||
)
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -660,7 +713,7 @@ class API:
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if isinstance(event, ChunkGenerated):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
assert isinstance(event.chunk, (TokenChunk, ToolCallChunk))
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
|
||||
@@ -51,6 +51,18 @@ class ChatCompletionMessageText(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class ToolCallItem(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str
|
||||
index: int | None = None
|
||||
type: Literal["function"] = "function"
|
||||
function: ToolCallItem
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "developer", "tool", "function"]
|
||||
content: (
|
||||
@@ -58,7 +70,7 @@ class ChatCompletionMessage(BaseModel):
|
||||
) = None
|
||||
thinking: str | None = None # Added for GPT-OSS harmony format support
|
||||
name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_call_id: str | None = None
|
||||
function_call: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from exo.shared.types.api import GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .worker.runner_response import ToolCallItem
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
@@ -25,8 +26,12 @@ class TokenChunk(BaseChunk):
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ToolCallChunk(BaseChunk):
|
||||
tool_calls: list[ToolCallItem]
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
data: bytes
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.api import FinishReason, GenerationStats, ToolCallItem
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -18,5 +18,9 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -11,9 +13,10 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -33,6 +36,8 @@ from exo.shared.types.tasks import (
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallItem,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
@@ -136,6 +141,7 @@ def main(
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(f"model has_tool_calling={tokenizer.has_tool_calling}")
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
@@ -202,9 +208,17 @@ def main(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tokenizer)
|
||||
|
||||
chunk_idx = -1
|
||||
for response in mlx_generator:
|
||||
chunk_idx += 1
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if device_rank == 0:
|
||||
@@ -212,7 +226,7 @@ def main(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
idx=chunk_idx,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
@@ -221,6 +235,18 @@ def main(
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
idx=chunk_idx,
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
@@ -276,6 +302,18 @@ def get_gpt_oss_encoding():
|
||||
return encoding
|
||||
|
||||
|
||||
def filter_kimi_tokens(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
for resp in responses:
|
||||
if (
|
||||
resp.text == "<|tool_calls_section_begin|>"
|
||||
or resp.text == "<|tool_calls_section_end|>"
|
||||
):
|
||||
continue
|
||||
yield resp
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
@@ -329,6 +367,120 @@ def parse_thinking_models(
|
||||
yield response
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
assert tokenizer.tool_call_start is not None
|
||||
assert tokenizer.tool_call_end is not None
|
||||
tool_parser = cast(
|
||||
Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None,
|
||||
tokenizer.tool_parser,
|
||||
) # first arg has followup args, but we don't care about them rn
|
||||
assert tool_parser is not None
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tokenizer.tool_call_start:
|
||||
in_tool_call = True
|
||||
continue
|
||||
# assumption: the tool call end is one token
|
||||
if in_tool_call and response.text == tokenizer.tool_call_end:
|
||||
try:
|
||||
# tool_parser returns an arbitrarily nested python dictionary
|
||||
# we actually don't want the python dictionary, we just want to
|
||||
# parse the top level { function: ..., arguments: ... } structure
|
||||
# as we're just gonna hand it back to the api anyway
|
||||
parsed = tool_parser("".join(tool_call_text_parts).strip())
|
||||
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||
if isinstance(parsed, list):
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
|
||||
except (json.JSONDecodeError, ValidationError) as e:
|
||||
logger.opt(exception=e).warning("tool call parsing failed")
|
||||
# assumption: talking about tool calls, not making a tool call
|
||||
response.text = (
|
||||
tokenizer.tool_call_start
|
||||
+ "".join(tool_call_text_parts)
|
||||
+ tokenizer.tool_call_end
|
||||
)
|
||||
yield response
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(response.text)
|
||||
continue
|
||||
# fallthrough
|
||||
yield response
|
||||
|
||||
|
||||
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Version of to-be-upstreamed kimi-k2 tool parser
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*(.+):\d+\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
|
||||
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
|
||||
tool_call_start = "<|tool_call_begin|>"
|
||||
tool_call_end = "<|tool_call_end|>"
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: Any | None = None):
|
||||
func_name = _func_name_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
|
||||
# strip off the `functions.` prefix, if it exists.
|
||||
func_name = func_name[func_name.find(".") + 1 :]
|
||||
|
||||
func_args = _func_arg_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
|
||||
# the args should be valid json - no need to check against our tools to deserialize
|
||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||
|
||||
return dict(name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
||||
if (
|
||||
((name := obj.get("name")) is not None)
|
||||
and ((args := obj.get("arguments")) is not None)
|
||||
and isinstance(name, str)
|
||||
):
|
||||
return ToolCallItem(name=name, arguments=json.dumps(args))
|
||||
else:
|
||||
raise ValidationError
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -111,7 +111,7 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
@@ -140,6 +140,12 @@ class EventCollector:
|
||||
pass
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
tool_parser = None
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
|
||||
Reference in New Issue
Block a user