Compare commits

...

2 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
826da9512d Add exo eval 2026-01-16 12:55:43 +00:00
Ryuichi Leo Takashige
4d9114b9b5 Add ChatCompletion tool calling support
https://platform.openai.com/docs/api-reference/chat/get
2026-01-16 11:36:08 +00:00
7 changed files with 6058 additions and 751 deletions

378
bench/exo_eval.py Normal file
View File

@@ -0,0 +1,378 @@
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false, reportMissingTypeStubs=false
"""
exo-eval: Run SWE-bench evaluation against exo using OpenHands SDK (local, no Docker).
"""
from __future__ import annotations
import argparse
import json
import os
import subprocess
import tempfile
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
from datasets import load_dataset
from loguru import logger
from openhands.sdk import LLM, Agent, Conversation, Tool
from openhands.tools.file_editor import FileEditorTool
from openhands.tools.terminal import TerminalTool
class EvalStatus(str, Enum):
Resolved = "Resolved"
Failed = "Failed"
Error = "Error"
Timeout = "Timeout"
@dataclass
class EvalResult:
instance_id: str
repo: str
status: EvalStatus
elapsed_seconds: float
tests_passed: list[str]
tests_failed: list[str]
error_message: str | None = None
def load_swe_bench(
split: str = "lite",
limit: int | None = None,
instance_ids: list[str] | None = None,
) -> list[dict[str, Any]]:
"""Load SWE-bench dataset from HuggingFace."""
# SWE-bench Lite is a curated 300-instance subset
dataset_name = (
"princeton-nlp/SWE-bench_Lite" if split == "lite" else "princeton-nlp/SWE-bench"
)
actual_split = "test" if split == "lite" else split
ds = load_dataset(dataset_name, split=actual_split)
instances = [dict(row) for row in ds]
if instance_ids:
instances = [i for i in instances if i["instance_id"] in instance_ids]
if limit:
instances = instances[:limit]
return instances
def clone_repo_at_commit(repo: str, commit: str, dest: Path) -> None:
"""Clone a repo at a specific commit."""
repo_url = f"https://github.com/{repo}.git"
subprocess.run(
["git", "clone", "--depth", "1", repo_url, str(dest)],
check=True,
capture_output=True,
)
subprocess.run(
["git", "fetch", "--depth", "1", "origin", commit],
cwd=dest,
check=True,
capture_output=True,
)
subprocess.run(
["git", "checkout", commit],
cwd=dest,
check=True,
capture_output=True,
)
def build_agent_prompt(instance: dict[str, Any]) -> str:
"""Build the prompt for the agent."""
return f"""You are a software engineer fixing a bug in the {instance['repo']} repository.
## Problem Statement
{instance['problem_statement']}
## Instructions
1. Explore the codebase to understand the issue
2. Identify the files that need to be modified
3. Make the necessary changes to fix the issue
4. The fix should be minimal and targeted
You have access to:
- terminal: Run shell commands (git, grep, python, etc.)
- file_editor: View and edit files
Start by exploring the repository structure to understand where the relevant code is.
"""
def parse_fail_to_pass(fail_to_pass_str: str) -> list[str]:
"""Parse the FAIL_TO_PASS field into a list of test names."""
try:
return json.loads(fail_to_pass_str)
except json.JSONDecodeError:
return [t.strip() for t in fail_to_pass_str.split(",") if t.strip()]
def run_tests(workspace: Path, tests: list[str]) -> tuple[list[str], list[str]]:
"""Run tests and return (passed, failed) lists."""
passed = []
failed = []
for test in tests:
try:
result = subprocess.run(
["python", "-m", "pytest", "-xvs", test],
cwd=workspace,
capture_output=True,
timeout=300,
)
if result.returncode == 0:
passed.append(test)
else:
failed.append(test)
except subprocess.TimeoutExpired:
failed.append(test)
return passed, failed
def run_single_eval(
instance: dict[str, Any],
host: str,
port: int,
model: str,
max_turns: int = 30,
timeout: float = 600.0,
) -> EvalResult:
"""Evaluate a single SWE-bench instance."""
instance_id = instance["instance_id"]
repo = instance["repo"]
base_commit = instance["base_commit"]
fail_to_pass = parse_fail_to_pass(instance["FAIL_TO_PASS"])
start_time = time.perf_counter()
try:
with tempfile.TemporaryDirectory() as tmpdir:
workspace = Path(tmpdir) / "repo"
# Clone repo at base commit
logger.info(f"Cloning {repo} at {base_commit[:8]}...")
clone_repo_at_commit(repo, base_commit, workspace)
# Setup OpenHands agent
llm = LLM(
model=f"openai/{model}",
base_url=f"http://{host}:{port}/v1",
api_key="not-needed",
)
agent = Agent(
llm=llm,
tools=[
Tool(name=TerminalTool.name),
Tool(name=FileEditorTool.name),
],
)
# Run agent
conversation = Conversation(
agent=agent,
workspace=str(workspace),
)
logger.info(f"Running agent on {instance_id}...")
conversation.send_message(build_agent_prompt(instance))
for _turn in range(max_turns):
if time.perf_counter() - start_time > timeout:
return EvalResult(
instance_id=instance_id,
repo=repo,
status=EvalStatus.Timeout,
elapsed_seconds=time.perf_counter() - start_time,
tests_passed=[],
tests_failed=fail_to_pass,
)
result = conversation.run(max_turns=1)
if result.done:
break
# Run tests to verify
logger.info(f"Running tests for {instance_id}...")
passed, failed = run_tests(workspace, fail_to_pass)
elapsed = time.perf_counter() - start_time
status = EvalStatus.Resolved if not failed else EvalStatus.Failed
return EvalResult(
instance_id=instance_id,
repo=repo,
status=status,
elapsed_seconds=elapsed,
tests_passed=passed,
tests_failed=failed,
)
except Exception as e:
return EvalResult(
instance_id=instance_id,
repo=repo,
status=EvalStatus.Error,
elapsed_seconds=time.perf_counter() - start_time,
tests_passed=[],
tests_failed=[],
error_message=str(e),
)
def verify_exo_running(host: str, port: int, model: str) -> str:
"""Verify exo is running and return full model ID."""
import http.client
conn = http.client.HTTPConnection(host, port, timeout=10)
conn.request("GET", "/models")
resp = conn.getresponse()
if resp.status != 200:
raise RuntimeError(f"exo not responding at {host}:{port}")
data = json.loads(resp.read())
for m in data.get("data", []):
if m.get("id") == model or m.get("hugging_face_id") == model:
return m.get("hugging_face_id") or m.get("id")
raise ValueError(f"Model '{model}' not found in exo")
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-eval",
description="Run SWE-bench evaluation against exo (local, no Docker).",
)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="exo model ID")
ap.add_argument(
"--split", default="lite", choices=["lite", "dev", "test", "train"]
)
ap.add_argument("--limit", type=int, default=10, help="Max instances")
ap.add_argument("--instance-ids", nargs="+", help="Specific instance IDs")
ap.add_argument("--max-turns", type=int, default=30)
ap.add_argument("--timeout", type=float, default=600.0)
ap.add_argument("--json-out", default="bench/eval_results.json")
ap.add_argument("-v", "--verbose", action="store_true")
ap.add_argument("--dry-run", action="store_true")
args = ap.parse_args()
# Load dataset first (doesn't require exo to be running)
logger.info(f"Loading SWE-bench {args.split} dataset...")
instances = load_swe_bench(
split=args.split,
limit=args.limit,
instance_ids=args.instance_ids,
)
logger.info(f"Loaded {len(instances)} instances")
if args.dry_run:
print(f"\nSWE-bench {args.split} instances ({len(instances)}):")
for inst in instances:
print(f" {inst['instance_id']} ({inst['repo']})")
return 0
# Verify exo is running
model_id = verify_exo_running(args.host, args.port, args.model)
logger.info(f"Using model: {model_id}")
# Run evaluation
results: list[EvalResult] = []
for i, instance in enumerate(instances):
logger.info(f"[{i+1}/{len(instances)}] {instance['instance_id']}")
result = run_single_eval(
instance=instance,
host=args.host,
port=args.port,
model=model_id,
max_turns=args.max_turns,
timeout=args.timeout,
)
results.append(result)
logger.info(f" Status: {result.status.value}")
if result.tests_passed:
logger.info(f" Passed: {len(result.tests_passed)} tests")
if result.tests_failed:
logger.info(f" Failed: {len(result.tests_failed)} tests")
if result.error_message:
logger.error(f" Error: {result.error_message}")
# Compute summary
total = len(results)
resolved = sum(1 for r in results if r.status == EvalStatus.Resolved)
failed = sum(1 for r in results if r.status == EvalStatus.Failed)
errors = sum(1 for r in results if r.status == EvalStatus.Error)
timeouts = sum(1 for r in results if r.status == EvalStatus.Timeout)
summary = {
"model": model_id,
"split": args.split,
"total": total,
"resolved": resolved,
"resolved_rate": resolved / total if total else 0,
"failed": failed,
"errors": errors,
"timeouts": timeouts,
}
output = {
"summary": summary,
"results": [
{
"instance_id": r.instance_id,
"repo": r.repo,
"status": r.status.value,
"elapsed_seconds": r.elapsed_seconds,
"tests_passed": r.tests_passed,
"tests_failed": r.tests_failed,
"error_message": r.error_message,
}
for r in results
],
}
Path(args.json_out).write_text(json.dumps(output, indent=2))
logger.info(f"Results written to {args.json_out}")
# Print summary
print("\n" + "=" * 60)
print("SWE-bench Evaluation Results")
print("=" * 60)
print(f"Model: {model_id}")
print(f"Split: {args.split}")
print(f"Total: {total}")
if total:
print(f"Resolved: {resolved} ({resolved/total*100:.1f}%)")
else:
print("Resolved: 0")
print(f"Failed: {failed}")
print(f"Errors: {errors}")
print(f"Timeouts: {timeouts}")
print("=" * 60)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -23,6 +23,9 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"openhands-sdk>=0.1.0", # for exo-eval SWE-bench evaluation
"openhands-tools>=0.1.0", # tools for openhands agents
"datasets>=3.0.0", # for loading SWE-bench from HuggingFace
]
[project.scripts]

View File

@@ -1,6 +1,6 @@
import time
from collections.abc import AsyncGenerator
from typing import cast
from typing import Any, cast
import anyio
from anyio import create_task_group
@@ -72,7 +72,13 @@ def chunk_to_response(
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
delta=ChatCompletionMessage(
role="assistant",
content=chunk.text if chunk.text else None,
tool_calls=[tc.model_dump() for tc in chunk.tool_calls]
if chunk.tool_calls
else None,
),
finish_reason=chunk.finish_reason,
)
],
@@ -424,6 +430,7 @@ class API:
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
all_tool_calls: list[dict[str, Any]] = []
async for chunk in self._chat_chunk_stream(command_id):
if model is None:
@@ -431,6 +438,20 @@ class API:
text_parts.append(chunk.text)
# Collect tool calls
if chunk.tool_calls:
for tc in chunk.tool_calls:
all_tool_calls.append(
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
@@ -446,7 +467,8 @@ class API:
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
content=combined_text if combined_text else None,
tool_calls=all_tool_calls if all_tool_calls else None,
),
finish_reason=finish_reason,
)
@@ -459,6 +481,7 @@ class API:
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
all_tool_calls: list[dict[str, Any]] = []
stats: GenerationStats | None = None
@@ -469,6 +492,20 @@ class API:
text_parts.append(chunk.text)
stats = chunk.stats or stats
# Collect tool calls
if chunk.tool_calls:
for tc in chunk.tool_calls:
all_tool_calls.append(
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
@@ -483,7 +520,9 @@ class API:
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content=combined_text
role="assistant",
content=combined_text if combined_text else None,
tool_calls=all_tool_calls if all_tool_calls else None,
),
finish_reason=finish_reason,
)

View File

@@ -1,4 +1,7 @@
from enum import Enum
from typing import Literal
from pydantic import BaseModel
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
@@ -12,6 +15,17 @@ class ChunkType(str, Enum):
Image = "Image"
class ToolCallFunction(BaseModel, frozen=True):
name: str
arguments: str
class ToolCall(BaseModel, frozen=True):
id: str
type: Literal["function"] = "function"
function: ToolCallFunction
class BaseChunk(TaggedModel):
idx: int
model: ModelId
@@ -22,6 +36,7 @@ class TokenChunk(BaseChunk):
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
tool_calls: list[ToolCall] | None = None
class ImageChunk(BaseChunk):

View File

@@ -1,4 +1,5 @@
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.chunks import ToolCall
from exo.utils.pydantic_ext import TaggedModel
@@ -16,6 +17,7 @@ class GenerationResponse(BaseRunnerResponse):
# logprobs: list[float] | None = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
tool_calls: list[ToolCall] | None = None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -1,9 +1,13 @@
import json
import time
from collections.abc import Generator
from functools import cache
from typing import Any
from uuid import uuid4
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
@@ -12,7 +16,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import TokenChunk, ToolCall, ToolCallFunction
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -172,7 +176,10 @@ def main(
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# TODO: Add tool call parser here
# Parse tool calls to place them in the tool calls section
mlx_generator = parse_tool_calls(
mlx_generator, tokenizer, task_params.tools
)
for response in mlx_generator:
match response:
@@ -188,6 +195,7 @@ def main(
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
tool_calls=response.tool_calls,
),
)
)
@@ -261,6 +269,98 @@ def parse_gpt_oss(
break
def _generate_tool_call_id() -> str:
return f"call_{uuid4().hex[:24]}"
def _parse_tool_call_content(
content: str,
tokenizer: TokenizerWrapper,
tools: list[dict[str, Any]] | None,
) -> ToolCall | None:
content = content.strip()
if not content:
return None
tool_parser: Any = getattr(tokenizer, "tool_parser", None)
if tool_parser is None:
logger.warning("No tool_parser available for tokenizer")
return None
try:
parsed: dict[str, Any] = tool_parser(content, tools) # pyright: ignore[reportAny]
if parsed and "name" in parsed:
arguments: Any = parsed.get("arguments", {}) # pyright: ignore[reportAny]
arguments_str: str = (
json.dumps(arguments)
if not isinstance(arguments, str)
else arguments
)
return ToolCall(
id=_generate_tool_call_id(),
type="function",
function=ToolCallFunction(
name=str(parsed["name"]), # pyright: ignore[reportAny]
arguments=arguments_str,
),
)
except Exception as e:
logger.warning(f"tool_parser failed: {e}")
return None
def parse_tool_calls(
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
tools: list[dict[str, Any]] | None,
) -> Generator[GenerationResponse]:
has_tool_calling = getattr(tokenizer, "has_tool_calling", False)
if not has_tool_calling or tools is None:
yield from responses
return
tool_call_start: str | None = getattr(tokenizer, "tool_call_start", None)
tool_call_end: str | None = getattr(tokenizer, "tool_call_end", None)
if tool_call_start is None or tool_call_end is None:
yield from responses
return
in_tool_call = False
tool_call_buffer: list[str] = []
pending_tool_calls: list[ToolCall] = []
for response in responses:
if response.text == tool_call_start:
in_tool_call = True
tool_call_buffer = []
continue
if response.text == tool_call_end:
in_tool_call = False
parsed = _parse_tool_call_content(
"".join(tool_call_buffer), tokenizer, tools
)
if parsed is not None:
pending_tool_calls.append(parsed)
continue
if in_tool_call:
tool_call_buffer.append(response.text)
continue
if response.finish_reason is None or not pending_tool_calls:
yield response
else:
yield response.model_copy(
update={
"finish_reason": "tool_calls",
"tool_calls": pending_tool_calls if pending_tool_calls else None,
}
)
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

6260
uv.lock generated
View File

File diff suppressed because it is too large Load Diff