mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-20 20:10:10 -05:00
Compare commits
1 Commits
leo/reduce
...
tool-calli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53aec2fe25 |
@@ -44,7 +44,7 @@ from exo.shared.types.api import (
|
||||
PlacementPreviewResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
@@ -73,7 +73,7 @@ from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
@@ -85,6 +85,15 @@ def chunk_to_response(
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
if isinstance(chunk, TokenChunk)
|
||||
else StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[tool.model_dump() for tool in chunk.tool_calls],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -138,7 +147,9 @@ class API:
|
||||
name="dashboard",
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
|
||||
self._chat_completion_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ToolCallChunk]
|
||||
] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
@@ -404,16 +415,21 @@ class API:
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
) -> AsyncGenerator[TokenChunk | ToolCallChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
self._chat_completion_queues[command_id], recv = channel[
|
||||
TokenChunk | ToolCallChunk
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
if (
|
||||
isinstance(chunk, TokenChunk)
|
||||
and chunk.finish_reason is not None
|
||||
):
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
@@ -435,7 +451,7 @@ class API:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
if isinstance(chunk, TokenChunk) and chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
@@ -454,7 +470,7 @@ class API:
|
||||
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if isinstance(chunk, ToolCallChunk) or chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
@@ -467,6 +483,24 @@ class API:
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
finish_reason = "tool_calls"
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model or chunk.model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
tool.model_dump() for tool in chunk.tool_calls
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -510,6 +544,11 @@ class API:
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Tool call in bench",
|
||||
)
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -660,7 +699,7 @@ class API:
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if isinstance(event, ChunkGenerated):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
assert isinstance(event.chunk, (TokenChunk, ToolCallChunk))
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
|
||||
@@ -120,7 +120,7 @@ def place_instance(
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
|
||||
if len(selected_cycle) == 1:
|
||||
logger.debug(
|
||||
logger.warning(
|
||||
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from exo.shared.types.api import GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .worker.runner_response import ToolCallItem
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
@@ -25,8 +26,12 @@ class TokenChunk(BaseChunk):
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ToolCallChunk(BaseChunk):
|
||||
tool_calls: list[ToolCallItem]
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
data: bytes
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
class BaseRunnerResponse(TaggedModel):
|
||||
@@ -10,6 +10,11 @@ class TokenizedResponse(BaseRunnerResponse):
|
||||
prompt_tokens: int
|
||||
|
||||
|
||||
class ToolCallItem(CamelCaseModel):
|
||||
arguments: str
|
||||
name: str
|
||||
|
||||
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
@@ -18,5 +23,9 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -11,9 +13,10 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -33,6 +36,8 @@ from exo.shared.types.tasks import (
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallItem,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
@@ -202,7 +207,13 @@ def main(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
# TODO: Add call parser here
|
||||
if (
|
||||
tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
or tokenizer.tool_call_start
|
||||
or tokenizer.tool_call_end
|
||||
) is not None:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tokenizer)
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
@@ -212,7 +223,7 @@ def main(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
idx=response.token, # hang on --- is this wrong?? this is totally wrong
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
@@ -221,6 +232,18 @@ def main(
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
idx=4,
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
@@ -329,6 +352,50 @@ def parse_thinking_models(
|
||||
yield response
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
assert tokenizer.tool_call_start is not None
|
||||
assert tokenizer.tool_call_end is not None
|
||||
tool_parser = cast(
|
||||
Callable[[str], dict[str, Any] | list[dict[str, Any]]] | None,
|
||||
tokenizer.tool_parser,
|
||||
) # first arg has followup args, but we don't care about them rn
|
||||
assert tool_parser is not None
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tokenizer.tool_call_start:
|
||||
in_tool_call = True
|
||||
continue
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(response.text)
|
||||
continue
|
||||
# assumption: the tool call end is one token
|
||||
if response.text == tokenizer.tool_call_end:
|
||||
try:
|
||||
# tool_parser returns an arbitrarily nested python dictionary
|
||||
# we actually don't want the python dictionary, we just want to
|
||||
# parse the top level { function: ..., arguments: ... } structure
|
||||
# as we're just gonna hand it back to the api anyway
|
||||
parsed = tool_parser("".join(tool_call_text_parts).strip())
|
||||
if isinstance(parsed, list):
|
||||
tools = [
|
||||
ToolCallItem.model_validate_json(json.dumps(tool))
|
||||
for tool in parsed
|
||||
]
|
||||
else:
|
||||
tools = [ToolCallItem.model_validate_json(json.dumps(parsed))]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
except (json.JSONDecodeError, ValidationError):
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
# fallthrough
|
||||
yield response
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
Reference in New Issue
Block a user