mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-23 16:19:07 -04:00
paged: in-kernel decode read patch 0009 (kill the gather regression)
Mirror patch 0009 for the paged llama.cpp engine. It removes the patch-0003 per-layer per-step gather (ggml_get_rows of K/V to a contiguous buffer) on the decode step and instead reads paged blocks in-kernel: build_attn passes the physical K/V views plus a position-ordered block table (src[5] of ggml_flash_attn_ext, padded to FATTN_KQ_STRIDE), and the CUDA fattn vec kernel plus the CPU reference map each logical KV index to its physical cell and read in place. KV_max / parallel_blocks / stream_k split-K are unchanged; a nullptr block table is the stock contiguous read (byte-identical, gated by LLAMA_KV_PAGED). Verified on GB10 (sm_121, Qwen3-32B NVFP4, batch 32 / 1024 ctx): the decode step drops from 1279 ms (paged-gather) to 696 ms in-kernel (-46%), reaching stock parity (647 ms). CPU paged vs stock is bit-for-bit identical; GPU stays within the documented batch-shape non-determinism band. Assisted-by: Claude:opus-4.8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -0,0 +1,609 @@
|
||||
From 59490d82e4d0d4ad05ffb5ca3cccc668f4a75281 Mon Sep 17 00:00:00 2001
|
||||
From: Ettore Di Giacinto <mudler@localai.io>
|
||||
Date: Mon, 22 Jun 2026 20:03:17 +0200
|
||||
Subject: [PATCH] paged in-kernel decode read (env LLAMA_KV_PAGED) - patch 0009
|
||||
|
||||
Replace the per-layer per-step gather (patch 0003: ggml_get_rows of K/V into a
|
||||
contiguous buffer) with an in-kernel paged read on the decode step. build_attn
|
||||
passes the UNMODIFIED physical K/V views plus a block table (src[5] of
|
||||
ggml_flash_attn_ext: an I32 [n_view, n_stream] position-ordered physical-cell
|
||||
index, padded to FATTN_KQ_STRIDE). The CUDA fattn vec kernel and the CPU
|
||||
reference map logical KV index j -> physical cell block_table[seq*ne11+j] and
|
||||
read K_base+cell*nb11 / V_base+cell*nb21 in place, so the get_rows of K and V
|
||||
(the bulk of the gather) is gone. The mask stays a small compacted [n_view]
|
||||
causal mask in the same position order; KV_max / parallel_blocks / stream_k
|
||||
split-K are unchanged. The decode shape is forced onto the vec kernel (the only
|
||||
one wired for the block table); a nullptr block table => the stock contiguous
|
||||
read, byte-identical.
|
||||
|
||||
Token-POSITION ordering keeps the flash-attn reduction order identical to stock,
|
||||
so CPU-paged logits == CPU-stock bit-for-bit (verified: 4-stream FA greedy, 64
|
||||
tokens). On GPU paged(vec) == stock(vec) at batch 1; at batch>1 it stays within
|
||||
the documented vec-vs-mma non-determinism band. Decode step at batch 32 / 1024
|
||||
ctx on GB10 (Qwen3-32B NVFP4): paged-gather 1279 ms -> in-kernel 696 ms (-46%),
|
||||
recovering the gather regression to stock parity (647 ms). Gated behind
|
||||
LLAMA_KV_PAGED; no-op (stock byte-identical) when unset.
|
||||
|
||||
Assisted-by: Claude:opus-4.8 [Claude Code]
|
||||
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
|
||||
---
|
||||
ggml/include/ggml.h | 6 ++
|
||||
ggml/src/ggml-cpu/ops.cpp | 10 ++-
|
||||
ggml/src/ggml-cuda/fattn-common.cuh | 8 +-
|
||||
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 4 +-
|
||||
ggml/src/ggml-cuda/fattn-tile.cuh | 4 +-
|
||||
ggml/src/ggml-cuda/fattn-vec.cuh | 25 +++++--
|
||||
ggml/src/ggml-cuda/fattn-wmma-f16.cu | 4 +-
|
||||
ggml/src/ggml-cuda/fattn.cu | 9 +++
|
||||
ggml/src/ggml.c | 14 ++++
|
||||
src/llama-graph.cpp | 23 ++++--
|
||||
src/llama-graph.h | 3 +-
|
||||
src/llama-kv-cache.cpp | 31 ++++++++
|
||||
src/llama-kv-cache.h | 4 +
|
||||
src/paged-attn.cpp | 107 +++++++++++++++++++++++++++
|
||||
src/paged-attn.h | 18 +++++
|
||||
15 files changed, 248 insertions(+), 22 deletions(-)
|
||||
|
||||
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
|
||||
index d6807b6..823f5a9 100644
|
||||
--- a/ggml/include/ggml.h
|
||||
+++ b/ggml/include/ggml.h
|
||||
@@ -2427,6 +2427,12 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * sinks);
|
||||
|
||||
+ // [paged] optional block table in src[5]: I32 [n_kv_logical, n_stream]; maps each
|
||||
+ // logical KV index to the physical cell within K/V. nullptr => stock contiguous read.
|
||||
+ GGML_API void ggml_flash_attn_ext_set_block_table(
|
||||
+ struct ggml_tensor * a,
|
||||
+ struct ggml_tensor * block_table);
|
||||
+
|
||||
// TODO: needs to be adapted to ggml_flash_attn_ext
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||
struct ggml_context * ctx,
|
||||
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
||||
index 74611dc..63c07a2 100644
|
||||
--- a/ggml/src/ggml-cpu/ops.cpp
|
||||
+++ b/ggml/src/ggml-cpu/ops.cpp
|
||||
@@ -8330,6 +8330,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||
const ggml_tensor * v = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
+ const ggml_tensor * block_table = dst->src[5]; // [paged] logical->physical cell map (src[5])
|
||||
+ const int32_t * bt = block_table ? (const int32_t *) block_table->data : nullptr;
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||
@@ -8449,7 +8451,9 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||
|
||||
float s; // KQ value
|
||||
|
||||
- const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
+ // [paged] map the logical KV index ic to its physical cell via the block table.
|
||||
+ const int64_t ic_phys = bt ? (int64_t) bt[ik3*nek1 + ic] : ic;
|
||||
+ const char * k_data = (const char *) k->data + ( ic_phys*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||
|
||||
s = s*scale; // scale KQ value
|
||||
@@ -8465,7 +8469,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||
float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
|
||||
float vs = 1.0f; // post-softmax KQ value, expf(s - M)
|
||||
|
||||
- const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
+ const char * v_data = ((const char *) v->data + (ic_phys*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
if (s > M) {
|
||||
@@ -9021,7 +9025,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||
|
||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
- bool use_tiled = !use_ref &&
|
||||
+ bool use_tiled = !use_ref && dst->src[5] == nullptr && // [paged] one_chunk honors the block table
|
||||
(q->type == GGML_TYPE_F32 &&
|
||||
kv_is_f32_or_f16 &&
|
||||
k->type == v->type &&
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
|
||||
index 8dfa51a..3c6ddd5 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-common.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
|
||||
@@ -39,7 +39,8 @@ typedef void (* fattn_kernel_t)(
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
- const int32_t nb31, const int32_t nb32, const int64_t nb33);
|
||||
+ const int32_t nb31, const int32_t nb32, const int64_t nb33,
|
||||
+ const int * __restrict__ block_table);
|
||||
|
||||
typedef float (*vec_dot_KQ_t)(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
||||
@@ -981,6 +982,8 @@ void launch_fattn(
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
+ const ggml_tensor * block_table = dst->src[5]; // [paged] optional logical->physical map
|
||||
+ const int * bt_ptr = block_table ? (const int *) block_table->data : nullptr;
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
@@ -1217,7 +1220,8 @@ void launch_fattn(
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
||||
- mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
|
||||
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
|
||||
+ bt_ptr
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
index 83478a0..0a92cd6 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
|
||||
@@ -1723,7 +1723,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
- const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
+ const int32_t nb31, const int32_t nb32, const int64_t nb33,
|
||||
+ const int * __restrict__ block_table) {
|
||||
+ GGML_UNUSED(block_table); // [paged] block table is honored only by the vec kernel
|
||||
ggml_cuda_pdl_sync(); // TODO optimize placement
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE))
|
||||
const char * GGML_CUDA_RESTRICT Q = Q_ptr;
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
index 0a09981..0ff14e6 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
|
||||
@@ -808,7 +808,9 @@ static __global__ void flash_attn_tile(
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
- const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
+ const int32_t nb31, const int32_t nb32, const int64_t nb33,
|
||||
+ const int * __restrict__ block_table) {
|
||||
+ GGML_UNUSED(block_table); // [paged] block table is honored only by the vec kernel
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
const char * GGML_CUDA_RESTRICT Q = Q_ptr;
|
||||
const char * GGML_CUDA_RESTRICT K = K_ptr;
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh
|
||||
index 69dd936..a09e2fb 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-vec.cuh
|
||||
+++ b/ggml/src/ggml-cuda/fattn-vec.cuh
|
||||
@@ -39,7 +39,8 @@ static __global__ void flash_attn_ext_vec(
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
- const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
+ const int32_t nb31, const int32_t nb32, const int64_t nb33,
|
||||
+ const int * __restrict__ block_table) {
|
||||
ggml_cuda_pdl_lc();
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
const char * GGML_CUDA_RESTRICT Q = Q_ptr;
|
||||
@@ -61,7 +62,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
- nb31, nb32, nb33);
|
||||
+ nb31, nb32, nb33, block_table);
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
@@ -110,6 +111,14 @@ static __global__ void flash_attn_ext_vec(
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
+ // [paged] in-kernel block-table read: logical KV index j -> physical cell
|
||||
+ // block_table[sequence*ne11 + j]; read K0 + cell*nb11 / V0 + cell*nb21. The
|
||||
+ // mask/KV_max stay logical (the table is in token-position order). nullptr =>
|
||||
+ // the stock contiguous read below.
|
||||
+ const char * GGML_CUDA_RESTRICT K0 = K;
|
||||
+ const char * GGML_CUDA_RESTRICT V0 = V;
|
||||
+ const int * GGML_CUDA_RESTRICT bt = block_table ? block_table + (size_t) sequence*ne11 : nullptr;
|
||||
+
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
@@ -267,10 +276,11 @@ static __global__ void flash_attn_ext_vec(
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
|
||||
const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0;
|
||||
+ const char * GGML_CUDA_RESTRICT K_blk = bt ? (K0 + (int64_t) bt[k_VKQ_0 + i_KQ]*nb11) : (K + i_KQ*nb11);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
- float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
|
||||
+ float sum = vec_dot_KQ(K_blk, Q_reg[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum<nthreads_KQ>(sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
@@ -324,6 +334,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
||||
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
||||
+ const char * GGML_CUDA_RESTRICT V_blk = bt ? (V0 + (int64_t) bt[k_VKQ_0 + k]*nb21) : (V + k*nb21);
|
||||
|
||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||
half2 KQ_k[ncols];
|
||||
@@ -336,14 +347,14 @@ static __global__ void flash_attn_ext_vec(
|
||||
half2 tmp[V_rows_per_thread/2];
|
||||
if constexpr (type_V == GGML_TYPE_BF16) {
|
||||
float2 tmp_f[V_rows_per_thread/2];
|
||||
- dequantize_V(V + k*nb21, tmp_f,
|
||||
+ dequantize_V(V_blk, tmp_f,
|
||||
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
|
||||
tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]);
|
||||
}
|
||||
} else {
|
||||
- dequantize_V(V + k*nb21, tmp,
|
||||
+ dequantize_V(V_blk, tmp,
|
||||
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
|
||||
}
|
||||
#pragma unroll
|
||||
@@ -363,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
float2 tmp[V_rows_per_thread/2];
|
||||
- dequantize_V(V + k*nb21, tmp,
|
||||
+ dequantize_V(V_blk, tmp,
|
||||
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
|
||||
@@ -522,7 +533,7 @@ static __global__ void flash_attn_ext_vec(
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
- nb31, nb32, nb33);
|
||||
+ nb31, nb32, nb33, block_table);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
|
||||
index 6850716..5357849 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu
|
||||
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
|
||||
@@ -44,7 +44,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
- const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
+ const int32_t nb31, const int32_t nb32, const int64_t nb33,
|
||||
+ const int * __restrict__ block_table) {
|
||||
+ GGML_UNUSED(block_table); // [paged] block table is honored only by the vec kernel
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
|
||||
const char * GGML_CUDA_RESTRICT Q = Q_ptr;
|
||||
const char * GGML_CUDA_RESTRICT K = K_ptr;
|
||||
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
|
||||
index d6c501b..e3771ee 100644
|
||||
--- a/ggml/src/ggml-cuda/fattn.cu
|
||||
+++ b/ggml/src/ggml-cuda/fattn.cu
|
||||
@@ -574,6 +574,15 @@ size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * d
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
+
|
||||
+ // [paged] the block table (src[5]) is only honored by the vec kernel's
|
||||
+ // in-kernel read; force it. build_attn only sets it for a vec-supported
|
||||
+ // 1-token-per-stream decode shape.
|
||||
+ if (dst->src[5] != nullptr) {
|
||||
+ ggml_cuda_flash_attn_ext_vec(ctx, dst);
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
|
||||
case BEST_FATTN_KERNEL_NONE:
|
||||
GGML_ABORT("fatal error");
|
||||
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
|
||||
index b43016c..adbe52b 100644
|
||||
--- a/ggml/src/ggml.c
|
||||
+++ b/ggml/src/ggml.c
|
||||
@@ -5442,6 +5442,20 @@ void ggml_flash_attn_ext_add_sinks(
|
||||
a->src[4] = sinks;
|
||||
}
|
||||
|
||||
+void ggml_flash_attn_ext_set_block_table(
|
||||
+ struct ggml_tensor * a,
|
||||
+ struct ggml_tensor * block_table) {
|
||||
+ if (!block_table) {
|
||||
+ a->src[5] = NULL;
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
+ GGML_ASSERT(block_table->type == GGML_TYPE_I32);
|
||||
+
|
||||
+ a->src[5] = block_table;
|
||||
+}
|
||||
+
|
||||
// ggml_flash_attn_back
|
||||
|
||||
struct ggml_tensor * ggml_flash_attn_back(
|
||||
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
|
||||
index b59d2a5..abdb48d 100644
|
||||
--- a/src/llama-graph.cpp
|
||||
+++ b/src/llama-graph.cpp
|
||||
@@ -2074,7 +2074,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
- int il) const {
|
||||
+ int il,
|
||||
+ ggml_tensor * block_table) const {
|
||||
const bool v_trans = v->nb[1] > v->nb[2];
|
||||
|
||||
// split the batch into streams if needed
|
||||
@@ -2109,6 +2110,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
|
||||
|
||||
+ if (block_table) {
|
||||
+ ggml_flash_attn_ext_set_block_table(cur, block_table);
|
||||
+ }
|
||||
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
|
||||
|
||||
@@ -2358,12 +2362,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
- // [paged 0003] gather K, V and the mask to the sequence's used cells only
|
||||
- // (no-op unless env LLAMA_KV_PAGED is set).
|
||||
- ggml_tensor * kq_mask_g = kq_mask;
|
||||
- paged_attn::gather(ctx0, res, mctx_cur, &k, &v, &kq_mask_g);
|
||||
+ // [paged] decode read: when paging is active and this is a 1-token-per-stream
|
||||
+ // decode step, present K/V as n_gather views + a block table so the fattn
|
||||
+ // kernel reads the sequence's cells in-kernel (no get_rows of K/V). Else
|
||||
+ // fall back to the gather-read (prefill, transposed V, or env off). All a
|
||||
+ // no-op unless env LLAMA_KV_PAGED is set => stock byte-identical.
|
||||
+ ggml_tensor * kq_mask_g = kq_mask;
|
||||
+ ggml_tensor * block_table = nullptr;
|
||||
+ const bool is_decode = (q_cur->ne[2] == k->ne[3]); // 1 query token per stream
|
||||
+ if (!(is_decode && paged_attn::in_kernel_decode(ctx0, res, mctx_cur, &k, &v, &kq_mask_g, &block_table))) {
|
||||
+ paged_attn::gather(ctx0, res, mctx_cur, &k, &v, &kq_mask_g);
|
||||
+ }
|
||||
|
||||
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_g, sinks, v_mla, kq_scale, il);
|
||||
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_g, sinks, v_mla, kq_scale, il, block_table);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (inp->self_v_rot) {
|
||||
diff --git a/src/llama-graph.h b/src/llama-graph.h
|
||||
index 5e8a658..c95ae49 100644
|
||||
--- a/src/llama-graph.h
|
||||
+++ b/src/llama-graph.h
|
||||
@@ -969,7 +969,8 @@ struct llm_graph_context {
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
- int il) const;
|
||||
+ int il,
|
||||
+ ggml_tensor * block_table = nullptr) const; // [paged] optional src[5] block table
|
||||
|
||||
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
||||
|
||||
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
|
||||
index 7510ff9..0351f86 100644
|
||||
--- a/src/llama-kv-cache.cpp
|
||||
+++ b/src/llama-kv-cache.cpp
|
||||
@@ -1474,6 +1474,33 @@ void llama_kv_cache::get_gather_idxs(int32_t * dst, uint32_t n_kv, const slot_in
|
||||
}
|
||||
}
|
||||
|
||||
+void llama_kv_cache::get_block_table(int32_t * dst, uint32_t n_blk, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
+ const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
||||
+ for (uint32_t j = 0; j < ns; ++j) {
|
||||
+ const auto & cells = v_cells[sinfo.s0 + j];
|
||||
+ const uint32_t n = std::min<uint32_t>(n_kv, cells.size());
|
||||
+ std::vector<std::pair<llama_pos, int32_t>> pc;
|
||||
+ pc.reserve(n);
|
||||
+ int32_t pad = -1;
|
||||
+ for (uint32_t i = 0; i < n; ++i) {
|
||||
+ if (!cells.is_empty(i)) {
|
||||
+ pc.emplace_back(cells.pos_get(i), (int32_t) i);
|
||||
+ } else if (pad < 0) {
|
||||
+ pad = (int32_t) i;
|
||||
+ }
|
||||
+ }
|
||||
+ std::sort(pc.begin(), pc.end());
|
||||
+ int32_t * col = dst + (size_t) j * n_blk;
|
||||
+ for (size_t k = 0; k < pc.size(); ++k) {
|
||||
+ col[k] = pc[k].second;
|
||||
+ }
|
||||
+ const int32_t padv = (pad >= 0) ? pad : (pc.empty() ? 0 : pc.back().second);
|
||||
+ for (uint32_t k = (uint32_t) pc.size(); k < n_blk; ++k) {
|
||||
+ col[k] = padv;
|
||||
+ }
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
GGML_UNUSED(sinfo);
|
||||
|
||||
@@ -2773,6 +2800,10 @@ void llama_kv_cache_context::get_gather_idxs(int32_t * dst) const {
|
||||
kv->get_gather_idxs(dst, n_kv, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
+void llama_kv_cache_context::get_block_table(int32_t * dst, uint32_t n_blk) const {
|
||||
+ kv->get_block_table(dst, n_blk, n_kv, sinfos[i_cur]);
|
||||
+}
|
||||
+
|
||||
ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
|
||||
}
|
||||
diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
|
||||
index f374ac6..e9980b6 100644
|
||||
--- a/src/llama-kv-cache.h
|
||||
+++ b/src/llama-kv-cache.h
|
||||
@@ -176,6 +176,9 @@ public:
|
||||
// gather-read. get_n_gather returns the max count across streams.
|
||||
uint32_t get_n_gather(uint32_t n_kv, const slot_info & sinfo) const;
|
||||
void get_gather_idxs(int32_t * dst, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
+ // [paged inc1] block table [n_blk, n_stream] (position order, padded to n_blk
|
||||
+ // per column with a masked empty cell) for the in-kernel paged read.
|
||||
+ void get_block_table(int32_t * dst, uint32_t n_blk, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
@@ -386,6 +389,7 @@ public:
|
||||
// current ubatch's stream).
|
||||
uint32_t get_n_gather() const;
|
||||
void get_gather_idxs(int32_t * dst) const;
|
||||
+ void get_block_table(int32_t * dst, uint32_t n_blk) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
// note: the heads in k_cur and v_cur should be laid out contiguously in memory
|
||||
diff --git a/src/paged-attn.cpp b/src/paged-attn.cpp
|
||||
index ade75e8..8eebeaa 100644
|
||||
--- a/src/paged-attn.cpp
|
||||
+++ b/src/paged-attn.cpp
|
||||
@@ -43,6 +43,25 @@ public:
|
||||
ggml_tensor * idxs;
|
||||
};
|
||||
|
||||
+// Block table filler for the in-kernel paged read: fills an I32 [n_blk, n_stream]
|
||||
+// tensor with each stream's position-ordered cells, padded to n_blk (per column)
|
||||
+// with a masked empty cell, by delegating to the kv-cache context.
|
||||
+class input_block_table : public llm_graph_input_i {
|
||||
+public:
|
||||
+ input_block_table(const llama_kv_cache_context * mctx, ggml_tensor * idxs, uint32_t n_blk)
|
||||
+ : mctx(mctx), idxs(idxs), n_blk(n_blk) {}
|
||||
+
|
||||
+ void set_input(const llama_ubatch * ubatch) override {
|
||||
+ GGML_UNUSED(ubatch);
|
||||
+ GGML_ASSERT(idxs && ggml_backend_buffer_is_host(idxs->buffer));
|
||||
+ mctx->get_block_table((int32_t *) idxs->data, n_blk);
|
||||
+ }
|
||||
+
|
||||
+ const llama_kv_cache_context * mctx;
|
||||
+ ggml_tensor * idxs;
|
||||
+ uint32_t n_blk;
|
||||
+};
|
||||
+
|
||||
} // namespace
|
||||
|
||||
void gather(ggml_context * ctx0,
|
||||
@@ -125,4 +144,92 @@ void gather(ggml_context * ctx0,
|
||||
}
|
||||
}
|
||||
|
||||
+bool in_kernel_decode(ggml_context * ctx0,
|
||||
+ llm_graph_result * res,
|
||||
+ const llama_kv_cache_context * mctx,
|
||||
+ ggml_tensor ** k,
|
||||
+ ggml_tensor ** v,
|
||||
+ ggml_tensor ** kq_mask,
|
||||
+ ggml_tensor ** block_table) {
|
||||
+ if (!active()) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ // Bench escape hatch: LLAMA_KV_PAGED_GATHER=1 forces the old gather-read decode
|
||||
+ // path (for a same-build BEFORE/AFTER decode-step comparison). Dev-only.
|
||||
+ static const bool force_gather = (std::getenv("LLAMA_KV_PAGED_GATHER") != nullptr);
|
||||
+ if (force_gather) {
|
||||
+ return false;
|
||||
+ }
|
||||
+
|
||||
+ ggml_tensor * K = *k;
|
||||
+ ggml_tensor * V = *v;
|
||||
+ ggml_tensor * M = *kq_mask;
|
||||
+
|
||||
+ const int64_t n_stream = K->ne[3];
|
||||
+ GGML_ASSERT(M->ne[3] == n_stream);
|
||||
+
|
||||
+ const int64_t n_gather = (int64_t) mctx->get_n_gather();
|
||||
+ if (n_gather <= 0) {
|
||||
+ // Worst-case reserve / nothing placed yet: keep the dense [0,n_kv) read.
|
||||
+ return false;
|
||||
+ }
|
||||
+
|
||||
+ // The in-kernel read addresses V along its d-major (non-transposed) axis. If
|
||||
+ // the cache stores V transposed, fall back to gather() (which normalizes it).
|
||||
+ if (V->nb[1] > V->nb[2]) {
|
||||
+ return false;
|
||||
+ }
|
||||
+
|
||||
+ if (debug()) {
|
||||
+ static int64_t once = 0;
|
||||
+ if (once++ < 2) {
|
||||
+ fprintf(stderr, "[paged-attn] in-kernel decode n_stream=%lld n_kv=%lld n_gather=%lld\n",
|
||||
+ (long long) n_stream, (long long) K->ne[2], (long long) n_gather);
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ // Block table [n_gather, n_stream]: column s holds stream s's non-empty cells
|
||||
+ // in token-POSITION order (identical to the gather index, so the reduction
|
||||
+ // order matches stock bit-for-bit), padded with a masked empty cell. Filled
|
||||
+ // at set_input from the kv-cache (get_gather_idxs), exactly like the gather.
|
||||
+ // Pad the logical length to FATTN_KQ_STRIDE (256) so the CUDA fattn vec kernel
|
||||
+ // reads fixed 128-wide KV blocks without overrun and the KV_max mask scan
|
||||
+ // engages; padded entries point at a masked empty cell (0 contribution). Stays
|
||||
+ // <= n_kv since n_kv is itself padded to 256 and n_gather <= n_kv.
|
||||
+ int64_t n_view = GGML_PAD(n_gather, 256);
|
||||
+ if (n_view > K->ne[2]) {
|
||||
+ n_view = K->ne[2];
|
||||
+ }
|
||||
+
|
||||
+ ggml_tensor * idx = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_view, n_stream);
|
||||
+ ggml_set_input(idx);
|
||||
+ res->add_input(llm_graph_input_ptr(new input_block_table(mctx, idx, (uint32_t) n_view)));
|
||||
+
|
||||
+ // Present K and V as [d, h, n_view, ns] VIEWS of the full physical window:
|
||||
+ // identical per-cell (nb1,nb2) and per-stream (nb3) strides, only the cell
|
||||
+ // dim shrinks to n_view. NOT materialized - the kernel reads in place.
|
||||
+ *k = ggml_view_4d(ctx0, K, K->ne[0], K->ne[1], n_view, n_stream,
|
||||
+ K->nb[1], K->nb[2], K->nb[3], 0);
|
||||
+ *v = ggml_view_4d(ctx0, V, V->ne[0], V->ne[1], n_view, n_stream,
|
||||
+ V->nb[1], V->nb[2], V->nb[3], 0);
|
||||
+
|
||||
+ // Compact the mask to [n_gather, n_tps, 1, ns] in the same position order so
|
||||
+ // the kernel's logical mask index aligns with the block table. Cheap: the
|
||||
+ // mask is ~(d*h) smaller than K/V, which is why only its get_rows remains.
|
||||
+ {
|
||||
+ ggml_tensor * m = ggml_reshape_3d(ctx0, M, M->ne[0], M->ne[1], n_stream);
|
||||
+ m = ggml_cont(ctx0, ggml_transpose(ctx0, m));
|
||||
+ m = ggml_get_rows(ctx0, m, idx);
|
||||
+ m = ggml_cont(ctx0, ggml_transpose(ctx0, m));
|
||||
+ m = ggml_reshape_4d(ctx0, m, n_view, M->ne[1], 1, n_stream);
|
||||
+ if (M->type != m->type) {
|
||||
+ m = ggml_cast(ctx0, m, M->type);
|
||||
+ }
|
||||
+ *kq_mask = m;
|
||||
+ }
|
||||
+
|
||||
+ *block_table = idx;
|
||||
+ return true;
|
||||
+}
|
||||
+
|
||||
} // namespace paged_attn
|
||||
diff --git a/src/paged-attn.h b/src/paged-attn.h
|
||||
index c5b7bd7..23e2184 100644
|
||||
--- a/src/paged-attn.h
|
||||
+++ b/src/paged-attn.h
|
||||
@@ -37,4 +37,22 @@ void gather(ggml_context * ctx0,
|
||||
ggml_tensor ** v,
|
||||
ggml_tensor ** kq_mask);
|
||||
|
||||
+// [paged inc1] In-kernel paged decode read. Instead of materializing the
|
||||
+// sequence's cells (gather()), present K and V as n_gather-length VIEWS of the
|
||||
+// full physical window and return the position-ordered physical-cell index list
|
||||
+// as a block table (src[5] of ggml_flash_attn_ext). The fattn kernel/op then
|
||||
+// reads K_base + block_table[j]*nb in-kernel, removing the get_rows of K and V
|
||||
+// (the bulk of the gather). On return (true): *k,*v point at the views, *kq_mask
|
||||
+// at the compacted mask, *block_table at the I32 [n_gather, n_stream] index.
|
||||
+// Returns false (leaving *k,*v,*kq_mask untouched) when the in-kernel path does
|
||||
+// not apply - env off, nothing placed, or a transposed V cache - so the caller
|
||||
+// keeps the dense gather()/contiguous read.
|
||||
+bool in_kernel_decode(ggml_context * ctx0,
|
||||
+ llm_graph_result * res,
|
||||
+ const llama_kv_cache_context * mctx,
|
||||
+ ggml_tensor ** k,
|
||||
+ ggml_tensor ** v,
|
||||
+ ggml_tensor ** kq_mask,
|
||||
+ ggml_tensor ** block_table);
|
||||
+
|
||||
} // namespace paged_attn
|
||||
--
|
||||
2.43.0
|
||||
|
||||
Reference in New Issue
Block a user