mirror of
https://github.com/exo-explore/exo.git
synced 2026-05-24 14:45:41 -04:00
Performance optimizations
This commit is contained in:
24
scripts/check_kv_compat.py
Normal file
24
scripts/check_kv_compat.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import sys
|
||||
sys.path.insert(0, "src")
|
||||
import mlx.core as mx
|
||||
from mlx_lm import load
|
||||
from mlx_lm.models.cache import RotatingKVCache, KVCache
|
||||
|
||||
model, tok = load("mlx-community/gpt-oss-20b-MXFP4-Q8")
|
||||
|
||||
prompt = "Hello " * 2000
|
||||
tokens = tok.encode(prompt)
|
||||
print(f"Tokens: {len(tokens)}")
|
||||
|
||||
cache = model.make_cache()
|
||||
token_arr = mx.array([tokens])
|
||||
logits = model(token_arr, cache=cache)
|
||||
mx.eval(logits)
|
||||
|
||||
for i, c in enumerate(cache[:6]):
|
||||
if isinstance(c, KVCache) and not isinstance(c, RotatingKVCache) and c.keys is not None:
|
||||
k = c.keys.astype(mx.float32)
|
||||
print(f"Layer {i} KVCache: shape={c.keys.shape} offset={c.offset} first=[{float(k[0,0,0,0]):.6f}, {float(k[0,0,0,1]):.6f}] last=[{float(k[0,0,-1,-2]):.6f}, {float(k[0,0,-1,-1]):.6f}]")
|
||||
elif isinstance(c, RotatingKVCache) and c.keys is not None:
|
||||
k = c.keys.astype(mx.float32)
|
||||
print(f"Layer {i} RotatingKV: shape={c.keys.shape} _idx={c._idx} offset={c.offset} first=[{float(k[0,0,0,0]):.6f}, {float(k[0,0,0,1]):.6f}]")
|
||||
12
scripts/check_rotating_cache.py
Normal file
12
scripts/check_rotating_cache.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import mlx.core as mx
|
||||
from mlx_lm import load
|
||||
from mlx_lm.models.cache import RotatingKVCache
|
||||
|
||||
model, tok = load("mlx-community/gpt-oss-20b-MXFP4-Q8")
|
||||
cache = model.make_cache()
|
||||
tokens = mx.ones((1, 5000), dtype=mx.int32)
|
||||
model(tokens, cache=cache)
|
||||
mx.eval([c.keys for c in cache if c.keys is not None])
|
||||
for i, c in enumerate(cache[:4]):
|
||||
if isinstance(c, RotatingKVCache):
|
||||
print(f"Layer {i}: _idx={c._idx} offset={c.offset} keep={c.keep} max_size={c.max_size} keys={c.keys.shape}")
|
||||
80
scripts/compare_kv.py
Normal file
80
scripts/compare_kv.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import sys
|
||||
sys.path.insert(0, "src")
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
import socket
|
||||
from pathlib import Path
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from mlx_lm import load
|
||||
from mlx_lm.models.cache import RotatingKVCache, KVCache
|
||||
from exo.disaggregated.protocol import read_header, read_message, KVChunk, Done
|
||||
from exo.disaggregated.prefill_client import _nhd_to_bhsd, _torch_to_mx
|
||||
|
||||
ENDPOINT = sys.argv[1] if len(sys.argv) > 1 else "10.43.0.1:62988"
|
||||
MODEL = sys.argv[2] if len(sys.argv) > 2 else "mlx-community/Llama-3.2-1B-Instruct-bf16"
|
||||
MODEL_PATH = sys.argv[3] if len(sys.argv) > 3 else None
|
||||
|
||||
model, tok = load(MODEL_PATH or str(Path.home() / ".exo/models" / MODEL.replace("/", "--")))
|
||||
prompt = "The quick brown fox jumps over the lazy dog. " * 300
|
||||
tokens = tok.encode(prompt)
|
||||
print(f"Tokens: {len(tokens)}")
|
||||
|
||||
host, port = ENDPOINT.rsplit(":", 1)
|
||||
sock = socket.create_connection((host, int(port)), timeout=60)
|
||||
request = json.dumps({"model": MODEL, "token_ids": tokens, "start_pos": 0}).encode() + b"\n"
|
||||
sock.sendall(request)
|
||||
stream = sock.makefile("rb", buffering=65536)
|
||||
header = read_header(stream)
|
||||
|
||||
vllm_kv = defaultdict(list)
|
||||
while True:
|
||||
msg = read_message(stream, header)
|
||||
if msg is None or isinstance(msg, Done):
|
||||
break
|
||||
if isinstance(msg, KVChunk):
|
||||
vllm_kv[msg.layer_idx].append((msg.keys, msg.values))
|
||||
sock.close()
|
||||
|
||||
print(f"Received {len(vllm_kv)} layers from vLLM")
|
||||
|
||||
if hasattr(model, "make_cache"):
|
||||
mlx_cache = model.make_cache()
|
||||
else:
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
mlx_cache = make_prompt_cache(model)
|
||||
token_arr = mx.array([tokens[:-2]])
|
||||
mlx_logits = model(token_arr, cache=mlx_cache)
|
||||
mx.eval(mlx_logits)
|
||||
|
||||
for i in range(min(6, len(mlx_cache))):
|
||||
c = mlx_cache[i]
|
||||
if c.keys is None:
|
||||
continue
|
||||
mlx_k = c.keys.astype(mx.float32)
|
||||
|
||||
if i not in vllm_kv:
|
||||
print(f"Layer {i}: no vLLM data")
|
||||
continue
|
||||
chunks = vllm_kv[i]
|
||||
vk = torch.cat([k for k, v in chunks], dim=0) if len(chunks) > 1 else chunks[0][0]
|
||||
vk_mx = _torch_to_mx(vk.permute(1, 0, 2).unsqueeze(0)).astype(mx.float32)
|
||||
|
||||
n = min(mlx_k.shape[2], vk_mx.shape[2])
|
||||
diff = mx.abs(mlx_k[:, :, :n, :] - vk_mx[:, :, :n, :])
|
||||
max_diff = mx.max(diff).item()
|
||||
mean_diff = mx.mean(diff).item()
|
||||
cache_type = "RotatingKV" if isinstance(c, RotatingKVCache) else "KV"
|
||||
print(f"Layer {i} ({cache_type}): mlx={mlx_k.shape} vllm={vk_mx.shape} max_diff={max_diff:.6f} mean_diff={mean_diff:.6f}")
|
||||
|
||||
a = mlx_k[:, :, :n, :].reshape(-1)
|
||||
b_vec = vk_mx[:, :, :n, :].reshape(-1)
|
||||
cos_sim = float(mx.sum(a * b_vec).item()) / (float(mx.sqrt(mx.sum(a * a)).item()) * float(mx.sqrt(mx.sum(b_vec * b_vec)).item()) + 1e-8)
|
||||
diff_tensor = mx.abs(mlx_k[:, :, :n, :] - vk_mx[:, :, :n, :])
|
||||
max_idx = mx.argmax(diff_tensor.reshape(-1)).item()
|
||||
total_elems = diff_tensor.shape[1] * n * diff_tensor.shape[3]
|
||||
h = (max_idx // (n * diff_tensor.shape[3])) % diff_tensor.shape[1]
|
||||
s = (max_idx // diff_tensor.shape[3]) % n
|
||||
d = max_idx % diff_tensor.shape[3]
|
||||
print(f" cosine_sim={cos_sim:.6f} max_diff={max_diff:.4f} at h={h} pos={s} dim={d}: mlx={float(mlx_k[0,h,s,d].item()):.6f} vllm={float(vk_mx[0,h,s,d].item()):.6f}")
|
||||
print(f" mean_diff={mean_diff:.6f} p50={float(mx.sort(diff_tensor.reshape(-1))[diff_tensor.size // 2].item()):.6f} p99={float(mx.sort(diff_tensor.reshape(-1))[int(diff_tensor.size * 0.99)].item()):.6f}")
|
||||
@@ -4,11 +4,9 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
|
||||
captured_layers: dict[str, Any] = {}
|
||||
|
||||
@@ -20,6 +20,7 @@ os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||
|
||||
from exo.worker.runner.bootstrap import _ensure_cuda_libs
|
||||
|
||||
_ensure_cuda_libs()
|
||||
|
||||
import torch
|
||||
@@ -44,8 +45,8 @@ def main():
|
||||
parser.add_argument("--prompt", default="Hello, world! How are you today?", help="Prompt to prefill")
|
||||
args = parser.parse_args()
|
||||
|
||||
from exo.worker.engines.vllm.vllm_generator import load_vllm_engine
|
||||
from exo.worker.engines.vllm.growable_cache import get_model_runner
|
||||
from exo.worker.engines.vllm.vllm_generator import load_vllm_engine
|
||||
|
||||
print(f"Loading vLLM engine from {args.model}...")
|
||||
engine, _, prefix_cache = load_vllm_engine(
|
||||
|
||||
@@ -21,6 +21,7 @@ os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||
|
||||
from exo.worker.runner.bootstrap import _ensure_cuda_libs
|
||||
|
||||
_ensure_cuda_libs()
|
||||
|
||||
import torch
|
||||
@@ -77,8 +78,9 @@ def main():
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
from exo.worker.engines.vllm.growable_cache import patch_vllm, set_prefix_cache
|
||||
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.vllm.growable_cache import patch_vllm, set_prefix_cache
|
||||
|
||||
patch_vllm()
|
||||
|
||||
@@ -101,18 +103,19 @@ def main():
|
||||
},
|
||||
)
|
||||
|
||||
print(f"Loading engine with KVConnector...")
|
||||
print("Loading engine with KVConnector...")
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
print("Engine loaded.")
|
||||
|
||||
|
||||
from exo.worker.engines.vllm.growable_cache import get_model_runner
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from capture_connector import captured_layers as gdn_captured
|
||||
|
||||
model_runner = get_model_runner()
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as orig_causal_conv1d_fn
|
||||
import vllm.model_executor.layers.mamba.ops.causal_conv1d as cc_mod
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn as orig_causal_conv1d_fn,
|
||||
)
|
||||
|
||||
gdn_states: dict[int, dict[str, torch.Tensor]] = {}
|
||||
gdn_call_idx = [0]
|
||||
@@ -147,11 +150,11 @@ def main():
|
||||
continue
|
||||
if hasattr(mod, 'causal_conv1d_fn') and mod.causal_conv1d_fn is orig_causal_conv1d_fn:
|
||||
mod.causal_conv1d_fn = patched_causal_conv1d_fn
|
||||
print(f" Patched causal_conv1d_fn")
|
||||
print(" Patched causal_conv1d_fn")
|
||||
|
||||
from exo.worker.engines.vllm.vllm_generator import VllmBatchEngine
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams, InputMessage
|
||||
from exo.shared.types.tasks import TaskId
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.worker.engines.vllm.vllm_generator import VllmBatchEngine
|
||||
|
||||
batch_engine = VllmBatchEngine(engine=engine, model_id=args.model, prefix_cache=prefix_cache)
|
||||
|
||||
@@ -163,7 +166,7 @@ def main():
|
||||
|
||||
task_id = batch_engine.submit(task_id=TaskId("extract"), task_params=task, prompt=args.prompt)
|
||||
|
||||
print(f"Running prefill via VllmBatchEngine...")
|
||||
print("Running prefill via VllmBatchEngine...")
|
||||
t0 = time.perf_counter()
|
||||
while batch_engine.has_work:
|
||||
results = batch_engine.step()
|
||||
|
||||
@@ -70,7 +70,7 @@ def main():
|
||||
if vllm_token_ids:
|
||||
print(f" Using vLLM token_ids ({len(vllm_token_ids)} tokens)")
|
||||
else:
|
||||
print(f" WARNING: No token_ids in metadata")
|
||||
print(" WARNING: No token_ids in metadata")
|
||||
|
||||
print(f"\nLoading MLX model: {args.model}")
|
||||
model, tokenizer = load(args.model)
|
||||
@@ -161,7 +161,7 @@ def main():
|
||||
print(f"\n Injected: {injected} layers, Skipped: {skipped} layers")
|
||||
|
||||
from exo.worker.engines.vllm.kv_cache import TorchKVCache as TKV
|
||||
print(f"\nRound-trip test (MLX → torch → MLX)...")
|
||||
print("\nRound-trip test (MLX → torch → MLX)...")
|
||||
rt_caches = model.make_cache()
|
||||
rt_tokens = mx.array(vllm_token_ids)
|
||||
rt_logits = model(rt_tokens[None], cache=rt_caches)
|
||||
@@ -184,7 +184,7 @@ def main():
|
||||
rt_max_diff = max(rt_max_diff, d)
|
||||
print(f" Round-trip max diff: {rt_max_diff:.4e} ({'PASS' if rt_max_diff < 0.01 else 'FAIL'})")
|
||||
|
||||
print(f"\nComparing with MLX-native prefill...")
|
||||
print("\nComparing with MLX-native prefill...")
|
||||
native_caches = rt_caches
|
||||
|
||||
for i in range(num_model_layers):
|
||||
@@ -243,12 +243,12 @@ def main():
|
||||
print(f" Model (MLX): {args.model}")
|
||||
print(f" Prompt tokens: {num_tokens}")
|
||||
print(f" Layers injected: {injected}/{num_model_layers}")
|
||||
print(f" Type mismatches: 0")
|
||||
print(" Type mismatches: 0")
|
||||
print(f" Generated {len(generated_tokens)} tokens")
|
||||
print(f" Text: {generated_text!r}")
|
||||
|
||||
if False:
|
||||
print(f"\n GAPS FOUND:")
|
||||
print("\n GAPS FOUND:")
|
||||
for idx, got, expected in type_mismatches:
|
||||
print(f" Layer {idx}: vLLM gives KV tensors, MLX wants {expected}")
|
||||
arrays_layers = [i for i, c in enumerate(caches) if isinstance(c, ArraysCache)]
|
||||
@@ -256,9 +256,9 @@ def main():
|
||||
print(f" ArraysCache layers (not populated): {arrays_layers[:10]}{'...' if len(arrays_layers) > 10 else ''}")
|
||||
|
||||
if generated_tokens and not all(t == generated_tokens[0] for t in generated_tokens):
|
||||
print(f"\n COHERENT OUTPUT: YES (varied tokens)")
|
||||
print("\n COHERENT OUTPUT: YES (varied tokens)")
|
||||
else:
|
||||
print(f"\n COHERENT OUTPUT: POSSIBLY NOT (all same token)")
|
||||
print("\n COHERENT OUTPUT: POSSIBLY NOT (all same token)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user