mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 11:11:45 -05:00
Compare commits
105 Commits
rust-explo
...
leo/add-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ddc81385fd | ||
|
|
694be69f30 | ||
|
|
8a7d01641a | ||
|
|
5aaf8d8003 | ||
|
|
8319477913 | ||
|
|
4a2a2c092e | ||
|
|
bc90ac33d7 | ||
|
|
42d256d70d | ||
|
|
07a0622d04 | ||
|
|
9868aaaf54 | ||
|
|
90b2c97342 | ||
|
|
af97b836c7 | ||
|
|
fd722f663c | ||
|
|
4ec4695f84 | ||
|
|
87caab8647 | ||
|
|
c4b19088da | ||
|
|
b119a81d33 | ||
|
|
110ef45928 | ||
|
|
22fa3c8bb0 | ||
|
|
7608a5e7f4 | ||
|
|
04cc92a97f | ||
|
|
163bb83195 | ||
|
|
f082d284a5 | ||
|
|
5ec1906a98 | ||
|
|
5bc35b38b0 | ||
|
|
f906db66c5 | ||
|
|
d1ea5f4315 | ||
|
|
d9648194ed | ||
|
|
bc73ef1436 | ||
|
|
8aeeb46d2f | ||
|
|
edb2015607 | ||
|
|
f613ebdc6c | ||
|
|
e72a1778dd | ||
|
|
eb4c76e758 | ||
|
|
b890c671b8 | ||
|
|
e7f3f47754 | ||
|
|
d935c7a372 | ||
|
|
bd089b30d7 | ||
|
|
13b397a3c9 | ||
|
|
cf5fddf3f8 | ||
|
|
c9df4ff004 | ||
|
|
4f7869b91b | ||
|
|
b08ec25ef6 | ||
|
|
f235019c28 | ||
|
|
68a77f0910 | ||
|
|
8456e3f74b | ||
|
|
83e4725415 | ||
|
|
49dc7a8798 | ||
|
|
dea52342ca | ||
|
|
aae28d8e8b | ||
|
|
a28def8e45 | ||
|
|
56a9864e19 | ||
|
|
10afd08427 | ||
|
|
04a0690746 | ||
|
|
970717f1bb | ||
|
|
774eb1756a | ||
|
|
061e58ce39 | ||
|
|
e8b6ec131b | ||
|
|
7b4c5d0c6d | ||
|
|
fb3d1e887f | ||
|
|
2d15e49f4e | ||
|
|
c0f192897c | ||
|
|
7587cb872c | ||
|
|
bcb07782c1 | ||
|
|
24a6adf022 | ||
|
|
5d3b407602 | ||
|
|
e7a5826aed | ||
|
|
ebe279018f | ||
|
|
bf67e7d334 | ||
|
|
0cd2f6aab4 | ||
|
|
ba8a44e6a2 | ||
|
|
07c4be157b | ||
|
|
1e1eb8f8a1 | ||
|
|
1bc2d9728d | ||
|
|
7823fd7b1a | ||
|
|
05caab0047 | ||
|
|
bd8f9f2d10 | ||
|
|
34fcafa68a | ||
|
|
5152789e00 | ||
|
|
b734437b2d | ||
|
|
553939fa31 | ||
|
|
13ee17428e | ||
|
|
1b0d39c0b3 | ||
|
|
5e3cd73a9e | ||
|
|
1d1256c769 | ||
|
|
77baf9c58e | ||
|
|
022a09b6d9 | ||
|
|
0aa708fac4 | ||
|
|
eb89c2e4b9 | ||
|
|
72a5eec3f7 | ||
|
|
a25892e8d5 | ||
|
|
8798ab52ee | ||
|
|
457debc338 | ||
|
|
0cfaea41bc | ||
|
|
18c82443ba | ||
|
|
b9ec8b0a44 | ||
|
|
00442b3cfd | ||
|
|
aa41da8541 | ||
|
|
86e5d7b101 | ||
|
|
d9ddf90575 | ||
|
|
4591301767 | ||
|
|
8b0b5e1b88 | ||
|
|
bd6287727a | ||
|
|
eb53611210 | ||
|
|
71bbe5f25b |
0
bench/__init__.py
Normal file
0
bench/__init__.py
Normal file
104
bench/eval_config.toml
Normal file
104
bench/eval_config.toml
Normal file
@@ -0,0 +1,104 @@
|
||||
# exo-eval configuration file
|
||||
# See bench/exo_eval.py for usage
|
||||
|
||||
[eval]
|
||||
# Eval framework type: "lm_eval" | "swe_bench" | "custom" | "livecodebench"
|
||||
type = "livecodebench"
|
||||
# Require HuggingFace token (default: true)
|
||||
# Set to false if using only public datasets
|
||||
require_hf_token = true
|
||||
|
||||
# Instance/placement configuration
|
||||
# Controls how exo sets up the model instance before running evals
|
||||
[instance]
|
||||
# Placement strategy: "ring" | "jaccl" | "both"
|
||||
instance_meta = "jaccl"
|
||||
# Sharding strategy: "pipeline" | "tensor" | "both"
|
||||
sharding = "tensor"
|
||||
# Node constraints
|
||||
min_nodes = 2
|
||||
max_nodes = 2
|
||||
|
||||
# lm_eval configuration (EleutherAI's lm-evaluation-harness)
|
||||
[lm_eval]
|
||||
# Tasks to run (list of task names)
|
||||
# NOTE: Chat completions API only supports generation-based tasks.
|
||||
# Loglikelihood tasks (mmlu, hellaswag, arc) require /v1/completions endpoint.
|
||||
#
|
||||
# Generation-based tasks (work with chat completions):
|
||||
# - mmlu_pro, mmlu_generative, mmlu_flan_cot_fewshot, mmlu_flan_cot_zeroshot
|
||||
# - gsm8k, gsm8k_cot, gsm8k_cot_zeroshot
|
||||
# - truthfulqa (uses generate_until for some subtasks)
|
||||
# - humaneval, mbpp (code generation)
|
||||
#
|
||||
# Run `lm_eval --tasks list` to see all available tasks
|
||||
tasks = ["mmlu_pro"]
|
||||
# Number of few-shot examples (5 is standard for mmlu_pro CoT)
|
||||
num_fewshot = 5
|
||||
# Batch size (use 1 for API models, "auto" doesn't work)
|
||||
batch_size = 1
|
||||
# Number of concurrent requests (set > 1 to enable parallelism)
|
||||
# Higher values enable better batching throughput
|
||||
num_concurrent = 64
|
||||
# Apply chat template for instruct/chat models (default: true)
|
||||
apply_chat_template = true
|
||||
# Use fewshot examples as conversation turns (better for chat models)
|
||||
fewshot_as_multiturn = true
|
||||
# Optional: limit samples per task (omit or comment out for no limit)
|
||||
# limit = 100
|
||||
# Output path for results
|
||||
output_path = "bench/eval_results"
|
||||
|
||||
# LiveCodeBench configuration
|
||||
# Contamination-free code generation benchmark
|
||||
# See: https://livecodebench.github.io/
|
||||
[livecodebench]
|
||||
# Evaluation scenario: "codegeneration" | "selfrepair" | "testoutputprediction" | "codeexecution"
|
||||
scenario = "codegeneration"
|
||||
# Dataset release version (newer versions have more problems)
|
||||
# release_v1: May 2023 - March 2024 (400 problems)
|
||||
# release_v2: May 2023 - May 2024 (511 problems)
|
||||
# release_v3: May 2023 - July 2024 (612 problems)
|
||||
# release_v4: May 2023 - September 2024 (713 problems)
|
||||
# release_v5: May 2023 - January 2025 (880 problems)
|
||||
release_version = "release_v5"
|
||||
# Sampling temperature
|
||||
# - For non-reasoning models: 0 (deterministic)
|
||||
# - For reasoning/thinking models (Kimi K2.5, DeepSeek R1): use model-recommended (e.g., 1.0)
|
||||
temperature = 1.0
|
||||
# Number of samples per problem (1 for pass@1, matches Artificial Analysis)
|
||||
n_samples = 1
|
||||
# Max tokens for generation
|
||||
# - For non-reasoning models: 16384
|
||||
# - For reasoning/thinking models: use model-recommended (Kimi K2.5 uses 96k)
|
||||
max_tokens = 96000
|
||||
# Use code_generation_lite for faster evaluation (default: true)
|
||||
# Set to false to use full test suite (slower but more thorough)
|
||||
fast = true
|
||||
# Run evaluation after generation (computes pass@1, pass@5)
|
||||
evaluate = true
|
||||
# Number of parallel API requests
|
||||
multiprocess = 8
|
||||
# Cache generated outputs for resumption (disabled by default to avoid stale results)
|
||||
use_cache = false
|
||||
# Timeout in seconds (universal for all operations)
|
||||
timeout = 100000
|
||||
openai_timeout = 100000
|
||||
# Output path for results
|
||||
output_path = "bench/lcb_results"
|
||||
|
||||
# SWE-bench configuration (placeholder)
|
||||
[swe_bench]
|
||||
# SWE-bench dataset
|
||||
dataset = "princeton-nlp/SWE-bench_Lite"
|
||||
# Maximum workers for parallel execution
|
||||
max_workers = 8
|
||||
# Path for prediction outputs
|
||||
predictions_path = "bench/predictions"
|
||||
|
||||
# Custom evaluation script configuration
|
||||
[custom]
|
||||
# Path to custom evaluation script
|
||||
script = "path/to/eval_script.py"
|
||||
# Arguments to pass to the script
|
||||
args = ["--arg1", "value1"]
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@@ -24,7 +25,7 @@ class ExoHttpError(RuntimeError):
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 100000.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
@@ -180,14 +181,7 @@ def parse_int_list(values: list[str]) -> list[int]:
|
||||
part = part.strip()
|
||||
if part:
|
||||
items.append(int(part))
|
||||
|
||||
seen: set[int] = set()
|
||||
out: list[int] = []
|
||||
for x in items:
|
||||
if x not in seen:
|
||||
out.append(x)
|
||||
seen.add(x)
|
||||
return out
|
||||
return items
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
@@ -277,12 +271,29 @@ class PromptSizer:
|
||||
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
|
||||
)
|
||||
|
||||
content = ""
|
||||
tok = self.count_fn(content)
|
||||
# Estimate tokens per atom using a sample
|
||||
sample_count = 100
|
||||
sample_content = self.atom * sample_count
|
||||
sample_tokens = self.count_fn(sample_content) - self.base_tokens
|
||||
tokens_per_atom = sample_tokens / sample_count
|
||||
|
||||
while tok < target:
|
||||
content += self.atom
|
||||
tok = self.count_fn(content)
|
||||
# Estimate starting point
|
||||
needed_tokens = target - self.base_tokens
|
||||
estimated_atoms = int(needed_tokens / tokens_per_atom)
|
||||
|
||||
# Binary search to find exact atom count
|
||||
low, high = 0, estimated_atoms * 2 + 100
|
||||
while low < high:
|
||||
mid = (low + high) // 2
|
||||
tok = self.count_fn(self.atom * mid)
|
||||
if tok < target:
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid
|
||||
|
||||
content = self.atom * low
|
||||
tok = self.count_fn(content)
|
||||
logger.info(f"{tok=}")
|
||||
|
||||
if tok != target:
|
||||
raise RuntimeError(
|
||||
@@ -348,7 +359,7 @@ def main() -> int:
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
|
||||
"--timeout", type=float, default=100000.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
@@ -369,6 +380,14 @@ def main() -> int:
|
||||
logger.error("--repeat must be >= 1")
|
||||
return 2
|
||||
|
||||
# Log pairing mode
|
||||
if len(pp_list) == len(tg_list):
|
||||
logger.info(f"pp/tg mode: tandem (zip) - {len(pp_list)} pairs")
|
||||
else:
|
||||
logger.info(
|
||||
f"pp/tg mode: combinations (product) - {len(pp_list) * len(tg_list)} pairs"
|
||||
)
|
||||
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
@@ -486,60 +505,55 @@ def main() -> int:
|
||||
)
|
||||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||||
|
||||
for pp in pp_list:
|
||||
# if (
|
||||
# pp * n_nodes > 2048
|
||||
# and "ring" in instance_meta.lower()
|
||||
# and "tensor" in sharding.lower()
|
||||
# ):
|
||||
# model_card = MODEL_CARDS[short_id]
|
||||
# if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
# logger.info(
|
||||
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
# )
|
||||
# continue
|
||||
for tg in tg_list:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
time.sleep(3)
|
||||
try:
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client, full_model_id, pp, tg, prompt_sizer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
row.update(
|
||||
{
|
||||
"model_short_id": short_id,
|
||||
"model_id": full_model_id,
|
||||
"placement_sharding": sharding,
|
||||
"placement_instance_meta": instance_meta,
|
||||
"placement_nodes": n_nodes,
|
||||
"instance_id": instance_id,
|
||||
"pp_tokens": actual_pp_tokens,
|
||||
"tg": tg,
|
||||
"repeat_index": r,
|
||||
}
|
||||
)
|
||||
runs.append(row)
|
||||
all_rows.append(row)
|
||||
# If pp and tg lists have same length, run in tandem (zip)
|
||||
# Otherwise, run all combinations (cartesian product)
|
||||
if len(pp_list) == len(tg_list):
|
||||
pp_tg_pairs = list(zip(pp_list, tg_list))
|
||||
else:
|
||||
pp_tg_pairs = list(itertools.product(pp_list, tg_list))
|
||||
|
||||
if runs:
|
||||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||||
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
for pp, tg in pp_tg_pairs:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
time.sleep(3)
|
||||
try:
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client, full_model_id, pp, tg, prompt_sizer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
row.update(
|
||||
{
|
||||
"model_short_id": short_id,
|
||||
"model_id": full_model_id,
|
||||
"placement_sharding": sharding,
|
||||
"placement_instance_meta": instance_meta,
|
||||
"placement_nodes": n_nodes,
|
||||
"instance_id": instance_id,
|
||||
"pp_tokens": actual_pp_tokens,
|
||||
"tg": tg,
|
||||
"repeat_index": r,
|
||||
}
|
||||
)
|
||||
runs.append(row)
|
||||
all_rows.append(row)
|
||||
|
||||
logger.info(
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
f"peak_memory={format_peak_memory(peak)}\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
if runs:
|
||||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||||
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
f"peak_memory={format_peak_memory(peak)}\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
finally:
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
|
||||
1052
bench/exo_eval.py
Normal file
1052
bench/exo_eval.py
Normal file
File diff suppressed because it is too large
Load Diff
287
bench/livecodebench_runner.py
Normal file
287
bench/livecodebench_runner.py
Normal file
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
"""
|
||||
LiveCodeBench runner wrapper for exo.
|
||||
|
||||
This wrapper allows running LiveCodeBench with custom OpenAI-compatible endpoints
|
||||
by dynamically registering models and configuring the OpenAI client.
|
||||
|
||||
Usage:
|
||||
python -m bench.livecodebench_runner --model my-model --base-url http://localhost:52415/v1 [lcb args...]
|
||||
|
||||
The wrapper:
|
||||
1. Registers the custom model in LiveCodeBench's model registry
|
||||
2. Sets up environment variables for the OpenAI client
|
||||
3. Runs the standard LiveCodeBench runner
|
||||
|
||||
Requires LiveCodeBench to be installed:
|
||||
git clone https://github.com/LiveCodeBench/LiveCodeBench
|
||||
cd LiveCodeBench && uv pip install -e .
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, NoReturn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _cleanup_and_exit(exit_code: int = 130) -> NoReturn:
|
||||
"""Terminate all child processes and exit."""
|
||||
# Terminate any active multiprocessing pools
|
||||
for child in multiprocessing.active_children():
|
||||
child.terminate()
|
||||
child.join(timeout=1)
|
||||
if child.is_alive():
|
||||
child.kill()
|
||||
# Force exit to avoid hanging on cleanup
|
||||
os._exit(exit_code)
|
||||
|
||||
|
||||
def _signal_handler(signum: int, frame: object) -> NoReturn:
|
||||
"""Handle interrupt signals by terminating all child processes."""
|
||||
_cleanup_and_exit(130)
|
||||
|
||||
|
||||
def get_lcb_directory() -> Path | None:
|
||||
"""Find the LiveCodeBench installation directory.
|
||||
|
||||
LiveCodeBench uses relative paths like 'lcb_runner/prompts/few_shot_examples/...'
|
||||
which require running from the LiveCodeBench directory.
|
||||
"""
|
||||
# Check environment variable first
|
||||
if env_path := os.environ.get("LIVECODEBENCH_DIR"):
|
||||
lcb_path = Path(env_path)
|
||||
if (lcb_path / "lcb_runner" / "prompts" / "few_shot_examples").exists():
|
||||
return lcb_path
|
||||
|
||||
# Use importlib to find package location without executing module code
|
||||
# This avoids triggering the relative path imports that would fail
|
||||
try:
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.find_spec("lcb_runner")
|
||||
if spec and spec.origin:
|
||||
# spec.origin is the __init__.py path, go up two levels
|
||||
lcb_path = Path(spec.origin).parent.parent
|
||||
if (lcb_path / "lcb_runner" / "prompts" / "few_shot_examples").exists():
|
||||
return lcb_path
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
||||
|
||||
# Check common locations relative to this script
|
||||
script_dir = Path(__file__).parent.parent # exo/
|
||||
common_locations = [
|
||||
script_dir / "LiveCodeBench", # exo/LiveCodeBench
|
||||
script_dir.parent / "LiveCodeBench", # sibling to exo
|
||||
]
|
||||
for loc in common_locations:
|
||||
if (loc / "lcb_runner" / "prompts" / "few_shot_examples").exists():
|
||||
return loc
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def setup_custom_model(model_name: str, base_url: str) -> None:
|
||||
"""Register a custom model in LiveCodeBench's registry."""
|
||||
try:
|
||||
from lcb_runner.lm_styles import ( # pyright: ignore[reportMissingImports]
|
||||
LanguageModel,
|
||||
LanguageModelList,
|
||||
LanguageModelStore,
|
||||
LMStyle,
|
||||
)
|
||||
except ImportError as e:
|
||||
print(
|
||||
"Error: LiveCodeBench not installed. Install with:\n"
|
||||
" git clone https://github.com/LiveCodeBench/LiveCodeBench\n"
|
||||
" cd LiveCodeBench && uv pip install -e .",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise SystemExit(1) from e
|
||||
|
||||
# Check if model already exists
|
||||
if model_name in LanguageModelStore:
|
||||
return
|
||||
|
||||
# Create a new model entry using OpenAIChat style
|
||||
# This will route through the oai_runner which respects OPENAI_BASE_URL
|
||||
custom_model = LanguageModel(
|
||||
model_name=model_name,
|
||||
model_repr=model_name,
|
||||
model_style=LMStyle.OpenAIChat,
|
||||
release_date=datetime.now(),
|
||||
link=base_url,
|
||||
)
|
||||
|
||||
# Add to the model list and store
|
||||
LanguageModelList.append(custom_model)
|
||||
LanguageModelStore[model_name] = custom_model
|
||||
|
||||
|
||||
def patch_openai_client(base_url: str) -> None:
|
||||
"""Patch the OpenAI client to use a custom base URL.
|
||||
|
||||
This patches the oai_runner module to use our custom base URL.
|
||||
"""
|
||||
try:
|
||||
from lcb_runner.runner import oai_runner # noqa: I001 # pyright: ignore[reportMissingImports]
|
||||
except ImportError as e:
|
||||
print(f"Error importing required modules: {e}", file=sys.stderr)
|
||||
raise SystemExit(1) from e
|
||||
|
||||
# Store original client creation
|
||||
original_init = oai_runner.OpenAI
|
||||
|
||||
def patched_openai(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Create OpenAI client with custom base_url."""
|
||||
# Inject base_url if not already set
|
||||
if "base_url" not in kwargs:
|
||||
kwargs["base_url"] = base_url
|
||||
# Use dummy API key if not set (exo doesn't require auth)
|
||||
if "api_key" not in kwargs and not os.getenv("OPENAI_KEY"):
|
||||
kwargs["api_key"] = os.getenv("OPENAI_API_KEY", "exo-local")
|
||||
return original_init(*args, **kwargs)
|
||||
|
||||
# Apply the patch
|
||||
oai_runner.OpenAI = patched_openai
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point."""
|
||||
# Set up signal handlers for clean exit
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LiveCodeBench runner wrapper for exo",
|
||||
epilog="Additional arguments are passed to lcb_runner.runner.main",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default=os.environ.get("OPENAI_BASE_URL", "http://localhost:52415/v1"),
|
||||
help="OpenAI-compatible API base URL (default: OPENAI_BASE_URL or localhost:52415/v1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
help="Model name to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default=None,
|
||||
help="Output directory for results (maps to LiveCodeBench's --custom_output_save_name)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit number of problems to evaluate (for testing)",
|
||||
)
|
||||
|
||||
# Parse known args, pass rest to LiveCodeBench
|
||||
args, remaining = parser.parse_known_args()
|
||||
|
||||
# Set up environment
|
||||
os.environ["OPENAI_BASE_URL"] = args.base_url
|
||||
if "OPENAI_API_KEY" not in os.environ and "OPENAI_KEY" not in os.environ:
|
||||
os.environ["OPENAI_API_KEY"] = "exo-local"
|
||||
os.environ["OPENAI_KEY"] = "exo-local"
|
||||
|
||||
# Save original directory for output path resolution
|
||||
original_cwd = os.getcwd()
|
||||
|
||||
# Change to LiveCodeBench directory before imports that use relative paths
|
||||
# LiveCodeBench uses paths like 'lcb_runner/prompts/few_shot_examples/...'
|
||||
lcb_dir = get_lcb_directory()
|
||||
if lcb_dir:
|
||||
os.chdir(lcb_dir)
|
||||
else:
|
||||
print(
|
||||
"Warning: Could not find LiveCodeBench directory. "
|
||||
"Relative path imports may fail.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Setup custom model and patch client
|
||||
setup_custom_model(args.model, args.base_url)
|
||||
patch_openai_client(args.base_url)
|
||||
|
||||
# Build arguments for LiveCodeBench runner
|
||||
lcb_args = ["--model", args.model]
|
||||
|
||||
# Resolve output directory to absolute path (relative to original cwd)
|
||||
output_base: str | None = None
|
||||
if args.output_dir:
|
||||
output_base = str(Path(original_cwd) / args.output_dir)
|
||||
|
||||
lcb_args.extend(remaining)
|
||||
|
||||
# Run LiveCodeBench
|
||||
try:
|
||||
from lcb_runner.runner import main as lcb_main_module # noqa: I001 # pyright: ignore[reportMissingImports]
|
||||
from lcb_runner.utils import path_utils # noqa: I001 # pyright: ignore[reportMissingImports]
|
||||
|
||||
# Patch output path to use our output directory
|
||||
if output_base:
|
||||
original_get_output_path = path_utils.get_output_path
|
||||
|
||||
def patched_get_output_path(model_repr: str, runner_args: Any) -> str:
|
||||
# Get the original path and replace 'output/' with our base
|
||||
original_path = original_get_output_path(model_repr, runner_args)
|
||||
# Replace 'output/' prefix with our custom base
|
||||
if original_path.startswith("output/"):
|
||||
new_path = str(
|
||||
Path(output_base) / original_path[7:]
|
||||
) # Skip 'output/'
|
||||
else:
|
||||
new_path = str(Path(output_base) / original_path)
|
||||
path_utils.ensure_dir(new_path)
|
||||
print(f"Saving results to: {new_path}")
|
||||
return new_path
|
||||
|
||||
path_utils.get_output_path = patched_get_output_path
|
||||
# Also patch in main module since it may have imported directly
|
||||
if hasattr(lcb_main_module, "get_output_path"):
|
||||
lcb_main_module.get_output_path = patched_get_output_path
|
||||
|
||||
# Patch benchmark loading to support --limit
|
||||
# Must patch in the main module since it imports the function directly
|
||||
if args.limit is not None:
|
||||
original_build = lcb_main_module.build_prompt_benchmark
|
||||
|
||||
def limited_build(*a: Any, **kw: Any) -> Any:
|
||||
benchmark, format_prompt = original_build(*a, **kw)
|
||||
if args.limit and len(benchmark) > args.limit:
|
||||
print(
|
||||
f"Limiting benchmark from {len(benchmark)} to {args.limit} problems"
|
||||
)
|
||||
benchmark = benchmark[: args.limit]
|
||||
return benchmark, format_prompt
|
||||
|
||||
lcb_main_module.build_prompt_benchmark = limited_build
|
||||
|
||||
# Patch sys.argv for argparse in lcb_main
|
||||
sys.argv = [sys.argv[0], *lcb_args]
|
||||
lcb_main_module.main()
|
||||
return 0
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted by user", file=sys.stderr)
|
||||
_cleanup_and_exit(130)
|
||||
except SystemExit as e:
|
||||
return e.code if isinstance(e.code, int) else 1
|
||||
except Exception as e:
|
||||
print(f"Error running LiveCodeBench: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
145
bench/lm_eval_patched.py
Normal file
145
bench/lm_eval_patched.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Patched lm_eval runner that fixes bugs in the upstream library.
|
||||
|
||||
Fixes:
|
||||
- UnboundLocalError on `outputs` in TemplateAPI.amodel_call when API returns error
|
||||
- Prevents eval crash on transient API failures (returns None instead of raising)
|
||||
- Compatibility with transformers 5.x (missing AutoModelForVision2Seq)
|
||||
- sock_read timeout causing connection drops with large request queues
|
||||
|
||||
Usage: python -m bench.lm_eval_patched [lm_eval args...]
|
||||
"""
|
||||
|
||||
# ruff: noqa: I001, E402
|
||||
# pyright: reportMissingTypeStubs=false, reportUnknownVariableType=false
|
||||
# pyright: reportUnknownMemberType=false, reportAny=false, reportUnknownArgumentType=false
|
||||
# pyright: reportPrivateUsage=false, reportUnknownLambdaType=false
|
||||
|
||||
# MUST patch transformers BEFORE any lm_eval imports
|
||||
# AutoModelForVision2Seq/AutoModelForImageTextToText were removed in transformers 5.0
|
||||
# Patch the lazy module's __getattr__ to return stubs for missing classes
|
||||
from transformers.utils import import_utils
|
||||
|
||||
_original_getattr = import_utils._LazyModule.__getattr__
|
||||
|
||||
|
||||
def _patched_getattr(self: object, name: str) -> object:
|
||||
if name in ("AutoModelForVision2Seq", "AutoModelForImageTextToText"):
|
||||
return type(name, (), {}) # Return a stub class
|
||||
return _original_getattr(self, name) # type: ignore
|
||||
|
||||
|
||||
import_utils._LazyModule.__getattr__ = _patched_getattr
|
||||
|
||||
import functools
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _patch_amodel_call() -> None:
|
||||
"""Monkey-patch TemplateAPI.amodel_call to handle the unbound `outputs` variable bug."""
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
|
||||
original: Any = TemplateAPI.amodel_call
|
||||
|
||||
@functools.wraps(original)
|
||||
async def patched_amodel_call(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return await original(self, *args, **kwargs)
|
||||
except (UnboundLocalError, Exception):
|
||||
# Return one empty-string result per request in the batch so the
|
||||
# reorderer doesn't assert on missing coverage.
|
||||
messages = kwargs.get("messages") or (args[2] if len(args) > 2 else [])
|
||||
return [""] * max(len(messages), 1)
|
||||
|
||||
TemplateAPI.amodel_call = patched_amodel_call
|
||||
|
||||
|
||||
def _patch_client_timeout() -> None:
|
||||
"""Patch TemplateAPI.get_batched_requests to disable sock_read timeout.
|
||||
|
||||
By default, aiohttp's ClientTimeout can have a sock_read timeout that causes
|
||||
connections to drop if no data is received for a while. With large request
|
||||
queues, requests may wait a long time before processing starts, causing
|
||||
spurious connection drops and retries that pile up requests.
|
||||
"""
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
|
||||
original_get_batched: Any = TemplateAPI.get_batched_requests
|
||||
|
||||
@functools.wraps(original_get_batched)
|
||||
async def patched_get_batched_requests(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
# Override the timeout to explicitly disable sock_read timeout
|
||||
# This prevents connection drops when requests are queued for a long time
|
||||
original_timeout = getattr(self, "timeout", 604800)
|
||||
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
|
||||
timeout = ClientTimeout(
|
||||
total=original_timeout, sock_read=None, sock_connect=None
|
||||
)
|
||||
|
||||
async with ClientSession(connector=conn, timeout=timeout) as session:
|
||||
# Call the internal async logic with our session
|
||||
return await _run_batched_requests_with_session(
|
||||
self, session, *args, **kwargs
|
||||
)
|
||||
|
||||
async def _run_batched_requests_with_session(
|
||||
self: Any,
|
||||
session: ClientSession,
|
||||
requests: Any,
|
||||
cache_keys: Any = None,
|
||||
ctxlens: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from lm_eval.models.utils import chunks
|
||||
|
||||
eval_logger = logging.getLogger("lm_eval.models.api_models")
|
||||
ctxlens = ctxlens if ctxlens else [None] * len(requests)
|
||||
sem = asyncio.Semaphore(self._concurrent)
|
||||
|
||||
retry_: Any = retry(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
||||
reraise=True,
|
||||
before_sleep=lambda retry_state: eval_logger.info(
|
||||
f"Retry attempt {retry_state.attempt_number}"
|
||||
),
|
||||
)(self.amodel_call)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
retry_(
|
||||
session=session,
|
||||
sem=sem,
|
||||
messages=message,
|
||||
cache_keys=cache_key,
|
||||
ctxlens=ctxlen,
|
||||
gen_kwargs=copy.deepcopy(kwargs.get("gen_kwargs")),
|
||||
**{k: v for k, v in kwargs.items() if k != "gen_kwargs"},
|
||||
)
|
||||
)
|
||||
for message, cache_key, ctxlen in zip(
|
||||
chunks(requests, n=self._batch_size),
|
||||
chunks(cache_keys, n=self._batch_size),
|
||||
chunks(ctxlens, n=self._batch_size),
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
|
||||
|
||||
TemplateAPI.get_batched_requests = patched_get_batched_requests
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_patch_amodel_call()
|
||||
_patch_client_timeout()
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
|
||||
cli_evaluate()
|
||||
290
bench/stats_dashboard.html
Normal file
290
bench/stats_dashboard.html
Normal file
@@ -0,0 +1,290 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>exo Usage Stats</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'SF Mono', 'Menlo', monospace;
|
||||
background: #1a1a2e;
|
||||
color: #e0e0e0;
|
||||
padding: 24px;
|
||||
min-height: 100vh;
|
||||
}
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 24px;
|
||||
padding-bottom: 16px;
|
||||
border-bottom: 1px solid #333;
|
||||
}
|
||||
.header h1 {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
color: #fff;
|
||||
}
|
||||
.status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 13px;
|
||||
color: #888;
|
||||
}
|
||||
.status-dot {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: #666;
|
||||
}
|
||||
.status-dot.connected { background: #4caf50; }
|
||||
.status-dot.error { background: #f44336; }
|
||||
.config {
|
||||
margin-bottom: 24px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
.config label {
|
||||
font-size: 12px;
|
||||
color: #888;
|
||||
}
|
||||
.config input {
|
||||
background: #252540;
|
||||
border: 1px solid #444;
|
||||
border-radius: 4px;
|
||||
color: #e0e0e0;
|
||||
padding: 4px 8px;
|
||||
font-size: 13px;
|
||||
font-family: inherit;
|
||||
width: 280px;
|
||||
}
|
||||
.section {
|
||||
background: #252540;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.section h2 {
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: #aaa;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.stat-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 16px;
|
||||
}
|
||||
.stat-card {
|
||||
background: #1a1a2e;
|
||||
border-radius: 6px;
|
||||
padding: 16px;
|
||||
}
|
||||
.stat-label {
|
||||
font-size: 11px;
|
||||
color: #888;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.stat-value {
|
||||
font-size: 28px;
|
||||
font-weight: 700;
|
||||
color: #fff;
|
||||
}
|
||||
.stat-rate {
|
||||
font-size: 12px;
|
||||
color: #4caf50;
|
||||
margin-top: 4px;
|
||||
}
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 13px;
|
||||
}
|
||||
th {
|
||||
text-align: left;
|
||||
padding: 8px 12px;
|
||||
color: #888;
|
||||
font-weight: 500;
|
||||
border-bottom: 1px solid #333;
|
||||
font-size: 11px;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
td {
|
||||
padding: 8px 12px;
|
||||
border-bottom: 1px solid #2a2a45;
|
||||
}
|
||||
td.num {
|
||||
text-align: right;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
.model-name {
|
||||
color: #7c9eff;
|
||||
max-width: 300px;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.empty-state {
|
||||
color: #666;
|
||||
font-style: italic;
|
||||
padding: 16px 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>exo Usage Stats</h1>
|
||||
<div class="status">
|
||||
<div class="status-dot" id="statusDot"></div>
|
||||
<span id="statusText">connecting...</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="config">
|
||||
<label for="baseUrl">Base URL:</label>
|
||||
<input type="text" id="baseUrl" value="http://mac8-1:52415">
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Totals</h2>
|
||||
<div class="stat-grid">
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Requests</div>
|
||||
<div class="stat-value" id="totalRequests">0</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Prompt Tokens</div>
|
||||
<div class="stat-value" id="totalPrompt">0</div>
|
||||
<div class="stat-rate" id="promptRate"></div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Completion Tokens</div>
|
||||
<div class="stat-value" id="totalCompletion">0</div>
|
||||
<div class="stat-rate" id="completionRate"></div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Reasoning Tokens</div>
|
||||
<div class="stat-value" id="totalReasoning">0</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Total Tokens</div>
|
||||
<div class="stat-value" id="totalTokens">0</div>
|
||||
<div class="stat-rate" id="totalRate"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Per-Model Breakdown</h2>
|
||||
<div id="modelTable">
|
||||
<div class="empty-state">No data yet</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
|
||||
function fmt(n) {
|
||||
return n.toLocaleString();
|
||||
}
|
||||
|
||||
// Track first non-zero timestamp for overall average rate
|
||||
let firstSeenTime = null;
|
||||
let firstSeenTokens = { prompt: 0, completion: 0, total: 0 };
|
||||
|
||||
function setRate(id, currentTokens, tokenType) {
|
||||
const el = document.getElementById(id);
|
||||
if (firstSeenTime === null || currentTokens <= firstSeenTokens[tokenType]) {
|
||||
el.textContent = '';
|
||||
return;
|
||||
}
|
||||
const elapsed = (performance.now() / 1000) - firstSeenTime;
|
||||
if (elapsed <= 0) { el.textContent = ''; return; }
|
||||
const delta = currentTokens - firstSeenTokens[tokenType];
|
||||
const avg = delta / elapsed;
|
||||
el.textContent = fmt(Math.round(avg)) + ' tok/s avg';
|
||||
}
|
||||
|
||||
function renderModelTable(byModel) {
|
||||
const container = document.getElementById('modelTable');
|
||||
const models = Object.entries(byModel);
|
||||
if (models.length === 0) {
|
||||
container.innerHTML = '<div class="empty-state">No data yet</div>';
|
||||
return;
|
||||
}
|
||||
let html = '<table><thead><tr>';
|
||||
html += '<th>Model</th><th style="text-align:right">Requests</th>';
|
||||
html += '<th style="text-align:right">Prompt</th>';
|
||||
html += '<th style="text-align:right">Completion</th>';
|
||||
html += '<th style="text-align:right">Reasoning</th>';
|
||||
html += '<th style="text-align:right">Total</th>';
|
||||
html += '</tr></thead><tbody>';
|
||||
for (const [name, counters] of models) {
|
||||
const total = (counters.prompt_tokens || 0) + (counters.completion_tokens || 0);
|
||||
html += '<tr>';
|
||||
html += `<td class="model-name" title="${name}">${name}</td>`;
|
||||
html += `<td class="num">${fmt(counters.requests || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.prompt_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.completion_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.reasoning_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(total)}</td>`;
|
||||
html += '</tr>';
|
||||
}
|
||||
html += '</tbody></table>';
|
||||
container.innerHTML = html;
|
||||
}
|
||||
|
||||
async function poll() {
|
||||
const baseUrl = document.getElementById('baseUrl').value.replace(/\/+$/, '');
|
||||
const dot = document.getElementById('statusDot');
|
||||
const text = document.getElementById('statusText');
|
||||
|
||||
try {
|
||||
const resp = await fetch(baseUrl + '/v1/usage');
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
|
||||
const data = await resp.json();
|
||||
|
||||
dot.className = 'status-dot connected';
|
||||
text.textContent = 'connected';
|
||||
|
||||
|
||||
document.getElementById('totalRequests').textContent = fmt(data.total_requests || 0);
|
||||
document.getElementById('totalPrompt').textContent = fmt(data.total_prompt_tokens || 0);
|
||||
document.getElementById('totalCompletion').textContent = fmt(data.total_completion_tokens || 0);
|
||||
document.getElementById('totalReasoning').textContent = fmt(data.total_reasoning_tokens || 0);
|
||||
document.getElementById('totalTokens').textContent = fmt(data.total_tokens || 0);
|
||||
|
||||
// Record first non-zero reading as baseline
|
||||
if (firstSeenTime === null && (data.total_tokens || 0) > 0) {
|
||||
firstSeenTime = performance.now() / 1000;
|
||||
firstSeenTokens = {
|
||||
prompt: data.total_prompt_tokens || 0,
|
||||
completion: data.total_completion_tokens || 0,
|
||||
total: data.total_tokens || 0,
|
||||
};
|
||||
}
|
||||
|
||||
setRate('promptRate', data.total_prompt_tokens || 0, 'prompt');
|
||||
setRate('completionRate', data.total_completion_tokens || 0, 'completion');
|
||||
setRate('totalRate', data.total_tokens || 0, 'total');
|
||||
|
||||
renderModelTable(data.by_model || {});
|
||||
|
||||
} catch (e) {
|
||||
dot.className = 'status-dot error';
|
||||
text.textContent = e.message || 'error';
|
||||
}
|
||||
}
|
||||
|
||||
poll();
|
||||
setInterval(poll, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -13,13 +13,14 @@ dependencies = [
|
||||
"filelock>=3.18.0",
|
||||
"rustworkx>=0.17.1",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"typer", # for huggingface-cli
|
||||
"psutil>=7.0.0",
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.5",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -34,6 +35,7 @@ dependencies = [
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
exo-eval = "bench.exo_eval:main"
|
||||
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
@@ -51,6 +53,14 @@ dev = [
|
||||
# cuda = [
|
||||
# "mlx[cuda]==0.26.3",
|
||||
# ]
|
||||
eval = [
|
||||
"lm_eval[api]",
|
||||
# LiveCodeBench dependencies (livecodebench itself must be installed manually due to packaging issues)
|
||||
# Install with: git clone https://github.com/LiveCodeBench/LiveCodeBench && cd LiveCodeBench && uv pip install -e .
|
||||
"openai>=1.59.6",
|
||||
"datasets>=2.14.0,<4.0", # LiveCodeBench requires <4.0 due to dataset script deprecation
|
||||
"pebble>=5.1.0",
|
||||
]
|
||||
|
||||
###
|
||||
# workspace configuration
|
||||
@@ -63,6 +73,7 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm.git", branch = "main" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
|
||||
@@ -155,13 +155,23 @@ async def seed_models(seed_dir: str | Path):
|
||||
|
||||
|
||||
async def fetch_file_list_with_cache(
|
||||
model_id: ModelId, revision: str = "main", recursive: bool = False
|
||||
model_id: ModelId,
|
||||
revision: str = "main",
|
||||
recursive: bool = False,
|
||||
cache_ttl_seconds: int = 3600,
|
||||
) -> list[FileListEntry]:
|
||||
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||
|
||||
# Always try fresh first
|
||||
# Use cache if it exists and is fresh (< TTL seconds old)
|
||||
if await aios.path.exists(cache_file):
|
||||
cache_age = time.time() - (await aios.stat(cache_file)).st_mtime
|
||||
if cache_age < cache_ttl_seconds:
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
|
||||
# Cache missing or stale - fetch fresh
|
||||
try:
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
@@ -173,7 +183,7 @@ async def fetch_file_list_with_cache(
|
||||
)
|
||||
return file_list
|
||||
except Exception as e:
|
||||
# Fetch failed - try cache fallback
|
||||
# Fetch failed - try cache fallback (even if stale)
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
||||
|
||||
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
model_card = await ModelCard.load(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
|
||||
@@ -267,6 +267,11 @@ def main():
|
||||
os.environ["EXO_FAST_SYNCH"] = "off"
|
||||
logger.info("FAST_SYNCH forced OFF")
|
||||
|
||||
# Set EXO_NO_BATCH env var for runner subprocesses
|
||||
if args.no_batch:
|
||||
os.environ["EXO_NO_BATCH"] = "1"
|
||||
logger.info("Batch inference disabled (serial mode)")
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
logger.info("EXO Shutdown complete")
|
||||
@@ -282,6 +287,7 @@ class Args(CamelCaseModel):
|
||||
no_worker: bool = False
|
||||
no_downloads: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
no_batch: bool = False
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -342,6 +348,11 @@ class Args(CamelCaseModel):
|
||||
dest="fast_synch",
|
||||
help="Force MLX FAST_SYNCH off",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-batch",
|
||||
action="store_true",
|
||||
help="Disable batch inference (use serial processing for benchmarking)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Literal, cast
|
||||
from typing import Annotated, Any, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
@@ -42,6 +43,7 @@ from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
CompletionTokensDetails,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
@@ -57,6 +59,8 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
ImageListItem,
|
||||
ImageListResponse,
|
||||
Logprobs,
|
||||
LogprobsContentItem,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -66,8 +70,10 @@ from exo.shared.types.api import (
|
||||
StartDownloadResponse,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
CompletionChunk,
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
InputImageChunk,
|
||||
@@ -107,14 +113,43 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
_THINK_TAG_RE = re.compile(r"<think>.*?</think>", re.DOTALL)
|
||||
|
||||
|
||||
def _strip_think_tags(text: str) -> str:
|
||||
"""Strip <think>...</think> blocks from response text.
|
||||
|
||||
These tags are an artifact of GPT-OSS channel parsing, not part of the
|
||||
model's intended output. The OpenAI API content field should not contain them.
|
||||
"""
|
||||
return _THINK_TAG_RE.sub("", text).lstrip()
|
||||
|
||||
|
||||
def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None) -> str:
|
||||
return f"image/{image_format or 'png'}"
|
||||
|
||||
|
||||
def _build_logprobs(chunk: TokenChunk) -> Logprobs:
|
||||
"""Convert flat logprob fields to OpenAI Logprobs format."""
|
||||
return Logprobs(
|
||||
content=[
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob if chunk.logprob is not None else 0.0,
|
||||
bytes=list(chunk.text.encode("utf-8")),
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
logprobs: Logprobs | None = None
|
||||
if isinstance(chunk, TokenChunk) and chunk.logprob is not None:
|
||||
logprobs = _build_logprobs(chunk)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -135,6 +170,7 @@ def chunk_to_response(
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
],
|
||||
),
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -197,7 +233,8 @@ class API:
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
|
||||
CommandId,
|
||||
Sender[TokenChunk | ErrorChunk | ToolCallChunk | CompletionChunk],
|
||||
] = {}
|
||||
self._image_generation_queues: dict[
|
||||
CommandId, Sender[ImageChunk | ErrorChunk]
|
||||
@@ -205,6 +242,9 @@ class API:
|
||||
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
# Accumulated usage stats per instance (keyed by model id)
|
||||
self._usage_by_model: dict[str, dict[str, int]] = {}
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
logger.info("Resetting API State")
|
||||
self.state = State()
|
||||
@@ -271,6 +311,48 @@ class API:
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.get("/v1/usage")(self.get_usage)
|
||||
|
||||
def get_usage(self) -> dict[str, Any]:
|
||||
"""Return accumulated token usage per model instance."""
|
||||
total_requests = 0
|
||||
total_prompt = 0
|
||||
total_completion = 0
|
||||
total_reasoning = 0
|
||||
for counters in self._usage_by_model.values():
|
||||
total_requests += counters.get("requests", 0)
|
||||
total_prompt += counters.get("prompt_tokens", 0)
|
||||
total_completion += counters.get("completion_tokens", 0)
|
||||
total_reasoning += counters.get("reasoning_tokens", 0)
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"total_prompt_tokens": total_prompt,
|
||||
"total_completion_tokens": total_completion,
|
||||
"total_reasoning_tokens": total_reasoning,
|
||||
"total_tokens": total_prompt + total_completion,
|
||||
"by_model": self._usage_by_model,
|
||||
}
|
||||
|
||||
def _accumulate_usage(
|
||||
self,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
reasoning_tokens: int,
|
||||
) -> None:
|
||||
"""Accumulate usage stats for a model instance."""
|
||||
if model not in self._usage_by_model:
|
||||
self._usage_by_model[model] = {
|
||||
"requests": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
}
|
||||
counters = self._usage_by_model[model]
|
||||
counters["requests"] += 1
|
||||
counters["prompt_tokens"] += prompt_tokens
|
||||
counters["completion_tokens"] += completion_tokens
|
||||
counters["reasoning_tokens"] += reasoning_tokens
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
@@ -492,29 +574,37 @@ class API:
|
||||
)
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
self, command_id: CommandId, timeout: float = 60000.0
|
||||
) -> AsyncGenerator[TokenChunk | ErrorChunk | ToolCallChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion.
|
||||
|
||||
Args:
|
||||
timeout: Max seconds to wait for the next chunk before aborting.
|
||||
"""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk
|
||||
TokenChunk | ErrorChunk | ToolCallChunk
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
with anyio.fail_after(timeout):
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
raise
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"Chat completion timed out after {timeout}s (command_id={command_id})"
|
||||
)
|
||||
yield ErrorChunk(
|
||||
model=ModelId("unknown"),
|
||||
finish_reason="error",
|
||||
error_message=f"Request timed out after {timeout}s",
|
||||
)
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
@@ -528,7 +618,7 @@ class API:
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
assert not isinstance(chunk, ImageChunk)
|
||||
if chunk.finish_reason == "error":
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
@@ -548,6 +638,15 @@ class API:
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
# Accumulate usage stats from the final chunk
|
||||
if isinstance(chunk, TokenChunk) and chunk.stats is not None:
|
||||
s = chunk.stats
|
||||
self._accumulate_usage(
|
||||
model=chunk.model,
|
||||
prompt_tokens=s.prompt_tokens,
|
||||
completion_tokens=s.generation_tokens,
|
||||
reasoning_tokens=s.reasoning_tokens,
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
@@ -557,10 +656,14 @@ class API:
|
||||
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
logprobs_items: list[LogprobsContentItem] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
# Skip CompletionChunk - it's for the legacy completions API
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -572,6 +675,16 @@ class API:
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.stats is not None:
|
||||
stats = chunk.stats
|
||||
if chunk.logprob is not None:
|
||||
lp = _build_logprobs(chunk)
|
||||
if lp.content:
|
||||
if len(lp.content) != 1:
|
||||
logger.warning(
|
||||
f"Expected 1 logprobs content item per chunk, got {len(lp.content)}"
|
||||
)
|
||||
logprobs_items.append(lp.content[0])
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
@@ -586,9 +699,33 @@ class API:
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_text = _strip_think_tags("".join(text_parts))
|
||||
assert model is not None
|
||||
|
||||
logprobs: Logprobs | None = None
|
||||
if logprobs_items:
|
||||
logprobs = Logprobs(content=logprobs_items)
|
||||
|
||||
usage: Usage | None = None
|
||||
if stats is not None:
|
||||
completion_tokens = stats.generation_tokens
|
||||
usage = Usage(
|
||||
prompt_tokens=stats.prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=stats.prompt_tokens + completion_tokens,
|
||||
completion_tokens_details=CompletionTokensDetails(
|
||||
reasoning_tokens=stats.reasoning_tokens,
|
||||
)
|
||||
if stats.reasoning_tokens > 0
|
||||
else None,
|
||||
)
|
||||
self._accumulate_usage(
|
||||
model=model or "unknown",
|
||||
prompt_tokens=stats.prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
reasoning_tokens=stats.reasoning_tokens,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -601,9 +738,11 @@ class API:
|
||||
content=combined_text,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
@@ -617,7 +756,7 @@ class API:
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
@@ -628,6 +767,7 @@ class API:
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
@@ -638,13 +778,12 @@ class API:
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
stats = chunk.stats or stats
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_text = _strip_think_tags("".join(text_parts))
|
||||
assert model is not None
|
||||
|
||||
resp = BenchChatCompletionResponse(
|
||||
@@ -695,7 +834,14 @@ class API:
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
try:
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
except BaseException:
|
||||
# Ensure task cleanup if handler is cancelled before _chat_chunk_stream's finally runs
|
||||
with contextlib.suppress(Exception):
|
||||
await self._send(TaskFinished(finished_command_id=command.command_id))
|
||||
self._chat_completion_queues.pop(command.command_id, None)
|
||||
raise
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
|
||||
@@ -13,6 +13,7 @@ from exo.master.placement import (
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Completion,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
@@ -40,6 +41,9 @@ from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion as ChatCompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
Completion as CompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageEdits as ImageEditsTask,
|
||||
)
|
||||
@@ -158,6 +162,48 @@ class Master:
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case Completion():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=CompletionTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageGeneration():
|
||||
for instance in self.state.instances.values():
|
||||
@@ -279,17 +325,15 @@ class Master:
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
)
|
||||
task_id = self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
if task_id is not None:
|
||||
generated_events.append(TaskDeleted(task_id=task_id))
|
||||
else:
|
||||
logger.debug(
|
||||
f"TaskFinished for unknown command_id={command.finished_command_id} (already cleaned up)"
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
@@ -7,7 +7,14 @@ import tomlkit
|
||||
from anyio import Path, open_file
|
||||
from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PositiveInt, field_validator
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
BaseModel,
|
||||
Field,
|
||||
PositiveInt,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
|
||||
from exo.shared.types.common import ModelId
|
||||
@@ -121,6 +128,22 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"kimi-k2.5": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2.5"),
|
||||
storage_size=Memory.from_gb(617),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"kimi-k2.5-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2.5-4bit"),
|
||||
storage_size=Memory.from_gb(606),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
@@ -703,15 +726,18 @@ if EXO_ENABLE_IMAGE_MODELS:
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
# Common field names for number of layers across different architectures
|
||||
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
num_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layer: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
architectures: list[str] | None = None
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
layer_count: int = Field(
|
||||
validation_alias=AliasChoices(
|
||||
"num_hidden_layers",
|
||||
"num_layers",
|
||||
"n_layer",
|
||||
"n_layers",
|
||||
"num_decoder_layers",
|
||||
"decoder_layers",
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_tensor(self) -> bool:
|
||||
@@ -726,25 +752,27 @@ class ConfigData(BaseModel):
|
||||
["GptOssForCausalLM"],
|
||||
]
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
# Check common field names for layer count
|
||||
layer_fields = [
|
||||
self.num_hidden_layers,
|
||||
self.num_layers,
|
||||
self.n_layer,
|
||||
self.n_layers,
|
||||
self.num_decoder_layers,
|
||||
self.decoder_layers,
|
||||
]
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def defer_to_text_config(cls, data: dict[str, Any]):
|
||||
text_config = data.get("text_config")
|
||||
if text_config is None:
|
||||
return data
|
||||
|
||||
for layer_count in layer_fields:
|
||||
if layer_count is not None:
|
||||
return layer_count
|
||||
for field in [
|
||||
"architectures",
|
||||
"hidden_size",
|
||||
"num_hidden_layers",
|
||||
"num_layers",
|
||||
"n_layer",
|
||||
"n_layers",
|
||||
"num_decoder_layers",
|
||||
"decoder_layers",
|
||||
]:
|
||||
if (val := text_config.get(field)) is not None: # pyright: ignore[reportAny]
|
||||
data[field] = val
|
||||
|
||||
raise ValueError(
|
||||
f"No layer count found in config.json: {self.model_dump_json()}"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
|
||||
@@ -98,6 +98,8 @@ class LogprobsContentItem(BaseModel):
|
||||
|
||||
class Logprobs(BaseModel):
|
||||
content: list[LogprobsContentItem] | None = None
|
||||
# This will always be null for open source models, but exists for OpenAI API
|
||||
refusal: list[LogprobsContentItem] | None = None
|
||||
|
||||
|
||||
class PromptTokensDetails(BaseModel):
|
||||
@@ -150,6 +152,7 @@ class GenerationStats(BaseModel):
|
||||
generation_tps: float
|
||||
prompt_tokens: int
|
||||
generation_tokens: int
|
||||
reasoning_tokens: int = 0
|
||||
peak_memory_usage: Memory
|
||||
|
||||
|
||||
@@ -170,6 +173,52 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
# Legacy Completions API types (for lm_eval compatibility)
|
||||
class CompletionLogprobs(BaseModel):
|
||||
"""Logprobs in the legacy completions format."""
|
||||
|
||||
tokens: list[str]
|
||||
token_logprobs: list[float | None]
|
||||
top_logprobs: list[dict[str, float]]
|
||||
text_offset: list[int]
|
||||
|
||||
|
||||
class CompletionChoice(BaseModel):
|
||||
text: str
|
||||
index: int
|
||||
logprobs: CompletionLogprobs | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["text_completion"] = "text_completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: list[CompletionChoice]
|
||||
usage: Usage | None = None
|
||||
|
||||
|
||||
class CompletionTaskParams(BaseModel):
|
||||
"""Parameters for the legacy /v1/completions endpoint."""
|
||||
|
||||
model: str
|
||||
# Prompt can be: string, list of strings, list of token IDs, or list of token ID lists
|
||||
prompt: str | list[str] | list[int] | list[list[int]]
|
||||
max_tokens: int | None = 16
|
||||
temperature: float | None = 1.0
|
||||
top_p: float | None = 1.0
|
||||
n: int | None = 1
|
||||
stream: bool = False
|
||||
logprobs: int | None = None
|
||||
echo: bool = False
|
||||
stop: str | list[str] | None = None
|
||||
presence_penalty: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
seed: int | None = None
|
||||
user: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
@@ -190,10 +239,12 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
pass
|
||||
bench: bool = True
|
||||
|
||||
|
||||
class PlaceInstanceParams(BaseModel):
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats, TopLogprobItem
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -17,6 +17,8 @@ class BaseChunk(TaggedModel):
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -32,6 +34,17 @@ class ToolCallChunk(BaseChunk):
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class CompletionChunk(BaseChunk):
|
||||
"""Chunk for legacy completions API with full logprobs for all tokens."""
|
||||
|
||||
text: str
|
||||
tokens: list[str]
|
||||
token_logprobs: list[float | None]
|
||||
top_logprobs: list[dict[str, float]]
|
||||
text_offset: list[int]
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
data: str
|
||||
chunk_index: int
|
||||
@@ -67,4 +80,4 @@ class InputImageChunk(BaseChunk):
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
GenerationChunk = TokenChunk | CompletionChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
|
||||
@@ -3,6 +3,7 @@ from pydantic import Field
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
CompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
@@ -25,6 +26,12 @@ class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class Completion(BaseCommand):
|
||||
"""Legacy completions API command for scoring/generation."""
|
||||
|
||||
request_params: CompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
request_params: ImageGenerationTaskParams
|
||||
|
||||
@@ -79,6 +86,7 @@ Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
| ChatCompletion
|
||||
| Completion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| PlaceInstance
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
CompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
@@ -60,6 +61,16 @@ class ChatCompletion(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class Completion(BaseTask):
|
||||
"""Legacy completions task for scoring tokens with echo=True."""
|
||||
|
||||
command_id: CommandId
|
||||
task_params: CompletionTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
@@ -87,6 +98,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| Completion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -6,6 +6,7 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
TopLogprobItem,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
@@ -14,14 +15,11 @@ class BaseRunnerResponse(TaggedModel):
|
||||
pass
|
||||
|
||||
|
||||
class TokenizedResponse(BaseRunnerResponse):
|
||||
prompt_tokens: int
|
||||
|
||||
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
@@ -194,6 +194,22 @@ class MpReceiver[T]:
|
||||
raise EndOfStream from None
|
||||
return item
|
||||
|
||||
def receive_with_timeout(self, timeout: float) -> T | None:
|
||||
"""Receive with timeout, returns None if no message within timeout."""
|
||||
if self._state.closed.is_set():
|
||||
raise ClosedResourceError
|
||||
|
||||
try:
|
||||
item = self._state.buffer.get(block=True, timeout=timeout)
|
||||
if isinstance(item, _MpEndOfStream):
|
||||
self.close()
|
||||
raise EndOfStream
|
||||
return item
|
||||
except Empty:
|
||||
return None
|
||||
except ValueError as e:
|
||||
raise ClosedResourceError from e
|
||||
|
||||
# nb: this function will not cancel particularly well
|
||||
async def receive_async(self) -> T:
|
||||
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
|
||||
|
||||
@@ -13,6 +13,9 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.base import (
|
||||
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
@@ -23,13 +26,15 @@ from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
|
||||
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
@@ -102,6 +107,16 @@ class CustomMlxLayer(nn.Module):
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
class EvalCheckpointLayer(CustomMlxLayer):
|
||||
"""Wraps a layer to force evaluation of its output, breaking up the computation graph
|
||||
to prevent Metal command buffer timeouts with large batches in pipeline parallel."""
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
output = self.original_layer(x, *args, **kwargs)
|
||||
mx.eval(output)
|
||||
return output
|
||||
|
||||
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -139,11 +154,13 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
).arguments.get("cache", None)
|
||||
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
mx.eval(output)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
mx.async_eval(output)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
@@ -200,7 +217,13 @@ def pipeline_auto_parallel(
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
for layer in layers:
|
||||
mx.eval(layer) # type: ignore
|
||||
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
# Wrap intermediate layers with eval checkpoints to prevent GPU timeout
|
||||
for i in range(1, len(layers) - 1):
|
||||
layers[i] = EvalCheckpointLayer(layers[i])
|
||||
layers[-1] = PipelineLastLayer(
|
||||
layers[-1],
|
||||
device_rank,
|
||||
@@ -254,6 +277,10 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
"cache", None
|
||||
)
|
||||
|
||||
# Evaluate logits before all_gather to break the computation graph
|
||||
# and prevent Metal command buffer timeouts with large batches
|
||||
mx.eval(logits)
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
@@ -344,7 +371,7 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model, KimiK25Model)):
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -453,7 +480,7 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel, KimiK25Model)
|
||||
) and hasattr(inner_model_instance, "num_layers"):
|
||||
logger.info(
|
||||
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
@@ -472,6 +499,66 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
def _patch_deepseek_for_batching(model: nn.Module) -> None:
|
||||
"""Patch DeepseekV3Model to handle batched total_context in __call__.
|
||||
|
||||
The upstream mlx-lm DeepseekV3Model has a bug where total_context becomes
|
||||
an array (one value per batch item) when batching, but the comparison
|
||||
`total_context >= self._mla_crossover` expects a scalar.
|
||||
|
||||
This patch fixes it by temporarily replacing the cache offset with a scalar
|
||||
(max across batch) before calling the original __call__, then restoring it.
|
||||
"""
|
||||
# Get the inner model (DeepseekV3Model)
|
||||
inner_model: Any = getattr(model, "model", None)
|
||||
if inner_model is None:
|
||||
inner_model = getattr(model, "language_model", None)
|
||||
if inner_model is not None:
|
||||
inner_model = getattr(inner_model, "model", None) # pyright: ignore[reportAny]
|
||||
|
||||
if inner_model is None:
|
||||
return
|
||||
|
||||
# Get the inner model's class and patch __call__
|
||||
inner_cls: Any = inner_model.__class__ # pyright: ignore[reportAny]
|
||||
if hasattr(inner_cls, "_batching_patched"): # pyright: ignore[reportAny]
|
||||
return # Already patched
|
||||
|
||||
original_call: Any = inner_cls.__call__ # pyright: ignore[reportAny]
|
||||
|
||||
def patched_inner_call(
|
||||
self: Any, # pyright: ignore[reportAny]
|
||||
x: mx.array,
|
||||
cache: Any = None, # pyright: ignore[reportAny]
|
||||
) -> mx.array:
|
||||
# Fix the batching bug where cache[0].offset is an array but the
|
||||
# comparison `total_context >= self._mla_crossover` expects a scalar.
|
||||
# We temporarily replace the offset with a scalar (max across batch)
|
||||
# for the crossover check, then restore it after.
|
||||
if cache is not None and len(cache) > 0 and hasattr(self, "_mla_crossover"): # pyright: ignore[reportAny]
|
||||
first_cache = cache[0]
|
||||
original_offset: Any = first_cache.offset # pyright: ignore[reportAny]
|
||||
|
||||
# Check if offset is an array (batched) and needs fixing
|
||||
if hasattr(original_offset, "shape") and original_offset.shape: # pyright: ignore[reportAny]
|
||||
# Use max offset for the crossover decision (conservative choice)
|
||||
scalar_offset = int(mx.max(original_offset).item()) # pyright: ignore[reportAny]
|
||||
first_cache.offset = scalar_offset
|
||||
|
||||
try:
|
||||
result: Any = original_call(self, x, cache) # pyright: ignore[reportAny]
|
||||
finally:
|
||||
# Restore original array offset
|
||||
first_cache.offset = original_offset
|
||||
return result # pyright: ignore[reportAny]
|
||||
|
||||
return original_call(self, x, cache) # pyright: ignore[reportAny]
|
||||
|
||||
inner_cls.__call__ = patched_inner_call
|
||||
inner_cls._batching_patched = True
|
||||
logger.info("Patched DeepseekV3Model for batched inference")
|
||||
|
||||
|
||||
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -497,6 +584,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
# Store pre-shard head count and group for context parallelism
|
||||
layer.self_attn.context_parallel_total_heads = layer.self_attn.num_heads
|
||||
layer.self_attn._cp_group = self.group
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
@@ -519,6 +609,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
mx.eval(layer)
|
||||
|
||||
# Store group for context parallelism
|
||||
if hasattr(model, "model"):
|
||||
model.model._cp_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -614,6 +708,80 @@ class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class WrappedMiniMaxAttention(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
|
||||
super().__init__(layer)
|
||||
self.group = group
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | Any = None,
|
||||
cache: Any | None = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
if getattr(self, "use_qk_norm", False):
|
||||
q_dim = queries.shape[-1] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
k_dim = keys.shape[-1] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
N = self.group.size()
|
||||
|
||||
qk = mx.concatenate([queries, keys], axis=-1) # (B, L, q_dim + k_dim)
|
||||
qk = mx.distributed.all_gather(
|
||||
qk, group=self.group
|
||||
) # (N*B, L, q_dim + k_dim)
|
||||
|
||||
# Reshape to separate rank contributions: (N, B, L, q_dim + k_dim)
|
||||
# Then transpose to (B, L, N, q_dim + k_dim) and merge N into feature dim
|
||||
qk = qk.reshape(N, B, L, q_dim + k_dim).transpose(1, 2, 0, 3) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
queries = qk[..., :q_dim].reshape(
|
||||
B, L, -1
|
||||
) # (B, L, N * q_dim) # pyright: ignore[reportUnknownMemberType]
|
||||
keys = qk[..., q_dim:].reshape(
|
||||
B, L, -1
|
||||
) # (B, L, N * k_dim) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
queries = self.q_norm(queries) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
keys = self.k_norm(keys) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
# Split back and take this rank's portion
|
||||
queries = mx.split(queries, N, axis=-1)[self.group.rank()]
|
||||
keys = mx.split(keys, N, axis=-1)[self.group.rank()]
|
||||
|
||||
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
|
||||
0, 2, 1, 3
|
||||
)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
|
||||
0, 2, 1, 3
|
||||
)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportAny]
|
||||
keys = self.rope(keys, offset=cache.offset) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportAny]
|
||||
keys, values = cache.update_and_fetch(keys, values) # pyright: ignore[reportAny]
|
||||
else:
|
||||
queries = self.rope(queries) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
keys = self.rope(keys) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=self.scale,
|
||||
mask=mask, # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
return self.o_proj(output) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -622,7 +790,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
rank = self.group.rank()
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -633,18 +800,11 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
# Shard qk_norm weights if present (must match sharded head count)
|
||||
if getattr(layer.self_attn, "use_qk_norm", False):
|
||||
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
@@ -669,18 +829,32 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
if isinstance(layer, Qwen3DecoderLayer) or hasattr(layer, "self_attn"):
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
else:
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer) and hasattr(
|
||||
layer, "linear_attn"
|
||||
)
|
||||
# These layers are fast so we don't shard. This may change in future.
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
|
||||
@@ -3,15 +3,15 @@ from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.models.cache import KVCache, trim_prompt_cache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
TopLogprobItem,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
@@ -158,6 +158,206 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs_array: mx.array,
|
||||
selected_token: int,
|
||||
tokenizer: TokenizerWrapper,
|
||||
top_k: int | None,
|
||||
) -> tuple[float, list[TopLogprobItem]]:
|
||||
"""Extract the selected token's logprob and top-k alternatives.
|
||||
|
||||
top k an be set to None to return all the logprobs
|
||||
"""
|
||||
selected_logprob = float(logprobs_array[selected_token].item())
|
||||
|
||||
if top_k == 0:
|
||||
return selected_logprob, []
|
||||
|
||||
vocab_size = logprobs_array.shape[0]
|
||||
|
||||
if top_k is None:
|
||||
sorted_indices = mx.argsort(-logprobs_array)
|
||||
mx.eval(sorted_indices)
|
||||
indices_list: list[int] = cast(list[int], sorted_indices.tolist())
|
||||
else:
|
||||
k = min(top_k, vocab_size)
|
||||
top_indices = mx.argpartition(-logprobs_array, kth=k - 1)[:k]
|
||||
top_logprobs_values = logprobs_array[top_indices]
|
||||
sorted_order = mx.argsort(-top_logprobs_values)
|
||||
top_indices = top_indices[sorted_order]
|
||||
mx.eval(top_indices)
|
||||
indices_list = cast(list[int], top_indices.tolist())
|
||||
|
||||
top_logprob_items: list[TopLogprobItem] = []
|
||||
for token_id in indices_list:
|
||||
logprob_value = float(logprobs_array[token_id].item())
|
||||
token_str = tokenizer.decode([token_id])
|
||||
|
||||
top_logprob_items.append(
|
||||
TopLogprobItem(
|
||||
token=token_str,
|
||||
logprob=logprob_value,
|
||||
bytes=list(token_str.encode("utf-8")),
|
||||
)
|
||||
)
|
||||
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def score_tokens(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
tokens: list[int],
|
||||
top_k: int | None = None,
|
||||
) -> list[tuple[float, list[TopLogprobItem]]]:
|
||||
"""Score a sequence of tokens, returning logprobs for each token.
|
||||
|
||||
This is used for the completions API with echo=True, where we need
|
||||
logprobs for the prompt tokens (not just generated tokens).
|
||||
|
||||
Args:
|
||||
model: The MLX model.
|
||||
tokenizer: The tokenizer.
|
||||
tokens: List of token IDs to score.
|
||||
top_k: Number of top logprobs to return per position.
|
||||
If None, returns all logprobs.
|
||||
|
||||
Returns:
|
||||
List of (token_logprob, top_logprobs) tuples for each token position.
|
||||
The first position has no logprob (no previous context), so returns (0.0, []).
|
||||
"""
|
||||
if len(tokens) == 0:
|
||||
return []
|
||||
|
||||
# First token has no previous context to condition on
|
||||
results: list[tuple[float, list[TopLogprobItem]]] = [(0.0, [])]
|
||||
|
||||
if len(tokens) == 1:
|
||||
return results
|
||||
|
||||
# Create an empty KV cache for the forward pass
|
||||
cache = make_kv_cache(model=model)
|
||||
|
||||
# Convert to MLX array and run forward pass
|
||||
input_tokens = mx.array(tokens[:-1])[None] # All tokens except last, batched
|
||||
|
||||
# Run the model to get logits for all positions
|
||||
# The model returns logits with shape [1, seq_len, vocab_size]
|
||||
logits: mx.array = model(input_tokens, cache=cast(list[KVCache], cache))
|
||||
logits = logits.squeeze(0) # Shape: [seq_len, vocab_size]
|
||||
|
||||
# Convert to log probabilities
|
||||
logprobs_all: mx.array = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
|
||||
mx.eval(logprobs_all)
|
||||
|
||||
# For each position, extract the logprob of the actual next token
|
||||
for i in range(len(tokens) - 1):
|
||||
next_token = tokens[i + 1]
|
||||
logprobs_at_position: mx.array = logprobs_all[i]
|
||||
|
||||
logprob, top_logprobs_items = extract_top_logprobs(
|
||||
logprobs_array=logprobs_at_position,
|
||||
selected_token=next_token,
|
||||
tokenizer=tokenizer,
|
||||
top_k=top_k,
|
||||
)
|
||||
results.append((logprob, top_logprobs_items))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def score_tokens_batched(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
token_sequences: list[list[int]],
|
||||
top_k: int | None = None,
|
||||
) -> list[list[tuple[float, list[TopLogprobItem]]]]:
|
||||
"""Score multiple token sequences in a single batched forward pass.
|
||||
|
||||
This is significantly faster than calling score_tokens() multiple times
|
||||
because it batches the forward pass across all sequences.
|
||||
|
||||
Args:
|
||||
model: The MLX model.
|
||||
tokenizer: The tokenizer.
|
||||
token_sequences: List of token ID sequences to score.
|
||||
top_k: Number of top logprobs to return per position.
|
||||
|
||||
Returns:
|
||||
List of results for each sequence. Each result is a list of
|
||||
(token_logprob, top_logprobs) tuples for each token position.
|
||||
"""
|
||||
if not token_sequences:
|
||||
return []
|
||||
|
||||
# Handle empty sequences and single-token sequences
|
||||
results: list[list[tuple[float, list[TopLogprobItem]]]] = []
|
||||
non_empty_indices: list[int] = []
|
||||
non_empty_sequences: list[list[int]] = []
|
||||
|
||||
for i, tokens in enumerate(token_sequences):
|
||||
if len(tokens) == 0:
|
||||
results.append([])
|
||||
elif len(tokens) == 1:
|
||||
results.append([(0.0, [])])
|
||||
else:
|
||||
results.append([]) # Placeholder, will be filled later
|
||||
non_empty_indices.append(i)
|
||||
non_empty_sequences.append(tokens)
|
||||
|
||||
if not non_empty_sequences:
|
||||
return results
|
||||
|
||||
# Find max sequence length (excluding last token since we predict it)
|
||||
max_len = max(len(seq) - 1 for seq in non_empty_sequences)
|
||||
|
||||
# Get pad token (use eos_token_id or 0)
|
||||
pad_token_id = getattr(tokenizer, "pad_token_id", None)
|
||||
if pad_token_id is None:
|
||||
pad_token_id = getattr(tokenizer, "eos_token_id", 0)
|
||||
|
||||
# Pad sequences and create attention mask
|
||||
batch_size = len(non_empty_sequences)
|
||||
padded_inputs = mx.full((batch_size, max_len), pad_token_id, dtype=mx.int32)
|
||||
seq_lengths: list[int] = []
|
||||
|
||||
for i, tokens in enumerate(non_empty_sequences):
|
||||
input_len = len(tokens) - 1 # Exclude last token
|
||||
padded_inputs[i, :input_len] = mx.array(tokens[:-1], dtype=mx.int32)
|
||||
seq_lengths.append(input_len)
|
||||
|
||||
# Run batched forward pass (no KV cache for scoring)
|
||||
# The model accepts [batch_size, seq_len] and returns [batch_size, seq_len, vocab_size]
|
||||
logits = model(padded_inputs, cache=None)
|
||||
|
||||
# Convert to log probabilities - logits shape: [batch, seq_len, vocab]
|
||||
logprobs_all = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
mx.eval(logprobs_all)
|
||||
|
||||
# Extract results for each sequence
|
||||
for batch_idx, (orig_idx, tokens, seq_len) in enumerate(
|
||||
zip(non_empty_indices, non_empty_sequences, seq_lengths, strict=True)
|
||||
):
|
||||
seq_results: list[tuple[float, list[TopLogprobItem]]] = [(0.0, [])]
|
||||
|
||||
for pos in range(seq_len):
|
||||
next_token = tokens[pos + 1]
|
||||
logprobs_at_position: mx.array = logprobs_all[batch_idx, pos]
|
||||
|
||||
logprob, top_logprobs_items = extract_top_logprobs(
|
||||
logprobs_array=logprobs_at_position,
|
||||
selected_token=next_token,
|
||||
tokenizer=tokenizer,
|
||||
top_k=top_k,
|
||||
)
|
||||
seq_results.append((logprob, top_logprobs_items))
|
||||
|
||||
results[orig_idx] = seq_results
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -167,7 +367,7 @@ def mlx_generate(
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
|
||||
is_bench: bool = task.bench
|
||||
|
||||
# Currently we support chat-completion tasks only.
|
||||
logger.debug(f"task_params: {task}")
|
||||
@@ -209,9 +409,14 @@ def mlx_generate(
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
|
||||
# Determine if we need logprobs
|
||||
should_extract_logprobs = task.logprobs is True
|
||||
top_k = task.top_logprobs if task.top_logprobs is not None else 0
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
total_prompt_tokens = len(prompt_tokens) + prefix_hit_length
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -233,7 +438,7 @@ def mlx_generate(
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
@@ -245,9 +450,22 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
# Extract logprobs if requested
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if should_extract_logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs_array=out.logprobs,
|
||||
selected_token=out.token,
|
||||
tokenizer=tokenizer,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
@@ -165,12 +165,11 @@ def mlx_distributed_init(
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_JACCL_DEVICES"] = coordination_file
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
@@ -259,10 +258,10 @@ def shard_and_load(
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
# Estimate timeout based on model size (5x default for large queued workloads)
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb / 5
|
||||
timeout_seconds = base_timeout + model_size_gb
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
@@ -339,8 +338,35 @@ def load_tokenizer_for_model_id(
|
||||
|
||||
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
|
||||
if "kimi-k2" in model_id_lower:
|
||||
import importlib.util
|
||||
import types
|
||||
|
||||
sys.path.insert(0, str(model_path))
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
|
||||
tool_decl_path = model_path / "tool_declaration_ts.py"
|
||||
if tool_decl_path.exists():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"tool_declaration_ts", tool_decl_path
|
||||
)
|
||||
if spec and spec.loader:
|
||||
tool_decl_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["tool_declaration_ts"] = tool_decl_module
|
||||
spec.loader.exec_module(tool_decl_module)
|
||||
|
||||
# Load tokenization_kimi with patched source (convert relative to absolute import)
|
||||
tok_path = model_path / "tokenization_kimi.py"
|
||||
source = tok_path.read_text()
|
||||
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
|
||||
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
|
||||
if spec:
|
||||
tok_module = types.ModuleType("tokenization_kimi")
|
||||
tok_module.__file__ = str(tok_path)
|
||||
sys.modules["tokenization_kimi"] = tok_module
|
||||
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
|
||||
TikTokenTokenizer = tok_module.TikTokenTokenizer # type: ignore[attr-defined] # noqa: N806
|
||||
else:
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -184,8 +185,10 @@ class Worker:
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
# Only sleep when there's nothing to do - allows rapid task dispatch
|
||||
await anyio.sleep(0.01)
|
||||
continue
|
||||
logger.info(f"Worker plan: {task.__class__.__name__}")
|
||||
logger.debug(f"Worker plan: {task.__class__.__name__}")
|
||||
assert task.task_status
|
||||
await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))
|
||||
|
||||
@@ -269,6 +272,13 @@ class Worker:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case ChatCompletion():
|
||||
# Don't wait for acknowledgment for batchable inference tasks
|
||||
# This allows multiple tasks to reach the runner for batching
|
||||
# For tensor parallel: all nodes send tasks to their runner
|
||||
# so non-coordinator can participate in collective ops
|
||||
runner_id = self._task_to_runner_id(task)
|
||||
await self.runners[runner_id].start_task(task, wait_for_ack=False)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Mapping, Sequence
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
Completion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
@@ -254,8 +255,12 @@ def _ready_to_warmup(
|
||||
)
|
||||
|
||||
# Rank = 0
|
||||
# For tensor parallel, warmup is skipped so other ranks go directly
|
||||
# to RunnerReady. We need to accept both WarmingUp and Ready states.
|
||||
connecting_rank_ready = device_rank == 0 and all(
|
||||
isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None), (RunnerWarmingUp, RunnerReady)
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
if global_runner_id != runner_id
|
||||
)
|
||||
@@ -273,9 +278,11 @@ def _pending_tasks(
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
# for now, just forward chat completions and completions
|
||||
# TODO(ciaran): do this better!
|
||||
if not isinstance(task, (ChatCompletion, ImageGeneration, ImageEdits)):
|
||||
if not isinstance(
|
||||
task, (ChatCompletion, Completion, ImageGeneration, ImageEdits)
|
||||
):
|
||||
continue
|
||||
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
|
||||
continue
|
||||
@@ -298,9 +305,14 @@ def _pending_tasks(
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# Skip tasks already sent to runner (waiting for completion)
|
||||
if task.task_id in runner.sent:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
# Allow sending tasks when runner is Ready OR Running (for batching)
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
|
||||
809
src/exo/worker/runner/batched_handler.py
Normal file
809
src/exo/worker/runner/batched_handler.py
Normal file
@@ -0,0 +1,809 @@
|
||||
"""Batched inference handler for processing multiple ChatCompletion requests concurrently."""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import BatchGenerator
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
TopLogprobItem,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import ChunkGenerated, Event
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletion
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import MAX_TOKENS
|
||||
from exo.worker.engines.mlx.generator.generate import extract_top_logprobs
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
from exo.worker.runner.pipelined_generator import PipelinedGenerator, PipelinedResponse
|
||||
|
||||
# Type alias for the finish_reason values TokenChunk accepts
|
||||
TokenFinishReason = Literal["stop", "length", "content_filter"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingRequest:
|
||||
"""A request waiting to be added to the batch."""
|
||||
|
||||
task: ChatCompletion
|
||||
prompt: str
|
||||
max_tokens: int
|
||||
sampler: Callable[[mx.array], mx.array]
|
||||
should_extract_logprobs: bool
|
||||
top_k: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveRequest:
|
||||
"""A request currently being processed in the batch."""
|
||||
|
||||
command_id: CommandId
|
||||
should_extract_logprobs: bool
|
||||
top_k: int
|
||||
harmony_parser: Any | None = None # StreamableParser for GPT-OSS models
|
||||
in_thinking: bool = False # Currently in thinking/reasoning section
|
||||
tokens_generated: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
|
||||
|
||||
class BatchedInferenceHandler:
|
||||
"""
|
||||
Handles batched inference for multiple ChatCompletion requests.
|
||||
|
||||
Uses MLX-LM's BatchGenerator to process multiple requests concurrently,
|
||||
improving throughput for scenarios with multiple concurrent requests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
model_id: ModelId,
|
||||
device_rank: int,
|
||||
world_size: int = 1,
|
||||
max_batch_size: int = 32,
|
||||
tensor_parallel_group: mx.distributed.Group | None = None,
|
||||
is_coordinator: bool = True,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.model_id = model_id
|
||||
self.device_rank = device_rank
|
||||
self.world_size = world_size
|
||||
self.max_batch_size = max_batch_size
|
||||
self.tensor_parallel_group = tensor_parallel_group
|
||||
self.is_coordinator = is_coordinator
|
||||
|
||||
# Model-specific thinking/reasoning detection
|
||||
self.is_gpt_oss = isinstance(model, GptOssModel)
|
||||
self._harmony_encoding: Any | None = None
|
||||
if self.is_gpt_oss:
|
||||
self._harmony_encoding = load_harmony_encoding(
|
||||
HarmonyEncodingName.HARMONY_GPT_OSS
|
||||
)
|
||||
logger.info("GPT-OSS model detected, enabling harmony stream parsing")
|
||||
|
||||
# Detect <think></think> tokens from tokenizer (works for any model)
|
||||
self._think_start_token: int | None = None
|
||||
self._think_end_token: int | None = None
|
||||
think_start: int | None = tokenizer.think_start_id # pyright: ignore[reportAny]
|
||||
if not self.is_gpt_oss and think_start is not None:
|
||||
self._think_start_token = think_start
|
||||
self._think_end_token = tokenizer.think_end_id # pyright: ignore[reportAny]
|
||||
logger.info(
|
||||
f"Detected <think></think> tokens ({self._think_start_token}/{self._think_end_token}), enabling reasoning tracking"
|
||||
)
|
||||
|
||||
# Pending requests waiting to be batched
|
||||
self.pending: list[PendingRequest] = []
|
||||
|
||||
# Track active count for non-coordinators (they don't have uid_to_request)
|
||||
self._non_coordinator_active_count: int = 0
|
||||
|
||||
# Active batch generator and request tracking
|
||||
self.batch_generator: BatchGenerator | None = None
|
||||
self.pipelined_generator: PipelinedGenerator | None = None
|
||||
self.uid_to_request: dict[int, ActiveRequest] = {}
|
||||
|
||||
# Use pipelined generator for multi-device pipeline parallelism
|
||||
self.use_pipelined = world_size > 1
|
||||
if self.use_pipelined:
|
||||
logger.info(
|
||||
f"Using PipelinedGenerator with {world_size} streams for pipeline overlap"
|
||||
)
|
||||
|
||||
# EOS tokens for the model
|
||||
self.stop_tokens: set[int] = set()
|
||||
eos_ids: list[int] | None = getattr(tokenizer, "eos_token_ids", None)
|
||||
if eos_ids:
|
||||
self.stop_tokens = set(eos_ids)
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""Check if there's an active batch being processed."""
|
||||
if self.use_pipelined:
|
||||
return (
|
||||
self.pipelined_generator is not None
|
||||
and self.pipelined_generator.has_active
|
||||
)
|
||||
if self.batch_generator is None:
|
||||
return False
|
||||
# For non-coordinators, use internal counter (they don't track uid_to_request)
|
||||
if not self.is_coordinator:
|
||||
return self._non_coordinator_active_count > 0
|
||||
return len(self.uid_to_request) > 0
|
||||
|
||||
@property
|
||||
def has_pending(self) -> bool:
|
||||
"""Check if there are pending requests waiting to be batched."""
|
||||
return len(self.pending) > 0
|
||||
|
||||
@property
|
||||
def current_batch_size(self) -> int:
|
||||
"""Current number of active requests in the batch."""
|
||||
return len(self.uid_to_request)
|
||||
|
||||
def add_request(self, task: ChatCompletion) -> None:
|
||||
"""Add a ChatCompletion request to the pending batch."""
|
||||
task_params = task.task_params
|
||||
|
||||
# Build prompt
|
||||
prompt = apply_chat_template(self.tokenizer, task_params)
|
||||
|
||||
# Determine max tokens
|
||||
max_tokens = task_params.max_tokens or MAX_TOKENS
|
||||
|
||||
# Create sampler for this request
|
||||
sampler = make_sampler(
|
||||
temp=task_params.temperature
|
||||
if task_params.temperature is not None
|
||||
else 0.7,
|
||||
top_p=task_params.top_p if task_params.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Logprobs configuration
|
||||
should_extract_logprobs = task_params.logprobs is True
|
||||
top_k = task_params.top_logprobs if task_params.top_logprobs is not None else 0
|
||||
|
||||
pending_request = PendingRequest(
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
should_extract_logprobs=should_extract_logprobs,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
self.pending.append(pending_request)
|
||||
|
||||
logger.info(
|
||||
f"Added request to batch queue (pending={len(self.pending)}, active={self.current_batch_size})"
|
||||
)
|
||||
|
||||
def _broadcast_int(self, value: int) -> int:
|
||||
"""Broadcast an integer from rank 0 to all ranks."""
|
||||
if self.tensor_parallel_group is None:
|
||||
return value
|
||||
arr = mx.array([value if self.is_coordinator else 0], dtype=mx.int32)
|
||||
synced = mx.distributed.all_sum(arr, group=self.tensor_parallel_group)
|
||||
mx.eval(synced)
|
||||
return int(synced.item())
|
||||
|
||||
def _broadcast_tokens(self, tokens_list: list[list[int]]) -> list[list[int]]:
|
||||
"""Broadcast tokenized prompts from rank 0 to all ranks."""
|
||||
if self.tensor_parallel_group is None:
|
||||
return tokens_list
|
||||
|
||||
# Step 1: Broadcast number of sequences
|
||||
num_seqs = self._broadcast_int(len(tokens_list))
|
||||
if num_seqs == 0:
|
||||
return []
|
||||
|
||||
# Step 2: Broadcast length of each sequence
|
||||
lengths: list[int] = []
|
||||
for i in range(num_seqs):
|
||||
length = self._broadcast_int(
|
||||
len(tokens_list[i])
|
||||
if self.is_coordinator and i < len(tokens_list)
|
||||
else 0
|
||||
)
|
||||
lengths.append(length)
|
||||
|
||||
# Step 3: Broadcast flattened tokens
|
||||
total_tokens = sum(lengths)
|
||||
if self.is_coordinator:
|
||||
flat: list[int] = []
|
||||
for seq in tokens_list:
|
||||
flat.extend(seq)
|
||||
flat_arr = mx.array(flat, dtype=mx.int32)
|
||||
else:
|
||||
flat_arr = mx.zeros((total_tokens,), dtype=mx.int32)
|
||||
|
||||
# Broadcast via all_sum (rank 0 contributes, others contribute zeros)
|
||||
synced_flat = mx.distributed.all_sum(flat_arr, group=self.tensor_parallel_group)
|
||||
mx.eval(synced_flat)
|
||||
|
||||
# Unflatten
|
||||
result: list[list[int]] = []
|
||||
offset = 0
|
||||
for length in lengths:
|
||||
seq_arr = synced_flat[offset : offset + length]
|
||||
seq: list[int] = [int(x) for x in seq_arr.tolist()] # type: ignore[union-attr]
|
||||
result.append(seq)
|
||||
offset += length
|
||||
|
||||
return result
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Start processing pending requests by adding them to the batch/pipelined generator."""
|
||||
# Declare variables with types
|
||||
tokenized_prompts: list[list[int]]
|
||||
max_tokens_list: list[int]
|
||||
samplers: list[Callable[[mx.array], mx.array]]
|
||||
prompt_token_counts: list[int]
|
||||
requests_to_flush: list[PendingRequest]
|
||||
|
||||
# For tensor parallel: rank 0 broadcasts batch info, others receive and sync
|
||||
if self.tensor_parallel_group is not None:
|
||||
# Broadcast how many to flush
|
||||
available_slots = self.max_batch_size - self.current_batch_size
|
||||
num_to_flush = self._broadcast_int(
|
||||
min(len(self.pending), available_slots) if self.is_coordinator else 0
|
||||
)
|
||||
|
||||
if num_to_flush == 0:
|
||||
return
|
||||
|
||||
# Get requests and tokenize on rank 0
|
||||
if self.is_coordinator:
|
||||
requests_to_flush = self.pending[:num_to_flush]
|
||||
self.pending = self.pending[num_to_flush:]
|
||||
tokenized_prompts = [
|
||||
self.tokenizer.encode(req.prompt) for req in requests_to_flush
|
||||
]
|
||||
max_tokens_list = [req.max_tokens for req in requests_to_flush]
|
||||
else:
|
||||
requests_to_flush = []
|
||||
tokenized_prompts = []
|
||||
max_tokens_list = []
|
||||
|
||||
# Broadcast tokenized prompts to all ranks
|
||||
tokenized_prompts = self._broadcast_tokens(tokenized_prompts)
|
||||
|
||||
# Broadcast max_tokens
|
||||
synced_max_tokens: list[int] = []
|
||||
for i in range(num_to_flush):
|
||||
mt = self._broadcast_int(
|
||||
max_tokens_list[i]
|
||||
if self.is_coordinator and i < len(max_tokens_list)
|
||||
else 0
|
||||
)
|
||||
synced_max_tokens.append(mt)
|
||||
max_tokens_list = synced_max_tokens
|
||||
|
||||
# Create samplers (same on all ranks since we use temp=0 typically)
|
||||
samplers = [make_sampler(temp=0.0) for _ in range(num_to_flush)]
|
||||
prompt_token_counts = [len(t) for t in tokenized_prompts]
|
||||
|
||||
else:
|
||||
if not self.has_pending:
|
||||
return
|
||||
available_slots = self.max_batch_size - self.current_batch_size
|
||||
requests_to_flush = self.pending[:available_slots]
|
||||
self.pending = self.pending[available_slots:]
|
||||
|
||||
# Prepare batch data - tokenize prompts
|
||||
tokenized_prompts = []
|
||||
max_tokens_list = []
|
||||
samplers = []
|
||||
prompt_token_counts = []
|
||||
|
||||
for req in requests_to_flush:
|
||||
tokens = self.tokenizer.encode(req.prompt)
|
||||
tokenized_prompts.append(tokens)
|
||||
max_tokens_list.append(req.max_tokens)
|
||||
samplers.append(req.sampler)
|
||||
prompt_token_counts.append(len(tokens))
|
||||
|
||||
if self.use_pipelined:
|
||||
self._flush_pipelined(
|
||||
requests_to_flush,
|
||||
tokenized_prompts,
|
||||
max_tokens_list,
|
||||
samplers,
|
||||
prompt_token_counts,
|
||||
)
|
||||
else:
|
||||
self._flush_batch(
|
||||
requests_to_flush,
|
||||
tokenized_prompts,
|
||||
max_tokens_list,
|
||||
samplers,
|
||||
prompt_token_counts,
|
||||
)
|
||||
|
||||
def _flush_pipelined(
|
||||
self,
|
||||
requests_to_flush: list[PendingRequest],
|
||||
tokenized_prompts: list[list[int]],
|
||||
max_tokens_list: list[int],
|
||||
samplers: list[Callable[[mx.array], mx.array]],
|
||||
prompt_token_counts: list[int],
|
||||
) -> None:
|
||||
"""Flush using PipelinedGenerator (multi-stream pipeline overlap)."""
|
||||
if self.pipelined_generator is None:
|
||||
logger.info(
|
||||
f"Creating PipelinedGenerator for {len(requests_to_flush)} requests ({self.world_size} streams)"
|
||||
)
|
||||
mx.reset_peak_memory()
|
||||
self.pipelined_generator = PipelinedGenerator(
|
||||
model=self.model,
|
||||
world_size=self.world_size,
|
||||
stop_tokens=self.stop_tokens if self.stop_tokens else None,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Adding {len(requests_to_flush)} requests to PipelinedGenerator"
|
||||
)
|
||||
|
||||
uids = self.pipelined_generator.insert(
|
||||
prompts=tokenized_prompts,
|
||||
max_tokens=max_tokens_list,
|
||||
samplers=samplers,
|
||||
)
|
||||
|
||||
for uid, req, prompt_tokens, tokens in zip(
|
||||
uids, requests_to_flush, prompt_token_counts, tokenized_prompts, strict=True
|
||||
):
|
||||
parser = None
|
||||
if self.is_gpt_oss and self._harmony_encoding is not None:
|
||||
parser = StreamableParser(self._harmony_encoding, role=Role.ASSISTANT) # pyright: ignore[reportAny]
|
||||
# Check if prompt contains <think> token - if so, model is already in thinking mode
|
||||
starts_in_thinking = (
|
||||
self._think_start_token is not None
|
||||
and self._think_start_token in tokens
|
||||
)
|
||||
self.uid_to_request[uid] = ActiveRequest(
|
||||
command_id=req.task.command_id,
|
||||
should_extract_logprobs=req.should_extract_logprobs,
|
||||
top_k=req.top_k,
|
||||
prompt_tokens=prompt_tokens,
|
||||
harmony_parser=parser,
|
||||
in_thinking=starts_in_thinking,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Flushed {len(requests_to_flush)} requests into pipelined generator (active={self.pipelined_generator.active_count}, uids={list(self.uid_to_request.keys())})"
|
||||
)
|
||||
|
||||
def _flush_batch(
|
||||
self,
|
||||
requests_to_flush: list[PendingRequest],
|
||||
tokenized_prompts: list[list[int]],
|
||||
max_tokens_list: list[int],
|
||||
samplers: list[Callable[[mx.array], mx.array]],
|
||||
prompt_token_counts: list[int],
|
||||
) -> None:
|
||||
"""Flush using BatchGenerator (single-stream, for non-pipeline instances)."""
|
||||
if self.batch_generator is None:
|
||||
logger.info(
|
||||
f"Creating new BatchGenerator for {len(requests_to_flush)} requests"
|
||||
)
|
||||
mx.reset_peak_memory()
|
||||
self.batch_generator = BatchGenerator(
|
||||
model=self.model,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop_tokens=self.stop_tokens if self.stop_tokens else None,
|
||||
prefill_batch_size=1,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Adding {len(requests_to_flush)} requests to existing BatchGenerator"
|
||||
)
|
||||
|
||||
# Insert into batch generator
|
||||
uids: list[int] = self.batch_generator.insert( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
prompts=tokenized_prompts,
|
||||
max_tokens=max_tokens_list,
|
||||
samplers=samplers, # pyright: ignore[reportCallIssue]
|
||||
)
|
||||
|
||||
# Only coordinator tracks requests (non-coordinators don't have request objects)
|
||||
if self.is_coordinator:
|
||||
for uid, req, prompt_tokens, tokens in zip(
|
||||
uids, # pyright: ignore[reportUnknownArgumentType]
|
||||
requests_to_flush,
|
||||
prompt_token_counts,
|
||||
tokenized_prompts,
|
||||
strict=True,
|
||||
):
|
||||
parser = None
|
||||
if self.is_gpt_oss and self._harmony_encoding is not None:
|
||||
parser = StreamableParser(
|
||||
self._harmony_encoding, # pyright: ignore[reportAny]
|
||||
role=Role.ASSISTANT,
|
||||
)
|
||||
# Check if prompt contains <think> token - if so, model is already in thinking mode
|
||||
starts_in_thinking = (
|
||||
self._think_start_token is not None
|
||||
and self._think_start_token in tokens
|
||||
)
|
||||
self.uid_to_request[uid] = ActiveRequest(
|
||||
command_id=req.task.command_id,
|
||||
should_extract_logprobs=req.should_extract_logprobs,
|
||||
top_k=req.top_k,
|
||||
prompt_tokens=prompt_tokens,
|
||||
harmony_parser=parser,
|
||||
in_thinking=starts_in_thinking,
|
||||
)
|
||||
else:
|
||||
# Non-coordinator: INCREMENT active count (not set) to track all active requests
|
||||
# across multiple flushes. This ensures is_active remains True when new requests
|
||||
# are added while existing ones are still generating.
|
||||
self._non_coordinator_active_count += len(tokenized_prompts)
|
||||
|
||||
# Log the actual active count (different tracking for coordinator vs non-coordinator)
|
||||
actual_active = (
|
||||
self.current_batch_size
|
||||
if self.is_coordinator
|
||||
else self._non_coordinator_active_count
|
||||
)
|
||||
logger.info(
|
||||
f"Flushed {len(tokenized_prompts)} requests into batch (active={actual_active}, is_coordinator={self.is_coordinator})"
|
||||
)
|
||||
|
||||
def step(self) -> Generator[Event, None, None]:
|
||||
"""
|
||||
Process one generation step and yield ChunkGenerated events.
|
||||
|
||||
Returns a generator of events for completed tokens across all active requests.
|
||||
"""
|
||||
if self.use_pipelined:
|
||||
yield from self._step_pipelined()
|
||||
return
|
||||
|
||||
if self.batch_generator is None:
|
||||
return
|
||||
|
||||
# Non-coordinators still need to call next() for model sync but don't emit events
|
||||
if not self.is_coordinator:
|
||||
if self._non_coordinator_active_count > 0:
|
||||
nc_responses: list[Any] = self.batch_generator.next() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
# Track completions to update active count
|
||||
for nc_resp in nc_responses: # pyright: ignore[reportUnknownVariableType]
|
||||
if nc_resp.finish_reason is not None: # pyright: ignore[reportUnknownMemberType]
|
||||
self._non_coordinator_active_count -= 1
|
||||
return
|
||||
|
||||
if not self.uid_to_request:
|
||||
return
|
||||
|
||||
# Get next tokens for all active requests
|
||||
# BatchGenerator.next() returns list of Response objects
|
||||
logger.debug(
|
||||
f"BatchGenerator.next() called (active_uids={list(self.uid_to_request.keys())})"
|
||||
)
|
||||
responses: list[Any] = self.batch_generator.next() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
logger.debug(f"BatchGenerator.next() returned {len(responses)} responses") # pyright: ignore[reportUnknownArgumentType]
|
||||
|
||||
completed_uids: list[int] = []
|
||||
|
||||
for response in responses: # pyright: ignore[reportUnknownVariableType]
|
||||
uid: int = response.uid # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if uid not in self.uid_to_request:
|
||||
logger.warning(f"Received response for unknown uid: {uid}")
|
||||
continue
|
||||
|
||||
active_request = self.uid_to_request[uid]
|
||||
active_request.tokens_generated += 1
|
||||
|
||||
# Extract response fields with explicit typing
|
||||
resp_token: int = response.token # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
resp_finish_reason: str | None = response.finish_reason # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
resp_logprobs: mx.array = response.logprobs # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
# Only emit events from device_rank 0
|
||||
if self.device_rank != 0:
|
||||
if resp_finish_reason is not None:
|
||||
completed_uids.append(uid) # pyright: ignore[reportUnknownArgumentType]
|
||||
continue
|
||||
|
||||
# Decode token to text
|
||||
# Skip emitting EOS token text (e.g., <|eot_id|>)
|
||||
if resp_token in self.stop_tokens:
|
||||
token_text = ""
|
||||
else:
|
||||
token_text = self.tokenizer.decode([resp_token])
|
||||
|
||||
# Handle thinking/reasoning token tracking
|
||||
if active_request.harmony_parser is not None:
|
||||
# GPT-OSS: Use harmony parser for channel-based thinking detection
|
||||
parser = active_request.harmony_parser # pyright: ignore[reportAny]
|
||||
parser.process(resp_token) # pyright: ignore[reportAny]
|
||||
delta: str | None = parser.last_content_delta # pyright: ignore[reportAny]
|
||||
channel: str = parser.current_channel # pyright: ignore[reportAny]
|
||||
|
||||
# Track reasoning tokens (analysis channel = thinking)
|
||||
if channel == "analysis":
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
# Handle thinking tag transitions
|
||||
prefix = ""
|
||||
if channel == "analysis" and not active_request.in_thinking:
|
||||
active_request.in_thinking = True
|
||||
prefix = "<think>"
|
||||
elif channel != "analysis" and active_request.in_thinking:
|
||||
active_request.in_thinking = False
|
||||
prefix = "</think>"
|
||||
|
||||
if resp_finish_reason is not None and active_request.in_thinking:
|
||||
# Close thinking tag on finish
|
||||
prefix = "</think>"
|
||||
active_request.in_thinking = False
|
||||
|
||||
effective_delta = delta or ""
|
||||
token_text = (
|
||||
prefix + effective_delta if (prefix or effective_delta) else ""
|
||||
)
|
||||
# Skip empty tokens (channel markers with no content delta)
|
||||
if not token_text and resp_finish_reason is None:
|
||||
continue
|
||||
elif self._think_start_token is not None:
|
||||
# MiniMax: Track <think>/</ think> tokens directly
|
||||
if resp_token == self._think_start_token:
|
||||
active_request.in_thinking = True
|
||||
elif resp_token == self._think_end_token:
|
||||
active_request.in_thinking = False
|
||||
elif active_request.in_thinking:
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
# Extract logprobs if requested
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if active_request.should_extract_logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs_array=resp_logprobs, # pyright: ignore[reportUnknownArgumentType]
|
||||
selected_token=resp_token, # pyright: ignore[reportUnknownArgumentType]
|
||||
tokenizer=self.tokenizer,
|
||||
top_k=active_request.top_k,
|
||||
)
|
||||
|
||||
# Build stats for final token
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: TokenFinishReason | None = None
|
||||
if resp_finish_reason is not None:
|
||||
elapsed_time = time.perf_counter() - active_request.start_time
|
||||
prompt_tps = active_request.prompt_tokens / max(elapsed_time, 0.001)
|
||||
generation_tps = active_request.tokens_generated / max(
|
||||
elapsed_time, 0.001
|
||||
)
|
||||
|
||||
# Get peak memory
|
||||
peak_memory_bytes = 0
|
||||
if mx.metal.is_available():
|
||||
peak_memory_bytes = mx.metal.get_peak_memory()
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=active_request.prompt_tokens,
|
||||
generation_tokens=active_request.tokens_generated,
|
||||
reasoning_tokens=active_request.reasoning_tokens,
|
||||
peak_memory_usage=Memory.from_bytes(peak_memory_bytes),
|
||||
)
|
||||
|
||||
# Map finish reason to the narrower type TokenChunk expects
|
||||
if resp_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif resp_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
elif resp_finish_reason == "content_filter":
|
||||
finish_reason = "content_filter"
|
||||
else:
|
||||
# Unknown finish reasons default to "stop"
|
||||
logger.warning(
|
||||
f"Unknown finish_reason: {resp_finish_reason}, mapping to 'stop'"
|
||||
)
|
||||
finish_reason = "stop"
|
||||
|
||||
completed_uids.append(uid) # pyright: ignore[reportUnknownArgumentType]
|
||||
|
||||
yield ChunkGenerated(
|
||||
command_id=active_request.command_id,
|
||||
chunk=TokenChunk(
|
||||
model=self.model_id,
|
||||
text=token_text,
|
||||
token_id=resp_token, # pyright: ignore[reportUnknownArgumentType]
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
),
|
||||
)
|
||||
|
||||
# Clean up completed requests
|
||||
for uid in completed_uids:
|
||||
del self.uid_to_request[uid]
|
||||
|
||||
def _step_pipelined(self) -> Generator[Event, None, None]:
|
||||
"""Process one generation step using the multi-stream PipelinedGenerator."""
|
||||
if self.pipelined_generator is None or not self.uid_to_request:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"PipelinedGenerator.next() called (active={self.pipelined_generator.active_count})"
|
||||
)
|
||||
responses: list[PipelinedResponse] = self.pipelined_generator.next()
|
||||
logger.debug(f"PipelinedGenerator.next() returned {len(responses)} responses")
|
||||
|
||||
completed_uids: list[int] = []
|
||||
|
||||
for response in responses:
|
||||
uid = response.uid
|
||||
if uid not in self.uid_to_request:
|
||||
logger.warning(f"Received response for unknown uid: {uid}")
|
||||
continue
|
||||
|
||||
active_request = self.uid_to_request[uid]
|
||||
active_request.tokens_generated += 1
|
||||
|
||||
resp_token: int = response.token
|
||||
resp_finish_reason: str | None = response.finish_reason
|
||||
resp_logprobs: mx.array = response.logprobs
|
||||
|
||||
# Only emit events from device_rank 0
|
||||
if self.device_rank != 0:
|
||||
if resp_finish_reason is not None:
|
||||
completed_uids.append(uid)
|
||||
continue
|
||||
|
||||
# Decode token to text
|
||||
# Skip emitting EOS token text (e.g., <|eot_id|>)
|
||||
if resp_token in self.stop_tokens:
|
||||
token_text = ""
|
||||
else:
|
||||
token_text = self.tokenizer.decode([resp_token])
|
||||
|
||||
# Handle thinking/reasoning token tracking
|
||||
if active_request.harmony_parser is not None:
|
||||
# GPT-OSS: Use harmony parser for channel-based thinking detection
|
||||
parser = active_request.harmony_parser # pyright: ignore[reportAny]
|
||||
parser.process(resp_token) # pyright: ignore[reportAny]
|
||||
delta: str | None = parser.last_content_delta # pyright: ignore[reportAny]
|
||||
channel: str = parser.current_channel # pyright: ignore[reportAny]
|
||||
|
||||
if channel == "analysis":
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
prefix = ""
|
||||
if channel == "analysis" and not active_request.in_thinking:
|
||||
active_request.in_thinking = True
|
||||
prefix = "<think>"
|
||||
elif channel != "analysis" and active_request.in_thinking:
|
||||
active_request.in_thinking = False
|
||||
prefix = "</think>"
|
||||
|
||||
if resp_finish_reason is not None and active_request.in_thinking:
|
||||
prefix = "</think>"
|
||||
active_request.in_thinking = False
|
||||
|
||||
effective_delta = delta or ""
|
||||
token_text = (
|
||||
prefix + effective_delta if (prefix or effective_delta) else ""
|
||||
)
|
||||
if not token_text and resp_finish_reason is None:
|
||||
continue
|
||||
elif self._think_start_token is not None:
|
||||
# MiniMax: Track <think>/</think> tokens directly
|
||||
if resp_token == self._think_start_token:
|
||||
active_request.in_thinking = True
|
||||
elif resp_token == self._think_end_token:
|
||||
active_request.in_thinking = False
|
||||
elif active_request.in_thinking:
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
# Extract logprobs if requested
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if active_request.should_extract_logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs_array=resp_logprobs,
|
||||
selected_token=resp_token,
|
||||
tokenizer=self.tokenizer,
|
||||
top_k=active_request.top_k,
|
||||
)
|
||||
|
||||
# Build stats for final token
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: TokenFinishReason | None = None
|
||||
if resp_finish_reason is not None:
|
||||
elapsed_time = time.perf_counter() - active_request.start_time
|
||||
prompt_tps = active_request.prompt_tokens / max(elapsed_time, 0.001)
|
||||
generation_tps = active_request.tokens_generated / max(
|
||||
elapsed_time, 0.001
|
||||
)
|
||||
|
||||
peak_memory_bytes = 0
|
||||
if mx.metal.is_available():
|
||||
peak_memory_bytes = mx.metal.get_peak_memory()
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=active_request.prompt_tokens,
|
||||
generation_tokens=active_request.tokens_generated,
|
||||
reasoning_tokens=active_request.reasoning_tokens,
|
||||
peak_memory_usage=Memory.from_bytes(peak_memory_bytes),
|
||||
)
|
||||
|
||||
if resp_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif resp_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
|
||||
completed_uids.append(uid)
|
||||
|
||||
yield ChunkGenerated(
|
||||
command_id=active_request.command_id,
|
||||
chunk=TokenChunk(
|
||||
model=self.model_id,
|
||||
text=token_text,
|
||||
token_id=resp_token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
),
|
||||
)
|
||||
|
||||
for uid in completed_uids:
|
||||
del self.uid_to_request[uid]
|
||||
|
||||
def emit_error(self, command_id: CommandId, error_message: str) -> Event:
|
||||
"""Create an error event for a failed request."""
|
||||
return ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=self.model_id,
|
||||
finish_reason="error",
|
||||
error_message=error_message,
|
||||
),
|
||||
)
|
||||
|
||||
def _close_generator(self) -> None:
|
||||
"""Close and clean up the batch/pipelined generator."""
|
||||
if self.batch_generator is not None:
|
||||
self.batch_generator.close() # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue]
|
||||
self.batch_generator = None
|
||||
if self.pipelined_generator is not None:
|
||||
self.pipelined_generator.close()
|
||||
self.pipelined_generator = None
|
||||
self.uid_to_request.clear()
|
||||
logger.info("Generator closed")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the handler and clean up resources."""
|
||||
self._close_generator()
|
||||
self.pending.clear()
|
||||
200
src/exo/worker/runner/batched_scoring_handler.py
Normal file
200
src/exo/worker/runner/batched_scoring_handler.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Batched scoring handler for processing multiple Completion requests concurrently."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import TopLogprobItem
|
||||
from exo.shared.types.chunks import CompletionChunk, ErrorChunk
|
||||
from exo.shared.types.events import ChunkGenerated, Event
|
||||
from exo.shared.types.tasks import Completion
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import score_tokens_batched
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingScoringRequest:
|
||||
"""A scoring request waiting to be batched."""
|
||||
|
||||
task: Completion
|
||||
tokens: list[int]
|
||||
prompt_text: str
|
||||
top_k: int | None
|
||||
echo: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedScoringHandler:
|
||||
"""
|
||||
Handles batched scoring for multiple Completion requests.
|
||||
|
||||
Collects multiple scoring requests and processes them in a single
|
||||
batched forward pass for improved throughput.
|
||||
"""
|
||||
|
||||
model: Model
|
||||
tokenizer: TokenizerWrapper
|
||||
model_id: ModelId
|
||||
device_rank: int
|
||||
max_batch_size: int = 32
|
||||
batch_timeout_ms: int = 10
|
||||
|
||||
pending: list[PendingScoringRequest] = field(default_factory=list)
|
||||
pending_start_time: float | None = None
|
||||
|
||||
@property
|
||||
def has_pending(self) -> bool:
|
||||
"""Check if there are pending requests."""
|
||||
return len(self.pending) > 0
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
task: Completion,
|
||||
tokens: list[int],
|
||||
prompt_text: str,
|
||||
) -> None:
|
||||
"""Add a Completion request to the pending batch."""
|
||||
task_params = task.task_params
|
||||
top_k = task_params.logprobs
|
||||
|
||||
self.pending.append(
|
||||
PendingScoringRequest(
|
||||
task=task,
|
||||
tokens=tokens,
|
||||
prompt_text=prompt_text,
|
||||
top_k=top_k,
|
||||
echo=task_params.echo,
|
||||
)
|
||||
)
|
||||
|
||||
if self.pending_start_time is None:
|
||||
self.pending_start_time = time.perf_counter()
|
||||
|
||||
logger.debug(f"Added scoring request to batch (pending={len(self.pending)})")
|
||||
|
||||
def should_flush(self) -> bool:
|
||||
"""Check if the batch should be flushed."""
|
||||
if not self.has_pending:
|
||||
return False
|
||||
|
||||
# Flush if batch is full
|
||||
if len(self.pending) >= self.max_batch_size:
|
||||
return True
|
||||
|
||||
# Flush if timeout reached
|
||||
if self.pending_start_time is not None:
|
||||
elapsed_ms = (time.perf_counter() - self.pending_start_time) * 1000
|
||||
if elapsed_ms >= self.batch_timeout_ms:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def flush(self) -> list[Event]:
|
||||
"""Process all pending requests and return events."""
|
||||
if not self.has_pending:
|
||||
return []
|
||||
|
||||
requests = self.pending
|
||||
self.pending = []
|
||||
self.pending_start_time = None
|
||||
|
||||
logger.info(f"Processing batch of {len(requests)} scoring requests")
|
||||
|
||||
# Collect all token sequences
|
||||
token_sequences = [req.tokens for req in requests]
|
||||
|
||||
# Get common top_k (use first request's top_k, they should all be the same)
|
||||
top_k = requests[0].top_k if requests else None
|
||||
|
||||
try:
|
||||
# Run batched scoring
|
||||
all_results = score_tokens_batched(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
token_sequences=token_sequences,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# Generate events for each request
|
||||
events: list[Event] = []
|
||||
for req, logprob_results in zip(requests, all_results, strict=True):
|
||||
if self.device_rank != 0:
|
||||
continue
|
||||
|
||||
event = self._build_completion_event(req, logprob_results)
|
||||
events.append(event)
|
||||
|
||||
logger.info(f"Batch scoring complete ({len(events)} events)")
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
# Return error events for all requests
|
||||
logger.error(f"Batch scoring failed: {e}")
|
||||
events = []
|
||||
for req in requests:
|
||||
if self.device_rank == 0:
|
||||
events.append(
|
||||
ChunkGenerated(
|
||||
command_id=req.task.command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=self.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
def _build_completion_event(
|
||||
self,
|
||||
req: PendingScoringRequest,
|
||||
logprob_results: list[tuple[float, list[TopLogprobItem]]],
|
||||
) -> Event:
|
||||
"""Build a ChunkGenerated event for a completed scoring request."""
|
||||
tokens = req.tokens
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
# Build response in completions format
|
||||
token_strings: list[str] = []
|
||||
token_logprobs: list[float | None] = []
|
||||
top_logprobs: list[dict[str, float]] = []
|
||||
text_offset: list[int] = []
|
||||
|
||||
offset = 0
|
||||
for i, token_id in enumerate(tokens):
|
||||
token_str = tokenizer.decode([token_id])
|
||||
token_strings.append(token_str)
|
||||
|
||||
if i < len(logprob_results):
|
||||
logprob, top_items = logprob_results[i]
|
||||
# First token has no logprob (None in OpenAI format)
|
||||
token_logprobs.append(logprob if i > 0 else None)
|
||||
top_lp_dict = {item.token: item.logprob for item in top_items}
|
||||
top_logprobs.append(top_lp_dict)
|
||||
else:
|
||||
token_logprobs.append(None)
|
||||
top_logprobs.append({})
|
||||
|
||||
text_offset.append(offset)
|
||||
offset += len(token_str)
|
||||
|
||||
return ChunkGenerated(
|
||||
command_id=req.task.command_id,
|
||||
chunk=CompletionChunk(
|
||||
model=self.model_id,
|
||||
text=req.prompt_text if req.echo else "",
|
||||
tokens=token_strings,
|
||||
token_logprobs=token_logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
text_offset=text_offset,
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
self.pending.clear()
|
||||
self.pending_start_time = None
|
||||
334
src/exo/worker/runner/pipelined_generator.py
Normal file
334
src/exo/worker/runner/pipelined_generator.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""Multi-stream pipelined batch generator for pipeline-parallel inference.
|
||||
|
||||
When a model is split across N ranks (pipeline parallelism), each rank's GPU is idle
|
||||
for (N-1)/N of each step while waiting for other ranks to compute their layers.
|
||||
|
||||
This module fills the pipeline bubble by splitting sequences into N micro-batch groups
|
||||
and processing each group on a different MLX stream. The GPU can overlap one stream's
|
||||
network communication (send/recv/all_gather) with another stream's compute.
|
||||
"""
|
||||
|
||||
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false
|
||||
# pyright: reportUnknownArgumentType=false, reportAny=false
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicroBatch:
|
||||
"""State for one micro-batch group of sequences."""
|
||||
|
||||
uids: list[int]
|
||||
y: mx.array # Last sampled tokens [batch]
|
||||
logprobs: list[mx.array] # Logprobs for each sequence
|
||||
max_tokens: list[int]
|
||||
num_tokens: list[int]
|
||||
cache: list[Any] # KV cache (list of layer caches)
|
||||
samplers: list[Callable[[mx.array], mx.array]]
|
||||
tokens: list[mx.array] # All tokens generated so far per sequence
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.uids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelinedResponse:
|
||||
"""Response from one generation step."""
|
||||
|
||||
uid: int
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
finish_reason: str | None
|
||||
cache: list[Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingPrompt:
|
||||
"""A prompt waiting to be prefilled."""
|
||||
|
||||
uid: int
|
||||
tokens: list[int]
|
||||
max_tokens: int
|
||||
sampler: Callable[[mx.array], mx.array]
|
||||
|
||||
|
||||
class PipelinedGenerator:
|
||||
"""
|
||||
Multi-stream batch generator that fills pipeline bubbles.
|
||||
|
||||
Splits active sequences into `world_size` micro-batch groups, each processed
|
||||
on its own MLX stream. During mx.eval(), the GPU overlaps network operations
|
||||
on one stream with compute on another.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
world_size: int,
|
||||
stop_tokens: set[int] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
self.model = model
|
||||
self.world_size = world_size
|
||||
self.stop_tokens = stop_tokens or set()
|
||||
self.max_tokens_default = max_tokens
|
||||
|
||||
# Create one stream per pipeline stage
|
||||
self.streams = [mx.new_stream(mx.default_device()) for _ in range(world_size)]
|
||||
|
||||
# Micro-batch groups (one per stream)
|
||||
self.micro_batches: list[MicroBatch | None] = [None] * world_size
|
||||
|
||||
# Pending prompts to be inserted
|
||||
self.pending_prompts: list[PendingPrompt] = []
|
||||
|
||||
# UID counter
|
||||
self._next_uid = 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
"""Total number of active sequences across all micro-batches."""
|
||||
return sum(len(mb) for mb in self.micro_batches if mb is not None)
|
||||
|
||||
@property
|
||||
def has_active(self) -> bool:
|
||||
return self.active_count > 0 or len(self.pending_prompts) > 0
|
||||
|
||||
def insert(
|
||||
self,
|
||||
prompts: list[list[int]],
|
||||
max_tokens: list[int],
|
||||
samplers: list[Callable[[mx.array], mx.array]],
|
||||
) -> list[int]:
|
||||
"""Queue prompts for processing. Returns assigned UIDs."""
|
||||
uids: list[int] = []
|
||||
for prompt, mt, sampler in zip(prompts, max_tokens, samplers, strict=True):
|
||||
uid = self._next_uid
|
||||
self._next_uid += 1
|
||||
self.pending_prompts.append(
|
||||
PendingPrompt(uid=uid, tokens=prompt, max_tokens=mt, sampler=sampler)
|
||||
)
|
||||
uids.append(uid)
|
||||
return uids
|
||||
|
||||
def _prefill_group(self, group_idx: int, prompts: list[PendingPrompt]) -> None:
|
||||
"""Prefill a group of prompts and create a MicroBatch."""
|
||||
if not prompts:
|
||||
return
|
||||
|
||||
stream = self.streams[group_idx]
|
||||
|
||||
with mx.stream(stream):
|
||||
# Create per-sequence caches
|
||||
caches = [make_prompt_cache(self.model) for _ in prompts]
|
||||
|
||||
# Tokenize and prefill each sequence
|
||||
all_y: list[mx.array] = []
|
||||
all_logprobs: list[mx.array] = []
|
||||
all_samplers: list[Callable[[mx.array], mx.array]] = []
|
||||
all_tokens: list[mx.array] = []
|
||||
|
||||
for prompt_info, cache in zip(prompts, caches, strict=True):
|
||||
tokens = mx.array(prompt_info.tokens)
|
||||
# Run prefill (process all tokens except last)
|
||||
if len(prompt_info.tokens) > 1:
|
||||
self.model(tokens[:-1][None, :], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
|
||||
# Process last token to get first generation logits
|
||||
last_token = tokens[-1:][None, :]
|
||||
logits = self.model(last_token, cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
sampled = prompt_info.sampler(logprobs)
|
||||
|
||||
all_y.append(sampled.squeeze(0))
|
||||
all_logprobs.append(logprobs.squeeze(0))
|
||||
all_samplers.append(prompt_info.sampler)
|
||||
all_tokens.append(tokens)
|
||||
|
||||
mx.eval(*all_y, *all_logprobs)
|
||||
|
||||
# Create micro-batch
|
||||
batch = MicroBatch(
|
||||
uids=[p.uid for p in prompts],
|
||||
y=mx.stack(all_y),
|
||||
logprobs=all_logprobs,
|
||||
max_tokens=[p.max_tokens for p in prompts],
|
||||
num_tokens=[0] * len(prompts),
|
||||
cache=caches,
|
||||
samplers=all_samplers,
|
||||
tokens=all_tokens,
|
||||
)
|
||||
|
||||
if self.micro_batches[group_idx] is None:
|
||||
self.micro_batches[group_idx] = batch
|
||||
else:
|
||||
# Extend existing micro-batch (would need cache merging - for now replace)
|
||||
self.micro_batches[group_idx] = batch
|
||||
|
||||
def _prefill_pending(self) -> None:
|
||||
"""Distribute pending prompts across micro-batch groups and prefill."""
|
||||
if not self.pending_prompts:
|
||||
return
|
||||
|
||||
# Distribute round-robin across groups
|
||||
groups: list[list[PendingPrompt]] = [[] for _ in range(self.world_size)]
|
||||
for i, prompt in enumerate(self.pending_prompts):
|
||||
groups[i % self.world_size].append(prompt)
|
||||
self.pending_prompts.clear()
|
||||
|
||||
for group_idx, group_prompts in enumerate(groups):
|
||||
if group_prompts:
|
||||
self._prefill_group(group_idx, group_prompts)
|
||||
|
||||
def _step_all(self) -> None:
|
||||
"""
|
||||
Run one generation step across all micro-batch groups on different streams.
|
||||
|
||||
This is where pipeline overlap happens: each group's model forward pass
|
||||
runs on its own stream, and mx.eval() allows the GPU to overlap network
|
||||
ops (send/recv/all_gather) from one stream with compute from another.
|
||||
|
||||
Each sequence is processed individually with its own KV cache, but all
|
||||
lazy graphs across streams are evaluated together for GPU overlap.
|
||||
"""
|
||||
# Build computation graphs on each stream (lazy, no evaluation yet)
|
||||
# Each micro-batch group processes its sequences on its own stream.
|
||||
all_sampled: list[mx.array] = []
|
||||
all_logprobs: list[mx.array] = []
|
||||
# Track which (group_idx, seq_idx) each result corresponds to
|
||||
result_map: list[tuple[int, int]] = []
|
||||
|
||||
for i, mb in enumerate(self.micro_batches):
|
||||
if mb is None or len(mb) == 0:
|
||||
continue
|
||||
|
||||
with mx.stream(self.streams[i]):
|
||||
for e in range(len(mb)):
|
||||
# Process each sequence individually with its own cache
|
||||
input_token = mb.y[e : e + 1][None, :] # [1, 1]
|
||||
|
||||
# Forward pass (lazy graph construction)
|
||||
# For pipeline models, this includes send/recv/all_gather ops
|
||||
logits = self.model(input_token, cache=mb.cache[e])
|
||||
logits = logits[:, -1, :] # [1, vocab]
|
||||
|
||||
# Compute logprobs
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
|
||||
# Sample
|
||||
sampled = mb.samplers[e](logprobs)
|
||||
|
||||
all_sampled.append(sampled.squeeze(0))
|
||||
all_logprobs.append(logprobs.squeeze(0))
|
||||
result_map.append((i, e))
|
||||
|
||||
if not result_map:
|
||||
return
|
||||
|
||||
# Evaluate ALL streams together - this is where overlap happens!
|
||||
# The GPU can execute stream0's all_gather while computing stream1's layers.
|
||||
mx.eval(*all_sampled, *all_logprobs)
|
||||
|
||||
# Update micro-batch states with results
|
||||
# Group results by micro-batch for efficient update
|
||||
group_results: dict[int, list[int]] = {}
|
||||
for idx, (group_idx, _seq_idx) in enumerate(result_map):
|
||||
group_results.setdefault(group_idx, []).append(idx)
|
||||
|
||||
for group_idx, result_indices in group_results.items():
|
||||
mb = self.micro_batches[group_idx]
|
||||
assert mb is not None
|
||||
group_sampled = [all_sampled[idx] for idx in result_indices]
|
||||
group_logprobs = [all_logprobs[idx] for idx in result_indices]
|
||||
mb.y = mx.stack(group_sampled)
|
||||
mb.logprobs = group_logprobs
|
||||
for e, idx in enumerate(result_indices):
|
||||
mb.tokens[e] = mx.concatenate([mb.tokens[e], all_sampled[idx][None]])
|
||||
|
||||
def next(self) -> list[PipelinedResponse]:
|
||||
"""
|
||||
Run one generation step and return responses.
|
||||
|
||||
Returns a PipelinedResponse for each active sequence (across all groups).
|
||||
Finished sequences are removed from their micro-batch.
|
||||
"""
|
||||
# Prefill any pending prompts first
|
||||
self._prefill_pending()
|
||||
|
||||
if not self.has_active:
|
||||
return []
|
||||
|
||||
# Run the multi-stream forward pass
|
||||
self._step_all()
|
||||
|
||||
# Collect responses and filter completed sequences
|
||||
responses: list[PipelinedResponse] = []
|
||||
|
||||
for group_idx, mb in enumerate(self.micro_batches):
|
||||
if mb is None or len(mb) == 0:
|
||||
continue
|
||||
|
||||
keep_idx: list[int] = []
|
||||
end_idx: list[int] = []
|
||||
|
||||
for e in range(len(mb)):
|
||||
token = int(mb.y[e].item())
|
||||
uid = mb.uids[e]
|
||||
num_tok = mb.num_tokens[e] + 1
|
||||
max_tok = mb.max_tokens[e]
|
||||
mb.num_tokens[e] = num_tok
|
||||
|
||||
if token in self.stop_tokens:
|
||||
finish_reason = "stop"
|
||||
end_idx.append(e)
|
||||
elif num_tok >= max_tok:
|
||||
finish_reason = "length"
|
||||
end_idx.append(e)
|
||||
else:
|
||||
finish_reason = None
|
||||
keep_idx.append(e)
|
||||
|
||||
responses.append(
|
||||
PipelinedResponse(
|
||||
uid=uid,
|
||||
token=token,
|
||||
logprobs=mb.logprobs[e],
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# Remove finished sequences
|
||||
if end_idx:
|
||||
if keep_idx:
|
||||
# Filter the micro-batch to keep only active sequences
|
||||
mb.uids = [mb.uids[i] for i in keep_idx]
|
||||
mb.y = mb.y[mx.array(keep_idx)]
|
||||
mb.logprobs = [mb.logprobs[i] for i in keep_idx]
|
||||
mb.max_tokens = [mb.max_tokens[i] for i in keep_idx]
|
||||
mb.num_tokens = [mb.num_tokens[i] for i in keep_idx]
|
||||
mb.samplers = [mb.samplers[i] for i in keep_idx]
|
||||
mb.tokens = [mb.tokens[i] for i in keep_idx]
|
||||
# Cache filtering: trim batch dimension
|
||||
for c in mb.cache:
|
||||
if hasattr(c, "keys") and c.keys is not None:
|
||||
c.keys = c.keys[mx.array(keep_idx)]
|
||||
c.values = c.values[mx.array(keep_idx)]
|
||||
else:
|
||||
self.micro_batches[group_idx] = None
|
||||
|
||||
return responses
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
self.micro_batches = [None] * self.world_size
|
||||
self.pending_prompts.clear()
|
||||
@@ -1,11 +1,13 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from anyio import EndOfStream, WouldBlock
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
@@ -61,7 +63,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.shared.types.worker.shards import ShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.image import (
|
||||
DistributedImageModel,
|
||||
@@ -70,8 +72,10 @@ from exo.worker.engines.image import (
|
||||
warmup_image_generator,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.generator.generate import (
|
||||
mlx_generate,
|
||||
warmup_inference,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
detect_thinking_prompt_suffix,
|
||||
@@ -79,8 +83,128 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
from exo.worker.runner.batched_handler import BatchedInferenceHandler
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Batching configuration
|
||||
BATCH_ENABLED = os.environ.get("EXO_NO_BATCH") != "1"
|
||||
BATCH_MAX_SIZE = 32
|
||||
|
||||
|
||||
def _should_use_serial_processing(
|
||||
task: ChatCompletion,
|
||||
tokenizer: TokenizerWrapper,
|
||||
model: Model,
|
||||
model_id: ModelId,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a ChatCompletion task requires serial processing.
|
||||
|
||||
Currently always returns False - batch mode handles all cases.
|
||||
Post-processing (GPT-OSS, thinking models, tool calls) can be applied
|
||||
per-request to the individual streams from the batch generator.
|
||||
"""
|
||||
# All tasks can use batch mode - post-processing is per-request
|
||||
return False
|
||||
|
||||
|
||||
def _process_serial_chat_completion(
|
||||
task: ChatCompletion,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
shard_metadata: ShardMetadata,
|
||||
event_sender: MpSender[Event],
|
||||
) -> None:
|
||||
"""Process a ChatCompletion task serially (original implementation)."""
|
||||
task_params = task.task_params
|
||||
command_id = task.command_id
|
||||
device_rank = shard_metadata.device_rank
|
||||
|
||||
if task_params.messages[0].content is not None:
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Build prompt once - used for both generation and thinking detection
|
||||
prompt = apply_chat_template(tokenizer, task_params)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
mlx_generator = parse_thinking_models(mlx_generator, tokenizer)
|
||||
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if device_rank == 0 and response.finish_reason == "error":
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
error_message=response.text,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif device_rank == 0:
|
||||
assert response.finish_reason not in (
|
||||
"error",
|
||||
"tool_calls",
|
||||
"function_call",
|
||||
)
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
@@ -93,6 +217,11 @@ def main(
|
||||
bound_instance.bound_shard,
|
||||
)
|
||||
device_rank = shard_metadata.device_rank
|
||||
# Determine if this node is the coordinator for tensor parallel
|
||||
# Use sorted node ordering for consistency with main.py
|
||||
node_id = bound_instance.bound_node_id
|
||||
sorted_nodes = sorted(instance.shard_assignments.node_to_runner.keys())
|
||||
is_tp_coordinator = node_id == sorted_nodes[0]
|
||||
logger.info("hello from the runner")
|
||||
if getattr(shard_metadata, "immediate_exception", False):
|
||||
raise Exception("Fake exception - runner failed to spin up.")
|
||||
@@ -102,229 +231,199 @@ def main(
|
||||
setup_start_time = time.time()
|
||||
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
tokenizer: TokenizerWrapper | None = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
batch_handler: BatchedInferenceHandler | None = None
|
||||
is_tensor_parallel = False
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
|
||||
def process_task(task: Task) -> bool:
|
||||
"""
|
||||
Process a single task. Returns True if the runner should continue,
|
||||
False if it should shut down.
|
||||
"""
|
||||
nonlocal \
|
||||
current_status, \
|
||||
model, \
|
||||
tokenizer, \
|
||||
group, \
|
||||
batch_handler, \
|
||||
is_tensor_parallel
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
# NOTE: TaskAcknowledged is sent per-case below, AFTER the initial status
|
||||
# update, to avoid a race where the supervisor sees the ack before the
|
||||
# status change and re-dispatches the same lifecycle command.
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Model loading timed out"
|
||||
),
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
time.sleep(0.5)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(f"model has_tool_calling={tokenizer.has_tool_calling}")
|
||||
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Model loading timed out"
|
||||
),
|
||||
)
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
# Initialize batch handler for text generation models
|
||||
is_tensor_parallel = isinstance(shard_metadata, TensorShardMetadata)
|
||||
if BATCH_ENABLED:
|
||||
batch_handler = BatchedInferenceHandler(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
model_id=shard_metadata.model_card.model_id,
|
||||
device_rank=device_rank,
|
||||
world_size=1
|
||||
if is_tensor_parallel
|
||||
else shard_metadata.world_size,
|
||||
max_batch_size=BATCH_MAX_SIZE,
|
||||
tensor_parallel_group=group if is_tensor_parallel else None,
|
||||
is_coordinator=is_tp_coordinator
|
||||
if is_tensor_parallel
|
||||
else True,
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
f"Batch handler initialized (max_batch_size={BATCH_MAX_SIZE}, tensor_parallel={is_tensor_parallel})"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
|
||||
# For tensor parallel with batch handler, skip explicit warmup.
|
||||
# The batch handler synchronizes all ranks via all_sum in flush(),
|
||||
# so the first real request warms up the model on all ranks simultaneously.
|
||||
# Without a batch handler, warmup must run normally to avoid GPU locks
|
||||
# from mismatched send/recv in serial processing.
|
||||
if is_tensor_parallel and batch_handler is not None:
|
||||
logger.info(
|
||||
"Tensor parallel: skipping warmup (first request will warm up through batch handler)"
|
||||
)
|
||||
toks = 0
|
||||
else:
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
logger.info("warmup completed (non-primary node)")
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
logger.info(f"received chat request: {task}")
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
logger.info("warmup completed (non-primary node)")
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, (RunnerReady, RunnerRunning))
|
||||
):
|
||||
logger.info(f"received chat request: {task}")
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
|
||||
# Check if we should use serial processing for this task
|
||||
if not BATCH_ENABLED:
|
||||
logger.debug("Serial mode: BATCH_ENABLED is False")
|
||||
use_serial = True
|
||||
elif batch_handler is None:
|
||||
logger.debug("Serial mode: batch_handler is None")
|
||||
use_serial = True
|
||||
else:
|
||||
use_serial = _should_use_serial_processing(
|
||||
task, tokenizer, model, shard_metadata.model_card.model_id
|
||||
)
|
||||
|
||||
if use_serial:
|
||||
# Serial processing for complex tasks
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
logger.info("runner running (serial mode)")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Build prompt once - used for both generation and thinking detection
|
||||
prompt = apply_chat_template(tokenizer, task_params)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
_process_serial_chat_completion(
|
||||
task, model, tokenizer, shard_metadata, event_sender
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
mlx_generator = parse_thinking_models(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
elif "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if (
|
||||
device_rank == 0
|
||||
and response.finish_reason == "error"
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
error_message=response.text,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif device_rank == 0:
|
||||
assert response.finish_reason not in (
|
||||
"error",
|
||||
"tool_calls",
|
||||
"function_call",
|
||||
)
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
@@ -341,53 +440,35 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
# Batch processing for simple tasks
|
||||
assert batch_handler is not None
|
||||
try:
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
# Non-coordinator TP: don't add to batch handler.
|
||||
# The batch handler syncs via all_sum in flush();
|
||||
# non-coordinator participates through that, not through add_request.
|
||||
if is_tensor_parallel and not is_tp_coordinator:
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
return True
|
||||
batch_handler.add_request(task)
|
||||
|
||||
# Update status to running if not already
|
||||
if not isinstance(current_status, RunnerRunning):
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running (batch mode)")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
# Return True to indicate task was added to batch
|
||||
# (completion will be sent when batch processes)
|
||||
return True
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
@@ -399,93 +480,242 @@ def main(
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
case ImageGeneration(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
except Exception as e:
|
||||
try:
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
if batch_handler is not None:
|
||||
batch_handler.close()
|
||||
batch_handler = None
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
return not isinstance(current_status, RunnerShutdown)
|
||||
|
||||
# Track tasks that were added to batch (need completion after batch processes)
|
||||
batched_task_ids: list[tuple[Task, bool]] = [] # (task, completed)
|
||||
|
||||
with task_receiver as tasks:
|
||||
while True:
|
||||
# For tensor parallel: both coordinator and non-coordinator go through
|
||||
# the same loop, but only coordinator receives tasks. This ensures
|
||||
# flush() all_sum calls are synchronized.
|
||||
if batch_handler is not None and (
|
||||
batch_handler.is_active
|
||||
or batch_handler.has_pending
|
||||
or is_tensor_parallel
|
||||
):
|
||||
# Drain all available tasks before stepping
|
||||
# Non-coordinator won't receive any (main.py doesn't send to it)
|
||||
should_break = False
|
||||
while True:
|
||||
try:
|
||||
task = tasks.receive_nowait()
|
||||
if isinstance(task, ChatCompletion) and isinstance(
|
||||
current_status, (RunnerReady, RunnerRunning)
|
||||
):
|
||||
was_batched = process_task(task)
|
||||
if was_batched:
|
||||
batched_task_ids.append((task, False))
|
||||
else:
|
||||
should_continue = process_task(task)
|
||||
if not should_continue:
|
||||
should_break = True
|
||||
break
|
||||
except WouldBlock:
|
||||
break # No more tasks available
|
||||
except EndOfStream:
|
||||
should_break = True
|
||||
break
|
||||
if should_break:
|
||||
break
|
||||
|
||||
# Flush: for tensor parallel, always call so all ranks sync via all_sum
|
||||
# For non-TP, only call when has_pending
|
||||
if batch_handler.has_pending or is_tensor_parallel:
|
||||
if batch_handler.has_pending:
|
||||
logger.info(
|
||||
f"Flushing batch (pending={len(batch_handler.pending)}, active={batch_handler.current_batch_size})"
|
||||
)
|
||||
batch_handler.flush()
|
||||
|
||||
# Step generation and emit events
|
||||
if batch_handler.is_active:
|
||||
event_count = 0
|
||||
for event in batch_handler.step():
|
||||
event_sender.send(event)
|
||||
event_count += 1
|
||||
if event_count > 0:
|
||||
logger.debug(f"Emitted {event_count} events from batch")
|
||||
|
||||
# Check for completed batched tasks
|
||||
if not batch_handler.is_active and not batch_handler.has_pending:
|
||||
# All batched tasks completed
|
||||
for task, completed in batched_task_ids:
|
||||
if not completed:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
raise
|
||||
batched_task_ids.clear()
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
# Return to ready state
|
||||
if isinstance(current_status, RunnerRunning):
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready (batch completed)")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
else:
|
||||
# No active batch - use blocking receive
|
||||
try:
|
||||
task = tasks.receive()
|
||||
should_continue = process_task(task)
|
||||
if not should_continue:
|
||||
break
|
||||
except EndOfStream:
|
||||
break
|
||||
|
||||
gc.collect()
|
||||
break
|
||||
# Cleanup
|
||||
if batch_handler is not None:
|
||||
batch_handler.close()
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
@cache
|
||||
|
||||
@@ -52,6 +52,9 @@ class RunnerSupervisor:
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
sent: set[TaskId] = field(
|
||||
default_factory=set, init=False
|
||||
) # Tasks sent to runner (not yet completed)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
@@ -126,21 +129,39 @@ class RunnerSupervisor:
|
||||
assert self._tg
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
async def start_task(self, task: Task, wait_for_ack: bool = True):
|
||||
"""
|
||||
Send a task to the runner.
|
||||
|
||||
Args:
|
||||
task: The task to send.
|
||||
wait_for_ack: If True, wait for TaskAcknowledged before returning.
|
||||
If False, return immediately after sending (for batching).
|
||||
"""
|
||||
if task.task_id in self.completed:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
logger.debug(
|
||||
f"Skipping task {task.task_id} as it has already been completed"
|
||||
)
|
||||
return
|
||||
if task.task_id in self.sent:
|
||||
logger.debug(f"Task {task.task_id} already sent, skipping duplicate")
|
||||
return
|
||||
if task.task_id in self.pending:
|
||||
logger.debug(f"Task {task.task_id} already pending, skipping duplicate")
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
self.sent.add(task.task_id)
|
||||
try:
|
||||
self._task_sender.send(task)
|
||||
except ClosedResourceError:
|
||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||
self.sent.discard(task.task_id)
|
||||
return
|
||||
await event.wait()
|
||||
logger.info(f"Finished task {task}")
|
||||
if wait_for_ack:
|
||||
await event.wait()
|
||||
logger.info(f"Finished task {task}")
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
@@ -149,7 +170,11 @@ class RunnerSupervisor:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
# Use pop with default to handle tasks sent with wait_for_ack=False
|
||||
# that may have already been removed or never added
|
||||
pending_event = self.pending.pop(event.task_id, None)
|
||||
if pending_event:
|
||||
pending_event.set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
@@ -167,6 +192,7 @@ class RunnerSupervisor:
|
||||
),
|
||||
)
|
||||
self.completed.add(event.task_id)
|
||||
self.sent.discard(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
|
||||
@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
sent: set[TaskId] = field(default_factory=set)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -5,6 +5,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
|
||||
@@ -118,6 +118,10 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
||||
# Force serial processing mode since batch mode requires a real tokenizer
|
||||
monkeypatch.setattr(mlx_runner, "_should_use_serial_processing", make_nothin(True))
|
||||
# Disable batch handler initialization
|
||||
monkeypatch.setattr(mlx_runner, "BATCH_ENABLED", False)
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
@@ -192,29 +196,30 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
# Status update comes before ack to prevent race conditions
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
|
||||
),
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
@@ -222,10 +227,10 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
|
||||
),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user