From ba472da84ffc04b40bc307f69579dd23eea2bb91 Mon Sep 17 00:00:00 2001 From: Ryuichi Leo Takashige Date: Wed, 18 Mar 2026 11:31:55 +0000 Subject: [PATCH] Implement prefill/decode --- bench_overnight.log | 0 bench_overnight.sh | 24 ++ scripts/disaggregated/capture_connector.py | 89 +++++++ scripts/disaggregated/inspect_vllm_kv.py | 155 ++++++++++++ scripts/disaggregated/test_kv_extract.py | 258 ++++++++++++++++++++ scripts/disaggregated/test_kv_inject.py | 265 +++++++++++++++++++++ scripts/test_vllm_multi.sh | 80 +++++++ src/exo/worker/main.py | 7 +- tmp/bench_overnight.log | 82 +++++++ 9 files changed, 957 insertions(+), 3 deletions(-) create mode 100644 bench_overnight.log create mode 100755 bench_overnight.sh create mode 100644 scripts/disaggregated/capture_connector.py create mode 100644 scripts/disaggregated/inspect_vllm_kv.py create mode 100644 scripts/disaggregated/test_kv_extract.py create mode 100644 scripts/disaggregated/test_kv_inject.py create mode 100755 scripts/test_vllm_multi.sh create mode 100644 tmp/bench_overnight.log diff --git a/bench_overnight.log b/bench_overnight.log new file mode 100644 index 000000000..e69de29bb diff --git a/bench_overnight.sh b/bench_overnight.sh new file mode 100755 index 000000000..1428e3430 --- /dev/null +++ b/bench_overnight.sh @@ -0,0 +1,24 @@ +#!/bin/bash +set -e + +export PATH="/opt/homebrew/bin:$PATH" + +echo "=== Starting overnight bench runs at $(date) ===" + +echo "--- [4/8] Qwen3.5-122B-A10B-GPTQ-Int4 ---" +echo "Skipping because Int 4" +#uv run bench/exo_bench.py --force-download --model "Qwen/Qwen3.5-122B-A10B-GPTQ-Int4" --pp 700 --tg 36000 --repeat 1 + +echo "--- [5/8] Qwen3.5-27B-FP8 ---" +#uv run bench/exo_bench.py --force-download --model "Qwen/Qwen3.5-27B-FP8" --pp 700 --tg 35133 --repeat 1 + +echo "--- [6/8] GLM-4.7-Flash-bf16 ---" +uv run bench/exo_bench.py --force-download --model "mlx-community/GLM-4.7-Flash-bf16" --pp 700 --tg 29000 --repeat 1 + +echo "--- [7/8] NVIDIA-Nemotron-3-Nano-30B-A3B (23000,1200) ---" +uv run bench/exo_bench.py --force-download --model "mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-BF16" --pp 700 --tg 23000,1200 --repeat 1 + +echo "--- [8/8] Qwen3.5-27B-bf16 ---" +uv run bench/exo_bench.py --force-download --model "mlx-community/Qwen3.5-27B-bf16" --pp 700 --tg 35400 --repeat 1 + +echo "=== All bench runs complete at $(date) ===" diff --git a/scripts/disaggregated/capture_connector.py b/scripts/disaggregated/capture_connector.py new file mode 100644 index 000000000..fc3607dda --- /dev/null +++ b/scripts/disaggregated/capture_connector.py @@ -0,0 +1,89 @@ +"""Minimal KVConnector that captures per-layer cache data.""" + +from dataclasses import dataclass +from typing import Any + +import torch + +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) + +captured_layers: dict[str, Any] = {} + + +@dataclass +class CaptureMetadata(KVConnectorMetadata): + pass + + +class CaptureConnector(KVConnectorBase_V1): + def __init__(self, vllm_config, role, kv_cache_config=None): + super().__init__(vllm_config, role, kv_cache_config) + + def start_load_kv(self, forward_context, **kwargs): + pass + + def wait_for_layer_load(self, layer_name): + pass + + def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs): + import time + slot_mapping = getattr(attn_metadata, 'slot_mapping', None) + if slot_mapping is not None and slot_mapping.shape[0] <= 100: + return + t0 = time.perf_counter() + torch.cuda.synchronize() + t_sync = time.perf_counter() - t0 + if isinstance(kv_layer, (list, tuple)): + captured_layers[layer_name] = [t.cpu().clone() for t in kv_layer] + else: + slot_mapping = getattr(attn_metadata, 'slot_mapping', None) + if slot_mapping is not None: + if kv_layer.shape[0] == 2: + k_all = kv_layer[0] + v_all = kv_layer[1] + else: + k_all = kv_layer[:, 0] + v_all = kv_layer[:, 1] + k_flat = k_all.reshape(-1, *k_all.shape[-2:]) + v_flat = v_all.reshape(-1, *v_all.shape[-2:]) + valid = slot_mapping >= 0 + safe_sm = slot_mapping.clamp(min=0) + keys = k_flat[safe_sm] + values = v_flat[safe_sm] + keys[~valid] = 0 + values[~valid] = 0 + prev = captured_layers.get(layer_name) + if isinstance(prev, dict) and "keys" in prev: + t1 = time.perf_counter() + captured_layers[layer_name] = { + "keys": torch.cat([prev["keys"], keys.cpu()], dim=0), + "values": torch.cat([prev["values"], values.cpu()], dim=0), + } + t_copy = time.perf_counter() - t1 + else: + t1 = time.perf_counter() + captured_layers[layer_name] = { + "keys": keys.cpu(), + "values": values.cpu(), + } + t_copy = time.perf_counter() - t1 + if "layers.3." in layer_name: + print(f" [attn save] sync={t_sync*1000:.1f}ms copy={t_copy*1000:.1f}ms tokens={keys.shape[0]}") + else: + captured_layers[layer_name] = kv_layer.cpu().clone() + + def wait_for_save(self): + pass + + def get_num_new_matched_tokens(self, request, num_computed_tokens): + return 0, False + + def update_state_after_alloc(self, request, blocks, num_external_tokens): + pass + + def build_connector_meta(self, scheduler_output): + return CaptureMetadata() diff --git a/scripts/disaggregated/inspect_vllm_kv.py b/scripts/disaggregated/inspect_vllm_kv.py new file mode 100644 index 000000000..b46c49c3a --- /dev/null +++ b/scripts/disaggregated/inspect_vllm_kv.py @@ -0,0 +1,155 @@ +"""Inspect vLLM KV cache structure per-layer after prefill. + +Runs on DGX Spark. Prints per-layer shapes, dtypes, kv_cache_config, +and layer_to_group mapping to understand what vLLM stores for each +model architecture (standard attention, sliding window, GatedDeltaNet). + +Usage: + uv run python scripts/disaggregated/inspect_vllm_kv.py --model ~/.local/share/exo/models/openai--gpt-oss-20b +""" + +import argparse +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "src")) + +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD" +os.environ["VLLM_BATCH_INVARIANT"] = "1" + +from exo.worker.runner.bootstrap import _ensure_cuda_libs +_ensure_cuda_libs() + +import torch + + +def _build_layer_groups(kv_cache_config): + group_lookup = {} + for group_idx, group_spec in enumerate(kv_cache_config.kv_cache_groups): + for layer_name in group_spec.layer_names: + group_lookup[layer_name] = group_idx + + layer_to_group = [] + for tensor_spec in kv_cache_config.kv_cache_tensors: + for name in tensor_spec.shared_by: + layer_to_group.append(group_lookup[name]) + return layer_to_group + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, help="Path to model") + parser.add_argument("--prompt", default="Hello, world! How are you today?", help="Prompt to prefill") + args = parser.parse_args() + + from exo.worker.engines.vllm.vllm_generator import load_vllm_engine + from exo.worker.engines.vllm.growable_cache import get_model_runner + + print(f"Loading vLLM engine from {args.model}...") + engine, _, prefix_cache = load_vllm_engine( + model_path=args.model, + model_id=args.model, + trust_remote_code=True, + ) + print("Engine loaded.\n") + + from vllm import SamplingParams + + tokenizer = engine.get_tokenizer() + token_ids = tokenizer.encode(args.prompt, add_special_tokens=False) + print(f"Prompt: {args.prompt!r}") + print(f"Token IDs: {len(token_ids)} tokens\n") + + request_id = "inspect-test" + params = SamplingParams(max_tokens=1, detokenize=False) + engine.add_request(request_id, {"prompt_token_ids": token_ids}, params) + + while engine.has_unfinished_requests(): + engine.step() + + model_runner = get_model_runner() + if model_runner is None: + print("ERROR: model_runner is None") + return + + print("=" * 70) + print("PER-LAYER KV CACHE TENSORS (model_runner.kv_caches)") + print("=" * 70) + kv_caches = model_runner.kv_caches + for i, kv in enumerate(kv_caches): + if isinstance(kv, list): + shapes = [t.shape for t in kv] + dtypes = [t.dtype for t in kv] + print(f" Layer {i:3d}: list of {len(kv)} tensors — shapes={shapes}, dtypes={dtypes}") + elif isinstance(kv, torch.Tensor): + print(f" Layer {i:3d}: shape={tuple(kv.shape)}, dtype={kv.dtype}, device={kv.device}") + else: + print(f" Layer {i:3d}: type={type(kv).__name__}") + print(f"\n Total layers with KV: {len(kv_caches)}\n") + + engine_core = engine.engine_core.engine_core + kv_cache_config = engine_core.scheduler.kv_cache_manager.kv_cache_config + + print("=" * 70) + print("KV CACHE CONFIG") + print("=" * 70) + + print(f"\n Number of KV cache groups: {len(kv_cache_config.kv_cache_groups)}") + for gi, group in enumerate(kv_cache_config.kv_cache_groups): + print(f"\n Group {gi}:") + print(f" Layer names ({len(group.layer_names)}):") + for name in group.layer_names[:5]: + print(f" {name}") + if len(group.layer_names) > 5: + print(f" ... and {len(group.layer_names) - 5} more") + + print(f"\n Number of KV cache tensors: {len(kv_cache_config.kv_cache_tensors)}") + for ti, tensor_spec in enumerate(kv_cache_config.kv_cache_tensors): + shared = tensor_spec.shared_by[:3] + extra = f" ... +{len(tensor_spec.shared_by)-3}" if len(tensor_spec.shared_by) > 3 else "" + print(f" Tensor {ti}: shared_by={shared}{extra}") + + layer_to_group = _build_layer_groups(kv_cache_config) + print(f"\n layer_to_group ({len(layer_to_group)} entries): {layer_to_group[:10]}{'...' if len(layer_to_group) > 10 else ''}") + + coordinator = engine_core.scheduler.kv_cache_manager.coordinator + null_block = coordinator.block_pool.null_block + + internal_id = None + for mgr in coordinator.single_type_managers: + for key in mgr.req_to_blocks: + if str(key).startswith(request_id): + internal_id = str(key) + break + if internal_id: + break + + if internal_id: + print(f"\n Request internal_id: {internal_id}") + for gi, mgr in enumerate(coordinator.single_type_managers): + blocks = mgr.req_to_blocks.get(internal_id) + if blocks: + real_blocks = [b for b in blocks if b is not null_block and not b.is_null] + null_count = len(blocks) - len(real_blocks) + print(f" Group {gi}: {len(real_blocks)} real blocks, {null_count} null blocks, block_size={mgr.block_size}") + else: + print(f" Group {gi}: no blocks") + + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f" Model: {args.model}") + print(f" KV cache layers: {len(kv_caches)}") + print(f" KV cache groups: {len(kv_cache_config.kv_cache_groups)}") + print(f" Layer-to-group mapping entries: {len(layer_to_group)}") + unique_shapes = set() + for kv in kv_caches: + if isinstance(kv, torch.Tensor): + unique_shapes.add(tuple(kv.shape)) + print(f" Unique tensor shapes: {unique_shapes}") + + +if __name__ == "__main__": + main() diff --git a/scripts/disaggregated/test_kv_extract.py b/scripts/disaggregated/test_kv_extract.py new file mode 100644 index 000000000..021f42d42 --- /dev/null +++ b/scripts/disaggregated/test_kv_extract.py @@ -0,0 +1,258 @@ +"""Extract KV cache per-layer from vLLM using a real KVConnector. + +Patches vLLM to allow KVConnector on hybrid models (attention + GDN). + +Usage: + uv run python scripts/disaggregated/test_kv_extract.py --model ~/.local/share/exo/models/Qwen--Qwen3.5-2B --output /tmp/kv_cache_qwen35/ +""" + +import argparse +import json +import os +import sys +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "src")) +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD" +os.environ["VLLM_BATCH_INVARIANT"] = "1" + +from exo.worker.runner.bootstrap import _ensure_cuda_libs +_ensure_cuda_libs() + +import torch + + +def _patch_vllm_for_connector(): + """Patch vLLM to allow KVConnector on hybrid models.""" + from vllm.v1.core import kv_cache_utils + + original_unify = kv_cache_utils.unify_hybrid_kv_cache_specs + + def patched_unify(kv_cache_spec): + try: + original_unify(kv_cache_spec) + except ValueError: + pass + + kv_cache_utils.unify_hybrid_kv_cache_specs = patched_unify + + from vllm.v1.core.sched import scheduler as sched_mod + original_connector_finished = sched_mod.Scheduler._connector_finished + + def patched_connector_finished(self, request): + return False, None + + sched_mod.Scheduler._connector_finished = patched_connector_finished + + from capture_connector import CaptureConnector + from vllm.distributed.kv_transfer.kv_connector import factory + + original_get = factory.KVConnectorFactory._get_connector_class_with_compat + + @classmethod + def patched_get(cls, kv_transfer_config): + if "capture_connector" in (kv_transfer_config.kv_connector or ""): + return CaptureConnector, None + return original_get.__func__(cls, kv_transfer_config) + + factory.KVConnectorFactory._get_connector_class_with_compat = patched_get + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True) + parser.add_argument("--output", required=True) + _lorem = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus magna felis sollicitudin mauris. Integer in mauris eu nibh euismod gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis risus a elit. Etiam tempor. Ut ullamcorper, ligula ut dictum pharetra, nisi nunc fringilla magna, in commodo elit erat nec turpis. Ut pharetra augue nec augue. Nam elit agna, endrerit sit amet, tincidunt ac, viverra sed, nulla. Donec porta diam eu massa. Quisque diam lorem, interdum vitae, dapibus ac, scelerisque vitae, pede. Donec eget tellus non erat lacinia fermentum. Donec in velit vel ipsum auctor pulvinar. Vestibulum iaculis lacinia est. Proin dictum elementum velit. Fusce euismod consequat ante. Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Pellentesque sed dolor. Aliquam congue fermentum nisl. Mauris accumsan nulla vel diam. Sed in lacus ut enim adipiscing aliquet. Nulla venenatis. In pede mi, aliquet sit amet, euismod in, auctor ut, ligula. Aliquam dapibus tincidunt metus. Praesent justo dolor, lobortis quis, lobortis dignissim, pulvinar ac, lorem. " + parser.add_argument("--prompt", default=_lorem * 21 + "Now answer this question: What is the capital of France and why is it historically significant? Give a detailed answer.") + args = parser.parse_args() + + out_dir = Path(args.output) + out_dir.mkdir(parents=True, exist_ok=True) + + _patch_vllm_for_connector() + + from vllm.engine.arg_utils import EngineArgs + from vllm.v1.engine.llm_engine import LLMEngine + from exo.worker.engines.vllm.growable_cache import patch_vllm, set_prefix_cache + from exo.worker.engines.mlx.cache import KVPrefixCache + + patch_vllm() + + prefix_cache = KVPrefixCache(group=None) + set_prefix_cache(prefix_cache) + + engine_args = EngineArgs( + model=args.model, + served_model_name=args.model, + gpu_memory_utilization=0.05, + trust_remote_code=True, + load_format="fastsafetensors", + enable_prefix_caching=False, + attention_backend="TRITON_ATTN", + enforce_eager=True, + disable_log_stats=True, + kv_transfer_config={ + "kv_connector": "capture_connector:CaptureConnector", + "kv_role": "kv_both", + }, + ) + + print(f"Loading engine with KVConnector...") + engine = LLMEngine.from_engine_args(engine_args) + print("Engine loaded.") + + from exo.worker.engines.vllm.growable_cache import get_model_runner + from vllm.model_executor.layers.mamba.abstract import MambaBase + from capture_connector import captured_layers as gdn_captured + + model_runner = get_model_runner() + + from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as orig_causal_conv1d_fn + import vllm.model_executor.layers.mamba.ops.causal_conv1d as cc_mod + + gdn_states: dict[int, dict[str, torch.Tensor]] = {} + gdn_call_idx = [0] + gdn_layer_order = [0, 1, 2, 4, 5, 6, 8, 9, 10, 12, 13, 14, 16, 17, 18, 20, 21, 22] + + def patched_causal_conv1d_fn(*args, conv_states=None, cache_indices=None, **kwargs): + result = orig_causal_conv1d_fn(*args, conv_states=conv_states, cache_indices=cache_indices, **kwargs) + if conv_states is not None and cache_indices is not None: + x = args[0] if args else None + if x is not None and x.shape[0] <= 100: + return result + import time as _time + t0 = _time.perf_counter() + torch.cuda.synchronize() + t_sync = _time.perf_counter() - t0 + ci = cache_indices[0].item() if cache_indices.numel() > 0 else 0 + idx = gdn_call_idx[0] + layer_idx = gdn_layer_order[idx % len(gdn_layer_order)] + t1 = _time.perf_counter() + conv_at_ci = conv_states[ci:ci+1].transpose(-1, -2).contiguous().cpu() + t_copy = _time.perf_counter() - t1 + gdn_states.setdefault(layer_idx, {})["conv"] = conv_at_ci + gdn_states[layer_idx]["ci"] = ci + if gdn_call_idx[0] < 3: + print(f" [gdn save] sync={t_sync*1000:.1f}ms copy={t_copy*1000:.1f}ms layer={layer_idx}") + gdn_call_idx[0] += 1 + return result + + cc_mod.causal_conv1d_fn = patched_causal_conv1d_fn + for mod in list(sys.modules.values()): + if mod is None or mod is cc_mod: + continue + if hasattr(mod, 'causal_conv1d_fn') and mod.causal_conv1d_fn is orig_causal_conv1d_fn: + mod.causal_conv1d_fn = patched_causal_conv1d_fn + print(f" Patched causal_conv1d_fn") + + from exo.worker.engines.vllm.vllm_generator import VllmBatchEngine + from exo.shared.types.text_generation import TextGenerationTaskParams, InputMessage + from exo.shared.types.tasks import TaskId + + batch_engine = VllmBatchEngine(engine=engine, model_id=args.model, prefix_cache=prefix_cache) + + task = TextGenerationTaskParams( + model=args.model, + input=[InputMessage(role="user", content=args.prompt)], + max_completion_tokens=1, + ) + + task_id = batch_engine.submit(task_id=TaskId("extract"), task_params=task, prompt=args.prompt) + + print(f"Running prefill via VllmBatchEngine...") + t0 = time.perf_counter() + while batch_engine.has_work: + results = batch_engine.step() + for tid, resp in results: + print(f" Prefill done in {(time.perf_counter()-t0)*1000:.0f}ms") + batch_engine.cancel([tid]) + break + if results: + break + t1 = time.perf_counter() + print(f"Total: {(t1-t0)*1000:.0f}ms") + + prompt_mx = prefix_cache.prompts[0] if prefix_cache.prompts else None + token_ids = [int(x) for x in prompt_mx.tolist()] if prompt_mx is not None else [] + + from capture_connector import captured_layers + + print(f"\nCaptured {len(captured_layers)} layers via save_kv_layer:") + for name in sorted(captured_layers.keys()): + v = captured_layers[name] + if isinstance(v, list): + print(f" {name}: {[tuple(t.shape) for t in v]}") + elif isinstance(v, torch.Tensor): + print(f" {name}: {tuple(v.shape)}") + else: + print(f" {name}: {type(v).__name__}") + + num_tokens = len(token_ids) + print(f" Chat-templated prompt: {num_tokens} tokens") + total_layers = 24 + + for f_old in out_dir.glob("layer_*"): + f_old.unlink() + + metadata = { + "model": args.model, + "prompt": args.prompt, + "num_tokens": num_tokens, + "token_ids": token_ids, + "num_layers": total_layers, + "layers": [], + } + + print(f"\nSaving {total_layers} layers...") + + torch.cuda.synchronize() + for layer_idx in sorted(gdn_states.keys()): + ci = gdn_states[layer_idx]["ci"] + kv = model_runner.kv_caches[layer_idx] + if isinstance(kv, (list, tuple)) and len(kv) > 1: + rec_pool = kv[1] + rec = rec_pool[ci:ci+1].cpu().clone() + gdn_states[layer_idx]["rec"] = rec + + for li in range(total_layers): + if li in gdn_states: + s = gdn_states[li] + conv = s.get("conv") + rec = s.get("rec") + torch.save(conv, out_dir / f"layer_{li:03d}_conv.pt") + if rec is not None: + torch.save(rec, out_dir / f"layer_{li:03d}_rec.pt") + metadata["layers"].append({"type": "gdn", "conv": list(conv.shape), "rec": list(rec.shape) if rec is not None else None}) + print(f" Layer {li}: GDN conv={tuple(conv.shape)}, rec={tuple(rec.shape) if rec is not None else 'None'}") + else: + attn_name = None + for n in captured_layers: + parts = n.split(".") + for pi, p in enumerate(parts): + if p == "layers" and pi + 1 < len(parts) and parts[pi + 1] == str(li): + attn_name = n + break + if attn_name and isinstance(captured_layers[attn_name], dict): + kv = captured_layers[attn_name] + torch.save(kv["keys"], out_dir / f"layer_{li:03d}_keys.pt") + torch.save(kv["values"], out_dir / f"layer_{li:03d}_values.pt") + if "last_chunk_keys" in kv: + torch.save(kv["last_chunk_keys"], out_dir / f"layer_{li:03d}_keys_last.pt") + torch.save(kv["last_chunk_values"], out_dir / f"layer_{li:03d}_values_last.pt") + metadata["layers"].append({"type": "kv", "keys_shape": list(kv["keys"].shape), "values_shape": list(kv["values"].shape)}) + print(f" Layer {li}: KV keys={tuple(kv['keys'].shape)}, values={tuple(kv['values'].shape)}") + else: + metadata["layers"].append({"type": "missing"}) + print(f" Layer {li}: MISSING") + + with open(out_dir / "metadata.json", "w") as f: + json.dump(metadata, f, indent=2) + print(f"\nSaved metadata to {out_dir}/metadata.json") + + +if __name__ == "__main__": + main() diff --git a/scripts/disaggregated/test_kv_inject.py b/scripts/disaggregated/test_kv_inject.py new file mode 100644 index 000000000..d38bcc486 --- /dev/null +++ b/scripts/disaggregated/test_kv_inject.py @@ -0,0 +1,265 @@ +"""Inject extracted vLLM KV cache into MLX model caches and test decode. + +Runs on Mac (Apple Silicon). Loads per-layer KV tensors saved by +test_kv_extract.py, converts to MLX format, injects into MLX caches, +and generates tokens to verify correctness. + +Usage: + uv run python scripts/disaggregated/test_kv_inject.py \ + --model mlx-community/gpt-oss-20b-MXFP4-Q8 \ + --kv-dir /path/to/extracted/kv_cache/ \ + --num-tokens 20 +""" + +import argparse +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "src")) + +import mlx.core as mx +import torch +from mlx_lm import load +from mlx_lm.models.cache import ArraysCache, KVCache, RotatingKVCache + + +def _torch_to_mx(t: torch.Tensor) -> mx.array: + t = t.detach().cpu() + if t.dtype == torch.bfloat16: + return mx.array(t.float().numpy()).astype(mx.bfloat16) + return mx.array(t.numpy()) + + +def _to_bhsd(keys: torch.Tensor, values: torch.Tensor, num_tokens: int) -> tuple[mx.array, mx.array]: + """Convert vLLM block format to MLX BHSD [1, H, S, D]. + + Input can be: + - 4D [blocks, block_size, H, D] — flatten to [blocks*block_size, H, D], trim to num_tokens + - 3D [S, H, D] — use directly + """ + if keys.dim() == 4: + keys = keys.reshape(-1, keys.shape[2], keys.shape[3])[:num_tokens] + values = values.reshape(-1, values.shape[2], values.shape[3])[:num_tokens] + elif keys.dim() == 3: + keys = keys[:num_tokens] + values = values[:num_tokens] + + k_mx = _torch_to_mx(keys.permute(1, 0, 2).unsqueeze(0)) + v_mx = _torch_to_mx(values.permute(1, 0, 2).unsqueeze(0)) + return k_mx, v_mx + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, help="MLX model path/ID") + parser.add_argument("--kv-dir", required=True, help="Directory with extracted KV tensors") + parser.add_argument("--num-tokens", type=int, default=500, help="Tokens to generate") + parser.add_argument("--prompt", default=None, help="Override prompt (must match extraction prompt)") + args = parser.parse_args() + + kv_dir = Path(args.kv_dir) + with open(kv_dir / "metadata.json") as f: + metadata = json.load(f) + + num_extracted_layers = metadata["num_layers"] + num_tokens = metadata["num_tokens"] + vllm_token_ids = metadata.get("token_ids", []) + + print(f"Extracted KV: {num_extracted_layers} layers, {num_tokens} tokens") + if vllm_token_ids: + print(f" Using vLLM token_ids ({len(vllm_token_ids)} tokens)") + else: + print(f" WARNING: No token_ids in metadata") + + print(f"\nLoading MLX model: {args.model}") + model, tokenizer = load(args.model) + + caches = model.make_cache() + num_model_layers = len(caches) + print(f"\nMLX model expects {num_model_layers} cache layers:") + for i, c in enumerate(caches): + print(f" Layer {i:3d}: {type(c).__name__}", end="") + if isinstance(c, RotatingKVCache): + print(f" (max_size={c.max_size}, keep={c.keep})", end="") + elif isinstance(c, ArraysCache): + print(f" (size={len(c.state)})", end="") + print() + + layer_info = metadata.get("layers", []) + print(f"\nExtracted {num_extracted_layers} layers from vLLM") + + print("\nInjecting KV cache into MLX caches...") + injected = 0 + skipped = 0 + for i in range(num_model_layers): + cache = caches[i] + + if isinstance(cache, ArraysCache): + conv_path = kv_dir / f"layer_{i:03d}_conv.pt" + rec_path = kv_dir / f"layer_{i:03d}_rec.pt" + keys_path = kv_dir / f"layer_{i:03d}_keys.pt" + values_path = kv_dir / f"layer_{i:03d}_values.pt" + if conv_path.exists(): + conv = torch.load(conv_path, weights_only=True) + rec = torch.load(rec_path, weights_only=True) if rec_path.exists() else None + states = [_torch_to_mx(conv)] + states.append(_torch_to_mx(rec) if rec is not None else None) + cache.state = states + injected += 1 + print(f" Layer {i}: ArraysCache conv={tuple(conv.shape)}, rec={tuple(rec.shape) if rec is not None else 'None'}") + elif keys_path.exists(): + conv = torch.load(keys_path, weights_only=True) + rec = torch.load(values_path, weights_only=True) + cache.state = [_torch_to_mx(conv), _torch_to_mx(rec)] + injected += 1 + print(f" Layer {i}: ArraysCache (legacy) conv={tuple(conv.shape)}, rec={tuple(rec.shape)}") + else: + print(f" Layer {i}: SKIP — ArraysCache, no files") + skipped += 1 + continue + + keys_path = kv_dir / f"layer_{i:03d}_keys.pt" + values_path = kv_dir / f"layer_{i:03d}_values.pt" + if not keys_path.exists(): + skipped += 1 + continue + + keys_torch = torch.load(keys_path, weights_only=True) + values_torch = torch.load(values_path, weights_only=True) + k_mx, v_mx = _to_bhsd(keys_torch, values_torch, num_tokens) + seq_len = int(k_mx.shape[2]) + + if isinstance(cache, KVCache) and not isinstance(cache, RotatingKVCache): + cache.keys = k_mx + cache.values = v_mx + cache.offset = seq_len + injected += 1 + elif isinstance(cache, RotatingKVCache): + if seq_len <= cache.max_size: + cache.keys = k_mx + cache.values = v_mx + cache.offset = seq_len + cache._idx = seq_len + else: + keep = cache.keep + window = cache.max_size + sink_keys = k_mx[:, :, :keep, :] + sink_values = v_mx[:, :, :keep, :] + recent_keys = k_mx[:, :, -(window - keep):, :] + recent_values = v_mx[:, :, -(window - keep):, :] + cache.keys = mx.concatenate([sink_keys, recent_keys], axis=2) + cache.values = mx.concatenate([sink_values, recent_values], axis=2) + cache.offset = seq_len + cache._idx = keep + injected += 1 + print(f" Layer {i}: RotatingKVCache (seq_len={seq_len}, max_size={cache.max_size})") + else: + print(f" Layer {i}: SKIP — {type(cache).__name__}") + skipped += 1 + + print(f"\n Injected: {injected} layers, Skipped: {skipped} layers") + + from exo.worker.engines.vllm.kv_cache import TorchKVCache as TKV + print(f"\nRound-trip test (MLX → torch → MLX)...") + rt_caches = model.make_cache() + rt_tokens = mx.array(vllm_token_ids) + rt_logits = model(rt_tokens[None], cache=rt_caches) + mx.eval(rt_logits) + torch_rt = TKV.from_mlx_cache(rt_caches) + back_rt = torch_rt.to_mlx_cache() + rt_max_diff = 0.0 + for i in range(len(rt_caches)): + nc = rt_caches[i] + bc = back_rt[i] + if isinstance(nc, ArraysCache): + for ai in range(len(nc.state)): + if nc.state[ai] is not None and bc.state[ai] is not None: + d = mx.max(mx.abs(nc.state[ai].astype(mx.float32) - bc.state[ai].astype(mx.float32))).item() + rt_max_diff = max(rt_max_diff, d) + elif isinstance(nc, (KVCache, RotatingKVCache)) and nc.keys is not None: + nk, nv = nc.state + bk, bv = bc.state + d = mx.max(mx.abs(nk.astype(mx.float32) - bk.astype(mx.float32))).item() + rt_max_diff = max(rt_max_diff, d) + print(f" Round-trip max diff: {rt_max_diff:.4e} ({'PASS' if rt_max_diff < 0.01 else 'FAIL'})") + + print(f"\nComparing with MLX-native prefill...") + native_caches = rt_caches + + for i in range(num_model_layers): + nc = native_caches[i] + ic = caches[i] + if isinstance(nc, KVCache) and not isinstance(nc, RotatingKVCache) and nc.keys is not None and ic.keys is not None: + s = min(nc.offset, ic.offset) + nk = nc.keys[:, :, :s, :].astype(mx.float32) + ik = ic.keys[:, :, :s, :].astype(mx.float32) + nv = nc.values[:, :, :s, :].astype(mx.float32) + iv = ic.values[:, :, :s, :].astype(mx.float32) + k_diff = mx.max(mx.abs(nk - ik)).item() + v_diff = mx.max(mx.abs(nv - iv)).item() + if k_diff > 0.01 or i < 4 or i == num_model_layers - 1: + print(f" Layer {i:3d} KVCache: k_diff={k_diff:.4e}, v_diff={v_diff:.4e}, offset native={nc.offset} injected={ic.offset}") + elif isinstance(nc, RotatingKVCache): + pass + elif isinstance(nc, ArraysCache): + for ai in range(len(nc.state)): + na = nc.state[ai] + ia = ic.state[ai] + if na is not None and ia is not None: + diff = mx.max(mx.abs(na.astype(mx.float32) - ia.astype(mx.float32))).item() + if diff > 0.01 or i < 4 or i == num_model_layers - 1: + print(f" Layer {i:3d} Arrays[{ai}]: diff={diff:.4e}, native_shape={na.shape}, injected_shape={ia.shape}") + + native_last = mx.array([vllm_token_ids[-1]]) + native_decode_logits = model(native_last[None], cache=native_caches) + mx.eval(native_decode_logits) + native_first = mx.argmax(native_decode_logits[:, -1, :], axis=-1) + print(f" Native decode first token: {native_first.item()}, text: {tokenizer.decode([native_first.item()])!r}") + + print(f"\nDecoding {args.num_tokens} tokens with injected cache...") + last_tokens = mx.array(vllm_token_ids[-2:]) + logits = model(last_tokens[None], cache=caches) + mx.eval(logits) + + generated_tokens = [] + token = mx.argmax(logits[:, -1, :], axis=-1) + mx.eval(token) + generated_tokens.append(token.item()) + + for _ in range(args.num_tokens - 1): + logits = model(token[None], cache=caches) + mx.eval(logits) + token = mx.argmax(logits[:, -1, :], axis=-1) + mx.eval(token) + generated_tokens.append(token.item()) + + generated_text = tokenizer.decode(generated_tokens) + + print(f"\n{'='*70}") + print("RESULTS") + print(f"{'='*70}") + print(f" Model (vLLM): {metadata['model']}") + print(f" Model (MLX): {args.model}") + print(f" Prompt tokens: {num_tokens}") + print(f" Layers injected: {injected}/{num_model_layers}") + print(f" Type mismatches: 0") + print(f" Generated {len(generated_tokens)} tokens") + print(f" Text: {generated_text!r}") + + if False: + print(f"\n GAPS FOUND:") + for idx, got, expected in type_mismatches: + print(f" Layer {idx}: vLLM gives KV tensors, MLX wants {expected}") + arrays_layers = [i for i, c in enumerate(caches) if isinstance(c, ArraysCache)] + if arrays_layers: + print(f" ArraysCache layers (not populated): {arrays_layers[:10]}{'...' if len(arrays_layers) > 10 else ''}") + + if generated_tokens and not all(t == generated_tokens[0] for t in generated_tokens): + print(f"\n COHERENT OUTPUT: YES (varied tokens)") + else: + print(f"\n COHERENT OUTPUT: POSSIBLY NOT (all same token)") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_vllm_multi.sh b/scripts/test_vllm_multi.sh new file mode 100755 index 000000000..da6614b3d --- /dev/null +++ b/scripts/test_vllm_multi.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +set -euo pipefail + +HOST="${1:-gx10-de89}" +PORT="${2:-52415}" +NUM_REQUESTS="${3:-4}" +MODEL="${4:-Qwen/Qwen2.5-0.5B-Instruct}" + +echo "Sending $NUM_REQUESTS parallel requests to $HOST:$PORT ($MODEL) with ~32k token prompts..." +echo + +tmpdir=$(mktemp -d) +pids=() +for i in $(seq 1 "$NUM_REQUESTS"); do + ( + python3 -c " +import json, sys, time, urllib.request + +import random +random.seed($i * 9999) +topics = [ + 'mathematics', 'philosophy', 'religion', 'culture', 'astronomy', + 'biology', 'music', 'architecture', 'literature', 'physics', + 'chemistry', 'geology', 'psychology', 'economics', 'linguistics', +] +random.shuffle(topics) +sentences = [] +for j in range(95): + t1, t2, t3 = topics[j % len(topics)], topics[(j+3) % len(topics)], topics[(j+7) % len(topics)] + sentences.append( + f'In the field of {t1}, the number {$i * 1000 + j} holds particular significance ' + f'when examining its relationship to {t2} and {t3}. Scholars have long debated ' + f'whether the patterns observed in iteration {j} of this analysis reveal deeper ' + f'structural connections between seemingly unrelated disciplines. The evidence ' + f'from experiment {$i * 7 + j * 13} suggests that cross-domain numerical ' + f'correlations emerge at scale {j * $i}, challenging conventional assumptions ' + f'about the independence of these fields. ') +prompt = ' '.join(sentences) + f' Summarize the key finding about the number {$i}.' +payload = json.dumps({ + 'model': '$MODEL', + 'messages': [{'role': 'user', 'content': prompt}], + 'max_tokens': 1, + 'stream': True, +}).encode() +req = urllib.request.Request( + 'http://$HOST:$PORT/v1/chat/completions', + data=payload, + headers={'Content-Type': 'application/json'}, +) +t0 = time.perf_counter() +try: + resp = urllib.request.urlopen(req, timeout=300) + first_byte = None + for line in resp: + if first_byte is None: + first_byte = time.perf_counter() + line = line.decode().strip() + if line.startswith('data: ') and line != 'data: [DONE]': + break + ttft = (first_byte or time.perf_counter()) - t0 + prompt_tokens = len(prompt.split()) * 1.3 # rough estimate + tps = prompt_tokens / ttft + print(f'request $i: TTFT={ttft:.2f}s ~{int(prompt_tokens)} prompt tokens ~{int(tps)} tok/s prefill') +except Exception as e: + elapsed = time.perf_counter() - t0 + print(f'request $i: FAILED after {elapsed:.2f}s — {e}', file=sys.stderr) + sys.exit(1) +" >"$tmpdir/$i" 2>&1 + ) & + pids+=($!) +done + +for pid in "${pids[@]}"; do + wait "$pid" +done + +for i in $(seq 1 "$NUM_REQUESTS"); do + cat "$tmpdir/$i" +done +rm -rf "$tmpdir" diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index b9cbeed7b..99ef664af 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -93,7 +93,7 @@ class Worker: tg.start_soon(self.plan_step) tg.start_soon(self._event_applier) tg.start_soon(self._poll_connection_updates) - tg.start_soon(self._update_prefill_endpoints()) + tg.start_soon(self._update_prefill_endpoints) finally: # Actual shutdown code - waits for all tasks to complete before executing. logger.info("Stopping Worker") @@ -147,9 +147,9 @@ class Worker: candidates.sort(key=lambda i: self._IFACE_PRIORITY.get(i.interface_type, 3)) return candidates[0].ip_address - def _update_prefill_endpoints(self) -> None: + async def _update_prefill_endpoints(self) -> None: while True: - anyio.sleep(5) + await anyio.sleep(5) try: for runner_sup in self.runners.values(): instance = runner_sup.bound_instance.instance @@ -176,6 +176,7 @@ class Worker: endpoints.append({"host": ip, "port": port}) safe_model = str(my_model_id).replace("/", "--") + # TODO: Change this to be in the task with a list of optional prefill endpoints. path = f"/tmp/exo_prefill_endpoints_{safe_model}.json" with open(path, "w") as f: json.dump(endpoints, f) diff --git a/tmp/bench_overnight.log b/tmp/bench_overnight.log new file mode 100644 index 000000000..1bde089c5 --- /dev/null +++ b/tmp/bench_overnight.log @@ -0,0 +1,82 @@ +=== Starting overnight bench runs at Tue Mar 10 22:27:20 GMT 2026 === +--- [1/8] Qwen3.5-27B-GPTQ-Int4 --- +2026-03-10 22:27:22.370 | INFO | __main__:main:300 - pp/tg mode: combinations (product) - 2 pairs +You are using a model of type qwen3_5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors. +2026-03-10 22:27:23.659 | DEBUG | __main__:main:317 - [exo-bench] loaded tokenizer: mlx-community/Qwen3.5-27B-GPTQ-Int4 for prompt sizer +2026-03-10 22:27:23.661 | DEBUG | __main__:main:339 - exo-bench model: short_id=Qwen3.5-27B-GPTQ-Int4 full_id=mlx-community/Qwen3.5-27B-GPTQ-Int4 +2026-03-10 22:27:23.661 | INFO | __main__:main:340 - placements: 1 +2026-03-10 22:27:23.661 | INFO | __main__:main:342 - - Pipeline / MlxRing / nodes=1 +2026-03-10 22:27:23.661 | INFO | __main__:main:353 - Planning phase: checking downloads... +2026-03-10 22:27:23.670 | INFO | harness:run_planning_phase:415 - Started download on 12D3KooWGXXhpS3kzjfDVuBGX8AeARLjVdAFaDouYJtXDVXkyq7f +2026-03-10 22:27:23.674 | INFO | __main__:main:365 - Download: model already cached +2026-03-10 22:27:23.675 | INFO | __main__:main:377 - ================================================================================ +2026-03-10 22:27:23.675 | INFO | __main__:main:378 - PLACEMENT: Pipeline / MlxRing / nodes=1 / instance_id=725d8b5f-cc83-4dd9-8a4e-c6e2bc5b0607 +2026-03-10 22:27:31.871 | INFO | __main__:main:409 - --- pp=700 tg=32067 concurrency=1 --- +2026-03-10 22:27:34.896 | INFO | __main__:build:224 - tok=700 +2026-03-10 23:00:45.925 | INFO | __main__:main:519 - prompt_tps=11.76 gen_tps=16.13 prompt_tokens=700 gen_tokens=32067 peak_memory=31.81GB + +2026-03-10 23:00:47.935 | INFO | __main__:main:409 - --- pp=700 tg=33085 concurrency=1 --- +2026-03-10 23:00:50.948 | INFO | __main__:build:224 - tok=700 +2026-03-10 23:35:06.460 | INFO | __main__:main:519 - prompt_tps=11.87 gen_tps=16.12 prompt_tokens=700 gen_tokens=33085 peak_memory=31.89GB + +2026-03-10 23:35:08.978 | DEBUG | __main__:main:532 - Deleted instance 725d8b5f-cc83-4dd9-8a4e-c6e2bc5b0607 +2026-03-10 23:35:13.985 | DEBUG | __main__:main:541 - +Wrote results JSON: bench/results.json +--- [2/8] NVIDIA-Nemotron-3-Nano-30B-A3B (1120,1330,23100) --- +2026-03-10 23:35:17.156 | INFO | __main__:main:300 - pp/tg mode: combinations (product) - 3 pairs +2026-03-10 23:35:18.698 | DEBUG | __main__:main:317 - [exo-bench] loaded tokenizer: mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-BF16 for prompt sizer +2026-03-10 23:35:18.701 | DEBUG | __main__:main:339 - exo-bench model: short_id=NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-BF16 full_id=mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-BF16 +2026-03-10 23:35:18.701 | INFO | __main__:main:340 - placements: 1 +2026-03-10 23:35:18.701 | INFO | __main__:main:342 - - Pipeline / MlxRing / nodes=1 +2026-03-10 23:35:18.701 | INFO | __main__:main:353 - Planning phase: checking downloads... +2026-03-10 23:35:18.710 | INFO | harness:run_planning_phase:415 - Started download on 12D3KooWGXXhpS3kzjfDVuBGX8AeARLjVdAFaDouYJtXDVXkyq7f +2026-03-10 23:35:18.716 | INFO | __main__:main:365 - Download: model already cached +2026-03-10 23:35:18.716 | INFO | __main__:main:377 - ================================================================================ +2026-03-10 23:35:18.716 | INFO | __main__:main:378 - PLACEMENT: Pipeline / MlxRing / nodes=1 / instance_id=fb4d3883-2570-42e9-b3ee-aa1e4a1f8121 +2026-03-10 23:35:43.422 | INFO | __main__:main:409 - --- pp=700 tg=1120 concurrency=1 --- +2026-03-10 23:35:46.449 | INFO | __main__:build:224 - tok=700 +2026-03-10 23:36:05.631 | INFO | __main__:main:519 - prompt_tps=42.56 gen_tps=62.91 prompt_tokens=700 gen_tokens=1120 peak_memory=64.47GB + +2026-03-10 23:36:07.636 | INFO | __main__:main:409 - --- pp=700 tg=1330 concurrency=1 --- +2026-03-10 23:36:10.654 | INFO | __main__:build:224 - tok=700 +2026-03-10 23:36:33.088 | INFO | __main__:main:519 - prompt_tps=42.28 gen_tps=62.93 prompt_tokens=700 gen_tokens=1330 peak_memory=64.47GB + +2026-03-10 23:36:35.097 | INFO | __main__:main:409 - --- pp=700 tg=23100 concurrency=1 --- +2026-03-10 23:36:38.107 | INFO | __main__:build:224 - tok=700 +2026-03-10 23:43:00.943 | INFO | __main__:main:519 - prompt_tps=42.50 gen_tps=60.57 prompt_tokens=700 gen_tokens=23100 peak_memory=64.47GB + +2026-03-10 23:43:03.952 | DEBUG | __main__:main:532 - Deleted instance fb4d3883-2570-42e9-b3ee-aa1e4a1f8121 +2026-03-10 23:43:08.954 | DEBUG | __main__:main:541 - +Wrote results JSON: bench/results.json +--- [3/8] Qwen3.5-35B-A3B-bf16 --- +2026-03-11 10:33:18.576 | INFO | __main__:main:300 - pp/tg mode: combinations (product) - 4 pairs +2026-03-11 10:33:18.581 | INFO | harness:resolve_model_short_id:187 - Model not in /models, adding from HuggingFace: mlx-community/Qwen3.5-35B-A3B-bf16 +You are using a model of type qwen3_5_moe to instantiate a model of type . This is not supported for all configurations of models and can yield errors. +2026-03-11 10:33:22.909 | DEBUG | __main__:main:317 - [exo-bench] loaded tokenizer: mlx-community/Qwen3.5-35B-A3B-bf16 for prompt sizer +2026-03-11 10:33:22.913 | DEBUG | __main__:main:339 - exo-bench model: short_id=Qwen3.5-35B-A3B-bf16 full_id=mlx-community/Qwen3.5-35B-A3B-bf16 +2026-03-11 10:33:22.913 | INFO | __main__:main:340 - placements: 1 +2026-03-11 10:33:22.913 | INFO | __main__:main:342 - - Pipeline / MlxRing / nodes=1 +2026-03-11 10:33:22.913 | INFO | __main__:main:353 - Planning phase: checking downloads... +2026-03-11 10:33:22.923 | INFO | harness:run_planning_phase:415 - Started download on 12D3KooWGXXhpS3kzjfDVuBGX8AeARLjVdAFaDouYJtXDVXkyq7f +2026-03-11 10:44:01.613 | INFO | __main__:main:363 - Download: 638.7s (freshly downloaded) +2026-03-11 10:44:01.613 | INFO | __main__:main:377 - ================================================================================ +2026-03-11 10:44:01.613 | INFO | __main__:main:378 - PLACEMENT: Pipeline / MlxRing / nodes=1 / instance_id=3a5aa42f-a0e1-4273-b61e-56b4e34e0c1d +2026-03-11 10:44:27.758 | INFO | __main__:main:409 - --- pp=700 tg=6200 concurrency=1 --- +2026-03-11 10:44:30.785 | INFO | __main__:build:224 - tok=700 +2026-03-11 10:46:16.814 | INFO | __main__:main:519 - prompt_tps=39.17 gen_tps=59.35 prompt_tokens=700 gen_tokens=6200 peak_memory=70.10GB + +2026-03-11 10:46:18.816 | INFO | __main__:main:409 - --- pp=700 tg=6450 concurrency=1 --- +2026-03-11 10:46:21.834 | INFO | __main__:build:224 - tok=700 +2026-03-11 10:48:11.923 | INFO | __main__:main:519 - prompt_tps=38.95 gen_tps=59.50 prompt_tokens=700 gen_tokens=6450 peak_memory=70.10GB + +2026-03-11 10:48:13.927 | INFO | __main__:main:409 - --- pp=700 tg=25600 concurrency=1 --- +2026-03-11 10:48:16.940 | INFO | __main__:build:224 - tok=700 +2026-03-11 10:55:53.583 | INFO | __main__:main:519 - prompt_tps=39.36 gen_tps=56.27 prompt_tokens=700 gen_tokens=25600 peak_memory=70.34GB + +2026-03-11 10:55:55.585 | INFO | __main__:main:409 - --- pp=700 tg=38000 concurrency=1 --- +2026-03-11 10:55:58.608 | INFO | __main__:build:224 - tok=700 +2026-03-11 11:07:35.629 | INFO | __main__:main:519 - prompt_tps=39.43 gen_tps=54.66 prompt_tokens=700 gen_tokens=38000 peak_memory=70.61GB + +2026-03-11 11:07:38.598 | DEBUG | __main__:main:532 - Deleted instance 3a5aa42f-a0e1-4273-b61e-56b4e34e0c1d +2026-03-11 11:07:43.607 | DEBUG | __main__:main:541 - +Wrote results JSON: bench/results.json