mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 11:11:45 -05:00
Compare commits
52 Commits
ciaran/ima
...
david/mla-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6018a9c97c | ||
|
|
07b8405d3e | ||
|
|
5d3b407602 | ||
|
|
e7a5826aed | ||
|
|
ebe279018f | ||
|
|
bf67e7d334 | ||
|
|
0cd2f6aab4 | ||
|
|
ba8a44e6a2 | ||
|
|
07c4be157b | ||
|
|
1e1eb8f8a1 | ||
|
|
1bc2d9728d | ||
|
|
7823fd7b1a | ||
|
|
05caab0047 | ||
|
|
bd8f9f2d10 | ||
|
|
34fcafa68a | ||
|
|
5152789e00 | ||
|
|
b734437b2d | ||
|
|
553939fa31 | ||
|
|
c55cbf6739 | ||
|
|
13ee17428e | ||
|
|
1b0d39c0b3 | ||
|
|
5e3cd73a9e | ||
|
|
1d1256c769 | ||
|
|
77baf9c58e | ||
|
|
022a09b6d9 | ||
|
|
0aa708fac4 | ||
|
|
eb89c2e4b9 | ||
|
|
72a5eec3f7 | ||
|
|
bd4f0bf048 | ||
|
|
cd8c01b7c8 | ||
|
|
59e991ce15 | ||
|
|
ffba340e70 | ||
|
|
9968abe816 | ||
|
|
0e30b0830f | ||
|
|
44453c4c8b | ||
|
|
1290e8ed9f | ||
|
|
d93db3d6bf | ||
|
|
a25892e8d5 | ||
|
|
8798ab52ee | ||
|
|
457debc338 | ||
|
|
0cfaea41bc | ||
|
|
18c82443ba | ||
|
|
b9ec8b0a44 | ||
|
|
00442b3cfd | ||
|
|
aa41da8541 | ||
|
|
86e5d7b101 | ||
|
|
d9ddf90575 | ||
|
|
4591301767 | ||
|
|
8b0b5e1b88 | ||
|
|
bd6287727a | ||
|
|
eb53611210 | ||
|
|
71bbe5f25b |
@@ -5,18 +5,18 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
|
||||
@@ -31,6 +31,35 @@ enum NetworkSetupHelper {
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
|
||||
0
bench/__init__.py
Normal file
0
bench/__init__.py
Normal file
66
bench/eval_config.toml
Normal file
66
bench/eval_config.toml
Normal file
@@ -0,0 +1,66 @@
|
||||
# exo-eval configuration file
|
||||
# See bench/exo_eval.py for usage
|
||||
|
||||
[eval]
|
||||
# Eval framework type: "lm_eval" | "swe_bench" | "custom"
|
||||
type = "lm_eval"
|
||||
# Require HuggingFace token (default: true)
|
||||
# Set to false if using only public datasets
|
||||
require_hf_token = true
|
||||
|
||||
# Instance/placement configuration
|
||||
# Controls how exo sets up the model instance before running evals
|
||||
[instance]
|
||||
# Placement strategy: "ring" | "jaccl" | "both"
|
||||
instance_meta = "jaccl"
|
||||
# Sharding strategy: "pipeline" | "tensor" | "both"
|
||||
sharding = "tensor"
|
||||
# Node constraints
|
||||
min_nodes = 2
|
||||
max_nodes = 2
|
||||
|
||||
# lm_eval configuration (EleutherAI's lm-evaluation-harness)
|
||||
[lm_eval]
|
||||
# Tasks to run (list of task names)
|
||||
# NOTE: Chat completions API only supports generation-based tasks.
|
||||
# Loglikelihood tasks (mmlu, hellaswag, arc) require /v1/completions endpoint.
|
||||
#
|
||||
# Generation-based tasks (work with chat completions):
|
||||
# - mmlu_pro, mmlu_generative, mmlu_flan_cot_fewshot, mmlu_flan_cot_zeroshot
|
||||
# - gsm8k, gsm8k_cot, gsm8k_cot_zeroshot
|
||||
# - truthfulqa (uses generate_until for some subtasks)
|
||||
# - humaneval, mbpp (code generation)
|
||||
#
|
||||
# Run `lm_eval --tasks list` to see all available tasks
|
||||
tasks = ["mmlu_pro"]
|
||||
# Number of few-shot examples (5 is standard for mmlu_pro CoT)
|
||||
num_fewshot = 5
|
||||
# Batch size (use 1 for API models, "auto" doesn't work)
|
||||
batch_size = 1
|
||||
# Number of concurrent requests (set > 1 to enable parallelism)
|
||||
# Higher values enable better batching throughput
|
||||
num_concurrent = 64
|
||||
# Apply chat template for instruct/chat models (default: true)
|
||||
apply_chat_template = true
|
||||
# Use fewshot examples as conversation turns (better for chat models)
|
||||
fewshot_as_multiturn = true
|
||||
# Optional: limit samples per task (omit or comment out for no limit)
|
||||
# limit = 100
|
||||
# Output path for results
|
||||
output_path = "bench/eval_results"
|
||||
|
||||
# SWE-bench configuration (placeholder)
|
||||
[swe_bench]
|
||||
# SWE-bench dataset
|
||||
dataset = "princeton-nlp/SWE-bench_Lite"
|
||||
# Maximum workers for parallel execution
|
||||
max_workers = 8
|
||||
# Path for prediction outputs
|
||||
predictions_path = "bench/predictions"
|
||||
|
||||
# Custom evaluation script configuration
|
||||
[custom]
|
||||
# Path to custom evaluation script
|
||||
script = "path/to/eval_script.py"
|
||||
# Arguments to pass to the script
|
||||
args = ["--arg1", "value1"]
|
||||
679
bench/exo_eval.py
Normal file
679
bench/exo_eval.py
Normal file
@@ -0,0 +1,679 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
"""
|
||||
exo-eval: Evaluation harness for exo inference system.
|
||||
|
||||
Supports multiple evaluation frameworks via TOML configuration:
|
||||
- lm_eval: Language model evaluation using EleutherAI's lm-evaluation-harness
|
||||
- swe_bench: SWE-bench evaluation (placeholder for future implementation)
|
||||
- custom: Custom evaluation scripts
|
||||
|
||||
Usage:
|
||||
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit
|
||||
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit --dry-run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
# Add parent directory to path for direct script execution
|
||||
if __name__ == "__main__" and __package__ is None:
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
import tomlkit
|
||||
from huggingface_hub import get_token as get_hf_token
|
||||
from loguru import logger
|
||||
from tomlkit.exceptions import TOMLKitError
|
||||
|
||||
from bench.exo_bench import (
|
||||
ExoClient,
|
||||
ExoHttpError,
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
placement_filter,
|
||||
resolve_model_short_id,
|
||||
sharding_filter,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
)
|
||||
|
||||
EvalType = Literal["lm_eval", "swe_bench", "custom"]
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict[str, Any]:
|
||||
"""Load and parse TOML configuration file."""
|
||||
path = Path(config_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return dict(tomlkit.load(f))
|
||||
|
||||
|
||||
def get_eval_type(config: dict[str, Any]) -> EvalType:
|
||||
"""Extract evaluation type from config."""
|
||||
eval_section = config.get("eval", {})
|
||||
eval_type = eval_section.get("type", "lm_eval")
|
||||
if eval_type not in ("lm_eval", "swe_bench", "custom"):
|
||||
raise ValueError(f"Unknown eval type: {eval_type}")
|
||||
return eval_type
|
||||
|
||||
|
||||
def check_hf_token(config: dict[str, Any]) -> bool:
|
||||
"""Check if HuggingFace token is available when required.
|
||||
|
||||
Returns True if token is available or not required, False otherwise.
|
||||
"""
|
||||
eval_section = config.get("eval", {})
|
||||
require_hf_token = eval_section.get("require_hf_token", True)
|
||||
|
||||
if not require_hf_token:
|
||||
return True
|
||||
|
||||
token = get_hf_token()
|
||||
if token is None:
|
||||
logger.error(
|
||||
"HuggingFace token not found. "
|
||||
"Set HF_TOKEN environment variable or run 'huggingface-cli login'. "
|
||||
"To disable this check, set require_hf_token = false in [eval] config."
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("HuggingFace token found")
|
||||
return True
|
||||
|
||||
|
||||
def select_placement(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
config: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Select a placement based on config preferences."""
|
||||
instance_config = config.get("instance", {})
|
||||
|
||||
# If explicit instance is provided, use it directly
|
||||
if "instance" in instance_config:
|
||||
return instance_config["instance"]
|
||||
|
||||
# Otherwise, select from previews based on preferences
|
||||
instance_meta_pref = instance_config.get("instance_meta", "ring")
|
||||
sharding_pref = instance_config.get("sharding", "pipeline")
|
||||
max_nodes = instance_config.get("max_nodes", 4)
|
||||
min_nodes = instance_config.get("min_nodes", 1)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), instance_meta_pref):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), sharding_pref):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
if min_nodes <= n <= max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
return None
|
||||
|
||||
# Sort by preference: exact match on sharding/meta, then by node count (descending)
|
||||
def sort_key(p: dict[str, Any]) -> tuple[int, int, int]:
|
||||
meta_match = (
|
||||
1 if instance_meta_pref in str(p.get("instance_meta", "")).lower() else 0
|
||||
)
|
||||
sharding_match = 1 if sharding_pref in str(p.get("sharding", "")).lower() else 0
|
||||
n_nodes = nodes_used_in_instance(p["instance"])
|
||||
return (meta_match, sharding_match, n_nodes)
|
||||
|
||||
selected.sort(key=sort_key, reverse=True)
|
||||
return selected[0]
|
||||
|
||||
|
||||
def setup_instance(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
config: dict[str, Any],
|
||||
dry_run: bool,
|
||||
) -> tuple[str | None, dict[str, Any] | None]:
|
||||
"""Create and wait for an instance to be ready. Returns (instance_id, preview)."""
|
||||
preview = select_placement(client, full_model_id, config)
|
||||
|
||||
if preview is None:
|
||||
logger.error("No valid placement found matching config preferences")
|
||||
return None, None
|
||||
|
||||
instance_data = preview.get("instance")
|
||||
instance: dict[str, Any] = (
|
||||
instance_data if isinstance(instance_data, dict) else preview
|
||||
)
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
|
||||
sharding = str(preview.get("sharding", "unknown"))
|
||||
instance_meta = str(preview.get("instance_meta", "unknown"))
|
||||
n_nodes = nodes_used_in_instance(instance)
|
||||
|
||||
logger.info(f"Selected placement: {sharding} / {instance_meta} / nodes={n_nodes}")
|
||||
logger.info(f"Instance ID: {instance_id}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would create instance and wait for ready")
|
||||
return instance_id, preview
|
||||
|
||||
# Create instance
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
logger.info("Instance is ready")
|
||||
time.sleep(1) # Brief pause after ready
|
||||
return instance_id, preview
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize instance: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
return None, None
|
||||
|
||||
|
||||
def teardown_instance(client: ExoClient, instance_id: str) -> None:
|
||||
"""Delete an instance and wait for it to be gone."""
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
except (ConnectionRefusedError, OSError):
|
||||
logger.warning(
|
||||
f"Could not connect to exo to delete instance {instance_id} (server may be down)"
|
||||
)
|
||||
return
|
||||
try:
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
except (ConnectionRefusedError, OSError, TimeoutError):
|
||||
logger.warning("Could not verify instance deletion (server may be down)")
|
||||
return
|
||||
logger.info(f"Instance {instance_id} deleted")
|
||||
|
||||
|
||||
def build_lm_eval_args(
|
||||
config: dict[str, Any],
|
||||
base_url: str,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
limit: int | None,
|
||||
use_completions: bool,
|
||||
) -> list[str]:
|
||||
"""Build command-line arguments for lm_eval."""
|
||||
lm_eval_config = config.get("lm_eval", {})
|
||||
|
||||
# Choose model type based on whether tasks need completions API
|
||||
if use_completions:
|
||||
model_type = "local-completions"
|
||||
endpoint_url = f"{base_url}/v1/completions"
|
||||
else:
|
||||
model_type = "local-chat-completions"
|
||||
endpoint_url = f"{base_url}/v1/chat/completions"
|
||||
|
||||
# Build model_args string with num_concurrent and timeout
|
||||
model_args_parts = [f"model={model}", f"base_url={endpoint_url}"]
|
||||
num_concurrent = lm_eval_config.get("num_concurrent")
|
||||
if num_concurrent is not None and num_concurrent > 1:
|
||||
model_args_parts.append(f"num_concurrent={num_concurrent}")
|
||||
# Use a very long timeout (1 week) to handle large request queues
|
||||
timeout = lm_eval_config.get("timeout", 604800)
|
||||
model_args_parts.append(f"timeout={timeout}")
|
||||
model_args = ",".join(model_args_parts)
|
||||
|
||||
args = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"bench.lm_eval_patched",
|
||||
"--model",
|
||||
model_type,
|
||||
"--model_args",
|
||||
model_args,
|
||||
"--verbosity",
|
||||
"WARNING",
|
||||
]
|
||||
|
||||
# Tasks
|
||||
tasks = lm_eval_config.get("tasks", ["mmlu"])
|
||||
tasks_str = ",".join(tasks) if isinstance(tasks, list) else str(tasks)
|
||||
args.extend(["--tasks", tasks_str])
|
||||
|
||||
# Few-shot
|
||||
num_fewshot = lm_eval_config.get("num_fewshot")
|
||||
if num_fewshot is not None:
|
||||
args.extend(["--num_fewshot", str(num_fewshot)])
|
||||
|
||||
# Batch size (default to 1 for API models, "auto" doesn't work)
|
||||
batch_size = lm_eval_config.get("batch_size", 1)
|
||||
args.extend(["--batch_size", str(batch_size)])
|
||||
|
||||
# Apply chat template for instruct/chat models (default: true)
|
||||
# Only applies to chat completions, but doesn't hurt to include
|
||||
apply_chat_template = lm_eval_config.get("apply_chat_template", True)
|
||||
if apply_chat_template and not use_completions:
|
||||
args.append("--apply_chat_template")
|
||||
|
||||
# Fewshot as multiturn (optional, works with chat template)
|
||||
fewshot_as_multiturn = lm_eval_config.get("fewshot_as_multiturn", False)
|
||||
if fewshot_as_multiturn and not use_completions:
|
||||
args.append("--fewshot_as_multiturn")
|
||||
|
||||
# Limit (command line overrides config)
|
||||
effective_limit = limit if limit is not None else lm_eval_config.get("limit")
|
||||
if effective_limit is not None:
|
||||
args.extend(["--limit", str(effective_limit)])
|
||||
|
||||
# Output path
|
||||
effective_output = output_path or lm_eval_config.get("output_path")
|
||||
if effective_output:
|
||||
args.extend(["--output_path", effective_output])
|
||||
# Log model responses for post-hoc analysis when output is saved
|
||||
args.append("--log_samples")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def run_lm_eval(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
limit: int | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run lm_eval evaluation."""
|
||||
lm_eval_config = config.get("lm_eval", {})
|
||||
tasks = lm_eval_config.get("tasks", ["mmlu"])
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
exo_base_url = f"http://{host}:{port}"
|
||||
|
||||
# Build args - use native completions or chat completions endpoint directly
|
||||
args = build_lm_eval_args(
|
||||
config, exo_base_url, model, output_path, limit, use_completions=False
|
||||
)
|
||||
logger.info(f"lm_eval command: {' '.join(args)}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would execute the above command")
|
||||
return 0
|
||||
|
||||
try:
|
||||
result = subprocess.run(args, check=False)
|
||||
|
||||
# Print token usage summary from exo
|
||||
try:
|
||||
import httpx
|
||||
|
||||
usage_resp = httpx.get(f"{exo_base_url}/v1/usage", timeout=5)
|
||||
if usage_resp.status_code == 200:
|
||||
usage = usage_resp.json()
|
||||
logger.info("--- Token Usage (Total) ---")
|
||||
logger.info(f" Requests: {usage.get('total_requests', 0)}")
|
||||
logger.info(
|
||||
f" Prompt tokens: {usage.get('total_prompt_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Completion tokens: {usage.get('total_completion_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Reasoning tokens: {usage.get('total_reasoning_tokens', 0)}"
|
||||
)
|
||||
logger.info(f" Total tokens: {usage.get('total_tokens', 0)}")
|
||||
by_model = usage.get("by_model", {})
|
||||
if by_model:
|
||||
for model_name, counters in by_model.items():
|
||||
logger.info(f"--- Token Usage ({model_name}) ---")
|
||||
logger.info(
|
||||
f" Requests: {counters.get('requests', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Prompt tokens: {counters.get('prompt_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Completion tokens: {counters.get('completion_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Reasoning tokens: {counters.get('reasoning_tokens', 0)}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Usage endpoint not available
|
||||
|
||||
return result.returncode
|
||||
except FileNotFoundError:
|
||||
logger.error("lm_eval not found. Install with: uv sync --extra eval")
|
||||
return 1
|
||||
|
||||
|
||||
def run_swe_bench(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run SWE-bench evaluation (placeholder)."""
|
||||
swe_config = config.get("swe_bench", {})
|
||||
|
||||
dataset = swe_config.get("dataset", "princeton-nlp/SWE-bench_Lite")
|
||||
max_workers = swe_config.get("max_workers", 8)
|
||||
predictions_path = output_path or swe_config.get(
|
||||
"predictions_path", "bench/predictions"
|
||||
)
|
||||
|
||||
logger.info("SWE-bench evaluation configuration:")
|
||||
logger.info(f" Dataset: {dataset}")
|
||||
logger.info(f" Model: {model}")
|
||||
logger.info(f" API endpoint: http://{host}:{port}/v1")
|
||||
logger.info(f" Max workers: {max_workers}")
|
||||
logger.info(f" Predictions path: {predictions_path}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] SWE-bench evaluation would be executed")
|
||||
return 0
|
||||
|
||||
logger.warning(
|
||||
"SWE-bench integration is a placeholder. "
|
||||
"Implement swebench inference and evaluation logic as needed."
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def run_custom_eval(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run custom evaluation script."""
|
||||
custom_config = config.get("custom", {})
|
||||
|
||||
script = custom_config.get("script")
|
||||
if not script:
|
||||
logger.error("No script specified in [custom] config section")
|
||||
return 1
|
||||
|
||||
script_path = Path(script)
|
||||
if not script_path.exists():
|
||||
logger.error(f"Custom script not found: {script}")
|
||||
return 1
|
||||
|
||||
script_args = custom_config.get("args", [])
|
||||
if not isinstance(script_args, list):
|
||||
script_args = [str(script_args)]
|
||||
|
||||
# Build environment with exo connection info
|
||||
env = os.environ.copy()
|
||||
env["EXO_HOST"] = host
|
||||
env["EXO_PORT"] = str(port)
|
||||
env["EXO_MODEL"] = model
|
||||
if output_path:
|
||||
env["EXO_OUTPUT_PATH"] = output_path
|
||||
|
||||
cmd = [sys.executable, str(script_path), *script_args]
|
||||
logger.info(f"Custom eval command: {' '.join(cmd)}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would execute the above command")
|
||||
return 0
|
||||
|
||||
result = subprocess.run(cmd, env=env, check=False)
|
||||
return result.returncode
|
||||
|
||||
|
||||
def write_results_metadata(
|
||||
output_path: str,
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
eval_type: EvalType,
|
||||
return_code: int,
|
||||
preview: dict[str, Any] | None,
|
||||
) -> None:
|
||||
"""Write evaluation metadata to a JSON file."""
|
||||
metadata: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"eval_type": eval_type,
|
||||
"model": model,
|
||||
"api_endpoint": f"http://{host}:{port}/v1",
|
||||
"config": config,
|
||||
"return_code": return_code,
|
||||
}
|
||||
|
||||
if preview:
|
||||
metadata["placement"] = {
|
||||
"sharding": preview.get("sharding"),
|
||||
"instance_meta": preview.get("instance_meta"),
|
||||
"instance_id": instance_id_from_instance(preview["instance"])
|
||||
if "instance" in preview
|
||||
else None,
|
||||
}
|
||||
|
||||
output_dir = Path(output_path)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
metadata_path = output_dir / "eval_metadata.json"
|
||||
|
||||
with open(metadata_path, "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False, default=str)
|
||||
|
||||
logger.info(f"Wrote evaluation metadata to: {metadata_path}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point for exo-eval."""
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-eval",
|
||||
description="Evaluation harness for exo inference system.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--config",
|
||||
required=True,
|
||||
help="Path to TOML configuration file",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--host",
|
||||
default=os.environ.get("EXO_HOST", "localhost"),
|
||||
help="exo API host (default: localhost or EXO_HOST env var)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.environ.get("EXO_PORT", "52415")),
|
||||
help="exo API port (default: 52415 or EXO_PORT env var)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
help="Model name/ID to evaluate",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="Output path for results (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit samples per task (overrides config, lm_eval only)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=604800.0,
|
||||
help="HTTP timeout in seconds (default: 604800 = 1 week)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-instance-setup",
|
||||
action="store_true",
|
||||
help="Skip instance creation (assume instance already running)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--pipeline",
|
||||
type=int,
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="Use pipeline sharding with exactly N nodes (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta",
|
||||
choices=["ring", "jaccl", "both"],
|
||||
default=None,
|
||||
help="Instance meta preference (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print commands without executing",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
logger.info(f"exo-eval starting with config: {args.config}")
|
||||
|
||||
try:
|
||||
config = load_config(args.config)
|
||||
except FileNotFoundError as e:
|
||||
logger.error(str(e))
|
||||
return 1
|
||||
except TOMLKitError as e:
|
||||
logger.error(f"Failed to parse config: {e}")
|
||||
return 1
|
||||
|
||||
eval_type = get_eval_type(config)
|
||||
logger.info(f"Evaluation type: {eval_type}")
|
||||
logger.info(f"Model: {args.model}")
|
||||
logger.info(f"API endpoint: http://{args.host}:{args.port}/v1")
|
||||
|
||||
# Apply CLI overrides to instance config
|
||||
if args.pipeline is not None or args.instance_meta is not None:
|
||||
instance_config = config.setdefault("instance", {})
|
||||
if args.pipeline is not None:
|
||||
instance_config["sharding"] = "pipeline"
|
||||
instance_config["min_nodes"] = args.pipeline
|
||||
instance_config["max_nodes"] = args.pipeline
|
||||
logger.info(f"CLI override: pipeline={args.pipeline} nodes")
|
||||
# Limit concurrency for pipeline to avoid GPU timeouts
|
||||
if args.pipeline >= 2:
|
||||
lm_eval_config = config.setdefault("lm_eval", {})
|
||||
lm_eval_config["num_concurrent"] = 4
|
||||
logger.info("CLI override: num_concurrent=4 (pipeline>=2)")
|
||||
if args.instance_meta is not None:
|
||||
instance_config["instance_meta"] = args.instance_meta
|
||||
logger.info(f"CLI override: instance_meta={args.instance_meta}")
|
||||
|
||||
# Check HuggingFace token if required
|
||||
if not check_hf_token(config):
|
||||
return 1
|
||||
|
||||
# Setup instance and resolve model
|
||||
instance_id: str | None = None
|
||||
preview: dict[str, Any] | None = None
|
||||
client: ExoClient | None = None
|
||||
|
||||
if args.skip_instance_setup:
|
||||
# Use model name as-is when skipping instance setup
|
||||
full_model_id = args.model
|
||||
logger.info(f"Using model: {full_model_id} (instance setup skipped)")
|
||||
else:
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
|
||||
# Resolve model
|
||||
try:
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
logger.info(f"Resolved model: {short_id} -> {full_model_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resolve model: {e}")
|
||||
return 1
|
||||
|
||||
instance_id, preview = setup_instance(
|
||||
client, full_model_id, config, args.dry_run
|
||||
)
|
||||
if instance_id is None and not args.dry_run:
|
||||
return 1
|
||||
|
||||
try:
|
||||
# Run evaluation
|
||||
if eval_type == "lm_eval":
|
||||
return_code = run_lm_eval(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.limit,
|
||||
args.dry_run,
|
||||
)
|
||||
elif eval_type == "swe_bench":
|
||||
return_code = run_swe_bench(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.dry_run,
|
||||
)
|
||||
elif eval_type == "custom":
|
||||
return_code = run_custom_eval(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.dry_run,
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown eval type: {eval_type}")
|
||||
return 1
|
||||
|
||||
# Write metadata if output path specified and not dry-run
|
||||
output_path = args.output or config.get(eval_type, {}).get("output_path")
|
||||
if output_path and not args.dry_run:
|
||||
write_results_metadata(
|
||||
output_path,
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
eval_type,
|
||||
return_code,
|
||||
preview,
|
||||
)
|
||||
|
||||
return return_code
|
||||
|
||||
finally:
|
||||
# Teardown instance
|
||||
if instance_id and client and not args.skip_instance_setup and not args.dry_run:
|
||||
teardown_instance(client, instance_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
145
bench/lm_eval_patched.py
Normal file
145
bench/lm_eval_patched.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Patched lm_eval runner that fixes bugs in the upstream library.
|
||||
|
||||
Fixes:
|
||||
- UnboundLocalError on `outputs` in TemplateAPI.amodel_call when API returns error
|
||||
- Prevents eval crash on transient API failures (returns None instead of raising)
|
||||
- Compatibility with transformers 5.x (missing AutoModelForVision2Seq)
|
||||
- sock_read timeout causing connection drops with large request queues
|
||||
|
||||
Usage: python -m bench.lm_eval_patched [lm_eval args...]
|
||||
"""
|
||||
|
||||
# ruff: noqa: I001, E402
|
||||
# pyright: reportMissingTypeStubs=false, reportUnknownVariableType=false
|
||||
# pyright: reportUnknownMemberType=false, reportAny=false, reportUnknownArgumentType=false
|
||||
# pyright: reportPrivateUsage=false, reportUnknownLambdaType=false
|
||||
|
||||
# MUST patch transformers BEFORE any lm_eval imports
|
||||
# AutoModelForVision2Seq/AutoModelForImageTextToText were removed in transformers 5.0
|
||||
# Patch the lazy module's __getattr__ to return stubs for missing classes
|
||||
from transformers.utils import import_utils
|
||||
|
||||
_original_getattr = import_utils._LazyModule.__getattr__
|
||||
|
||||
|
||||
def _patched_getattr(self: object, name: str) -> object:
|
||||
if name in ("AutoModelForVision2Seq", "AutoModelForImageTextToText"):
|
||||
return type(name, (), {}) # Return a stub class
|
||||
return _original_getattr(self, name) # type: ignore
|
||||
|
||||
|
||||
import_utils._LazyModule.__getattr__ = _patched_getattr
|
||||
|
||||
import functools
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _patch_amodel_call() -> None:
|
||||
"""Monkey-patch TemplateAPI.amodel_call to handle the unbound `outputs` variable bug."""
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
|
||||
original: Any = TemplateAPI.amodel_call
|
||||
|
||||
@functools.wraps(original)
|
||||
async def patched_amodel_call(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return await original(self, *args, **kwargs)
|
||||
except (UnboundLocalError, Exception):
|
||||
# Return one empty-string result per request in the batch so the
|
||||
# reorderer doesn't assert on missing coverage.
|
||||
messages = kwargs.get("messages") or (args[2] if len(args) > 2 else [])
|
||||
return [""] * max(len(messages), 1)
|
||||
|
||||
TemplateAPI.amodel_call = patched_amodel_call
|
||||
|
||||
|
||||
def _patch_client_timeout() -> None:
|
||||
"""Patch TemplateAPI.get_batched_requests to disable sock_read timeout.
|
||||
|
||||
By default, aiohttp's ClientTimeout can have a sock_read timeout that causes
|
||||
connections to drop if no data is received for a while. With large request
|
||||
queues, requests may wait a long time before processing starts, causing
|
||||
spurious connection drops and retries that pile up requests.
|
||||
"""
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
|
||||
original_get_batched: Any = TemplateAPI.get_batched_requests
|
||||
|
||||
@functools.wraps(original_get_batched)
|
||||
async def patched_get_batched_requests(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
# Override the timeout to explicitly disable sock_read timeout
|
||||
# This prevents connection drops when requests are queued for a long time
|
||||
original_timeout = getattr(self, "timeout", 604800)
|
||||
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
|
||||
timeout = ClientTimeout(
|
||||
total=original_timeout, sock_read=None, sock_connect=None
|
||||
)
|
||||
|
||||
async with ClientSession(connector=conn, timeout=timeout) as session:
|
||||
# Call the internal async logic with our session
|
||||
return await _run_batched_requests_with_session(
|
||||
self, session, *args, **kwargs
|
||||
)
|
||||
|
||||
async def _run_batched_requests_with_session(
|
||||
self: Any,
|
||||
session: ClientSession,
|
||||
requests: Any,
|
||||
cache_keys: Any = None,
|
||||
ctxlens: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from lm_eval.models.utils import chunks
|
||||
|
||||
eval_logger = logging.getLogger("lm_eval.models.api_models")
|
||||
ctxlens = ctxlens if ctxlens else [None] * len(requests)
|
||||
sem = asyncio.Semaphore(self._concurrent)
|
||||
|
||||
retry_: Any = retry(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
||||
reraise=True,
|
||||
before_sleep=lambda retry_state: eval_logger.info(
|
||||
f"Retry attempt {retry_state.attempt_number}"
|
||||
),
|
||||
)(self.amodel_call)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
retry_(
|
||||
session=session,
|
||||
sem=sem,
|
||||
messages=message,
|
||||
cache_keys=cache_key,
|
||||
ctxlens=ctxlen,
|
||||
gen_kwargs=copy.deepcopy(kwargs.get("gen_kwargs")),
|
||||
**{k: v for k, v in kwargs.items() if k != "gen_kwargs"},
|
||||
)
|
||||
)
|
||||
for message, cache_key, ctxlen in zip(
|
||||
chunks(requests, n=self._batch_size),
|
||||
chunks(cache_keys, n=self._batch_size),
|
||||
chunks(ctxlens, n=self._batch_size),
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
|
||||
|
||||
TemplateAPI.get_batched_requests = patched_get_batched_requests
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_patch_amodel_call()
|
||||
_patch_client_timeout()
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
|
||||
cli_evaluate()
|
||||
290
bench/stats_dashboard.html
Normal file
290
bench/stats_dashboard.html
Normal file
@@ -0,0 +1,290 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>exo Usage Stats</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'SF Mono', 'Menlo', monospace;
|
||||
background: #1a1a2e;
|
||||
color: #e0e0e0;
|
||||
padding: 24px;
|
||||
min-height: 100vh;
|
||||
}
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 24px;
|
||||
padding-bottom: 16px;
|
||||
border-bottom: 1px solid #333;
|
||||
}
|
||||
.header h1 {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
color: #fff;
|
||||
}
|
||||
.status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 13px;
|
||||
color: #888;
|
||||
}
|
||||
.status-dot {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: #666;
|
||||
}
|
||||
.status-dot.connected { background: #4caf50; }
|
||||
.status-dot.error { background: #f44336; }
|
||||
.config {
|
||||
margin-bottom: 24px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
.config label {
|
||||
font-size: 12px;
|
||||
color: #888;
|
||||
}
|
||||
.config input {
|
||||
background: #252540;
|
||||
border: 1px solid #444;
|
||||
border-radius: 4px;
|
||||
color: #e0e0e0;
|
||||
padding: 4px 8px;
|
||||
font-size: 13px;
|
||||
font-family: inherit;
|
||||
width: 280px;
|
||||
}
|
||||
.section {
|
||||
background: #252540;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.section h2 {
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: #aaa;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.stat-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 16px;
|
||||
}
|
||||
.stat-card {
|
||||
background: #1a1a2e;
|
||||
border-radius: 6px;
|
||||
padding: 16px;
|
||||
}
|
||||
.stat-label {
|
||||
font-size: 11px;
|
||||
color: #888;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.stat-value {
|
||||
font-size: 28px;
|
||||
font-weight: 700;
|
||||
color: #fff;
|
||||
}
|
||||
.stat-rate {
|
||||
font-size: 12px;
|
||||
color: #4caf50;
|
||||
margin-top: 4px;
|
||||
}
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 13px;
|
||||
}
|
||||
th {
|
||||
text-align: left;
|
||||
padding: 8px 12px;
|
||||
color: #888;
|
||||
font-weight: 500;
|
||||
border-bottom: 1px solid #333;
|
||||
font-size: 11px;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
td {
|
||||
padding: 8px 12px;
|
||||
border-bottom: 1px solid #2a2a45;
|
||||
}
|
||||
td.num {
|
||||
text-align: right;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
.model-name {
|
||||
color: #7c9eff;
|
||||
max-width: 300px;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.empty-state {
|
||||
color: #666;
|
||||
font-style: italic;
|
||||
padding: 16px 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>exo Usage Stats</h1>
|
||||
<div class="status">
|
||||
<div class="status-dot" id="statusDot"></div>
|
||||
<span id="statusText">connecting...</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="config">
|
||||
<label for="baseUrl">Base URL:</label>
|
||||
<input type="text" id="baseUrl" value="http://mac8-1:52415">
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Totals</h2>
|
||||
<div class="stat-grid">
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Requests</div>
|
||||
<div class="stat-value" id="totalRequests">0</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Prompt Tokens</div>
|
||||
<div class="stat-value" id="totalPrompt">0</div>
|
||||
<div class="stat-rate" id="promptRate"></div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Completion Tokens</div>
|
||||
<div class="stat-value" id="totalCompletion">0</div>
|
||||
<div class="stat-rate" id="completionRate"></div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Reasoning Tokens</div>
|
||||
<div class="stat-value" id="totalReasoning">0</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Total Tokens</div>
|
||||
<div class="stat-value" id="totalTokens">0</div>
|
||||
<div class="stat-rate" id="totalRate"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Per-Model Breakdown</h2>
|
||||
<div id="modelTable">
|
||||
<div class="empty-state">No data yet</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
|
||||
function fmt(n) {
|
||||
return n.toLocaleString();
|
||||
}
|
||||
|
||||
// Track first non-zero timestamp for overall average rate
|
||||
let firstSeenTime = null;
|
||||
let firstSeenTokens = { prompt: 0, completion: 0, total: 0 };
|
||||
|
||||
function setRate(id, currentTokens, tokenType) {
|
||||
const el = document.getElementById(id);
|
||||
if (firstSeenTime === null || currentTokens <= firstSeenTokens[tokenType]) {
|
||||
el.textContent = '';
|
||||
return;
|
||||
}
|
||||
const elapsed = (performance.now() / 1000) - firstSeenTime;
|
||||
if (elapsed <= 0) { el.textContent = ''; return; }
|
||||
const delta = currentTokens - firstSeenTokens[tokenType];
|
||||
const avg = delta / elapsed;
|
||||
el.textContent = fmt(Math.round(avg)) + ' tok/s avg';
|
||||
}
|
||||
|
||||
function renderModelTable(byModel) {
|
||||
const container = document.getElementById('modelTable');
|
||||
const models = Object.entries(byModel);
|
||||
if (models.length === 0) {
|
||||
container.innerHTML = '<div class="empty-state">No data yet</div>';
|
||||
return;
|
||||
}
|
||||
let html = '<table><thead><tr>';
|
||||
html += '<th>Model</th><th style="text-align:right">Requests</th>';
|
||||
html += '<th style="text-align:right">Prompt</th>';
|
||||
html += '<th style="text-align:right">Completion</th>';
|
||||
html += '<th style="text-align:right">Reasoning</th>';
|
||||
html += '<th style="text-align:right">Total</th>';
|
||||
html += '</tr></thead><tbody>';
|
||||
for (const [name, counters] of models) {
|
||||
const total = (counters.prompt_tokens || 0) + (counters.completion_tokens || 0);
|
||||
html += '<tr>';
|
||||
html += `<td class="model-name" title="${name}">${name}</td>`;
|
||||
html += `<td class="num">${fmt(counters.requests || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.prompt_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.completion_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.reasoning_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(total)}</td>`;
|
||||
html += '</tr>';
|
||||
}
|
||||
html += '</tbody></table>';
|
||||
container.innerHTML = html;
|
||||
}
|
||||
|
||||
async function poll() {
|
||||
const baseUrl = document.getElementById('baseUrl').value.replace(/\/+$/, '');
|
||||
const dot = document.getElementById('statusDot');
|
||||
const text = document.getElementById('statusText');
|
||||
|
||||
try {
|
||||
const resp = await fetch(baseUrl + '/v1/usage');
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
|
||||
const data = await resp.json();
|
||||
|
||||
dot.className = 'status-dot connected';
|
||||
text.textContent = 'connected';
|
||||
|
||||
|
||||
document.getElementById('totalRequests').textContent = fmt(data.total_requests || 0);
|
||||
document.getElementById('totalPrompt').textContent = fmt(data.total_prompt_tokens || 0);
|
||||
document.getElementById('totalCompletion').textContent = fmt(data.total_completion_tokens || 0);
|
||||
document.getElementById('totalReasoning').textContent = fmt(data.total_reasoning_tokens || 0);
|
||||
document.getElementById('totalTokens').textContent = fmt(data.total_tokens || 0);
|
||||
|
||||
// Record first non-zero reading as baseline
|
||||
if (firstSeenTime === null && (data.total_tokens || 0) > 0) {
|
||||
firstSeenTime = performance.now() / 1000;
|
||||
firstSeenTokens = {
|
||||
prompt: data.total_prompt_tokens || 0,
|
||||
completion: data.total_completion_tokens || 0,
|
||||
total: data.total_tokens || 0,
|
||||
};
|
||||
}
|
||||
|
||||
setRate('promptRate', data.total_prompt_tokens || 0, 'prompt');
|
||||
setRate('completionRate', data.total_completion_tokens || 0, 'completion');
|
||||
setRate('totalRate', data.total_tokens || 0, 'total');
|
||||
|
||||
renderModelTable(data.by_model || {});
|
||||
|
||||
} catch (e) {
|
||||
dot.className = 'status-dot error';
|
||||
text.textContent = e.message || 'error';
|
||||
}
|
||||
}
|
||||
|
||||
poll();
|
||||
setInterval(poll, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -3,12 +3,28 @@
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Filter source to ONLY include package.json and package-lock.json
|
||||
# This ensures prettier-svelte only rebuilds when lockfiles change
|
||||
dashboardLockfileSrc = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
isDashboardDir = baseName == "dashboard" && type == "directory";
|
||||
isPackageFile =
|
||||
(lib.hasInfix "/dashboard/" path || lib.hasSuffix "/dashboard" (builtins.dirOf path))
|
||||
&& (baseName == "package.json" || baseName == "package-lock.json");
|
||||
in
|
||||
isDashboardDir || isPackageFile;
|
||||
};
|
||||
|
||||
# Stub source with lockfiles and minimal files for build to succeed
|
||||
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
|
||||
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
|
||||
mkdir -p $out
|
||||
cp ${inputs.self}/dashboard/package.json $out/
|
||||
cp ${inputs.self}/dashboard/package-lock.json $out/
|
||||
cp ${dashboardLockfileSrc}/dashboard/package.json $out/
|
||||
cp ${dashboardLockfileSrc}/dashboard/package-lock.json $out/
|
||||
# Minimal files so vite build succeeds (produces empty output)
|
||||
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
|
||||
mkdir -p $out/src
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
ttftMs,
|
||||
tps,
|
||||
totalTokens,
|
||||
cancelRequest,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import ChatAttachments from "./ChatAttachments.svelte";
|
||||
import ImageParamsPanel from "./ImageParamsPanel.svelte";
|
||||
@@ -606,15 +605,37 @@
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
{#if loading}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => cancelRequest()}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/50 text-exo-light-gray border border-exo-medium-gray/50 hover:border-red-500/50 hover:text-red-400 cursor-pointer"
|
||||
>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
@@ -623,81 +644,47 @@
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="hidden sm:inline">CANCEL</span>
|
||||
<span class="sm:hidden">X</span>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
|
||||
@@ -464,7 +464,6 @@ class AppStore {
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
private previousNodeIds: Set<string> = new Set();
|
||||
private activeAbortController: AbortController | null = null;
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
@@ -1747,9 +1746,6 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
this.ttftMs = null;
|
||||
@@ -1884,7 +1880,6 @@ class AppStore {
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
}),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -1980,9 +1975,6 @@ class AppStore {
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
return;
|
||||
}
|
||||
console.error("Error sending message:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -1991,7 +1983,6 @@ class AppStore {
|
||||
"Failed to get response",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -2012,9 +2003,6 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
|
||||
@@ -2100,7 +2088,6 @@ class AppStore {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -2210,19 +2197,6 @@ class AppStore {
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
// Clean up the "Generating image..." message on cancellation
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "Cancelled";
|
||||
msg.attachments = [];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
return;
|
||||
}
|
||||
console.error("Error generating image:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -2231,7 +2205,6 @@ class AppStore {
|
||||
"Failed to generate image",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2255,9 +2228,6 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
|
||||
@@ -2366,7 +2336,6 @@ class AppStore {
|
||||
const apiResponse = await fetch("/v1/images/edits", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!apiResponse.ok) {
|
||||
@@ -2438,19 +2407,6 @@ class AppStore {
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
// Clean up the "Editing image..." message on cancellation
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "cancelled";
|
||||
msg.attachments = [];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
return;
|
||||
}
|
||||
console.error("Error editing image:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -2459,24 +2415,11 @@ class AppStore {
|
||||
"Failed to edit image",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel an in-flight request by aborting the active fetch
|
||||
*/
|
||||
cancelRequest(): void {
|
||||
if (this.activeAbortController) {
|
||||
this.activeAbortController.abort();
|
||||
this.activeAbortController = null;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear current chat and go back to welcome state
|
||||
*/
|
||||
@@ -2613,7 +2556,6 @@ export const editMessage = (messageId: string, newContent: string) =>
|
||||
export const editAndRegenerate = (messageId: string, newContent: string) =>
|
||||
appStore.editAndRegenerate(messageId, newContent);
|
||||
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
|
||||
export const cancelRequest = () => appStore.cancelRequest();
|
||||
|
||||
// Conversation actions
|
||||
export const conversations = () => appStore.conversations;
|
||||
|
||||
@@ -13,13 +13,14 @@ dependencies = [
|
||||
"filelock>=3.18.0",
|
||||
"rustworkx>=0.17.1",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"typer", # for huggingface-cli
|
||||
"psutil>=7.0.0",
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"mlx-lm==0.30.5",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -34,6 +35,7 @@ dependencies = [
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
exo-eval = "bench.exo_eval:main"
|
||||
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
@@ -51,6 +53,9 @@ dev = [
|
||||
# cuda = [
|
||||
# "mlx[cuda]==0.26.3",
|
||||
# ]
|
||||
eval = [
|
||||
"lm_eval[api]",
|
||||
]
|
||||
|
||||
###
|
||||
# workspace configuration
|
||||
@@ -66,6 +71,7 @@ exo_pyo3_bindings = { workspace = true }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm.git", branch = "main" }
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.9,<0.9.0"]
|
||||
|
||||
@@ -121,11 +121,20 @@ async def ensure_models_dir() -> Path:
|
||||
|
||||
|
||||
async def delete_model(model_id: ModelId) -> bool:
|
||||
model_dir = await ensure_models_dir() / model_id.normalize()
|
||||
if not await aios.path.exists(model_dir):
|
||||
return False
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
return True
|
||||
models_dir = await ensure_models_dir()
|
||||
model_dir = models_dir / model_id.normalize()
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
deleted = False
|
||||
if await aios.path.exists(model_dir):
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
deleted = True
|
||||
|
||||
# Also clear cache
|
||||
if await aios.path.exists(cache_dir):
|
||||
await asyncio.to_thread(shutil.rmtree, cache_dir, ignore_errors=False)
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
async def seed_models(seed_dir: str | Path):
|
||||
@@ -151,16 +160,28 @@ async def fetch_file_list_with_cache(
|
||||
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
)
|
||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
|
||||
return file_list
|
||||
|
||||
# Always try fresh first
|
||||
try:
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
)
|
||||
# Update cache with fresh data
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
|
||||
)
|
||||
return file_list
|
||||
except Exception as e:
|
||||
# Fetch failed - try cache fallback
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
# No cache available, propagate the error
|
||||
raise
|
||||
|
||||
|
||||
async def fetch_file_list_with_retry(
|
||||
@@ -332,8 +353,28 @@ async def _download_file(
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
) -> Path:
|
||||
if await aios.path.exists(target_dir / path):
|
||||
return target_dir / path
|
||||
target_path = target_dir / path
|
||||
|
||||
if await aios.path.exists(target_path):
|
||||
local_size = (await aios.stat(target_path)).st_size
|
||||
|
||||
# Try to verify against remote, but allow offline operation
|
||||
try:
|
||||
remote_size, _ = await file_meta(model_id, revision, path)
|
||||
if local_size != remote_size:
|
||||
logger.info(
|
||||
f"File {path} size mismatch (local={local_size}, remote={remote_size}), re-downloading"
|
||||
)
|
||||
await aios.remove(target_path)
|
||||
else:
|
||||
return target_path
|
||||
except Exception as e:
|
||||
# Offline or network error - trust local file
|
||||
logger.debug(
|
||||
f"Could not verify {path} against remote (offline?): {e}, using local file"
|
||||
)
|
||||
return target_path
|
||||
|
||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(model_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
@@ -542,17 +583,26 @@ async def download_shard(
|
||||
async def on_progress_wrapper(
|
||||
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
|
||||
) -> None:
|
||||
start_time = (
|
||||
file_progress[file.path].start_time
|
||||
if file.path in file_progress
|
||||
else time.time()
|
||||
)
|
||||
downloaded_this_session = (
|
||||
file_progress[file.path].downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - file_progress[file.path].downloaded.in_bytes)
|
||||
if file.path in file_progress
|
||||
else curr_bytes
|
||||
previous_progress = file_progress.get(file.path)
|
||||
|
||||
# Detect re-download: curr_bytes < previous downloaded means file was deleted and restarted
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
# Fresh download or re-download: reset tracking
|
||||
start_time = time.time()
|
||||
downloaded_this_session = curr_bytes
|
||||
else:
|
||||
# Continuing download: accumulate
|
||||
start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
speed = (
|
||||
downloaded_this_session / (time.time() - start_time)
|
||||
if time.time() - start_time > 0
|
||||
|
||||
0
src/exo/download/tests/__init__.py
Normal file
0
src/exo/download/tests/__init__.py
Normal file
451
src/exo/download/tests/test_download_verification.py
Normal file
451
src/exo/download/tests/test_download_verification.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""Tests for download verification and cache behavior."""
|
||||
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from exo.download.download_utils import (
|
||||
delete_model,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_id() -> ModelId:
|
||||
return ModelId("test-org/test-model")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
|
||||
"""Set up a temporary models directory for testing."""
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
yield models_dir
|
||||
|
||||
|
||||
class TestFileVerification:
|
||||
"""Tests for file size verification in _download_file."""
|
||||
|
||||
async def test_redownload_when_file_size_changes_upstream(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that files with mismatched sizes are re-downloaded."""
|
||||
# Import inside test to allow patching
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file with wrong size
|
||||
local_file = target_dir / "test.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"local content") # 13 bytes
|
||||
|
||||
remote_size = 1000 # Different from local
|
||||
remote_hash = "abc123"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(remote_size, remote_hash),
|
||||
) as mock_file_meta,
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
# Set up mock HTTP response for re-download
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.content.read = AsyncMock( # pyright: ignore[reportAny]
|
||||
side_effect=[b"x" * remote_size, b""]
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=mock_response
|
||||
)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
mock_session_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=mock_session
|
||||
)
|
||||
mock_session_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
|
||||
# Mock calc_hash to return the expected hash
|
||||
with patch(
|
||||
"exo.download.download_utils.calc_hash",
|
||||
new_callable=AsyncMock,
|
||||
return_value=remote_hash,
|
||||
):
|
||||
await _download_file(model_id, "main", "test.safetensors", target_dir)
|
||||
|
||||
# file_meta should be called twice: once for verification, once for download
|
||||
assert mock_file_meta.call_count == 2
|
||||
|
||||
async def test_skip_download_when_file_size_matches(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that files with matching sizes are not re-downloaded."""
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file
|
||||
local_file = target_dir / "test.safetensors"
|
||||
local_content = b"local content"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(local_content)
|
||||
|
||||
remote_size = len(local_content) # Same as local
|
||||
remote_hash = "abc123"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(remote_size, remote_hash),
|
||||
) as mock_file_meta,
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
result = await _download_file(
|
||||
model_id, "main", "test.safetensors", target_dir
|
||||
)
|
||||
|
||||
# Should return immediately without downloading
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_called_once()
|
||||
mock_session_factory.assert_not_called()
|
||||
|
||||
async def test_offline_fallback_uses_local_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that local files are used when network is unavailable."""
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file
|
||||
local_file = target_dir / "test.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"local content")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
result = await _download_file(
|
||||
model_id, "main", "test.safetensors", target_dir
|
||||
)
|
||||
|
||||
# Should return local file without attempting download
|
||||
assert result == local_file
|
||||
mock_session_factory.assert_not_called()
|
||||
|
||||
|
||||
class TestFileListCache:
|
||||
"""Tests for file list caching behavior."""
|
||||
|
||||
async def test_fetch_fresh_and_update_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that fresh data is fetched and cache is updated."""
|
||||
models_dir = tmp_path / "models"
|
||||
|
||||
file_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
FileListEntry(type="file", path="config.json", size=100),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
return_value=file_list,
|
||||
) as mock_fetch,
|
||||
):
|
||||
result = await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
assert result == file_list
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
# Verify cache was written
|
||||
cache_file = (
|
||||
models_dir
|
||||
/ "caches"
|
||||
/ model_id.normalize()
|
||||
/ f"{model_id.normalize()}--main--file_list.json"
|
||||
)
|
||||
assert await aios.path.exists(cache_file)
|
||||
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
cached_data = TypeAdapter(list[FileListEntry]).validate_json(
|
||||
await f.read()
|
||||
)
|
||||
assert cached_data == file_list
|
||||
|
||||
async def test_fallback_to_cache_when_fetch_fails(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that cached data is used when fetch fails."""
|
||||
models_dir = tmp_path / "models"
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Create cache file
|
||||
cached_file_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
]
|
||||
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(cached_file_list).decode()
|
||||
)
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
):
|
||||
result = await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
assert result == cached_file_list
|
||||
|
||||
async def test_error_propagates_when_no_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that errors propagate when fetch fails and no cache exists."""
|
||||
models_dir = tmp_path / "models"
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
pytest.raises(Exception, match="Network error"),
|
||||
):
|
||||
await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
|
||||
class TestModelDeletion:
|
||||
"""Tests for model deletion including cache cleanup."""
|
||||
|
||||
async def test_delete_model_clears_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that deleting a model also deletes its cache."""
|
||||
models_dir = tmp_path / "models"
|
||||
model_dir = models_dir / model_id.normalize()
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
# Create model and cache directories
|
||||
await aios.makedirs(model_dir, exist_ok=True)
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Add some files
|
||||
async with aiofiles.open(model_dir / "model.safetensors", "w") as f:
|
||||
await f.write("model data")
|
||||
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
|
||||
await f.write("[]")
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
assert result is True
|
||||
assert not await aios.path.exists(model_dir)
|
||||
assert not await aios.path.exists(cache_dir)
|
||||
|
||||
async def test_delete_model_only_cache_exists(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test deleting when only cache exists (model already deleted)."""
|
||||
models_dir = tmp_path / "models"
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
# Only create cache directory
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
|
||||
await f.write("[]")
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
# Returns False because model dir didn't exist
|
||||
assert result is False
|
||||
# But cache should still be cleaned up
|
||||
assert not await aios.path.exists(cache_dir)
|
||||
|
||||
async def test_delete_nonexistent_model(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test deleting a model that doesn't exist."""
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestProgressResetOnRedownload:
|
||||
"""Tests for progress tracking when files are re-downloaded."""
|
||||
|
||||
async def test_progress_resets_correctly_on_redownload(
|
||||
self, model_id: ModelId
|
||||
) -> None:
|
||||
"""Test that progress tracking resets when a file is re-downloaded.
|
||||
|
||||
When a file is deleted and re-downloaded (due to size mismatch),
|
||||
the progress tracking should reset rather than calculating negative
|
||||
downloaded_this_session values.
|
||||
"""
|
||||
# Simulate file_progress dict as it exists in download_shard
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
|
||||
# Initialize with old file progress (simulating existing large file)
|
||||
old_file_size = 1_500_000_000 # 1.5 GB
|
||||
file_progress["model.safetensors"] = RepoFileDownloadProgress(
|
||||
repo_id=model_id,
|
||||
repo_revision="main",
|
||||
file_path="model.safetensors",
|
||||
downloaded=Memory.from_bytes(old_file_size),
|
||||
downloaded_this_session=Memory.from_bytes(0),
|
||||
total=Memory.from_bytes(old_file_size),
|
||||
speed=0,
|
||||
eta=timedelta(0),
|
||||
status="not_started",
|
||||
start_time=time.time() - 10, # Started 10 seconds ago
|
||||
)
|
||||
|
||||
# Simulate the logic from on_progress_wrapper after re-download starts
|
||||
# This is the exact logic from the fixed on_progress_wrapper
|
||||
curr_bytes = 100_000 # 100 KB - new download just started
|
||||
previous_progress = file_progress.get("model.safetensors")
|
||||
|
||||
# Detect re-download: curr_bytes < previous downloaded
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
# Fresh download or re-download: reset tracking
|
||||
start_time = time.time()
|
||||
downloaded_this_session = curr_bytes
|
||||
else:
|
||||
# Continuing download: accumulate
|
||||
start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
# Key assertions
|
||||
assert is_redownload is True, "Should detect re-download scenario"
|
||||
assert downloaded_this_session == curr_bytes, (
|
||||
"downloaded_this_session should equal curr_bytes on re-download"
|
||||
)
|
||||
assert downloaded_this_session > 0, (
|
||||
"downloaded_this_session should be positive, not negative"
|
||||
)
|
||||
|
||||
# Calculate speed (should be positive)
|
||||
elapsed = time.time() - start_time
|
||||
speed = downloaded_this_session / elapsed if elapsed > 0 else 0
|
||||
assert speed >= 0, "Speed should be non-negative"
|
||||
|
||||
async def test_progress_accumulates_on_continuing_download(
|
||||
self, model_id: ModelId
|
||||
) -> None:
|
||||
"""Test that progress accumulates correctly for continuing downloads.
|
||||
|
||||
When a download continues from where it left off (resume),
|
||||
the progress should accumulate correctly.
|
||||
"""
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
|
||||
# Initialize with partial download progress
|
||||
initial_downloaded = 500_000 # 500 KB already downloaded
|
||||
start_time = time.time() - 5 # Started 5 seconds ago
|
||||
file_progress["model.safetensors"] = RepoFileDownloadProgress(
|
||||
repo_id=model_id,
|
||||
repo_revision="main",
|
||||
file_path="model.safetensors",
|
||||
downloaded=Memory.from_bytes(initial_downloaded),
|
||||
downloaded_this_session=Memory.from_bytes(initial_downloaded),
|
||||
total=Memory.from_bytes(1_000_000),
|
||||
speed=100_000,
|
||||
eta=timedelta(seconds=5),
|
||||
status="in_progress",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
# Progress callback with more bytes downloaded
|
||||
curr_bytes = 600_000 # 600 KB - continuing download
|
||||
previous_progress = file_progress.get("model.safetensors")
|
||||
|
||||
# This is NOT a re-download (curr_bytes > previous downloaded)
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
downloaded_this_session = curr_bytes
|
||||
used_start_time = time.time()
|
||||
else:
|
||||
used_start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
# Key assertions
|
||||
assert is_redownload is False, (
|
||||
"Should NOT detect re-download for continuing download"
|
||||
)
|
||||
assert used_start_time == start_time, "Should preserve original start_time"
|
||||
expected_session = initial_downloaded + (curr_bytes - initial_downloaded)
|
||||
assert downloaded_this_session == expected_session, (
|
||||
f"Should accumulate: {downloaded_this_session} == {expected_session}"
|
||||
)
|
||||
assert downloaded_this_session == 600_000, (
|
||||
"downloaded_this_session should equal total downloaded so far"
|
||||
)
|
||||
@@ -1,10 +1,11 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Literal, cast
|
||||
from typing import Annotated, Any, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
@@ -42,6 +43,7 @@ from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
CompletionTokensDetails,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
@@ -57,6 +59,8 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
ImageListItem,
|
||||
ImageListResponse,
|
||||
Logprobs,
|
||||
LogprobsContentItem,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -66,8 +70,10 @@ from exo.shared.types.api import (
|
||||
StartDownloadResponse,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
CompletionChunk,
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
InputImageChunk,
|
||||
@@ -88,7 +94,6 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
@@ -108,14 +113,43 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
_THINK_TAG_RE = re.compile(r"<think>.*?</think>", re.DOTALL)
|
||||
|
||||
|
||||
def _strip_think_tags(text: str) -> str:
|
||||
"""Strip <think>...</think> blocks from response text.
|
||||
|
||||
These tags are an artifact of GPT-OSS channel parsing, not part of the
|
||||
model's intended output. The OpenAI API content field should not contain them.
|
||||
"""
|
||||
return _THINK_TAG_RE.sub("", text).lstrip()
|
||||
|
||||
|
||||
def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None) -> str:
|
||||
return f"image/{image_format or 'png'}"
|
||||
|
||||
|
||||
def _build_logprobs(chunk: TokenChunk) -> Logprobs:
|
||||
"""Convert flat logprob fields to OpenAI Logprobs format."""
|
||||
return Logprobs(
|
||||
content=[
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob if chunk.logprob is not None else 0.0,
|
||||
bytes=list(chunk.text.encode("utf-8")),
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
logprobs: Logprobs | None = None
|
||||
if isinstance(chunk, TokenChunk) and chunk.logprob is not None:
|
||||
logprobs = _build_logprobs(chunk)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -136,6 +170,7 @@ def chunk_to_response(
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
],
|
||||
),
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -198,7 +233,8 @@ class API:
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[
|
||||
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
|
||||
CommandId,
|
||||
Sender[TokenChunk | ErrorChunk | ToolCallChunk | CompletionChunk],
|
||||
] = {}
|
||||
self._image_generation_queues: dict[
|
||||
CommandId, Sender[ImageChunk | ErrorChunk]
|
||||
@@ -206,6 +242,9 @@ class API:
|
||||
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
# Accumulated usage stats per instance (keyed by model id)
|
||||
self._usage_by_model: dict[str, dict[str, int]] = {}
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
logger.info("Resetting API State")
|
||||
self.state = State()
|
||||
@@ -272,6 +311,48 @@ class API:
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.get("/v1/usage")(self.get_usage)
|
||||
|
||||
def get_usage(self) -> dict[str, Any]:
|
||||
"""Return accumulated token usage per model instance."""
|
||||
total_requests = 0
|
||||
total_prompt = 0
|
||||
total_completion = 0
|
||||
total_reasoning = 0
|
||||
for counters in self._usage_by_model.values():
|
||||
total_requests += counters.get("requests", 0)
|
||||
total_prompt += counters.get("prompt_tokens", 0)
|
||||
total_completion += counters.get("completion_tokens", 0)
|
||||
total_reasoning += counters.get("reasoning_tokens", 0)
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"total_prompt_tokens": total_prompt,
|
||||
"total_completion_tokens": total_completion,
|
||||
"total_reasoning_tokens": total_reasoning,
|
||||
"total_tokens": total_prompt + total_completion,
|
||||
"by_model": self._usage_by_model,
|
||||
}
|
||||
|
||||
def _accumulate_usage(
|
||||
self,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
reasoning_tokens: int,
|
||||
) -> None:
|
||||
"""Accumulate usage stats for a model instance."""
|
||||
if model not in self._usage_by_model:
|
||||
self._usage_by_model[model] = {
|
||||
"requests": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
}
|
||||
counters = self._usage_by_model[model]
|
||||
counters["requests"] += 1
|
||||
counters["prompt_tokens"] += prompt_tokens
|
||||
counters["completion_tokens"] += completion_tokens
|
||||
counters["reasoning_tokens"] += reasoning_tokens
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
@@ -493,30 +574,40 @@ class API:
|
||||
)
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
self, command_id: CommandId, timeout: float = 60000.0
|
||||
) -> AsyncGenerator[TokenChunk | ErrorChunk | ToolCallChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion.
|
||||
|
||||
Args:
|
||||
timeout: Max seconds to wait for the next chunk before aborting.
|
||||
"""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk
|
||||
TokenChunk | ErrorChunk | ToolCallChunk
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
with anyio.fail_after(timeout):
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"Chat completion timed out after {timeout}s (command_id={command_id})"
|
||||
)
|
||||
yield ErrorChunk(
|
||||
model=ModelId("unknown"),
|
||||
finish_reason="error",
|
||||
error_message=f"Request timed out after {timeout}s",
|
||||
)
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
if command_id in self._chat_completion_queues:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
@@ -527,7 +618,7 @@ class API:
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
assert not isinstance(chunk, ImageChunk)
|
||||
if chunk.finish_reason == "error":
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
@@ -547,6 +638,15 @@ class API:
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
# Accumulate usage stats from the final chunk
|
||||
if isinstance(chunk, TokenChunk) and chunk.stats is not None:
|
||||
s = chunk.stats
|
||||
self._accumulate_usage(
|
||||
model=chunk.model,
|
||||
prompt_tokens=s.prompt_tokens,
|
||||
completion_tokens=s.generation_tokens,
|
||||
reasoning_tokens=s.reasoning_tokens,
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
@@ -556,10 +656,14 @@ class API:
|
||||
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
logprobs_items: list[LogprobsContentItem] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
# Skip CompletionChunk - it's for the legacy completions API
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -571,6 +675,16 @@ class API:
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.stats is not None:
|
||||
stats = chunk.stats
|
||||
if chunk.logprob is not None:
|
||||
lp = _build_logprobs(chunk)
|
||||
if lp.content:
|
||||
if len(lp.content) != 1:
|
||||
logger.warning(
|
||||
f"Expected 1 logprobs content item per chunk, got {len(lp.content)}"
|
||||
)
|
||||
logprobs_items.append(lp.content[0])
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
@@ -585,9 +699,33 @@ class API:
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_text = _strip_think_tags("".join(text_parts))
|
||||
assert model is not None
|
||||
|
||||
logprobs: Logprobs | None = None
|
||||
if logprobs_items:
|
||||
logprobs = Logprobs(content=logprobs_items)
|
||||
|
||||
usage: Usage | None = None
|
||||
if stats is not None:
|
||||
completion_tokens = stats.generation_tokens
|
||||
usage = Usage(
|
||||
prompt_tokens=stats.prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=stats.prompt_tokens + completion_tokens,
|
||||
completion_tokens_details=CompletionTokensDetails(
|
||||
reasoning_tokens=stats.reasoning_tokens,
|
||||
)
|
||||
if stats.reasoning_tokens > 0
|
||||
else None,
|
||||
)
|
||||
self._accumulate_usage(
|
||||
model=model or "unknown",
|
||||
prompt_tokens=stats.prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
reasoning_tokens=stats.reasoning_tokens,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -600,9 +738,11 @@ class API:
|
||||
content=combined_text,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
@@ -616,7 +756,7 @@ class API:
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
@@ -627,6 +767,7 @@ class API:
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
@@ -637,13 +778,12 @@ class API:
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
stats = chunk.stats or stats
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_text = _strip_think_tags("".join(text_parts))
|
||||
assert model is not None
|
||||
|
||||
resp = BenchChatCompletionResponse(
|
||||
@@ -694,7 +834,14 @@ class API:
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
try:
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
except BaseException:
|
||||
# Ensure task cleanup if handler is cancelled before _chat_chunk_stream's finally runs
|
||||
with contextlib.suppress(Exception):
|
||||
await self._send(TaskFinished(finished_command_id=command.command_id))
|
||||
self._chat_completion_queues.pop(command.command_id, None)
|
||||
raise
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
@@ -900,11 +1047,6 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -986,11 +1128,6 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
|
||||
@@ -13,6 +13,7 @@ from exo.master.placement import (
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Completion,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
@@ -21,7 +22,6 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -36,12 +36,14 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion as ChatCompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
Completion as CompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageEdits as ImageEditsTask,
|
||||
)
|
||||
@@ -160,6 +162,48 @@ class Master:
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case Completion():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=CompletionTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageGeneration():
|
||||
for instance in self.state.instances.values():
|
||||
@@ -248,7 +292,7 @@ class Master:
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
@@ -260,7 +304,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -270,7 +314,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement, self.state.tasks
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
@@ -280,29 +324,16 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
)
|
||||
)
|
||||
self.command_task_mapping.pop(
|
||||
task_id = self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
if task_id is not None:
|
||||
generated_events.append(TaskDeleted(task_id=task_id))
|
||||
else:
|
||||
logger.debug(
|
||||
f"TaskFinished for unknown command_id={command.finished_command_id} (already cleaned up)"
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
|
||||
@@ -20,15 +20,9 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -186,7 +180,6 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -202,18 +195,6 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
|
||||
@@ -413,9 +413,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
model_id=ModelId("exolabs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
@@ -428,7 +428,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -442,7 +442,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -457,7 +457,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
model_id=ModelId("exolabs/FLUX.1-dev"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
@@ -470,7 +470,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -484,7 +484,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -499,7 +499,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"flux1-krea-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-Krea-dev"),
|
||||
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
|
||||
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
@@ -541,9 +541,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
model_id=ModelId("exolabs/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
@@ -551,10 +551,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -575,9 +575,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
@@ -585,10 +585,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -610,6 +610,92 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _generate_image_model_quant_variants(
|
||||
base_name: str,
|
||||
base_card: ModelCard,
|
||||
) -> dict[str, ModelCard]:
|
||||
"""Create quantized variants of an image model card.
|
||||
|
||||
Only the transformer component is quantized; text encoders stay at bf16.
|
||||
Sizes are calculated exactly from the base card's component sizes.
|
||||
"""
|
||||
if base_card.components is None:
|
||||
raise ValueError(f"Image model {base_name} must have components defined")
|
||||
|
||||
# quantizations = [8, 6, 5, 4, 3]
|
||||
quantizations = [8, 4]
|
||||
|
||||
num_transformer_bytes = next(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name == "transformer"
|
||||
)
|
||||
|
||||
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
|
||||
|
||||
remaining_bytes = Memory.from_bytes(
|
||||
sum(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name != "transformer"
|
||||
)
|
||||
)
|
||||
|
||||
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
|
||||
assert base_card.components is not None
|
||||
return [
|
||||
ComponentInfo(
|
||||
component_name=c.component_name,
|
||||
component_path=c.component_path,
|
||||
storage_size=new_size
|
||||
if c.component_name == "transformer"
|
||||
else c.storage_size,
|
||||
n_layers=c.n_layers,
|
||||
can_shard=c.can_shard,
|
||||
safetensors_index_filename=c.safetensors_index_filename,
|
||||
)
|
||||
for c in base_card.components
|
||||
]
|
||||
|
||||
variants = {
|
||||
base_name: ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=transformer_bytes + remaining_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
)
|
||||
}
|
||||
|
||||
for quant in quantizations:
|
||||
quant_transformer_bytes = Memory.from_bytes(
|
||||
(num_transformer_bytes * quant) // 16
|
||||
)
|
||||
total_bytes = remaining_bytes + quant_transformer_bytes
|
||||
|
||||
model_id = ModelId(base_card.model_id + f"-{quant}bit")
|
||||
|
||||
variants[f"{base_name}-{quant}bit"] = ModelCard(
|
||||
model_id=model_id,
|
||||
storage_size=total_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
)
|
||||
|
||||
return variants
|
||||
|
||||
|
||||
_image_model_cards: dict[str, ModelCard] = {}
|
||||
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
|
||||
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
|
||||
_IMAGE_MODEL_CARDS = _image_model_cards
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
|
||||
@@ -98,6 +98,8 @@ class LogprobsContentItem(BaseModel):
|
||||
|
||||
class Logprobs(BaseModel):
|
||||
content: list[LogprobsContentItem] | None = None
|
||||
# This will always be null for open source models, but exists for OpenAI API
|
||||
refusal: list[LogprobsContentItem] | None = None
|
||||
|
||||
|
||||
class PromptTokensDetails(BaseModel):
|
||||
@@ -150,6 +152,7 @@ class GenerationStats(BaseModel):
|
||||
generation_tps: float
|
||||
prompt_tokens: int
|
||||
generation_tokens: int
|
||||
reasoning_tokens: int = 0
|
||||
peak_memory_usage: Memory
|
||||
|
||||
|
||||
@@ -170,6 +173,52 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
# Legacy Completions API types (for lm_eval compatibility)
|
||||
class CompletionLogprobs(BaseModel):
|
||||
"""Logprobs in the legacy completions format."""
|
||||
|
||||
tokens: list[str]
|
||||
token_logprobs: list[float | None]
|
||||
top_logprobs: list[dict[str, float]]
|
||||
text_offset: list[int]
|
||||
|
||||
|
||||
class CompletionChoice(BaseModel):
|
||||
text: str
|
||||
index: int
|
||||
logprobs: CompletionLogprobs | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["text_completion"] = "text_completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: list[CompletionChoice]
|
||||
usage: Usage | None = None
|
||||
|
||||
|
||||
class CompletionTaskParams(BaseModel):
|
||||
"""Parameters for the legacy /v1/completions endpoint."""
|
||||
|
||||
model: str
|
||||
# Prompt can be: string, list of strings, list of token IDs, or list of token ID lists
|
||||
prompt: str | list[str] | list[int] | list[list[int]]
|
||||
max_tokens: int | None = 16
|
||||
temperature: float | None = 1.0
|
||||
top_p: float | None = 1.0
|
||||
n: int | None = 1
|
||||
stream: bool = False
|
||||
logprobs: int | None = None
|
||||
echo: bool = False
|
||||
stop: str | list[str] | None = None
|
||||
presence_penalty: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
seed: int | None = None
|
||||
user: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats, TopLogprobItem
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -17,6 +17,8 @@ class BaseChunk(TaggedModel):
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -32,6 +34,17 @@ class ToolCallChunk(BaseChunk):
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class CompletionChunk(BaseChunk):
|
||||
"""Chunk for legacy completions API with full logprobs for all tokens."""
|
||||
|
||||
text: str
|
||||
tokens: list[str]
|
||||
token_logprobs: list[float | None]
|
||||
top_logprobs: list[dict[str, float]]
|
||||
text_offset: list[int]
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
data: str
|
||||
chunk_index: int
|
||||
@@ -67,4 +80,4 @@ class InputImageChunk(BaseChunk):
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
GenerationChunk = TokenChunk | CompletionChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
|
||||
@@ -3,6 +3,7 @@ from pydantic import Field
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
CompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
@@ -25,6 +26,12 @@ class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class Completion(BaseCommand):
|
||||
"""Legacy completions API command for scoring/generation."""
|
||||
|
||||
request_params: CompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
request_params: ImageGenerationTaskParams
|
||||
|
||||
@@ -48,10 +55,6 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -83,12 +86,12 @@ Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
| ChatCompletion
|
||||
| Completion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
12
src/exo/shared/types/mlx.py
Normal file
12
src/exo/shared/types/mlx.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Shared types for MLX-related functionality."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]
|
||||
@@ -4,6 +4,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
CompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
@@ -24,7 +25,6 @@ class TaskStatus(str, Enum):
|
||||
Complete = "Complete"
|
||||
TimedOut = "TimedOut"
|
||||
Failed = "Failed"
|
||||
Cancelled = "Cancelled"
|
||||
|
||||
|
||||
class BaseTask(TaggedModel):
|
||||
@@ -61,8 +61,14 @@ class ChatCompletion(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask):
|
||||
cancelled_task_id: TaskId
|
||||
class Completion(BaseTask):
|
||||
"""Legacy completions task for scoring tokens with echo=True."""
|
||||
|
||||
command_id: CommandId
|
||||
task_params: CompletionTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
@@ -92,7 +98,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| CancelTask
|
||||
| Completion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -6,6 +6,7 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
TopLogprobItem,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
@@ -14,14 +15,11 @@ class BaseRunnerResponse(TaggedModel):
|
||||
pass
|
||||
|
||||
|
||||
class TokenizedResponse(BaseRunnerResponse):
|
||||
prompt_tokens: int
|
||||
|
||||
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
@@ -194,6 +194,22 @@ class MpReceiver[T]:
|
||||
raise EndOfStream from None
|
||||
return item
|
||||
|
||||
def receive_with_timeout(self, timeout: float) -> T | None:
|
||||
"""Receive with timeout, returns None if no message within timeout."""
|
||||
if self._state.closed.is_set():
|
||||
raise ClosedResourceError
|
||||
|
||||
try:
|
||||
item = self._state.buffer.get(block=True, timeout=timeout)
|
||||
if isinstance(item, _MpEndOfStream):
|
||||
self.close()
|
||||
raise EndOfStream
|
||||
return item
|
||||
except Empty:
|
||||
return None
|
||||
except ValueError as e:
|
||||
raise ClosedResourceError from e
|
||||
|
||||
# nb: this function will not cancel particularly well
|
||||
async def receive_async(self) -> T:
|
||||
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
|
||||
|
||||
@@ -349,13 +349,8 @@ class InfoGatherer:
|
||||
async def _monitor_misc(self):
|
||||
if self.misc_poll_interval is None:
|
||||
return
|
||||
prev = await MiscData.gather()
|
||||
await self.info_sender.send(prev)
|
||||
while True:
|
||||
curr = await MiscData.gather()
|
||||
if prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await self.info_sender.send(await MiscData.gather())
|
||||
await anyio.sleep(self.misc_poll_interval)
|
||||
|
||||
async def _monitor_system_profiler_thunderbolt_data(self):
|
||||
@@ -365,15 +360,12 @@ class InfoGatherer:
|
||||
if iface_map is None:
|
||||
return
|
||||
|
||||
old_idents = []
|
||||
while True:
|
||||
data = await ThunderboltConnectivity.gather()
|
||||
assert data is not None
|
||||
|
||||
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
|
||||
if idents != old_idents:
|
||||
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
|
||||
old_idents = idents
|
||||
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
|
||||
|
||||
conns = [it for i in data if (it := i.conn()) is not None]
|
||||
await self.info_sender.send(MacThunderboltConnections(conns=conns))
|
||||
@@ -398,22 +390,17 @@ class InfoGatherer:
|
||||
async def _watch_system_info(self):
|
||||
if self.interface_watcher_interval is None:
|
||||
return
|
||||
old_nics = []
|
||||
while True:
|
||||
nics = await get_network_interfaces()
|
||||
if nics != old_nics:
|
||||
old_nics = nics
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_thunderbolt_bridge_status(self):
|
||||
if self.thunderbolt_bridge_poll_interval is None:
|
||||
return
|
||||
prev: ThunderboltBridgeInfo | None = None
|
||||
while True:
|
||||
curr = await ThunderboltBridgeInfo.gather()
|
||||
if curr is not None and prev != curr:
|
||||
prev = curr
|
||||
if curr is not None:
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Callable, Generator
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
@@ -109,7 +109,6 @@ class DistributedImageModel:
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
advanced_params: AdvancedImageParams | None = None,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
if (
|
||||
advanced_params is not None
|
||||
@@ -154,7 +153,6 @@ class DistributedImageModel:
|
||||
guidance_override=guidance_override,
|
||||
negative_prompt=negative_prompt,
|
||||
num_sync_steps=num_sync_steps,
|
||||
cancel_checker=cancel_checker,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (GeneratedImage, partial_index, total_partials)
|
||||
|
||||
@@ -3,7 +3,6 @@ import io
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
@@ -69,18 +68,12 @@ def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
|
||||
def generate_image(
|
||||
model: DistributedImageModel,
|
||||
task: ImageGenerationTaskParams | ImageEditsInternalParams,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
|
||||
"""Generate image(s), optionally yielding partial results.
|
||||
|
||||
When partial_images > 0 or stream=True, yields PartialImageResponse for
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Args:
|
||||
model: The distributed image model to use for generation.
|
||||
task: The task parameters for image generation or editing.
|
||||
cancel_checker: Optional callback to check if generation should be cancelled.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
|
||||
ImageGenerationResponse for final complete images
|
||||
@@ -130,7 +123,6 @@ def generate_image(
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
cancel_checker=cancel_checker,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from collections.abc import Callable
|
||||
from math import ceil
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -95,8 +94,6 @@ class DiffusionRunner:
|
||||
self.total_layers = config.total_blocks
|
||||
|
||||
self._guidance_override: float | None = None
|
||||
self._cancel_checker: Callable[[], bool] | None = None
|
||||
self._cancelling = False
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
@@ -151,54 +148,6 @@ class DiffusionRunner:
|
||||
return self._guidance_override
|
||||
return self.config.guidance_scale
|
||||
|
||||
def _check_cancellation(self) -> bool:
|
||||
if self._cancelling:
|
||||
return True
|
||||
if (
|
||||
self.is_first_stage
|
||||
and self._cancel_checker is not None
|
||||
and self._cancel_checker()
|
||||
):
|
||||
self._cancelling = True
|
||||
return self._cancelling
|
||||
|
||||
def _is_sentinel(self, tensor: mx.array) -> bool:
|
||||
return bool(mx.all(mx.isnan(tensor)).item())
|
||||
|
||||
def _make_sentinel_like(self, tensor: mx.array) -> mx.array:
|
||||
return mx.full(tensor.shape, float("nan"), dtype=tensor.dtype)
|
||||
|
||||
def _recv(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: mx.Dtype,
|
||||
src: int,
|
||||
) -> mx.array:
|
||||
"""Receive data and check for cancellation sentinel."""
|
||||
data = mx.distributed.recv(shape, dtype, src, group=self.group)
|
||||
mx.eval(data)
|
||||
if self._is_sentinel(data):
|
||||
self._cancelling = True
|
||||
return data
|
||||
|
||||
def _recv_like(self, template: mx.array, src: int) -> mx.array:
|
||||
"""Receive data matching template and check for cancellation sentinel."""
|
||||
data = mx.distributed.recv_like(template, src=src, group=self.group)
|
||||
mx.eval(data)
|
||||
if self._is_sentinel(data):
|
||||
self._cancelling = True
|
||||
return data
|
||||
|
||||
def _send(self, data: mx.array, dst: int) -> mx.array:
|
||||
"""Send data, or sentinel if cancelling."""
|
||||
|
||||
if self._cancelling:
|
||||
data = self._make_sentinel_like(data)
|
||||
|
||||
result = mx.distributed.send(data, dst, group=self.group)
|
||||
mx.async_eval(result)
|
||||
return result
|
||||
|
||||
def _ensure_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
@@ -295,7 +244,6 @@ class DiffusionRunner:
|
||||
guidance_override: float | None = None,
|
||||
negative_prompt: str | None = None,
|
||||
num_sync_steps: int = 1,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
):
|
||||
"""Primary entry point for image generation.
|
||||
|
||||
@@ -307,21 +255,17 @@ class DiffusionRunner:
|
||||
5. Decode to image
|
||||
|
||||
Args:
|
||||
runtime_config: Runtime configuration (steps, height, width)
|
||||
settings: Generation config (steps, height, width)
|
||||
prompt: Text prompt
|
||||
seed: Random seed
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
guidance_override: Optional override for guidance scale (CFG)
|
||||
negative_prompt: Optional negative prompt for CFG
|
||||
num_sync_steps: Number of synchronous pipeline steps
|
||||
cancel_checker: Optional callback to check for cancellation
|
||||
|
||||
Yields:
|
||||
Partial images as (GeneratedImage, partial_index, total_partials) tuples
|
||||
Final GeneratedImage
|
||||
"""
|
||||
self._guidance_override = guidance_override
|
||||
self._cancel_checker = cancel_checker
|
||||
latents = self.adapter.create_latents(seed, runtime_config)
|
||||
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
|
||||
|
||||
@@ -363,7 +307,7 @@ class DiffusionRunner:
|
||||
except StopIteration as e:
|
||||
latents = e.value # pyright: ignore[reportAny]
|
||||
|
||||
if self.is_last_stage and not self._cancelling:
|
||||
if self.is_last_stage:
|
||||
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt) # pyright: ignore[reportAny]
|
||||
|
||||
def _run_diffusion_loop(
|
||||
@@ -379,7 +323,6 @@ class DiffusionRunner:
|
||||
if capture_steps is None:
|
||||
capture_steps = set()
|
||||
|
||||
self._cancelling = False
|
||||
self._reset_all_caches()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
@@ -402,13 +345,9 @@ class DiffusionRunner:
|
||||
num_sync_steps=num_sync_steps,
|
||||
)
|
||||
|
||||
if self._cancelling:
|
||||
break
|
||||
|
||||
ctx.in_loop( # pyright: ignore[reportAny]
|
||||
t=t,
|
||||
latents=latents,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
@@ -417,7 +356,7 @@ class DiffusionRunner:
|
||||
yield (latents, t)
|
||||
|
||||
except KeyboardInterrupt: # noqa: PERF203
|
||||
ctx.interruption(t=t, latents=latents, time_steps=time_steps) # pyright: ignore[reportAny]
|
||||
ctx.interruption(t=t, latents=latents) # pyright: ignore[reportAny]
|
||||
raise StopImageGenerationException(
|
||||
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
|
||||
) from None
|
||||
@@ -627,8 +566,6 @@ class DiffusionRunner:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_hidden_states_mask)
|
||||
|
||||
self._check_cancellation()
|
||||
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
if self.is_first_stage:
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
@@ -648,12 +585,19 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
hidden_states = self._recv(
|
||||
(batch_size, num_img_tokens, hidden_dim), dtype, self.prev_rank
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = self._recv(
|
||||
(batch_size, text_seq_len, hidden_dim), dtype, self.prev_rank
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
@@ -675,20 +619,30 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
concatenated = self._send(concatenated, self.next_rank)
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
hidden_states = self._send(hidden_states, self.next_rank)
|
||||
encoder_hidden_states = self._send(encoder_hidden_states, self.next_rank)
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
hidden_states = self._recv(
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
@@ -700,7 +654,10 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
hidden_states = self._send(hidden_states, self.next_rank)
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
@@ -784,13 +741,14 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = self._send(hidden_states, 0)
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = self._recv_like(prev_latents, src=self.world_size - 1)
|
||||
|
||||
if self._cancelling:
|
||||
return prev_latents
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
else:
|
||||
hidden_states = prev_latents
|
||||
@@ -850,9 +808,10 @@ class DiffusionRunner:
|
||||
and not self.is_last_stage
|
||||
and not is_first_async_step
|
||||
):
|
||||
patch = self._recv_like(patch, src=self.prev_rank)
|
||||
|
||||
self._check_cancellation()
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
|
||||
@@ -883,19 +842,10 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
patch_latents[patch_idx] = self._send(
|
||||
patch_latents[patch_idx], self.next_rank
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
)
|
||||
|
||||
# Drain final rank patch sends if cancelling
|
||||
if (
|
||||
self._cancelling
|
||||
and self.is_first_stage
|
||||
and not self.is_last_stage
|
||||
and t != config.num_inference_steps - 1
|
||||
):
|
||||
for patch_idx in range(len(patch_latents)):
|
||||
_ = self._recv_like(patch_latents[patch_idx], src=self.prev_rank)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
@@ -934,16 +884,22 @@ class DiffusionRunner:
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = self._recv(
|
||||
(batch_size, patch_len, hidden_dim), patch.dtype, self.prev_rank
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
encoder_hidden_states = self._recv(
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
@@ -968,25 +924,32 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
patch_concat = self._send(patch_concat, self.next_rank)
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
patch = self._send(patch, self.next_rank)
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
encoder_hidden_states = self._send(
|
||||
encoder_hidden_states, self.next_rank
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = self._recv(
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
@@ -998,7 +961,8 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
patch = self._send(patch, self.next_rank)
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
if self.is_last_stage:
|
||||
|
||||
@@ -13,12 +13,17 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.base import (
|
||||
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
|
||||
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
@@ -27,7 +32,8 @@ from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
@@ -100,6 +106,16 @@ class CustomMlxLayer(nn.Module):
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
class EvalCheckpointLayer(CustomMlxLayer):
|
||||
"""Wraps a layer to force evaluation of its output, breaking up the computation graph
|
||||
to prevent Metal command buffer timeouts with large batches in pipeline parallel."""
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
output = self.original_layer(x, *args, **kwargs)
|
||||
mx.eval(output)
|
||||
return output
|
||||
|
||||
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -137,14 +153,20 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
).arguments.get("cache", None)
|
||||
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
mx.eval(output)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
mx.async_eval(output)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
-output.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -195,6 +217,9 @@ def pipeline_auto_parallel(
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
# Wrap intermediate layers with eval checkpoints to prevent GPU timeout
|
||||
for i in range(1, len(layers) - 1):
|
||||
layers[i] = EvalCheckpointLayer(layers[i])
|
||||
layers[-1] = PipelineLastLayer(
|
||||
layers[-1],
|
||||
device_rank,
|
||||
@@ -248,14 +273,14 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
"cache", None
|
||||
)
|
||||
|
||||
# Evaluate logits before all_gather to break the computation graph
|
||||
# and prevent Metal command buffer timeouts with large batches
|
||||
mx.eval(logits)
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
logits = mx.distributed.all_gather(logits, group=group)[
|
||||
-logits.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
@@ -334,15 +359,7 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return patch_tensor_model(model)
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -351,7 +368,6 @@ def tensor_auto_parallel(
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -367,6 +383,14 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GLM4MoeLiteModel):
|
||||
tensor_parallel_sharding_strategy = GLM4MoeLiteShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
@@ -441,7 +465,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -496,6 +520,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
# Store pre-shard head count and group for context parallelism
|
||||
layer.self_attn.context_parallel_total_heads = layer.self_attn.num_heads
|
||||
layer.self_attn._cp_group = self.group
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
@@ -516,6 +543,12 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
mx.eval(layer)
|
||||
|
||||
# Store group for context parallelism
|
||||
if hasattr(model, "model"):
|
||||
model.model._cp_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -533,6 +566,158 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(GLM4MoeLiteModel, model)
|
||||
for layer in model.layers: # type: ignore
|
||||
layer = cast(Glm4MoeLiteDecoderLayer, layer)
|
||||
eval_with_timeout(
|
||||
layer.parameters(),
|
||||
timeout_seconds / len(model.layers), # type: ignore
|
||||
on_timeout,
|
||||
)
|
||||
if layer.self_attn.q_lora_rank is None: # type: ignore
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
else:
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Logic from upstream mlx
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
|
||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
||||
return w[sh:eh]
|
||||
|
||||
layer.self_attn.embed_q.apply(shard_heads)
|
||||
layer.self_attn.unembed_out.apply(shard_heads)
|
||||
|
||||
if isinstance(layer.mlp, Glm4MoeLiteMLP):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
else:
|
||||
if getattr(layer.mlp, "shared_experts", None) is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.up_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # type: ignore
|
||||
mx.eval(layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class WrappedMiniMaxAttention(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
|
||||
super().__init__(layer)
|
||||
self.group = group
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | Any = None,
|
||||
cache: Any | None = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
if getattr(self, "use_qk_norm", False):
|
||||
q_dim = queries.shape[-1] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
k_dim = keys.shape[-1] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
N = self.group.size()
|
||||
|
||||
qk = mx.concatenate([queries, keys], axis=-1) # (B, L, q_dim + k_dim)
|
||||
qk = mx.distributed.all_gather(
|
||||
qk, group=self.group
|
||||
) # (N*B, L, q_dim + k_dim)
|
||||
|
||||
# Reshape to separate rank contributions: (N, B, L, q_dim + k_dim)
|
||||
# Then transpose to (B, L, N, q_dim + k_dim) and merge N into feature dim
|
||||
qk = qk.reshape(N, B, L, q_dim + k_dim).transpose(1, 2, 0, 3) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
queries = qk[..., :q_dim].reshape(
|
||||
B, L, -1
|
||||
) # (B, L, N * q_dim) # pyright: ignore[reportUnknownMemberType]
|
||||
keys = qk[..., q_dim:].reshape(
|
||||
B, L, -1
|
||||
) # (B, L, N * k_dim) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
queries = self.q_norm(queries) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
keys = self.k_norm(keys) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
# Split back and take this rank's portion
|
||||
queries = mx.split(queries, N, axis=-1)[self.group.rank()]
|
||||
keys = mx.split(keys, N, axis=-1)[self.group.rank()]
|
||||
|
||||
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
|
||||
0, 2, 1, 3
|
||||
)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
|
||||
0, 2, 1, 3
|
||||
)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportAny]
|
||||
keys = self.rope(keys, offset=cache.offset) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportAny]
|
||||
keys, values = cache.update_and_fetch(keys, values) # pyright: ignore[reportAny]
|
||||
else:
|
||||
queries = self.rope(queries) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
keys = self.rope(keys) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=self.scale,
|
||||
mask=mask, # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
return self.o_proj(output) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -550,9 +735,12 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
@@ -566,7 +754,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -577,18 +765,32 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
if isinstance(layer, Qwen3DecoderLayer) or hasattr(layer, "self_attn"):
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
else:
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer) and hasattr(
|
||||
layer, "linear_attn"
|
||||
)
|
||||
# These layers are fast so we don't shard. This may change in future.
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
@@ -607,6 +809,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -661,7 +864,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -1,39 +1,81 @@
|
||||
# type: ignore
|
||||
# TODO: Fix this file, including types!
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
trim_prompt_cache,
|
||||
)
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
|
||||
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.85
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
||||
)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self):
|
||||
# Only one prefix cache per runner.
|
||||
def __init__(self, tokenizer: TokenizerWrapper):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[list[_BaseCache]] = []
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
|
||||
def add_kv_cache(
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
|
||||
):
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
self.prompts.clear()
|
||||
self.caches.clear()
|
||||
self._last_used.clear()
|
||||
|
||||
def add_kv_cache(self, prompt: str, cache: KVCacheType):
|
||||
"""Add a new cache entry. Evicts LRU entries if memory is high."""
|
||||
self._evict_if_needed()
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.caches.append(deepcopy(cache))
|
||||
self._access_counter += 1
|
||||
self._last_used.append(self._access_counter)
|
||||
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
|
||||
|
||||
def update_kv_cache(
|
||||
self,
|
||||
index: int,
|
||||
prompt: str,
|
||||
cache: KVCacheType,
|
||||
):
|
||||
"""Update an existing cache entry in-place."""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts[index] = tokenized_prompt
|
||||
self.caches[index] = deepcopy(cache)
|
||||
self._access_counter += 1
|
||||
self._last_used[index] = self._access_counter
|
||||
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: str,
|
||||
) -> list[_BaseCache]:
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
) -> tuple[KVCacheType, mx.array, int | None]:
|
||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
||||
|
||||
Returns:
|
||||
Tuple of (cache, remaining_tokens, matched_index) where:
|
||||
- cache: KV cache to use for generation
|
||||
- remaining_tokens: tokens that still need prefilling
|
||||
- matched_index: index of the matched entry (None if no match)
|
||||
"""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
@@ -42,63 +84,127 @@ class KVPrefixCache:
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
return self.caches[i]
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = _cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
self._access_counter += 1
|
||||
self._last_used[i] = self._access_counter
|
||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
||||
return prompt_cache, tokenized_prompt[-1:], i
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
|
||||
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(
|
||||
model,
|
||||
# max_kv_size=MAX_KV_SIZE,
|
||||
# keep=KEEP_KV_SIZE
|
||||
new_tokens = max_length - best_snapshot_length
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
||||
)
|
||||
|
||||
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
return prompt_cache
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = _cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
|
||||
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
tokenized_prompt = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
return mx.array(tokenized_prompt)
|
||||
self._access_counter += 1
|
||||
self._last_used[best_snapshot_index] = self._access_counter
|
||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
||||
return prompt_cache, remaining_tokens, best_snapshot_index
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(model)
|
||||
if len(self.prompts) == 0:
|
||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
||||
)
|
||||
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory pressure is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
active: int = mx.metal.get_active_memory()
|
||||
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
while len(self.caches) > 0:
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
|
||||
)
|
||||
|
||||
active = mx.metal.get_active_memory()
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
break
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""Encode a prompt string to token array.
|
||||
|
||||
For chat-templated prompts (which have their own structure markers like
|
||||
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
|
||||
that would corrupt the prompt structure.
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def _cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
|
||||
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
|
||||
"""Find the length of the common prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
if n == 0:
|
||||
return 0
|
||||
|
||||
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
|
||||
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
|
||||
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: mx.array,
|
||||
cache: list[_BaseCache],
|
||||
) -> None:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=0,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
pass
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> KVCacheType:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
return [KVCache() for _ in model.layers]
|
||||
else:
|
||||
logger.info("Using quantized KV cache")
|
||||
return [
|
||||
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
|
||||
for _ in model.layers
|
||||
]
|
||||
else:
|
||||
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
|
||||
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
KV_GROUP_SIZE: int | None = 32
|
||||
KV_BITS: int | None = None
|
||||
ATTENTION_KV_BITS: int | None = 4
|
||||
MAX_TOKENS: int = 8192
|
||||
MAX_TOKENS: int = 32168
|
||||
MAX_KV_SIZE: int | None = 3200
|
||||
KEEP_KV_SIZE: int | None = 1600
|
||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
|
||||
@@ -1,47 +1,92 @@
|
||||
import time
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.cache import KVCache, trim_prompt_cache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
TopLogprobItem,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
make_kv_cache,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
_MIN_PREFIX_HIT_TO_UPDATE = 1000
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[KVCache | Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache):
|
||||
if (
|
||||
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> float:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
then trims off the extra generated token.
|
||||
|
||||
Returns:
|
||||
tokens_per_sec
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.time() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
||||
)
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
trim_prompt_cache(cast(list[Any], cache), 1)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
logger.debug(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
@@ -89,6 +134,10 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
|
||||
return tokens_generated
|
||||
|
||||
|
||||
@@ -110,11 +159,212 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs_array: mx.array,
|
||||
selected_token: int,
|
||||
tokenizer: TokenizerWrapper,
|
||||
top_k: int | None,
|
||||
) -> tuple[float, list[TopLogprobItem]]:
|
||||
"""Extract the selected token's logprob and top-k alternatives.
|
||||
|
||||
top k an be set to None to return all the logprobs
|
||||
"""
|
||||
selected_logprob = float(logprobs_array[selected_token].item())
|
||||
|
||||
if top_k == 0:
|
||||
return selected_logprob, []
|
||||
|
||||
vocab_size = logprobs_array.shape[0]
|
||||
|
||||
if top_k is None:
|
||||
sorted_indices = mx.argsort(-logprobs_array)
|
||||
mx.eval(sorted_indices)
|
||||
indices_list: list[int] = cast(list[int], sorted_indices.tolist())
|
||||
else:
|
||||
k = min(top_k, vocab_size)
|
||||
top_indices = mx.argpartition(-logprobs_array, kth=k - 1)[:k]
|
||||
top_logprobs_values = logprobs_array[top_indices]
|
||||
sorted_order = mx.argsort(-top_logprobs_values)
|
||||
top_indices = top_indices[sorted_order]
|
||||
mx.eval(top_indices)
|
||||
indices_list = cast(list[int], top_indices.tolist())
|
||||
|
||||
top_logprob_items: list[TopLogprobItem] = []
|
||||
for token_id in indices_list:
|
||||
logprob_value = float(logprobs_array[token_id].item())
|
||||
token_str = tokenizer.decode([token_id])
|
||||
|
||||
top_logprob_items.append(
|
||||
TopLogprobItem(
|
||||
token=token_str,
|
||||
logprob=logprob_value,
|
||||
bytes=list(token_str.encode("utf-8")),
|
||||
)
|
||||
)
|
||||
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def score_tokens(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
tokens: list[int],
|
||||
top_k: int | None = None,
|
||||
) -> list[tuple[float, list[TopLogprobItem]]]:
|
||||
"""Score a sequence of tokens, returning logprobs for each token.
|
||||
|
||||
This is used for the completions API with echo=True, where we need
|
||||
logprobs for the prompt tokens (not just generated tokens).
|
||||
|
||||
Args:
|
||||
model: The MLX model.
|
||||
tokenizer: The tokenizer.
|
||||
tokens: List of token IDs to score.
|
||||
top_k: Number of top logprobs to return per position.
|
||||
If None, returns all logprobs.
|
||||
|
||||
Returns:
|
||||
List of (token_logprob, top_logprobs) tuples for each token position.
|
||||
The first position has no logprob (no previous context), so returns (0.0, []).
|
||||
"""
|
||||
if len(tokens) == 0:
|
||||
return []
|
||||
|
||||
# First token has no previous context to condition on
|
||||
results: list[tuple[float, list[TopLogprobItem]]] = [(0.0, [])]
|
||||
|
||||
if len(tokens) == 1:
|
||||
return results
|
||||
|
||||
# Create an empty KV cache for the forward pass
|
||||
cache = make_kv_cache(model=model)
|
||||
|
||||
# Convert to MLX array and run forward pass
|
||||
input_tokens = mx.array(tokens[:-1])[None] # All tokens except last, batched
|
||||
|
||||
# Run the model to get logits for all positions
|
||||
# The model returns logits with shape [1, seq_len, vocab_size]
|
||||
logits: mx.array = model(input_tokens, cache=cast(list[KVCache], cache))
|
||||
logits = logits.squeeze(0) # Shape: [seq_len, vocab_size]
|
||||
|
||||
# Convert to log probabilities
|
||||
logprobs_all: mx.array = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
|
||||
mx.eval(logprobs_all)
|
||||
|
||||
# For each position, extract the logprob of the actual next token
|
||||
for i in range(len(tokens) - 1):
|
||||
next_token = tokens[i + 1]
|
||||
logprobs_at_position: mx.array = logprobs_all[i]
|
||||
|
||||
logprob, top_logprobs_items = extract_top_logprobs(
|
||||
logprobs_array=logprobs_at_position,
|
||||
selected_token=next_token,
|
||||
tokenizer=tokenizer,
|
||||
top_k=top_k,
|
||||
)
|
||||
results.append((logprob, top_logprobs_items))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def score_tokens_batched(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
token_sequences: list[list[int]],
|
||||
top_k: int | None = None,
|
||||
) -> list[list[tuple[float, list[TopLogprobItem]]]]:
|
||||
"""Score multiple token sequences in a single batched forward pass.
|
||||
|
||||
This is significantly faster than calling score_tokens() multiple times
|
||||
because it batches the forward pass across all sequences.
|
||||
|
||||
Args:
|
||||
model: The MLX model.
|
||||
tokenizer: The tokenizer.
|
||||
token_sequences: List of token ID sequences to score.
|
||||
top_k: Number of top logprobs to return per position.
|
||||
|
||||
Returns:
|
||||
List of results for each sequence. Each result is a list of
|
||||
(token_logprob, top_logprobs) tuples for each token position.
|
||||
"""
|
||||
if not token_sequences:
|
||||
return []
|
||||
|
||||
# Handle empty sequences and single-token sequences
|
||||
results: list[list[tuple[float, list[TopLogprobItem]]]] = []
|
||||
non_empty_indices: list[int] = []
|
||||
non_empty_sequences: list[list[int]] = []
|
||||
|
||||
for i, tokens in enumerate(token_sequences):
|
||||
if len(tokens) == 0:
|
||||
results.append([])
|
||||
elif len(tokens) == 1:
|
||||
results.append([(0.0, [])])
|
||||
else:
|
||||
results.append([]) # Placeholder, will be filled later
|
||||
non_empty_indices.append(i)
|
||||
non_empty_sequences.append(tokens)
|
||||
|
||||
if not non_empty_sequences:
|
||||
return results
|
||||
|
||||
# Find max sequence length (excluding last token since we predict it)
|
||||
max_len = max(len(seq) - 1 for seq in non_empty_sequences)
|
||||
|
||||
# Get pad token (use eos_token_id or 0)
|
||||
pad_token_id = getattr(tokenizer, "pad_token_id", None)
|
||||
if pad_token_id is None:
|
||||
pad_token_id = getattr(tokenizer, "eos_token_id", 0)
|
||||
|
||||
# Pad sequences and create attention mask
|
||||
batch_size = len(non_empty_sequences)
|
||||
padded_inputs = mx.full((batch_size, max_len), pad_token_id, dtype=mx.int32)
|
||||
seq_lengths: list[int] = []
|
||||
|
||||
for i, tokens in enumerate(non_empty_sequences):
|
||||
input_len = len(tokens) - 1 # Exclude last token
|
||||
padded_inputs[i, :input_len] = mx.array(tokens[:-1], dtype=mx.int32)
|
||||
seq_lengths.append(input_len)
|
||||
|
||||
# Run batched forward pass (no KV cache for scoring)
|
||||
# The model accepts [batch_size, seq_len] and returns [batch_size, seq_len, vocab_size]
|
||||
logits = model(padded_inputs, cache=None)
|
||||
|
||||
# Convert to log probabilities - logits shape: [batch, seq_len, vocab]
|
||||
logprobs_all = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
mx.eval(logprobs_all)
|
||||
|
||||
# Extract results for each sequence
|
||||
for batch_idx, (orig_idx, tokens, seq_len) in enumerate(
|
||||
zip(non_empty_indices, non_empty_sequences, seq_lengths, strict=True)
|
||||
):
|
||||
seq_results: list[tuple[float, list[TopLogprobItem]]] = [(0.0, [])]
|
||||
|
||||
for pos in range(seq_len):
|
||||
next_token = tokens[pos + 1]
|
||||
logprobs_at_position: mx.array = logprobs_all[batch_idx, pos]
|
||||
|
||||
logprob, top_logprobs_items = extract_top_logprobs(
|
||||
logprobs_array=logprobs_at_position,
|
||||
selected_token=next_token,
|
||||
tokenizer=tokenizer,
|
||||
top_k=top_k,
|
||||
)
|
||||
seq_results.append((logprob, top_logprobs_items))
|
||||
|
||||
results[orig_idx] = seq_results
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -126,7 +376,22 @@ def mlx_generate(
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
||||
if is_bench:
|
||||
kv_prefix_cache = None
|
||||
|
||||
# Use prefix cache if available, otherwise create fresh cache
|
||||
prefix_hit_length = 0
|
||||
matched_index: int | None = None
|
||||
if kv_prefix_cache is None:
|
||||
caches = make_kv_cache(model=model)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
else:
|
||||
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
)
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
@@ -139,11 +404,23 @@ def mlx_generate(
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
|
||||
# Determine if we need logprobs
|
||||
should_extract_logprobs = task.logprobs is True
|
||||
top_k = task.top_logprobs if task.top_logprobs is not None else 0
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
@@ -153,12 +430,13 @@ def mlx_generate(
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
@@ -172,12 +450,47 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
# Extract logprobs if requested
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if should_extract_logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs_array=out.logprobs,
|
||||
selected_token=out.token,
|
||||
tokenizer=tokenizer,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
# Log generation stats
|
||||
generation_elapsed = time.perf_counter() - generation_start_time
|
||||
generated_tokens = len(generated_text_parts)
|
||||
generation_tps = (
|
||||
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
|
||||
)
|
||||
logger.debug(
|
||||
f"Generation complete: prefill {prompt_tokens} tokens @ "
|
||||
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
|
||||
f"{generation_tps:.1f} tok/s"
|
||||
)
|
||||
if kv_prefix_cache is not None:
|
||||
full_prompt = prompt + "".join(generated_text_parts)
|
||||
if (
|
||||
matched_index is not None
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
|
||||
else:
|
||||
kv_prefix_cache.add_kv_cache(full_prompt, caches)
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -18,15 +18,12 @@ try:
|
||||
except ImportError:
|
||||
pass # transformers < 5.0 or bytes_to_unicode not available
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -70,6 +67,8 @@ Group = mx.distributed.Group
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -87,6 +86,30 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
||||
group=group,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
if group.rank() == 0:
|
||||
a = mx.array([value], dtype=mx.int32)
|
||||
else:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
||||
mx.eval(m)
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -379,7 +402,11 @@ def apply_chat_template(
|
||||
continue
|
||||
|
||||
message.content = "\n".join(c.text for c in message.content).strip()
|
||||
if message.content is None and message.thinking is None:
|
||||
if (
|
||||
message.content is None
|
||||
and message.thinking is None
|
||||
and message.tool_calls is None
|
||||
):
|
||||
continue
|
||||
|
||||
# Null values are not valid when applying templates in tokenizer
|
||||
@@ -436,31 +463,6 @@ class NullKVCache(KVCache):
|
||||
raise NotImplementedError("We should not be setting a NullKVCache.")
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
return [KVCache() for _ in model.layers]
|
||||
else:
|
||||
logger.info("Using quantized KV cache")
|
||||
return [
|
||||
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
|
||||
for _ in model.layers
|
||||
]
|
||||
else:
|
||||
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
|
||||
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
|
||||
|
||||
|
||||
def mlx_force_oom(size: int = 40000) -> None:
|
||||
"""
|
||||
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
|
||||
@@ -510,23 +512,3 @@ def mlx_cleanup(
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ChatCompletion,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -116,9 +116,8 @@ class Worker:
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
async with create_task_group() as tg:
|
||||
for runner in self.runners.values():
|
||||
tg.start_soon(runner.shutdown)
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
@@ -186,8 +185,10 @@ class Worker:
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
# Only sleep when there's nothing to do - allows rapid task dispatch
|
||||
await anyio.sleep(0.01)
|
||||
continue
|
||||
logger.info(f"Worker plan: {task.__class__.__name__}")
|
||||
logger.debug(f"Worker plan: {task.__class__.__name__}")
|
||||
assert task.task_status
|
||||
await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))
|
||||
|
||||
@@ -222,22 +223,15 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await runner.start_task(task)
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await runner.shutdown()
|
||||
case CancelTask(cancelled_task_id=cancelled_task_id):
|
||||
await self.runners[self._task_to_runner_id(task)].cancel_task(
|
||||
cancelled_task_id
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
@@ -278,6 +272,12 @@ class Worker:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case ChatCompletion():
|
||||
# Don't wait for acknowledgment for batchable inference tasks
|
||||
# This allows multiple tasks to reach the runner for batching
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
task, wait_for_ack=False
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
@@ -360,6 +360,8 @@ class Worker:
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
## Op Executors
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
|
||||
@@ -4,8 +4,8 @@ from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ChatCompletion,
|
||||
Completion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
@@ -60,8 +60,7 @@ def plan(
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _cancel_tasks(runners, tasks)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
)
|
||||
|
||||
|
||||
@@ -272,12 +271,14 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
# for now, just forward chat completions and completions
|
||||
# TODO(ciaran): do this better!
|
||||
if not isinstance(task, (ChatCompletion, ImageGeneration, ImageEdits)):
|
||||
if not isinstance(
|
||||
task, (ChatCompletion, Completion, ImageGeneration, ImageEdits)
|
||||
):
|
||||
continue
|
||||
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
|
||||
continue
|
||||
@@ -286,7 +287,7 @@ def _pending_tasks(
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
@@ -294,31 +295,21 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
# Skip tasks already sent to runner (waiting for completion)
|
||||
if task.task_id in runner.sent:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
# Allow sending tasks when runner is Ready OR Running (for batching)
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
return task
|
||||
|
||||
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
for runner in runners.values():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
if task.task_id in runner.cancelled:
|
||||
continue
|
||||
return CancelTask(
|
||||
instance_id=task.instance_id, cancelled_task_id=task.task_id
|
||||
)
|
||||
|
||||
662
src/exo/worker/runner/batched_handler.py
Normal file
662
src/exo/worker/runner/batched_handler.py
Normal file
@@ -0,0 +1,662 @@
|
||||
"""Batched inference handler for processing multiple ChatCompletion requests concurrently."""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import BatchGenerator
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
TopLogprobItem,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import ChunkGenerated, Event
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletion
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import MAX_TOKENS
|
||||
from exo.worker.engines.mlx.generator.generate import extract_top_logprobs
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
from exo.worker.runner.pipelined_generator import PipelinedGenerator, PipelinedResponse
|
||||
|
||||
# Type alias for the finish_reason values TokenChunk accepts
|
||||
TokenFinishReason = Literal["stop", "length", "content_filter"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingRequest:
|
||||
"""A request waiting to be added to the batch."""
|
||||
|
||||
task: ChatCompletion
|
||||
prompt: str
|
||||
max_tokens: int
|
||||
sampler: Callable[[mx.array], mx.array]
|
||||
should_extract_logprobs: bool
|
||||
top_k: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveRequest:
|
||||
"""A request currently being processed in the batch."""
|
||||
|
||||
command_id: CommandId
|
||||
should_extract_logprobs: bool
|
||||
top_k: int
|
||||
harmony_parser: Any | None = None # StreamableParser for GPT-OSS models
|
||||
in_thinking: bool = False # Currently in thinking/reasoning section
|
||||
tokens_generated: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
|
||||
|
||||
class BatchedInferenceHandler:
|
||||
"""
|
||||
Handles batched inference for multiple ChatCompletion requests.
|
||||
|
||||
Uses MLX-LM's BatchGenerator to process multiple requests concurrently,
|
||||
improving throughput for scenarios with multiple concurrent requests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
model_id: ModelId,
|
||||
device_rank: int,
|
||||
world_size: int = 1,
|
||||
max_batch_size: int = 32,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.model_id = model_id
|
||||
self.device_rank = device_rank
|
||||
self.world_size = world_size
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
# Model-specific thinking/reasoning detection
|
||||
self.is_gpt_oss = isinstance(model, GptOssModel)
|
||||
self._harmony_encoding: Any | None = None
|
||||
if self.is_gpt_oss:
|
||||
self._harmony_encoding = load_harmony_encoding(
|
||||
HarmonyEncodingName.HARMONY_GPT_OSS
|
||||
)
|
||||
logger.info("GPT-OSS model detected, enabling harmony stream parsing")
|
||||
|
||||
# Detect <think></think> tokens from tokenizer (works for any model)
|
||||
self._think_start_token: int | None = None
|
||||
self._think_end_token: int | None = None
|
||||
think_start: int | None = tokenizer.think_start_id # pyright: ignore[reportAny]
|
||||
if not self.is_gpt_oss and think_start is not None:
|
||||
self._think_start_token = think_start
|
||||
self._think_end_token = tokenizer.think_end_id # pyright: ignore[reportAny]
|
||||
logger.info(
|
||||
f"Detected <think></think> tokens ({self._think_start_token}/{self._think_end_token}), enabling reasoning tracking"
|
||||
)
|
||||
|
||||
# Pending requests waiting to be batched
|
||||
self.pending: list[PendingRequest] = []
|
||||
|
||||
# Active batch generator and request tracking
|
||||
self.batch_generator: BatchGenerator | None = None
|
||||
self.pipelined_generator: PipelinedGenerator | None = None
|
||||
self.uid_to_request: dict[int, ActiveRequest] = {}
|
||||
|
||||
# Use pipelined generator for multi-device pipeline parallelism
|
||||
self.use_pipelined = world_size > 1
|
||||
if self.use_pipelined:
|
||||
logger.info(
|
||||
f"Using PipelinedGenerator with {world_size} streams for pipeline overlap"
|
||||
)
|
||||
|
||||
# EOS tokens for the model
|
||||
self.stop_tokens: set[int] = set()
|
||||
eos_ids: list[int] | None = getattr(tokenizer, "eos_token_ids", None)
|
||||
if eos_ids:
|
||||
self.stop_tokens = set(eos_ids)
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""Check if there's an active batch being processed."""
|
||||
if self.use_pipelined:
|
||||
return (
|
||||
self.pipelined_generator is not None
|
||||
and self.pipelined_generator.has_active
|
||||
)
|
||||
return self.batch_generator is not None and len(self.uid_to_request) > 0
|
||||
|
||||
@property
|
||||
def has_pending(self) -> bool:
|
||||
"""Check if there are pending requests waiting to be batched."""
|
||||
return len(self.pending) > 0
|
||||
|
||||
@property
|
||||
def current_batch_size(self) -> int:
|
||||
"""Current number of active requests in the batch."""
|
||||
return len(self.uid_to_request)
|
||||
|
||||
def add_request(self, task: ChatCompletion) -> None:
|
||||
"""Add a ChatCompletion request to the pending batch."""
|
||||
task_params = task.task_params
|
||||
|
||||
# Build prompt
|
||||
prompt = apply_chat_template(self.tokenizer, task_params)
|
||||
|
||||
# Determine max tokens
|
||||
max_tokens = task_params.max_tokens or MAX_TOKENS
|
||||
|
||||
# Create sampler for this request
|
||||
sampler = make_sampler(
|
||||
temp=task_params.temperature
|
||||
if task_params.temperature is not None
|
||||
else 0.7,
|
||||
top_p=task_params.top_p if task_params.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Logprobs configuration
|
||||
should_extract_logprobs = task_params.logprobs is True
|
||||
top_k = task_params.top_logprobs if task_params.top_logprobs is not None else 0
|
||||
|
||||
pending_request = PendingRequest(
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
should_extract_logprobs=should_extract_logprobs,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
self.pending.append(pending_request)
|
||||
|
||||
logger.info(
|
||||
f"Added request to batch queue (pending={len(self.pending)}, active={self.current_batch_size})"
|
||||
)
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Start processing pending requests by adding them to the batch/pipelined generator."""
|
||||
if not self.has_pending:
|
||||
return
|
||||
|
||||
# Determine how many requests to flush (up to available slots)
|
||||
available_slots = self.max_batch_size - self.current_batch_size
|
||||
requests_to_flush = self.pending[:available_slots]
|
||||
self.pending = self.pending[available_slots:]
|
||||
|
||||
# Prepare batch data - tokenize prompts
|
||||
tokenized_prompts: list[list[int]] = []
|
||||
max_tokens_list: list[int] = []
|
||||
samplers: list[Callable[[mx.array], mx.array]] = []
|
||||
prompt_token_counts: list[int] = []
|
||||
|
||||
for req in requests_to_flush:
|
||||
tokens = self.tokenizer.encode(req.prompt)
|
||||
tokenized_prompts.append(tokens)
|
||||
max_tokens_list.append(req.max_tokens)
|
||||
samplers.append(req.sampler)
|
||||
prompt_token_counts.append(len(tokens))
|
||||
|
||||
if self.use_pipelined:
|
||||
self._flush_pipelined(
|
||||
requests_to_flush,
|
||||
tokenized_prompts,
|
||||
max_tokens_list,
|
||||
samplers,
|
||||
prompt_token_counts,
|
||||
)
|
||||
else:
|
||||
self._flush_batch(
|
||||
requests_to_flush,
|
||||
tokenized_prompts,
|
||||
max_tokens_list,
|
||||
samplers,
|
||||
prompt_token_counts,
|
||||
)
|
||||
|
||||
def _flush_pipelined(
|
||||
self,
|
||||
requests_to_flush: list[PendingRequest],
|
||||
tokenized_prompts: list[list[int]],
|
||||
max_tokens_list: list[int],
|
||||
samplers: list[Callable[[mx.array], mx.array]],
|
||||
prompt_token_counts: list[int],
|
||||
) -> None:
|
||||
"""Flush using PipelinedGenerator (multi-stream pipeline overlap)."""
|
||||
if self.pipelined_generator is None:
|
||||
logger.info(
|
||||
f"Creating PipelinedGenerator for {len(requests_to_flush)} requests ({self.world_size} streams)"
|
||||
)
|
||||
mx.reset_peak_memory()
|
||||
self.pipelined_generator = PipelinedGenerator(
|
||||
model=self.model,
|
||||
world_size=self.world_size,
|
||||
stop_tokens=self.stop_tokens if self.stop_tokens else None,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Adding {len(requests_to_flush)} requests to PipelinedGenerator"
|
||||
)
|
||||
|
||||
uids = self.pipelined_generator.insert(
|
||||
prompts=tokenized_prompts,
|
||||
max_tokens=max_tokens_list,
|
||||
samplers=samplers,
|
||||
)
|
||||
|
||||
for uid, req, prompt_tokens, tokens in zip(
|
||||
uids, requests_to_flush, prompt_token_counts, tokenized_prompts, strict=True
|
||||
):
|
||||
parser = None
|
||||
if self.is_gpt_oss and self._harmony_encoding is not None:
|
||||
parser = StreamableParser(self._harmony_encoding, role=Role.ASSISTANT) # pyright: ignore[reportAny]
|
||||
# Check if prompt contains <think> token - if so, model is already in thinking mode
|
||||
starts_in_thinking = (
|
||||
self._think_start_token is not None
|
||||
and self._think_start_token in tokens
|
||||
)
|
||||
self.uid_to_request[uid] = ActiveRequest(
|
||||
command_id=req.task.command_id,
|
||||
should_extract_logprobs=req.should_extract_logprobs,
|
||||
top_k=req.top_k,
|
||||
prompt_tokens=prompt_tokens,
|
||||
harmony_parser=parser,
|
||||
in_thinking=starts_in_thinking,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Flushed {len(requests_to_flush)} requests into pipelined generator (active={self.pipelined_generator.active_count}, uids={list(self.uid_to_request.keys())})"
|
||||
)
|
||||
|
||||
def _flush_batch(
|
||||
self,
|
||||
requests_to_flush: list[PendingRequest],
|
||||
tokenized_prompts: list[list[int]],
|
||||
max_tokens_list: list[int],
|
||||
samplers: list[Callable[[mx.array], mx.array]],
|
||||
prompt_token_counts: list[int],
|
||||
) -> None:
|
||||
"""Flush using BatchGenerator (single-stream, for non-pipeline instances)."""
|
||||
if self.batch_generator is None:
|
||||
logger.info(
|
||||
f"Creating new BatchGenerator for {len(requests_to_flush)} requests"
|
||||
)
|
||||
mx.reset_peak_memory()
|
||||
self.batch_generator = BatchGenerator(
|
||||
model=self.model,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop_tokens=self.stop_tokens if self.stop_tokens else None,
|
||||
prefill_batch_size=1,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Adding {len(requests_to_flush)} requests to existing BatchGenerator"
|
||||
)
|
||||
|
||||
# Insert into batch generator
|
||||
uids: list[int] = self.batch_generator.insert( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
prompts=tokenized_prompts,
|
||||
max_tokens=max_tokens_list,
|
||||
samplers=samplers, # pyright: ignore[reportCallIssue]
|
||||
)
|
||||
|
||||
for uid, req, prompt_tokens, tokens in zip(
|
||||
uids, requests_to_flush, prompt_token_counts, tokenized_prompts, strict=True
|
||||
): # pyright: ignore[reportUnknownArgumentType]
|
||||
parser = None
|
||||
if self.is_gpt_oss and self._harmony_encoding is not None:
|
||||
parser = StreamableParser(self._harmony_encoding, role=Role.ASSISTANT) # pyright: ignore[reportAny]
|
||||
# Check if prompt contains <think> token - if so, model is already in thinking mode
|
||||
starts_in_thinking = (
|
||||
self._think_start_token is not None
|
||||
and self._think_start_token in tokens
|
||||
)
|
||||
self.uid_to_request[uid] = ActiveRequest(
|
||||
command_id=req.task.command_id,
|
||||
should_extract_logprobs=req.should_extract_logprobs,
|
||||
top_k=req.top_k,
|
||||
prompt_tokens=prompt_tokens,
|
||||
harmony_parser=parser,
|
||||
in_thinking=starts_in_thinking,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Flushed {len(requests_to_flush)} requests into batch (active={self.current_batch_size}, uids={list(self.uid_to_request.keys())})"
|
||||
)
|
||||
|
||||
def step(self) -> Generator[Event, None, None]:
|
||||
"""
|
||||
Process one generation step and yield ChunkGenerated events.
|
||||
|
||||
Returns a generator of events for completed tokens across all active requests.
|
||||
"""
|
||||
if self.use_pipelined:
|
||||
yield from self._step_pipelined()
|
||||
return
|
||||
|
||||
if self.batch_generator is None or not self.uid_to_request:
|
||||
return
|
||||
|
||||
# Get next tokens for all active requests
|
||||
# BatchGenerator.next() returns list of Response objects
|
||||
logger.debug(
|
||||
f"BatchGenerator.next() called (active_uids={list(self.uid_to_request.keys())})"
|
||||
)
|
||||
responses: list[Any] = self.batch_generator.next() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
logger.debug(f"BatchGenerator.next() returned {len(responses)} responses") # pyright: ignore[reportUnknownArgumentType]
|
||||
|
||||
completed_uids: list[int] = []
|
||||
|
||||
for response in responses: # pyright: ignore[reportUnknownVariableType]
|
||||
uid: int = response.uid # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if uid not in self.uid_to_request:
|
||||
logger.warning(f"Received response for unknown uid: {uid}")
|
||||
continue
|
||||
|
||||
active_request = self.uid_to_request[uid]
|
||||
active_request.tokens_generated += 1
|
||||
|
||||
# Extract response fields with explicit typing
|
||||
resp_token: int = response.token # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
resp_finish_reason: str | None = response.finish_reason # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
resp_logprobs: mx.array = response.logprobs # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
# Only emit events from device_rank 0
|
||||
if self.device_rank != 0:
|
||||
if resp_finish_reason is not None:
|
||||
completed_uids.append(uid) # pyright: ignore[reportUnknownArgumentType]
|
||||
continue
|
||||
|
||||
# Decode token to text
|
||||
# Skip emitting EOS token text (e.g., <|eot_id|>)
|
||||
if resp_token in self.stop_tokens:
|
||||
token_text = ""
|
||||
else:
|
||||
token_text = self.tokenizer.decode([resp_token])
|
||||
|
||||
# Handle thinking/reasoning token tracking
|
||||
if active_request.harmony_parser is not None:
|
||||
# GPT-OSS: Use harmony parser for channel-based thinking detection
|
||||
parser = active_request.harmony_parser # pyright: ignore[reportAny]
|
||||
parser.process(resp_token) # pyright: ignore[reportAny]
|
||||
delta: str | None = parser.last_content_delta # pyright: ignore[reportAny]
|
||||
channel: str = parser.current_channel # pyright: ignore[reportAny]
|
||||
|
||||
# Track reasoning tokens (analysis channel = thinking)
|
||||
if channel == "analysis":
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
# Handle thinking tag transitions
|
||||
prefix = ""
|
||||
if channel == "analysis" and not active_request.in_thinking:
|
||||
active_request.in_thinking = True
|
||||
prefix = "<think>"
|
||||
elif channel != "analysis" and active_request.in_thinking:
|
||||
active_request.in_thinking = False
|
||||
prefix = "</think>"
|
||||
|
||||
if resp_finish_reason is not None and active_request.in_thinking:
|
||||
# Close thinking tag on finish
|
||||
prefix = "</think>"
|
||||
active_request.in_thinking = False
|
||||
|
||||
effective_delta = delta or ""
|
||||
token_text = (
|
||||
prefix + effective_delta if (prefix or effective_delta) else ""
|
||||
)
|
||||
# Skip empty tokens (channel markers with no content delta)
|
||||
if not token_text and resp_finish_reason is None:
|
||||
continue
|
||||
elif self._think_start_token is not None:
|
||||
# MiniMax: Track <think>/</ think> tokens directly
|
||||
if resp_token == self._think_start_token:
|
||||
active_request.in_thinking = True
|
||||
elif resp_token == self._think_end_token:
|
||||
active_request.in_thinking = False
|
||||
elif active_request.in_thinking:
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
# Extract logprobs if requested
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if active_request.should_extract_logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs_array=resp_logprobs, # pyright: ignore[reportUnknownArgumentType]
|
||||
selected_token=resp_token, # pyright: ignore[reportUnknownArgumentType]
|
||||
tokenizer=self.tokenizer,
|
||||
top_k=active_request.top_k,
|
||||
)
|
||||
|
||||
# Build stats for final token
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: TokenFinishReason | None = None
|
||||
if resp_finish_reason is not None:
|
||||
elapsed_time = time.perf_counter() - active_request.start_time
|
||||
prompt_tps = active_request.prompt_tokens / max(elapsed_time, 0.001)
|
||||
generation_tps = active_request.tokens_generated / max(
|
||||
elapsed_time, 0.001
|
||||
)
|
||||
|
||||
# Get peak memory
|
||||
peak_memory_bytes = 0
|
||||
if mx.metal.is_available():
|
||||
peak_memory_bytes = mx.metal.get_peak_memory()
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=active_request.prompt_tokens,
|
||||
generation_tokens=active_request.tokens_generated,
|
||||
reasoning_tokens=active_request.reasoning_tokens,
|
||||
peak_memory_usage=Memory.from_bytes(peak_memory_bytes),
|
||||
)
|
||||
|
||||
# Map finish reason to the narrower type TokenChunk expects
|
||||
if resp_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif resp_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
elif resp_finish_reason == "content_filter":
|
||||
finish_reason = "content_filter"
|
||||
else:
|
||||
# Unknown finish reasons default to "stop"
|
||||
logger.warning(
|
||||
f"Unknown finish_reason: {resp_finish_reason}, mapping to 'stop'"
|
||||
)
|
||||
finish_reason = "stop"
|
||||
|
||||
completed_uids.append(uid) # pyright: ignore[reportUnknownArgumentType]
|
||||
|
||||
yield ChunkGenerated(
|
||||
command_id=active_request.command_id,
|
||||
chunk=TokenChunk(
|
||||
model=self.model_id,
|
||||
text=token_text,
|
||||
token_id=resp_token, # pyright: ignore[reportUnknownArgumentType]
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
),
|
||||
)
|
||||
|
||||
# Clean up completed requests
|
||||
for uid in completed_uids:
|
||||
del self.uid_to_request[uid]
|
||||
|
||||
def _step_pipelined(self) -> Generator[Event, None, None]:
|
||||
"""Process one generation step using the multi-stream PipelinedGenerator."""
|
||||
if self.pipelined_generator is None or not self.uid_to_request:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"PipelinedGenerator.next() called (active={self.pipelined_generator.active_count})"
|
||||
)
|
||||
responses: list[PipelinedResponse] = self.pipelined_generator.next()
|
||||
logger.debug(f"PipelinedGenerator.next() returned {len(responses)} responses")
|
||||
|
||||
completed_uids: list[int] = []
|
||||
|
||||
for response in responses:
|
||||
uid = response.uid
|
||||
if uid not in self.uid_to_request:
|
||||
logger.warning(f"Received response for unknown uid: {uid}")
|
||||
continue
|
||||
|
||||
active_request = self.uid_to_request[uid]
|
||||
active_request.tokens_generated += 1
|
||||
|
||||
resp_token: int = response.token
|
||||
resp_finish_reason: str | None = response.finish_reason
|
||||
resp_logprobs: mx.array = response.logprobs
|
||||
|
||||
# Only emit events from device_rank 0
|
||||
if self.device_rank != 0:
|
||||
if resp_finish_reason is not None:
|
||||
completed_uids.append(uid)
|
||||
continue
|
||||
|
||||
# Decode token to text
|
||||
# Skip emitting EOS token text (e.g., <|eot_id|>)
|
||||
if resp_token in self.stop_tokens:
|
||||
token_text = ""
|
||||
else:
|
||||
token_text = self.tokenizer.decode([resp_token])
|
||||
|
||||
# Handle thinking/reasoning token tracking
|
||||
if active_request.harmony_parser is not None:
|
||||
# GPT-OSS: Use harmony parser for channel-based thinking detection
|
||||
parser = active_request.harmony_parser # pyright: ignore[reportAny]
|
||||
parser.process(resp_token) # pyright: ignore[reportAny]
|
||||
delta: str | None = parser.last_content_delta # pyright: ignore[reportAny]
|
||||
channel: str = parser.current_channel # pyright: ignore[reportAny]
|
||||
|
||||
if channel == "analysis":
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
prefix = ""
|
||||
if channel == "analysis" and not active_request.in_thinking:
|
||||
active_request.in_thinking = True
|
||||
prefix = "<think>"
|
||||
elif channel != "analysis" and active_request.in_thinking:
|
||||
active_request.in_thinking = False
|
||||
prefix = "</think>"
|
||||
|
||||
if resp_finish_reason is not None and active_request.in_thinking:
|
||||
prefix = "</think>"
|
||||
active_request.in_thinking = False
|
||||
|
||||
effective_delta = delta or ""
|
||||
token_text = (
|
||||
prefix + effective_delta if (prefix or effective_delta) else ""
|
||||
)
|
||||
if not token_text and resp_finish_reason is None:
|
||||
continue
|
||||
elif self._think_start_token is not None:
|
||||
# MiniMax: Track <think>/</think> tokens directly
|
||||
if resp_token == self._think_start_token:
|
||||
active_request.in_thinking = True
|
||||
elif resp_token == self._think_end_token:
|
||||
active_request.in_thinking = False
|
||||
elif active_request.in_thinking:
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
# Extract logprobs if requested
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if active_request.should_extract_logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs_array=resp_logprobs,
|
||||
selected_token=resp_token,
|
||||
tokenizer=self.tokenizer,
|
||||
top_k=active_request.top_k,
|
||||
)
|
||||
|
||||
# Build stats for final token
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: TokenFinishReason | None = None
|
||||
if resp_finish_reason is not None:
|
||||
elapsed_time = time.perf_counter() - active_request.start_time
|
||||
prompt_tps = active_request.prompt_tokens / max(elapsed_time, 0.001)
|
||||
generation_tps = active_request.tokens_generated / max(
|
||||
elapsed_time, 0.001
|
||||
)
|
||||
|
||||
peak_memory_bytes = 0
|
||||
if mx.metal.is_available():
|
||||
peak_memory_bytes = mx.metal.get_peak_memory()
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=active_request.prompt_tokens,
|
||||
generation_tokens=active_request.tokens_generated,
|
||||
reasoning_tokens=active_request.reasoning_tokens,
|
||||
peak_memory_usage=Memory.from_bytes(peak_memory_bytes),
|
||||
)
|
||||
|
||||
if resp_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif resp_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
|
||||
completed_uids.append(uid)
|
||||
|
||||
yield ChunkGenerated(
|
||||
command_id=active_request.command_id,
|
||||
chunk=TokenChunk(
|
||||
model=self.model_id,
|
||||
text=token_text,
|
||||
token_id=resp_token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
),
|
||||
)
|
||||
|
||||
for uid in completed_uids:
|
||||
del self.uid_to_request[uid]
|
||||
|
||||
def emit_error(self, command_id: CommandId, error_message: str) -> Event:
|
||||
"""Create an error event for a failed request."""
|
||||
return ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=self.model_id,
|
||||
finish_reason="error",
|
||||
error_message=error_message,
|
||||
),
|
||||
)
|
||||
|
||||
def _close_generator(self) -> None:
|
||||
"""Close and clean up the batch/pipelined generator."""
|
||||
if self.batch_generator is not None:
|
||||
self.batch_generator.close() # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue]
|
||||
self.batch_generator = None
|
||||
if self.pipelined_generator is not None:
|
||||
self.pipelined_generator.close()
|
||||
self.pipelined_generator = None
|
||||
self.uid_to_request.clear()
|
||||
logger.info("Generator closed")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the handler and clean up resources."""
|
||||
self._close_generator()
|
||||
self.pending.clear()
|
||||
200
src/exo/worker/runner/batched_scoring_handler.py
Normal file
200
src/exo/worker/runner/batched_scoring_handler.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Batched scoring handler for processing multiple Completion requests concurrently."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import TopLogprobItem
|
||||
from exo.shared.types.chunks import CompletionChunk, ErrorChunk
|
||||
from exo.shared.types.events import ChunkGenerated, Event
|
||||
from exo.shared.types.tasks import Completion
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import score_tokens_batched
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingScoringRequest:
|
||||
"""A scoring request waiting to be batched."""
|
||||
|
||||
task: Completion
|
||||
tokens: list[int]
|
||||
prompt_text: str
|
||||
top_k: int | None
|
||||
echo: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedScoringHandler:
|
||||
"""
|
||||
Handles batched scoring for multiple Completion requests.
|
||||
|
||||
Collects multiple scoring requests and processes them in a single
|
||||
batched forward pass for improved throughput.
|
||||
"""
|
||||
|
||||
model: Model
|
||||
tokenizer: TokenizerWrapper
|
||||
model_id: ModelId
|
||||
device_rank: int
|
||||
max_batch_size: int = 32
|
||||
batch_timeout_ms: int = 10
|
||||
|
||||
pending: list[PendingScoringRequest] = field(default_factory=list)
|
||||
pending_start_time: float | None = None
|
||||
|
||||
@property
|
||||
def has_pending(self) -> bool:
|
||||
"""Check if there are pending requests."""
|
||||
return len(self.pending) > 0
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
task: Completion,
|
||||
tokens: list[int],
|
||||
prompt_text: str,
|
||||
) -> None:
|
||||
"""Add a Completion request to the pending batch."""
|
||||
task_params = task.task_params
|
||||
top_k = task_params.logprobs
|
||||
|
||||
self.pending.append(
|
||||
PendingScoringRequest(
|
||||
task=task,
|
||||
tokens=tokens,
|
||||
prompt_text=prompt_text,
|
||||
top_k=top_k,
|
||||
echo=task_params.echo,
|
||||
)
|
||||
)
|
||||
|
||||
if self.pending_start_time is None:
|
||||
self.pending_start_time = time.perf_counter()
|
||||
|
||||
logger.debug(f"Added scoring request to batch (pending={len(self.pending)})")
|
||||
|
||||
def should_flush(self) -> bool:
|
||||
"""Check if the batch should be flushed."""
|
||||
if not self.has_pending:
|
||||
return False
|
||||
|
||||
# Flush if batch is full
|
||||
if len(self.pending) >= self.max_batch_size:
|
||||
return True
|
||||
|
||||
# Flush if timeout reached
|
||||
if self.pending_start_time is not None:
|
||||
elapsed_ms = (time.perf_counter() - self.pending_start_time) * 1000
|
||||
if elapsed_ms >= self.batch_timeout_ms:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def flush(self) -> list[Event]:
|
||||
"""Process all pending requests and return events."""
|
||||
if not self.has_pending:
|
||||
return []
|
||||
|
||||
requests = self.pending
|
||||
self.pending = []
|
||||
self.pending_start_time = None
|
||||
|
||||
logger.info(f"Processing batch of {len(requests)} scoring requests")
|
||||
|
||||
# Collect all token sequences
|
||||
token_sequences = [req.tokens for req in requests]
|
||||
|
||||
# Get common top_k (use first request's top_k, they should all be the same)
|
||||
top_k = requests[0].top_k if requests else None
|
||||
|
||||
try:
|
||||
# Run batched scoring
|
||||
all_results = score_tokens_batched(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
token_sequences=token_sequences,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# Generate events for each request
|
||||
events: list[Event] = []
|
||||
for req, logprob_results in zip(requests, all_results, strict=True):
|
||||
if self.device_rank != 0:
|
||||
continue
|
||||
|
||||
event = self._build_completion_event(req, logprob_results)
|
||||
events.append(event)
|
||||
|
||||
logger.info(f"Batch scoring complete ({len(events)} events)")
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
# Return error events for all requests
|
||||
logger.error(f"Batch scoring failed: {e}")
|
||||
events = []
|
||||
for req in requests:
|
||||
if self.device_rank == 0:
|
||||
events.append(
|
||||
ChunkGenerated(
|
||||
command_id=req.task.command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=self.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
def _build_completion_event(
|
||||
self,
|
||||
req: PendingScoringRequest,
|
||||
logprob_results: list[tuple[float, list[TopLogprobItem]]],
|
||||
) -> Event:
|
||||
"""Build a ChunkGenerated event for a completed scoring request."""
|
||||
tokens = req.tokens
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
# Build response in completions format
|
||||
token_strings: list[str] = []
|
||||
token_logprobs: list[float | None] = []
|
||||
top_logprobs: list[dict[str, float]] = []
|
||||
text_offset: list[int] = []
|
||||
|
||||
offset = 0
|
||||
for i, token_id in enumerate(tokens):
|
||||
token_str = tokenizer.decode([token_id])
|
||||
token_strings.append(token_str)
|
||||
|
||||
if i < len(logprob_results):
|
||||
logprob, top_items = logprob_results[i]
|
||||
# First token has no logprob (None in OpenAI format)
|
||||
token_logprobs.append(logprob if i > 0 else None)
|
||||
top_lp_dict = {item.token: item.logprob for item in top_items}
|
||||
top_logprobs.append(top_lp_dict)
|
||||
else:
|
||||
token_logprobs.append(None)
|
||||
top_logprobs.append({})
|
||||
|
||||
text_offset.append(offset)
|
||||
offset += len(token_str)
|
||||
|
||||
return ChunkGenerated(
|
||||
command_id=req.task.command_id,
|
||||
chunk=CompletionChunk(
|
||||
model=self.model_id,
|
||||
text=req.prompt_text if req.echo else "",
|
||||
tokens=token_strings,
|
||||
token_logprobs=token_logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
text_offset=text_offset,
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
self.pending.clear()
|
||||
self.pending_start_time = None
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
@@ -15,7 +15,6 @@ def entrypoint(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
@@ -39,7 +38,7 @@ def entrypoint(
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
334
src/exo/worker/runner/pipelined_generator.py
Normal file
334
src/exo/worker/runner/pipelined_generator.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""Multi-stream pipelined batch generator for pipeline-parallel inference.
|
||||
|
||||
When a model is split across N ranks (pipeline parallelism), each rank's GPU is idle
|
||||
for (N-1)/N of each step while waiting for other ranks to compute their layers.
|
||||
|
||||
This module fills the pipeline bubble by splitting sequences into N micro-batch groups
|
||||
and processing each group on a different MLX stream. The GPU can overlap one stream's
|
||||
network communication (send/recv/all_gather) with another stream's compute.
|
||||
"""
|
||||
|
||||
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false
|
||||
# pyright: reportUnknownArgumentType=false, reportAny=false
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicroBatch:
|
||||
"""State for one micro-batch group of sequences."""
|
||||
|
||||
uids: list[int]
|
||||
y: mx.array # Last sampled tokens [batch]
|
||||
logprobs: list[mx.array] # Logprobs for each sequence
|
||||
max_tokens: list[int]
|
||||
num_tokens: list[int]
|
||||
cache: list[Any] # KV cache (list of layer caches)
|
||||
samplers: list[Callable[[mx.array], mx.array]]
|
||||
tokens: list[mx.array] # All tokens generated so far per sequence
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.uids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelinedResponse:
|
||||
"""Response from one generation step."""
|
||||
|
||||
uid: int
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
finish_reason: str | None
|
||||
cache: list[Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingPrompt:
|
||||
"""A prompt waiting to be prefilled."""
|
||||
|
||||
uid: int
|
||||
tokens: list[int]
|
||||
max_tokens: int
|
||||
sampler: Callable[[mx.array], mx.array]
|
||||
|
||||
|
||||
class PipelinedGenerator:
|
||||
"""
|
||||
Multi-stream batch generator that fills pipeline bubbles.
|
||||
|
||||
Splits active sequences into `world_size` micro-batch groups, each processed
|
||||
on its own MLX stream. During mx.eval(), the GPU overlaps network operations
|
||||
on one stream with compute on another.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
world_size: int,
|
||||
stop_tokens: set[int] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
self.model = model
|
||||
self.world_size = world_size
|
||||
self.stop_tokens = stop_tokens or set()
|
||||
self.max_tokens_default = max_tokens
|
||||
|
||||
# Create one stream per pipeline stage
|
||||
self.streams = [mx.new_stream(mx.default_device()) for _ in range(world_size)]
|
||||
|
||||
# Micro-batch groups (one per stream)
|
||||
self.micro_batches: list[MicroBatch | None] = [None] * world_size
|
||||
|
||||
# Pending prompts to be inserted
|
||||
self.pending_prompts: list[PendingPrompt] = []
|
||||
|
||||
# UID counter
|
||||
self._next_uid = 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
"""Total number of active sequences across all micro-batches."""
|
||||
return sum(len(mb) for mb in self.micro_batches if mb is not None)
|
||||
|
||||
@property
|
||||
def has_active(self) -> bool:
|
||||
return self.active_count > 0 or len(self.pending_prompts) > 0
|
||||
|
||||
def insert(
|
||||
self,
|
||||
prompts: list[list[int]],
|
||||
max_tokens: list[int],
|
||||
samplers: list[Callable[[mx.array], mx.array]],
|
||||
) -> list[int]:
|
||||
"""Queue prompts for processing. Returns assigned UIDs."""
|
||||
uids: list[int] = []
|
||||
for prompt, mt, sampler in zip(prompts, max_tokens, samplers, strict=True):
|
||||
uid = self._next_uid
|
||||
self._next_uid += 1
|
||||
self.pending_prompts.append(
|
||||
PendingPrompt(uid=uid, tokens=prompt, max_tokens=mt, sampler=sampler)
|
||||
)
|
||||
uids.append(uid)
|
||||
return uids
|
||||
|
||||
def _prefill_group(self, group_idx: int, prompts: list[PendingPrompt]) -> None:
|
||||
"""Prefill a group of prompts and create a MicroBatch."""
|
||||
if not prompts:
|
||||
return
|
||||
|
||||
stream = self.streams[group_idx]
|
||||
|
||||
with mx.stream(stream):
|
||||
# Create per-sequence caches
|
||||
caches = [make_prompt_cache(self.model) for _ in prompts]
|
||||
|
||||
# Tokenize and prefill each sequence
|
||||
all_y: list[mx.array] = []
|
||||
all_logprobs: list[mx.array] = []
|
||||
all_samplers: list[Callable[[mx.array], mx.array]] = []
|
||||
all_tokens: list[mx.array] = []
|
||||
|
||||
for prompt_info, cache in zip(prompts, caches, strict=True):
|
||||
tokens = mx.array(prompt_info.tokens)
|
||||
# Run prefill (process all tokens except last)
|
||||
if len(prompt_info.tokens) > 1:
|
||||
self.model(tokens[:-1][None, :], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
|
||||
# Process last token to get first generation logits
|
||||
last_token = tokens[-1:][None, :]
|
||||
logits = self.model(last_token, cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
sampled = prompt_info.sampler(logprobs)
|
||||
|
||||
all_y.append(sampled.squeeze(0))
|
||||
all_logprobs.append(logprobs.squeeze(0))
|
||||
all_samplers.append(prompt_info.sampler)
|
||||
all_tokens.append(tokens)
|
||||
|
||||
mx.eval(*all_y, *all_logprobs)
|
||||
|
||||
# Create micro-batch
|
||||
batch = MicroBatch(
|
||||
uids=[p.uid for p in prompts],
|
||||
y=mx.stack(all_y),
|
||||
logprobs=all_logprobs,
|
||||
max_tokens=[p.max_tokens for p in prompts],
|
||||
num_tokens=[0] * len(prompts),
|
||||
cache=caches,
|
||||
samplers=all_samplers,
|
||||
tokens=all_tokens,
|
||||
)
|
||||
|
||||
if self.micro_batches[group_idx] is None:
|
||||
self.micro_batches[group_idx] = batch
|
||||
else:
|
||||
# Extend existing micro-batch (would need cache merging - for now replace)
|
||||
self.micro_batches[group_idx] = batch
|
||||
|
||||
def _prefill_pending(self) -> None:
|
||||
"""Distribute pending prompts across micro-batch groups and prefill."""
|
||||
if not self.pending_prompts:
|
||||
return
|
||||
|
||||
# Distribute round-robin across groups
|
||||
groups: list[list[PendingPrompt]] = [[] for _ in range(self.world_size)]
|
||||
for i, prompt in enumerate(self.pending_prompts):
|
||||
groups[i % self.world_size].append(prompt)
|
||||
self.pending_prompts.clear()
|
||||
|
||||
for group_idx, group_prompts in enumerate(groups):
|
||||
if group_prompts:
|
||||
self._prefill_group(group_idx, group_prompts)
|
||||
|
||||
def _step_all(self) -> None:
|
||||
"""
|
||||
Run one generation step across all micro-batch groups on different streams.
|
||||
|
||||
This is where pipeline overlap happens: each group's model forward pass
|
||||
runs on its own stream, and mx.eval() allows the GPU to overlap network
|
||||
ops (send/recv/all_gather) from one stream with compute from another.
|
||||
|
||||
Each sequence is processed individually with its own KV cache, but all
|
||||
lazy graphs across streams are evaluated together for GPU overlap.
|
||||
"""
|
||||
# Build computation graphs on each stream (lazy, no evaluation yet)
|
||||
# Each micro-batch group processes its sequences on its own stream.
|
||||
all_sampled: list[mx.array] = []
|
||||
all_logprobs: list[mx.array] = []
|
||||
# Track which (group_idx, seq_idx) each result corresponds to
|
||||
result_map: list[tuple[int, int]] = []
|
||||
|
||||
for i, mb in enumerate(self.micro_batches):
|
||||
if mb is None or len(mb) == 0:
|
||||
continue
|
||||
|
||||
with mx.stream(self.streams[i]):
|
||||
for e in range(len(mb)):
|
||||
# Process each sequence individually with its own cache
|
||||
input_token = mb.y[e : e + 1][None, :] # [1, 1]
|
||||
|
||||
# Forward pass (lazy graph construction)
|
||||
# For pipeline models, this includes send/recv/all_gather ops
|
||||
logits = self.model(input_token, cache=mb.cache[e])
|
||||
logits = logits[:, -1, :] # [1, vocab]
|
||||
|
||||
# Compute logprobs
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
|
||||
# Sample
|
||||
sampled = mb.samplers[e](logprobs)
|
||||
|
||||
all_sampled.append(sampled.squeeze(0))
|
||||
all_logprobs.append(logprobs.squeeze(0))
|
||||
result_map.append((i, e))
|
||||
|
||||
if not result_map:
|
||||
return
|
||||
|
||||
# Evaluate ALL streams together - this is where overlap happens!
|
||||
# The GPU can execute stream0's all_gather while computing stream1's layers.
|
||||
mx.eval(*all_sampled, *all_logprobs)
|
||||
|
||||
# Update micro-batch states with results
|
||||
# Group results by micro-batch for efficient update
|
||||
group_results: dict[int, list[int]] = {}
|
||||
for idx, (group_idx, _seq_idx) in enumerate(result_map):
|
||||
group_results.setdefault(group_idx, []).append(idx)
|
||||
|
||||
for group_idx, result_indices in group_results.items():
|
||||
mb = self.micro_batches[group_idx]
|
||||
assert mb is not None
|
||||
group_sampled = [all_sampled[idx] for idx in result_indices]
|
||||
group_logprobs = [all_logprobs[idx] for idx in result_indices]
|
||||
mb.y = mx.stack(group_sampled)
|
||||
mb.logprobs = group_logprobs
|
||||
for e, idx in enumerate(result_indices):
|
||||
mb.tokens[e] = mx.concatenate([mb.tokens[e], all_sampled[idx][None]])
|
||||
|
||||
def next(self) -> list[PipelinedResponse]:
|
||||
"""
|
||||
Run one generation step and return responses.
|
||||
|
||||
Returns a PipelinedResponse for each active sequence (across all groups).
|
||||
Finished sequences are removed from their micro-batch.
|
||||
"""
|
||||
# Prefill any pending prompts first
|
||||
self._prefill_pending()
|
||||
|
||||
if not self.has_active:
|
||||
return []
|
||||
|
||||
# Run the multi-stream forward pass
|
||||
self._step_all()
|
||||
|
||||
# Collect responses and filter completed sequences
|
||||
responses: list[PipelinedResponse] = []
|
||||
|
||||
for group_idx, mb in enumerate(self.micro_batches):
|
||||
if mb is None or len(mb) == 0:
|
||||
continue
|
||||
|
||||
keep_idx: list[int] = []
|
||||
end_idx: list[int] = []
|
||||
|
||||
for e in range(len(mb)):
|
||||
token = int(mb.y[e].item())
|
||||
uid = mb.uids[e]
|
||||
num_tok = mb.num_tokens[e] + 1
|
||||
max_tok = mb.max_tokens[e]
|
||||
mb.num_tokens[e] = num_tok
|
||||
|
||||
if token in self.stop_tokens:
|
||||
finish_reason = "stop"
|
||||
end_idx.append(e)
|
||||
elif num_tok >= max_tok:
|
||||
finish_reason = "length"
|
||||
end_idx.append(e)
|
||||
else:
|
||||
finish_reason = None
|
||||
keep_idx.append(e)
|
||||
|
||||
responses.append(
|
||||
PipelinedResponse(
|
||||
uid=uid,
|
||||
token=token,
|
||||
logprobs=mb.logprobs[e],
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# Remove finished sequences
|
||||
if end_idx:
|
||||
if keep_idx:
|
||||
# Filter the micro-batch to keep only active sequences
|
||||
mb.uids = [mb.uids[i] for i in keep_idx]
|
||||
mb.y = mb.y[mx.array(keep_idx)]
|
||||
mb.logprobs = [mb.logprobs[i] for i in keep_idx]
|
||||
mb.max_tokens = [mb.max_tokens[i] for i in keep_idx]
|
||||
mb.num_tokens = [mb.num_tokens[i] for i in keep_idx]
|
||||
mb.samplers = [mb.samplers[i] for i in keep_idx]
|
||||
mb.tokens = [mb.tokens[i] for i in keep_idx]
|
||||
# Cache filtering: trim batch dimension
|
||||
for c in mb.cache:
|
||||
if hasattr(c, "keys") and c.keys is not None:
|
||||
c.keys = c.keys[mx.array(keep_idx)]
|
||||
c.values = c.values[mx.array(keep_idx)]
|
||||
else:
|
||||
self.micro_batches[group_idx] = None
|
||||
|
||||
return responses
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
self.micro_batches = [None] * self.world_size
|
||||
self.pending_prompts.clear()
|
||||
@@ -6,6 +6,7 @@ from functools import cache
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from anyio import EndOfStream, WouldBlock
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
@@ -37,7 +38,6 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -62,7 +62,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.shared.types.worker.shards import ShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.image import (
|
||||
DistributedImageModel,
|
||||
@@ -71,23 +71,145 @@ from exo.worker.engines.image import (
|
||||
warmup_image_generator,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import (
|
||||
mlx_generate,
|
||||
warmup_inference,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
detect_thinking_prompt_suffix,
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
mx_any,
|
||||
)
|
||||
from exo.worker.runner.batched_handler import BatchedInferenceHandler
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Batching configuration
|
||||
BATCH_ENABLED = True
|
||||
BATCH_MAX_SIZE = 32
|
||||
|
||||
|
||||
def _should_use_serial_processing(
|
||||
task: ChatCompletion,
|
||||
tokenizer: TokenizerWrapper,
|
||||
model: Model,
|
||||
model_id: ModelId,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a ChatCompletion task requires serial processing.
|
||||
|
||||
Currently always returns False - batch mode handles all cases.
|
||||
Post-processing (GPT-OSS, thinking models, tool calls) can be applied
|
||||
per-request to the individual streams from the batch generator.
|
||||
"""
|
||||
# All tasks can use batch mode - post-processing is per-request
|
||||
return False
|
||||
|
||||
|
||||
def _process_serial_chat_completion(
|
||||
task: ChatCompletion,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
shard_metadata: ShardMetadata,
|
||||
event_sender: MpSender[Event],
|
||||
) -> None:
|
||||
"""Process a ChatCompletion task serially (original implementation)."""
|
||||
task_params = task.task_params
|
||||
command_id = task.command_id
|
||||
device_rank = shard_metadata.device_rank
|
||||
|
||||
if task_params.messages[0].content is not None:
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Build prompt once - used for both generation and thinking detection
|
||||
prompt = apply_chat_template(tokenizer, task_params)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
mlx_generator = parse_thinking_models(mlx_generator, tokenizer)
|
||||
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if device_rank == 0 and response.finish_reason == "error":
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
error_message=response.text,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif device_rank == 0:
|
||||
assert response.finish_reason not in (
|
||||
"error",
|
||||
"tool_calls",
|
||||
"function_call",
|
||||
)
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
instance, runner_id, shard_metadata = (
|
||||
bound_instance.instance,
|
||||
@@ -102,242 +224,186 @@ def main(
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
cancelled_tasks = set[TaskId]()
|
||||
|
||||
# type checker was unhappy with me - splitting these fixed it
|
||||
inference_model: Model | None = None
|
||||
image_model: DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer: TokenizerWrapper | None = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
batch_handler: BatchedInferenceHandler | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
|
||||
def process_task(task: Task) -> bool:
|
||||
"""
|
||||
Process a single task. Returns True if the runner should continue,
|
||||
False if it should shut down.
|
||||
"""
|
||||
nonlocal current_status, model, tokenizer, group, batch_handler
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
# NOTE: TaskAcknowledged is sent per-case below, AFTER the initial status
|
||||
# update, to avoid a race where the supervisor sees the ack before the
|
||||
# status change and re-dispatches the same lifecycle command.
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Model loading timed out"
|
||||
),
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
time.sleep(0.5)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(f"model has_tool_calling={tokenizer.has_tool_calling}")
|
||||
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Model loading timed out"
|
||||
),
|
||||
)
|
||||
# Initialize batch handler for text generation models
|
||||
if BATCH_ENABLED:
|
||||
# For tensor parallelism, distributed ops are handled inside model layers
|
||||
# so batch handler should use world_size=1 (no pipelining)
|
||||
batch_world_size = (
|
||||
1
|
||||
if isinstance(shard_metadata, TensorShardMetadata)
|
||||
else shard_metadata.world_size
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
inference_model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
image_model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
|
||||
toks = warmup_inference(
|
||||
model=inference_model,
|
||||
batch_handler = BatchedInferenceHandler(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
model_id=shard_metadata.model_card.model_id,
|
||||
device_rank=device_rank,
|
||||
world_size=batch_world_size,
|
||||
max_batch_size=BATCH_MAX_SIZE,
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
f"Batch handler initialized (max_batch_size={BATCH_MAX_SIZE}, world_size={batch_world_size})"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert image_model
|
||||
image = warmup_image_generator(model=image_model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
logger.info("warmup completed (non-primary node)")
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
logger.info(f"received chat request: {task}")
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
logger.info("warmup completed (non-primary node)")
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, (RunnerReady, RunnerRunning))
|
||||
):
|
||||
logger.info(f"received chat request: {task}")
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
|
||||
# Check if we should use serial processing for this task
|
||||
if not BATCH_ENABLED:
|
||||
logger.debug("Serial mode: BATCH_ENABLED is False")
|
||||
use_serial = True
|
||||
elif batch_handler is None:
|
||||
logger.debug("Serial mode: batch_handler is None")
|
||||
use_serial = True
|
||||
else:
|
||||
use_serial = _should_use_serial_processing(
|
||||
task, tokenizer, model, shard_metadata.model_card.model_id
|
||||
)
|
||||
|
||||
if use_serial:
|
||||
# Serial processing for complex tasks
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
logger.info("runner running (serial mode)")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Build prompt once - used for both generation and thinking detection
|
||||
prompt = apply_chat_template(tokenizer, task_params)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
_process_serial_chat_completion(
|
||||
task, model, tokenizer, shard_metadata, event_sender
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
mlx_generator = parse_thinking_models(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
elif "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
inference_model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
cancel_every = 5
|
||||
tokens_since_last_cancel_check = 0
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= cancel_every:
|
||||
tokens_since_last_cancel_check = 0
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
break
|
||||
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if (
|
||||
device_rank == 0
|
||||
and response.finish_reason == "error"
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
error_message=response.text,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif device_rank == 0:
|
||||
assert response.finish_reason not in (
|
||||
"error",
|
||||
"tool_calls",
|
||||
"function_call",
|
||||
)
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
@@ -354,38 +420,32 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageGeneration() | ImageEdits() if isinstance(
|
||||
current_status, RunnerReady
|
||||
):
|
||||
assert image_model
|
||||
task_name = (
|
||||
"image generation"
|
||||
if isinstance(task, ImageGeneration)
|
||||
else "image edits"
|
||||
)
|
||||
logger.info(f"received {task_name} request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
# Batch processing for simple tasks
|
||||
assert batch_handler is not None
|
||||
try:
|
||||
_run_image_task(
|
||||
task=task,
|
||||
image_model=image_model,
|
||||
shard_metadata=shard_metadata,
|
||||
event_sender=event_sender,
|
||||
cancel_receiver=cancel_receiver,
|
||||
cancelled_tasks=cancelled_tasks,
|
||||
)
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
batch_handler.add_request(task)
|
||||
|
||||
# Update status to running if not already
|
||||
if not isinstance(current_status, RunnerRunning):
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running (batch mode)")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
# Return True to indicate task was added to batch
|
||||
# (completion will be sent when batch processes)
|
||||
return True
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=task.command_id,
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
@@ -394,35 +454,235 @@ def main(
|
||||
)
|
||||
)
|
||||
raise
|
||||
case ImageGeneration(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
try:
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
raise
|
||||
|
||||
gc.collect()
|
||||
break
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
if batch_handler is not None:
|
||||
batch_handler.close()
|
||||
batch_handler = None
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
return not isinstance(current_status, RunnerShutdown)
|
||||
|
||||
# Track tasks that were added to batch (need completion after batch processes)
|
||||
batched_task_ids: list[tuple[Task, bool]] = [] # (task, completed)
|
||||
|
||||
with task_receiver as tasks:
|
||||
while True:
|
||||
# Check if batch handler is active and needs processing
|
||||
if batch_handler is not None and (
|
||||
batch_handler.is_active or batch_handler.has_pending
|
||||
):
|
||||
# Drain all available tasks before stepping
|
||||
should_break = False
|
||||
while True:
|
||||
try:
|
||||
task = tasks.receive_nowait()
|
||||
if isinstance(task, ChatCompletion) and isinstance(
|
||||
current_status, (RunnerReady, RunnerRunning)
|
||||
):
|
||||
was_batched = process_task(task)
|
||||
if was_batched:
|
||||
batched_task_ids.append((task, False))
|
||||
else:
|
||||
should_continue = process_task(task)
|
||||
if not should_continue:
|
||||
should_break = True
|
||||
break
|
||||
except WouldBlock:
|
||||
break # No more tasks available
|
||||
except EndOfStream:
|
||||
should_break = True
|
||||
break
|
||||
if should_break:
|
||||
break
|
||||
|
||||
# Flush all pending requests before stepping
|
||||
if batch_handler.has_pending:
|
||||
logger.info(
|
||||
f"Flushing batch (pending={len(batch_handler.pending)}, active={batch_handler.current_batch_size})"
|
||||
)
|
||||
batch_handler.flush()
|
||||
|
||||
# Step generation and emit events
|
||||
if batch_handler.is_active:
|
||||
event_count = 0
|
||||
for event in batch_handler.step():
|
||||
event_sender.send(event)
|
||||
event_count += 1
|
||||
if event_count > 0:
|
||||
logger.debug(f"Emitted {event_count} events from batch")
|
||||
|
||||
# Check for completed batched tasks
|
||||
if not batch_handler.is_active and not batch_handler.has_pending:
|
||||
# All batched tasks completed
|
||||
for task, completed in batched_task_ids:
|
||||
if not completed:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
batched_task_ids.clear()
|
||||
|
||||
# Return to ready state
|
||||
if isinstance(current_status, RunnerRunning):
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready (batch completed)")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
else:
|
||||
# No active batch - use blocking receive
|
||||
try:
|
||||
task = tasks.receive()
|
||||
should_continue = process_task(task)
|
||||
if not should_continue:
|
||||
break
|
||||
except EndOfStream:
|
||||
break
|
||||
|
||||
# Cleanup
|
||||
if batch_handler is not None:
|
||||
batch_handler.close()
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
@cache
|
||||
@@ -526,54 +786,6 @@ def parse_thinking_models(
|
||||
yield response
|
||||
|
||||
|
||||
def _run_image_task(
|
||||
task: ImageGeneration | ImageEdits,
|
||||
image_model: DistributedImageModel,
|
||||
shard_metadata: ShardMetadata,
|
||||
event_sender: MpSender[Event],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
cancelled_tasks: set[TaskId],
|
||||
) -> None:
|
||||
task_id = task.task_id
|
||||
command_id = task.command_id
|
||||
|
||||
def check_cancelled(task_id: TaskId = task_id) -> bool:
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
return (task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model,
|
||||
task=task.task_params,
|
||||
cancel_checker=check_cancelled,
|
||||
):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
|
||||
def _send_image_chunk(
|
||||
encoded_data: str,
|
||||
command_id: CommandId,
|
||||
|
||||
@@ -49,12 +49,13 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
sent: set[TaskId] = field(
|
||||
default_factory=set, init=False
|
||||
) # Tasks sent to runner (not yet completed)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -65,8 +66,8 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float = 400,
|
||||
) -> Self:
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
@@ -74,7 +75,6 @@ class RunnerSupervisor:
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
@@ -89,7 +89,6 @@ class RunnerSupervisor:
|
||||
initialize_timeout=initialize_timeout,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
)
|
||||
|
||||
@@ -97,65 +96,72 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
async with self._tg as tg:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._forward_events)
|
||||
|
||||
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
await to_thread.run_sync(self.runner_process.join, 30)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.close()
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 10)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
def shutdown(self):
|
||||
assert self._tg
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
async def start_task(self, task: Task, wait_for_ack: bool = True):
|
||||
"""
|
||||
Send a task to the runner.
|
||||
|
||||
Args:
|
||||
task: The task to send.
|
||||
wait_for_ack: If True, wait for TaskAcknowledged before returning.
|
||||
If False, return immediately after sending (for batching).
|
||||
"""
|
||||
if task.task_id in self.completed:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
logger.debug(
|
||||
f"Skipping task {task.task_id} as it has already been completed"
|
||||
)
|
||||
return
|
||||
if task.task_id in self.sent:
|
||||
logger.debug(f"Task {task.task_id} already sent, skipping duplicate")
|
||||
return
|
||||
if task.task_id in self.pending:
|
||||
logger.debug(f"Task {task.task_id} already pending, skipping duplicate")
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
self.sent.add(task.task_id)
|
||||
try:
|
||||
self._task_sender.send(task)
|
||||
except ClosedResourceError:
|
||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||
self.sent.discard(task.task_id)
|
||||
return
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
if wait_for_ack:
|
||||
await event.wait()
|
||||
logger.info(f"Finished task {task}")
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
@@ -164,7 +170,11 @@ class RunnerSupervisor:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
# Use pop with default to handle tasks sent with wait_for_ack=False
|
||||
# that may have already been removed or never added
|
||||
pending_event = self.pending.pop(event.task_id, None)
|
||||
if pending_event:
|
||||
pending_event.set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
@@ -182,6 +192,7 @@ class RunnerSupervisor:
|
||||
),
|
||||
)
|
||||
self.completed.add(event.task_id)
|
||||
self.sent.discard(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
@@ -221,4 +232,4 @@ class RunnerSupervisor:
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
await self.shutdown()
|
||||
self.shutdown()
|
||||
|
||||
@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
sent: set[TaskId] = field(default_factory=set)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
545
src/exo/worker/tests/unittests/test_mlx/test_kv_prefix_cache.py
Normal file
545
src/exo/worker/tests/unittests/test_mlx/test_kv_prefix_cache.py
Normal file
@@ -0,0 +1,545 @@
|
||||
# type: ignore
|
||||
import time
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
KVPrefixCache,
|
||||
_cache_length,
|
||||
_get_prefix_length,
|
||||
encode_prompt,
|
||||
make_kv_cache,
|
||||
)
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import (
|
||||
DEFAULT_GPT_OSS_CONFIG,
|
||||
DEFAULT_GPT_OSS_MODEL_ID,
|
||||
)
|
||||
|
||||
|
||||
def _check_model_exists() -> bool:
|
||||
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
|
||||
|
||||
|
||||
class TestGetPrefixLength:
|
||||
def test_identical_arrays(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert _get_prefix_length(a, b) == 5
|
||||
|
||||
def test_no_common_prefix(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([4, 5, 6])
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_partial_prefix(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 7, 8])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_prompt_longer_than_cached(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_cached_longer_than_prompt(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_single_token_match(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 5, 6])
|
||||
assert _get_prefix_length(a, b) == 1
|
||||
|
||||
def test_empty_prompt(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([1, 2, 3])
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_empty_cached(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_both_empty(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
|
||||
class TestKVPrefix:
|
||||
@pytest.fixture
|
||||
def mock_tokenizer(self):
|
||||
"""Create a minimal mock tokenizer for tests that don't need real tokenization."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.encode.return_value = [1, 2, 3]
|
||||
return tokenizer
|
||||
|
||||
def test_starts_empty(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_empties_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache.prompts.append(mx.array([1, 2, 3]))
|
||||
cache.caches.append([KVCache()])
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
|
||||
|
||||
def _load_gpt_oss() -> tuple[Model, object]:
|
||||
from mlx_lm.utils import load_model
|
||||
|
||||
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
|
||||
|
||||
model_path = DEFAULT_GPT_OSS_CONFIG.model_path
|
||||
model_id = ModelId(DEFAULT_GPT_OSS_MODEL_ID)
|
||||
|
||||
model, _ = load_model(model_path, lazy=False)
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
not _check_model_exists(),
|
||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
||||
)
|
||||
class TestKVPrefixCacheWithModel:
|
||||
@pytest.fixture(scope="class")
|
||||
def model_and_tokenizer(self):
|
||||
model, tokenizer = _load_gpt_oss()
|
||||
return model, tokenizer
|
||||
|
||||
def test_prefill_populates_cache(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Hello!!")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens
|
||||
assert _cache_length(cache) == len(tokens)
|
||||
|
||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Test exact")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
assert stored_length > 0
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# Exact match returns only last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, tokens[-1:])
|
||||
|
||||
def test_add_and_get_prefix_match(self, model_and_tokenizer):
|
||||
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
short_task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Hi")],
|
||||
max_tokens=1,
|
||||
)
|
||||
short_prompt = apply_chat_template(tokenizer, short_task)
|
||||
short_tokens = encode_prompt(tokenizer, short_prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(short_prompt, cache)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
long_task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content="Hi there, how are you?")
|
||||
],
|
||||
max_tokens=1,
|
||||
)
|
||||
long_prompt = apply_chat_template(tokenizer, long_task)
|
||||
long_tokens = encode_prompt(tokenizer, long_prompt)
|
||||
|
||||
# The prompts share a prefix (chat template preamble + "Hi")
|
||||
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
|
||||
assert expected_prefix > 0, (
|
||||
"Prompts should share a prefix from the chat template"
|
||||
)
|
||||
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, long_prompt
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# remaining_tokens should be the suffix after the shared prefix
|
||||
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
|
||||
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
|
||||
|
||||
def test_stored_cache_not_mutated_after_get_and_generation(
|
||||
self, model_and_tokenizer
|
||||
):
|
||||
"""Getting a cache and then mutating it (as generation does) must not corrupt stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Mutation test")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
assert matched_index == 0
|
||||
|
||||
# Simulate generation: feed many additional tokens through the cache
|
||||
head_dim = result_cache[0].keys.shape[-1]
|
||||
num_heads = result_cache[0].keys.shape[1]
|
||||
extra_keys = mx.random.normal((1, num_heads, 50, head_dim))
|
||||
extra_values = mx.random.normal((1, num_heads, 50, head_dim))
|
||||
for layer_cache in result_cache:
|
||||
layer_cache.update_and_fetch(extra_keys, extra_values)
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
# Stored cache must be unchanged
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
|
||||
|
||||
def test_stored_cache_survives_repeated_get_mutate_cycles(
|
||||
self, model_and_tokenizer
|
||||
):
|
||||
"""Multiple get+mutate cycles (like repeated user requests) must not corrupt cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Repeat test")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
|
||||
head_dim = result_cache[0].keys.shape[-1]
|
||||
num_heads = result_cache[0].keys.shape[1]
|
||||
extra = mx.random.normal((1, num_heads, 30, head_dim))
|
||||
for layer_cache in result_cache:
|
||||
layer_cache.update_and_fetch(extra, extra)
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
||||
f"Failed on loop {i}"
|
||||
)
|
||||
|
||||
def test_mlx_generate_populates_cache(self, model_and_tokenizer):
|
||||
"""mlx_generate should save the cache after generation completes."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Hello")],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
|
||||
# Consume the entire generator so the cache-saving code after yield runs
|
||||
generated_tokens = 0
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
generated_tokens += 1
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
assert len(kv_prefix_cache.caches) == 1
|
||||
# Cache should contain prompt + generated tokens
|
||||
expected_length = len(prompt_tokens) + generated_tokens
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
|
||||
|
||||
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Reuse test")],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
|
||||
# First generation populates cache
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
|
||||
# Second call should find a prefix match (the stored cache contains
|
||||
# prompt + generated tokens, which shares the prompt prefix)
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
)
|
||||
# The stored cache is longer than the prompt (it includes generated tokens),
|
||||
# so this is a prefix match where our prompt is fully contained
|
||||
assert matched_index == 0
|
||||
# Exact match: remaining_tokens is just the last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
|
||||
|
||||
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
|
||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
repeats = (1200 // len(base_tokens)) + 2
|
||||
long_content = base_text * repeats
|
||||
|
||||
task1 = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content=long_content)],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt1 = apply_chat_template(tokenizer, task1)
|
||||
prompt1_tokens = encode_prompt(tokenizer, prompt1)
|
||||
assert len(prompt1_tokens) > 1000, (
|
||||
"Prompt must exceed _MIN_PREFIX_HIT_TO_UPDATE"
|
||||
)
|
||||
|
||||
# First generation populates the cache (must prefill all tokens)
|
||||
t0 = time.perf_counter()
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task1,
|
||||
prompt=prompt1,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
first_gen_time = time.perf_counter() - t0
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation: same long prompt + extra content (simulating multi-turn)
|
||||
task2 = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content=long_content),
|
||||
ChatCompletionMessage(role="assistant", content="Sure, I can help."),
|
||||
ChatCompletionMessage(role="user", content="Tell me more."),
|
||||
],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt2 = apply_chat_template(tokenizer, task2)
|
||||
prompt2_tokens = encode_prompt(tokenizer, prompt2)
|
||||
|
||||
# Verify the prompts share a long prefix
|
||||
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
|
||||
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
|
||||
|
||||
# Second generation should reuse the cached prefix (only prefill new tokens)
|
||||
t0 = time.perf_counter()
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task2,
|
||||
prompt=prompt2,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
second_gen_time = time.perf_counter() - t0
|
||||
|
||||
# Second generation should be significantly faster due to prefix cache hit - hopefully not flaky
|
||||
assert second_gen_time < first_gen_time * 0.5, (
|
||||
f"Expected prefix cache speedup: "
|
||||
f"first={first_gen_time:.2f}s, second={second_gen_time:.2f}s"
|
||||
)
|
||||
|
||||
# With prefix_hit > 1000, should update in-place (not add a second entry)
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
|
||||
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
assert updated_cache_length > first_cache_length
|
||||
|
||||
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
|
||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Immutable test")],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
|
||||
# First generation populates cache
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation gets the cache and mutates it during generation
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
|
||||
# The first stored cache must not have been mutated by the second generation
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
|
||||
|
||||
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
# Add three cache entries with different prompts
|
||||
prompts = ["First entry", "Second entry", "Third entry"]
|
||||
for i, content in enumerate(prompts):
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content=content)],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 3
|
||||
|
||||
# Access the third entry to make it most recently used
|
||||
kv_prefix_cache._last_used[2] = 100.0
|
||||
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
|
||||
|
||||
# Simulate memory pressure: active memory exceeds threshold
|
||||
fake_limit = 1000
|
||||
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
|
||||
return_value=fake_active,
|
||||
),
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.device_info",
|
||||
return_value={"max_recommended_working_set_size": fake_limit},
|
||||
),
|
||||
):
|
||||
# Trigger eviction by adding a new entry
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="New entry")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
# Since fake_active stays above threshold after each eviction (we don't change it),
|
||||
# all old entries get evicted, leaving only the newly added one
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
@@ -118,6 +118,10 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
||||
# Force serial processing mode since batch mode requires a real tokenizer
|
||||
monkeypatch.setattr(mlx_runner, "_should_use_serial_processing", make_nothin(True))
|
||||
# Disable batch handler initialization
|
||||
monkeypatch.setattr(mlx_runner, "BATCH_ENABLED", False)
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
@@ -192,29 +196,30 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
# Status update comes before ack to prevent race conditions
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
|
||||
),
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
@@ -222,10 +227,10 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
|
||||
),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user