Performance optimizations

This commit is contained in:
Ryuichi Leo Takashige
2026-03-19 11:51:57 +00:00
parent 6fa2cc1265
commit 03ea3cf6cd
21 changed files with 459 additions and 129 deletions

View File

@@ -0,0 +1,24 @@
import sys
sys.path.insert(0, "src")
import mlx.core as mx
from mlx_lm import load
from mlx_lm.models.cache import RotatingKVCache, KVCache
model, tok = load("mlx-community/gpt-oss-20b-MXFP4-Q8")
prompt = "Hello " * 2000
tokens = tok.encode(prompt)
print(f"Tokens: {len(tokens)}")
cache = model.make_cache()
token_arr = mx.array([tokens])
logits = model(token_arr, cache=cache)
mx.eval(logits)
for i, c in enumerate(cache[:6]):
if isinstance(c, KVCache) and not isinstance(c, RotatingKVCache) and c.keys is not None:
k = c.keys.astype(mx.float32)
print(f"Layer {i} KVCache: shape={c.keys.shape} offset={c.offset} first=[{float(k[0,0,0,0]):.6f}, {float(k[0,0,0,1]):.6f}] last=[{float(k[0,0,-1,-2]):.6f}, {float(k[0,0,-1,-1]):.6f}]")
elif isinstance(c, RotatingKVCache) and c.keys is not None:
k = c.keys.astype(mx.float32)
print(f"Layer {i} RotatingKV: shape={c.keys.shape} _idx={c._idx} offset={c.offset} first=[{float(k[0,0,0,0]):.6f}, {float(k[0,0,0,1]):.6f}]")

View File

@@ -0,0 +1,12 @@
import mlx.core as mx
from mlx_lm import load
from mlx_lm.models.cache import RotatingKVCache
model, tok = load("mlx-community/gpt-oss-20b-MXFP4-Q8")
cache = model.make_cache()
tokens = mx.ones((1, 5000), dtype=mx.int32)
model(tokens, cache=cache)
mx.eval([c.keys for c in cache if c.keys is not None])
for i, c in enumerate(cache[:4]):
if isinstance(c, RotatingKVCache):
print(f"Layer {i}: _idx={c._idx} offset={c.offset} keep={c.keep} max_size={c.max_size} keys={c.keys.shape}")

80
scripts/compare_kv.py Normal file
View File

@@ -0,0 +1,80 @@
import sys
sys.path.insert(0, "src")
import mlx.core as mx
import torch
import socket
from pathlib import Path
import json
from collections import defaultdict
from mlx_lm import load
from mlx_lm.models.cache import RotatingKVCache, KVCache
from exo.disaggregated.protocol import read_header, read_message, KVChunk, Done
from exo.disaggregated.prefill_client import _nhd_to_bhsd, _torch_to_mx
ENDPOINT = sys.argv[1] if len(sys.argv) > 1 else "10.43.0.1:62988"
MODEL = sys.argv[2] if len(sys.argv) > 2 else "mlx-community/Llama-3.2-1B-Instruct-bf16"
MODEL_PATH = sys.argv[3] if len(sys.argv) > 3 else None
model, tok = load(MODEL_PATH or str(Path.home() / ".exo/models" / MODEL.replace("/", "--")))
prompt = "The quick brown fox jumps over the lazy dog. " * 300
tokens = tok.encode(prompt)
print(f"Tokens: {len(tokens)}")
host, port = ENDPOINT.rsplit(":", 1)
sock = socket.create_connection((host, int(port)), timeout=60)
request = json.dumps({"model": MODEL, "token_ids": tokens, "start_pos": 0}).encode() + b"\n"
sock.sendall(request)
stream = sock.makefile("rb", buffering=65536)
header = read_header(stream)
vllm_kv = defaultdict(list)
while True:
msg = read_message(stream, header)
if msg is None or isinstance(msg, Done):
break
if isinstance(msg, KVChunk):
vllm_kv[msg.layer_idx].append((msg.keys, msg.values))
sock.close()
print(f"Received {len(vllm_kv)} layers from vLLM")
if hasattr(model, "make_cache"):
mlx_cache = model.make_cache()
else:
from mlx_lm.models.cache import make_prompt_cache
mlx_cache = make_prompt_cache(model)
token_arr = mx.array([tokens[:-2]])
mlx_logits = model(token_arr, cache=mlx_cache)
mx.eval(mlx_logits)
for i in range(min(6, len(mlx_cache))):
c = mlx_cache[i]
if c.keys is None:
continue
mlx_k = c.keys.astype(mx.float32)
if i not in vllm_kv:
print(f"Layer {i}: no vLLM data")
continue
chunks = vllm_kv[i]
vk = torch.cat([k for k, v in chunks], dim=0) if len(chunks) > 1 else chunks[0][0]
vk_mx = _torch_to_mx(vk.permute(1, 0, 2).unsqueeze(0)).astype(mx.float32)
n = min(mlx_k.shape[2], vk_mx.shape[2])
diff = mx.abs(mlx_k[:, :, :n, :] - vk_mx[:, :, :n, :])
max_diff = mx.max(diff).item()
mean_diff = mx.mean(diff).item()
cache_type = "RotatingKV" if isinstance(c, RotatingKVCache) else "KV"
print(f"Layer {i} ({cache_type}): mlx={mlx_k.shape} vllm={vk_mx.shape} max_diff={max_diff:.6f} mean_diff={mean_diff:.6f}")
a = mlx_k[:, :, :n, :].reshape(-1)
b_vec = vk_mx[:, :, :n, :].reshape(-1)
cos_sim = float(mx.sum(a * b_vec).item()) / (float(mx.sqrt(mx.sum(a * a)).item()) * float(mx.sqrt(mx.sum(b_vec * b_vec)).item()) + 1e-8)
diff_tensor = mx.abs(mlx_k[:, :, :n, :] - vk_mx[:, :, :n, :])
max_idx = mx.argmax(diff_tensor.reshape(-1)).item()
total_elems = diff_tensor.shape[1] * n * diff_tensor.shape[3]
h = (max_idx // (n * diff_tensor.shape[3])) % diff_tensor.shape[1]
s = (max_idx // diff_tensor.shape[3]) % n
d = max_idx % diff_tensor.shape[3]
print(f" cosine_sim={cos_sim:.6f} max_diff={max_diff:.4f} at h={h} pos={s} dim={d}: mlx={float(mlx_k[0,h,s,d].item()):.6f} vllm={float(vk_mx[0,h,s,d].item()):.6f}")
print(f" mean_diff={mean_diff:.6f} p50={float(mx.sort(diff_tensor.reshape(-1))[diff_tensor.size // 2].item()):.6f} p99={float(mx.sort(diff_tensor.reshape(-1))[int(diff_tensor.size * 0.99)].item()):.6f}")

View File

@@ -4,11 +4,9 @@ from dataclasses import dataclass
from typing import Any
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
captured_layers: dict[str, Any] = {}

View File

@@ -20,6 +20,7 @@ os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
os.environ["VLLM_BATCH_INVARIANT"] = "1"
from exo.worker.runner.bootstrap import _ensure_cuda_libs
_ensure_cuda_libs()
import torch
@@ -44,8 +45,8 @@ def main():
parser.add_argument("--prompt", default="Hello, world! How are you today?", help="Prompt to prefill")
args = parser.parse_args()
from exo.worker.engines.vllm.vllm_generator import load_vllm_engine
from exo.worker.engines.vllm.growable_cache import get_model_runner
from exo.worker.engines.vllm.vllm_generator import load_vllm_engine
print(f"Loading vLLM engine from {args.model}...")
engine, _, prefix_cache = load_vllm_engine(

View File

@@ -21,6 +21,7 @@ os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
os.environ["VLLM_BATCH_INVARIANT"] = "1"
from exo.worker.runner.bootstrap import _ensure_cuda_libs
_ensure_cuda_libs()
import torch
@@ -77,8 +78,9 @@ def main():
from vllm.engine.arg_utils import EngineArgs
from vllm.v1.engine.llm_engine import LLMEngine
from exo.worker.engines.vllm.growable_cache import patch_vllm, set_prefix_cache
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.vllm.growable_cache import patch_vllm, set_prefix_cache
patch_vllm()
@@ -101,18 +103,19 @@ def main():
},
)
print(f"Loading engine with KVConnector...")
print("Loading engine with KVConnector...")
engine = LLMEngine.from_engine_args(engine_args)
print("Engine loaded.")
from exo.worker.engines.vllm.growable_cache import get_model_runner
from vllm.model_executor.layers.mamba.abstract import MambaBase
from capture_connector import captured_layers as gdn_captured
model_runner = get_model_runner()
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as orig_causal_conv1d_fn
import vllm.model_executor.layers.mamba.ops.causal_conv1d as cc_mod
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn as orig_causal_conv1d_fn,
)
gdn_states: dict[int, dict[str, torch.Tensor]] = {}
gdn_call_idx = [0]
@@ -147,11 +150,11 @@ def main():
continue
if hasattr(mod, 'causal_conv1d_fn') and mod.causal_conv1d_fn is orig_causal_conv1d_fn:
mod.causal_conv1d_fn = patched_causal_conv1d_fn
print(f" Patched causal_conv1d_fn")
print(" Patched causal_conv1d_fn")
from exo.worker.engines.vllm.vllm_generator import VllmBatchEngine
from exo.shared.types.text_generation import TextGenerationTaskParams, InputMessage
from exo.shared.types.tasks import TaskId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.worker.engines.vllm.vllm_generator import VllmBatchEngine
batch_engine = VllmBatchEngine(engine=engine, model_id=args.model, prefix_cache=prefix_cache)
@@ -163,7 +166,7 @@ def main():
task_id = batch_engine.submit(task_id=TaskId("extract"), task_params=task, prompt=args.prompt)
print(f"Running prefill via VllmBatchEngine...")
print("Running prefill via VllmBatchEngine...")
t0 = time.perf_counter()
while batch_engine.has_work:
results = batch_engine.step()

View File

@@ -70,7 +70,7 @@ def main():
if vllm_token_ids:
print(f" Using vLLM token_ids ({len(vllm_token_ids)} tokens)")
else:
print(f" WARNING: No token_ids in metadata")
print(" WARNING: No token_ids in metadata")
print(f"\nLoading MLX model: {args.model}")
model, tokenizer = load(args.model)
@@ -161,7 +161,7 @@ def main():
print(f"\n Injected: {injected} layers, Skipped: {skipped} layers")
from exo.worker.engines.vllm.kv_cache import TorchKVCache as TKV
print(f"\nRound-trip test (MLX → torch → MLX)...")
print("\nRound-trip test (MLX → torch → MLX)...")
rt_caches = model.make_cache()
rt_tokens = mx.array(vllm_token_ids)
rt_logits = model(rt_tokens[None], cache=rt_caches)
@@ -184,7 +184,7 @@ def main():
rt_max_diff = max(rt_max_diff, d)
print(f" Round-trip max diff: {rt_max_diff:.4e} ({'PASS' if rt_max_diff < 0.01 else 'FAIL'})")
print(f"\nComparing with MLX-native prefill...")
print("\nComparing with MLX-native prefill...")
native_caches = rt_caches
for i in range(num_model_layers):
@@ -243,12 +243,12 @@ def main():
print(f" Model (MLX): {args.model}")
print(f" Prompt tokens: {num_tokens}")
print(f" Layers injected: {injected}/{num_model_layers}")
print(f" Type mismatches: 0")
print(" Type mismatches: 0")
print(f" Generated {len(generated_tokens)} tokens")
print(f" Text: {generated_text!r}")
if False:
print(f"\n GAPS FOUND:")
print("\n GAPS FOUND:")
for idx, got, expected in type_mismatches:
print(f" Layer {idx}: vLLM gives KV tensors, MLX wants {expected}")
arrays_layers = [i for i, c in enumerate(caches) if isinstance(c, ArraysCache)]
@@ -256,9 +256,9 @@ def main():
print(f" ArraysCache layers (not populated): {arrays_layers[:10]}{'...' if len(arrays_layers) > 10 else ''}")
if generated_tokens and not all(t == generated_tokens[0] for t in generated_tokens):
print(f"\n COHERENT OUTPUT: YES (varied tokens)")
print("\n COHERENT OUTPUT: YES (varied tokens)")
else:
print(f"\n COHERENT OUTPUT: POSSIBLY NOT (all same token)")
print("\n COHERENT OUTPUT: POSSIBLY NOT (all same token)")
if __name__ == "__main__":

View File

@@ -52,8 +52,6 @@ class BatchConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBaseCl
return
layer_idx = int(m.group(1))
torch.cuda.synchronize()
if isinstance(kv_layer, (list, tuple)):
return
@@ -68,10 +66,8 @@ class BatchConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBaseCl
v_flat = v_all.reshape(-1, *v_all.shape[-2:]) # pyright: ignore[reportAny]
valid = slot_mapping >= 0 # pyright: ignore[reportAny]
safe_sm = slot_mapping.clamp(min=0) # pyright: ignore[reportAny]
keys = k_flat[safe_sm] # pyright: ignore[reportAny]
values = v_flat[safe_sm] # pyright: ignore[reportAny]
keys[~valid] = 0
values[~valid] = 0
keys = k_flat[safe_sm][valid] # pyright: ignore[reportAny]
values = v_flat[safe_sm][valid] # pyright: ignore[reportAny]
prev = self.captured_layers.get(layer_idx)
if prev is not None:

View File

@@ -48,28 +48,10 @@ def _inject_kv_cache(cache: KVCache, keys: torch.Tensor, values: torch.Tensor, n
def _inject_rotating_kv_cache(cache: RotatingKVCache, keys: torch.Tensor, values: torch.Tensor, num_tokens: int) -> None:
k_mx, v_mx = _nhd_to_bhsd(keys, values)
seq_len = int(k_mx.shape[2])
if seq_len <= cache.max_size:
cache.keys = k_mx
cache.values = v_mx
cache.offset = seq_len
cache._idx = seq_len
else:
keep = cache.keep
window = cache.max_size
if keep == 0:
cache.keys = k_mx[:, :, -window:, :]
cache.values = v_mx[:, :, -window:, :]
cache._idx = window
else:
sink_keys = k_mx[:, :, :keep, :]
sink_values = v_mx[:, :, :keep, :]
recent_keys = k_mx[:, :, -(window - keep):, :]
recent_values = v_mx[:, :, -(window - keep):, :]
cache.keys = mx.concatenate([sink_keys, recent_keys], axis=2)
cache.values = mx.concatenate([sink_values, recent_values], axis=2)
cache._idx = keep
cache.offset = num_tokens
cache.keys = k_mx
cache.values = v_mx
cache.offset = seq_len
cache._idx = seq_len
def _inject_arrays_cache(cache: ArraysCache, arrays: list[torch.Tensor]) -> None:
@@ -82,6 +64,8 @@ def remote_prefill(
model_id: str,
mlx_model: Model,
on_prefill_progress: Callable[[int, int], None] | None = None,
existing_cache: list[KVCache | RotatingKVCache | ArraysCache] | None = None,
start_pos: int = 0,
) -> tuple[list[KVCache | RotatingKVCache | ArraysCache], int]:
if ":" in endpoint:
host, port_str = endpoint.rsplit(":", 1)
@@ -90,12 +74,12 @@ def remote_prefill(
host = endpoint
port = 8900
logger.info(f"Connecting to prefill server at {host}:{port} ({len(token_ids)} tokens)")
logger.info(f"Connecting to prefill server at {host}:{port} ({len(token_ids)} tokens, start_pos={start_pos})")
t0 = time.perf_counter()
sock = socket.create_connection((host, port), timeout=30)
try:
request = json.dumps({"model": model_id, "token_ids": token_ids}).encode("utf-8") + b"\n"
request = json.dumps({"model": model_id, "token_ids": token_ids, "start_pos": start_pos}).encode("utf-8") + b"\n"
sock.sendall(request)
raw_stream = sock.makefile("rb", buffering=65536)
@@ -147,7 +131,16 @@ def remote_prefill(
finally:
sock.close()
caches: list[KVCache | RotatingKVCache | ArraysCache] = cast(list[KVCache | RotatingKVCache | ArraysCache], mlx_model.make_cache()) # pyright: ignore[reportUnknownMemberType]
if existing_cache is not None and start_pos > 0:
caches = existing_cache
else:
if hasattr(mlx_model, "make_cache"):
caches = cast(list[KVCache | RotatingKVCache | ArraysCache], mlx_model.make_cache()) # pyright: ignore[reportUnknownMemberType]
else:
from mlx_lm.models.cache import make_prompt_cache
caches = cast(list[KVCache | RotatingKVCache | ArraysCache], make_prompt_cache(mlx_model)) # pyright: ignore[reportUnknownMemberType]
final_offset = start_pos + total_tokens
for i, cache in enumerate(caches):
if i in kv_buffers:
@@ -161,19 +154,25 @@ def remote_prefill(
all_values = torch.cat([v for _k, v in chunks], dim=0) # type: ignore
if isinstance(cache, RotatingKVCache):
_inject_rotating_kv_cache(cache, all_keys, all_values, total_tokens) # pyright: ignore[reportUnknownArgumentType]
_inject_rotating_kv_cache(cache, all_keys, all_values, final_offset) # pyright: ignore[reportUnknownArgumentType]
elif isinstance(cache, KVCache):
_inject_kv_cache(cache, all_keys, all_values, total_tokens) # pyright: ignore[reportUnknownArgumentType]
if start_pos > 0 and cache.keys is not None:
k_new, v_new = _nhd_to_bhsd(all_keys, all_values) # pyright: ignore[reportUnknownArgumentType]
cache.keys = mx.concatenate([cache.keys[:, :, :start_pos, :], k_new], axis=2)
cache.values = mx.concatenate([cache.values[:, :, :start_pos, :], v_new], axis=2)
cache.offset = final_offset
else:
_inject_kv_cache(cache, all_keys, all_values, final_offset) # pyright: ignore[reportUnknownArgumentType]
if i in arrays_buffers and isinstance(cache, ArraysCache):
_inject_arrays_cache(cache, arrays_buffers[i])
t_injected = time.perf_counter()
logger.info(
f"Remote prefill: {total_tokens} tokens, "
f"Remote prefill: {total_tokens} new tokens (start_pos={start_pos}, final_offset={final_offset}), "
f"transfer={((t_received - t0) * 1000):.0f}ms, "
f"inject={((t_injected - t_received) * 1000):.0f}ms, "
f"total={((t_injected - t0) * 1000):.0f}ms"
)
return caches, total_tokens
return caches, final_offset

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import contextlib
import json
import queue
import socketserver
import threading
import time
@@ -20,9 +19,12 @@ from exo.disaggregated.protocol import (
if TYPE_CHECKING:
from vllm.v1.engine.llm_engine import LLMEngine
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.vllm.kv_cache import KVLayerState, TorchKVCache
from exo.worker.runner.bootstrap import logger
_engine_ref: LLMEngine | None = None
_prefix_cache_ref: KVPrefixCache | None = None
_overlapping: bool = True
_connector_patched: bool = False
_gdn_patched: bool = False
@@ -79,8 +81,10 @@ def _patch_gdn_capture() -> None:
_gdn_patched = True
try:
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as orig_fn # type: ignore
import vllm.model_executor.layers.mamba.ops.causal_conv1d as cc_mod # type: ignore
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn as orig_fn, # type: ignore
)
except ImportError:
return
@@ -90,7 +94,6 @@ def _patch_gdn_capture() -> None:
x = args[0] if args else None
if x is not None and x.shape[0] <= 100: # type: ignore
return result
torch.cuda.synchronize()
ci: int = cache_indices[0].item() if cache_indices.numel() > 0 else 0 # type: ignore
idx = _gdn_call_idx[0]
if _gdn_layer_order and idx < len(_gdn_layer_order) * 100:
@@ -106,8 +109,8 @@ def _patch_gdn_capture() -> None:
for mod in list(sys.modules.values()):
if mod is None or mod is cc_mod:
continue
if hasattr(mod, "causal_conv1d_fn") and getattr(mod, "causal_conv1d_fn") is orig_fn:
setattr(mod, "causal_conv1d_fn", patched_fn)
if hasattr(mod, "causal_conv1d_fn") and mod.causal_conv1d_fn is orig_fn:
mod.causal_conv1d_fn = patched_fn
logger.info("Patched causal_conv1d_fn for GDN state capture")
@@ -150,32 +153,56 @@ def _get_layer_info(engine: LLMEngine) -> tuple[int, str, list[dict[str, Any]]]:
return num_layers, dtype_str, layers_info
def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], wfile: Any) -> None: # pyright: ignore[reportAny]
from exo.disaggregated.streaming_connector import StreamingConnector
def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], start_pos: int, wfile: Any) -> None: # pyright: ignore[reportAny]
from exo.worker.engines.vllm.growable_cache import get_model_runner
model_runner = get_model_runner()
assert model_runner is not None
from exo.disaggregated.streaming_connector import get_shared_queue, reset_shared_queue
from exo.disaggregated.streaming_connector import (
get_shared_queue,
reset_shared_queue,
)
reset_shared_queue()
_gdn_states.clear()
_gdn_call_idx[0] = 0
layer_queue = get_shared_queue()
server_cached = 0
cached_torch: TorchKVCache | None = None
if _prefix_cache_ref is not None:
cached_torch, server_cached, _ = _prefix_cache_ref.lookup(token_ids)
skip_tokens = max(0, start_pos - server_cached)
num_layers, dtype_str, layers_info = _get_layer_info(engine)
write_header(wfile, {"num_layers": num_layers, "dtype": dtype_str, "layers": layers_info}) # pyright: ignore[reportAny]
if start_pos < server_cached and cached_torch is not None:
for i, layer in enumerate(cached_torch.layers):
if isinstance(layer, KVLayerState) and layer.keys.numel() > 0:
keys = layer.keys
values = layer.values
if keys.dim() == 4:
keys = keys.reshape(-1, keys.shape[-2], keys.shape[-1])
values = values.reshape(-1, values.shape[-2], values.shape[-1])
keys = keys[start_pos:server_cached]
values = values[start_pos:server_cached]
if keys.shape[0] > 0:
write_kv_chunk(wfile, i, keys, values) # pyright: ignore[reportAny]
logger.info(f"Sent cached KV for positions {start_pos}-{server_cached} from server prefix cache")
from vllm.sampling_params import (
SamplingParams,
)
prefill_token_ids = token_ids[:-2] if len(token_ids) > 2 else token_ids
request_id = f"prefill-{time.monotonic_ns()}"
params = SamplingParams(max_tokens=1, detokenize=False) # pyright: ignore[reportCallIssue]
engine.add_request(request_id, {"prompt_token_ids": token_ids}, params) # pyright: ignore[reportArgumentType]
params = SamplingParams(max_tokens=2, detokenize=False) # pyright: ignore[reportCallIssue]
engine.add_request(request_id, {"prompt_token_ids": prefill_token_ids}, params) # pyright: ignore[reportArgumentType]
chunks_sent = [0]
layer_token_counts: dict[int, int] = {}
def writer_loop() -> None:
while True:
@@ -183,6 +210,18 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], wfile: Any
if item is None:
break
layer_idx, keys, values = item
prev = layer_token_counts.get(layer_idx, 0)
n = keys.shape[0]
new_total = prev + n
layer_token_counts[layer_idx] = new_total
if new_total <= skip_tokens:
continue
if prev < skip_tokens:
trim = skip_tokens - prev
keys = keys[trim:]
values = values[trim:]
write_kv_chunk(wfile, layer_idx, keys, values) # pyright: ignore[reportAny]
chunks_sent[0] += 1
@@ -193,6 +232,7 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], wfile: Any
outputs = engine.step()
for output in outputs:
if output.request_id == request_id and output.outputs[0].token_ids:
_save_vllm_prefix_cache(engine, request_id, prefill_token_ids)
engine.abort_request([request_id]) # type: ignore
break
else:
@@ -201,13 +241,18 @@ def _run_prefill_overlapping(engine: LLMEngine, token_ids: list[int], wfile: Any
layer_queue.put(None)
writer_thread.join()
logger.info(f"Overlapping prefill: sent {chunks_sent[0]} KV chunks")
actual_per_layer = max(layer_token_counts.values()) if layer_token_counts else 0
new_tokens_sent = max(0, actual_per_layer - skip_tokens)
cached_tokens_sent = max(0, server_cached - start_pos) if start_pos < server_cached else 0
tokens_sent = cached_tokens_sent + new_tokens_sent
logger.info(f"Overlapping prefill: sent {chunks_sent[0]} chunks, {tokens_sent} tokens (server_cached={server_cached}, skip={skip_tokens})")
_stream_gdn_states(engine, wfile, num_layers, layers_info)
write_done(wfile, len(token_ids)) # pyright: ignore[reportAny]
cached_arrays: list[tuple[int, list[torch.Tensor]]] = []
_stream_gdn_states_and_collect(engine, wfile, num_layers, layers_info, cached_arrays)
write_done(wfile, tokens_sent) # pyright: ignore[reportAny]
def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], wfile: Any) -> None: # pyright: ignore[reportAny]
def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], start_pos: int, wfile: Any) -> None: # pyright: ignore[reportAny]
from exo.worker.engines.vllm.growable_cache import get_model_runner
num_layers, dtype_str, layers_info = _get_layer_info(engine)
@@ -215,20 +260,29 @@ def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], wfile: Any) -> N
model_runner = get_model_runner()
assert model_runner is not None
from exo.disaggregated.batch_connector import clear_shared_captured_layers, get_shared_captured_layers
from exo.disaggregated.batch_connector import (
clear_shared_captured_layers,
get_shared_captured_layers,
)
_gdn_states.clear()
_gdn_call_idx[0] = 0
clear_shared_captured_layers()
captured_layers = get_shared_captured_layers()
server_cached = 0
if _prefix_cache_ref is not None:
_, server_cached, _ = _prefix_cache_ref.lookup(token_ids)
skip_tokens = max(0, start_pos - server_cached)
from vllm.sampling_params import (
SamplingParams,
)
prefill_token_ids = token_ids[:-2] if len(token_ids) > 2 else token_ids
request_id = f"prefill-{time.monotonic_ns()}"
params = SamplingParams(max_tokens=1, detokenize=False) # pyright: ignore[reportCallIssue]
engine.add_request(request_id, {"prompt_token_ids": token_ids}, params) # pyright: ignore[reportArgumentType]
params = SamplingParams(max_tokens=2, detokenize=False) # pyright: ignore[reportCallIssue]
engine.add_request(request_id, {"prompt_token_ids": prefill_token_ids}, params) # pyright: ignore[reportArgumentType]
while engine.has_unfinished_requests():
outputs = engine.step()
@@ -242,17 +296,32 @@ def _run_prefill_batch(engine: LLMEngine, token_ids: list[int], wfile: Any) -> N
write_header(wfile, {"num_layers": num_layers, "dtype": dtype_str, "layers": layers_info}) # pyright: ignore[reportAny]
logger.info(f"Batch prefill: streaming {len(captured_layers)} captured layers")
all_kv: list[tuple[int, torch.Tensor, torch.Tensor]] = []
for layer_idx in sorted(captured_layers.keys()):
layer_data = captured_layers[layer_idx]
write_kv_chunk(wfile, layer_idx, layer_data["keys"], layer_data["values"]) # pyright: ignore[reportAny]
keys = layer_data["keys"]
values = layer_data["values"]
all_kv.append((layer_idx, keys, values))
if keys.shape[0] > skip_tokens:
write_kv_chunk(wfile, layer_idx, keys[skip_tokens:], values[skip_tokens:]) # pyright: ignore[reportAny]
clear_shared_captured_layers()
_stream_gdn_states(engine, wfile, num_layers, layers_info)
write_done(wfile, len(token_ids)) # pyright: ignore[reportAny]
actual_per_layer = max((k.shape[0] for _, k, _ in all_kv), default=0)
tokens_sent = max(0, actual_per_layer - skip_tokens)
logger.info(f"Batch prefill: {len(all_kv)} layers, {tokens_sent} tokens sent (server_cached={server_cached}, skip={skip_tokens}, captured={actual_per_layer})")
cached_arrays: list[tuple[int, list[torch.Tensor]]] = []
_stream_gdn_states_and_collect(engine, wfile, num_layers, layers_info, cached_arrays)
write_done(wfile, tokens_sent) # pyright: ignore[reportAny]
def _stream_gdn_states(_engine: LLMEngine, wfile: Any, num_layers: int, layers_info: list[dict[str, Any]]) -> None: # type: ignore
def _stream_gdn_states_and_collect(
_engine: LLMEngine,
wfile: Any,
num_layers: int,
layers_info: list[dict[str, Any]],
out_arrays: list[tuple[int, list[torch.Tensor]]],
) -> None: # type: ignore
from exo.worker.engines.vllm.growable_cache import get_model_runner
if not _gdn_states:
@@ -282,6 +351,7 @@ def _stream_gdn_states(_engine: LLMEngine, wfile: Any, num_layers: int, layers_i
arrays.append(rec.to(torch.bfloat16))
if arrays:
write_arrays_state(wfile, layer_idx, arrays) # type: ignore
out_arrays.append((layer_idx, arrays))
except Exception:
logger.opt(exception=True).warning(f"Failed to capture GDN state for layer {layer_idx}")
@@ -289,6 +359,101 @@ def _stream_gdn_states(_engine: LLMEngine, wfile: Any, num_layers: int, layers_i
_gdn_call_idx[0] = 0
def _build_torch_cache(kv_chunks: list[tuple[int, torch.Tensor, torch.Tensor]], arrays_chunks: list[tuple[int, list[torch.Tensor]]], num_layers: int) -> TorchKVCache:
from exo.worker.engines.vllm.kv_cache import ArraysLayerState
layers_by_idx: dict[int, KVLayerState | ArraysLayerState] = {}
for layer_idx, keys, values in kv_chunks:
if layer_idx in layers_by_idx:
prev = layers_by_idx[layer_idx]
if isinstance(prev, KVLayerState):
layers_by_idx[layer_idx] = KVLayerState(
keys=torch.cat([prev.keys, keys], dim=0), # type: ignore
values=torch.cat([prev.values, values], dim=0), # type: ignore
)
else:
layers_by_idx[layer_idx] = KVLayerState(keys=keys, values=values)
for layer_idx, arrays in arrays_chunks:
layers_by_idx[layer_idx] = ArraysLayerState(arrays=[a if isinstance(a, torch.Tensor) else None for a in arrays])
ordered: list[KVLayerState | ArraysLayerState] = []
for i in range(num_layers):
if i in layers_by_idx:
ordered.append(layers_by_idx[i])
else:
ordered.append(KVLayerState(keys=torch.empty(0), values=torch.empty(0)))
return TorchKVCache(ordered)
def _save_vllm_prefix_cache(engine: LLMEngine, request_id: str, prefill_token_ids: list[int]) -> None:
if _prefix_cache_ref is None:
logger.info("Server prefix cache: no cache ref")
return
try:
from exo.worker.engines.vllm.vllm_generator import _save_prefix_cache
try:
engine_core = engine.engine_core.engine_core # type: ignore
coordinator = engine_core.scheduler.kv_cache_manager.coordinator # type: ignore
all_keys: list[str] = []
for mgr in coordinator.single_type_managers: # type: ignore
all_keys.extend(str(k) for k in mgr.req_to_blocks) # type: ignore
logger.info(f"Server prefix cache: request_id={request_id}, available_keys={all_keys[:5]}")
except Exception:
pass
before = len(_prefix_cache_ref.prompts)
_save_prefix_cache(engine, _prefix_cache_ref, request_id, prefill_token_ids, len(prefill_token_ids))
after = len(_prefix_cache_ref.prompts)
if after > before:
logger.info(f"Server prefix cache: saved {len(prefill_token_ids)} tokens (entries: {before}{after})")
else:
logger.info(f"Server prefix cache: save had no effect for request_id={request_id}")
except Exception:
logger.opt(exception=True).warning("Failed to save server-side prefix cache")
def _check_cache(token_ids: list[int]) -> TorchKVCache | None:
if _prefix_cache_ref is None:
return None
import mlx.core as mx
prompt_arr = mx.array(token_ids)
best_index: int | None = None
best_length = 0
for i, cached_prompt in enumerate(_prefix_cache_ref.prompts):
prefix_len = min(len(cached_prompt), len(prompt_arr))
if prefix_len == 0:
continue
match_len = int(mx.sum(cached_prompt[:prefix_len] == prompt_arr[:prefix_len]).item()) # pyright: ignore[reportAny]
if match_len == len(token_ids) and match_len == len(cached_prompt) and match_len > best_length:
best_index = i
best_length = match_len
if best_index is None:
return None
cached = _prefix_cache_ref.caches[best_index]
if isinstance(cached, TorchKVCache):
return cached
return None
def _send_cached(torch_cache: TorchKVCache, token_ids: list[int], wfile: Any, engine: LLMEngine) -> None:
num_layers, dtype_str, layers_info = _get_layer_info(engine)
write_header(wfile, {"num_layers": num_layers, "dtype": dtype_str, "layers": layers_info}) # type: ignore
from exo.worker.engines.vllm.kv_cache import ArraysLayerState
for i, layer in enumerate(torch_cache.layers):
if isinstance(layer, KVLayerState) and layer.keys.numel() > 0:
write_kv_chunk(wfile, i, layer.keys, layer.values) # type: ignore
elif isinstance(layer, ArraysLayerState):
arrays = [a for a in layer.arrays if a is not None]
if arrays:
write_arrays_state(wfile, i, arrays) # type: ignore
write_done(wfile, len(token_ids)) # type: ignore
class _PrefillHandler(socketserver.StreamRequestHandler):
def handle(self) -> None:
try:
@@ -297,6 +462,7 @@ class _PrefillHandler(socketserver.StreamRequestHandler):
return
request: dict[str, Any] = json.loads(line.decode("utf-8")) # pyright: ignore[reportAny]
token_ids: list[int] = request["token_ids"] # pyright: ignore[reportAny]
start_pos: int = request.get("start_pos", 0) # pyright: ignore[reportAny]
engine = _engine_ref
if engine is None:
@@ -309,13 +475,13 @@ class _PrefillHandler(socketserver.StreamRequestHandler):
self.wfile.write(error)
return
logger.info(f"Prefill request: {len(token_ids)} tokens, overlapping={_overlapping}")
logger.info(f"Prefill request: {len(token_ids)} tokens, start_pos={start_pos}, overlapping={_overlapping}")
t0 = time.perf_counter()
if _overlapping:
_run_prefill_overlapping(engine, token_ids, self.wfile)
_run_prefill_overlapping(engine, token_ids, start_pos, self.wfile)
else:
_run_prefill_batch(engine, token_ids, self.wfile)
_run_prefill_batch(engine, token_ids, start_pos, self.wfile)
elapsed = time.perf_counter() - t0
logger.info(f"Prefill complete: {len(token_ids)} tokens in {elapsed*1000:.0f}ms ({len(token_ids)/elapsed:.0f} tok/s)")
@@ -328,10 +494,12 @@ def start_prefill_server(
bind_address: str,
port: int,
overlapping: bool = True,
prefix_cache: KVPrefixCache | None = None,
) -> socketserver.ThreadingTCPServer:
global _engine_ref, _overlapping
global _engine_ref, _overlapping, _prefix_cache_ref
_engine_ref = engine
_overlapping = overlapping
_prefix_cache_ref = prefix_cache
_patch_gdn_capture()
_init_gdn_layer_order()

View File

@@ -61,7 +61,9 @@ def _str_to_dtype(s: str) -> torch.dtype:
def _dtype_size(dtype: torch.dtype) -> int:
return {torch.float16: 2, torch.bfloat16: 2, torch.float32: 4}[dtype]
if dtype == torch.bfloat16:
return 4
return {torch.float16: 2, torch.float32: 4}[dtype]
def write_header(stream: BinaryIO, header: dict[str, object]) -> None:
@@ -72,7 +74,7 @@ def write_header(stream: BinaryIO, header: dict[str, object]) -> None:
def _tensor_to_bytes(t: torch.Tensor) -> bytes:
if t.dtype == torch.bfloat16:
return t.contiguous().view(torch.int16).numpy().tobytes() # type: ignore
return t.contiguous().float().numpy().tobytes() # type: ignore
return t.contiguous().numpy().tobytes() # type: ignore
@@ -133,8 +135,8 @@ def read_message(stream: BinaryIO, header: dict[str, object]) -> Message | None:
values_raw = _read_exactly(stream, tensor_bytes)
shape = (num_tokens, n_heads, head_dim)
if dtype == torch.bfloat16:
keys: torch.Tensor = torch.frombuffer(bytearray(keys_raw), dtype=torch.int16).view(torch.bfloat16).reshape(shape).clone() # type: ignore
values: torch.Tensor = torch.frombuffer(bytearray(values_raw), dtype=torch.int16).view(torch.bfloat16).reshape(shape).clone() # type: ignore
keys: torch.Tensor = torch.frombuffer(bytearray(keys_raw), dtype=torch.float32).reshape(shape).to(torch.bfloat16).clone() # type: ignore
values: torch.Tensor = torch.frombuffer(bytearray(values_raw), dtype=torch.float32).reshape(shape).to(torch.bfloat16).clone() # type: ignore
else:
keys = torch.frombuffer(bytearray(keys_raw), dtype=dtype).reshape(shape).clone() # type: ignore
values = torch.frombuffer(bytearray(values_raw), dtype=dtype).reshape(shape).clone() # type: ignore
@@ -157,7 +159,7 @@ def read_message(stream: BinaryIO, header: dict[str, object]) -> Message | None:
total_elems *= d # pyright: ignore[reportAny]
raw = _read_exactly(stream, total_elems * elem_size)
if dtype == torch.bfloat16:
t: torch.Tensor = torch.frombuffer(bytearray(raw), dtype=torch.int16).view(torch.bfloat16).reshape(shape_arr).clone() # type: ignore
t: torch.Tensor = torch.frombuffer(bytearray(raw), dtype=torch.float32).reshape(shape_arr).to(torch.bfloat16).clone() # type: ignore
else:
t = torch.frombuffer(bytearray(raw), dtype=dtype).reshape(shape_arr).clone() # type: ignore
arrays.append(t) # pyright: ignore[reportUnknownArgumentType]

View File

@@ -37,6 +37,8 @@ class StreamingConnectorMetadata(KVConnectorMetadata): # pyright: ignore[report
class StreamingConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBaseClass]
_queue: queue.Queue[tuple[int, torch.Tensor, torch.Tensor] | None]
_save_count: int = 0
def __init__(self, vllm_config: Any, role: KVConnectorRole, kv_cache_config: Any = None) -> None: # type: ignore
super().__init__(vllm_config, role, kv_cache_config) # pyright: ignore[reportUnknownMemberType]
self._queue = _shared_queue
@@ -61,11 +63,14 @@ class StreamingConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBa
return
layer_idx = int(m.group(1))
torch.cuda.synchronize()
if isinstance(kv_layer, (list, tuple)):
return
if self._save_count < 1:
import logging
logging.getLogger("exo").info(f"save_kv_layer: kv_layer.shape={kv_layer.shape} dtype={kv_layer.dtype} slot_mapping.shape={slot_mapping.shape if slot_mapping is not None else None}") # pyright: ignore[reportAny]
self._save_count += 1
if slot_mapping is not None:
if kv_layer.shape[0] == 2: # pyright: ignore[reportAny]
k_all = kv_layer[0] # pyright: ignore[reportAny]
@@ -77,10 +82,11 @@ class StreamingConnector(KVConnectorBase_V1): # pyright: ignore[reportUntypedBa
v_flat = v_all.reshape(-1, *v_all.shape[-2:]) # pyright: ignore[reportAny]
valid = slot_mapping >= 0 # pyright: ignore[reportAny]
safe_sm = slot_mapping.clamp(min=0) # pyright: ignore[reportAny]
keys = k_flat[safe_sm] # pyright: ignore[reportAny]
values = v_flat[safe_sm] # pyright: ignore[reportAny]
keys[~valid] = 0
values[~valid] = 0
keys = k_flat[safe_sm][valid] # pyright: ignore[reportAny]
values = v_flat[safe_sm][valid] # pyright: ignore[reportAny]
if keys.dtype not in (torch.bfloat16, torch.float16, torch.float32): # pyright: ignore[reportAny]
keys = keys.to(torch.bfloat16) # pyright: ignore[reportAny]
values = values.to(torch.bfloat16) # pyright: ignore[reportAny]
self._queue.put((layer_idx, keys.cpu(), values.cpu())) # pyright: ignore[reportAny]
else:
self._queue.put((layer_idx, kv_layer.cpu().clone(), kv_layer.cpu().clone())) # pyright: ignore[reportAny]

View File

@@ -98,6 +98,7 @@ class Master:
from exo.master.placement_utils import (
_find_ip_prioritised as find_ip_prioritised, # pyright: ignore[reportPrivateUsage]
)
from exo.shared.models.model_cards import derive_base_model
endpoints: list[tuple[int, str]] = []
vllm_instance_count = 0
@@ -111,21 +112,14 @@ class Master:
if first_shard is None:
logger.info(f"Prefill routing: VllmInstance {instance.instance_id} has no shards")
continue
if first_shard.model_card.base_model.lower() != decode_model_base.lower():
if derive_base_model(first_shard.model_card.base_model).lower() != decode_model_base.lower():
logger.info(
f"Prefill routing: VllmInstance {instance.instance_id} base_model "
f"{first_shard.model_card.base_model!r} != decode {decode_model_base!r}"
)
continue
active_task_count = sum(
1 for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
and task.task_status in (TaskStatus.Pending, TaskStatus.Running)
)
if active_task_count > 0:
logger.info(f"Prefill routing: VllmInstance {instance.instance_id} busy ({active_task_count} active tasks)")
continue
pass
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
runner_status = self.state.runners.get(runner_id)

View File

@@ -41,7 +41,7 @@ _card_cache: dict[ModelId, "ModelCard"] = {}
import re
_QUANT_SUFFIXES = re.compile(
r"[-_](?:MLX|MXFP[0-9]+|GPTQ|AWQ|GGUF|fp16|bf16|fp8|int[0-9]+|[0-9]+(?:\.[0-9]+)?bit|Q[0-9]+(?:_[A-Z0-9]+)?|gs[0-9]+)(?:[-_](?:MLX|Q[0-9]+|Int[0-9]+|[A-Z0-9]+|gs[0-9]+))*$",
r"[-_ ](?:MLX|MXFP[0-9]+|NVFP[0-9]+|GPTQ|AWQ|GGUF|fp16|bf16|fp8|int[0-9]+|[0-9]+(?:\.[0-9]+)?bit|Q[0-9]+(?:_[A-Z0-9]+)?|gs[0-9]+)(?:[-_ ](?:MLX|Q[0-9]+|Int[0-9]+|[A-Z0-9]+|gs[0-9]+))*$",
re.IGNORECASE,
)
@@ -127,7 +127,8 @@ class ModelCard(CamelCaseModel):
if not self.base_model:
self.base_model = derive_base_model(self.model_id)
else:
self.base_model = _normalize_base_model(self.base_model)
stripped = _QUANT_SUFFIXES.sub("", self.base_model)
self.base_model = _normalize_base_model(stripped)
return self
@field_validator("tasks", mode="before")

View File

@@ -194,10 +194,18 @@ class KVPrefixCache:
# This ensures stream_generate always has at least one token to start with
mlx_cache = self._get_mlx_cache(best_index)
has_ssm = has_non_kv_caches(mlx_cache)
snapshots_available = self._snapshots[best_index] is not None
if is_exact and has_ssm and not snapshots_available:
prompt_cache = deepcopy(mlx_cache)
self._access_counter += 1
self._last_used[best_index] = self._access_counter
remaining = prompt_tokens[best_length:]
return prompt_cache, remaining, best_index
target = (max_length - 1) if is_exact and not has_ssm else best_length
restore_pos, restore_snap = self._get_snapshot(best_index, target)
# No usable snapshot — need fresh cache
if restore_snap is None and has_ssm:
return make_kv_cache(model), prompt_tokens, None

View File

@@ -121,7 +121,9 @@ class ExoBatchGenerator:
matched_index: int | None = None
prompt_tokens = all_prompt_tokens
if self.kv_prefix_cache is not None and not is_bench:
has_prefill_endpoints = bool(task_params.prefill_endpoints) and len(all_prompt_tokens) > 1000
if self.kv_prefix_cache is not None and not is_bench and not has_prefill_endpoints:
cache, remaining_tokens, matched_index = self.kv_prefix_cache.get_kv_cache(
self.model, all_prompt_tokens
)
@@ -168,8 +170,13 @@ class ExoBatchGenerator:
model_id=str(task_params.model),
mlx_model=self.model,
on_prefill_progress=on_prefill_progress,
existing_cache=None,
start_pos=0,
)
cache = injected_cache
from exo.worker.engines.mlx.cache import snapshot_ssm_states
cache_snapshots = [snapshot_ssm_states(cache)]
_prefill_tps = total_tokens / max(time.perf_counter() - t0, 0.001)
used_remote_prefill = True
logger.info(f"Remote prefill: {total_tokens} tokens at {_prefill_tps:.0f} tok/s")

View File

@@ -264,6 +264,7 @@ class TorchKVCache:
first = kv_caches[0]
device = first[0].device if isinstance(first, list) else first.device
block_size = first[0].shape[1] if isinstance(first, list) else first.shape[-3]
for layer_idx, layer in enumerate(self.layers):
if not isinstance(layer, KVLayerState):
continue
@@ -271,14 +272,26 @@ class TorchKVCache:
bt = block_tables[gi]
kv = kv_caches[layer_idx]
k_all, v_all = _split_kv(kv)
n_blocks = min(len(bt), layer.keys.shape[0])
keys = layer.keys
values = layer.values
if keys.dim() == 3:
offset = token_offset_per_group[gi] if token_offset_per_group else 0
if offset > 0:
keys = keys[offset:]
values = values[offset:]
s, h, d = keys.shape
pad = (block_size - s % block_size) % block_size
if pad > 0:
keys = torch.nn.functional.pad(keys, (0, 0, 0, 0, 0, pad))
values = torch.nn.functional.pad(values, (0, 0, 0, 0, 0, pad))
keys = keys.reshape(-1, block_size, h, d)
values = values.reshape(-1, block_size, h, d)
n_blocks = min(len(bt), keys.shape[0])
if n_blocks > 0:
k_all[bt[:n_blocks]] = layer.keys[:n_blocks].to(
device, non_blocking=True
)
v_all[bt[:n_blocks]] = layer.values[:n_blocks].to(
device, non_blocking=True
)
k_all[bt[:n_blocks]] = keys[:n_blocks].to(device, non_blocking=True)
v_all[bt[:n_blocks]] = values[:n_blocks].to(device, non_blocking=True)
torch.cuda.synchronize()
def __iter__(self) -> Iterator[LayerState]:

View File

@@ -213,8 +213,9 @@ def vllm_generate(
tokenizer = engine.get_tokenizer()
stop_ids = _stop_token_ids(tokenizer, model_id)
DEFAULT_PREFILL_STEP_SIZE = 8192
max_batch_tokens: int = (
getattr(engine.model_config, "max_num_batched_tokens", 2048) or 2048
getattr(engine.model_config, "max_num_batched_tokens", DEFAULT_PREFILL_STEP_SIZE) or DEFAULT_PREFILL_STEP_SIZE
) # type: ignore[reportUnknownMemberType]
start_time = time.perf_counter()
first_token_time: float | None = None
@@ -577,21 +578,36 @@ def load_vllm_engine(
"kv_role": "kv_both",
}
engine_args = EngineArgs(
model=model_path,
served_model_name=str(model_id),
gpu_memory_utilization=0.05,
trust_remote_code=trust_remote_code,
load_format="fastsafetensors",
enable_prefix_caching=False,
attention_backend="TRITON_ATTN",
enforce_eager=True,
disable_log_stats=True,
kv_transfer_config=kv_transfer_config, # type: ignore
)
is_nvfp4 = "nvfp4" in model_path.lower() or "nvfp4" in str(model_id).lower()
backends = ["FLASHINFER", "TRITON_ATTN"] if is_nvfp4 else ["FLASH_ATTN", "TRITON_ATTN"]
set_weight_loading_callback(on_layer_loaded)
engine = LLMEngine.from_engine_args(engine_args)
engine: LLMEngine | None = None
for backend in backends:
try:
engine_args = EngineArgs(
model=model_path,
served_model_name=str(model_id),
gpu_memory_utilization=0.05,
trust_remote_code=trust_remote_code,
load_format="fastsafetensors",
enable_prefix_caching=False,
attention_backend=backend,
enforce_eager=True,
disable_log_stats=True,
max_num_batched_tokens=4096,
kv_transfer_config=kv_transfer_config, # type: ignore
)
set_weight_loading_callback(on_layer_loaded)
engine = LLMEngine.from_engine_args(engine_args)
logger.info(f"vLLM engine using attention backend: {backend}")
break
except (ValueError, RuntimeError) as e:
logger.warning(f"Attention backend {backend} failed: {e}, trying next")
continue
if engine is None:
raise RuntimeError(f"No attention backend worked for {model_id}")
tool_parser: ToolParser | None = None
tokenizer = engine.get_tokenizer()

View File

@@ -182,7 +182,6 @@ class Worker:
json.dump(endpoints, f)
except:
logger.warning("Updating prefill endpoints failed")
pass
async def plan_step(self):
while True:

View File

@@ -70,7 +70,7 @@ def entrypoint(
if isinstance(bound_instance.instance, VllmInstance):
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
os.environ["VLLM_BATCH_INVARIANT"] = "1"
# os.environ["VLLM_BATCH_INVARIANT"] = "1"
_ensure_cuda_libs()
from exo.shared.constants import EXO_MODELS_DIR
from exo.worker.runner.llm_inference.runner import Runner, VllmBuilder

View File

@@ -561,7 +561,9 @@ class VllmBuilder(Builder):
tokenizer = TokenizerWrapper(self._engine.get_tokenizer())
max_concurrent = 1 if os.environ.get("EXO_NO_BATCH") else 8
prefill_port = int(os.environ.get("EXO_PREFILL_PORT", "8900"))
from exo.master.placement import random_ephemeral_port
prefill_port = random_ephemeral_port()
overlapping = not os.environ.get("EXO_NO_OVERLAPPING_PREFILL_SENDS")
try:
from exo.disaggregated.prefill_server import start_prefill_server
@@ -571,6 +573,7 @@ class VllmBuilder(Builder):
bind_address="0.0.0.0",
port=prefill_port,
overlapping=overlapping,
prefix_cache=self._prefix_cache,
)
self._prefill_server_port = prefill_port
except Exception: