Compare commits

..

2 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
826da9512d Add exo eval 2026-01-16 12:55:43 +00:00
Ryuichi Leo Takashige
4d9114b9b5 Add ChatCompletion tool calling support
https://platform.openai.com/docs/api-reference/chat/get
2026-01-16 11:36:08 +00:00
15 changed files with 5433 additions and 262 deletions

View File

@@ -1,16 +1,5 @@
name: Build EXO macOS DMG
# Release workflow:
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
# 3. This workflow builds, signs, and notarizes the DMG
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
# 5. DMG and appcast.xml are uploaded to S3
# 6. The draft GitHub Release is published with the DMG attached
#
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
# If no draft exists, a release is auto-created with generated notes.
on:
workflow_dispatch:
push:
@@ -22,10 +11,8 @@ on:
jobs:
build-macos-app:
runs-on: "macos-26"
permissions:
contents: write
env:
SPARKLE_VERSION: 2.9.0-beta.1
SPARKLE_VERSION: 2.8.1
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
@@ -100,52 +87,6 @@ jobs:
exit 1
fi
- name: Fetch and validate release notes
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Find draft release by name using gh release list (more reliable with default token)
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
if [[ -z "$DRAFT_EXISTS" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
echo "Please create a draft release with release notes before pushing the tag."
exit 1
fi
# Fetch full release details via API to get body and ID
echo "Found draft release, fetching details..."
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
# Extract release notes
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "Draft release has no notes (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: Draft release exists but has no release notes"
echo "Please add release notes to the draft release before pushing the tag."
exit 1
fi
# Save release ID for later publishing
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
echo "$NOTES" > /tmp/release_notes.md
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
# ============================================================
# Install dependencies
# ============================================================
@@ -363,28 +304,6 @@ jobs:
$CHANNEL_FLAG \
.
- name: Inject release notes into appcast
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
env:
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
run: |
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
export NOTES=$(cat "$RELEASE_NOTES_FILE")
# Insert description after the enclosure tag for this version
awk '
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
print
print " <description sparkle:format=\"markdown\"><![CDATA["
print ENVIRON["NOTES"]
print " ]]></description>"
next
}
{ print }
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
echo "Injected markdown release notes for version $RELEASE_VERSION"
# ============================================================
# Upload artifacts
# ============================================================
@@ -417,26 +336,3 @@ jobs:
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
fi
- name: Publish GitHub Release
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
# Update the draft release with the tag and upload DMG
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
-f tag_name="$GITHUB_REF_NAME" \
-F draft=false
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
echo "Published release $GITHUB_REF_NAME with DMG attached"
else
# Alpha without draft release - create one with auto-generated notes
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
--title "$GITHUB_REF_NAME" \
--generate-notes \
--prerelease
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
fi

View File

@@ -585,7 +585,7 @@
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.9.0-beta.1;
minimumVersion = 2.8.1;
};
};
/* End XCRemoteSwiftPackageReference section */

View File

@@ -6,8 +6,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/sparkle-project/Sparkle.git",
"state" : {
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
"version" : "2.9.0-beta.1"
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
"version" : "2.8.1"
}
}
],

View File

@@ -56,11 +56,6 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
lastConnectionState = "connecting"
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
let result = await self.performCheck()
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
connection?.cancel()
connection = nil
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { state in
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
switch state {
case .ready:
resumeOnce(.working)
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: timeout)
try? await Task.sleep(nanoseconds: 3_000_000_000)
let state = conn.state
switch state {
case .ready:

View File

@@ -241,9 +241,6 @@ class PromptSizer:
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn

378
bench/exo_eval.py Normal file
View File

@@ -0,0 +1,378 @@
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false, reportMissingTypeStubs=false
"""
exo-eval: Run SWE-bench evaluation against exo using OpenHands SDK (local, no Docker).
"""
from __future__ import annotations
import argparse
import json
import os
import subprocess
import tempfile
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any
from datasets import load_dataset
from loguru import logger
from openhands.sdk import LLM, Agent, Conversation, Tool
from openhands.tools.file_editor import FileEditorTool
from openhands.tools.terminal import TerminalTool
class EvalStatus(str, Enum):
Resolved = "Resolved"
Failed = "Failed"
Error = "Error"
Timeout = "Timeout"
@dataclass
class EvalResult:
instance_id: str
repo: str
status: EvalStatus
elapsed_seconds: float
tests_passed: list[str]
tests_failed: list[str]
error_message: str | None = None
def load_swe_bench(
split: str = "lite",
limit: int | None = None,
instance_ids: list[str] | None = None,
) -> list[dict[str, Any]]:
"""Load SWE-bench dataset from HuggingFace."""
# SWE-bench Lite is a curated 300-instance subset
dataset_name = (
"princeton-nlp/SWE-bench_Lite" if split == "lite" else "princeton-nlp/SWE-bench"
)
actual_split = "test" if split == "lite" else split
ds = load_dataset(dataset_name, split=actual_split)
instances = [dict(row) for row in ds]
if instance_ids:
instances = [i for i in instances if i["instance_id"] in instance_ids]
if limit:
instances = instances[:limit]
return instances
def clone_repo_at_commit(repo: str, commit: str, dest: Path) -> None:
"""Clone a repo at a specific commit."""
repo_url = f"https://github.com/{repo}.git"
subprocess.run(
["git", "clone", "--depth", "1", repo_url, str(dest)],
check=True,
capture_output=True,
)
subprocess.run(
["git", "fetch", "--depth", "1", "origin", commit],
cwd=dest,
check=True,
capture_output=True,
)
subprocess.run(
["git", "checkout", commit],
cwd=dest,
check=True,
capture_output=True,
)
def build_agent_prompt(instance: dict[str, Any]) -> str:
"""Build the prompt for the agent."""
return f"""You are a software engineer fixing a bug in the {instance['repo']} repository.
## Problem Statement
{instance['problem_statement']}
## Instructions
1. Explore the codebase to understand the issue
2. Identify the files that need to be modified
3. Make the necessary changes to fix the issue
4. The fix should be minimal and targeted
You have access to:
- terminal: Run shell commands (git, grep, python, etc.)
- file_editor: View and edit files
Start by exploring the repository structure to understand where the relevant code is.
"""
def parse_fail_to_pass(fail_to_pass_str: str) -> list[str]:
"""Parse the FAIL_TO_PASS field into a list of test names."""
try:
return json.loads(fail_to_pass_str)
except json.JSONDecodeError:
return [t.strip() for t in fail_to_pass_str.split(",") if t.strip()]
def run_tests(workspace: Path, tests: list[str]) -> tuple[list[str], list[str]]:
"""Run tests and return (passed, failed) lists."""
passed = []
failed = []
for test in tests:
try:
result = subprocess.run(
["python", "-m", "pytest", "-xvs", test],
cwd=workspace,
capture_output=True,
timeout=300,
)
if result.returncode == 0:
passed.append(test)
else:
failed.append(test)
except subprocess.TimeoutExpired:
failed.append(test)
return passed, failed
def run_single_eval(
instance: dict[str, Any],
host: str,
port: int,
model: str,
max_turns: int = 30,
timeout: float = 600.0,
) -> EvalResult:
"""Evaluate a single SWE-bench instance."""
instance_id = instance["instance_id"]
repo = instance["repo"]
base_commit = instance["base_commit"]
fail_to_pass = parse_fail_to_pass(instance["FAIL_TO_PASS"])
start_time = time.perf_counter()
try:
with tempfile.TemporaryDirectory() as tmpdir:
workspace = Path(tmpdir) / "repo"
# Clone repo at base commit
logger.info(f"Cloning {repo} at {base_commit[:8]}...")
clone_repo_at_commit(repo, base_commit, workspace)
# Setup OpenHands agent
llm = LLM(
model=f"openai/{model}",
base_url=f"http://{host}:{port}/v1",
api_key="not-needed",
)
agent = Agent(
llm=llm,
tools=[
Tool(name=TerminalTool.name),
Tool(name=FileEditorTool.name),
],
)
# Run agent
conversation = Conversation(
agent=agent,
workspace=str(workspace),
)
logger.info(f"Running agent on {instance_id}...")
conversation.send_message(build_agent_prompt(instance))
for _turn in range(max_turns):
if time.perf_counter() - start_time > timeout:
return EvalResult(
instance_id=instance_id,
repo=repo,
status=EvalStatus.Timeout,
elapsed_seconds=time.perf_counter() - start_time,
tests_passed=[],
tests_failed=fail_to_pass,
)
result = conversation.run(max_turns=1)
if result.done:
break
# Run tests to verify
logger.info(f"Running tests for {instance_id}...")
passed, failed = run_tests(workspace, fail_to_pass)
elapsed = time.perf_counter() - start_time
status = EvalStatus.Resolved if not failed else EvalStatus.Failed
return EvalResult(
instance_id=instance_id,
repo=repo,
status=status,
elapsed_seconds=elapsed,
tests_passed=passed,
tests_failed=failed,
)
except Exception as e:
return EvalResult(
instance_id=instance_id,
repo=repo,
status=EvalStatus.Error,
elapsed_seconds=time.perf_counter() - start_time,
tests_passed=[],
tests_failed=[],
error_message=str(e),
)
def verify_exo_running(host: str, port: int, model: str) -> str:
"""Verify exo is running and return full model ID."""
import http.client
conn = http.client.HTTPConnection(host, port, timeout=10)
conn.request("GET", "/models")
resp = conn.getresponse()
if resp.status != 200:
raise RuntimeError(f"exo not responding at {host}:{port}")
data = json.loads(resp.read())
for m in data.get("data", []):
if m.get("id") == model or m.get("hugging_face_id") == model:
return m.get("hugging_face_id") or m.get("id")
raise ValueError(f"Model '{model}' not found in exo")
def main() -> int:
ap = argparse.ArgumentParser(
prog="exo-eval",
description="Run SWE-bench evaluation against exo (local, no Docker).",
)
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument(
"--port", type=int, default=int(os.environ.get("EXO_PORT", "52415"))
)
ap.add_argument("--model", required=True, help="exo model ID")
ap.add_argument(
"--split", default="lite", choices=["lite", "dev", "test", "train"]
)
ap.add_argument("--limit", type=int, default=10, help="Max instances")
ap.add_argument("--instance-ids", nargs="+", help="Specific instance IDs")
ap.add_argument("--max-turns", type=int, default=30)
ap.add_argument("--timeout", type=float, default=600.0)
ap.add_argument("--json-out", default="bench/eval_results.json")
ap.add_argument("-v", "--verbose", action="store_true")
ap.add_argument("--dry-run", action="store_true")
args = ap.parse_args()
# Load dataset first (doesn't require exo to be running)
logger.info(f"Loading SWE-bench {args.split} dataset...")
instances = load_swe_bench(
split=args.split,
limit=args.limit,
instance_ids=args.instance_ids,
)
logger.info(f"Loaded {len(instances)} instances")
if args.dry_run:
print(f"\nSWE-bench {args.split} instances ({len(instances)}):")
for inst in instances:
print(f" {inst['instance_id']} ({inst['repo']})")
return 0
# Verify exo is running
model_id = verify_exo_running(args.host, args.port, args.model)
logger.info(f"Using model: {model_id}")
# Run evaluation
results: list[EvalResult] = []
for i, instance in enumerate(instances):
logger.info(f"[{i+1}/{len(instances)}] {instance['instance_id']}")
result = run_single_eval(
instance=instance,
host=args.host,
port=args.port,
model=model_id,
max_turns=args.max_turns,
timeout=args.timeout,
)
results.append(result)
logger.info(f" Status: {result.status.value}")
if result.tests_passed:
logger.info(f" Passed: {len(result.tests_passed)} tests")
if result.tests_failed:
logger.info(f" Failed: {len(result.tests_failed)} tests")
if result.error_message:
logger.error(f" Error: {result.error_message}")
# Compute summary
total = len(results)
resolved = sum(1 for r in results if r.status == EvalStatus.Resolved)
failed = sum(1 for r in results if r.status == EvalStatus.Failed)
errors = sum(1 for r in results if r.status == EvalStatus.Error)
timeouts = sum(1 for r in results if r.status == EvalStatus.Timeout)
summary = {
"model": model_id,
"split": args.split,
"total": total,
"resolved": resolved,
"resolved_rate": resolved / total if total else 0,
"failed": failed,
"errors": errors,
"timeouts": timeouts,
}
output = {
"summary": summary,
"results": [
{
"instance_id": r.instance_id,
"repo": r.repo,
"status": r.status.value,
"elapsed_seconds": r.elapsed_seconds,
"tests_passed": r.tests_passed,
"tests_failed": r.tests_failed,
"error_message": r.error_message,
}
for r in results
],
}
Path(args.json_out).write_text(json.dumps(output, indent=2))
logger.info(f"Results written to {args.json_out}")
# Print summary
print("\n" + "=" * 60)
print("SWE-bench Evaluation Results")
print("=" * 60)
print(f"Model: {model_id}")
print(f"Split: {args.split}")
print(f"Total: {total}")
if total:
print(f"Resolved: {resolved} ({resolved/total*100:.1f}%)")
else:
print("Resolved: 0")
print(f"Failed: {failed}")
print(f"Errors: {errors}")
print(f"Timeouts: {timeouts}")
print("=" * 60)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -23,7 +23,9 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"openhands-sdk>=0.1.0", # for exo-eval SWE-bench evaluation
"openhands-tools>=0.1.0", # tools for openhands agents
"datasets>=3.0.0", # for loading SWE-bench from HuggingFace
]
[project.scripts]

View File

@@ -1,6 +1,6 @@
import time
from collections.abc import AsyncGenerator
from typing import cast
from typing import Any, cast
import anyio
from anyio import create_task_group
@@ -72,7 +72,13 @@ def chunk_to_response(
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
delta=ChatCompletionMessage(
role="assistant",
content=chunk.text if chunk.text else None,
tool_calls=[tc.model_dump() for tc in chunk.tool_calls]
if chunk.tool_calls
else None,
),
finish_reason=chunk.finish_reason,
)
],
@@ -424,6 +430,7 @@ class API:
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
all_tool_calls: list[dict[str, Any]] = []
async for chunk in self._chat_chunk_stream(command_id):
if model is None:
@@ -431,6 +438,20 @@ class API:
text_parts.append(chunk.text)
# Collect tool calls
if chunk.tool_calls:
for tc in chunk.tool_calls:
all_tool_calls.append(
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
@@ -446,7 +467,8 @@ class API:
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
content=combined_text if combined_text else None,
tool_calls=all_tool_calls if all_tool_calls else None,
),
finish_reason=finish_reason,
)
@@ -459,6 +481,7 @@ class API:
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
all_tool_calls: list[dict[str, Any]] = []
stats: GenerationStats | None = None
@@ -469,6 +492,20 @@ class API:
text_parts.append(chunk.text)
stats = chunk.stats or stats
# Collect tool calls
if chunk.tool_calls:
for tc in chunk.tool_calls:
all_tool_calls.append(
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
@@ -483,7 +520,9 @@ class API:
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content=combined_text
role="assistant",
content=combined_text if combined_text else None,
tool_calls=all_tool_calls if all_tool_calls else None,
),
finish_reason=finish_reason,
)

View File

@@ -29,11 +29,6 @@ class _InterceptHandler(logging.Handler):
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

@@ -1,4 +1,7 @@
from enum import Enum
from typing import Literal
from pydantic import BaseModel
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
@@ -12,6 +15,17 @@ class ChunkType(str, Enum):
Image = "Image"
class ToolCallFunction(BaseModel, frozen=True):
name: str
arguments: str
class ToolCall(BaseModel, frozen=True):
id: str
type: Literal["function"] = "function"
function: ToolCallFunction
class BaseChunk(TaggedModel):
idx: int
model: ModelId
@@ -22,6 +36,7 @@ class TokenChunk(BaseChunk):
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
tool_calls: list[ToolCall] | None = None
class ImageChunk(BaseChunk):

View File

@@ -1,4 +1,5 @@
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.chunks import ToolCall
from exo.utils.pydantic_ext import TaggedModel
@@ -16,6 +17,7 @@ class GenerationResponse(BaseRunnerResponse):
# logprobs: list[float] | None = None # too big. we can change to be top-k
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
tool_calls: list[ToolCall] | None = None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -1,9 +1,13 @@
import json
import time
from collections.abc import Generator
from functools import cache
from typing import Any
from uuid import uuid4
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
@@ -12,7 +16,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.chunks import TokenChunk, ToolCall, ToolCallFunction
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -172,7 +176,10 @@ def main(
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# TODO: Add tool call parser here
# Parse tool calls to place them in the tool calls section
mlx_generator = parse_tool_calls(
mlx_generator, tokenizer, task_params.tools
)
for response in mlx_generator:
match response:
@@ -188,6 +195,7 @@ def main(
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
tool_calls=response.tool_calls,
),
)
)
@@ -261,6 +269,98 @@ def parse_gpt_oss(
break
def _generate_tool_call_id() -> str:
return f"call_{uuid4().hex[:24]}"
def _parse_tool_call_content(
content: str,
tokenizer: TokenizerWrapper,
tools: list[dict[str, Any]] | None,
) -> ToolCall | None:
content = content.strip()
if not content:
return None
tool_parser: Any = getattr(tokenizer, "tool_parser", None)
if tool_parser is None:
logger.warning("No tool_parser available for tokenizer")
return None
try:
parsed: dict[str, Any] = tool_parser(content, tools) # pyright: ignore[reportAny]
if parsed and "name" in parsed:
arguments: Any = parsed.get("arguments", {}) # pyright: ignore[reportAny]
arguments_str: str = (
json.dumps(arguments)
if not isinstance(arguments, str)
else arguments
)
return ToolCall(
id=_generate_tool_call_id(),
type="function",
function=ToolCallFunction(
name=str(parsed["name"]), # pyright: ignore[reportAny]
arguments=arguments_str,
),
)
except Exception as e:
logger.warning(f"tool_parser failed: {e}")
return None
def parse_tool_calls(
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
tools: list[dict[str, Any]] | None,
) -> Generator[GenerationResponse]:
has_tool_calling = getattr(tokenizer, "has_tool_calling", False)
if not has_tool_calling or tools is None:
yield from responses
return
tool_call_start: str | None = getattr(tokenizer, "tool_call_start", None)
tool_call_end: str | None = getattr(tokenizer, "tool_call_end", None)
if tool_call_start is None or tool_call_end is None:
yield from responses
return
in_tool_call = False
tool_call_buffer: list[str] = []
pending_tool_calls: list[ToolCall] = []
for response in responses:
if response.text == tool_call_start:
in_tool_call = True
tool_call_buffer = []
continue
if response.text == tool_call_end:
in_tool_call = False
parsed = _parse_tool_call_content(
"".join(tool_call_buffer), tokenizer, tools
)
if parsed is not None:
pending_tool_calls.append(parsed)
continue
if in_tool_call:
tool_call_buffer.append(response.text)
continue
if response.finish_reason is None or not pending_tool_calls:
yield response
else:
yield response.model_copy(
update={
"finish_reason": "tool_calls",
"tool_calls": pending_tool_calls if pending_tool_calls else None,
}
)
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -1,64 +1,62 @@
import anyio
import httpx
from anyio import create_task_group
import http.client
import time
from anyio import create_task_group, to_thread
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
REACHABILITY_ATTEMPTS = 3
BAD_STATUSLINE_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
client: httpx.AsyncClient,
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
target_ip = f"[{target_ip}]"
url = f"http://{target_ip}:52415/node_id"
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
# TODO: use an async http client
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
try:
r = await client.get(url)
if r.status_code != 200:
await anyio.sleep(1)
continue
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
body = r.text.strip().strip('"')
if not body:
await anyio.sleep(1)
continue
body = response.read().decode("utf-8").strip()
remote_node_id = NodeId(body)
break
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
# expected failure cases
except (
httpx.TimeoutException,
httpx.NetworkError,
):
await anyio.sleep(1)
# other failures should be logged on last attempt
except httpx.HTTPError as e:
last_error = e
await anyio.sleep(1)
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
return NodeId(body) or None
except OSError:
return None
except http.client.BadStatusLine:
if attempt >= BAD_STATUSLINE_ATTEMPTS:
logger.warning(
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
)
return None
time.sleep(1)
return _fetch_remote_node_id(attempt=attempt + 1)
except http.client.HTTPException as e:
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
return None
finally:
connection.close()
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id is None:
return
if remote_node_id == self_node_id:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
@@ -76,33 +74,18 @@ async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
# these are intentionally httpx's defaults so we can tune them later
timeout = httpx.Timeout(timeout=5.0)
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=5,
)
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node.node_profile:
continue
if node.node_id == self_node_id:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
client,
)
return reachable

4876
uv.lock generated
View File

File diff suppressed because it is too large Load Diff