Compare commits

...

1 Commits

Author SHA1 Message Date
Evan
53aec2fe25 implement mlx-lm tool calling 2026-01-21 00:27:20 +00:00
4 changed files with 134 additions and 14 deletions

View File

@@ -44,7 +44,7 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -73,7 +73,7 @@ from exo.utils.event_buffer import OrderedBuffer
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
@@ -85,6 +85,15 @@ def chunk_to_response(
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
finish_reason=chunk.finish_reason,
)
if isinstance(chunk, TokenChunk)
else StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=[tool.model_dump() for tool in chunk.tool_calls],
),
finish_reason="tool_calls",
)
],
)
@@ -138,7 +147,9 @@ class API:
name="dashboard",
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._chat_completion_queues: dict[
CommandId, Sender[TokenChunk | ToolCallChunk]
] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -404,16 +415,21 @@ class API:
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
) -> AsyncGenerator[TokenChunk | ToolCallChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
self._chat_completion_queues[command_id], recv = channel[
TokenChunk | ToolCallChunk
]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
if (
isinstance(chunk, TokenChunk)
and chunk.finish_reason is not None
):
break
except anyio.get_cancelled_exc_class():
@@ -435,7 +451,7 @@ class API:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
if isinstance(chunk, TokenChunk) and chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
@@ -454,7 +470,7 @@ class API:
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
if isinstance(chunk, ToolCallChunk) or chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
@@ -467,6 +483,24 @@ class API:
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ToolCallChunk):
finish_reason = "tool_calls"
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model or chunk.model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
tool_calls=[
tool.model_dump() for tool in chunk.tool_calls
],
),
)
],
)
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -510,6 +544,11 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ToolCallChunk):
raise HTTPException(
status_code=500,
detail="Tool call in bench",
)
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -660,7 +699,7 @@ class API:
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
assert isinstance(event.chunk, TokenChunk)
assert isinstance(event.chunk, (TokenChunk, ToolCallChunk))
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:

View File

@@ -5,6 +5,7 @@ from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .worker.runner_response import ToolCallItem
class ChunkType(str, Enum):
@@ -25,8 +26,12 @@ class TokenChunk(BaseChunk):
error_message: str | None = None
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
class ImageChunk(BaseChunk):
data: bytes
GenerationChunk = TokenChunk | ImageChunk
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk

View File

@@ -1,5 +1,5 @@
from exo.shared.types.api import FinishReason, GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
class BaseRunnerResponse(TaggedModel):
@@ -10,6 +10,11 @@ class TokenizedResponse(BaseRunnerResponse):
prompt_tokens: int
class ToolCallItem(CamelCaseModel):
arguments: str
name: str
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
@@ -18,5 +23,9 @@ class GenerationResponse(BaseRunnerResponse):
stats: GenerationStats | None = None
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -1,6 +1,8 @@
import json
import time
from collections.abc import Generator
from functools import cache
from typing import Any, Callable, cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -11,9 +13,10 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -33,6 +36,8 @@ from exo.shared.types.tasks import (
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallItem,
ToolCallResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -202,7 +207,13 @@ def main(
mlx_generator, tokenizer
)
# TODO: Add tool call parser here
# TODO: Add call parser here
if (
tokenizer.tool_parser # pyright: ignore[reportAny]
or tokenizer.tool_call_start
or tokenizer.tool_call_end
) is not None:
mlx_generator = parse_tool_calls(mlx_generator, tokenizer)
for response in mlx_generator:
match response:
@@ -212,7 +223,7 @@ def main(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
idx=response.token, # hang on --- is this wrong?? this is totally wrong
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
@@ -221,6 +232,18 @@ def main(
),
)
)
case ToolCallResponse():
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ToolCallChunk(
idx=4,
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
),
)
)
# can we make this more explicit?
except Exception as e:
@@ -329,6 +352,50 @@ def parse_thinking_models(
yield response
def parse_tool_calls(
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse | ToolCallResponse]:
assert tokenizer.tool_call_start is not None
assert tokenizer.tool_call_end is not None
tool_parser = cast(
Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None,
tokenizer.tool_parser,
) # first arg has followup args, but we don't care about them rn
assert tool_parser is not None
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
# assumption: the tool call start is one token
if response.text == tokenizer.tool_call_start:
in_tool_call = True
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
continue
# assumption: the tool call end is one token
if response.text == tokenizer.tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
if isinstance(parsed, list):
tools = [
ToolCallItem.model_validate_json(json.dumps(tool))
for tool in parsed
]
else:
tools = [ToolCallItem.model_validate_json(json.dumps(parsed))]
yield ToolCallResponse(tool_calls=tools)
except (json.JSONDecodeError, ValidationError):
in_tool_call = False
tool_call_text_parts = []
# fallthrough
yield response
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"