mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-21 12:30:22 -05:00
Compare commits
2 Commits
tool-calli
...
leo/add-to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
826da9512d | ||
|
|
4d9114b9b5 |
378
bench/exo_eval.py
Normal file
378
bench/exo_eval.py
Normal file
@@ -0,0 +1,378 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false, reportMissingTypeStubs=false
|
||||
"""
|
||||
exo-eval: Run SWE-bench evaluation against exo using OpenHands SDK (local, no Docker).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from datasets import load_dataset
|
||||
from loguru import logger
|
||||
from openhands.sdk import LLM, Agent, Conversation, Tool
|
||||
from openhands.tools.file_editor import FileEditorTool
|
||||
from openhands.tools.terminal import TerminalTool
|
||||
|
||||
|
||||
class EvalStatus(str, Enum):
|
||||
Resolved = "Resolved"
|
||||
Failed = "Failed"
|
||||
Error = "Error"
|
||||
Timeout = "Timeout"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalResult:
|
||||
instance_id: str
|
||||
repo: str
|
||||
status: EvalStatus
|
||||
elapsed_seconds: float
|
||||
tests_passed: list[str]
|
||||
tests_failed: list[str]
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
def load_swe_bench(
|
||||
split: str = "lite",
|
||||
limit: int | None = None,
|
||||
instance_ids: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Load SWE-bench dataset from HuggingFace."""
|
||||
# SWE-bench Lite is a curated 300-instance subset
|
||||
dataset_name = (
|
||||
"princeton-nlp/SWE-bench_Lite" if split == "lite" else "princeton-nlp/SWE-bench"
|
||||
)
|
||||
actual_split = "test" if split == "lite" else split
|
||||
|
||||
ds = load_dataset(dataset_name, split=actual_split)
|
||||
instances = [dict(row) for row in ds]
|
||||
|
||||
if instance_ids:
|
||||
instances = [i for i in instances if i["instance_id"] in instance_ids]
|
||||
|
||||
if limit:
|
||||
instances = instances[:limit]
|
||||
|
||||
return instances
|
||||
|
||||
|
||||
def clone_repo_at_commit(repo: str, commit: str, dest: Path) -> None:
|
||||
"""Clone a repo at a specific commit."""
|
||||
repo_url = f"https://github.com/{repo}.git"
|
||||
|
||||
subprocess.run(
|
||||
["git", "clone", "--depth", "1", repo_url, str(dest)],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
subprocess.run(
|
||||
["git", "fetch", "--depth", "1", "origin", commit],
|
||||
cwd=dest,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
subprocess.run(
|
||||
["git", "checkout", commit],
|
||||
cwd=dest,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
|
||||
def build_agent_prompt(instance: dict[str, Any]) -> str:
|
||||
"""Build the prompt for the agent."""
|
||||
return f"""You are a software engineer fixing a bug in the {instance['repo']} repository.
|
||||
|
||||
## Problem Statement
|
||||
{instance['problem_statement']}
|
||||
|
||||
## Instructions
|
||||
1. Explore the codebase to understand the issue
|
||||
2. Identify the files that need to be modified
|
||||
3. Make the necessary changes to fix the issue
|
||||
4. The fix should be minimal and targeted
|
||||
|
||||
You have access to:
|
||||
- terminal: Run shell commands (git, grep, python, etc.)
|
||||
- file_editor: View and edit files
|
||||
|
||||
Start by exploring the repository structure to understand where the relevant code is.
|
||||
"""
|
||||
|
||||
|
||||
def parse_fail_to_pass(fail_to_pass_str: str) -> list[str]:
|
||||
"""Parse the FAIL_TO_PASS field into a list of test names."""
|
||||
try:
|
||||
return json.loads(fail_to_pass_str)
|
||||
except json.JSONDecodeError:
|
||||
return [t.strip() for t in fail_to_pass_str.split(",") if t.strip()]
|
||||
|
||||
|
||||
def run_tests(workspace: Path, tests: list[str]) -> tuple[list[str], list[str]]:
|
||||
"""Run tests and return (passed, failed) lists."""
|
||||
passed = []
|
||||
failed = []
|
||||
|
||||
for test in tests:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["python", "-m", "pytest", "-xvs", test],
|
||||
cwd=workspace,
|
||||
capture_output=True,
|
||||
timeout=300,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
passed.append(test)
|
||||
else:
|
||||
failed.append(test)
|
||||
except subprocess.TimeoutExpired:
|
||||
failed.append(test)
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
def run_single_eval(
|
||||
instance: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
max_turns: int = 30,
|
||||
timeout: float = 600.0,
|
||||
) -> EvalResult:
|
||||
"""Evaluate a single SWE-bench instance."""
|
||||
instance_id = instance["instance_id"]
|
||||
repo = instance["repo"]
|
||||
base_commit = instance["base_commit"]
|
||||
fail_to_pass = parse_fail_to_pass(instance["FAIL_TO_PASS"])
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
workspace = Path(tmpdir) / "repo"
|
||||
|
||||
# Clone repo at base commit
|
||||
logger.info(f"Cloning {repo} at {base_commit[:8]}...")
|
||||
clone_repo_at_commit(repo, base_commit, workspace)
|
||||
|
||||
# Setup OpenHands agent
|
||||
llm = LLM(
|
||||
model=f"openai/{model}",
|
||||
base_url=f"http://{host}:{port}/v1",
|
||||
api_key="not-needed",
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
llm=llm,
|
||||
tools=[
|
||||
Tool(name=TerminalTool.name),
|
||||
Tool(name=FileEditorTool.name),
|
||||
],
|
||||
)
|
||||
|
||||
# Run agent
|
||||
conversation = Conversation(
|
||||
agent=agent,
|
||||
workspace=str(workspace),
|
||||
)
|
||||
|
||||
logger.info(f"Running agent on {instance_id}...")
|
||||
conversation.send_message(build_agent_prompt(instance))
|
||||
|
||||
for _turn in range(max_turns):
|
||||
if time.perf_counter() - start_time > timeout:
|
||||
return EvalResult(
|
||||
instance_id=instance_id,
|
||||
repo=repo,
|
||||
status=EvalStatus.Timeout,
|
||||
elapsed_seconds=time.perf_counter() - start_time,
|
||||
tests_passed=[],
|
||||
tests_failed=fail_to_pass,
|
||||
)
|
||||
|
||||
result = conversation.run(max_turns=1)
|
||||
if result.done:
|
||||
break
|
||||
|
||||
# Run tests to verify
|
||||
logger.info(f"Running tests for {instance_id}...")
|
||||
passed, failed = run_tests(workspace, fail_to_pass)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
status = EvalStatus.Resolved if not failed else EvalStatus.Failed
|
||||
|
||||
return EvalResult(
|
||||
instance_id=instance_id,
|
||||
repo=repo,
|
||||
status=status,
|
||||
elapsed_seconds=elapsed,
|
||||
tests_passed=passed,
|
||||
tests_failed=failed,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return EvalResult(
|
||||
instance_id=instance_id,
|
||||
repo=repo,
|
||||
status=EvalStatus.Error,
|
||||
elapsed_seconds=time.perf_counter() - start_time,
|
||||
tests_passed=[],
|
||||
tests_failed=[],
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
||||
def verify_exo_running(host: str, port: int, model: str) -> str:
|
||||
"""Verify exo is running and return full model ID."""
|
||||
import http.client
|
||||
|
||||
conn = http.client.HTTPConnection(host, port, timeout=10)
|
||||
conn.request("GET", "/models")
|
||||
resp = conn.getresponse()
|
||||
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"exo not responding at {host}:{port}")
|
||||
|
||||
data = json.loads(resp.read())
|
||||
for m in data.get("data", []):
|
||||
if m.get("id") == model or m.get("hugging_face_id") == model:
|
||||
return m.get("hugging_face_id") or m.get("id")
|
||||
|
||||
raise ValueError(f"Model '{model}' not found in exo")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-eval",
|
||||
description="Run SWE-bench evaluation against exo (local, no Docker).",
|
||||
)
|
||||
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="exo model ID")
|
||||
ap.add_argument(
|
||||
"--split", default="lite", choices=["lite", "dev", "test", "train"]
|
||||
)
|
||||
ap.add_argument("--limit", type=int, default=10, help="Max instances")
|
||||
ap.add_argument("--instance-ids", nargs="+", help="Specific instance IDs")
|
||||
ap.add_argument("--max-turns", type=int, default=30)
|
||||
ap.add_argument("--timeout", type=float, default=600.0)
|
||||
ap.add_argument("--json-out", default="bench/eval_results.json")
|
||||
ap.add_argument("-v", "--verbose", action="store_true")
|
||||
ap.add_argument("--dry-run", action="store_true")
|
||||
|
||||
args = ap.parse_args()
|
||||
|
||||
# Load dataset first (doesn't require exo to be running)
|
||||
logger.info(f"Loading SWE-bench {args.split} dataset...")
|
||||
instances = load_swe_bench(
|
||||
split=args.split,
|
||||
limit=args.limit,
|
||||
instance_ids=args.instance_ids,
|
||||
)
|
||||
logger.info(f"Loaded {len(instances)} instances")
|
||||
|
||||
if args.dry_run:
|
||||
print(f"\nSWE-bench {args.split} instances ({len(instances)}):")
|
||||
for inst in instances:
|
||||
print(f" {inst['instance_id']} ({inst['repo']})")
|
||||
return 0
|
||||
|
||||
# Verify exo is running
|
||||
model_id = verify_exo_running(args.host, args.port, args.model)
|
||||
logger.info(f"Using model: {model_id}")
|
||||
|
||||
# Run evaluation
|
||||
results: list[EvalResult] = []
|
||||
for i, instance in enumerate(instances):
|
||||
logger.info(f"[{i+1}/{len(instances)}] {instance['instance_id']}")
|
||||
|
||||
result = run_single_eval(
|
||||
instance=instance,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
model=model_id,
|
||||
max_turns=args.max_turns,
|
||||
timeout=args.timeout,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
logger.info(f" Status: {result.status.value}")
|
||||
if result.tests_passed:
|
||||
logger.info(f" Passed: {len(result.tests_passed)} tests")
|
||||
if result.tests_failed:
|
||||
logger.info(f" Failed: {len(result.tests_failed)} tests")
|
||||
if result.error_message:
|
||||
logger.error(f" Error: {result.error_message}")
|
||||
|
||||
# Compute summary
|
||||
total = len(results)
|
||||
resolved = sum(1 for r in results if r.status == EvalStatus.Resolved)
|
||||
failed = sum(1 for r in results if r.status == EvalStatus.Failed)
|
||||
errors = sum(1 for r in results if r.status == EvalStatus.Error)
|
||||
timeouts = sum(1 for r in results if r.status == EvalStatus.Timeout)
|
||||
|
||||
summary = {
|
||||
"model": model_id,
|
||||
"split": args.split,
|
||||
"total": total,
|
||||
"resolved": resolved,
|
||||
"resolved_rate": resolved / total if total else 0,
|
||||
"failed": failed,
|
||||
"errors": errors,
|
||||
"timeouts": timeouts,
|
||||
}
|
||||
|
||||
output = {
|
||||
"summary": summary,
|
||||
"results": [
|
||||
{
|
||||
"instance_id": r.instance_id,
|
||||
"repo": r.repo,
|
||||
"status": r.status.value,
|
||||
"elapsed_seconds": r.elapsed_seconds,
|
||||
"tests_passed": r.tests_passed,
|
||||
"tests_failed": r.tests_failed,
|
||||
"error_message": r.error_message,
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
}
|
||||
|
||||
Path(args.json_out).write_text(json.dumps(output, indent=2))
|
||||
logger.info(f"Results written to {args.json_out}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SWE-bench Evaluation Results")
|
||||
print("=" * 60)
|
||||
print(f"Model: {model_id}")
|
||||
print(f"Split: {args.split}")
|
||||
print(f"Total: {total}")
|
||||
if total:
|
||||
print(f"Resolved: {resolved} ({resolved/total*100:.1f}%)")
|
||||
else:
|
||||
print("Resolved: 0")
|
||||
print(f"Failed: {failed}")
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Timeouts: {timeouts}")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -23,6 +23,9 @@ dependencies = [
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"openhands-sdk>=0.1.0", # for exo-eval SWE-bench evaluation
|
||||
"openhands-tools>=0.1.0", # tools for openhands agents
|
||||
"datasets>=3.0.0", # for loading SWE-bench from HuggingFace
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group
|
||||
@@ -72,7 +72,13 @@ def chunk_to_response(
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=chunk.text if chunk.text else None,
|
||||
tool_calls=[tc.model_dump() for tc in chunk.tool_calls]
|
||||
if chunk.tool_calls
|
||||
else None,
|
||||
),
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -424,6 +430,7 @@ class API:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
all_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if model is None:
|
||||
@@ -431,6 +438,20 @@ class API:
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
# Collect tool calls
|
||||
if chunk.tool_calls:
|
||||
for tc in chunk.tool_calls:
|
||||
all_tool_calls.append(
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
@@ -446,7 +467,8 @@ class API:
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
content=combined_text if combined_text else None,
|
||||
tool_calls=all_tool_calls if all_tool_calls else None,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
@@ -459,6 +481,7 @@ class API:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
all_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -469,6 +492,20 @@ class API:
|
||||
text_parts.append(chunk.text)
|
||||
stats = chunk.stats or stats
|
||||
|
||||
# Collect tool calls
|
||||
if chunk.tool_calls:
|
||||
for tc in chunk.tool_calls:
|
||||
all_tool_calls.append(
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
@@ -483,7 +520,9 @@ class API:
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=combined_text
|
||||
role="assistant",
|
||||
content=combined_text if combined_text else None,
|
||||
tool_calls=all_tool_calls if all_tool_calls else None,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
@@ -12,6 +15,17 @@ class ChunkType(str, Enum):
|
||||
Image = "Image"
|
||||
|
||||
|
||||
class ToolCallFunction(BaseModel, frozen=True):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel, frozen=True):
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: ToolCallFunction
|
||||
|
||||
|
||||
class BaseChunk(TaggedModel):
|
||||
idx: int
|
||||
model: ModelId
|
||||
@@ -22,6 +36,7 @@ class TokenChunk(BaseChunk):
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.chunks import ToolCall
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -16,6 +17,7 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
@@ -12,7 +16,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
)
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import TokenChunk, ToolCall, ToolCallFunction
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -172,7 +176,10 @@ def main(
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
# Parse tool calls to place them in the tool calls section
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator, tokenizer, task_params.tools
|
||||
)
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
@@ -188,6 +195,7 @@ def main(
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
tool_calls=response.tool_calls,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -261,6 +269,98 @@ def parse_gpt_oss(
|
||||
break
|
||||
|
||||
|
||||
def _generate_tool_call_id() -> str:
|
||||
return f"call_{uuid4().hex[:24]}"
|
||||
|
||||
|
||||
def _parse_tool_call_content(
|
||||
content: str,
|
||||
tokenizer: TokenizerWrapper,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> ToolCall | None:
|
||||
content = content.strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
tool_parser: Any = getattr(tokenizer, "tool_parser", None)
|
||||
if tool_parser is None:
|
||||
logger.warning("No tool_parser available for tokenizer")
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed: dict[str, Any] = tool_parser(content, tools) # pyright: ignore[reportAny]
|
||||
if parsed and "name" in parsed:
|
||||
arguments: Any = parsed.get("arguments", {}) # pyright: ignore[reportAny]
|
||||
arguments_str: str = (
|
||||
json.dumps(arguments)
|
||||
if not isinstance(arguments, str)
|
||||
else arguments
|
||||
)
|
||||
return ToolCall(
|
||||
id=_generate_tool_call_id(),
|
||||
type="function",
|
||||
function=ToolCallFunction(
|
||||
name=str(parsed["name"]), # pyright: ignore[reportAny]
|
||||
arguments=arguments_str,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"tool_parser failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
has_tool_calling = getattr(tokenizer, "has_tool_calling", False)
|
||||
if not has_tool_calling or tools is None:
|
||||
yield from responses
|
||||
return
|
||||
|
||||
tool_call_start: str | None = getattr(tokenizer, "tool_call_start", None)
|
||||
tool_call_end: str | None = getattr(tokenizer, "tool_call_end", None)
|
||||
|
||||
if tool_call_start is None or tool_call_end is None:
|
||||
yield from responses
|
||||
return
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_buffer: list[str] = []
|
||||
pending_tool_calls: list[ToolCall] = []
|
||||
|
||||
for response in responses:
|
||||
if response.text == tool_call_start:
|
||||
in_tool_call = True
|
||||
tool_call_buffer = []
|
||||
continue
|
||||
|
||||
if response.text == tool_call_end:
|
||||
in_tool_call = False
|
||||
parsed = _parse_tool_call_content(
|
||||
"".join(tool_call_buffer), tokenizer, tools
|
||||
)
|
||||
if parsed is not None:
|
||||
pending_tool_calls.append(parsed)
|
||||
continue
|
||||
|
||||
if in_tool_call:
|
||||
tool_call_buffer.append(response.text)
|
||||
continue
|
||||
|
||||
if response.finish_reason is None or not pending_tool_calls:
|
||||
yield response
|
||||
else:
|
||||
yield response.model_copy(
|
||||
update={
|
||||
"finish_reason": "tool_calls",
|
||||
"tool_calls": pending_tool_calls if pending_tool_calls else None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
Reference in New Issue
Block a user