From 016de1803b2fbbaa53b7e80b4e4e00807764346a Mon Sep 17 00:00:00 2001 From: Ryuichi Leo Takashige Date: Fri, 20 Mar 2026 11:23:52 +0000 Subject: [PATCH] Unnecessary further optimizations 2 --- scripts/compare_kv.py | 107 ++++++++++++- src/exo/disaggregated/batch_connector.py | 15 +- src/exo/disaggregated/prefill_server.py | 140 +++++++++++++++--- src/exo/disaggregated/protocol.py | 25 +++- src/exo/disaggregated/streaming_connector.py | 18 ++- .../worker/engines/mlx/gdn_softplus_patch.py | 58 ++++++++ src/exo/worker/engines/vllm/growable_cache.py | 22 +-- src/exo/worker/engines/vllm/kv_cache.py | 26 +++- src/exo/worker/engines/vllm/vllm_generator.py | 18 ++- src/exo/worker/runner/bootstrap.py | 4 + test_growable_compile.py | 98 ++++++++++++ tmp/spark/spark-build-apple-cdc-ncm.sh | 0 12 files changed, 488 insertions(+), 43 deletions(-) create mode 100644 src/exo/worker/engines/mlx/gdn_softplus_patch.py create mode 100644 test_growable_compile.py mode change 100644 => 100755 tmp/spark/spark-build-apple-cdc-ncm.sh diff --git a/scripts/compare_kv.py b/scripts/compare_kv.py index 0aad5611e..6a5ff57b8 100644 --- a/scripts/compare_kv.py +++ b/scripts/compare_kv.py @@ -1,5 +1,9 @@ import sys sys.path.insert(0, "src") +from exo.worker.engines.mlx.gdn_softplus_patch import patch_gdn_softplus +from exo.worker.engines.mlx.yarn_rope_patch import patch_yarn_rope +patch_gdn_softplus() +patch_yarn_rope() import mlx.core as mx import torch import socket @@ -7,8 +11,8 @@ from pathlib import Path import json from collections import defaultdict from mlx_lm import load -from mlx_lm.models.cache import RotatingKVCache, KVCache -from exo.disaggregated.protocol import read_header, read_message, KVChunk, Done +from mlx_lm.models.cache import ArraysCache, RotatingKVCache, KVCache +from exo.disaggregated.protocol import read_header, read_message, ArraysState, KVChunk, Done from exo.disaggregated.prefill_client import _nhd_to_bhsd, _torch_to_mx ENDPOINT = sys.argv[1] if len(sys.argv) > 1 else "10.43.0.1:62988" @@ -16,7 +20,7 @@ MODEL = sys.argv[2] if len(sys.argv) > 2 else "mlx-community/Llama-3.2-1B-Instru MODEL_PATH = sys.argv[3] if len(sys.argv) > 3 else None model, tok = load(MODEL_PATH or str(Path.home() / ".exo/models" / MODEL.replace("/", "--"))) -prompt = "The quick brown fox jumps over the lazy dog. " * 300 +prompt = "The quick brown fox jumps over the lazy dog. " * 3000 tokens = tok.encode(prompt) print(f"Tokens: {len(tokens)}") @@ -28,15 +32,18 @@ stream = sock.makefile("rb", buffering=65536) header = read_header(stream) vllm_kv = defaultdict(list) +vllm_arrays: dict[int, list[torch.Tensor]] = {} while True: msg = read_message(stream, header) if msg is None or isinstance(msg, Done): break if isinstance(msg, KVChunk): vllm_kv[msg.layer_idx].append((msg.keys, msg.values)) + elif isinstance(msg, ArraysState): + vllm_arrays[msg.layer_idx] = msg.arrays sock.close() -print(f"Received {len(vllm_kv)} layers from vLLM") +print(f"Received {len(vllm_kv)} KV layers, {len(vllm_arrays)} arrays layers from vLLM") if hasattr(model, "make_cache"): mlx_cache = model.make_cache() @@ -49,6 +56,27 @@ mx.eval(mlx_logits) for i in range(min(6, len(mlx_cache))): c = mlx_cache[i] + if isinstance(c, ArraysCache): + if i in vllm_arrays: + vllm_arrs = vllm_arrays[i] + mlx_state = c.state + print(f"Layer {i} (Arrays): mlx_state={len(mlx_state)} arrays, vllm={len(vllm_arrs)} arrays") + for ai, (m_arr, v_arr) in enumerate(zip(mlx_state, vllm_arrs)): + if m_arr is None: + continue + v_mx = _torch_to_mx(v_arr).astype(mx.float32) + m_f = m_arr.astype(mx.float32) + if m_f.shape != v_mx.shape: + print(f" [{ai}] SHAPE MISMATCH mlx={m_f.shape} vllm={v_mx.shape}") + else: + d = mx.abs(m_f - v_mx) + a = m_f.reshape(-1) + b = v_mx.reshape(-1) + cos = float(mx.sum(a * b).item()) / (float(mx.sqrt(mx.sum(a * a)).item()) * float(mx.sqrt(mx.sum(b * b)).item()) + 1e-8) + print(f" [{ai}] cosine_sim={cos:.6f} max_diff={mx.max(d).item():.6f} mean_diff={mx.mean(d).item():.6f} shape={m_f.shape}") + else: + print(f"Layer {i} (Arrays): no vLLM data") + continue if c.keys is None: continue mlx_k = c.keys.astype(mx.float32) @@ -85,3 +113,74 @@ for i in range(min(6, len(mlx_cache))): diffs = [abs(mlx_row[d] - vllm_row[d]) for d in range(D)] top5 = sorted(range(D), key=lambda d: -diffs[d])[:5] print(f" pos={pos} top5 diff dims: {[(d, f'{diffs[d]:.3f}', f'mlx={mlx_row[d]:.3f}', f'vllm={vllm_row[d]:.3f}') for d in top5]}") + +print("\n--- Run 2: cached request ---") +sock2 = socket.create_connection((host, int(port)), timeout=60) +request2 = json.dumps({"model": MODEL, "token_ids": tokens, "start_pos": 0}).encode() + b"\n" +sock2.sendall(request2) +stream2 = sock2.makefile("rb", buffering=65536) + +first_byte = stream2.peek(1)[:1] +if first_byte == b"{": + line2 = stream2.readline() + print(f"Server error: {json.loads(line2.decode())}") + sys.exit(1) + +header2 = read_header(stream2) +vllm_kv2 = defaultdict(list) +vllm_arrays2: dict[int, list[torch.Tensor]] = {} +total_tokens2 = 0 +while True: + msg = read_message(stream2, header2) + if msg is None: + break + if isinstance(msg, KVChunk): + vllm_kv2[msg.layer_idx].append((msg.keys, msg.values)) + elif isinstance(msg, ArraysState): + vllm_arrays2[msg.layer_idx] = msg.arrays + elif isinstance(msg, Done): + total_tokens2 = msg.total_tokens + break +sock2.close() + +kv_tokens2 = 0 +if vllm_kv2: + first_layer = next(iter(vllm_kv2.values())) + kv_tokens2 = sum(k.shape[0] for k, v in first_layer) +print(f"Received {len(vllm_kv2)} KV layers ({kv_tokens2} tokens), {len(vllm_arrays2)} arrays layers, total_tokens={total_tokens2}") + +for i in range(min(6, len(mlx_cache))): + c = mlx_cache[i] + if isinstance(c, ArraysCache): + if i in vllm_arrays2: + vllm_arrs = vllm_arrays2[i] + mlx_state = c.state + for ai, (m_arr, v_arr) in enumerate(zip(mlx_state, vllm_arrs)): + if m_arr is None: + continue + v_mx = _torch_to_mx(v_arr).astype(mx.float32) + m_f = m_arr.astype(mx.float32) + if m_f.shape != v_mx.shape: + print(f"Layer {i} [{ai}] SHAPE MISMATCH mlx={m_f.shape} vllm={v_mx.shape}") + else: + a2 = m_f.reshape(-1) + b2 = v_mx.reshape(-1) + cos2 = float(mx.sum(a2 * b2).item()) / (float(mx.sqrt(mx.sum(a2 * a2)).item()) * float(mx.sqrt(mx.sum(b2 * b2)).item()) + 1e-8) + print(f"Layer {i} (Arrays) [{ai}] cosine_sim={cos2:.6f} shape={m_f.shape}") + continue + if c.keys is None or i not in vllm_kv2: + continue + mlx_k = c.keys.astype(mx.float32) + chunks = vllm_kv2[i] + vk = torch.cat([k for k, v in chunks], dim=0) if len(chunks) > 1 else chunks[0][0] + vk_mx = _torch_to_mx(vk.permute(1, 0, 2).unsqueeze(0)).astype(mx.float32) + n = min(mlx_k.shape[2], vk_mx.shape[2]) + a2 = mlx_k[:, :, :n, :].reshape(-1) + b2 = vk_mx[:, :, :n, :].reshape(-1) + cos2 = float(mx.sum(a2 * b2).item()) / (float(mx.sqrt(mx.sum(a2 * a2)).item()) * float(mx.sqrt(mx.sum(b2 * b2)).item()) + 1e-8) + print(f"Layer {i} (KV) cosine_sim={cos2:.6f} mlx={mlx_k.shape} vllm={vk_mx.shape}") + +if len(vllm_kv2) > 0: + print("PASS") +else: + print("FAIL") diff --git a/src/exo/disaggregated/batch_connector.py b/src/exo/disaggregated/batch_connector.py index 8de7440c9..bbbd9a3cb 100644 --- a/src/exo/disaggregated/batch_connector.py +++ b/src/exo/disaggregated/batch_connector.py @@ -9,19 +9,26 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( # pyright: igno KVConnectorBase_V1, # pyright: ignore[reportUnknownVariableType] KVConnectorMetadata, # pyright: ignore[reportUnknownVariableType] KVConnectorRole, # pyright: ignore[reportUnknownVariableType] + SupportsHMA, # pyright: ignore[reportUnknownVariableType] ) _LAYER_RE = re.compile(r"layers\.(\d+)\.") _shared_captured_layers: dict[int, dict[str, torch.Tensor]] = {} +_shared_captured_arrays: dict[int, list[torch.Tensor]] = {} def get_shared_captured_layers() -> dict[int, dict[str, torch.Tensor]]: return _shared_captured_layers +def get_shared_captured_arrays() -> dict[int, list[torch.Tensor]]: + return _shared_captured_arrays + + def clear_shared_captured_layers() -> None: _shared_captured_layers.clear() + _shared_captured_arrays.clear() @dataclass @@ -29,7 +36,7 @@ class BatchConnectorMetadata(KVConnectorMetadata): # pyright: ignore[reportUnty pass -class BatchConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBaseClass] +class BatchConnector(KVConnectorBase_V1, SupportsHMA): # pyright: ignore[reportUntypedBaseClass] captured_layers: dict[int, dict[str, torch.Tensor]] def __init__(self, vllm_config: Any, role: KVConnectorRole, kv_cache_config: Any = None) -> None: # type: ignore @@ -53,6 +60,9 @@ class BatchConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBaseCl layer_idx = int(m.group(1)) if isinstance(kv_layer, (list, tuple)): + from exo.disaggregated.streaming_connector import _to_bf16 + + _shared_captured_arrays[layer_idx] = [_to_bf16(t).cpu() for t in kv_layer] # pyright: ignore[reportAny] return if slot_mapping is not None: @@ -88,6 +98,9 @@ class BatchConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBaseCl def wait_for_save(self) -> None: pass + def request_finished_all_groups(self, request: Any, block_ids: tuple[list[int], ...]) -> tuple[bool, dict[str, Any] | None]: # pyright: ignore[reportAny] + return False, None + def get_num_new_matched_tokens(self, request: Any, num_computed_tokens: int) -> tuple[int, bool]: # pyright: ignore[reportAny] return 0, False diff --git a/src/exo/disaggregated/prefill_server.py b/src/exo/disaggregated/prefill_server.py index d1c857d2e..9d5ab5e2b 100644 --- a/src/exo/disaggregated/prefill_server.py +++ b/src/exo/disaggregated/prefill_server.py @@ -34,6 +34,7 @@ _gdn_patched: bool = False _gdn_states: dict[int, dict[str, torch.Tensor]] = {} _gdn_layer_order: list[int] = [] _gdn_call_idx: list[int] = [0] +_ssm_call_idx: list[int] = [0] def _patch_vllm_for_connector(connector_class: type[Any]) -> None: # pyright: ignore[reportUnusedFunction] @@ -116,6 +117,46 @@ def _patch_gdn_capture() -> None: mod.causal_conv1d_fn = patched_fn logger.info("Patched causal_conv1d_fn for GDN state capture") + try: + from vllm.model_executor.models import qwen3_next as qn_mod # type: ignore + + orig_chunk = getattr(qn_mod, "fi_chunk_gated_delta_rule", None) # type: ignore + if orig_chunk is None: + return + + def patched_chunk(*args: Any, **kwargs: Any) -> Any: + result = orig_chunk(*args, **kwargs) + output_final_state = kwargs.get("output_final_state", False) + if output_final_state and isinstance(result, tuple) and len(result) == 2: + _, ssm_state = result + idx = _ssm_call_idx[0] + if _gdn_layer_order and idx < len(_gdn_layer_order) * 100: + layer_idx = _gdn_layer_order[idx % len(_gdn_layer_order)] + _gdn_states.setdefault(layer_idx, {})["ssm"] = ssm_state.cpu() # type: ignore + _ssm_call_idx[0] += 1 + return result + + qn_mod.fi_chunk_gated_delta_rule = patched_chunk # type: ignore + + orig_fla_chunk = getattr(qn_mod, "fla_chunk_gated_delta_rule", None) # type: ignore + if orig_fla_chunk is not None: + def patched_fla_chunk(*args: Any, **kwargs: Any) -> Any: + result = orig_fla_chunk(*args, **kwargs) + output_final_state = kwargs.get("output_final_state", False) + if output_final_state and isinstance(result, tuple) and len(result) == 2: + _, ssm_state = result + idx = _ssm_call_idx[0] + if _gdn_layer_order and idx < len(_gdn_layer_order) * 100: + layer_idx = _gdn_layer_order[idx % len(_gdn_layer_order)] + _gdn_states.setdefault(layer_idx, {})["ssm"] = ssm_state.cpu() # type: ignore + _ssm_call_idx[0] += 1 + return result + qn_mod.fla_chunk_gated_delta_rule = patched_fla_chunk # type: ignore + + logger.info("Patched chunk_gated_delta_rule for SSM state capture") + except ImportError: + pass + def _init_gdn_layer_order() -> None: from exo.worker.engines.vllm.growable_cache import get_model_runner @@ -163,6 +204,7 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], start_pos: assert model_runner is not None from exo.disaggregated.streaming_connector import ( + get_shared_arrays_queue, get_shared_queue, reset_shared_queue, ) @@ -170,16 +212,49 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], start_pos: reset_shared_queue() _gdn_states.clear() _gdn_call_idx[0] = 0 + _ssm_call_idx[0] = 0 layer_queue = get_shared_queue() + arrays_queue = get_shared_arrays_queue() server_cached = 0 + cached_data: TorchKVCache | None = None if _prefix_cache_ref is not None: - _, server_cached, _ = _prefix_cache_ref.lookup(token_ids) + cached_data, server_cached, _ = _prefix_cache_ref.lookup(token_ids) + if not isinstance(cached_data, TorchKVCache): + cached_data = None + server_cached = 0 skip_tokens = max(0, start_pos - server_cached) num_layers, dtype_str, layers_info = _get_layer_info(engine) write_header(wfile, {"num_layers": num_layers, "dtype": dtype_str, "layers": layers_info}) # pyright: ignore[reportAny] + if cached_data is not None and start_pos < server_cached: + from exo.worker.engines.vllm.kv_cache import ArraysLayerState + kv_sent = 0 + arr_sent = 0 + for i, layer in enumerate(cached_data.layers): + if isinstance(layer, KVLayerState) and layer.keys.numel() > 0: + keys = layer.keys + values = layer.values + if keys.shape != values.shape: + logger.warning(f"Skipping layer {i}: keys={list(keys.shape)} != values={list(values.shape)}") + continue + if keys.dim() == 4: + keys = keys.reshape(-1, keys.shape[-2], keys.shape[-1]) + values = values.reshape(-1, values.shape[-2], values.shape[-1]) + keys = keys[start_pos:server_cached] + values = values[start_pos:server_cached] + if keys.numel() > 0: + write_kv_chunk(wfile, i, keys, values) # pyright: ignore[reportAny] + kv_sent += 1 + + elif isinstance(layer, ArraysLayerState): + arrays = [a for a in layer.arrays if a is not None] + if arrays: + write_arrays_state(wfile, i, arrays) # pyright: ignore[reportAny] + arr_sent += 1 + logger.info(f"Sent cached: {kv_sent} KV, {arr_sent} arrays for positions {start_pos}-{server_cached}") + from vllm.sampling_params import ( SamplingParams, ) @@ -191,6 +266,7 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], start_pos: chunks_sent = [0] layer_token_counts: dict[int, int] = {} + all_kv_chunks: list[tuple[int, torch.Tensor, torch.Tensor]] = [] def writer_loop() -> None: while True: @@ -198,6 +274,7 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], start_pos: if item is None: break layer_idx, keys, values = item + all_kv_chunks.append((layer_idx, keys, values)) prev = layer_token_counts.get(layer_idx, 0) n = keys.shape[0] @@ -218,12 +295,10 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], start_pos: writer_thread = threading.Thread(target=writer_loop, daemon=True) writer_thread.start() - extracted_cache: TorchKVCache | None = None while engine.has_unfinished_requests(): outputs = engine.step() for output in outputs: if output.request_id == request_id and output.outputs[0].token_ids: - extracted_cache = _extract_vllm_cache(engine, request_id, len(prefill_token_ids)) engine.abort_request([request_id]) # type: ignore break else: @@ -233,15 +308,33 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], start_pos: layer_queue.put(None) writer_thread.join() actual_per_layer = max(layer_token_counts.values()) if layer_token_counts else 0 - tokens_sent = max(0, actual_per_layer - skip_tokens) + cached_tokens_sent = max(0, server_cached - start_pos) if cached_data is not None and start_pos < server_cached else 0 + tokens_sent = cached_tokens_sent + max(0, actual_per_layer - skip_tokens) logger.info(f"Overlapping prefill: sent {chunks_sent[0]} chunks, {tokens_sent} tokens (server_cached={server_cached}, skip={skip_tokens})") + while not arrays_queue.empty(): + item = arrays_queue.get_nowait() + if item is not None: + layer_idx, arrays = item + write_arrays_state(wfile, layer_idx, arrays) # pyright: ignore[reportAny] + + gdn_snapshot: list[tuple[int, list[torch.Tensor]]] = [] + for layer_idx in sorted(_gdn_states.keys()): + state = _gdn_states[layer_idx] + arrs: list[torch.Tensor] = [] + if "conv" in state: + arrs.append(state["conv"]) + if "ssm" in state: + arrs.append(state["ssm"]) + if arrs: + gdn_snapshot.append((layer_idx, arrs)) + cached_arrays: list[tuple[int, list[torch.Tensor]]] = [] _stream_gdn_states_and_collect(engine, wfile, num_layers, layers_info, cached_arrays) write_done(wfile, tokens_sent) # pyright: ignore[reportAny] - if extracted_cache is not None: - threading.Thread(target=_store_prefix_cache, args=(prefill_token_ids, extracted_cache), daemon=True).start() + connector_cache = _build_torch_cache(all_kv_chunks, gdn_snapshot, num_layers) + threading.Thread(target=_store_prefix_cache, args=(prefill_token_ids, connector_cache), daemon=True).start() def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], start_pos: int, wfile: Any) -> None: # pyright: ignore[reportAny] @@ -254,6 +347,7 @@ def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], start_pos: int, from exo.disaggregated.batch_connector import ( clear_shared_captured_layers, + get_shared_captured_arrays, get_shared_captured_layers, ) @@ -261,6 +355,7 @@ def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], start_pos: int, _gdn_call_idx[0] = 0 clear_shared_captured_layers() captured_layers = get_shared_captured_layers() + captured_arrays = get_shared_captured_arrays() server_cached = 0 if _prefix_cache_ref is not None: @@ -276,12 +371,10 @@ def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], start_pos: int, params = SamplingParams(max_tokens=2, detokenize=False) # pyright: ignore[reportCallIssue] engine.add_request(request_id, {"prompt_token_ids": prefill_token_ids}, params) # pyright: ignore[reportArgumentType] - extracted_cache: TorchKVCache | None = None while engine.has_unfinished_requests(): outputs = engine.step() for output in outputs: if output.request_id == request_id and output.outputs[0].token_ids: - extracted_cache = _extract_vllm_cache(engine, request_id, len(prefill_token_ids)) engine.abort_request([request_id]) # type: ignore break else: @@ -298,18 +391,22 @@ def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], start_pos: int, all_kv.append((layer_idx, keys, values)) if keys.shape[0] > skip_tokens: write_kv_chunk(wfile, layer_idx, keys[skip_tokens:], values[skip_tokens:]) # pyright: ignore[reportAny] - clear_shared_captured_layers() actual_per_layer = max((k.shape[0] for _, k, _ in all_kv), default=0) tokens_sent = max(0, actual_per_layer - skip_tokens) logger.info(f"Batch prefill: {len(all_kv)} layers, {tokens_sent} tokens sent (server_cached={server_cached}, skip={skip_tokens}, captured={actual_per_layer})") + batch_arrays: list[tuple[int, list[torch.Tensor]]] = list(captured_arrays.items()) + for layer_idx, arrs in batch_arrays: + write_arrays_state(wfile, layer_idx, arrs) # pyright: ignore[reportAny] + clear_shared_captured_layers() + cached_arrays: list[tuple[int, list[torch.Tensor]]] = [] _stream_gdn_states_and_collect(engine, wfile, num_layers, layers_info, cached_arrays) write_done(wfile, tokens_sent) # pyright: ignore[reportAny] - if extracted_cache is not None: - threading.Thread(target=_store_prefix_cache, args=(prefill_token_ids, extracted_cache), daemon=True).start() + connector_cache = _build_torch_cache(all_kv, batch_arrays, num_layers) + threading.Thread(target=_store_prefix_cache, args=(prefill_token_ids, connector_cache), daemon=True).start() def _stream_gdn_states_and_collect( @@ -334,18 +431,14 @@ def _stream_gdn_states_and_collect( for layer_idx in sorted(_gdn_states.keys()): try: state = _gdn_states[layer_idx] - ci: int = state.get("ci", 0) # type: ignore conv = state.get("conv") - kv = kv_caches[layer_idx] # type: ignore - rec: torch.Tensor | None = None - if isinstance(kv, (list, tuple)) and len(kv) > 1: - rec = kv[1][ci : ci + 1].cpu().clone() # type: ignore + ssm = state.get("ssm") arrays: list[torch.Tensor] = [] if conv is not None: - arrays.append(conv.to(torch.bfloat16)) - if rec is not None: - arrays.append(rec.to(torch.bfloat16)) + arrays.append(conv) + if ssm is not None: + arrays.append(ssm) if arrays: write_arrays_state(wfile, layer_idx, arrays) # type: ignore out_arrays.append((layer_idx, arrays)) @@ -408,13 +501,16 @@ def _extract_vllm_cache(engine: LLMEngine, request_id: str, num_tokens: int) -> null_block = coordinator.block_pool.null_block # type: ignore block_ids_per_group: list[list[int]] = [] token_offset_per_group: list[int] = [] + block_sizes_per_group: list[int] = [] for mgr in coordinator.single_type_managers: # type: ignore blocks = mgr.req_to_blocks.get(internal_id) # type: ignore if not blocks: block_ids_per_group.append([]) token_offset_per_group.append(0) + block_sizes_per_group.append(0) continue block_size: int = mgr.block_size # type: ignore + block_sizes_per_group.append(block_size) num_leading_nulls = 0 for b in blocks: # type: ignore if b is null_block or b.is_null: # type: ignore @@ -432,6 +528,7 @@ def _extract_vllm_cache(engine: LLMEngine, request_id: str, num_tokens: int) -> layer_to_group, num_tokens, token_offset_per_group, + block_sizes_per_group, ) except Exception: logger.opt(exception=True).warning("Failed to extract vLLM cache") @@ -482,13 +579,18 @@ def _send_cached(torch_cache: TorchKVCache, token_ids: list[int], wfile: Any, en write_header(wfile, {"num_layers": num_layers, "dtype": dtype_str, "layers": layers_info}) # type: ignore from exo.worker.engines.vllm.kv_cache import ArraysLayerState + kv_sent = 0 + arr_sent = 0 for i, layer in enumerate(torch_cache.layers): if isinstance(layer, KVLayerState) and layer.keys.numel() > 0: write_kv_chunk(wfile, i, layer.keys, layer.values) # type: ignore + kv_sent += 1 elif isinstance(layer, ArraysLayerState): arrays = [a for a in layer.arrays if a is not None] if arrays: write_arrays_state(wfile, i, arrays) # type: ignore + arr_sent += 1 + logger.info(f"_send_cached: sent {kv_sent} KV layers, {arr_sent} arrays layers") write_done(wfile, len(token_ids)) # type: ignore diff --git a/src/exo/disaggregated/protocol.py b/src/exo/disaggregated/protocol.py index bfa47acf4..ca24ff2c0 100644 --- a/src/exo/disaggregated/protocol.py +++ b/src/exo/disaggregated/protocol.py @@ -77,6 +77,9 @@ def _tensor_to_bytes(t: torch.Tensor) -> bytes: def write_kv_chunk(stream: BinaryIO, layer_idx: int, keys: torch.Tensor, values: torch.Tensor) -> None: + if keys.dim() == 4: + keys = keys.reshape(-1, keys.shape[-2], keys.shape[-1]) + values = values.reshape(-1, values.shape[-2], values.shape[-1]) keys_bytes = _tensor_to_bytes(keys) values_bytes = _tensor_to_bytes(values) num_tokens: int = keys.shape[0] @@ -86,11 +89,18 @@ def write_kv_chunk(stream: BinaryIO, layer_idx: int, keys: torch.Tensor, values: _write_exactly(stream, header + keys_bytes + values_bytes) +def _dtype_to_str(dtype: torch.dtype) -> str: + return {torch.float16: "float16", torch.bfloat16: "bfloat16", torch.float32: "float32"}[dtype] + + def write_arrays_state(stream: BinaryIO, layer_idx: int, arrays: list[torch.Tensor]) -> None: buf = io.BytesIO() buf.write(struct.pack(">BI", MSG_ARRAYS_STATE, layer_idx)) buf.write(struct.pack(">I", len(arrays))) for arr in arrays: + dtype_str = _dtype_to_str(arr.dtype).encode("utf-8") + buf.write(struct.pack(">I", len(dtype_str))) + buf.write(dtype_str) shape: tuple[int, ...] = tuple(arr.shape) buf.write(struct.pack(">I", len(shape))) for dim in shape: @@ -143,10 +153,17 @@ def read_message(stream: BinaryIO, header: dict[str, object]) -> Message | None: num_arrays: int arr_layer_idx, = struct.unpack(">I", _read_exactly(stream, 4)) # pyright: ignore[reportAny] num_arrays, = struct.unpack(">I", _read_exactly(stream, 4)) # pyright: ignore[reportAny] - dtype = _str_to_dtype(str(header["dtype"])) - elem_size = _dtype_size(dtype) + fallback_dtype = _str_to_dtype(str(header["dtype"])) arrays: list[torch.Tensor] = [] for _ in range(num_arrays): + dtype_len_raw = _read_exactly(stream, 4) + dtype_len: int = struct.unpack(">I", dtype_len_raw)[0] # pyright: ignore[reportAny] + if dtype_len > 0 and dtype_len < 20: + dtype_str_bytes = _read_exactly(stream, dtype_len) + arr_dtype = _str_to_dtype(dtype_str_bytes.decode("utf-8")) + else: + arr_dtype = fallback_dtype + elem_size = _dtype_size(arr_dtype) ndim: int ndim, = struct.unpack(">I", _read_exactly(stream, 4)) # pyright: ignore[reportAny] shape_arr = struct.unpack(f">{ndim}I", _read_exactly(stream, ndim * 4)) @@ -154,10 +171,10 @@ def read_message(stream: BinaryIO, header: dict[str, object]) -> Message | None: for d in shape_arr: # pyright: ignore[reportAny] total_elems *= d # pyright: ignore[reportAny] raw = _read_exactly(stream, total_elems * elem_size) - if dtype == torch.bfloat16: + if arr_dtype == torch.bfloat16: t: torch.Tensor = torch.frombuffer(bytearray(raw), dtype=torch.int16).view(torch.bfloat16).reshape(shape_arr).clone() # type: ignore else: - t = torch.frombuffer(bytearray(raw), dtype=dtype).reshape(shape_arr).clone() # type: ignore + t = torch.frombuffer(bytearray(raw), dtype=arr_dtype).reshape(shape_arr).clone() # type: ignore arrays.append(t) # pyright: ignore[reportUnknownArgumentType] return ArraysState(layer_idx=arr_layer_idx, arrays=arrays) diff --git a/src/exo/disaggregated/streaming_connector.py b/src/exo/disaggregated/streaming_connector.py index a3c8b1644..85ec4c640 100644 --- a/src/exo/disaggregated/streaming_connector.py +++ b/src/exo/disaggregated/streaming_connector.py @@ -10,6 +10,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( # pyright: igno KVConnectorBase_V1, # pyright: ignore[reportUnknownVariableType] KVConnectorMetadata, # pyright: ignore[reportUnknownVariableType] KVConnectorRole, # pyright: ignore[reportUnknownVariableType] + SupportsHMA, # pyright: ignore[reportUnknownVariableType] ) _LAYER_RE = re.compile(r"layers\.(\d+)\.") @@ -25,18 +26,28 @@ def _to_bf16(t: torch.Tensor) -> torch.Tensor: return t.to(torch.bfloat16) _shared_queue: queue.Queue[tuple[int, torch.Tensor, torch.Tensor] | None] = queue.Queue() +_shared_arrays_queue: queue.Queue[tuple[int, list[torch.Tensor]] | None] = queue.Queue() def get_shared_queue() -> queue.Queue[tuple[int, torch.Tensor, torch.Tensor] | None]: return _shared_queue +def get_shared_arrays_queue() -> queue.Queue[tuple[int, list[torch.Tensor]] | None]: + return _shared_arrays_queue + + def reset_shared_queue() -> None: while not _shared_queue.empty(): try: _shared_queue.get_nowait() except queue.Empty: break + while not _shared_arrays_queue.empty(): + try: + _shared_arrays_queue.get_nowait() + except queue.Empty: + break @dataclass @@ -44,7 +55,7 @@ class StreamingConnectorMetadata(KVConnectorMetadata): # pyright: ignore[report pass -class StreamingConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBaseClass] +class StreamingConnector(KVConnectorBase_V1, SupportsHMA): # pyright: ignore[reportUntypedBaseClass] _queue: queue.Queue[tuple[int, torch.Tensor, torch.Tensor] | None] _save_count: int = 0 @@ -74,6 +85,8 @@ class StreamingConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBa layer_idx = int(m.group(1)) if isinstance(kv_layer, (list, tuple)): + arrays = [_to_bf16(t).cpu() for t in kv_layer] # pyright: ignore[reportAny] + _shared_arrays_queue.put((layer_idx, arrays)) return if self._save_count < 1: @@ -104,6 +117,9 @@ class StreamingConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBa def finish(self) -> None: self._queue.put(None) + def request_finished_all_groups(self, request: Any, block_ids: tuple[list[int], ...]) -> tuple[bool, dict[str, Any] | None]: # pyright: ignore[reportAny] + return False, None + def get_num_new_matched_tokens(self, request: Any, num_computed_tokens: int) -> tuple[int, bool]: # pyright: ignore[reportAny] return 0, False diff --git a/src/exo/worker/engines/mlx/gdn_softplus_patch.py b/src/exo/worker/engines/mlx/gdn_softplus_patch.py new file mode 100644 index 000000000..3bec0169b --- /dev/null +++ b/src/exo/worker/engines/mlx/gdn_softplus_patch.py @@ -0,0 +1,58 @@ +"""Patch mlx_lm's GDN gated_delta_update to match vLLM's float32 precision. + +vLLM computes both softplus (gating) and sigmoid (beta) in float32. +mlx_lm computes them in bfloat16. The precision difference compounds +through the SSM recurrence over thousands of tokens. +""" + +import sys +from functools import partial +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + + +@partial(mx.compile, shapeless=True) +def _compute_g_f32(A_log: mx.array, a: mx.array, dt_bias: mx.array) -> mx.array: + return mx.exp( + -mx.exp(A_log.astype(mx.float32)) + * nn.softplus((a + dt_bias).astype(mx.float32)) + ) + + +def patch_gdn_softplus() -> None: + from mlx_lm.models import gated_delta + + orig_update = gated_delta.gated_delta_update + orig_ops = gated_delta.gated_delta_ops + orig_kernel = gated_delta.gated_delta_kernel + + def patched_gated_delta_update( + q: mx.array, + k: mx.array, + v: mx.array, + a: mx.array, + b: mx.array, + A_log: mx.array, + dt_bias: mx.array, + state: Optional[mx.array] = None, + mask: Optional[mx.array] = None, + use_kernel: bool = True, + ) -> Tuple[mx.array, mx.array]: + beta = mx.sigmoid(b.astype(mx.float32)).astype(b.dtype) + g = _compute_g_f32(A_log, a, dt_bias) + if state is None: + B, _, Hk, Dk = q.shape + Hv, Dv = v.shape[-2:] + state = mx.zeros((B, Hv, Dv, Dk), dtype=q.dtype) + + return orig_ops(q, k, v, g, beta, state, mask) + + gated_delta.gated_delta_update = patched_gated_delta_update + + for mod in list(sys.modules.values()): + if mod is None or mod is gated_delta: + continue + if getattr(mod, "gated_delta_update", None) is orig_update: + mod.gated_delta_update = patched_gated_delta_update diff --git a/src/exo/worker/engines/vllm/growable_cache.py b/src/exo/worker/engines/vllm/growable_cache.py index 716520d52..bbfc069ba 100644 --- a/src/exo/worker/engines/vllm/growable_cache.py +++ b/src/exo/worker/engines/vllm/growable_cache.py @@ -56,17 +56,18 @@ def _patch_determine_available_memory() -> None: @torch.inference_mode() def patched(self: "Worker") -> int: + real_empty_cache = torch.cuda.empty_cache + torch.cuda.empty_cache = lambda: None # type: ignore try: original(self) - except AssertionError: - logger.warning( - "vLLM memory profiling assertion failed (free memory changed during init, " - "likely another process released GPU memory). Continuing with growable cache." - ) - torch.cuda.empty_cache() + except (AssertionError, Exception): + pass + finally: + torch.cuda.empty_cache = real_empty_cache # type: ignore free_bytes, _ = torch.cuda.mem_get_info() initial = max(int(free_bytes * INITIAL_FRACTION), 1) self._growable_max_kv_bytes = free_bytes + self.available_kv_cache_memory_bytes = initial logger.info( f"Growable KV cache: initial {initial / (1024**3):.2f} GiB " f"(max {free_bytes / (1024**3):.2f} GiB)" @@ -164,12 +165,10 @@ def _try_grow_cache(kv_cache_manager: "object") -> bool: model_runner = kv_cache_manager._growable_model_runner # type: ignore if model_runner is None: - logger.debug("No model_runner reference — cannot grow cache") return False free_bytes, _ = torch.cuda.mem_get_info() if free_bytes < GROWTH_HEADROOM_BYTES: - logger.debug(f"Only {free_bytes / (1024**3):.2f} GiB free — not enough to grow") return False kv_cache_config = model_runner._growable_kv_cache_config # type: ignore @@ -182,7 +181,6 @@ def _try_grow_cache(kv_cache_manager: "object") -> bool: growth_blocks = min(usable_bytes // per_block_bytes, old_num_blocks) if growth_blocks < MIN_GROWTH_BLOCKS: - logger.debug(f"Growth too small ({growth_blocks} blocks)") return False new_num_blocks = old_num_blocks + growth_blocks @@ -193,11 +191,11 @@ def _try_grow_cache(kv_cache_manager: "object") -> bool: ) try: - _grow_tensors(model_runner, kv_cache_config, old_num_blocks, new_num_blocks) - _grow_block_pool(block_pool, old_num_blocks, new_num_blocks) kv_cache_config.num_blocks = new_num_blocks for tensor_spec in kv_cache_config.kv_cache_tensors: tensor_spec.size = int(tensor_spec.size * new_num_blocks / old_num_blocks) + _grow_tensors(model_runner, kv_cache_config, old_num_blocks, new_num_blocks) + _grow_block_pool(block_pool, old_num_blocks, new_num_blocks) logger.info(f"KV cache grown successfully to {new_num_blocks} blocks") return True except Exception: @@ -205,6 +203,8 @@ def _try_grow_cache(kv_cache_manager: "object") -> bool: return False + + def _grow_tensors( model_runner: "object", kv_cache_config: "object", diff --git a/src/exo/worker/engines/vllm/kv_cache.py b/src/exo/worker/engines/vllm/kv_cache.py index 728ef4f5b..77b9e8259 100644 --- a/src/exo/worker/engines/vllm/kv_cache.py +++ b/src/exo/worker/engines/vllm/kv_cache.py @@ -228,6 +228,7 @@ class TorchKVCache: layer_to_group: list[int], num_tokens: int, token_offset_per_group: list[int] | None = None, + block_sizes_per_group: list[int] | None = None, ) -> "TorchKVCache": block_tables = [ torch.tensor(ids, dtype=torch.long) for ids in block_ids_per_group @@ -245,6 +246,18 @@ class TorchKVCache: layers.append(KVLayerState(keys=torch.empty(0), values=torch.empty(0))) continue + if k_all.dim() >= 4 and len(bt) > 0 and block_sizes_per_group is not None: + page_size = k_all.shape[1] + sched_block_size = block_sizes_per_group[gi] + pages_per_block = sched_block_size // page_size + if pages_per_block > 1: + expanded = [] + for b in bt.tolist(): + start_page = b * pages_per_block + end_page = min(start_page + pages_per_block, k_all.shape[0]) + expanded.extend(range(start_page, end_page)) + bt = torch.tensor(expanded, dtype=torch.long) + keys = k_all[bt].to("cpu", non_blocking=True) values = v_all[bt].to("cpu", non_blocking=True) torch.cuda.synchronize() @@ -264,8 +277,18 @@ class TorchKVCache: first = kv_caches[0] device = first[0].device if isinstance(first, list) else first.device - block_size = first[0].shape[1] if isinstance(first, list) else first.shape[-3] for layer_idx, layer in enumerate(self.layers): + if isinstance(layer, ArraysLayerState): + gi = layer_to_group[layer_idx] + bt = block_tables[gi] + kv = kv_caches[layer_idx] + if isinstance(kv, list): + for ti, (stored, target) in enumerate(zip(layer.arrays, kv)): + if stored is not None and target is not None: + n = min(len(bt), stored.shape[0]) + if n > 0: + target[bt[:n]] = stored[:n].to(device, non_blocking=True) + continue if not isinstance(layer, KVLayerState): continue gi = layer_to_group[layer_idx] @@ -275,6 +298,7 @@ class TorchKVCache: keys = layer.keys values = layer.values + block_size = k_all.shape[-3] if k_all.dim() >= 3 else k_all.shape[1] needs_reshape = keys.dim() == 3 and keys.shape[1:] != k_all.shape[1:] if needs_reshape: offset = token_offset_per_group[gi] if token_offset_per_group else 0 diff --git a/src/exo/worker/engines/vllm/vllm_generator.py b/src/exo/worker/engines/vllm/vllm_generator.py index e85471fa8..501dd5557 100644 --- a/src/exo/worker/engines/vllm/vllm_generator.py +++ b/src/exo/worker/engines/vllm/vllm_generator.py @@ -578,8 +578,21 @@ def load_vllm_engine( "kv_role": "kv_both", } + import json + from pathlib import Path + is_nvfp4 = "nvfp4" in model_path.lower() or "nvfp4" in str(model_id).lower() - backends = ["FLASHINFER", "TRITON_ATTN"] if is_nvfp4 else ["FLASH_ATTN", "TRITON_ATTN"] + has_mamba = False + config_path = Path(model_path) / "config.json" + if config_path.exists(): + with open(config_path) as f: + model_config = json.load(f) + text_config = model_config.get("text_config", model_config) + has_mamba = "mamba_ssm_dtype" in text_config or "linear_attention" in (text_config.get("layer_types") or []) + if is_nvfp4 and not has_mamba: + backends = ["FLASHINFER", "FLASH_ATTN", "TRITON_ATTN"] + else: + backends = ["FLASH_ATTN", "TRITON_ATTN"] engine: LLMEngine | None = None for backend in backends: @@ -592,10 +605,11 @@ def load_vllm_engine( load_format="fastsafetensors", enable_prefix_caching=False, attention_backend=backend, - enforce_eager=True, + compilation_config={"cudagraph_mode": "none"}, disable_log_stats=True, max_num_batched_tokens=4096, kv_transfer_config=kv_transfer_config, # type: ignore + disable_hybrid_kv_cache_manager=False, ) set_weight_loading_callback(on_layer_loaded) diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index db168dbca..e5e608de4 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -69,6 +69,10 @@ def entrypoint( patch_yarn_rope() + from exo.worker.engines.mlx.gdn_softplus_patch import patch_gdn_softplus + + patch_gdn_softplus() + # Import main after setting global logger - this lets us just import logger from this module try: if isinstance(bound_instance.instance, VllmInstance): diff --git a/test_growable_compile.py b/test_growable_compile.py new file mode 100644 index 000000000..5d9ae1fc1 --- /dev/null +++ b/test_growable_compile.py @@ -0,0 +1,98 @@ +"""Test hybrid prefix cache: _extract_vllm_cache for attn + captured SSM for mamba.""" +import os, time +os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD" +from exo.worker.engines.vllm.growable_cache import patch_vllm, set_prefix_cache, get_model_runner +patch_vllm() +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +KVConnectorFactory.register_connector("StreamingConnector", "exo.disaggregated.streaming_connector", "StreamingConnector") +from vllm.engine.arg_utils import EngineArgs +from vllm.sampling_params import SamplingParams +from vllm.v1.engine.llm_engine import LLMEngine + +MODEL = os.path.expanduser("~/.local/share/exo/models/Sehyo--Qwen3.5-35B-A3B-NVFP4") +GEN = 600 +ea = EngineArgs(model=MODEL, served_model_name="test", gpu_memory_utilization=0.05, trust_remote_code=False, + load_format="fastsafetensors", enable_prefix_caching=True, attention_backend="FLASH_ATTN", + compilation_config={"cudagraph_mode": "none"}, disable_log_stats=True, max_num_batched_tokens=4096, + kv_transfer_config={"kv_connector": "StreamingConnector", "kv_role": "kv_both"}, + disable_hybrid_kv_cache_manager=False) +engine = LLMEngine.from_engine_args(ea) +tok = engine.get_tokenizer() +from exo.worker.engines.mlx.cache import KVPrefixCache +pc = KVPrefixCache(group=None) +set_prefix_cache(pc) + +from exo.disaggregated.prefill_server import ( + _patch_gdn_capture, _init_gdn_layer_order, _gdn_states, _gdn_call_idx, _ssm_call_idx, + _extract_vllm_cache, +) +from exo.disaggregated.streaming_connector import reset_shared_queue +_patch_gdn_capture() +_init_gdn_layer_order() +print(f"Engine loaded") + +article = ("The European Union announced sweeping new regulations on artificial intelligence. " * 500) +tids = tok.encode(article)[:22000] +msgs = [{"role": "user", "content": tok.decode(tids) + "\nSummarize the key points of this article."}] +tids = tok.encode(tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)) +ptids = tids[:-2] +print(f"Prompt: {len(ptids)} tokens") + +reset_shared_queue() +_gdn_states.clear() +_gdn_call_idx[0] = 0 +_ssm_call_idx[0] = 0 + +engine.add_request("r1", {"prompt_token_ids": ptids}, SamplingParams(max_tokens=2, temperature=0.7)) +done = False +tc = None +while engine.has_unfinished_requests() and not done: + for out in engine.step(): + if out.outputs and out.outputs[0].token_ids: + tc = _extract_vllm_cache(engine, "r1", len(ptids)) + engine.abort_request(["r1"]) + done = True; break +print(f"Extracted: {tc.num_layers if tc else 'NONE'} layers") + +if tc and _gdn_states: + from exo.worker.engines.vllm.kv_cache import ArraysLayerState, KVLayerState + replaced = 0 + for layer_idx in sorted(_gdn_states.keys()): + state = _gdn_states[layer_idx] + arrays = [] + if "conv" in state: arrays.append(state["conv"]) + if "ssm" in state: arrays.append(state["ssm"]) + if arrays and layer_idx < len(tc.layers): + tc.layers[layer_idx] = ArraysLayerState(arrays=arrays) + replaced += 1 + print(f"Replaced {replaced} GDN layers with clean prefill state") + kv_c = sum(1 for l in tc.layers if isinstance(l, KVLayerState) and l.keys.numel() > 0) + arr_c = sum(1 for l in tc.layers if isinstance(l, ArraysLayerState)) + print(f"Final cache: {kv_c} KV layers, {arr_c} Arrays layers") + +import mlx.core as mx +pc.add_kv_cache(mx.array(ptids), tc, None) +print("Stored hybrid cache") + +engine.add_request("r2", {"prompt_token_ids": ptids}, SamplingParams(max_tokens=GEN, temperature=0.7)) +t2 = time.perf_counter() +prev = 0; text2 = ""; done2 = False +while engine.has_unfinished_requests() and not done2: + for out in engine.step(): + if out.outputs: + prev = len(out.outputs[0].token_ids) + if out.outputs[0].text: text2 = out.outputs[0].text + if out.finished: done2 = True; break +e2 = time.perf_counter() - t2 +print(f"\nRequest 2: {prev} tokens in {e2:.1f}s ({prev/max(e2,0.01):.1f} tok/s)") +print(f"Output: {text2[:500]}") + +keywords = ["regulation", "AI", "high-risk", "compliance", "transparency", "ban", "EU", "framework"] +hits = sum(1 for kw in keywords if kw.lower() in text2.lower()) +print(f"\nKeyword hits: {hits}/{len(keywords)}") +if hits >= 2: + print("PASS") +else: + print(f"FAIL ({hits} hits)") + exit(1) diff --git a/tmp/spark/spark-build-apple-cdc-ncm.sh b/tmp/spark/spark-build-apple-cdc-ncm.sh old mode 100644 new mode 100755