Implement prefill/decode

This commit is contained in:
Ryuichi Leo Takashige
2026-03-18 11:31:55 +00:00
parent f208586092
commit ba472da84f
9 changed files with 957 additions and 3 deletions

0
bench_overnight.log Normal file
View File

24
bench_overnight.sh Executable file
View File

@@ -0,0 +1,24 @@
#!/bin/bash
set -e
export PATH="/opt/homebrew/bin:$PATH"
echo "=== Starting overnight bench runs at $(date) ==="
echo "--- [4/8] Qwen3.5-122B-A10B-GPTQ-Int4 ---"
echo "Skipping because Int 4"
#uv run bench/exo_bench.py --force-download --model "Qwen/Qwen3.5-122B-A10B-GPTQ-Int4" --pp 700 --tg 36000 --repeat 1
echo "--- [5/8] Qwen3.5-27B-FP8 ---"
#uv run bench/exo_bench.py --force-download --model "Qwen/Qwen3.5-27B-FP8" --pp 700 --tg 35133 --repeat 1
echo "--- [6/8] GLM-4.7-Flash-bf16 ---"
uv run bench/exo_bench.py --force-download --model "mlx-community/GLM-4.7-Flash-bf16" --pp 700 --tg 29000 --repeat 1
echo "--- [7/8] NVIDIA-Nemotron-3-Nano-30B-A3B (23000,1200) ---"
uv run bench/exo_bench.py --force-download --model "mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-BF16" --pp 700 --tg 23000,1200 --repeat 1
echo "--- [8/8] Qwen3.5-27B-bf16 ---"
uv run bench/exo_bench.py --force-download --model "mlx-community/Qwen3.5-27B-bf16" --pp 700 --tg 35400 --repeat 1
echo "=== All bench runs complete at $(date) ==="

View File

@@ -0,0 +1,89 @@
"""Minimal KVConnector that captures per-layer cache data."""
from dataclasses import dataclass
from typing import Any
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
captured_layers: dict[str, Any] = {}
@dataclass
class CaptureMetadata(KVConnectorMetadata):
pass
class CaptureConnector(KVConnectorBase_V1):
def __init__(self, vllm_config, role, kv_cache_config=None):
super().__init__(vllm_config, role, kv_cache_config)
def start_load_kv(self, forward_context, **kwargs):
pass
def wait_for_layer_load(self, layer_name):
pass
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
import time
slot_mapping = getattr(attn_metadata, 'slot_mapping', None)
if slot_mapping is not None and slot_mapping.shape[0] <= 100:
return
t0 = time.perf_counter()
torch.cuda.synchronize()
t_sync = time.perf_counter() - t0
if isinstance(kv_layer, (list, tuple)):
captured_layers[layer_name] = [t.cpu().clone() for t in kv_layer]
else:
slot_mapping = getattr(attn_metadata, 'slot_mapping', None)
if slot_mapping is not None:
if kv_layer.shape[0] == 2:
k_all = kv_layer[0]
v_all = kv_layer[1]
else:
k_all = kv_layer[:, 0]
v_all = kv_layer[:, 1]
k_flat = k_all.reshape(-1, *k_all.shape[-2:])
v_flat = v_all.reshape(-1, *v_all.shape[-2:])
valid = slot_mapping >= 0
safe_sm = slot_mapping.clamp(min=0)
keys = k_flat[safe_sm]
values = v_flat[safe_sm]
keys[~valid] = 0
values[~valid] = 0
prev = captured_layers.get(layer_name)
if isinstance(prev, dict) and "keys" in prev:
t1 = time.perf_counter()
captured_layers[layer_name] = {
"keys": torch.cat([prev["keys"], keys.cpu()], dim=0),
"values": torch.cat([prev["values"], values.cpu()], dim=0),
}
t_copy = time.perf_counter() - t1
else:
t1 = time.perf_counter()
captured_layers[layer_name] = {
"keys": keys.cpu(),
"values": values.cpu(),
}
t_copy = time.perf_counter() - t1
if "layers.3." in layer_name:
print(f" [attn save] sync={t_sync*1000:.1f}ms copy={t_copy*1000:.1f}ms tokens={keys.shape[0]}")
else:
captured_layers[layer_name] = kv_layer.cpu().clone()
def wait_for_save(self):
pass
def get_num_new_matched_tokens(self, request, num_computed_tokens):
return 0, False
def update_state_after_alloc(self, request, blocks, num_external_tokens):
pass
def build_connector_meta(self, scheduler_output):
return CaptureMetadata()

View File

@@ -0,0 +1,155 @@
"""Inspect vLLM KV cache structure per-layer after prefill.
Runs on DGX Spark. Prints per-layer shapes, dtypes, kv_cache_config,
and layer_to_group mapping to understand what vLLM stores for each
model architecture (standard attention, sliding window, GatedDeltaNet).
Usage:
uv run python scripts/disaggregated/inspect_vllm_kv.py --model ~/.local/share/exo/models/openai--gpt-oss-20b
"""
import argparse
import os
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "src"))
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
os.environ["VLLM_BATCH_INVARIANT"] = "1"
from exo.worker.runner.bootstrap import _ensure_cuda_libs
_ensure_cuda_libs()
import torch
def _build_layer_groups(kv_cache_config):
group_lookup = {}
for group_idx, group_spec in enumerate(kv_cache_config.kv_cache_groups):
for layer_name in group_spec.layer_names:
group_lookup[layer_name] = group_idx
layer_to_group = []
for tensor_spec in kv_cache_config.kv_cache_tensors:
for name in tensor_spec.shared_by:
layer_to_group.append(group_lookup[name])
return layer_to_group
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Path to model")
parser.add_argument("--prompt", default="Hello, world! How are you today?", help="Prompt to prefill")
args = parser.parse_args()
from exo.worker.engines.vllm.vllm_generator import load_vllm_engine
from exo.worker.engines.vllm.growable_cache import get_model_runner
print(f"Loading vLLM engine from {args.model}...")
engine, _, prefix_cache = load_vllm_engine(
model_path=args.model,
model_id=args.model,
trust_remote_code=True,
)
print("Engine loaded.\n")
from vllm import SamplingParams
tokenizer = engine.get_tokenizer()
token_ids = tokenizer.encode(args.prompt, add_special_tokens=False)
print(f"Prompt: {args.prompt!r}")
print(f"Token IDs: {len(token_ids)} tokens\n")
request_id = "inspect-test"
params = SamplingParams(max_tokens=1, detokenize=False)
engine.add_request(request_id, {"prompt_token_ids": token_ids}, params)
while engine.has_unfinished_requests():
engine.step()
model_runner = get_model_runner()
if model_runner is None:
print("ERROR: model_runner is None")
return
print("=" * 70)
print("PER-LAYER KV CACHE TENSORS (model_runner.kv_caches)")
print("=" * 70)
kv_caches = model_runner.kv_caches
for i, kv in enumerate(kv_caches):
if isinstance(kv, list):
shapes = [t.shape for t in kv]
dtypes = [t.dtype for t in kv]
print(f" Layer {i:3d}: list of {len(kv)} tensors — shapes={shapes}, dtypes={dtypes}")
elif isinstance(kv, torch.Tensor):
print(f" Layer {i:3d}: shape={tuple(kv.shape)}, dtype={kv.dtype}, device={kv.device}")
else:
print(f" Layer {i:3d}: type={type(kv).__name__}")
print(f"\n Total layers with KV: {len(kv_caches)}\n")
engine_core = engine.engine_core.engine_core
kv_cache_config = engine_core.scheduler.kv_cache_manager.kv_cache_config
print("=" * 70)
print("KV CACHE CONFIG")
print("=" * 70)
print(f"\n Number of KV cache groups: {len(kv_cache_config.kv_cache_groups)}")
for gi, group in enumerate(kv_cache_config.kv_cache_groups):
print(f"\n Group {gi}:")
print(f" Layer names ({len(group.layer_names)}):")
for name in group.layer_names[:5]:
print(f" {name}")
if len(group.layer_names) > 5:
print(f" ... and {len(group.layer_names) - 5} more")
print(f"\n Number of KV cache tensors: {len(kv_cache_config.kv_cache_tensors)}")
for ti, tensor_spec in enumerate(kv_cache_config.kv_cache_tensors):
shared = tensor_spec.shared_by[:3]
extra = f" ... +{len(tensor_spec.shared_by)-3}" if len(tensor_spec.shared_by) > 3 else ""
print(f" Tensor {ti}: shared_by={shared}{extra}")
layer_to_group = _build_layer_groups(kv_cache_config)
print(f"\n layer_to_group ({len(layer_to_group)} entries): {layer_to_group[:10]}{'...' if len(layer_to_group) > 10 else ''}")
coordinator = engine_core.scheduler.kv_cache_manager.coordinator
null_block = coordinator.block_pool.null_block
internal_id = None
for mgr in coordinator.single_type_managers:
for key in mgr.req_to_blocks:
if str(key).startswith(request_id):
internal_id = str(key)
break
if internal_id:
break
if internal_id:
print(f"\n Request internal_id: {internal_id}")
for gi, mgr in enumerate(coordinator.single_type_managers):
blocks = mgr.req_to_blocks.get(internal_id)
if blocks:
real_blocks = [b for b in blocks if b is not null_block and not b.is_null]
null_count = len(blocks) - len(real_blocks)
print(f" Group {gi}: {len(real_blocks)} real blocks, {null_count} null blocks, block_size={mgr.block_size}")
else:
print(f" Group {gi}: no blocks")
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f" Model: {args.model}")
print(f" KV cache layers: {len(kv_caches)}")
print(f" KV cache groups: {len(kv_cache_config.kv_cache_groups)}")
print(f" Layer-to-group mapping entries: {len(layer_to_group)}")
unique_shapes = set()
for kv in kv_caches:
if isinstance(kv, torch.Tensor):
unique_shapes.add(tuple(kv.shape))
print(f" Unique tensor shapes: {unique_shapes}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,258 @@
"""Extract KV cache per-layer from vLLM using a real KVConnector.
Patches vLLM to allow KVConnector on hybrid models (attention + GDN).
Usage:
uv run python scripts/disaggregated/test_kv_extract.py --model ~/.local/share/exo/models/Qwen--Qwen3.5-2B --output /tmp/kv_cache_qwen35/
"""
import argparse
import json
import os
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "src"))
sys.path.insert(0, str(Path(__file__).resolve().parent))
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
os.environ["VLLM_BATCH_INVARIANT"] = "1"
from exo.worker.runner.bootstrap import _ensure_cuda_libs
_ensure_cuda_libs()
import torch
def _patch_vllm_for_connector():
"""Patch vLLM to allow KVConnector on hybrid models."""
from vllm.v1.core import kv_cache_utils
original_unify = kv_cache_utils.unify_hybrid_kv_cache_specs
def patched_unify(kv_cache_spec):
try:
original_unify(kv_cache_spec)
except ValueError:
pass
kv_cache_utils.unify_hybrid_kv_cache_specs = patched_unify
from vllm.v1.core.sched import scheduler as sched_mod
original_connector_finished = sched_mod.Scheduler._connector_finished
def patched_connector_finished(self, request):
return False, None
sched_mod.Scheduler._connector_finished = patched_connector_finished
from capture_connector import CaptureConnector
from vllm.distributed.kv_transfer.kv_connector import factory
original_get = factory.KVConnectorFactory._get_connector_class_with_compat
@classmethod
def patched_get(cls, kv_transfer_config):
if "capture_connector" in (kv_transfer_config.kv_connector or ""):
return CaptureConnector, None
return original_get.__func__(cls, kv_transfer_config)
factory.KVConnectorFactory._get_connector_class_with_compat = patched_get
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--output", required=True)
_lorem = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus magna felis sollicitudin mauris. Integer in mauris eu nibh euismod gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis risus a elit. Etiam tempor. Ut ullamcorper, ligula ut dictum pharetra, nisi nunc fringilla magna, in commodo elit erat nec turpis. Ut pharetra augue nec augue. Nam elit agna, endrerit sit amet, tincidunt ac, viverra sed, nulla. Donec porta diam eu massa. Quisque diam lorem, interdum vitae, dapibus ac, scelerisque vitae, pede. Donec eget tellus non erat lacinia fermentum. Donec in velit vel ipsum auctor pulvinar. Vestibulum iaculis lacinia est. Proin dictum elementum velit. Fusce euismod consequat ante. Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Pellentesque sed dolor. Aliquam congue fermentum nisl. Mauris accumsan nulla vel diam. Sed in lacus ut enim adipiscing aliquet. Nulla venenatis. In pede mi, aliquet sit amet, euismod in, auctor ut, ligula. Aliquam dapibus tincidunt metus. Praesent justo dolor, lobortis quis, lobortis dignissim, pulvinar ac, lorem. "
parser.add_argument("--prompt", default=_lorem * 21 + "Now answer this question: What is the capital of France and why is it historically significant? Give a detailed answer.")
args = parser.parse_args()
out_dir = Path(args.output)
out_dir.mkdir(parents=True, exist_ok=True)
_patch_vllm_for_connector()
from vllm.engine.arg_utils import EngineArgs
from vllm.v1.engine.llm_engine import LLMEngine
from exo.worker.engines.vllm.growable_cache import patch_vllm, set_prefix_cache
from exo.worker.engines.mlx.cache import KVPrefixCache
patch_vllm()
prefix_cache = KVPrefixCache(group=None)
set_prefix_cache(prefix_cache)
engine_args = EngineArgs(
model=args.model,
served_model_name=args.model,
gpu_memory_utilization=0.05,
trust_remote_code=True,
load_format="fastsafetensors",
enable_prefix_caching=False,
attention_backend="TRITON_ATTN",
enforce_eager=True,
disable_log_stats=True,
kv_transfer_config={
"kv_connector": "capture_connector:CaptureConnector",
"kv_role": "kv_both",
},
)
print(f"Loading engine with KVConnector...")
engine = LLMEngine.from_engine_args(engine_args)
print("Engine loaded.")
from exo.worker.engines.vllm.growable_cache import get_model_runner
from vllm.model_executor.layers.mamba.abstract import MambaBase
from capture_connector import captured_layers as gdn_captured
model_runner = get_model_runner()
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as orig_causal_conv1d_fn
import vllm.model_executor.layers.mamba.ops.causal_conv1d as cc_mod
gdn_states: dict[int, dict[str, torch.Tensor]] = {}
gdn_call_idx = [0]
gdn_layer_order = [0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 16, 17, 18, 20, 21, 22]
def patched_causal_conv1d_fn(*args, conv_states=None, cache_indices=None, **kwargs):
result = orig_causal_conv1d_fn(*args, conv_states=conv_states, cache_indices=cache_indices, **kwargs)
if conv_states is not None and cache_indices is not None:
x = args[0] if args else None
if x is not None and x.shape[0] <= 100:
return result
import time as _time
t0 = _time.perf_counter()
torch.cuda.synchronize()
t_sync = _time.perf_counter() - t0
ci = cache_indices[0].item() if cache_indices.numel() > 0 else 0
idx = gdn_call_idx[0]
layer_idx = gdn_layer_order[idx % len(gdn_layer_order)]
t1 = _time.perf_counter()
conv_at_ci = conv_states[ci:ci+1].transpose(-1, -2).contiguous().cpu()
t_copy = _time.perf_counter() - t1
gdn_states.setdefault(layer_idx, {})["conv"] = conv_at_ci
gdn_states[layer_idx]["ci"] = ci
if gdn_call_idx[0] < 3:
print(f" [gdn save] sync={t_sync*1000:.1f}ms copy={t_copy*1000:.1f}ms layer={layer_idx}")
gdn_call_idx[0] += 1
return result
cc_mod.causal_conv1d_fn = patched_causal_conv1d_fn
for mod in list(sys.modules.values()):
if mod is None or mod is cc_mod:
continue
if hasattr(mod, 'causal_conv1d_fn') and mod.causal_conv1d_fn is orig_causal_conv1d_fn:
mod.causal_conv1d_fn = patched_causal_conv1d_fn
print(f" Patched causal_conv1d_fn")
from exo.worker.engines.vllm.vllm_generator import VllmBatchEngine
from exo.shared.types.text_generation import TextGenerationTaskParams, InputMessage
from exo.shared.types.tasks import TaskId
batch_engine = VllmBatchEngine(engine=engine, model_id=args.model, prefix_cache=prefix_cache)
task = TextGenerationTaskParams(
model=args.model,
input=[InputMessage(role="user", content=args.prompt)],
max_completion_tokens=1,
)
task_id = batch_engine.submit(task_id=TaskId("extract"), task_params=task, prompt=args.prompt)
print(f"Running prefill via VllmBatchEngine...")
t0 = time.perf_counter()
while batch_engine.has_work:
results = batch_engine.step()
for tid, resp in results:
print(f" Prefill done in {(time.perf_counter()-t0)*1000:.0f}ms")
batch_engine.cancel([tid])
break
if results:
break
t1 = time.perf_counter()
print(f"Total: {(t1-t0)*1000:.0f}ms")
prompt_mx = prefix_cache.prompts[0] if prefix_cache.prompts else None
token_ids = [int(x) for x in prompt_mx.tolist()] if prompt_mx is not None else []
from capture_connector import captured_layers
print(f"\nCaptured {len(captured_layers)} layers via save_kv_layer:")
for name in sorted(captured_layers.keys()):
v = captured_layers[name]
if isinstance(v, list):
print(f" {name}: {[tuple(t.shape) for t in v]}")
elif isinstance(v, torch.Tensor):
print(f" {name}: {tuple(v.shape)}")
else:
print(f" {name}: {type(v).__name__}")
num_tokens = len(token_ids)
print(f" Chat-templated prompt: {num_tokens} tokens")
total_layers = 24
for f_old in out_dir.glob("layer_*"):
f_old.unlink()
metadata = {
"model": args.model,
"prompt": args.prompt,
"num_tokens": num_tokens,
"token_ids": token_ids,
"num_layers": total_layers,
"layers": [],
}
print(f"\nSaving {total_layers} layers...")
torch.cuda.synchronize()
for layer_idx in sorted(gdn_states.keys()):
ci = gdn_states[layer_idx]["ci"]
kv = model_runner.kv_caches[layer_idx]
if isinstance(kv, (list, tuple)) and len(kv) > 1:
rec_pool = kv[1]
rec = rec_pool[ci:ci+1].cpu().clone()
gdn_states[layer_idx]["rec"] = rec
for li in range(total_layers):
if li in gdn_states:
s = gdn_states[li]
conv = s.get("conv")
rec = s.get("rec")
torch.save(conv, out_dir / f"layer_{li:03d}_conv.pt")
if rec is not None:
torch.save(rec, out_dir / f"layer_{li:03d}_rec.pt")
metadata["layers"].append({"type": "gdn", "conv": list(conv.shape), "rec": list(rec.shape) if rec is not None else None})
print(f" Layer {li}: GDN conv={tuple(conv.shape)}, rec={tuple(rec.shape) if rec is not None else 'None'}")
else:
attn_name = None
for n in captured_layers:
parts = n.split(".")
for pi, p in enumerate(parts):
if p == "layers" and pi + 1 < len(parts) and parts[pi + 1] == str(li):
attn_name = n
break
if attn_name and isinstance(captured_layers[attn_name], dict):
kv = captured_layers[attn_name]
torch.save(kv["keys"], out_dir / f"layer_{li:03d}_keys.pt")
torch.save(kv["values"], out_dir / f"layer_{li:03d}_values.pt")
if "last_chunk_keys" in kv:
torch.save(kv["last_chunk_keys"], out_dir / f"layer_{li:03d}_keys_last.pt")
torch.save(kv["last_chunk_values"], out_dir / f"layer_{li:03d}_values_last.pt")
metadata["layers"].append({"type": "kv", "keys_shape": list(kv["keys"].shape), "values_shape": list(kv["values"].shape)})
print(f" Layer {li}: KV keys={tuple(kv['keys'].shape)}, values={tuple(kv['values'].shape)}")
else:
metadata["layers"].append({"type": "missing"})
print(f" Layer {li}: MISSING")
with open(out_dir / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
print(f"\nSaved metadata to {out_dir}/metadata.json")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,265 @@
"""Inject extracted vLLM KV cache into MLX model caches and test decode.
Runs on Mac (Apple Silicon). Loads per-layer KV tensors saved by
test_kv_extract.py, converts to MLX format, injects into MLX caches,
and generates tokens to verify correctness.
Usage:
uv run python scripts/disaggregated/test_kv_inject.py \
--model mlx-community/gpt-oss-20b-MXFP4-Q8 \
--kv-dir /path/to/extracted/kv_cache/ \
--num-tokens 20
"""
import argparse
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "src"))
import mlx.core as mx
import torch
from mlx_lm import load
from mlx_lm.models.cache import ArraysCache, KVCache, RotatingKVCache
def _torch_to_mx(t: torch.Tensor) -> mx.array:
t = t.detach().cpu()
if t.dtype == torch.bfloat16:
return mx.array(t.float().numpy()).astype(mx.bfloat16)
return mx.array(t.numpy())
def _to_bhsd(keys: torch.Tensor, values: torch.Tensor, num_tokens: int) -> tuple[mx.array, mx.array]:
"""Convert vLLM block format to MLX BHSD [1, H, S, D].
Input can be:
- 4D [blocks, block_size, H, D] — flatten to [blocks*block_size, H, D], trim to num_tokens
- 3D [S, H, D] — use directly
"""
if keys.dim() == 4:
keys = keys.reshape(-1, keys.shape[2], keys.shape[3])[:num_tokens]
values = values.reshape(-1, values.shape[2], values.shape[3])[:num_tokens]
elif keys.dim() == 3:
keys = keys[:num_tokens]
values = values[:num_tokens]
k_mx = _torch_to_mx(keys.permute(1, 0, 2).unsqueeze(0))
v_mx = _torch_to_mx(values.permute(1, 0, 2).unsqueeze(0))
return k_mx, v_mx
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="MLX model path/ID")
parser.add_argument("--kv-dir", required=True, help="Directory with extracted KV tensors")
parser.add_argument("--num-tokens", type=int, default=500, help="Tokens to generate")
parser.add_argument("--prompt", default=None, help="Override prompt (must match extraction prompt)")
args = parser.parse_args()
kv_dir = Path(args.kv_dir)
with open(kv_dir / "metadata.json") as f:
metadata = json.load(f)
num_extracted_layers = metadata["num_layers"]
num_tokens = metadata["num_tokens"]
vllm_token_ids = metadata.get("token_ids", [])
print(f"Extracted KV: {num_extracted_layers} layers, {num_tokens} tokens")
if vllm_token_ids:
print(f" Using vLLM token_ids ({len(vllm_token_ids)} tokens)")
else:
print(f" WARNING: No token_ids in metadata")
print(f"\nLoading MLX model: {args.model}")
model, tokenizer = load(args.model)
caches = model.make_cache()
num_model_layers = len(caches)
print(f"\nMLX model expects {num_model_layers} cache layers:")
for i, c in enumerate(caches):
print(f" Layer {i:3d}: {type(c).__name__}", end="")
if isinstance(c, RotatingKVCache):
print(f" (max_size={c.max_size}, keep={c.keep})", end="")
elif isinstance(c, ArraysCache):
print(f" (size={len(c.state)})", end="")
print()
layer_info = metadata.get("layers", [])
print(f"\nExtracted {num_extracted_layers} layers from vLLM")
print("\nInjecting KV cache into MLX caches...")
injected = 0
skipped = 0
for i in range(num_model_layers):
cache = caches[i]
if isinstance(cache, ArraysCache):
conv_path = kv_dir / f"layer_{i:03d}_conv.pt"
rec_path = kv_dir / f"layer_{i:03d}_rec.pt"
keys_path = kv_dir / f"layer_{i:03d}_keys.pt"
values_path = kv_dir / f"layer_{i:03d}_values.pt"
if conv_path.exists():
conv = torch.load(conv_path, weights_only=True)
rec = torch.load(rec_path, weights_only=True) if rec_path.exists() else None
states = [_torch_to_mx(conv)]
states.append(_torch_to_mx(rec) if rec is not None else None)
cache.state = states
injected += 1
print(f" Layer {i}: ArraysCache conv={tuple(conv.shape)}, rec={tuple(rec.shape) if rec is not None else 'None'}")
elif keys_path.exists():
conv = torch.load(keys_path, weights_only=True)
rec = torch.load(values_path, weights_only=True)
cache.state = [_torch_to_mx(conv), _torch_to_mx(rec)]
injected += 1
print(f" Layer {i}: ArraysCache (legacy) conv={tuple(conv.shape)}, rec={tuple(rec.shape)}")
else:
print(f" Layer {i}: SKIP — ArraysCache, no files")
skipped += 1
continue
keys_path = kv_dir / f"layer_{i:03d}_keys.pt"
values_path = kv_dir / f"layer_{i:03d}_values.pt"
if not keys_path.exists():
skipped += 1
continue
keys_torch = torch.load(keys_path, weights_only=True)
values_torch = torch.load(values_path, weights_only=True)
k_mx, v_mx = _to_bhsd(keys_torch, values_torch, num_tokens)
seq_len = int(k_mx.shape[2])
if isinstance(cache, KVCache) and not isinstance(cache, RotatingKVCache):
cache.keys = k_mx
cache.values = v_mx
cache.offset = seq_len
injected += 1
elif isinstance(cache, RotatingKVCache):
if seq_len <= cache.max_size:
cache.keys = k_mx
cache.values = v_mx
cache.offset = seq_len
cache._idx = seq_len
else:
keep = cache.keep
window = cache.max_size
sink_keys = k_mx[:, :, :keep, :]
sink_values = v_mx[:, :, :keep, :]
recent_keys = k_mx[:, :, -(window - keep):, :]
recent_values = v_mx[:, :, -(window - keep):, :]
cache.keys = mx.concatenate([sink_keys, recent_keys], axis=2)
cache.values = mx.concatenate([sink_values, recent_values], axis=2)
cache.offset = seq_len
cache._idx = keep
injected += 1
print(f" Layer {i}: RotatingKVCache (seq_len={seq_len}, max_size={cache.max_size})")
else:
print(f" Layer {i}: SKIP — {type(cache).__name__}")
skipped += 1
print(f"\n Injected: {injected} layers, Skipped: {skipped} layers")
from exo.worker.engines.vllm.kv_cache import TorchKVCache as TKV
print(f"\nRound-trip test (MLX → torch → MLX)...")
rt_caches = model.make_cache()
rt_tokens = mx.array(vllm_token_ids)
rt_logits = model(rt_tokens[None], cache=rt_caches)
mx.eval(rt_logits)
torch_rt = TKV.from_mlx_cache(rt_caches)
back_rt = torch_rt.to_mlx_cache()
rt_max_diff = 0.0
for i in range(len(rt_caches)):
nc = rt_caches[i]
bc = back_rt[i]
if isinstance(nc, ArraysCache):
for ai in range(len(nc.state)):
if nc.state[ai] is not None and bc.state[ai] is not None:
d = mx.max(mx.abs(nc.state[ai].astype(mx.float32) - bc.state[ai].astype(mx.float32))).item()
rt_max_diff = max(rt_max_diff, d)
elif isinstance(nc, (KVCache, RotatingKVCache)) and nc.keys is not None:
nk, nv = nc.state
bk, bv = bc.state
d = mx.max(mx.abs(nk.astype(mx.float32) - bk.astype(mx.float32))).item()
rt_max_diff = max(rt_max_diff, d)
print(f" Round-trip max diff: {rt_max_diff:.4e} ({'PASS' if rt_max_diff < 0.01 else 'FAIL'})")
print(f"\nComparing with MLX-native prefill...")
native_caches = rt_caches
for i in range(num_model_layers):
nc = native_caches[i]
ic = caches[i]
if isinstance(nc, KVCache) and not isinstance(nc, RotatingKVCache) and nc.keys is not None and ic.keys is not None:
s = min(nc.offset, ic.offset)
nk = nc.keys[:, :, :s, :].astype(mx.float32)
ik = ic.keys[:, :, :s, :].astype(mx.float32)
nv = nc.values[:, :, :s, :].astype(mx.float32)
iv = ic.values[:, :, :s, :].astype(mx.float32)
k_diff = mx.max(mx.abs(nk - ik)).item()
v_diff = mx.max(mx.abs(nv - iv)).item()
if k_diff > 0.01 or i < 4 or i == num_model_layers - 1:
print(f" Layer {i:3d} KVCache: k_diff={k_diff:.4e}, v_diff={v_diff:.4e}, offset native={nc.offset} injected={ic.offset}")
elif isinstance(nc, RotatingKVCache):
pass
elif isinstance(nc, ArraysCache):
for ai in range(len(nc.state)):
na = nc.state[ai]
ia = ic.state[ai]
if na is not None and ia is not None:
diff = mx.max(mx.abs(na.astype(mx.float32) - ia.astype(mx.float32))).item()
if diff > 0.01 or i < 4 or i == num_model_layers - 1:
print(f" Layer {i:3d} Arrays[{ai}]: diff={diff:.4e}, native_shape={na.shape}, injected_shape={ia.shape}")
native_last = mx.array([vllm_token_ids[-1]])
native_decode_logits = model(native_last[None], cache=native_caches)
mx.eval(native_decode_logits)
native_first = mx.argmax(native_decode_logits[:, -1, :], axis=-1)
print(f" Native decode first token: {native_first.item()}, text: {tokenizer.decode([native_first.item()])!r}")
print(f"\nDecoding {args.num_tokens} tokens with injected cache...")
last_tokens = mx.array(vllm_token_ids[-2:])
logits = model(last_tokens[None], cache=caches)
mx.eval(logits)
generated_tokens = []
token = mx.argmax(logits[:, -1, :], axis=-1)
mx.eval(token)
generated_tokens.append(token.item())
for _ in range(args.num_tokens - 1):
logits = model(token[None], cache=caches)
mx.eval(logits)
token = mx.argmax(logits[:, -1, :], axis=-1)
mx.eval(token)
generated_tokens.append(token.item())
generated_text = tokenizer.decode(generated_tokens)
print(f"\n{'='*70}")
print("RESULTS")
print(f"{'='*70}")
print(f" Model (vLLM): {metadata['model']}")
print(f" Model (MLX): {args.model}")
print(f" Prompt tokens: {num_tokens}")
print(f" Layers injected: {injected}/{num_model_layers}")
print(f" Type mismatches: 0")
print(f" Generated {len(generated_tokens)} tokens")
print(f" Text: {generated_text!r}")
if False:
print(f"\n GAPS FOUND:")
for idx, got, expected in type_mismatches:
print(f" Layer {idx}: vLLM gives KV tensors, MLX wants {expected}")
arrays_layers = [i for i, c in enumerate(caches) if isinstance(c, ArraysCache)]
if arrays_layers:
print(f" ArraysCache layers (not populated): {arrays_layers[:10]}{'...' if len(arrays_layers) > 10 else ''}")
if generated_tokens and not all(t == generated_tokens[0] for t in generated_tokens):
print(f"\n COHERENT OUTPUT: YES (varied tokens)")
else:
print(f"\n COHERENT OUTPUT: POSSIBLY NOT (all same token)")
if __name__ == "__main__":
main()

80
scripts/test_vllm_multi.sh Executable file
View File

@@ -0,0 +1,80 @@
#!/usr/bin/env bash
set -euo pipefail
HOST="${1:-gx10-de89}"
PORT="${2:-52415}"
NUM_REQUESTS="${3:-4}"
MODEL="${4:-Qwen/Qwen2.5-0.5B-Instruct}"
echo "Sending $NUM_REQUESTS parallel requests to $HOST:$PORT ($MODEL) with ~32k token prompts..."
echo
tmpdir=$(mktemp -d)
pids=()
for i in $(seq 1 "$NUM_REQUESTS"); do
(
python3 -c "
import json, sys, time, urllib.request
import random
random.seed($i * 9999)
topics = [
'mathematics', 'philosophy', 'religion', 'culture', 'astronomy',
'biology', 'music', 'architecture', 'literature', 'physics',
'chemistry', 'geology', 'psychology', 'economics', 'linguistics',
]
random.shuffle(topics)
sentences = []
for j in range(95):
t1, t2, t3 = topics[j % len(topics)], topics[(j+3) % len(topics)], topics[(j+7) % len(topics)]
sentences.append(
f'In the field of {t1}, the number {$i * 1000 + j} holds particular significance '
f'when examining its relationship to {t2} and {t3}. Scholars have long debated '
f'whether the patterns observed in iteration {j} of this analysis reveal deeper '
f'structural connections between seemingly unrelated disciplines. The evidence '
f'from experiment {$i * 7 + j * 13} suggests that cross-domain numerical '
f'correlations emerge at scale {j * $i}, challenging conventional assumptions '
f'about the independence of these fields. ')
prompt = ' '.join(sentences) + f' Summarize the key finding about the number {$i}.'
payload = json.dumps({
'model': '$MODEL',
'messages': [{'role': 'user', 'content': prompt}],
'max_tokens': 1,
'stream': True,
}).encode()
req = urllib.request.Request(
'http://$HOST:$PORT/v1/chat/completions',
data=payload,
headers={'Content-Type': 'application/json'},
)
t0 = time.perf_counter()
try:
resp = urllib.request.urlopen(req, timeout=300)
first_byte = None
for line in resp:
if first_byte is None:
first_byte = time.perf_counter()
line = line.decode().strip()
if line.startswith('data: ') and line != 'data: [DONE]':
break
ttft = (first_byte or time.perf_counter()) - t0
prompt_tokens = len(prompt.split()) * 1.3 # rough estimate
tps = prompt_tokens / ttft
print(f'request $i: TTFT={ttft:.2f}s ~{int(prompt_tokens)} prompt tokens ~{int(tps)} tok/s prefill')
except Exception as e:
elapsed = time.perf_counter() - t0
print(f'request $i: FAILED after {elapsed:.2f}s — {e}', file=sys.stderr)
sys.exit(1)
" >"$tmpdir/$i" 2>&1
) &
pids+=($!)
done
for pid in "${pids[@]}"; do
wait "$pid"
done
for i in $(seq 1 "$NUM_REQUESTS"); do
cat "$tmpdir/$i"
done
rm -rf "$tmpdir"

View File

@@ -93,7 +93,7 @@ class Worker:
tg.start_soon(self.plan_step)
tg.start_soon(self._event_applier)
tg.start_soon(self._poll_connection_updates)
tg.start_soon(self._update_prefill_endpoints())
tg.start_soon(self._update_prefill_endpoints)
finally:
# Actual shutdown code - waits for all tasks to complete before executing.
logger.info("Stopping Worker")
@@ -147,9 +147,9 @@ class Worker:
candidates.sort(key=lambda i: self._IFACE_PRIORITY.get(i.interface_type, 3))
return candidates[0].ip_address
def _update_prefill_endpoints(self) -> None:
async def _update_prefill_endpoints(self) -> None:
while True:
anyio.sleep(5)
await anyio.sleep(5)
try:
for runner_sup in self.runners.values():
instance = runner_sup.bound_instance.instance
@@ -176,6 +176,7 @@ class Worker:
endpoints.append({"host": ip, "port": port})
safe_model = str(my_model_id).replace("/", "--")
# TODO: Change this to be in the task with a list of optional prefill endpoints.
path = f"/tmp/exo_prefill_endpoints_{safe_model}.json"
with open(path, "w") as f:
json.dump(endpoints, f)

82
tmp/bench_overnight.log Normal file
View File

@@ -0,0 +1,82 @@
=== Starting overnight bench runs at Tue Mar 10 22:27:20 GMT 2026 ===
--- [1/8] Qwen3.5-27B-GPTQ-Int4 ---
2026-03-10 22:27:22.370 | INFO | __main__:main:300 - pp/tg mode: combinations (product) - 2 pairs
You are using a model of type qwen3_5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
2026-03-10 22:27:23.659 | DEBUG | __main__:main:317 - [exo-bench] loaded tokenizer: mlx-community/Qwen3.5-27B-GPTQ-Int4 for prompt sizer
2026-03-10 22:27:23.661 | DEBUG | __main__:main:339 - exo-bench model: short_id=Qwen3.5-27B-GPTQ-Int4 full_id=mlx-community/Qwen3.5-27B-GPTQ-Int4
2026-03-10 22:27:23.661 | INFO | __main__:main:340 - placements: 1
2026-03-10 22:27:23.661 | INFO | __main__:main:342 - - Pipeline / MlxRing / nodes=1
2026-03-10 22:27:23.661 | INFO | __main__:main:353 - Planning phase: checking downloads...
2026-03-10 22:27:23.670 | INFO | harness:run_planning_phase:415 - Started download on 12D3KooWGXXhpS3kzjfDVuBGX8AeARLjVdAFaDouYJtXDVXkyq7f
2026-03-10 22:27:23.674 | INFO | __main__:main:365 - Download: model already cached
2026-03-10 22:27:23.675 | INFO | __main__:main:377 - ================================================================================
2026-03-10 22:27:23.675 | INFO | __main__:main:378 - PLACEMENT: Pipeline / MlxRing / nodes=1 / instance_id=725d8b5f-cc83-4dd9-8a4e-c6e2bc5b0607
2026-03-10 22:27:31.871 | INFO | __main__:main:409 - --- pp=700 tg=32067 concurrency=1 ---
2026-03-10 22:27:34.896 | INFO | __main__:build:224 - tok=700
2026-03-10 23:00:45.925 | INFO | __main__:main:519 - prompt_tps=11.76 gen_tps=16.13 prompt_tokens=700 gen_tokens=32067 peak_memory=31.81GB
2026-03-10 23:00:47.935 | INFO | __main__:main:409 - --- pp=700 tg=33085 concurrency=1 ---
2026-03-10 23:00:50.948 | INFO | __main__:build:224 - tok=700
2026-03-10 23:35:06.460 | INFO | __main__:main:519 - prompt_tps=11.87 gen_tps=16.12 prompt_tokens=700 gen_tokens=33085 peak_memory=31.89GB
2026-03-10 23:35:08.978 | DEBUG | __main__:main:532 - Deleted instance 725d8b5f-cc83-4dd9-8a4e-c6e2bc5b0607
2026-03-10 23:35:13.985 | DEBUG | __main__:main:541 -
Wrote results JSON: bench/results.json
--- [2/8] NVIDIA-Nemotron-3-Nano-30B-A3B (1120,1330,23100) ---
2026-03-10 23:35:17.156 | INFO | __main__:main:300 - pp/tg mode: combinations (product) - 3 pairs
2026-03-10 23:35:18.698 | DEBUG | __main__:main:317 - [exo-bench] loaded tokenizer: mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-BF16 for prompt sizer
2026-03-10 23:35:18.701 | DEBUG | __main__:main:339 - exo-bench model: short_id=NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-BF16 full_id=mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-BF16
2026-03-10 23:35:18.701 | INFO | __main__:main:340 - placements: 1
2026-03-10 23:35:18.701 | INFO | __main__:main:342 - - Pipeline / MlxRing / nodes=1
2026-03-10 23:35:18.701 | INFO | __main__:main:353 - Planning phase: checking downloads...
2026-03-10 23:35:18.710 | INFO | harness:run_planning_phase:415 - Started download on 12D3KooWGXXhpS3kzjfDVuBGX8AeARLjVdAFaDouYJtXDVXkyq7f
2026-03-10 23:35:18.716 | INFO | __main__:main:365 - Download: model already cached
2026-03-10 23:35:18.716 | INFO | __main__:main:377 - ================================================================================
2026-03-10 23:35:18.716 | INFO | __main__:main:378 - PLACEMENT: Pipeline / MlxRing / nodes=1 / instance_id=fb4d3883-2570-42e9-b3ee-aa1e4a1f8121
2026-03-10 23:35:43.422 | INFO | __main__:main:409 - --- pp=700 tg=1120 concurrency=1 ---
2026-03-10 23:35:46.449 | INFO | __main__:build:224 - tok=700
2026-03-10 23:36:05.631 | INFO | __main__:main:519 - prompt_tps=42.56 gen_tps=62.91 prompt_tokens=700 gen_tokens=1120 peak_memory=64.47GB
2026-03-10 23:36:07.636 | INFO | __main__:main:409 - --- pp=700 tg=1330 concurrency=1 ---
2026-03-10 23:36:10.654 | INFO | __main__:build:224 - tok=700
2026-03-10 23:36:33.088 | INFO | __main__:main:519 - prompt_tps=42.28 gen_tps=62.93 prompt_tokens=700 gen_tokens=1330 peak_memory=64.47GB
2026-03-10 23:36:35.097 | INFO | __main__:main:409 - --- pp=700 tg=23100 concurrency=1 ---
2026-03-10 23:36:38.107 | INFO | __main__:build:224 - tok=700
2026-03-10 23:43:00.943 | INFO | __main__:main:519 - prompt_tps=42.50 gen_tps=60.57 prompt_tokens=700 gen_tokens=23100 peak_memory=64.47GB
2026-03-10 23:43:03.952 | DEBUG | __main__:main:532 - Deleted instance fb4d3883-2570-42e9-b3ee-aa1e4a1f8121
2026-03-10 23:43:08.954 | DEBUG | __main__:main:541 -
Wrote results JSON: bench/results.json
--- [3/8] Qwen3.5-35B-A3B-bf16 ---
2026-03-11 10:33:18.576 | INFO | __main__:main:300 - pp/tg mode: combinations (product) - 4 pairs
2026-03-11 10:33:18.581 | INFO | harness:resolve_model_short_id:187 - Model not in /models, adding from HuggingFace: mlx-community/Qwen3.5-35B-A3B-bf16
You are using a model of type qwen3_5_moe to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
2026-03-11 10:33:22.909 | DEBUG | __main__:main:317 - [exo-bench] loaded tokenizer: mlx-community/Qwen3.5-35B-A3B-bf16 for prompt sizer
2026-03-11 10:33:22.913 | DEBUG | __main__:main:339 - exo-bench model: short_id=Qwen3.5-35B-A3B-bf16 full_id=mlx-community/Qwen3.5-35B-A3B-bf16
2026-03-11 10:33:22.913 | INFO | __main__:main:340 - placements: 1
2026-03-11 10:33:22.913 | INFO | __main__:main:342 - - Pipeline / MlxRing / nodes=1
2026-03-11 10:33:22.913 | INFO | __main__:main:353 - Planning phase: checking downloads...
2026-03-11 10:33:22.923 | INFO | harness:run_planning_phase:415 - Started download on 12D3KooWGXXhpS3kzjfDVuBGX8AeARLjVdAFaDouYJtXDVXkyq7f
2026-03-11 10:44:01.613 | INFO | __main__:main:363 - Download: 638.7s (freshly downloaded)
2026-03-11 10:44:01.613 | INFO | __main__:main:377 - ================================================================================
2026-03-11 10:44:01.613 | INFO | __main__:main:378 - PLACEMENT: Pipeline / MlxRing / nodes=1 / instance_id=3a5aa42f-a0e1-4273-b61e-56b4e34e0c1d
2026-03-11 10:44:27.758 | INFO | __main__:main:409 - --- pp=700 tg=6200 concurrency=1 ---
2026-03-11 10:44:30.785 | INFO | __main__:build:224 - tok=700
2026-03-11 10:46:16.814 | INFO | __main__:main:519 - prompt_tps=39.17 gen_tps=59.35 prompt_tokens=700 gen_tokens=6200 peak_memory=70.10GB
2026-03-11 10:46:18.816 | INFO | __main__:main:409 - --- pp=700 tg=6450 concurrency=1 ---
2026-03-11 10:46:21.834 | INFO | __main__:build:224 - tok=700
2026-03-11 10:48:11.923 | INFO | __main__:main:519 - prompt_tps=38.95 gen_tps=59.50 prompt_tokens=700 gen_tokens=6450 peak_memory=70.10GB
2026-03-11 10:48:13.927 | INFO | __main__:main:409 - --- pp=700 tg=25600 concurrency=1 ---
2026-03-11 10:48:16.940 | INFO | __main__:build:224 - tok=700
2026-03-11 10:55:53.583 | INFO | __main__:main:519 - prompt_tps=39.36 gen_tps=56.27 prompt_tokens=700 gen_tokens=25600 peak_memory=70.34GB
2026-03-11 10:55:55.585 | INFO | __main__:main:409 - --- pp=700 tg=38000 concurrency=1 ---
2026-03-11 10:55:58.608 | INFO | __main__:build:224 - tok=700
2026-03-11 11:07:35.629 | INFO | __main__:main:519 - prompt_tps=39.43 gen_tps=54.66 prompt_tokens=700 gen_tokens=38000 peak_memory=70.61GB
2026-03-11 11:07:38.598 | DEBUG | __main__:main:532 - Deleted instance 3a5aa42f-a0e1-4273-b61e-56b4e34e0c1d
2026-03-11 11:07:43.607 | DEBUG | __main__:main:541 -
Wrote results JSON: bench/results.json