Compare commits

...

7 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
86e5d7b101 optimize further and get usage stats 2026-01-22 22:13:00 +00:00
Ryuichi Leo Takashige
d9ddf90575 add token usage stats 2026-01-22 21:04:56 +00:00
Ryuichi Leo Takashige
4591301767 Add a bunch of LLM generated slop 2026-01-22 20:44:40 +00:00
Ryuichi Leo Takashige
8b0b5e1b88 Add completions endpoint 2026-01-22 17:26:52 +00:00
Ryuichi Leo Takashige
bd6287727a Add basic exo eval 2026-01-22 16:48:12 +00:00
Ryuichi Leo Takashige
eb53611210 Add option to use null top k 2026-01-22 16:44:53 +00:00
Ryuichi Leo Takashige
71bbe5f25b Review and extract logprob stuff from alexcheema/uncertainty-visualization 2026-01-22 14:51:12 +00:00
23 changed files with 3822 additions and 361 deletions

0
bench/__init__.py Normal file
View File

451
bench/completions_proxy.py Normal file
View File

@@ -0,0 +1,451 @@
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
"""
Proxy that converts /v1/completions requests to /v1/chat/completions.
Used by exo_eval to support lm_eval tasks that require the completions API.
"""
from __future__ import annotations
import asyncio
import json
import socket
from contextlib import asynccontextmanager, contextmanager
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from hypercorn.asyncio import serve
from hypercorn.config import Config
from loguru import logger
if TYPE_CHECKING:
from collections.abc import AsyncIterator
# Tasks that require the completions API (loglikelihood-based)
# These cannot work with chat completions because they need prompt token logprobs
COMPLETIONS_REQUIRED_TASKS: set[str] = {
# Multiple choice / loglikelihood tasks
"arc_challenge",
"arc_easy",
"hellaswag",
"mmlu",
"openbookqa",
"piqa",
"sciq",
"siqa",
"truthfulqa_mc1",
"truthfulqa_mc2",
"winogrande",
"boolq",
"lambada",
"lambada_openai",
"logiqa",
"logiqa2",
# Add more as needed
}
# Task prefixes that indicate completions are required
COMPLETIONS_REQUIRED_PREFIXES: tuple[str, ...] = (
"mmlu_", # mmlu subtasks (but NOT mmlu_pro, mmlu_generative, etc.)
"arc_", # arc subtasks
"hellaswag_",
"winogrande_",
)
# Generation-based tasks that happen to match completions prefixes above.
# These use generate_until (not loglikelihood) and must go through chat completions.
GENERATION_BASED_EXCEPTIONS: set[str] = {
"mmlu_pro",
"mmlu_generative",
"mmlu_flan_cot_fewshot",
"mmlu_flan_cot_zeroshot",
}
def tasks_require_completions(tasks: list[str]) -> bool:
"""Check if any of the tasks require the completions API."""
for task in tasks:
task_lower = task.lower()
if task_lower in GENERATION_BASED_EXCEPTIONS:
continue
if task_lower in COMPLETIONS_REQUIRED_TASKS:
return True
for prefix in COMPLETIONS_REQUIRED_PREFIXES:
if task_lower.startswith(prefix):
return True
return False
def find_free_port() -> int:
"""Find a free port to use for the proxy."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def create_proxy_app(upstream_url: str) -> FastAPI:
"""Create a FastAPI app that proxies completions to chat completions."""
app = FastAPI()
def convert_completions_to_chat_request(
completions_req: dict[str, Any],
) -> dict[str, Any]:
"""Convert a /v1/completions request to /v1/chat/completions format."""
prompt = completions_req.get("prompt", "")
# Handle prompt as string or list of strings
if isinstance(prompt, list):
prompt = prompt[0] if prompt else ""
chat_req: dict[str, Any] = {
"model": completions_req.get("model", ""),
"messages": [{"role": "user", "content": prompt}],
"stream": completions_req.get("stream", False),
}
# Map common parameters
for param in (
"max_tokens",
"temperature",
"top_p",
"stop",
"seed",
"presence_penalty",
"frequency_penalty",
):
if param in completions_req:
chat_req[param] = completions_req[param]
# Handle logprobs - completions uses int, chat uses bool + top_logprobs
logprobs = completions_req.get("logprobs")
if logprobs is not None and logprobs > 0:
chat_req["logprobs"] = True
chat_req["top_logprobs"] = logprobs
elif logprobs is not None:
chat_req["logprobs"] = True
return chat_req
def convert_chat_to_completions_response(
chat_resp: dict[str, Any],
echo: bool = False,
prompt: str = "",
) -> dict[str, Any]:
"""Convert a /v1/chat/completions response to /v1/completions format."""
choices = []
for chat_choice in chat_resp.get("choices", []):
message = chat_choice.get("message", {})
text = message.get("content", "") or ""
# Build logprobs in completions format
logprobs_data = None
chat_logprobs = chat_choice.get("logprobs")
if chat_logprobs and chat_logprobs.get("content"):
tokens: list[str] = []
token_logprobs: list[float] = []
top_logprobs: list[dict[str, float]] = []
text_offset: list[int] = []
offset = 0
for item in chat_logprobs["content"]:
tokens.append(item["token"])
token_logprobs.append(item["logprob"])
# Convert top_logprobs list to dict format
top_lp_dict: dict[str, float] = {}
for top_item in item.get("top_logprobs", []):
top_lp_dict[top_item["token"]] = top_item["logprob"]
top_logprobs.append(top_lp_dict)
text_offset.append(offset)
offset += len(item["token"])
logprobs_data = {
"tokens": tokens,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
"text_offset": text_offset,
}
# If echo was requested, prepend prompt to text
if echo:
text = prompt + text
choices.append(
{
"text": text,
"index": chat_choice.get("index", 0),
"logprobs": logprobs_data,
"finish_reason": chat_choice.get("finish_reason"),
}
)
return {
"id": chat_resp.get("id", ""),
"object": "text_completion",
"created": chat_resp.get("created", 0),
"model": chat_resp.get("model", ""),
"choices": choices,
"usage": chat_resp.get("usage"),
}
def convert_chat_stream_chunk_to_completions(
chunk: dict[str, Any],
echo: bool = False,
prompt: str = "",
is_first: bool = False,
) -> dict[str, Any]:
"""Convert a streaming chat completion chunk to completions format."""
choices = []
for chat_choice in chunk.get("choices", []):
delta = chat_choice.get("delta", {})
text = delta.get("content", "") or ""
# If echo and first chunk, prepend prompt
if echo and is_first:
text = prompt + text
# Build logprobs in completions format
logprobs_data = None
chat_logprobs = chat_choice.get("logprobs")
if chat_logprobs and chat_logprobs.get("content"):
tokens: list[str] = []
token_logprobs: list[float] = []
top_logprobs: list[dict[str, float]] = []
for item in chat_logprobs["content"]:
tokens.append(item["token"])
token_logprobs.append(item["logprob"])
top_lp_dict: dict[str, float] = {}
for top_item in item.get("top_logprobs", []):
top_lp_dict[top_item["token"]] = top_item["logprob"]
top_logprobs.append(top_lp_dict)
logprobs_data = {
"tokens": tokens,
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
}
choices.append(
{
"text": text,
"index": chat_choice.get("index", 0),
"logprobs": logprobs_data,
"finish_reason": chat_choice.get("finish_reason"),
}
)
return {
"id": chunk.get("id", ""),
"object": "text_completion",
"created": chunk.get("created", 0),
"model": chunk.get("model", ""),
"choices": choices,
}
@app.post("/v1/completions", response_model=None)
async def completions(request: Request):
body = await request.json()
prompt = body.get("prompt", "")
if isinstance(prompt, list):
prompt = prompt[0] if prompt else ""
echo = body.get("echo", False)
stream = body.get("stream", False)
chat_request = convert_completions_to_chat_request(body)
logger.debug(f"Proxying to {upstream_url}/v1/chat/completions")
async with httpx.AsyncClient(timeout=300.0, http2=False) as client:
if stream:
async def generate() -> AsyncGenerator[str, None]:
is_first = True
async with client.stream(
"POST",
f"{upstream_url}/v1/chat/completions",
json=chat_request,
) as response:
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
yield "data: [DONE]\n\n"
break
try:
chunk = json.loads(data)
converted = (
convert_chat_stream_chunk_to_completions(
chunk,
echo=echo,
prompt=prompt,
is_first=is_first,
)
)
is_first = False
yield f"data: {json.dumps(converted)}\n\n"
except json.JSONDecodeError:
continue
return StreamingResponse(generate(), media_type="text/event-stream")
else:
response = await client.post(
f"{upstream_url}/v1/chat/completions",
json=chat_request,
)
chat_response = response.json()
if "error" in chat_response:
return JSONResponse(chat_response, status_code=response.status_code)
completions_response = convert_chat_to_completions_response(
chat_response, echo=echo, prompt=prompt
)
return JSONResponse(completions_response)
@app.get("/v1/models", response_model=None)
async def models():
async with httpx.AsyncClient() as client:
response = await client.get(f"{upstream_url}/v1/models")
return JSONResponse(response.json())
return app
class CompletionsProxy:
"""Manages a completions proxy server lifecycle."""
def __init__(self, upstream_host: str, upstream_port: int):
self.upstream_url = f"http://{upstream_host}:{upstream_port}"
self.port = find_free_port()
self.host = "127.0.0.1"
self._task: asyncio.Task[None] | None = None
self._shutdown_event: asyncio.Event | None = None
@property
def base_url(self) -> str:
return f"http://{self.host}:{self.port}"
async def start(self) -> None:
"""Start the proxy server in the background."""
app = create_proxy_app(self.upstream_url)
config = Config()
config.bind = [f"{self.host}:{self.port}"]
config.accesslog = None # Suppress access logs
self._shutdown_event = asyncio.Event()
async def run_server() -> None:
await serve(app, config, shutdown_trigger=self._shutdown_event.wait) # type: ignore[arg-type]
self._task = asyncio.create_task(run_server())
# Wait a bit for server to start
await asyncio.sleep(0.5)
logger.info(f"Completions proxy started on {self.base_url}")
async def stop(self) -> None:
"""Stop the proxy server."""
if self._shutdown_event:
self._shutdown_event.set()
if self._task:
try:
await asyncio.wait_for(self._task, timeout=5.0)
except asyncio.TimeoutError:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
logger.info("Completions proxy stopped")
@asynccontextmanager
async def completions_proxy_context(
upstream_host: str, upstream_port: int
) -> AsyncIterator[CompletionsProxy]:
"""Context manager for running the completions proxy."""
proxy = CompletionsProxy(upstream_host, upstream_port)
await proxy.start()
try:
yield proxy
finally:
await proxy.stop()
@contextmanager
def run_completions_proxy(
upstream_host: str, upstream_port: int
) -> Generator[CompletionsProxy, None, None]:
"""Synchronous context manager that runs proxy in a subprocess."""
import subprocess
import sys
import time
port = find_free_port()
upstream_url = f"http://{upstream_host}:{upstream_port}"
# Start proxy as subprocess
proc = subprocess.Popen(
[
sys.executable,
"-c",
f"""
import asyncio
import sys
from bench.completions_proxy import create_proxy_app
from hypercorn.asyncio import serve
from hypercorn.config import Config
async def main():
print(f"Proxy starting: 127.0.0.1:{port} -> {upstream_url}", file=sys.stderr, flush=True)
app = create_proxy_app("{upstream_url}")
config = Config()
config.bind = ["127.0.0.1:{port}"]
config.accesslog = "-" # Log to stderr
config.errorlog = "-"
await serve(app, config)
asyncio.run(main())
""",
],
stdout=None, # Inherit stdout
stderr=None, # Inherit stderr
)
# Create a proxy object with the right base_url
class ProxyInfo:
def __init__(self, host: str, port: int):
self.host = host
self.port = port
@property
def base_url(self) -> str:
return f"http://{self.host}:{self.port}"
proxy = ProxyInfo("127.0.0.1", port)
# Wait for server to start
time.sleep(1.0)
logger.info(f"Completions proxy started on {proxy.base_url} -> {upstream_url}")
try:
yield proxy # type: ignore[misc]
finally:
proc.terminate()
try:
proc.wait(timeout=5.0)
except subprocess.TimeoutExpired:
proc.kill()
logger.info("Completions proxy stopped")

66
bench/eval_config.toml Normal file
View File

@@ -0,0 +1,66 @@
# exo-eval configuration file
# See bench/exo_eval.py for usage
[eval]
# Eval framework type: "lm_eval" | "swe_bench" | "custom"
type = "lm_eval"
# Require HuggingFace token (default: true)
# Set to false if using only public datasets
require_hf_token = true
# Instance/placement configuration
# Controls how exo sets up the model instance before running evals
[instance]
# Placement strategy: "ring" | "jaccl" | "both"
instance_meta = "ring"
# Sharding strategy: "pipeline" | "tensor" | "both"
sharding = "pipeline"
# Node constraints
min_nodes = 1
max_nodes = 4
# lm_eval configuration (EleutherAI's lm-evaluation-harness)
[lm_eval]
# Tasks to run (list of task names)
# NOTE: Chat completions API only supports generation-based tasks.
# Loglikelihood tasks (mmlu, hellaswag, arc) require /v1/completions endpoint.
#
# Generation-based tasks (work with chat completions):
# - mmlu_pro, mmlu_generative, mmlu_flan_cot_fewshot, mmlu_flan_cot_zeroshot
# - gsm8k, gsm8k_cot, gsm8k_cot_zeroshot
# - truthfulqa (uses generate_until for some subtasks)
# - humaneval, mbpp (code generation)
#
# Run `lm_eval --tasks list` to see all available tasks
tasks = ["mmlu_pro"]
# Number of few-shot examples (5 is standard for mmlu_pro CoT)
num_fewshot = 5
# Batch size (use 1 for API models, "auto" doesn't work)
batch_size = 1
# Number of concurrent requests (set > 1 to enable parallelism)
# Higher values enable better batching throughput
num_concurrent = 64
# Apply chat template for instruct/chat models (default: true)
apply_chat_template = true
# Use fewshot examples as conversation turns (better for chat models)
fewshot_as_multiturn = true
# Optional: limit samples per task (omit or comment out for no limit)
# limit = 100
# Output path for results
output_path = "bench/eval_results"
# SWE-bench configuration (placeholder)
[swe_bench]
# SWE-bench dataset
dataset = "princeton-nlp/SWE-bench_Lite"
# Maximum workers for parallel execution
max_workers = 8
# Path for prediction outputs
predictions_path = "bench/predictions"
# Custom evaluation script configuration
[custom]
# Path to custom evaluation script
script = "path/to/eval_script.py"
# Arguments to pass to the script
args = ["--arg1", "value1"]

629
bench/exo_eval.py Normal file
View File

@@ -0,0 +1,629 @@
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
"""
exo-eval: Evaluation harness for exo inference system.
Supports multiple evaluation frameworks via TOML configuration:
- lm_eval: Language model evaluation using EleutherAI's lm-evaluation-harness
- swe_bench: SWE-bench evaluation (placeholder for future implementation)
- custom: Custom evaluation scripts
Usage:
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit --dry-run
"""
from __future__ import annotations
import argparse
import contextlib
import json
import os
import subprocess
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Literal
# Add parent directory to path for direct script execution
if __name__ == "__main__" and __package__ is None:
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import tomlkit
from huggingface_hub import get_token as get_hf_token
from loguru import logger
from tomlkit.exceptions import TOMLKitError
from bench.completions_proxy import tasks_require_completions
from bench.exo_bench import (
ExoClient,
ExoHttpError,
instance_id_from_instance,
nodes_used_in_instance,
placement_filter,
resolve_model_short_id,
sharding_filter,
wait_for_instance_gone,
wait_for_instance_ready,
)
EvalType = Literal["lm_eval", "swe_bench", "custom"]
def load_config(config_path: str) -> dict[str, Any]:
"""Load and parse TOML configuration file."""
path = Path(config_path)
if not path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(path, encoding="utf-8") as f:
return dict(tomlkit.load(f))
def get_eval_type(config: dict[str, Any]) -> EvalType:
"""Extract evaluation type from config."""
eval_section = config.get("eval", {})
eval_type = eval_section.get("type", "lm_eval")
if eval_type not in ("lm_eval", "swe_bench", "custom"):
raise ValueError(f"Unknown eval type: {eval_type}")
return eval_type
def check_hf_token(config: dict[str, Any]) -> bool:
"""Check if HuggingFace token is available when required.
Returns True if token is available or not required, False otherwise.
"""
eval_section = config.get("eval", {})
require_hf_token = eval_section.get("require_hf_token", True)
if not require_hf_token:
return True
token = get_hf_token()
if token is None:
logger.error(
"HuggingFace token not found. "
"Set HF_TOKEN environment variable or run 'huggingface-cli login'. "
"To disable this check, set require_hf_token = false in [eval] config."
)
return False
logger.info("HuggingFace token found")
return True
def select_placement(
client: ExoClient,
full_model_id: str,
config: dict[str, Any],
) -> dict[str, Any] | None:
"""Select a placement based on config preferences."""
instance_config = config.get("instance", {})
# If explicit instance is provided, use it directly
if "instance" in instance_config:
return instance_config["instance"]
# Otherwise, select from previews based on preferences
instance_meta_pref = instance_config.get("instance_meta", "ring")
sharding_pref = instance_config.get("sharding", "pipeline")
max_nodes = instance_config.get("max_nodes", 4)
min_nodes = instance_config.get("min_nodes", 1)
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), instance_meta_pref):
continue
if not sharding_filter(str(p.get("sharding", "")), sharding_pref):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
if min_nodes <= n <= max_nodes:
selected.append(p)
if not selected:
return None
# Sort by preference: exact match on sharding/meta, then by node count (descending)
def sort_key(p: dict[str, Any]) -> tuple[int, int, int]:
meta_match = (
1 if instance_meta_pref in str(p.get("instance_meta", "")).lower() else 0
)
sharding_match = 1 if sharding_pref in str(p.get("sharding", "")).lower() else 0
n_nodes = nodes_used_in_instance(p["instance"])
return (meta_match, sharding_match, n_nodes)
selected.sort(key=sort_key, reverse=True)
return selected[0]
def setup_instance(
client: ExoClient,
full_model_id: str,
config: dict[str, Any],
dry_run: bool,
) -> tuple[str | None, dict[str, Any] | None]:
"""Create and wait for an instance to be ready. Returns (instance_id, preview)."""
preview = select_placement(client, full_model_id, config)
if preview is None:
logger.error("No valid placement found matching config preferences")
return None, None
instance_data = preview.get("instance")
instance: dict[str, Any] = (
instance_data if isinstance(instance_data, dict) else preview
)
instance_id = instance_id_from_instance(instance)
sharding = str(preview.get("sharding", "unknown"))
instance_meta = str(preview.get("instance_meta", "unknown"))
n_nodes = nodes_used_in_instance(instance)
logger.info(f"Selected placement: {sharding} / {instance_meta} / nodes={n_nodes}")
logger.info(f"Instance ID: {instance_id}")
if dry_run:
logger.info("[dry-run] Would create instance and wait for ready")
return instance_id, preview
# Create instance
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
logger.info("Instance is ready")
time.sleep(1) # Brief pause after ready
return instance_id, preview
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize instance: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
return None, None
def teardown_instance(client: ExoClient, instance_id: str) -> None:
"""Delete an instance and wait for it to be gone."""
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
wait_for_instance_gone(client, instance_id)
logger.info(f"Instance {instance_id} deleted")
def build_lm_eval_args(
config: dict[str, Any],
base_url: str,
model: str,
output_path: str | None,
limit: int | None,
use_completions: bool,
) -> list[str]:
"""Build command-line arguments for lm_eval."""
lm_eval_config = config.get("lm_eval", {})
# Choose model type based on whether tasks need completions API
if use_completions:
model_type = "local-completions"
endpoint_url = f"{base_url}/v1/completions"
else:
model_type = "local-chat-completions"
endpoint_url = f"{base_url}/v1/chat/completions"
# Build model_args string with num_concurrent if specified
model_args_parts = [f"model={model}", f"base_url={endpoint_url}"]
num_concurrent = lm_eval_config.get("num_concurrent")
if num_concurrent is not None and num_concurrent > 1:
model_args_parts.append(f"num_concurrent={num_concurrent}")
model_args = ",".join(model_args_parts)
args = [
"lm_eval",
"--model",
model_type,
"--model_args",
model_args,
"--verbosity",
"WARNING",
]
# Tasks
tasks = lm_eval_config.get("tasks", ["mmlu"])
tasks_str = ",".join(tasks) if isinstance(tasks, list) else str(tasks)
args.extend(["--tasks", tasks_str])
# Few-shot
num_fewshot = lm_eval_config.get("num_fewshot")
if num_fewshot is not None:
args.extend(["--num_fewshot", str(num_fewshot)])
# Batch size (default to 1 for API models, "auto" doesn't work)
batch_size = lm_eval_config.get("batch_size", 1)
args.extend(["--batch_size", str(batch_size)])
# Apply chat template for instruct/chat models (default: true)
# Only applies to chat completions, but doesn't hurt to include
apply_chat_template = lm_eval_config.get("apply_chat_template", True)
if apply_chat_template and not use_completions:
args.append("--apply_chat_template")
# Fewshot as multiturn (optional, works with chat template)
fewshot_as_multiturn = lm_eval_config.get("fewshot_as_multiturn", False)
if fewshot_as_multiturn and not use_completions:
args.append("--fewshot_as_multiturn")
# Limit (command line overrides config)
effective_limit = limit if limit is not None else lm_eval_config.get("limit")
if effective_limit is not None:
args.extend(["--limit", str(effective_limit)])
# Output path
effective_output = output_path or lm_eval_config.get("output_path")
if effective_output:
args.extend(["--output_path", effective_output])
# Log model responses for post-hoc analysis when output is saved
args.append("--log_samples")
return args
def run_lm_eval(
config: dict[str, Any],
host: str,
port: int,
model: str,
output_path: str | None,
limit: int | None,
dry_run: bool,
) -> int:
"""Run lm_eval evaluation."""
lm_eval_config = config.get("lm_eval", {})
tasks = lm_eval_config.get("tasks", ["mmlu"])
if isinstance(tasks, str):
tasks = [tasks]
# Check if tasks require the completions API
use_completions = tasks_require_completions(tasks)
if use_completions:
logger.info(
"Tasks require completions API - using native /v1/completions endpoint"
)
exo_base_url = f"http://{host}:{port}"
# Build args - use native completions or chat completions endpoint directly
args = build_lm_eval_args(
config, exo_base_url, model, output_path, limit, use_completions=use_completions
)
logger.info(f"lm_eval command: {' '.join(args)}")
if dry_run:
logger.info("[dry-run] Would execute the above command")
return 0
try:
result = subprocess.run(args, check=False)
# Print token usage summary from exo
try:
import httpx
usage_resp = httpx.get(f"{exo_base_url}/v1/usage", timeout=5)
if usage_resp.status_code == 200:
usage = usage_resp.json()
logger.info("--- Token Usage (Total) ---")
logger.info(f" Requests: {usage.get('total_requests', 0)}")
logger.info(f" Prompt tokens: {usage.get('total_prompt_tokens', 0)}")
logger.info(f" Completion tokens: {usage.get('total_completion_tokens', 0)}")
logger.info(f" Reasoning tokens: {usage.get('total_reasoning_tokens', 0)}")
logger.info(f" Total tokens: {usage.get('total_tokens', 0)}")
by_model = usage.get("by_model", {})
if by_model:
for model_name, counters in by_model.items():
logger.info(f"--- Token Usage ({model_name}) ---")
logger.info(f" Requests: {counters.get('requests', 0)}")
logger.info(f" Prompt tokens: {counters.get('prompt_tokens', 0)}")
logger.info(f" Completion tokens: {counters.get('completion_tokens', 0)}")
logger.info(f" Reasoning tokens: {counters.get('reasoning_tokens', 0)}")
except Exception:
pass # Usage endpoint not available
return result.returncode
except FileNotFoundError:
logger.error("lm_eval not found. Install with: uv sync --extra eval")
return 1
def run_swe_bench(
config: dict[str, Any],
host: str,
port: int,
model: str,
output_path: str | None,
dry_run: bool,
) -> int:
"""Run SWE-bench evaluation (placeholder)."""
swe_config = config.get("swe_bench", {})
dataset = swe_config.get("dataset", "princeton-nlp/SWE-bench_Lite")
max_workers = swe_config.get("max_workers", 8)
predictions_path = output_path or swe_config.get(
"predictions_path", "bench/predictions"
)
logger.info("SWE-bench evaluation configuration:")
logger.info(f" Dataset: {dataset}")
logger.info(f" Model: {model}")
logger.info(f" API endpoint: http://{host}:{port}/v1")
logger.info(f" Max workers: {max_workers}")
logger.info(f" Predictions path: {predictions_path}")
if dry_run:
logger.info("[dry-run] SWE-bench evaluation would be executed")
return 0
logger.warning(
"SWE-bench integration is a placeholder. "
"Implement swebench inference and evaluation logic as needed."
)
return 0
def run_custom_eval(
config: dict[str, Any],
host: str,
port: int,
model: str,
output_path: str | None,
dry_run: bool,
) -> int:
"""Run custom evaluation script."""
custom_config = config.get("custom", {})
script = custom_config.get("script")
if not script:
logger.error("No script specified in [custom] config section")
return 1
script_path = Path(script)
if not script_path.exists():
logger.error(f"Custom script not found: {script}")
return 1
script_args = custom_config.get("args", [])
if not isinstance(script_args, list):
script_args = [str(script_args)]
# Build environment with exo connection info
env = os.environ.copy()
env["EXO_HOST"] = host
env["EXO_PORT"] = str(port)
env["EXO_MODEL"] = model
if output_path:
env["EXO_OUTPUT_PATH"] = output_path
cmd = [sys.executable, str(script_path), *script_args]
logger.info(f"Custom eval command: {' '.join(cmd)}")
if dry_run:
logger.info("[dry-run] Would execute the above command")
return 0
result = subprocess.run(cmd, env=env, check=False)
return result.returncode
def write_results_metadata(
output_path: str,
config: dict[str, Any],
host: str,
port: int,
model: str,
eval_type: EvalType,
return_code: int,
preview: dict[str, Any] | None,
) -> None:
"""Write evaluation metadata to a JSON file."""
metadata: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"eval_type": eval_type,
"model": model,
"api_endpoint": f"http://{host}:{port}/v1",
"config": config,
"return_code": return_code,
}
if preview:
metadata["placement"] = {
"sharding": preview.get("sharding"),
"instance_meta": preview.get("instance_meta"),
"instance_id": instance_id_from_instance(preview["instance"])
if "instance" in preview
else None,
}
output_dir = Path(output_path)
output_dir.mkdir(parents=True, exist_ok=True)
metadata_path = output_dir / "eval_metadata.json"
with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False, default=str)
logger.info(f"Wrote evaluation metadata to: {metadata_path}")
def main() -> int:
"""Main entry point for exo-eval."""
ap = argparse.ArgumentParser(
prog="exo-eval",
description="Evaluation harness for exo inference system.",
)
ap.add_argument(
"--config",
required=True,
help="Path to TOML configuration file",
)
ap.add_argument(
"--host",
default=os.environ.get("EXO_HOST", "localhost"),
help="exo API host (default: localhost or EXO_HOST env var)",
)
ap.add_argument(
"--port",
type=int,
default=int(os.environ.get("EXO_PORT", "52415")),
help="exo API port (default: 52415 or EXO_PORT env var)",
)
ap.add_argument(
"--model",
required=True,
help="Model name/ID to evaluate",
)
ap.add_argument(
"--output",
default=None,
help="Output path for results (overrides config)",
)
ap.add_argument(
"--limit",
type=int,
default=None,
help="Limit samples per task (overrides config, lm_eval only)",
)
ap.add_argument(
"--timeout",
type=float,
default=600.0,
help="HTTP timeout in seconds (default: 600)",
)
ap.add_argument(
"--skip-instance-setup",
action="store_true",
help="Skip instance creation (assume instance already running)",
)
ap.add_argument(
"--dry-run",
action="store_true",
help="Print commands without executing",
)
args = ap.parse_args()
logger.info(f"exo-eval starting with config: {args.config}")
try:
config = load_config(args.config)
except FileNotFoundError as e:
logger.error(str(e))
return 1
except TOMLKitError as e:
logger.error(f"Failed to parse config: {e}")
return 1
eval_type = get_eval_type(config)
logger.info(f"Evaluation type: {eval_type}")
logger.info(f"Model: {args.model}")
logger.info(f"API endpoint: http://{args.host}:{args.port}/v1")
# Check HuggingFace token if required
if not check_hf_token(config):
return 1
# Setup instance and resolve model
instance_id: str | None = None
preview: dict[str, Any] | None = None
client: ExoClient | None = None
if args.skip_instance_setup:
# Use model name as-is when skipping instance setup
full_model_id = args.model
logger.info(f"Using model: {full_model_id} (instance setup skipped)")
else:
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
# Resolve model
try:
short_id, full_model_id = resolve_model_short_id(client, args.model)
logger.info(f"Resolved model: {short_id} -> {full_model_id}")
except Exception as e:
logger.error(f"Failed to resolve model: {e}")
return 1
instance_id, preview = setup_instance(
client, full_model_id, config, args.dry_run
)
if instance_id is None and not args.dry_run:
return 1
try:
# Run evaluation
if eval_type == "lm_eval":
return_code = run_lm_eval(
config,
args.host,
args.port,
full_model_id,
args.output,
args.limit,
args.dry_run,
)
elif eval_type == "swe_bench":
return_code = run_swe_bench(
config,
args.host,
args.port,
full_model_id,
args.output,
args.dry_run,
)
elif eval_type == "custom":
return_code = run_custom_eval(
config,
args.host,
args.port,
full_model_id,
args.output,
args.dry_run,
)
else:
logger.error(f"Unknown eval type: {eval_type}")
return 1
# Write metadata if output path specified and not dry-run
output_path = args.output or config.get(eval_type, {}).get("output_path")
if output_path and not args.dry_run:
write_results_metadata(
output_path,
config,
args.host,
args.port,
full_model_id,
eval_type,
return_code,
preview,
)
return return_code
finally:
# Teardown instance
if instance_id and client and not args.skip_instance_setup and not args.dry_run:
teardown_instance(client, instance_id)
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -13,6 +13,7 @@ dependencies = [
"filelock>=3.18.0",
"rustworkx>=0.17.1",
"huggingface-hub>=0.33.4",
"typer", # for huggingface-cli
"psutil>=7.0.0",
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
@@ -34,6 +35,7 @@ dependencies = [
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
exo-eval = "bench.exo_eval:main"
# dependencies only required for development
[dependency-groups]
@@ -51,6 +53,9 @@ dev = [
# cuda = [
# "mlx[cuda]==0.26.3",
# ]
eval = [
"lm_eval[api]",
]
###
# workspace configuration

View File

@@ -1,9 +1,10 @@
import base64
import json
import re
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import Literal, cast
from typing import Any, Literal, cast
from uuid import uuid4
import anyio
@@ -37,6 +38,11 @@ from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
CompletionChoice,
CompletionLogprobs,
CompletionResponse,
CompletionTaskParams,
CompletionTokensDetails,
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
@@ -51,6 +57,8 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
ImageListItem,
ImageListResponse,
Logprobs,
LogprobsContentItem,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -58,8 +66,10 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StreamingChoiceResponse,
ToolCall,
Usage,
)
from exo.shared.types.chunks import (
CompletionChunk,
ErrorChunk,
ImageChunk,
InputImageChunk,
@@ -69,6 +79,7 @@ from exo.shared.types.chunks import (
from exo.shared.types.commands import (
ChatCompletion,
Command,
Completion,
CreateInstance,
DeleteInstance,
ForwarderCommand,
@@ -95,14 +106,43 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
_THINK_TAG_RE = re.compile(r"<think>.*?</think>", re.DOTALL)
def _strip_think_tags(text: str) -> str:
"""Strip <think>...</think> blocks from response text.
These tags are an artifact of GPT-OSS channel parsing, not part of the
model's intended output. The OpenAI API content field should not contain them.
"""
return _THINK_TAG_RE.sub("", text).lstrip()
def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None) -> str:
return f"image/{image_format or 'png'}"
def _build_logprobs(chunk: TokenChunk) -> Logprobs:
"""Convert flat logprob fields to OpenAI Logprobs format."""
return Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob if chunk.logprob is not None else 0.0,
bytes=list(chunk.text.encode("utf-8")),
top_logprobs=chunk.top_logprobs or [],
)
]
)
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
) -> ChatCompletionResponse:
logprobs: Logprobs | None = None
if isinstance(chunk, TokenChunk) and chunk.logprob is not None:
logprobs = _build_logprobs(chunk)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
@@ -123,6 +163,7 @@ def chunk_to_response(
for i, tool in enumerate(chunk.tool_calls)
],
),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
@@ -183,7 +224,8 @@ class API:
)
self._chat_completion_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
CommandId,
Sender[TokenChunk | ErrorChunk | ToolCallChunk | CompletionChunk],
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
@@ -191,6 +233,9 @@ class API:
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup | None = None
# Accumulated usage stats per instance (keyed by model id)
self._usage_by_model: dict[str, dict[str, int]] = {}
def reset(self, new_session_id: SessionId, result_clock: int):
logger.info("Resetting API State")
self.state = State()
@@ -244,6 +289,7 @@ class API:
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
self.app.post("/v1/completions", response_model=None)(self.completions)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/v1/images/generations", response_model=None)(
self.image_generations
@@ -255,6 +301,42 @@ class API:
self.app.get("/images/{image_id}")(self.get_image)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
self.app.get("/v1/usage")(self.get_usage)
def get_usage(self) -> dict[str, Any]:
"""Return accumulated token usage per model instance."""
total_requests = 0
total_prompt = 0
total_completion = 0
total_reasoning = 0
for counters in self._usage_by_model.values():
total_requests += counters.get("requests", 0)
total_prompt += counters.get("prompt_tokens", 0)
total_completion += counters.get("completion_tokens", 0)
total_reasoning += counters.get("reasoning_tokens", 0)
return {
"total_requests": total_requests,
"total_prompt_tokens": total_prompt,
"total_completion_tokens": total_completion,
"total_reasoning_tokens": total_reasoning,
"total_tokens": total_prompt + total_completion,
"by_model": self._usage_by_model,
}
def _accumulate_usage(self, model: str, prompt_tokens: int, completion_tokens: int, reasoning_tokens: int) -> None:
"""Accumulate usage stats for a model instance."""
if model not in self._usage_by_model:
self._usage_by_model[model] = {
"requests": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"reasoning_tokens": 0,
}
counters = self._usage_by_model[model]
counters["requests"] += 1
counters["prompt_tokens"] += prompt_tokens
counters["completion_tokens"] += completion_tokens
counters["reasoning_tokens"] += reasoning_tokens
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -463,12 +545,12 @@ class API:
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
) -> AsyncGenerator[TokenChunk | ErrorChunk | ToolCallChunk | CompletionChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
try:
self._chat_completion_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
TokenChunk | ErrorChunk | ToolCallChunk | CompletionChunk
]()
with recv as token_chunks:
@@ -497,8 +579,12 @@ class API:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
# Skip CompletionChunk - it's for the legacy completions API
if isinstance(chunk, CompletionChunk):
continue
assert not isinstance(chunk, ImageChunk)
if chunk.finish_reason == "error":
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
@@ -518,6 +604,15 @@ class API:
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
# Accumulate usage stats from the final chunk
if isinstance(chunk, TokenChunk) and chunk.stats is not None:
s = chunk.stats
self._accumulate_usage(
model=chunk.model,
prompt_tokens=s.prompt_tokens,
completion_tokens=s.generation_tokens,
reasoning_tokens=s.reasoning_tokens,
)
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
@@ -527,10 +622,16 @@ class API:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
logprobs_items: list[LogprobsContentItem] = []
model: str | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
# Skip CompletionChunk - it's for the legacy completions API
if isinstance(chunk, CompletionChunk):
continue
if isinstance(chunk, ErrorChunk):
raise HTTPException(
status_code=500,
@@ -542,6 +643,16 @@ class API:
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.stats is not None:
stats = chunk.stats
if chunk.logprob is not None:
lp = _build_logprobs(chunk)
if lp.content:
if len(lp.content) != 1:
logger.warning(
f"Expected 1 logprobs content item per chunk, got {len(lp.content)}"
)
logprobs_items.append(lp.content[0])
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
@@ -556,9 +667,31 @@ class API:
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
combined_text = _strip_think_tags("".join(text_parts))
assert model is not None
logprobs: Logprobs | None = None
if logprobs_items:
logprobs = Logprobs(content=logprobs_items)
usage: Usage | None = None
if stats is not None:
completion_tokens = stats.generation_tokens
usage = Usage(
prompt_tokens=stats.prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=stats.prompt_tokens + completion_tokens,
completion_tokens_details=CompletionTokensDetails(
reasoning_tokens=stats.reasoning_tokens,
) if stats.reasoning_tokens > 0 else None,
)
self._accumulate_usage(
model=model or "unknown",
prompt_tokens=stats.prompt_tokens,
completion_tokens=completion_tokens,
reasoning_tokens=stats.reasoning_tokens,
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
@@ -571,9 +704,11 @@ class API:
content=combined_text,
tool_calls=tool_calls,
),
logprobs=logprobs,
finish_reason=finish_reason,
)
],
usage=usage,
)
async def _collect_chat_completion_with_stats(
@@ -587,7 +722,11 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
# Skip CompletionChunk - it's for the legacy completions API
if isinstance(chunk, CompletionChunk):
continue
if isinstance(chunk, ErrorChunk):
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
@@ -598,6 +737,7 @@ class API:
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
stats = chunk.stats or stats
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
@@ -608,13 +748,12 @@ class API:
)
for i, tool in enumerate(chunk.tool_calls)
)
stats = chunk.stats or stats
stats = chunk.stats or stats
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
combined_text = _strip_think_tags("".join(text_parts))
assert model is not None
resp = BenchChatCompletionResponse(
@@ -667,6 +806,88 @@ class API:
return await self._collect_chat_completion(command.command_id)
async def completions(
self, payload: CompletionTaskParams
) -> CompletionResponse:
"""Handle legacy completions API for lm_eval compatibility."""
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
)
command = Completion(request_params=payload)
await self._send(command)
return await self._collect_completion(command.command_id)
async def _collect_completion(self, command_id: CommandId) -> CompletionResponse:
"""Collect completion response chunks into a single response."""
text = ""
tokens: list[str] = []
token_logprobs: list[float | None] = []
top_logprobs: list[dict[str, float]] = []
text_offset: list[int] = []
finish_reason: FinishReason | None = None
model = ""
try:
self._chat_completion_queues[command_id], recv = channel[
TokenChunk | ErrorChunk | ToolCallChunk | CompletionChunk
]()
with recv as chunks:
async for chunk in chunks:
if isinstance(chunk, CompletionChunk):
text = chunk.text
tokens = chunk.tokens
token_logprobs = chunk.token_logprobs
top_logprobs = chunk.top_logprobs
text_offset = chunk.text_offset
finish_reason = chunk.finish_reason
model = chunk.model
elif isinstance(chunk, ErrorChunk):
raise HTTPException(
status_code=500, detail=chunk.error_message
)
if chunk.finish_reason is not None:
break
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
logprobs_data: CompletionLogprobs | None = None
if tokens:
logprobs_data = CompletionLogprobs(
tokens=tokens,
token_logprobs=token_logprobs,
top_logprobs=top_logprobs,
text_offset=text_offset,
)
return CompletionResponse(
id=f"cmpl-{uuid4()}",
created=int(time.time()),
model=model,
choices=[
CompletionChoice(
text=text,
index=0,
logprobs=logprobs_data,
finish_reason=finish_reason,
)
],
)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:

View File

@@ -13,6 +13,7 @@ from exo.master.placement import (
from exo.shared.apply import apply
from exo.shared.types.commands import (
ChatCompletion,
Completion,
CreateInstance,
DeleteInstance,
ForwarderCommand,
@@ -40,6 +41,9 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
Completion as CompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
@@ -158,6 +162,48 @@ class Master:
)
)
self.command_task_mapping[command.command_id] = task_id
case Completion():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=CompletionTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageGeneration():
for instance in self.state.instances.values():

View File

@@ -97,6 +97,8 @@ class LogprobsContentItem(BaseModel):
class Logprobs(BaseModel):
content: list[LogprobsContentItem] | None = None
# This will always be null for open source models, but exists for OpenAI API
refusal: list[LogprobsContentItem] | None = None
class PromptTokensDetails(BaseModel):
@@ -149,6 +151,7 @@ class GenerationStats(BaseModel):
generation_tps: float
prompt_tokens: int
generation_tokens: int
reasoning_tokens: int = 0
peak_memory_usage: Memory
@@ -169,6 +172,52 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
# Legacy Completions API types (for lm_eval compatibility)
class CompletionLogprobs(BaseModel):
"""Logprobs in the legacy completions format."""
tokens: list[str]
token_logprobs: list[float | None]
top_logprobs: list[dict[str, float]]
text_offset: list[int]
class CompletionChoice(BaseModel):
text: str
index: int
logprobs: CompletionLogprobs | None = None
finish_reason: FinishReason | None = None
class CompletionResponse(BaseModel):
id: str
object: Literal["text_completion"] = "text_completion"
created: int
model: str
choices: list[CompletionChoice]
usage: Usage | None = None
class CompletionTaskParams(BaseModel):
"""Parameters for the legacy /v1/completions endpoint."""
model: str
# Prompt can be: string, list of strings, list of token IDs, or list of token ID lists
prompt: str | list[str] | list[int] | list[list[int]]
max_tokens: int | None = 16
temperature: float | None = 1.0
top_p: float | None = 1.0
n: int | None = 1
stream: bool = False
logprobs: int | None = None
echo: bool = False
stop: str | list[str] | None = None
presence_penalty: float | None = None
frequency_penalty: float | None = None
seed: int | None = None
user: str | None = None
class ChatCompletionTaskParams(BaseModel):
model: str
frequency_penalty: float | None = None

View File

@@ -2,7 +2,7 @@ from collections.abc import Generator
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats, ImageGenerationStats
from exo.shared.types.api import GenerationStats, ImageGenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -17,6 +17,8 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
@@ -32,6 +34,17 @@ class ToolCallChunk(BaseChunk):
stats: GenerationStats | None = None
class CompletionChunk(BaseChunk):
"""Chunk for legacy completions API with full logprobs for all tokens."""
text: str
tokens: list[str]
token_logprobs: list[float | None]
top_logprobs: list[dict[str, float]]
text_offset: list[int]
finish_reason: FinishReason | None = None
class ImageChunk(BaseChunk):
data: str
chunk_index: int
@@ -67,4 +80,4 @@ class InputImageChunk(BaseChunk):
yield name, value
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
GenerationChunk = TokenChunk | CompletionChunk | ImageChunk | ToolCallChunk | ErrorChunk

View File

@@ -3,6 +3,7 @@ from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import (
ChatCompletionTaskParams,
CompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
@@ -25,6 +26,12 @@ class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
class Completion(BaseCommand):
"""Legacy completions API command for scoring/generation."""
request_params: CompletionTaskParams
class ImageGeneration(BaseCommand):
request_params: ImageGenerationTaskParams
@@ -66,6 +73,7 @@ Command = (
TestCommand
| RequestEventLog
| ChatCompletion
| Completion
| ImageGeneration
| ImageEdits
| PlaceInstance

View File

@@ -4,6 +4,7 @@ from pydantic import Field
from exo.shared.types.api import (
ChatCompletionTaskParams,
CompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
@@ -60,6 +61,16 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class Completion(BaseTask):
"""Legacy completions task for scoring tokens with echo=True."""
command_id: CommandId
task_params: CompletionTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -87,6 +98,7 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| Completion
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -6,6 +6,7 @@ from exo.shared.types.api import (
GenerationStats,
ImageGenerationStats,
ToolCallItem,
TopLogprobItem,
)
from exo.utils.pydantic_ext import TaggedModel
@@ -14,14 +15,11 @@ class BaseRunnerResponse(TaggedModel):
pass
class TokenizedResponse(BaseRunnerResponse):
prompt_tokens: int
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None

View File

@@ -194,6 +194,22 @@ class MpReceiver[T]:
raise EndOfStream from None
return item
def receive_with_timeout(self, timeout: float) -> T | None:
"""Receive with timeout, returns None if no message within timeout."""
if self._state.closed.is_set():
raise ClosedResourceError
try:
item = self._state.buffer.get(block=True, timeout=timeout)
if isinstance(item, _MpEndOfStream):
self.close()
raise EndOfStream
return item
except Empty:
return None
except ValueError as e:
raise ClosedResourceError from e
# nb: this function will not cancel particularly well
async def receive_async(self) -> T:
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))

View File

@@ -12,12 +12,11 @@ from exo.shared.types.api import (
ChatCompletionMessage,
FinishReason,
GenerationStats,
TopLogprobItem,
)
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
@@ -115,6 +114,206 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def extract_top_logprobs(
logprobs_array: mx.array,
selected_token: int,
tokenizer: TokenizerWrapper,
top_k: int | None,
) -> tuple[float, list[TopLogprobItem]]:
"""Extract the selected token's logprob and top-k alternatives.
top k an be set to None to return all the logprobs
"""
selected_logprob = float(logprobs_array[selected_token].item())
if top_k == 0:
return selected_logprob, []
vocab_size = logprobs_array.shape[0]
if top_k is None:
sorted_indices = mx.argsort(-logprobs_array)
mx.eval(sorted_indices)
indices_list: list[int] = cast(list[int], sorted_indices.tolist())
else:
k = min(top_k, vocab_size)
top_indices = mx.argpartition(-logprobs_array, kth=k - 1)[:k]
top_logprobs_values = logprobs_array[top_indices]
sorted_order = mx.argsort(-top_logprobs_values)
top_indices = top_indices[sorted_order]
mx.eval(top_indices)
indices_list = cast(list[int], top_indices.tolist())
top_logprob_items: list[TopLogprobItem] = []
for token_id in indices_list:
logprob_value = float(logprobs_array[token_id].item())
token_str = tokenizer.decode([token_id])
top_logprob_items.append(
TopLogprobItem(
token=token_str,
logprob=logprob_value,
bytes=list(token_str.encode("utf-8")),
)
)
return selected_logprob, top_logprob_items
def score_tokens(
model: Model,
tokenizer: TokenizerWrapper,
tokens: list[int],
top_k: int | None = None,
) -> list[tuple[float, list[TopLogprobItem]]]:
"""Score a sequence of tokens, returning logprobs for each token.
This is used for the completions API with echo=True, where we need
logprobs for the prompt tokens (not just generated tokens).
Args:
model: The MLX model.
tokenizer: The tokenizer.
tokens: List of token IDs to score.
top_k: Number of top logprobs to return per position.
If None, returns all logprobs.
Returns:
List of (token_logprob, top_logprobs) tuples for each token position.
The first position has no logprob (no previous context), so returns (0.0, []).
"""
if len(tokens) == 0:
return []
# First token has no previous context to condition on
results: list[tuple[float, list[TopLogprobItem]]] = [(0.0, [])]
if len(tokens) == 1:
return results
# Create an empty KV cache for the forward pass
cache = make_kv_cache(model=model)
# Convert to MLX array and run forward pass
input_tokens = mx.array(tokens[:-1])[None] # All tokens except last, batched
# Run the model to get logits for all positions
# The model returns logits with shape [1, seq_len, vocab_size]
logits: mx.array = model(input_tokens, cache=cast(list[KVCache], cache))
logits = logits.squeeze(0) # Shape: [seq_len, vocab_size]
# Convert to log probabilities
logprobs_all: mx.array = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
mx.eval(logprobs_all)
# For each position, extract the logprob of the actual next token
for i in range(len(tokens) - 1):
next_token = tokens[i + 1]
logprobs_at_position: mx.array = logprobs_all[i]
logprob, top_logprobs_items = extract_top_logprobs(
logprobs_array=logprobs_at_position,
selected_token=next_token,
tokenizer=tokenizer,
top_k=top_k,
)
results.append((logprob, top_logprobs_items))
return results
def score_tokens_batched(
model: Model,
tokenizer: TokenizerWrapper,
token_sequences: list[list[int]],
top_k: int | None = None,
) -> list[list[tuple[float, list[TopLogprobItem]]]]:
"""Score multiple token sequences in a single batched forward pass.
This is significantly faster than calling score_tokens() multiple times
because it batches the forward pass across all sequences.
Args:
model: The MLX model.
tokenizer: The tokenizer.
token_sequences: List of token ID sequences to score.
top_k: Number of top logprobs to return per position.
Returns:
List of results for each sequence. Each result is a list of
(token_logprob, top_logprobs) tuples for each token position.
"""
if not token_sequences:
return []
# Handle empty sequences and single-token sequences
results: list[list[tuple[float, list[TopLogprobItem]]]] = []
non_empty_indices: list[int] = []
non_empty_sequences: list[list[int]] = []
for i, tokens in enumerate(token_sequences):
if len(tokens) == 0:
results.append([])
elif len(tokens) == 1:
results.append([(0.0, [])])
else:
results.append([]) # Placeholder, will be filled later
non_empty_indices.append(i)
non_empty_sequences.append(tokens)
if not non_empty_sequences:
return results
# Find max sequence length (excluding last token since we predict it)
max_len = max(len(seq) - 1 for seq in non_empty_sequences)
# Get pad token (use eos_token_id or 0)
pad_token_id = getattr(tokenizer, "pad_token_id", None)
if pad_token_id is None:
pad_token_id = getattr(tokenizer, "eos_token_id", 0)
# Pad sequences and create attention mask
batch_size = len(non_empty_sequences)
padded_inputs = mx.full((batch_size, max_len), pad_token_id, dtype=mx.int32)
seq_lengths: list[int] = []
for i, tokens in enumerate(non_empty_sequences):
input_len = len(tokens) - 1 # Exclude last token
padded_inputs[i, :input_len] = mx.array(tokens[:-1], dtype=mx.int32)
seq_lengths.append(input_len)
# Run batched forward pass (no KV cache for scoring)
# The model accepts [batch_size, seq_len] and returns [batch_size, seq_len, vocab_size]
logits = model(padded_inputs, cache=None)
# Convert to log probabilities - logits shape: [batch, seq_len, vocab]
logprobs_all = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
mx.eval(logprobs_all)
# Extract results for each sequence
for batch_idx, (orig_idx, tokens, seq_len) in enumerate(
zip(non_empty_indices, non_empty_sequences, seq_lengths, strict=True)
):
seq_results: list[tuple[float, list[TopLogprobItem]]] = [(0.0, [])]
for pos in range(seq_len):
next_token = tokens[pos + 1]
logprobs_at_position: mx.array = logprobs_all[batch_idx, pos]
logprob, top_logprobs_items = extract_top_logprobs(
logprobs_array=logprobs_at_position,
selected_token=next_token,
tokenizer=tokenizer,
top_k=top_k,
)
seq_results.append((logprob, top_logprobs_items))
results[orig_idx] = seq_results
return results
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
@@ -144,6 +343,10 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0,
)
# Determine if we need logprobs
should_extract_logprobs = task.logprobs is True
top_k = task.top_logprobs if task.top_logprobs is not None else 0
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
@@ -177,9 +380,22 @@ def mlx_generate(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
# Extract logprobs if requested
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
if should_extract_logprobs:
logprob, top_logprobs = extract_top_logprobs(
logprobs_array=out.logprobs,
selected_token=out.token,
tokenizer=tokenizer,
top_k=top_k,
)
yield GenerationResponse(
text=out.text,
token=out.token,
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
)

View File

@@ -28,6 +28,8 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion,
Completion,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -177,7 +179,6 @@ class Worker:
async def plan_step(self):
while True:
await anyio.sleep(0.1)
# 3. based on the updated state, we plan & execute an operation.
task: Task | None = plan(
self.node_id,
@@ -191,8 +192,10 @@ class Worker:
self.input_chunk_counts,
)
if task is None:
# Only sleep when there's nothing to do - allows rapid task dispatch
await anyio.sleep(0.01)
continue
logger.info(f"Worker plan: {task.__class__.__name__}")
logger.debug(f"Worker plan: {task.__class__.__name__}")
assert task.task_status
await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))
@@ -292,6 +295,12 @@ class Worker:
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case ChatCompletion() | Completion():
# Don't wait for acknowledgment for batchable inference tasks
# This allows multiple tasks to reach the runner for batching
await self.runners[self._task_to_runner_id(task)].start_task(
task, wait_for_ack=False
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)

View File

@@ -6,6 +6,7 @@ from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
ChatCompletion,
Completion,
ConnectToGroup,
CreateRunner,
DownloadModel,
@@ -269,9 +270,9 @@ def _pending_tasks(
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
# for now, just forward chat completions and completions
# TODO(ciaran): do this better!
if not isinstance(task, (ChatCompletion, ImageGeneration, ImageEdits)):
if not isinstance(task, (ChatCompletion, Completion, ImageGeneration, ImageEdits)):
continue
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
continue
@@ -294,9 +295,14 @@ def _pending_tasks(
if task.task_id in runner.completed:
continue
# Skip tasks already sent to runner (waiting for completion)
if task.task_id in runner.sent:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
# Allow sending tasks when runner is Ready OR Running (for batching)
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):

View File

@@ -0,0 +1,409 @@
"""Batched inference handler for processing multiple ChatCompletion requests concurrently."""
import time
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import Any, Callable, Literal
import mlx.core as mx
from mlx_lm.generate import BatchGenerator
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import (
GenerationStats,
TopLogprobItem,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated, Event
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletion
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import MAX_TOKENS
from exo.worker.engines.mlx.generator.generate import extract_top_logprobs
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
from exo.worker.runner.bootstrap import logger
# Type alias for the finish_reason values TokenChunk accepts
TokenFinishReason = Literal["stop", "length", "content_filter"]
@dataclass
class PendingRequest:
"""A request waiting to be added to the batch."""
task: ChatCompletion
prompt: str
max_tokens: int
sampler: Callable[[mx.array], mx.array]
should_extract_logprobs: bool
top_k: int
@dataclass
class ActiveRequest:
"""A request currently being processed in the batch."""
command_id: CommandId
should_extract_logprobs: bool
top_k: int
gpt_oss_parser: Any | None = None # StreamableParser for GPT-OSS models
gpt_oss_thinking: bool = False
tokens_generated: int = 0
reasoning_tokens: int = 0
prompt_tokens: int = 0
start_time: float = field(default_factory=time.perf_counter)
class BatchedInferenceHandler:
"""
Handles batched inference for multiple ChatCompletion requests.
Uses MLX-LM's BatchGenerator to process multiple requests concurrently,
improving throughput for scenarios with multiple concurrent requests.
"""
def __init__(
self,
model: Model,
tokenizer: TokenizerWrapper,
model_id: ModelId,
device_rank: int,
max_batch_size: int = 8,
batch_timeout_ms: int = 50,
):
self.model = model
self.tokenizer = tokenizer
self.model_id = model_id
self.device_rank = device_rank
self.max_batch_size = max_batch_size
self.batch_timeout_ms = batch_timeout_ms
# GPT-OSS model detection
self.is_gpt_oss = isinstance(model, GptOssModel)
self._gpt_oss_encoding: Any | None = None
if self.is_gpt_oss:
self._gpt_oss_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
logger.info("GPT-OSS model detected, enabling per-request stream parsing")
# Pending requests waiting to be batched
self.pending: list[PendingRequest] = []
self.pending_start_time: float | None = None
# Active batch generator and request tracking
self.batch_generator: BatchGenerator | None = None
self.uid_to_request: dict[int, ActiveRequest] = {}
# EOS tokens for the model
self.stop_tokens: set[int] = set()
eos_ids: list[int] | None = getattr(tokenizer, "eos_token_ids", None)
if eos_ids:
self.stop_tokens = set(eos_ids)
@property
def is_active(self) -> bool:
"""Check if there's an active batch being processed."""
return self.batch_generator is not None and len(self.uid_to_request) > 0
@property
def has_pending(self) -> bool:
"""Check if there are pending requests waiting to be batched."""
return len(self.pending) > 0
@property
def current_batch_size(self) -> int:
"""Current number of active requests in the batch."""
return len(self.uid_to_request)
def add_request(self, task: ChatCompletion) -> None:
"""Add a ChatCompletion request to the pending batch."""
task_params = task.task_params
# Build prompt
prompt = apply_chat_template(self.tokenizer, task_params)
# Determine max tokens
max_tokens = task_params.max_tokens or MAX_TOKENS
# Create sampler for this request
sampler = make_sampler(
temp=task_params.temperature if task_params.temperature is not None else 0.7,
top_p=task_params.top_p if task_params.top_p is not None else 1.0,
)
# Logprobs configuration
should_extract_logprobs = task_params.logprobs is True
top_k = task_params.top_logprobs if task_params.top_logprobs is not None else 0
pending_request = PendingRequest(
task=task,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
should_extract_logprobs=should_extract_logprobs,
top_k=top_k,
)
self.pending.append(pending_request)
if self.pending_start_time is None:
self.pending_start_time = time.perf_counter()
logger.info(
f"Added request to batch queue (pending={len(self.pending)}, active={self.current_batch_size})"
)
def should_flush(self) -> bool:
"""
Determine if the pending batch should be flushed.
Returns True if:
- We have pending requests AND (batch is full OR timeout reached)
"""
if not self.has_pending:
return False
# Check if batch is full
available_slots = self.max_batch_size - self.current_batch_size
if len(self.pending) >= available_slots:
return True
# Check timeout
if self.pending_start_time is not None:
elapsed_ms = (time.perf_counter() - self.pending_start_time) * 1000
if elapsed_ms >= self.batch_timeout_ms:
return True
return False
def flush(self) -> None:
"""Start processing pending requests by adding them to the BatchGenerator."""
if not self.has_pending:
return
# Determine how many requests to flush (up to available slots)
available_slots = self.max_batch_size - self.current_batch_size
requests_to_flush = self.pending[:available_slots]
self.pending = self.pending[available_slots:]
if len(self.pending) == 0:
self.pending_start_time = None
# Create batch generator if not exists
if self.batch_generator is None:
logger.info(f"Creating new BatchGenerator for {len(requests_to_flush)} requests")
mx.reset_peak_memory()
self.batch_generator = BatchGenerator(
model=self.model,
max_tokens=MAX_TOKENS,
stop_tokens=self.stop_tokens if self.stop_tokens else None,
prefill_batch_size=min(len(requests_to_flush), 8),
)
else:
logger.info(f"Adding {len(requests_to_flush)} requests to existing BatchGenerator")
# Prepare batch data - tokenize prompts since BatchGenerator expects token IDs
tokenized_prompts: list[list[int]] = []
max_tokens_list: list[int] = []
samplers: list[Callable[[mx.array], mx.array]] = []
prompt_token_counts: list[int] = []
for req in requests_to_flush:
tokens = self.tokenizer.encode(req.prompt)
tokenized_prompts.append(tokens)
max_tokens_list.append(req.max_tokens)
samplers.append(req.sampler)
prompt_token_counts.append(len(tokens))
# Insert into batch generator
# Note: BatchGenerator.insert() accepts samplers param at runtime but pyright doesn't see it
uids: list[int] = self.batch_generator.insert( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
prompts=tokenized_prompts,
max_tokens=max_tokens_list,
samplers=samplers, # pyright: ignore[reportCallIssue]
)
# Track active requests
for uid, req, prompt_tokens in zip(uids, requests_to_flush, prompt_token_counts, strict=True): # pyright: ignore[reportUnknownArgumentType]
parser = None
if self.is_gpt_oss and self._gpt_oss_encoding is not None:
parser = StreamableParser(self._gpt_oss_encoding, role=Role.ASSISTANT) # pyright: ignore[reportAny]
self.uid_to_request[uid] = ActiveRequest(
command_id=req.task.command_id,
should_extract_logprobs=req.should_extract_logprobs,
top_k=req.top_k,
prompt_tokens=prompt_tokens,
gpt_oss_parser=parser,
)
logger.info(f"Flushed {len(requests_to_flush)} requests into batch (active={self.current_batch_size}, uids={list(self.uid_to_request.keys())})")
def step(self) -> Generator[Event, None, None]:
"""
Process one generation step and yield ChunkGenerated events.
Returns a generator of events for completed tokens across all active requests.
"""
if self.batch_generator is None or not self.uid_to_request:
return
# Get next tokens for all active requests
# BatchGenerator.next() returns list of Response objects
logger.debug(f"BatchGenerator.next() called (active_uids={list(self.uid_to_request.keys())})")
responses: list[Any] = self.batch_generator.next() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
logger.debug(f"BatchGenerator.next() returned {len(responses)} responses") # pyright: ignore[reportUnknownArgumentType]
completed_uids: list[int] = []
for response in responses: # pyright: ignore[reportUnknownVariableType]
uid: int = response.uid # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if uid not in self.uid_to_request:
logger.warning(f"Received response for unknown uid: {uid}")
continue
active_request = self.uid_to_request[uid]
active_request.tokens_generated += 1
# Extract response fields with explicit typing
resp_token: int = response.token # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
resp_finish_reason: str | None = response.finish_reason # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
resp_logprobs: mx.array = response.logprobs # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
# Only emit events from device_rank 0
if self.device_rank != 0:
if resp_finish_reason is not None:
completed_uids.append(uid) # pyright: ignore[reportUnknownArgumentType]
continue
# Decode token to text, applying GPT-OSS parsing if needed
token_text = self.tokenizer.decode([resp_token])
if active_request.gpt_oss_parser is not None:
parser = active_request.gpt_oss_parser # pyright: ignore[reportAny]
parser.process(resp_token) # pyright: ignore[reportAny]
delta: str | None = parser.last_content_delta # pyright: ignore[reportAny]
channel: str = parser.current_channel # pyright: ignore[reportAny]
# Track reasoning tokens (analysis channel = thinking)
if channel == "analysis":
active_request.reasoning_tokens += 1
# Handle thinking tag transitions
prefix = ""
if channel == "analysis" and not active_request.gpt_oss_thinking:
active_request.gpt_oss_thinking = True
prefix = "<think>"
elif channel != "analysis" and active_request.gpt_oss_thinking:
active_request.gpt_oss_thinking = False
prefix = "</think>"
if resp_finish_reason is not None and active_request.gpt_oss_thinking:
# Close thinking tag on finish
prefix = "</think>"
active_request.gpt_oss_thinking = False
effective_delta = delta or ""
token_text = prefix + effective_delta if (prefix or effective_delta) else ""
# Skip empty tokens (channel markers with no content delta)
if not token_text and resp_finish_reason is None:
continue
# Extract logprobs if requested
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
if active_request.should_extract_logprobs:
logprob, top_logprobs = extract_top_logprobs(
logprobs_array=resp_logprobs, # pyright: ignore[reportUnknownArgumentType]
selected_token=resp_token, # pyright: ignore[reportUnknownArgumentType]
tokenizer=self.tokenizer,
top_k=active_request.top_k,
)
# Build stats for final token
stats: GenerationStats | None = None
finish_reason: TokenFinishReason | None = None
if resp_finish_reason is not None:
elapsed_time = time.perf_counter() - active_request.start_time
prompt_tps = active_request.prompt_tokens / max(elapsed_time, 0.001)
generation_tps = active_request.tokens_generated / max(elapsed_time, 0.001)
# Get peak memory
peak_memory_bytes = 0
if mx.metal.is_available():
peak_memory_bytes = mx.metal.get_peak_memory()
stats = GenerationStats(
prompt_tps=prompt_tps,
generation_tps=generation_tps,
prompt_tokens=active_request.prompt_tokens,
generation_tokens=active_request.tokens_generated,
reasoning_tokens=active_request.reasoning_tokens,
peak_memory_usage=Memory.from_bytes(peak_memory_bytes),
)
# Map finish reason to the narrower type TokenChunk expects
if resp_finish_reason == "stop":
finish_reason = "stop"
elif resp_finish_reason == "length":
finish_reason = "length"
elif resp_finish_reason == "content_filter":
finish_reason = "content_filter"
else:
# Unknown finish reasons default to "stop"
logger.warning(f"Unknown finish_reason: {resp_finish_reason}, mapping to 'stop'")
finish_reason = "stop"
completed_uids.append(uid) # pyright: ignore[reportUnknownArgumentType]
yield ChunkGenerated(
command_id=active_request.command_id,
chunk=TokenChunk(
model=self.model_id,
text=token_text,
token_id=resp_token, # pyright: ignore[reportUnknownArgumentType]
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
),
)
# Clean up completed requests
for uid in completed_uids:
del self.uid_to_request[uid]
# Close batch generator if no more active requests
if not self.uid_to_request and not self.pending:
self._close_generator()
def emit_error(self, command_id: CommandId, error_message: str) -> Event:
"""Create an error event for a failed request."""
return ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=self.model_id,
finish_reason="error",
error_message=error_message,
),
)
def _close_generator(self) -> None:
"""Close and clean up the batch generator."""
if self.batch_generator is not None:
self.batch_generator.close() # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue]
self.batch_generator = None
self.uid_to_request.clear()
logger.info("Batch generator closed")
def close(self) -> None:
"""Close the handler and clean up resources."""
self._close_generator()
self.pending.clear()
self.pending_start_time = None

View File

@@ -0,0 +1,200 @@
"""Batched scoring handler for processing multiple Completion requests concurrently."""
import time
from dataclasses import dataclass, field
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import TopLogprobItem
from exo.shared.types.chunks import CompletionChunk, ErrorChunk
from exo.shared.types.events import ChunkGenerated, Event
from exo.shared.types.tasks import Completion
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import score_tokens_batched
from exo.worker.runner.bootstrap import logger
@dataclass
class PendingScoringRequest:
"""A scoring request waiting to be batched."""
task: Completion
tokens: list[int]
prompt_text: str
top_k: int | None
echo: bool
@dataclass
class BatchedScoringHandler:
"""
Handles batched scoring for multiple Completion requests.
Collects multiple scoring requests and processes them in a single
batched forward pass for improved throughput.
"""
model: Model
tokenizer: TokenizerWrapper
model_id: ModelId
device_rank: int
max_batch_size: int = 32
batch_timeout_ms: int = 10
pending: list[PendingScoringRequest] = field(default_factory=list)
pending_start_time: float | None = None
@property
def has_pending(self) -> bool:
"""Check if there are pending requests."""
return len(self.pending) > 0
def add_request(
self,
task: Completion,
tokens: list[int],
prompt_text: str,
) -> None:
"""Add a Completion request to the pending batch."""
task_params = task.task_params
top_k = task_params.logprobs
self.pending.append(
PendingScoringRequest(
task=task,
tokens=tokens,
prompt_text=prompt_text,
top_k=top_k,
echo=task_params.echo,
)
)
if self.pending_start_time is None:
self.pending_start_time = time.perf_counter()
logger.debug(f"Added scoring request to batch (pending={len(self.pending)})")
def should_flush(self) -> bool:
"""Check if the batch should be flushed."""
if not self.has_pending:
return False
# Flush if batch is full
if len(self.pending) >= self.max_batch_size:
return True
# Flush if timeout reached
if self.pending_start_time is not None:
elapsed_ms = (time.perf_counter() - self.pending_start_time) * 1000
if elapsed_ms >= self.batch_timeout_ms:
return True
return False
def flush(self) -> list[Event]:
"""Process all pending requests and return events."""
if not self.has_pending:
return []
requests = self.pending
self.pending = []
self.pending_start_time = None
logger.info(f"Processing batch of {len(requests)} scoring requests")
# Collect all token sequences
token_sequences = [req.tokens for req in requests]
# Get common top_k (use first request's top_k, they should all be the same)
top_k = requests[0].top_k if requests else None
try:
# Run batched scoring
all_results = score_tokens_batched(
model=self.model,
tokenizer=self.tokenizer,
token_sequences=token_sequences,
top_k=top_k,
)
# Generate events for each request
events: list[Event] = []
for req, logprob_results in zip(requests, all_results, strict=True):
if self.device_rank != 0:
continue
event = self._build_completion_event(req, logprob_results)
events.append(event)
logger.info(f"Batch scoring complete ({len(events)} events)")
return events
except Exception as e:
# Return error events for all requests
logger.error(f"Batch scoring failed: {e}")
events = []
for req in requests:
if self.device_rank == 0:
events.append(
ChunkGenerated(
command_id=req.task.command_id,
chunk=ErrorChunk(
model=self.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
return events
def _build_completion_event(
self,
req: PendingScoringRequest,
logprob_results: list[tuple[float, list[TopLogprobItem]]],
) -> Event:
"""Build a ChunkGenerated event for a completed scoring request."""
tokens = req.tokens
tokenizer = self.tokenizer
# Build response in completions format
token_strings: list[str] = []
token_logprobs: list[float | None] = []
top_logprobs: list[dict[str, float]] = []
text_offset: list[int] = []
offset = 0
for i, token_id in enumerate(tokens):
token_str = tokenizer.decode([token_id])
token_strings.append(token_str)
if i < len(logprob_results):
logprob, top_items = logprob_results[i]
# First token has no logprob (None in OpenAI format)
token_logprobs.append(logprob if i > 0 else None)
top_lp_dict = {item.token: item.logprob for item in top_items}
top_logprobs.append(top_lp_dict)
else:
token_logprobs.append(None)
top_logprobs.append({})
text_offset.append(offset)
offset += len(token_str)
return ChunkGenerated(
command_id=req.task.command_id,
chunk=CompletionChunk(
model=self.model_id,
text=req.prompt_text if req.echo else "",
tokens=token_strings,
token_logprobs=token_logprobs,
top_logprobs=top_logprobs,
text_offset=text_offset,
finish_reason="stop",
),
)
def close(self) -> None:
"""Clean up resources."""
self.pending.clear()
self.pending_start_time = None

View File

File diff suppressed because it is too large Load Diff

View File

@@ -52,6 +52,7 @@ class RunnerSupervisor:
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
sent: set[TaskId] = field(default_factory=set, init=False) # Tasks sent to runner (not yet completed)
completed: set[TaskId] = field(default_factory=set, init=False)
@classmethod
@@ -126,21 +127,39 @@ class RunnerSupervisor:
assert self._tg
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
async def start_task(self, task: Task, wait_for_ack: bool = True):
"""
Send a task to the runner.
Args:
task: The task to send.
wait_for_ack: If True, wait for TaskAcknowledged before returning.
If False, return immediately after sending (for batching).
"""
if task.task_id in self.completed:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
logger.debug(
f"Skipping task {task.task_id} as it has already been completed"
)
return
if task.task_id in self.sent:
logger.debug(f"Task {task.task_id} already sent, skipping duplicate")
return
if task.task_id in self.pending:
logger.debug(f"Task {task.task_id} already pending, skipping duplicate")
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
self.sent.add(task.task_id)
try:
self._task_sender.send(task)
except ClosedResourceError:
logger.warning(f"Task {task} dropped, runner closed communication.")
self.sent.discard(task.task_id)
return
await event.wait()
logger.info(f"Finished task {task}")
if wait_for_ack:
await event.wait()
logger.info(f"Finished task {task}")
async def _forward_events(self):
with self._ev_recv as events:
@@ -149,7 +168,11 @@ class RunnerSupervisor:
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
# Use pop with default to handle tasks sent with wait_for_ack=False
# that may have already been removed or never added
pending_event = self.pending.pop(event.task_id, None)
if pending_event:
pending_event.set()
continue
if (
isinstance(event, TaskStatusUpdated)
@@ -167,6 +190,7 @@ class RunnerSupervisor:
),
)
self.completed.add(event.task_id)
self.sent.discard(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)

View File

@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
sent: set[TaskId] = field(default_factory=set)
class OtherTask(BaseTask):

View File

@@ -118,6 +118,10 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
# Force serial processing mode since batch mode requires a real tokenizer
monkeypatch.setattr(mlx_runner, "_should_use_serial_processing", make_nothin(True))
# Disable batch handler initialization
monkeypatch.setattr(mlx_runner, "BATCH_ENABLED", False)
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
@@ -192,29 +196,30 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
# Status update comes before ack to prevent race conditions
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=LOAD_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
TaskAcknowledged(task_id=LOAD_TASK_ID),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=WARMUP_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
TaskAcknowledged(task_id=WARMUP_TASK_ID),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
@@ -222,10 +227,10 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),

794
uv.lock generated
View File

File diff suppressed because it is too large Load Diff