From 03ea3cf6cd9bfa5d76c1b99bc2652b057eff0efd Mon Sep 17 00:00:00 2001 From: Ryuichi Leo Takashige Date: Thu, 19 Mar 2026 11:51:57 +0000 Subject: [PATCH] Performance optimizations --- scripts/check_kv_compat.py | 24 ++ scripts/check_rotating_cache.py | 12 + scripts/compare_kv.py | 80 +++++++ scripts/disaggregated/capture_connector.py | 2 - scripts/disaggregated/inspect_vllm_kv.py | 3 +- scripts/disaggregated/test_kv_extract.py | 21 +- scripts/disaggregated/test_kv_inject.py | 14 +- src/exo/disaggregated/batch_connector.py | 8 +- src/exo/disaggregated/prefill_client.py | 57 +++-- src/exo/disaggregated/prefill_server.py | 220 +++++++++++++++--- src/exo/disaggregated/protocol.py | 12 +- src/exo/disaggregated/streaming_connector.py | 18 +- src/exo/master/main.py | 12 +- src/exo/shared/models/model_cards.py | 5 +- src/exo/worker/engines/mlx/cache.py | 10 +- .../engines/mlx/generator/batch_generate.py | 9 +- src/exo/worker/engines/vllm/kv_cache.py | 27 ++- src/exo/worker/engines/vllm/vllm_generator.py | 46 ++-- src/exo/worker/main.py | 1 - src/exo/worker/runner/bootstrap.py | 2 +- src/exo/worker/runner/llm_inference/runner.py | 5 +- 21 files changed, 459 insertions(+), 129 deletions(-) create mode 100644 scripts/check_kv_compat.py create mode 100644 scripts/check_rotating_cache.py create mode 100644 scripts/compare_kv.py diff --git a/scripts/check_kv_compat.py b/scripts/check_kv_compat.py new file mode 100644 index 000000000..b80967731 --- /dev/null +++ b/scripts/check_kv_compat.py @@ -0,0 +1,24 @@ +import sys +sys.path.insert(0, "src") +import mlx.core as mx +from mlx_lm import load +from mlx_lm.models.cache import RotatingKVCache, KVCache + +model, tok = load("mlx-community/gpt-oss-20b-MXFP4-Q8") + +prompt = "Hello " * 2000 +tokens = tok.encode(prompt) +print(f"Tokens: {len(tokens)}") + +cache = model.make_cache() +token_arr = mx.array([tokens]) +logits = model(token_arr, cache=cache) +mx.eval(logits) + +for i, c in enumerate(cache[:6]): + if isinstance(c, KVCache) and not isinstance(c, RotatingKVCache) and c.keys is not None: + k = c.keys.astype(mx.float32) + print(f"Layer {i} KVCache: shape={c.keys.shape} offset={c.offset} first=[{float(k[0,0,0,0]):.6f}, {float(k[0,0,0,1]):.6f}] last=[{float(k[0,0,-1,-2]):.6f}, {float(k[0,0,-1,-1]):.6f}]") + elif isinstance(c, RotatingKVCache) and c.keys is not None: + k = c.keys.astype(mx.float32) + print(f"Layer {i} RotatingKV: shape={c.keys.shape} _idx={c._idx} offset={c.offset} first=[{float(k[0,0,0,0]):.6f}, {float(k[0,0,0,1]):.6f}]") diff --git a/scripts/check_rotating_cache.py b/scripts/check_rotating_cache.py new file mode 100644 index 000000000..6260767bd --- /dev/null +++ b/scripts/check_rotating_cache.py @@ -0,0 +1,12 @@ +import mlx.core as mx +from mlx_lm import load +from mlx_lm.models.cache import RotatingKVCache + +model, tok = load("mlx-community/gpt-oss-20b-MXFP4-Q8") +cache = model.make_cache() +tokens = mx.ones((1, 5000), dtype=mx.int32) +model(tokens, cache=cache) +mx.eval([c.keys for c in cache if c.keys is not None]) +for i, c in enumerate(cache[:4]): + if isinstance(c, RotatingKVCache): + print(f"Layer {i}: _idx={c._idx} offset={c.offset} keep={c.keep} max_size={c.max_size} keys={c.keys.shape}") diff --git a/scripts/compare_kv.py b/scripts/compare_kv.py new file mode 100644 index 000000000..09a5e5083 --- /dev/null +++ b/scripts/compare_kv.py @@ -0,0 +1,80 @@ +import sys +sys.path.insert(0, "src") +import mlx.core as mx +import torch +import socket +from pathlib import Path +import json +from collections import defaultdict +from mlx_lm import load +from mlx_lm.models.cache import RotatingKVCache, KVCache +from exo.disaggregated.protocol import read_header, read_message, KVChunk, Done +from exo.disaggregated.prefill_client import _nhd_to_bhsd, _torch_to_mx + +ENDPOINT = sys.argv[1] if len(sys.argv) > 1 else "10.43.0.1:62988" +MODEL = sys.argv[2] if len(sys.argv) > 2 else "mlx-community/Llama-3.2-1B-Instruct-bf16" +MODEL_PATH = sys.argv[3] if len(sys.argv) > 3 else None + +model, tok = load(MODEL_PATH or str(Path.home() / ".exo/models" / MODEL.replace("/", "--"))) +prompt = "The quick brown fox jumps over the lazy dog. " * 300 +tokens = tok.encode(prompt) +print(f"Tokens: {len(tokens)}") + +host, port = ENDPOINT.rsplit(":", 1) +sock = socket.create_connection((host, int(port)), timeout=60) +request = json.dumps({"model": MODEL, "token_ids": tokens, "start_pos": 0}).encode() + b"\n" +sock.sendall(request) +stream = sock.makefile("rb", buffering=65536) +header = read_header(stream) + +vllm_kv = defaultdict(list) +while True: + msg = read_message(stream, header) + if msg is None or isinstance(msg, Done): + break + if isinstance(msg, KVChunk): + vllm_kv[msg.layer_idx].append((msg.keys, msg.values)) +sock.close() + +print(f"Received {len(vllm_kv)} layers from vLLM") + +if hasattr(model, "make_cache"): + mlx_cache = model.make_cache() +else: + from mlx_lm.models.cache import make_prompt_cache + mlx_cache = make_prompt_cache(model) +token_arr = mx.array([tokens[:-2]]) +mlx_logits = model(token_arr, cache=mlx_cache) +mx.eval(mlx_logits) + +for i in range(min(6, len(mlx_cache))): + c = mlx_cache[i] + if c.keys is None: + continue + mlx_k = c.keys.astype(mx.float32) + + if i not in vllm_kv: + print(f"Layer {i}: no vLLM data") + continue + chunks = vllm_kv[i] + vk = torch.cat([k for k, v in chunks], dim=0) if len(chunks) > 1 else chunks[0][0] + vk_mx = _torch_to_mx(vk.permute(1, 0, 2).unsqueeze(0)).astype(mx.float32) + + n = min(mlx_k.shape[2], vk_mx.shape[2]) + diff = mx.abs(mlx_k[:, :, :n, :] - vk_mx[:, :, :n, :]) + max_diff = mx.max(diff).item() + mean_diff = mx.mean(diff).item() + cache_type = "RotatingKV" if isinstance(c, RotatingKVCache) else "KV" + print(f"Layer {i} ({cache_type}): mlx={mlx_k.shape} vllm={vk_mx.shape} max_diff={max_diff:.6f} mean_diff={mean_diff:.6f}") + + a = mlx_k[:, :, :n, :].reshape(-1) + b_vec = vk_mx[:, :, :n, :].reshape(-1) + cos_sim = float(mx.sum(a * b_vec).item()) / (float(mx.sqrt(mx.sum(a * a)).item()) * float(mx.sqrt(mx.sum(b_vec * b_vec)).item()) + 1e-8) + diff_tensor = mx.abs(mlx_k[:, :, :n, :] - vk_mx[:, :, :n, :]) + max_idx = mx.argmax(diff_tensor.reshape(-1)).item() + total_elems = diff_tensor.shape[1] * n * diff_tensor.shape[3] + h = (max_idx // (n * diff_tensor.shape[3])) % diff_tensor.shape[1] + s = (max_idx // diff_tensor.shape[3]) % n + d = max_idx % diff_tensor.shape[3] + print(f" cosine_sim={cos_sim:.6f} max_diff={max_diff:.4f} at h={h} pos={s} dim={d}: mlx={float(mlx_k[0,h,s,d].item()):.6f} vllm={float(vk_mx[0,h,s,d].item()):.6f}") + print(f" mean_diff={mean_diff:.6f} p50={float(mx.sort(diff_tensor.reshape(-1))[diff_tensor.size // 2].item()):.6f} p99={float(mx.sort(diff_tensor.reshape(-1))[int(diff_tensor.size * 0.99)].item()):.6f}") diff --git a/scripts/disaggregated/capture_connector.py b/scripts/disaggregated/capture_connector.py index fc3607dda..0dacd8726 100644 --- a/scripts/disaggregated/capture_connector.py +++ b/scripts/disaggregated/capture_connector.py @@ -4,11 +4,9 @@ from dataclasses import dataclass from typing import Any import torch - from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, - KVConnectorRole, ) captured_layers: dict[str, Any] = {} diff --git a/scripts/disaggregated/inspect_vllm_kv.py b/scripts/disaggregated/inspect_vllm_kv.py index b46c49c3a..5f23fd509 100644 --- a/scripts/disaggregated/inspect_vllm_kv.py +++ b/scripts/disaggregated/inspect_vllm_kv.py @@ -20,6 +20,7 @@ os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD" os.environ["VLLM_BATCH_INVARIANT"] = "1" from exo.worker.runner.bootstrap import _ensure_cuda_libs + _ensure_cuda_libs() import torch @@ -44,8 +45,8 @@ def main(): parser.add_argument("--prompt", default="Hello, world! How are you today?", help="Prompt to prefill") args = parser.parse_args() - from exo.worker.engines.vllm.vllm_generator import load_vllm_engine from exo.worker.engines.vllm.growable_cache import get_model_runner + from exo.worker.engines.vllm.vllm_generator import load_vllm_engine print(f"Loading vLLM engine from {args.model}...") engine, _, prefix_cache = load_vllm_engine( diff --git a/scripts/disaggregated/test_kv_extract.py b/scripts/disaggregated/test_kv_extract.py index 021f42d42..66a5fca4b 100644 --- a/scripts/disaggregated/test_kv_extract.py +++ b/scripts/disaggregated/test_kv_extract.py @@ -21,6 +21,7 @@ os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD" os.environ["VLLM_BATCH_INVARIANT"] = "1" from exo.worker.runner.bootstrap import _ensure_cuda_libs + _ensure_cuda_libs() import torch @@ -77,8 +78,9 @@ def main(): from vllm.engine.arg_utils import EngineArgs from vllm.v1.engine.llm_engine import LLMEngine - from exo.worker.engines.vllm.growable_cache import patch_vllm, set_prefix_cache + from exo.worker.engines.mlx.cache import KVPrefixCache + from exo.worker.engines.vllm.growable_cache import patch_vllm, set_prefix_cache patch_vllm() @@ -101,18 +103,19 @@ def main(): }, ) - print(f"Loading engine with KVConnector...") + print("Loading engine with KVConnector...") engine = LLMEngine.from_engine_args(engine_args) print("Engine loaded.") + from exo.worker.engines.vllm.growable_cache import get_model_runner - from vllm.model_executor.layers.mamba.abstract import MambaBase - from capture_connector import captured_layers as gdn_captured model_runner = get_model_runner() - from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as orig_causal_conv1d_fn import vllm.model_executor.layers.mamba.ops.causal_conv1d as cc_mod + from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn as orig_causal_conv1d_fn, + ) gdn_states: dict[int, dict[str, torch.Tensor]] = {} gdn_call_idx = [0] @@ -147,11 +150,11 @@ def main(): continue if hasattr(mod, 'causal_conv1d_fn') and mod.causal_conv1d_fn is orig_causal_conv1d_fn: mod.causal_conv1d_fn = patched_causal_conv1d_fn - print(f" Patched causal_conv1d_fn") + print(" Patched causal_conv1d_fn") - from exo.worker.engines.vllm.vllm_generator import VllmBatchEngine - from exo.shared.types.text_generation import TextGenerationTaskParams, InputMessage from exo.shared.types.tasks import TaskId + from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams + from exo.worker.engines.vllm.vllm_generator import VllmBatchEngine batch_engine = VllmBatchEngine(engine=engine, model_id=args.model, prefix_cache=prefix_cache) @@ -163,7 +166,7 @@ def main(): task_id = batch_engine.submit(task_id=TaskId("extract"), task_params=task, prompt=args.prompt) - print(f"Running prefill via VllmBatchEngine...") + print("Running prefill via VllmBatchEngine...") t0 = time.perf_counter() while batch_engine.has_work: results = batch_engine.step() diff --git a/scripts/disaggregated/test_kv_inject.py b/scripts/disaggregated/test_kv_inject.py index d38bcc486..ed4c8116c 100644 --- a/scripts/disaggregated/test_kv_inject.py +++ b/scripts/disaggregated/test_kv_inject.py @@ -70,7 +70,7 @@ def main(): if vllm_token_ids: print(f" Using vLLM token_ids ({len(vllm_token_ids)} tokens)") else: - print(f" WARNING: No token_ids in metadata") + print(" WARNING: No token_ids in metadata") print(f"\nLoading MLX model: {args.model}") model, tokenizer = load(args.model) @@ -161,7 +161,7 @@ def main(): print(f"\n Injected: {injected} layers, Skipped: {skipped} layers") from exo.worker.engines.vllm.kv_cache import TorchKVCache as TKV - print(f"\nRound-trip test (MLX → torch → MLX)...") + print("\nRound-trip test (MLX → torch → MLX)...") rt_caches = model.make_cache() rt_tokens = mx.array(vllm_token_ids) rt_logits = model(rt_tokens[None], cache=rt_caches) @@ -184,7 +184,7 @@ def main(): rt_max_diff = max(rt_max_diff, d) print(f" Round-trip max diff: {rt_max_diff:.4e} ({'PASS' if rt_max_diff < 0.01 else 'FAIL'})") - print(f"\nComparing with MLX-native prefill...") + print("\nComparing with MLX-native prefill...") native_caches = rt_caches for i in range(num_model_layers): @@ -243,12 +243,12 @@ def main(): print(f" Model (MLX): {args.model}") print(f" Prompt tokens: {num_tokens}") print(f" Layers injected: {injected}/{num_model_layers}") - print(f" Type mismatches: 0") + print(" Type mismatches: 0") print(f" Generated {len(generated_tokens)} tokens") print(f" Text: {generated_text!r}") if False: - print(f"\n GAPS FOUND:") + print("\n GAPS FOUND:") for idx, got, expected in type_mismatches: print(f" Layer {idx}: vLLM gives KV tensors, MLX wants {expected}") arrays_layers = [i for i, c in enumerate(caches) if isinstance(c, ArraysCache)] @@ -256,9 +256,9 @@ def main(): print(f" ArraysCache layers (not populated): {arrays_layers[:10]}{'...' if len(arrays_layers) > 10 else ''}") if generated_tokens and not all(t == generated_tokens[0] for t in generated_tokens): - print(f"\n COHERENT OUTPUT: YES (varied tokens)") + print("\n COHERENT OUTPUT: YES (varied tokens)") else: - print(f"\n COHERENT OUTPUT: POSSIBLY NOT (all same token)") + print("\n COHERENT OUTPUT: POSSIBLY NOT (all same token)") if __name__ == "__main__": diff --git a/src/exo/disaggregated/batch_connector.py b/src/exo/disaggregated/batch_connector.py index eed1024da..8b22d451f 100644 --- a/src/exo/disaggregated/batch_connector.py +++ b/src/exo/disaggregated/batch_connector.py @@ -52,8 +52,6 @@ class BatchConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBaseCl return layer_idx = int(m.group(1)) - torch.cuda.synchronize() - if isinstance(kv_layer, (list, tuple)): return @@ -68,10 +66,8 @@ class BatchConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBaseCl v_flat = v_all.reshape(-1, *v_all.shape[-2:]) # pyright: ignore[reportAny] valid = slot_mapping >= 0 # pyright: ignore[reportAny] safe_sm = slot_mapping.clamp(min=0) # pyright: ignore[reportAny] - keys = k_flat[safe_sm] # pyright: ignore[reportAny] - values = v_flat[safe_sm] # pyright: ignore[reportAny] - keys[~valid] = 0 - values[~valid] = 0 + keys = k_flat[safe_sm][valid] # pyright: ignore[reportAny] + values = v_flat[safe_sm][valid] # pyright: ignore[reportAny] prev = self.captured_layers.get(layer_idx) if prev is not None: diff --git a/src/exo/disaggregated/prefill_client.py b/src/exo/disaggregated/prefill_client.py index 068ba0bee..871eb68a1 100644 --- a/src/exo/disaggregated/prefill_client.py +++ b/src/exo/disaggregated/prefill_client.py @@ -48,28 +48,10 @@ def _inject_kv_cache(cache: KVCache, keys: torch.Tensor, values: torch.Tensor, n def _inject_rotating_kv_cache(cache: RotatingKVCache, keys: torch.Tensor, values: torch.Tensor, num_tokens: int) -> None: k_mx, v_mx = _nhd_to_bhsd(keys, values) seq_len = int(k_mx.shape[2]) - - if seq_len <= cache.max_size: - cache.keys = k_mx - cache.values = v_mx - cache.offset = seq_len - cache._idx = seq_len - else: - keep = cache.keep - window = cache.max_size - if keep == 0: - cache.keys = k_mx[:, :, -window:, :] - cache.values = v_mx[:, :, -window:, :] - cache._idx = window - else: - sink_keys = k_mx[:, :, :keep, :] - sink_values = v_mx[:, :, :keep, :] - recent_keys = k_mx[:, :, -(window - keep):, :] - recent_values = v_mx[:, :, -(window - keep):, :] - cache.keys = mx.concatenate([sink_keys, recent_keys], axis=2) - cache.values = mx.concatenate([sink_values, recent_values], axis=2) - cache._idx = keep - cache.offset = num_tokens + cache.keys = k_mx + cache.values = v_mx + cache.offset = seq_len + cache._idx = seq_len def _inject_arrays_cache(cache: ArraysCache, arrays: list[torch.Tensor]) -> None: @@ -82,6 +64,8 @@ def remote_prefill( model_id: str, mlx_model: Model, on_prefill_progress: Callable[[int, int], None] | None = None, + existing_cache: list[KVCache | RotatingKVCache | ArraysCache] | None = None, + start_pos: int = 0, ) -> tuple[list[KVCache | RotatingKVCache | ArraysCache], int]: if ":" in endpoint: host, port_str = endpoint.rsplit(":", 1) @@ -90,12 +74,12 @@ def remote_prefill( host = endpoint port = 8900 - logger.info(f"Connecting to prefill server at {host}:{port} ({len(token_ids)} tokens)") + logger.info(f"Connecting to prefill server at {host}:{port} ({len(token_ids)} tokens, start_pos={start_pos})") t0 = time.perf_counter() sock = socket.create_connection((host, port), timeout=30) try: - request = json.dumps({"model": model_id, "token_ids": token_ids}).encode("utf-8") + b"\n" + request = json.dumps({"model": model_id, "token_ids": token_ids, "start_pos": start_pos}).encode("utf-8") + b"\n" sock.sendall(request) raw_stream = sock.makefile("rb", buffering=65536) @@ -147,7 +131,16 @@ def remote_prefill( finally: sock.close() - caches: list[KVCache | RotatingKVCache | ArraysCache] = cast(list[KVCache | RotatingKVCache | ArraysCache], mlx_model.make_cache()) # pyright: ignore[reportUnknownMemberType] + if existing_cache is not None and start_pos > 0: + caches = existing_cache + else: + if hasattr(mlx_model, "make_cache"): + caches = cast(list[KVCache | RotatingKVCache | ArraysCache], mlx_model.make_cache()) # pyright: ignore[reportUnknownMemberType] + else: + from mlx_lm.models.cache import make_prompt_cache + caches = cast(list[KVCache | RotatingKVCache | ArraysCache], make_prompt_cache(mlx_model)) # pyright: ignore[reportUnknownMemberType] + + final_offset = start_pos + total_tokens for i, cache in enumerate(caches): if i in kv_buffers: @@ -161,19 +154,25 @@ def remote_prefill( all_values = torch.cat([v for _k, v in chunks], dim=0) # type: ignore if isinstance(cache, RotatingKVCache): - _inject_rotating_kv_cache(cache, all_keys, all_values, total_tokens) # pyright: ignore[reportUnknownArgumentType] + _inject_rotating_kv_cache(cache, all_keys, all_values, final_offset) # pyright: ignore[reportUnknownArgumentType] elif isinstance(cache, KVCache): - _inject_kv_cache(cache, all_keys, all_values, total_tokens) # pyright: ignore[reportUnknownArgumentType] + if start_pos > 0 and cache.keys is not None: + k_new, v_new = _nhd_to_bhsd(all_keys, all_values) # pyright: ignore[reportUnknownArgumentType] + cache.keys = mx.concatenate([cache.keys[:, :, :start_pos, :], k_new], axis=2) + cache.values = mx.concatenate([cache.values[:, :, :start_pos, :], v_new], axis=2) + cache.offset = final_offset + else: + _inject_kv_cache(cache, all_keys, all_values, final_offset) # pyright: ignore[reportUnknownArgumentType] if i in arrays_buffers and isinstance(cache, ArraysCache): _inject_arrays_cache(cache, arrays_buffers[i]) t_injected = time.perf_counter() logger.info( - f"Remote prefill: {total_tokens} tokens, " + f"Remote prefill: {total_tokens} new tokens (start_pos={start_pos}, final_offset={final_offset}), " f"transfer={((t_received - t0) * 1000):.0f}ms, " f"inject={((t_injected - t_received) * 1000):.0f}ms, " f"total={((t_injected - t0) * 1000):.0f}ms" ) - return caches, total_tokens + return caches, final_offset diff --git a/src/exo/disaggregated/prefill_server.py b/src/exo/disaggregated/prefill_server.py index 6d9e6b9b9..02b552635 100644 --- a/src/exo/disaggregated/prefill_server.py +++ b/src/exo/disaggregated/prefill_server.py @@ -2,7 +2,6 @@ from __future__ import annotations import contextlib import json -import queue import socketserver import threading import time @@ -20,9 +19,12 @@ from exo.disaggregated.protocol import ( if TYPE_CHECKING: from vllm.v1.engine.llm_engine import LLMEngine +from exo.worker.engines.mlx.cache import KVPrefixCache +from exo.worker.engines.vllm.kv_cache import KVLayerState, TorchKVCache from exo.worker.runner.bootstrap import logger _engine_ref: LLMEngine | None = None +_prefix_cache_ref: KVPrefixCache | None = None _overlapping: bool = True _connector_patched: bool = False _gdn_patched: bool = False @@ -79,8 +81,10 @@ def _patch_gdn_capture() -> None: _gdn_patched = True try: - from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as orig_fn # type: ignore import vllm.model_executor.layers.mamba.ops.causal_conv1d as cc_mod # type: ignore + from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn as orig_fn, # type: ignore + ) except ImportError: return @@ -90,7 +94,6 @@ def _patch_gdn_capture() -> None: x = args[0] if args else None if x is not None and x.shape[0] <= 100: # type: ignore return result - torch.cuda.synchronize() ci: int = cache_indices[0].item() if cache_indices.numel() > 0 else 0 # type: ignore idx = _gdn_call_idx[0] if _gdn_layer_order and idx < len(_gdn_layer_order) * 100: @@ -106,8 +109,8 @@ def _patch_gdn_capture() -> None: for mod in list(sys.modules.values()): if mod is None or mod is cc_mod: continue - if hasattr(mod, "causal_conv1d_fn") and getattr(mod, "causal_conv1d_fn") is orig_fn: - setattr(mod, "causal_conv1d_fn", patched_fn) + if hasattr(mod, "causal_conv1d_fn") and mod.causal_conv1d_fn is orig_fn: + mod.causal_conv1d_fn = patched_fn logger.info("Patched causal_conv1d_fn for GDN state capture") @@ -150,32 +153,56 @@ def _get_layer_info(engine: LLMEngine) -> tuple[int, str, list[dict[str, Any]]]: return num_layers, dtype_str, layers_info -def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], wfile: Any) -> None: # pyright: ignore[reportAny] - from exo.disaggregated.streaming_connector import StreamingConnector +def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], start_pos: int, wfile: Any) -> None: # pyright: ignore[reportAny] from exo.worker.engines.vllm.growable_cache import get_model_runner model_runner = get_model_runner() assert model_runner is not None - from exo.disaggregated.streaming_connector import get_shared_queue, reset_shared_queue + from exo.disaggregated.streaming_connector import ( + get_shared_queue, + reset_shared_queue, + ) reset_shared_queue() _gdn_states.clear() _gdn_call_idx[0] = 0 layer_queue = get_shared_queue() + server_cached = 0 + cached_torch: TorchKVCache | None = None + if _prefix_cache_ref is not None: + cached_torch, server_cached, _ = _prefix_cache_ref.lookup(token_ids) + skip_tokens = max(0, start_pos - server_cached) + num_layers, dtype_str, layers_info = _get_layer_info(engine) write_header(wfile, {"num_layers": num_layers, "dtype": dtype_str, "layers": layers_info}) # pyright: ignore[reportAny] + if start_pos < server_cached and cached_torch is not None: + for i, layer in enumerate(cached_torch.layers): + if isinstance(layer, KVLayerState) and layer.keys.numel() > 0: + keys = layer.keys + values = layer.values + if keys.dim() == 4: + keys = keys.reshape(-1, keys.shape[-2], keys.shape[-1]) + values = values.reshape(-1, values.shape[-2], values.shape[-1]) + keys = keys[start_pos:server_cached] + values = values[start_pos:server_cached] + if keys.shape[0] > 0: + write_kv_chunk(wfile, i, keys, values) # pyright: ignore[reportAny] + logger.info(f"Sent cached KV for positions {start_pos}-{server_cached} from server prefix cache") + from vllm.sampling_params import ( SamplingParams, ) + prefill_token_ids = token_ids[:-2] if len(token_ids) > 2 else token_ids request_id = f"prefill-{time.monotonic_ns()}" - params = SamplingParams(max_tokens=1, detokenize=False) # pyright: ignore[reportCallIssue] - engine.add_request(request_id, {"prompt_token_ids": token_ids}, params) # pyright: ignore[reportArgumentType] + params = SamplingParams(max_tokens=2, detokenize=False) # pyright: ignore[reportCallIssue] + engine.add_request(request_id, {"prompt_token_ids": prefill_token_ids}, params) # pyright: ignore[reportArgumentType] chunks_sent = [0] + layer_token_counts: dict[int, int] = {} def writer_loop() -> None: while True: @@ -183,6 +210,18 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], wfile: Any if item is None: break layer_idx, keys, values = item + + prev = layer_token_counts.get(layer_idx, 0) + n = keys.shape[0] + new_total = prev + n + layer_token_counts[layer_idx] = new_total + + if new_total <= skip_tokens: + continue + if prev < skip_tokens: + trim = skip_tokens - prev + keys = keys[trim:] + values = values[trim:] write_kv_chunk(wfile, layer_idx, keys, values) # pyright: ignore[reportAny] chunks_sent[0] += 1 @@ -193,6 +232,7 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], wfile: Any outputs = engine.step() for output in outputs: if output.request_id == request_id and output.outputs[0].token_ids: + _save_vllm_prefix_cache(engine, request_id, prefill_token_ids) engine.abort_request([request_id]) # type: ignore break else: @@ -201,13 +241,18 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], wfile: Any layer_queue.put(None) writer_thread.join() - logger.info(f"Overlapping prefill: sent {chunks_sent[0]} KV chunks") + actual_per_layer = max(layer_token_counts.values()) if layer_token_counts else 0 + new_tokens_sent = max(0, actual_per_layer - skip_tokens) + cached_tokens_sent = max(0, server_cached - start_pos) if start_pos < server_cached else 0 + tokens_sent = cached_tokens_sent + new_tokens_sent + logger.info(f"Overlapping prefill: sent {chunks_sent[0]} chunks, {tokens_sent} tokens (server_cached={server_cached}, skip={skip_tokens})") - _stream_gdn_states(engine, wfile, num_layers, layers_info) - write_done(wfile, len(token_ids)) # pyright: ignore[reportAny] + cached_arrays: list[tuple[int, list[torch.Tensor]]] = [] + _stream_gdn_states_and_collect(engine, wfile, num_layers, layers_info, cached_arrays) + write_done(wfile, tokens_sent) # pyright: ignore[reportAny] -def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], wfile: Any) -> None: # pyright: ignore[reportAny] +def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], start_pos: int, wfile: Any) -> None: # pyright: ignore[reportAny] from exo.worker.engines.vllm.growable_cache import get_model_runner num_layers, dtype_str, layers_info = _get_layer_info(engine) @@ -215,20 +260,29 @@ def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], wfile: Any) -> N model_runner = get_model_runner() assert model_runner is not None - from exo.disaggregated.batch_connector import clear_shared_captured_layers, get_shared_captured_layers + from exo.disaggregated.batch_connector import ( + clear_shared_captured_layers, + get_shared_captured_layers, + ) _gdn_states.clear() _gdn_call_idx[0] = 0 clear_shared_captured_layers() captured_layers = get_shared_captured_layers() + server_cached = 0 + if _prefix_cache_ref is not None: + _, server_cached, _ = _prefix_cache_ref.lookup(token_ids) + skip_tokens = max(0, start_pos - server_cached) + from vllm.sampling_params import ( SamplingParams, ) + prefill_token_ids = token_ids[:-2] if len(token_ids) > 2 else token_ids request_id = f"prefill-{time.monotonic_ns()}" - params = SamplingParams(max_tokens=1, detokenize=False) # pyright: ignore[reportCallIssue] - engine.add_request(request_id, {"prompt_token_ids": token_ids}, params) # pyright: ignore[reportArgumentType] + params = SamplingParams(max_tokens=2, detokenize=False) # pyright: ignore[reportCallIssue] + engine.add_request(request_id, {"prompt_token_ids": prefill_token_ids}, params) # pyright: ignore[reportArgumentType] while engine.has_unfinished_requests(): outputs = engine.step() @@ -242,17 +296,32 @@ def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], wfile: Any) -> N write_header(wfile, {"num_layers": num_layers, "dtype": dtype_str, "layers": layers_info}) # pyright: ignore[reportAny] - logger.info(f"Batch prefill: streaming {len(captured_layers)} captured layers") + all_kv: list[tuple[int, torch.Tensor, torch.Tensor]] = [] for layer_idx in sorted(captured_layers.keys()): layer_data = captured_layers[layer_idx] - write_kv_chunk(wfile, layer_idx, layer_data["keys"], layer_data["values"]) # pyright: ignore[reportAny] + keys = layer_data["keys"] + values = layer_data["values"] + all_kv.append((layer_idx, keys, values)) + if keys.shape[0] > skip_tokens: + write_kv_chunk(wfile, layer_idx, keys[skip_tokens:], values[skip_tokens:]) # pyright: ignore[reportAny] clear_shared_captured_layers() - _stream_gdn_states(engine, wfile, num_layers, layers_info) - write_done(wfile, len(token_ids)) # pyright: ignore[reportAny] + actual_per_layer = max((k.shape[0] for _, k, _ in all_kv), default=0) + tokens_sent = max(0, actual_per_layer - skip_tokens) + logger.info(f"Batch prefill: {len(all_kv)} layers, {tokens_sent} tokens sent (server_cached={server_cached}, skip={skip_tokens}, captured={actual_per_layer})") + + cached_arrays: list[tuple[int, list[torch.Tensor]]] = [] + _stream_gdn_states_and_collect(engine, wfile, num_layers, layers_info, cached_arrays) + write_done(wfile, tokens_sent) # pyright: ignore[reportAny] -def _stream_gdn_states(_engine: LLMEngine, wfile: Any, num_layers: int, layers_info: list[dict[str, Any]]) -> None: # type: ignore +def _stream_gdn_states_and_collect( + _engine: LLMEngine, + wfile: Any, + num_layers: int, + layers_info: list[dict[str, Any]], + out_arrays: list[tuple[int, list[torch.Tensor]]], +) -> None: # type: ignore from exo.worker.engines.vllm.growable_cache import get_model_runner if not _gdn_states: @@ -282,6 +351,7 @@ def _stream_gdn_states(_engine: LLMEngine, wfile: Any, num_layers: int, layers_i arrays.append(rec.to(torch.bfloat16)) if arrays: write_arrays_state(wfile, layer_idx, arrays) # type: ignore + out_arrays.append((layer_idx, arrays)) except Exception: logger.opt(exception=True).warning(f"Failed to capture GDN state for layer {layer_idx}") @@ -289,6 +359,101 @@ def _stream_gdn_states(_engine: LLMEngine, wfile: Any, num_layers: int, layers_i _gdn_call_idx[0] = 0 +def _build_torch_cache(kv_chunks: list[tuple[int, torch.Tensor, torch.Tensor]], arrays_chunks: list[tuple[int, list[torch.Tensor]]], num_layers: int) -> TorchKVCache: + from exo.worker.engines.vllm.kv_cache import ArraysLayerState + + layers_by_idx: dict[int, KVLayerState | ArraysLayerState] = {} + for layer_idx, keys, values in kv_chunks: + if layer_idx in layers_by_idx: + prev = layers_by_idx[layer_idx] + if isinstance(prev, KVLayerState): + layers_by_idx[layer_idx] = KVLayerState( + keys=torch.cat([prev.keys, keys], dim=0), # type: ignore + values=torch.cat([prev.values, values], dim=0), # type: ignore + ) + else: + layers_by_idx[layer_idx] = KVLayerState(keys=keys, values=values) + for layer_idx, arrays in arrays_chunks: + layers_by_idx[layer_idx] = ArraysLayerState(arrays=[a if isinstance(a, torch.Tensor) else None for a in arrays]) + + ordered: list[KVLayerState | ArraysLayerState] = [] + for i in range(num_layers): + if i in layers_by_idx: + ordered.append(layers_by_idx[i]) + else: + ordered.append(KVLayerState(keys=torch.empty(0), values=torch.empty(0))) + return TorchKVCache(ordered) + + +def _save_vllm_prefix_cache(engine: LLMEngine, request_id: str, prefill_token_ids: list[int]) -> None: + if _prefix_cache_ref is None: + logger.info("Server prefix cache: no cache ref") + return + try: + from exo.worker.engines.vllm.vllm_generator import _save_prefix_cache + + try: + engine_core = engine.engine_core.engine_core # type: ignore + coordinator = engine_core.scheduler.kv_cache_manager.coordinator # type: ignore + all_keys: list[str] = [] + for mgr in coordinator.single_type_managers: # type: ignore + all_keys.extend(str(k) for k in mgr.req_to_blocks) # type: ignore + logger.info(f"Server prefix cache: request_id={request_id}, available_keys={all_keys[:5]}") + except Exception: + pass + + before = len(_prefix_cache_ref.prompts) + _save_prefix_cache(engine, _prefix_cache_ref, request_id, prefill_token_ids, len(prefill_token_ids)) + after = len(_prefix_cache_ref.prompts) + if after > before: + logger.info(f"Server prefix cache: saved {len(prefill_token_ids)} tokens (entries: {before} → {after})") + else: + logger.info(f"Server prefix cache: save had no effect for request_id={request_id}") + except Exception: + logger.opt(exception=True).warning("Failed to save server-side prefix cache") + + +def _check_cache(token_ids: list[int]) -> TorchKVCache | None: + if _prefix_cache_ref is None: + return None + import mlx.core as mx + + prompt_arr = mx.array(token_ids) + best_index: int | None = None + best_length = 0 + for i, cached_prompt in enumerate(_prefix_cache_ref.prompts): + prefix_len = min(len(cached_prompt), len(prompt_arr)) + if prefix_len == 0: + continue + match_len = int(mx.sum(cached_prompt[:prefix_len] == prompt_arr[:prefix_len]).item()) # pyright: ignore[reportAny] + if match_len == len(token_ids) and match_len == len(cached_prompt) and match_len > best_length: + best_index = i + best_length = match_len + + if best_index is None: + return None + + cached = _prefix_cache_ref.caches[best_index] + if isinstance(cached, TorchKVCache): + return cached + return None + + +def _send_cached(torch_cache: TorchKVCache, token_ids: list[int], wfile: Any, engine: LLMEngine) -> None: + num_layers, dtype_str, layers_info = _get_layer_info(engine) + write_header(wfile, {"num_layers": num_layers, "dtype": dtype_str, "layers": layers_info}) # type: ignore + from exo.worker.engines.vllm.kv_cache import ArraysLayerState + + for i, layer in enumerate(torch_cache.layers): + if isinstance(layer, KVLayerState) and layer.keys.numel() > 0: + write_kv_chunk(wfile, i, layer.keys, layer.values) # type: ignore + elif isinstance(layer, ArraysLayerState): + arrays = [a for a in layer.arrays if a is not None] + if arrays: + write_arrays_state(wfile, i, arrays) # type: ignore + write_done(wfile, len(token_ids)) # type: ignore + + class _PrefillHandler(socketserver.StreamRequestHandler): def handle(self) -> None: try: @@ -297,6 +462,7 @@ class _PrefillHandler(socketserver.StreamRequestHandler): return request: dict[str, Any] = json.loads(line.decode("utf-8")) # pyright: ignore[reportAny] token_ids: list[int] = request["token_ids"] # pyright: ignore[reportAny] + start_pos: int = request.get("start_pos", 0) # pyright: ignore[reportAny] engine = _engine_ref if engine is None: @@ -309,13 +475,13 @@ class _PrefillHandler(socketserver.StreamRequestHandler): self.wfile.write(error) return - logger.info(f"Prefill request: {len(token_ids)} tokens, overlapping={_overlapping}") + logger.info(f"Prefill request: {len(token_ids)} tokens, start_pos={start_pos}, overlapping={_overlapping}") t0 = time.perf_counter() if _overlapping: - _run_prefill_overlapping(engine, token_ids, self.wfile) + _run_prefill_overlapping(engine, token_ids, start_pos, self.wfile) else: - _run_prefill_batch(engine, token_ids, self.wfile) + _run_prefill_batch(engine, token_ids, start_pos, self.wfile) elapsed = time.perf_counter() - t0 logger.info(f"Prefill complete: {len(token_ids)} tokens in {elapsed*1000:.0f}ms ({len(token_ids)/elapsed:.0f} tok/s)") @@ -328,10 +494,12 @@ def start_prefill_server( bind_address: str, port: int, overlapping: bool = True, + prefix_cache: KVPrefixCache | None = None, ) -> socketserver.ThreadingTCPServer: - global _engine_ref, _overlapping + global _engine_ref, _overlapping, _prefix_cache_ref _engine_ref = engine _overlapping = overlapping + _prefix_cache_ref = prefix_cache _patch_gdn_capture() _init_gdn_layer_order() diff --git a/src/exo/disaggregated/protocol.py b/src/exo/disaggregated/protocol.py index 70ef72ff7..bde6a8a56 100644 --- a/src/exo/disaggregated/protocol.py +++ b/src/exo/disaggregated/protocol.py @@ -61,7 +61,9 @@ def _str_to_dtype(s: str) -> torch.dtype: def _dtype_size(dtype: torch.dtype) -> int: - return {torch.float16: 2, torch.bfloat16: 2, torch.float32: 4}[dtype] + if dtype == torch.bfloat16: + return 4 + return {torch.float16: 2, torch.float32: 4}[dtype] def write_header(stream: BinaryIO, header: dict[str, object]) -> None: @@ -72,7 +74,7 @@ def write_header(stream: BinaryIO, header: dict[str, object]) -> None: def _tensor_to_bytes(t: torch.Tensor) -> bytes: if t.dtype == torch.bfloat16: - return t.contiguous().view(torch.int16).numpy().tobytes() # type: ignore + return t.contiguous().float().numpy().tobytes() # type: ignore return t.contiguous().numpy().tobytes() # type: ignore @@ -133,8 +135,8 @@ def read_message(stream: BinaryIO, header: dict[str, object]) -> Message | None: values_raw = _read_exactly(stream, tensor_bytes) shape = (num_tokens, n_heads, head_dim) if dtype == torch.bfloat16: - keys: torch.Tensor = torch.frombuffer(bytearray(keys_raw), dtype=torch.int16).view(torch.bfloat16).reshape(shape).clone() # type: ignore - values: torch.Tensor = torch.frombuffer(bytearray(values_raw), dtype=torch.int16).view(torch.bfloat16).reshape(shape).clone() # type: ignore + keys: torch.Tensor = torch.frombuffer(bytearray(keys_raw), dtype=torch.float32).reshape(shape).to(torch.bfloat16).clone() # type: ignore + values: torch.Tensor = torch.frombuffer(bytearray(values_raw), dtype=torch.float32).reshape(shape).to(torch.bfloat16).clone() # type: ignore else: keys = torch.frombuffer(bytearray(keys_raw), dtype=dtype).reshape(shape).clone() # type: ignore values = torch.frombuffer(bytearray(values_raw), dtype=dtype).reshape(shape).clone() # type: ignore @@ -157,7 +159,7 @@ def read_message(stream: BinaryIO, header: dict[str, object]) -> Message | None: total_elems *= d # pyright: ignore[reportAny] raw = _read_exactly(stream, total_elems * elem_size) if dtype == torch.bfloat16: - t: torch.Tensor = torch.frombuffer(bytearray(raw), dtype=torch.int16).view(torch.bfloat16).reshape(shape_arr).clone() # type: ignore + t: torch.Tensor = torch.frombuffer(bytearray(raw), dtype=torch.float32).reshape(shape_arr).to(torch.bfloat16).clone() # type: ignore else: t = torch.frombuffer(bytearray(raw), dtype=dtype).reshape(shape_arr).clone() # type: ignore arrays.append(t) # pyright: ignore[reportUnknownArgumentType] diff --git a/src/exo/disaggregated/streaming_connector.py b/src/exo/disaggregated/streaming_connector.py index 24bcf7c2a..b00b2e9c1 100644 --- a/src/exo/disaggregated/streaming_connector.py +++ b/src/exo/disaggregated/streaming_connector.py @@ -37,6 +37,8 @@ class StreamingConnectorMetadata(KVConnectorMetadata): # pyright: ignore[report class StreamingConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBaseClass] _queue: queue.Queue[tuple[int, torch.Tensor, torch.Tensor] | None] + _save_count: int = 0 + def __init__(self, vllm_config: Any, role: KVConnectorRole, kv_cache_config: Any = None) -> None: # type: ignore super().__init__(vllm_config, role, kv_cache_config) # pyright: ignore[reportUnknownMemberType] self._queue = _shared_queue @@ -61,11 +63,14 @@ class StreamingConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBa return layer_idx = int(m.group(1)) - torch.cuda.synchronize() - if isinstance(kv_layer, (list, tuple)): return + if self._save_count < 1: + import logging + logging.getLogger("exo").info(f"save_kv_layer: kv_layer.shape={kv_layer.shape} dtype={kv_layer.dtype} slot_mapping.shape={slot_mapping.shape if slot_mapping is not None else None}") # pyright: ignore[reportAny] + self._save_count += 1 + if slot_mapping is not None: if kv_layer.shape[0] == 2: # pyright: ignore[reportAny] k_all = kv_layer[0] # pyright: ignore[reportAny] @@ -77,10 +82,11 @@ class StreamingConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBa v_flat = v_all.reshape(-1, *v_all.shape[-2:]) # pyright: ignore[reportAny] valid = slot_mapping >= 0 # pyright: ignore[reportAny] safe_sm = slot_mapping.clamp(min=0) # pyright: ignore[reportAny] - keys = k_flat[safe_sm] # pyright: ignore[reportAny] - values = v_flat[safe_sm] # pyright: ignore[reportAny] - keys[~valid] = 0 - values[~valid] = 0 + keys = k_flat[safe_sm][valid] # pyright: ignore[reportAny] + values = v_flat[safe_sm][valid] # pyright: ignore[reportAny] + if keys.dtype not in (torch.bfloat16, torch.float16, torch.float32): # pyright: ignore[reportAny] + keys = keys.to(torch.bfloat16) # pyright: ignore[reportAny] + values = values.to(torch.bfloat16) # pyright: ignore[reportAny] self._queue.put((layer_idx, keys.cpu(), values.cpu())) # pyright: ignore[reportAny] else: self._queue.put((layer_idx, kv_layer.cpu().clone(), kv_layer.cpu().clone())) # pyright: ignore[reportAny] diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 6726ccb2d..f62a67636 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -98,6 +98,7 @@ class Master: from exo.master.placement_utils import ( _find_ip_prioritised as find_ip_prioritised, # pyright: ignore[reportPrivateUsage] ) + from exo.shared.models.model_cards import derive_base_model endpoints: list[tuple[int, str]] = [] vllm_instance_count = 0 @@ -111,21 +112,14 @@ class Master: if first_shard is None: logger.info(f"Prefill routing: VllmInstance {instance.instance_id} has no shards") continue - if first_shard.model_card.base_model.lower() != decode_model_base.lower(): + if derive_base_model(first_shard.model_card.base_model).lower() != decode_model_base.lower(): logger.info( f"Prefill routing: VllmInstance {instance.instance_id} base_model " f"{first_shard.model_card.base_model!r} != decode {decode_model_base!r}" ) continue - active_task_count = sum( - 1 for task in self.state.tasks.values() - if task.instance_id == instance.instance_id - and task.task_status in (TaskStatus.Pending, TaskStatus.Running) - ) - if active_task_count > 0: - logger.info(f"Prefill routing: VllmInstance {instance.instance_id} busy ({active_task_count} active tasks)") - continue + pass for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): runner_status = self.state.runners.get(runner_id) diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py index 8edccb618..a2a115388 100644 --- a/src/exo/shared/models/model_cards.py +++ b/src/exo/shared/models/model_cards.py @@ -41,7 +41,7 @@ _card_cache: dict[ModelId, "ModelCard"] = {} import re _QUANT_SUFFIXES = re.compile( - r"[-_](?:MLX|MXFP[0-9]+|GPTQ|AWQ|GGUF|fp16|bf16|fp8|int[0-9]+|[0-9]+(?:\.[0-9]+)?bit|Q[0-9]+(?:_[A-Z0-9]+)?|gs[0-9]+)(?:[-_](?:MLX|Q[0-9]+|Int[0-9]+|[A-Z0-9]+|gs[0-9]+))*$", + r"[-_ ](?:MLX|MXFP[0-9]+|NVFP[0-9]+|GPTQ|AWQ|GGUF|fp16|bf16|fp8|int[0-9]+|[0-9]+(?:\.[0-9]+)?bit|Q[0-9]+(?:_[A-Z0-9]+)?|gs[0-9]+)(?:[-_ ](?:MLX|Q[0-9]+|Int[0-9]+|[A-Z0-9]+|gs[0-9]+))*$", re.IGNORECASE, ) @@ -127,7 +127,8 @@ class ModelCard(CamelCaseModel): if not self.base_model: self.base_model = derive_base_model(self.model_id) else: - self.base_model = _normalize_base_model(self.base_model) + stripped = _QUANT_SUFFIXES.sub("", self.base_model) + self.base_model = _normalize_base_model(stripped) return self @field_validator("tasks", mode="before") diff --git a/src/exo/worker/engines/mlx/cache.py b/src/exo/worker/engines/mlx/cache.py index b04a244e0..67b494ae2 100644 --- a/src/exo/worker/engines/mlx/cache.py +++ b/src/exo/worker/engines/mlx/cache.py @@ -194,10 +194,18 @@ class KVPrefixCache: # This ensures stream_generate always has at least one token to start with mlx_cache = self._get_mlx_cache(best_index) has_ssm = has_non_kv_caches(mlx_cache) + snapshots_available = self._snapshots[best_index] is not None + + if is_exact and has_ssm and not snapshots_available: + prompt_cache = deepcopy(mlx_cache) + self._access_counter += 1 + self._last_used[best_index] = self._access_counter + remaining = prompt_tokens[best_length:] + return prompt_cache, remaining, best_index + target = (max_length - 1) if is_exact and not has_ssm else best_length restore_pos, restore_snap = self._get_snapshot(best_index, target) - # No usable snapshot — need fresh cache if restore_snap is None and has_ssm: return make_kv_cache(model), prompt_tokens, None diff --git a/src/exo/worker/engines/mlx/generator/batch_generate.py b/src/exo/worker/engines/mlx/generator/batch_generate.py index 4fb94c912..5d4d99e24 100644 --- a/src/exo/worker/engines/mlx/generator/batch_generate.py +++ b/src/exo/worker/engines/mlx/generator/batch_generate.py @@ -121,7 +121,9 @@ class ExoBatchGenerator: matched_index: int | None = None prompt_tokens = all_prompt_tokens - if self.kv_prefix_cache is not None and not is_bench: + has_prefill_endpoints = bool(task_params.prefill_endpoints) and len(all_prompt_tokens) > 1000 + + if self.kv_prefix_cache is not None and not is_bench and not has_prefill_endpoints: cache, remaining_tokens, matched_index = self.kv_prefix_cache.get_kv_cache( self.model, all_prompt_tokens ) @@ -168,8 +170,13 @@ class ExoBatchGenerator: model_id=str(task_params.model), mlx_model=self.model, on_prefill_progress=on_prefill_progress, + existing_cache=None, + start_pos=0, ) cache = injected_cache + from exo.worker.engines.mlx.cache import snapshot_ssm_states + + cache_snapshots = [snapshot_ssm_states(cache)] _prefill_tps = total_tokens / max(time.perf_counter() - t0, 0.001) used_remote_prefill = True logger.info(f"Remote prefill: {total_tokens} tokens at {_prefill_tps:.0f} tok/s") diff --git a/src/exo/worker/engines/vllm/kv_cache.py b/src/exo/worker/engines/vllm/kv_cache.py index 520bcad50..a9ef7f55e 100644 --- a/src/exo/worker/engines/vllm/kv_cache.py +++ b/src/exo/worker/engines/vllm/kv_cache.py @@ -264,6 +264,7 @@ class TorchKVCache: first = kv_caches[0] device = first[0].device if isinstance(first, list) else first.device + block_size = first[0].shape[1] if isinstance(first, list) else first.shape[-3] for layer_idx, layer in enumerate(self.layers): if not isinstance(layer, KVLayerState): continue @@ -271,14 +272,26 @@ class TorchKVCache: bt = block_tables[gi] kv = kv_caches[layer_idx] k_all, v_all = _split_kv(kv) - n_blocks = min(len(bt), layer.keys.shape[0]) + + keys = layer.keys + values = layer.values + if keys.dim() == 3: + offset = token_offset_per_group[gi] if token_offset_per_group else 0 + if offset > 0: + keys = keys[offset:] + values = values[offset:] + s, h, d = keys.shape + pad = (block_size - s % block_size) % block_size + if pad > 0: + keys = torch.nn.functional.pad(keys, (0, 0, 0, 0, 0, pad)) + values = torch.nn.functional.pad(values, (0, 0, 0, 0, 0, pad)) + keys = keys.reshape(-1, block_size, h, d) + values = values.reshape(-1, block_size, h, d) + + n_blocks = min(len(bt), keys.shape[0]) if n_blocks > 0: - k_all[bt[:n_blocks]] = layer.keys[:n_blocks].to( - device, non_blocking=True - ) - v_all[bt[:n_blocks]] = layer.values[:n_blocks].to( - device, non_blocking=True - ) + k_all[bt[:n_blocks]] = keys[:n_blocks].to(device, non_blocking=True) + v_all[bt[:n_blocks]] = values[:n_blocks].to(device, non_blocking=True) torch.cuda.synchronize() def __iter__(self) -> Iterator[LayerState]: diff --git a/src/exo/worker/engines/vllm/vllm_generator.py b/src/exo/worker/engines/vllm/vllm_generator.py index 8705bc55e..e85471fa8 100644 --- a/src/exo/worker/engines/vllm/vllm_generator.py +++ b/src/exo/worker/engines/vllm/vllm_generator.py @@ -213,8 +213,9 @@ def vllm_generate( tokenizer = engine.get_tokenizer() stop_ids = _stop_token_ids(tokenizer, model_id) + DEFAULT_PREFILL_STEP_SIZE = 8192 max_batch_tokens: int = ( - getattr(engine.model_config, "max_num_batched_tokens", 2048) or 2048 + getattr(engine.model_config, "max_num_batched_tokens", DEFAULT_PREFILL_STEP_SIZE) or DEFAULT_PREFILL_STEP_SIZE ) # type: ignore[reportUnknownMemberType] start_time = time.perf_counter() first_token_time: float | None = None @@ -577,21 +578,36 @@ def load_vllm_engine( "kv_role": "kv_both", } - engine_args = EngineArgs( - model=model_path, - served_model_name=str(model_id), - gpu_memory_utilization=0.05, - trust_remote_code=trust_remote_code, - load_format="fastsafetensors", - enable_prefix_caching=False, - attention_backend="TRITON_ATTN", - enforce_eager=True, - disable_log_stats=True, - kv_transfer_config=kv_transfer_config, # type: ignore - ) + is_nvfp4 = "nvfp4" in model_path.lower() or "nvfp4" in str(model_id).lower() + backends = ["FLASHINFER", "TRITON_ATTN"] if is_nvfp4 else ["FLASH_ATTN", "TRITON_ATTN"] - set_weight_loading_callback(on_layer_loaded) - engine = LLMEngine.from_engine_args(engine_args) + engine: LLMEngine | None = None + for backend in backends: + try: + engine_args = EngineArgs( + model=model_path, + served_model_name=str(model_id), + gpu_memory_utilization=0.05, + trust_remote_code=trust_remote_code, + load_format="fastsafetensors", + enable_prefix_caching=False, + attention_backend=backend, + enforce_eager=True, + disable_log_stats=True, + max_num_batched_tokens=4096, + kv_transfer_config=kv_transfer_config, # type: ignore + ) + + set_weight_loading_callback(on_layer_loaded) + engine = LLMEngine.from_engine_args(engine_args) + logger.info(f"vLLM engine using attention backend: {backend}") + break + except (ValueError, RuntimeError) as e: + logger.warning(f"Attention backend {backend} failed: {e}, trying next") + continue + + if engine is None: + raise RuntimeError(f"No attention backend worked for {model_id}") tool_parser: ToolParser | None = None tokenizer = engine.get_tokenizer() diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 99ef664af..d2f856545 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -182,7 +182,6 @@ class Worker: json.dump(endpoints, f) except: logger.warning("Updating prefill endpoints failed") - pass async def plan_step(self): while True: diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index 685ca2df4..35606429b 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -70,7 +70,7 @@ def entrypoint( if isinstance(bound_instance.instance, VllmInstance): os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD" - os.environ["VLLM_BATCH_INVARIANT"] = "1" + # os.environ["VLLM_BATCH_INVARIANT"] = "1" _ensure_cuda_libs() from exo.shared.constants import EXO_MODELS_DIR from exo.worker.runner.llm_inference.runner import Runner, VllmBuilder diff --git a/src/exo/worker/runner/llm_inference/runner.py b/src/exo/worker/runner/llm_inference/runner.py index 16bbb1d7a..b1f965592 100644 --- a/src/exo/worker/runner/llm_inference/runner.py +++ b/src/exo/worker/runner/llm_inference/runner.py @@ -561,7 +561,9 @@ class VllmBuilder(Builder): tokenizer = TokenizerWrapper(self._engine.get_tokenizer()) max_concurrent = 1 if os.environ.get("EXO_NO_BATCH") else 8 - prefill_port = int(os.environ.get("EXO_PREFILL_PORT", "8900")) + from exo.master.placement import random_ephemeral_port + + prefill_port = random_ephemeral_port() overlapping = not os.environ.get("EXO_NO_OVERLAPPING_PREFILL_SENDS") try: from exo.disaggregated.prefill_server import start_prefill_server @@ -571,6 +573,7 @@ class VllmBuilder(Builder): bind_address="0.0.0.0", port=prefill_port, overlapping=overlapping, + prefix_cache=self._prefix_cache, ) self._prefill_server_port = prefill_port except Exception: