mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-21 12:30:22 -05:00
Compare commits
2 Commits
v1.0.63
...
leo/add-to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
826da9512d | ||
|
|
4d9114b9b5 |
106
.github/workflows/build-app.yml
vendored
106
.github/workflows/build-app.yml
vendored
@@ -1,16 +1,5 @@
|
||||
name: Build EXO macOS DMG
|
||||
|
||||
# Release workflow:
|
||||
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
|
||||
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
|
||||
# 3. This workflow builds, signs, and notarizes the DMG
|
||||
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
|
||||
# 5. DMG and appcast.xml are uploaded to S3
|
||||
# 6. The draft GitHub Release is published with the DMG attached
|
||||
#
|
||||
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
|
||||
# If no draft exists, a release is auto-created with generated notes.
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
@@ -22,10 +11,8 @@ on:
|
||||
jobs:
|
||||
build-macos-app:
|
||||
runs-on: "macos-26"
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
SPARKLE_VERSION: 2.9.0-beta.1
|
||||
SPARKLE_VERSION: 2.8.1
|
||||
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
|
||||
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
|
||||
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
|
||||
@@ -100,52 +87,6 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Fetch and validate release notes
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
# Find draft release by name using gh release list (more reliable with default token)
|
||||
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
|
||||
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
|
||||
|
||||
if [[ -z "$DRAFT_EXISTS" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
|
||||
echo "Please create a draft release with release notes before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch full release details via API to get body and ID
|
||||
echo "Found draft release, fetching details..."
|
||||
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
|
||||
|
||||
# Extract release notes
|
||||
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
|
||||
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "Draft release has no notes (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: Draft release exists but has no release notes"
|
||||
echo "Please add release notes to the draft release before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Save release ID for later publishing
|
||||
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
|
||||
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
|
||||
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
|
||||
|
||||
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
|
||||
echo "$NOTES" > /tmp/release_notes.md
|
||||
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
|
||||
|
||||
# ============================================================
|
||||
# Install dependencies
|
||||
# ============================================================
|
||||
@@ -363,28 +304,6 @@ jobs:
|
||||
$CHANNEL_FLAG \
|
||||
.
|
||||
|
||||
- name: Inject release notes into appcast
|
||||
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
|
||||
env:
|
||||
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
|
||||
run: |
|
||||
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
|
||||
export NOTES=$(cat "$RELEASE_NOTES_FILE")
|
||||
|
||||
# Insert description after the enclosure tag for this version
|
||||
awk '
|
||||
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
|
||||
print
|
||||
print " <description sparkle:format=\"markdown\"><![CDATA["
|
||||
print ENVIRON["NOTES"]
|
||||
print " ]]></description>"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
|
||||
|
||||
echo "Injected markdown release notes for version $RELEASE_VERSION"
|
||||
|
||||
# ============================================================
|
||||
# Upload artifacts
|
||||
# ============================================================
|
||||
@@ -417,26 +336,3 @@ jobs:
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
fi
|
||||
|
||||
- name: Publish GitHub Release
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
|
||||
|
||||
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
|
||||
# Update the draft release with the tag and upload DMG
|
||||
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
|
||||
-f tag_name="$GITHUB_REF_NAME" \
|
||||
-F draft=false
|
||||
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
|
||||
echo "Published release $GITHUB_REF_NAME with DMG attached"
|
||||
else
|
||||
# Alpha without draft release - create one with auto-generated notes
|
||||
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
|
||||
--title "$GITHUB_REF_NAME" \
|
||||
--generate-notes \
|
||||
--prerelease
|
||||
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
|
||||
fi
|
||||
|
||||
@@ -585,7 +585,7 @@
|
||||
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 2.9.0-beta.1;
|
||||
minimumVersion = 2.8.1;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/sparkle-project/Sparkle.git",
|
||||
"state" : {
|
||||
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
|
||||
"version" : "2.9.0-beta.1"
|
||||
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
|
||||
"version" : "2.8.1"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -56,11 +56,6 @@ struct ContentView: View {
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
// Show warning if local network is not working and EXO is running.
|
||||
// The checker uses a longer timeout on first launch to allow time for
|
||||
// the permission prompt, so this correctly handles both:
|
||||
// 1. User denied permission on first launch
|
||||
// 2. Permission broke after restart (macOS TCC bug)
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import os.log
|
||||
/// Checks if the app's local network permission is actually functional.
|
||||
///
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service uses NWConnection to mDNS multicast
|
||||
/// to verify actual connectivity.
|
||||
/// actually work after a restart. This service detects this by creating a UDP
|
||||
/// connection to the mDNS multicast address (224.0.0.251:5353).
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
@Published private(set) var lastConnectionState: String = "none"
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
func check() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
|
||||
lastConnectionState = "connecting"
|
||||
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
let result = await self.performCheck()
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks connectivity using NWConnection to mDNS multicast.
|
||||
/// The connection attempt triggers the permission prompt if not yet shown.
|
||||
private func checkConnectivity(timeout: UInt64) async -> Status {
|
||||
private func performCheck() async -> Status {
|
||||
Self.logger.info("Checking local network access via UDP multicast")
|
||||
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
|
||||
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
continuation.resume(returning: status)
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { state in
|
||||
conn.stateUpdateHandler = { [weak self] state in
|
||||
let stateStr: String
|
||||
switch state {
|
||||
case .setup: stateStr = "setup"
|
||||
case .preparing: stateStr = "preparing"
|
||||
case .ready: stateStr = "ready"
|
||||
case .waiting(let e): stateStr = "waiting(\(e))"
|
||||
case .failed(let e): stateStr = "failed(\(e))"
|
||||
case .cancelled: stateStr = "cancelled"
|
||||
@unknown default: stateStr = "unknown"
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
self?.lastConnectionState = stateStr
|
||||
}
|
||||
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
|
||||
resumeOnce(.notWorking(reason: "Connection blocked"))
|
||||
}
|
||||
// Otherwise keep waiting - might be showing permission prompt
|
||||
case .failed(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|
||||
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
conn.start(queue: .main)
|
||||
|
||||
Task {
|
||||
try? await Task.sleep(nanoseconds: timeout)
|
||||
try? await Task.sleep(nanoseconds: 3_000_000_000)
|
||||
let state = conn.state
|
||||
switch state {
|
||||
case .ready:
|
||||
|
||||
@@ -241,9 +241,6 @@ class PromptSizer:
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# Fix for transformers 5.x
|
||||
if hasattr(ids, "input_ids"):
|
||||
ids = ids.input_ids
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
|
||||
378
bench/exo_eval.py
Normal file
378
bench/exo_eval.py
Normal file
@@ -0,0 +1,378 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false, reportMissingTypeStubs=false
|
||||
"""
|
||||
exo-eval: Run SWE-bench evaluation against exo using OpenHands SDK (local, no Docker).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from datasets import load_dataset
|
||||
from loguru import logger
|
||||
from openhands.sdk import LLM, Agent, Conversation, Tool
|
||||
from openhands.tools.file_editor import FileEditorTool
|
||||
from openhands.tools.terminal import TerminalTool
|
||||
|
||||
|
||||
class EvalStatus(str, Enum):
|
||||
Resolved = "Resolved"
|
||||
Failed = "Failed"
|
||||
Error = "Error"
|
||||
Timeout = "Timeout"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalResult:
|
||||
instance_id: str
|
||||
repo: str
|
||||
status: EvalStatus
|
||||
elapsed_seconds: float
|
||||
tests_passed: list[str]
|
||||
tests_failed: list[str]
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
def load_swe_bench(
|
||||
split: str = "lite",
|
||||
limit: int | None = None,
|
||||
instance_ids: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Load SWE-bench dataset from HuggingFace."""
|
||||
# SWE-bench Lite is a curated 300-instance subset
|
||||
dataset_name = (
|
||||
"princeton-nlp/SWE-bench_Lite" if split == "lite" else "princeton-nlp/SWE-bench"
|
||||
)
|
||||
actual_split = "test" if split == "lite" else split
|
||||
|
||||
ds = load_dataset(dataset_name, split=actual_split)
|
||||
instances = [dict(row) for row in ds]
|
||||
|
||||
if instance_ids:
|
||||
instances = [i for i in instances if i["instance_id"] in instance_ids]
|
||||
|
||||
if limit:
|
||||
instances = instances[:limit]
|
||||
|
||||
return instances
|
||||
|
||||
|
||||
def clone_repo_at_commit(repo: str, commit: str, dest: Path) -> None:
|
||||
"""Clone a repo at a specific commit."""
|
||||
repo_url = f"https://github.com/{repo}.git"
|
||||
|
||||
subprocess.run(
|
||||
["git", "clone", "--depth", "1", repo_url, str(dest)],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
subprocess.run(
|
||||
["git", "fetch", "--depth", "1", "origin", commit],
|
||||
cwd=dest,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
subprocess.run(
|
||||
["git", "checkout", commit],
|
||||
cwd=dest,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
|
||||
def build_agent_prompt(instance: dict[str, Any]) -> str:
|
||||
"""Build the prompt for the agent."""
|
||||
return f"""You are a software engineer fixing a bug in the {instance['repo']} repository.
|
||||
|
||||
## Problem Statement
|
||||
{instance['problem_statement']}
|
||||
|
||||
## Instructions
|
||||
1. Explore the codebase to understand the issue
|
||||
2. Identify the files that need to be modified
|
||||
3. Make the necessary changes to fix the issue
|
||||
4. The fix should be minimal and targeted
|
||||
|
||||
You have access to:
|
||||
- terminal: Run shell commands (git, grep, python, etc.)
|
||||
- file_editor: View and edit files
|
||||
|
||||
Start by exploring the repository structure to understand where the relevant code is.
|
||||
"""
|
||||
|
||||
|
||||
def parse_fail_to_pass(fail_to_pass_str: str) -> list[str]:
|
||||
"""Parse the FAIL_TO_PASS field into a list of test names."""
|
||||
try:
|
||||
return json.loads(fail_to_pass_str)
|
||||
except json.JSONDecodeError:
|
||||
return [t.strip() for t in fail_to_pass_str.split(",") if t.strip()]
|
||||
|
||||
|
||||
def run_tests(workspace: Path, tests: list[str]) -> tuple[list[str], list[str]]:
|
||||
"""Run tests and return (passed, failed) lists."""
|
||||
passed = []
|
||||
failed = []
|
||||
|
||||
for test in tests:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["python", "-m", "pytest", "-xvs", test],
|
||||
cwd=workspace,
|
||||
capture_output=True,
|
||||
timeout=300,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
passed.append(test)
|
||||
else:
|
||||
failed.append(test)
|
||||
except subprocess.TimeoutExpired:
|
||||
failed.append(test)
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
def run_single_eval(
|
||||
instance: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
max_turns: int = 30,
|
||||
timeout: float = 600.0,
|
||||
) -> EvalResult:
|
||||
"""Evaluate a single SWE-bench instance."""
|
||||
instance_id = instance["instance_id"]
|
||||
repo = instance["repo"]
|
||||
base_commit = instance["base_commit"]
|
||||
fail_to_pass = parse_fail_to_pass(instance["FAIL_TO_PASS"])
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
workspace = Path(tmpdir) / "repo"
|
||||
|
||||
# Clone repo at base commit
|
||||
logger.info(f"Cloning {repo} at {base_commit[:8]}...")
|
||||
clone_repo_at_commit(repo, base_commit, workspace)
|
||||
|
||||
# Setup OpenHands agent
|
||||
llm = LLM(
|
||||
model=f"openai/{model}",
|
||||
base_url=f"http://{host}:{port}/v1",
|
||||
api_key="not-needed",
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
llm=llm,
|
||||
tools=[
|
||||
Tool(name=TerminalTool.name),
|
||||
Tool(name=FileEditorTool.name),
|
||||
],
|
||||
)
|
||||
|
||||
# Run agent
|
||||
conversation = Conversation(
|
||||
agent=agent,
|
||||
workspace=str(workspace),
|
||||
)
|
||||
|
||||
logger.info(f"Running agent on {instance_id}...")
|
||||
conversation.send_message(build_agent_prompt(instance))
|
||||
|
||||
for _turn in range(max_turns):
|
||||
if time.perf_counter() - start_time > timeout:
|
||||
return EvalResult(
|
||||
instance_id=instance_id,
|
||||
repo=repo,
|
||||
status=EvalStatus.Timeout,
|
||||
elapsed_seconds=time.perf_counter() - start_time,
|
||||
tests_passed=[],
|
||||
tests_failed=fail_to_pass,
|
||||
)
|
||||
|
||||
result = conversation.run(max_turns=1)
|
||||
if result.done:
|
||||
break
|
||||
|
||||
# Run tests to verify
|
||||
logger.info(f"Running tests for {instance_id}...")
|
||||
passed, failed = run_tests(workspace, fail_to_pass)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
status = EvalStatus.Resolved if not failed else EvalStatus.Failed
|
||||
|
||||
return EvalResult(
|
||||
instance_id=instance_id,
|
||||
repo=repo,
|
||||
status=status,
|
||||
elapsed_seconds=elapsed,
|
||||
tests_passed=passed,
|
||||
tests_failed=failed,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return EvalResult(
|
||||
instance_id=instance_id,
|
||||
repo=repo,
|
||||
status=EvalStatus.Error,
|
||||
elapsed_seconds=time.perf_counter() - start_time,
|
||||
tests_passed=[],
|
||||
tests_failed=[],
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
||||
def verify_exo_running(host: str, port: int, model: str) -> str:
|
||||
"""Verify exo is running and return full model ID."""
|
||||
import http.client
|
||||
|
||||
conn = http.client.HTTPConnection(host, port, timeout=10)
|
||||
conn.request("GET", "/models")
|
||||
resp = conn.getresponse()
|
||||
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"exo not responding at {host}:{port}")
|
||||
|
||||
data = json.loads(resp.read())
|
||||
for m in data.get("data", []):
|
||||
if m.get("id") == model or m.get("hugging_face_id") == model:
|
||||
return m.get("hugging_face_id") or m.get("id")
|
||||
|
||||
raise ValueError(f"Model '{model}' not found in exo")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-eval",
|
||||
description="Run SWE-bench evaluation against exo (local, no Docker).",
|
||||
)
|
||||
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
|
||||
)
|
||||
ap.add_argument("--model", required=True, help="exo model ID")
|
||||
ap.add_argument(
|
||||
"--split", default="lite", choices=["lite", "dev", "test", "train"]
|
||||
)
|
||||
ap.add_argument("--limit", type=int, default=10, help="Max instances")
|
||||
ap.add_argument("--instance-ids", nargs="+", help="Specific instance IDs")
|
||||
ap.add_argument("--max-turns", type=int, default=30)
|
||||
ap.add_argument("--timeout", type=float, default=600.0)
|
||||
ap.add_argument("--json-out", default="bench/eval_results.json")
|
||||
ap.add_argument("-v", "--verbose", action="store_true")
|
||||
ap.add_argument("--dry-run", action="store_true")
|
||||
|
||||
args = ap.parse_args()
|
||||
|
||||
# Load dataset first (doesn't require exo to be running)
|
||||
logger.info(f"Loading SWE-bench {args.split} dataset...")
|
||||
instances = load_swe_bench(
|
||||
split=args.split,
|
||||
limit=args.limit,
|
||||
instance_ids=args.instance_ids,
|
||||
)
|
||||
logger.info(f"Loaded {len(instances)} instances")
|
||||
|
||||
if args.dry_run:
|
||||
print(f"\nSWE-bench {args.split} instances ({len(instances)}):")
|
||||
for inst in instances:
|
||||
print(f" {inst['instance_id']} ({inst['repo']})")
|
||||
return 0
|
||||
|
||||
# Verify exo is running
|
||||
model_id = verify_exo_running(args.host, args.port, args.model)
|
||||
logger.info(f"Using model: {model_id}")
|
||||
|
||||
# Run evaluation
|
||||
results: list[EvalResult] = []
|
||||
for i, instance in enumerate(instances):
|
||||
logger.info(f"[{i+1}/{len(instances)}] {instance['instance_id']}")
|
||||
|
||||
result = run_single_eval(
|
||||
instance=instance,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
model=model_id,
|
||||
max_turns=args.max_turns,
|
||||
timeout=args.timeout,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
logger.info(f" Status: {result.status.value}")
|
||||
if result.tests_passed:
|
||||
logger.info(f" Passed: {len(result.tests_passed)} tests")
|
||||
if result.tests_failed:
|
||||
logger.info(f" Failed: {len(result.tests_failed)} tests")
|
||||
if result.error_message:
|
||||
logger.error(f" Error: {result.error_message}")
|
||||
|
||||
# Compute summary
|
||||
total = len(results)
|
||||
resolved = sum(1 for r in results if r.status == EvalStatus.Resolved)
|
||||
failed = sum(1 for r in results if r.status == EvalStatus.Failed)
|
||||
errors = sum(1 for r in results if r.status == EvalStatus.Error)
|
||||
timeouts = sum(1 for r in results if r.status == EvalStatus.Timeout)
|
||||
|
||||
summary = {
|
||||
"model": model_id,
|
||||
"split": args.split,
|
||||
"total": total,
|
||||
"resolved": resolved,
|
||||
"resolved_rate": resolved / total if total else 0,
|
||||
"failed": failed,
|
||||
"errors": errors,
|
||||
"timeouts": timeouts,
|
||||
}
|
||||
|
||||
output = {
|
||||
"summary": summary,
|
||||
"results": [
|
||||
{
|
||||
"instance_id": r.instance_id,
|
||||
"repo": r.repo,
|
||||
"status": r.status.value,
|
||||
"elapsed_seconds": r.elapsed_seconds,
|
||||
"tests_passed": r.tests_passed,
|
||||
"tests_failed": r.tests_failed,
|
||||
"error_message": r.error_message,
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
}
|
||||
|
||||
Path(args.json_out).write_text(json.dumps(output, indent=2))
|
||||
logger.info(f"Results written to {args.json_out}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SWE-bench Evaluation Results")
|
||||
print("=" * 60)
|
||||
print(f"Model: {model_id}")
|
||||
print(f"Split: {args.split}")
|
||||
print(f"Total: {total}")
|
||||
if total:
|
||||
print(f"Resolved: {resolved} ({resolved/total*100:.1f}%)")
|
||||
else:
|
||||
print("Resolved: 0")
|
||||
print(f"Failed: {failed}")
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Timeouts: {timeouts}")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -23,7 +23,9 @@ dependencies = [
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
"openhands-sdk>=0.1.0", # for exo-eval SWE-bench evaluation
|
||||
"openhands-tools>=0.1.0", # tools for openhands agents
|
||||
"datasets>=3.0.0", # for loading SWE-bench from HuggingFace
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group
|
||||
@@ -72,7 +72,13 @@ def chunk_to_response(
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=chunk.text if chunk.text else None,
|
||||
tool_calls=[tc.model_dump() for tc in chunk.tool_calls]
|
||||
if chunk.tool_calls
|
||||
else None,
|
||||
),
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -424,6 +430,7 @@ class API:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
all_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if model is None:
|
||||
@@ -431,6 +438,20 @@ class API:
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
# Collect tool calls
|
||||
if chunk.tool_calls:
|
||||
for tc in chunk.tool_calls:
|
||||
all_tool_calls.append(
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
@@ -446,7 +467,8 @@ class API:
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
content=combined_text if combined_text else None,
|
||||
tool_calls=all_tool_calls if all_tool_calls else None,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
@@ -459,6 +481,7 @@ class API:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
all_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -469,6 +492,20 @@ class API:
|
||||
text_parts.append(chunk.text)
|
||||
stats = chunk.stats or stats
|
||||
|
||||
# Collect tool calls
|
||||
if chunk.tool_calls:
|
||||
for tc in chunk.tool_calls:
|
||||
all_tool_calls.append(
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
@@ -483,7 +520,9 @@ class API:
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content=combined_text
|
||||
role="assistant",
|
||||
content=combined_text if combined_text else None,
|
||||
tool_calls=all_tool_calls if all_tool_calls else None,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
@@ -29,11 +29,6 @@ class _InterceptHandler(logging.Handler):
|
||||
|
||||
def logger_setup(log_file: Path | None, verbosity: int = 0):
|
||||
"""Set up logging for this process - formatting, file handles, verbosity and output"""
|
||||
|
||||
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
logger.remove()
|
||||
|
||||
# replace all stdlib loggers with _InterceptHandlers that log to loguru
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
@@ -12,6 +15,17 @@ class ChunkType(str, Enum):
|
||||
Image = "Image"
|
||||
|
||||
|
||||
class ToolCallFunction(BaseModel, frozen=True):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel, frozen=True):
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: ToolCallFunction
|
||||
|
||||
|
||||
class BaseChunk(TaggedModel):
|
||||
idx: int
|
||||
model: ModelId
|
||||
@@ -22,6 +36,7 @@ class TokenChunk(BaseChunk):
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.chunks import ToolCall
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -16,6 +17,7 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
@@ -12,7 +16,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
)
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import TokenChunk, ToolCall, ToolCallFunction
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -172,7 +176,10 @@ def main(
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
# Parse tool calls to place them in the tool calls section
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator, tokenizer, task_params.tools
|
||||
)
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
@@ -188,6 +195,7 @@ def main(
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
tool_calls=response.tool_calls,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -261,6 +269,98 @@ def parse_gpt_oss(
|
||||
break
|
||||
|
||||
|
||||
def _generate_tool_call_id() -> str:
|
||||
return f"call_{uuid4().hex[:24]}"
|
||||
|
||||
|
||||
def _parse_tool_call_content(
|
||||
content: str,
|
||||
tokenizer: TokenizerWrapper,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> ToolCall | None:
|
||||
content = content.strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
tool_parser: Any = getattr(tokenizer, "tool_parser", None)
|
||||
if tool_parser is None:
|
||||
logger.warning("No tool_parser available for tokenizer")
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed: dict[str, Any] = tool_parser(content, tools) # pyright: ignore[reportAny]
|
||||
if parsed and "name" in parsed:
|
||||
arguments: Any = parsed.get("arguments", {}) # pyright: ignore[reportAny]
|
||||
arguments_str: str = (
|
||||
json.dumps(arguments)
|
||||
if not isinstance(arguments, str)
|
||||
else arguments
|
||||
)
|
||||
return ToolCall(
|
||||
id=_generate_tool_call_id(),
|
||||
type="function",
|
||||
function=ToolCallFunction(
|
||||
name=str(parsed["name"]), # pyright: ignore[reportAny]
|
||||
arguments=arguments_str,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"tool_parser failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
has_tool_calling = getattr(tokenizer, "has_tool_calling", False)
|
||||
if not has_tool_calling or tools is None:
|
||||
yield from responses
|
||||
return
|
||||
|
||||
tool_call_start: str | None = getattr(tokenizer, "tool_call_start", None)
|
||||
tool_call_end: str | None = getattr(tokenizer, "tool_call_end", None)
|
||||
|
||||
if tool_call_start is None or tool_call_end is None:
|
||||
yield from responses
|
||||
return
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_buffer: list[str] = []
|
||||
pending_tool_calls: list[ToolCall] = []
|
||||
|
||||
for response in responses:
|
||||
if response.text == tool_call_start:
|
||||
in_tool_call = True
|
||||
tool_call_buffer = []
|
||||
continue
|
||||
|
||||
if response.text == tool_call_end:
|
||||
in_tool_call = False
|
||||
parsed = _parse_tool_call_content(
|
||||
"".join(tool_call_buffer), tokenizer, tools
|
||||
)
|
||||
if parsed is not None:
|
||||
pending_tool_calls.append(parsed)
|
||||
continue
|
||||
|
||||
if in_tool_call:
|
||||
tool_call_buffer.append(response.text)
|
||||
continue
|
||||
|
||||
if response.finish_reason is None or not pending_tool_calls:
|
||||
yield response
|
||||
else:
|
||||
yield response.model_copy(
|
||||
update={
|
||||
"finish_reason": "tool_calls",
|
||||
"tool_calls": pending_tool_calls if pending_tool_calls else None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -1,64 +1,62 @@
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio import create_task_group
|
||||
import http.client
|
||||
import time
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
REACHABILITY_ATTEMPTS = 3
|
||||
BAD_STATUSLINE_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
self_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
if ":" in target_ip:
|
||||
# TODO: use real IpAddress types
|
||||
target_ip = f"[{target_ip}]"
|
||||
url = f"http://{target_ip}:52415/node_id"
|
||||
|
||||
remote_node_id = None
|
||||
last_error = None
|
||||
|
||||
for _ in range(REACHABILITY_ATTEMPTS):
|
||||
# TODO: use an async http client
|
||||
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
|
||||
try:
|
||||
r = await client.get(url)
|
||||
if r.status_code != 200:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
body = r.text.strip().strip('"')
|
||||
if not body:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
body = response.read().decode("utf-8").strip()
|
||||
|
||||
remote_node_id = NodeId(body)
|
||||
break
|
||||
# Strip quotes if present (JSON string response)
|
||||
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
|
||||
body = body[1:-1]
|
||||
|
||||
# expected failure cases
|
||||
except (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
):
|
||||
await anyio.sleep(1)
|
||||
|
||||
# other failures should be logged on last attempt
|
||||
except httpx.HTTPError as e:
|
||||
last_error = e
|
||||
await anyio.sleep(1)
|
||||
|
||||
if last_error is not None:
|
||||
logger.warning(
|
||||
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
|
||||
)
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.BadStatusLine:
|
||||
if attempt >= BAD_STATUSLINE_ATTEMPTS:
|
||||
logger.warning(
|
||||
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
|
||||
)
|
||||
return None
|
||||
time.sleep(1)
|
||||
return _fetch_remote_node_id(attempt=attempt + 1)
|
||||
except http.client.HTTPException as e:
|
||||
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
|
||||
if remote_node_id is None:
|
||||
return
|
||||
|
||||
if remote_node_id == self_node_id:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
logger.warning(
|
||||
f"Discovered node with unexpected node_id; "
|
||||
@@ -76,33 +74,18 @@ async def check_reachable(
|
||||
topology: Topology, self_node_id: NodeId
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
|
||||
# these are intentionally httpx's defaults so we can tune them later
|
||||
timeout = httpx.Timeout(timeout=5.0)
|
||||
limits = httpx.Limits(
|
||||
max_connections=100,
|
||||
max_keepalive_connections=20,
|
||||
keepalive_expiry=5,
|
||||
)
|
||||
|
||||
async with (
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
|
||||
create_task_group() as tg,
|
||||
):
|
||||
async with create_task_group() as tg:
|
||||
for node in topology.list_nodes():
|
||||
if not node.node_profile:
|
||||
continue
|
||||
if node.node_id == self_node_id:
|
||||
continue
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node.node_id,
|
||||
self_node_id,
|
||||
reachable,
|
||||
client,
|
||||
)
|
||||
|
||||
return reachable
|
||||
|
||||
Reference in New Issue
Block a user