mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 23:36:30 -05:00
Compare commits
5 Commits
feat/bug-r
...
leo/tmp-te
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11f9e2e1d8 | ||
|
|
870d53329b | ||
|
|
4dc8654ee7 | ||
|
|
76d7994fa1 | ||
|
|
2c1b8f79b2 |
@@ -3,7 +3,6 @@
|
|||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from exo.shared.types.api import (
|
from exo.shared.types.api import (
|
||||||
ChatCompletionChoice,
|
ChatCompletionChoice,
|
||||||
@@ -141,7 +140,7 @@ async def generate_chat_stream(
|
|||||||
if isinstance(chunk, ToolCallChunk):
|
if isinstance(chunk, ToolCallChunk):
|
||||||
tool_call_deltas = [
|
tool_call_deltas = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=str(uuid4()),
|
id=tool.id,
|
||||||
index=i,
|
index=i,
|
||||||
function=tool,
|
function=tool,
|
||||||
)
|
)
|
||||||
@@ -207,7 +206,7 @@ async def collect_chat_response(
|
|||||||
if isinstance(chunk, ToolCallChunk):
|
if isinstance(chunk, ToolCallChunk):
|
||||||
tool_calls.extend(
|
tool_calls.extend(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=str(uuid4()),
|
id=tool.id,
|
||||||
index=i,
|
index=i,
|
||||||
function=tool,
|
function=tool,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from exo.shared.types.api import FinishReason
|
from exo.shared.types.api import FinishReason
|
||||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||||
@@ -179,7 +178,7 @@ async def collect_claude_response(
|
|||||||
for tool in chunk.tool_calls:
|
for tool in chunk.tool_calls:
|
||||||
tool_use_blocks.append(
|
tool_use_blocks.append(
|
||||||
ClaudeToolUseBlock(
|
ClaudeToolUseBlock(
|
||||||
id=f"toolu_{uuid4().hex[:24]}",
|
id=f"toolu_{tool.id}",
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
|
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
|
||||||
)
|
)
|
||||||
@@ -264,7 +263,7 @@ async def generate_claude_stream(
|
|||||||
|
|
||||||
# Emit tool_use content blocks
|
# Emit tool_use content blocks
|
||||||
for tool in chunk.tool_calls:
|
for tool in chunk.tool_calls:
|
||||||
tool_id = f"toolu_{uuid4().hex[:24]}"
|
tool_id = f"toolu_{tool.id}"
|
||||||
tool_input_json = tool.arguments
|
tool_input_json = tool.arguments
|
||||||
|
|
||||||
# content_block_start for tool_use
|
# content_block_start for tool_use
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||||
from exo.shared.types.common import CommandId
|
from exo.shared.types.common import CommandId
|
||||||
@@ -140,8 +139,8 @@ async def collect_responses_response(
|
|||||||
for tool in chunk.tool_calls:
|
for tool in chunk.tool_calls:
|
||||||
function_call_items.append(
|
function_call_items.append(
|
||||||
ResponseFunctionCallItem(
|
ResponseFunctionCallItem(
|
||||||
id=f"fc_{uuid4().hex[:24]}",
|
id=f"fc_{tool.id}",
|
||||||
call_id=f"call_{uuid4().hex[:24]}",
|
call_id=f"call_{tool.id}",
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool.arguments,
|
arguments=tool.arguments,
|
||||||
)
|
)
|
||||||
@@ -246,8 +245,8 @@ async def generate_responses_stream(
|
|||||||
if isinstance(chunk, ToolCallChunk):
|
if isinstance(chunk, ToolCallChunk):
|
||||||
last_stats = chunk.stats or last_stats
|
last_stats = chunk.stats or last_stats
|
||||||
for tool in chunk.tool_calls:
|
for tool in chunk.tool_calls:
|
||||||
fc_id = f"fc_{uuid4().hex[:24]}"
|
fc_id = f"fc_{tool.id}"
|
||||||
call_id = f"call_{uuid4().hex[:24]}"
|
call_id = f"call_{tool.id}"
|
||||||
|
|
||||||
# response.output_item.added for function_call
|
# response.output_item.added for function_call
|
||||||
fc_item = ResponseFunctionCallItem(
|
fc_item = ResponseFunctionCallItem(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from pydantic_core import PydanticUseDefault
|
from pydantic_core import PydanticUseDefault
|
||||||
@@ -60,6 +61,7 @@ class ChatCompletionMessageText(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ToolCallItem(BaseModel):
|
class ToolCallItem(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
name: str
|
name: str
|
||||||
arguments: str
|
arguments: str
|
||||||
|
|
||||||
|
|||||||
@@ -298,6 +298,9 @@ def mlx_generate(
|
|||||||
)
|
)
|
||||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||||
|
|
||||||
|
mx_barrier(group)
|
||||||
|
logger.info("Ready to prefill")
|
||||||
|
|
||||||
# Prefill cache with all tokens except the last one
|
# Prefill cache with all tokens except the last one
|
||||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||||
model,
|
model,
|
||||||
|
|||||||
@@ -810,8 +810,9 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
|||||||
|
|
||||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||||
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
|
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
|
||||||
|
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
|
||||||
_func_name_regex = re.compile(
|
_func_name_regex = re.compile(
|
||||||
r"^\s*(.+):\d+\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
r"^\s*(.+)[:_](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||||
)
|
)
|
||||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||||
|
|
||||||
@@ -836,6 +837,7 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
|||||||
if func_name_match is None:
|
if func_name_match is None:
|
||||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||||
func_name = func_name_match.group(1)
|
func_name = func_name_match.group(1)
|
||||||
|
tool_id = func_name_match.group(2)
|
||||||
# strip off the `functions.` prefix, if it exists.
|
# strip off the `functions.` prefix, if it exists.
|
||||||
func_name = func_name[func_name.find(".") + 1 :]
|
func_name = func_name[func_name.find(".") + 1 :]
|
||||||
|
|
||||||
@@ -846,7 +848,7 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
|||||||
# the args should be valid json - no need to check against our tools to deserialize
|
# the args should be valid json - no need to check against our tools to deserialize
|
||||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||||
|
|
||||||
return dict(name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
|
return dict(id=tool_id, name=func_name, arguments=arg_dct) # pyright: ignore[reportAny]
|
||||||
|
|
||||||
tokenizer._tool_call_start = tool_call_start
|
tokenizer._tool_call_start = tool_call_start
|
||||||
tokenizer._tool_call_end = tool_call_end
|
tokenizer._tool_call_end = tool_call_end
|
||||||
@@ -929,7 +931,13 @@ def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
|||||||
and ((args := obj.get("arguments")) is not None)
|
and ((args := obj.get("arguments")) is not None)
|
||||||
and isinstance(name, str)
|
and isinstance(name, str)
|
||||||
):
|
):
|
||||||
return ToolCallItem(name=name, arguments=json.dumps(args))
|
raw_id: object = obj.get("id")
|
||||||
|
extra = {"id": str(raw_id)} if raw_id is not None else {}
|
||||||
|
return ToolCallItem(
|
||||||
|
**extra,
|
||||||
|
name=name,
|
||||||
|
arguments=json.dumps(args),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValidationError
|
raise ValidationError
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user