feat(paged): tile in-kernel decode read + dispatch guard (patch 0010)

Increment 2 (robustness): graft the patch-0009 phys(j) block-table read into
the CUDA tile kernel (mirror of fattn-vec.cuh) and add a dispatch guard so a
present block table (src[5]) routes ONLY to the vec or tile kernel, never to
mma/wmma (which ignore the table and would silently read the wrong physical
cells). Default route stays vec, the inc-1 byte-validated path.

Gates: CPU byte-identical paged-on vs off (Qwen3-0.6B) PASS; GPU vec-paged ==
stock at -s 1 PASS; the real Qwen3-32B NVFP4 batch decode confirmed dispatching
to vec (Q ne=[128,1,64,N]). The tile graft is plumbed for the increment-3 GQA
head-group reuse but is EXPERIMENTAL/not byte-validated (LLAMA_KV_PAGED_TILE=1):
the GQA-grouped ncols2>1 tile path reads a full nbatch_fa tile unbounded while
the compacted paged mask is not padded to cover it. Bounding that path is
increment-3 work; the default vec route is unaffected.

Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-06-22 20:37:12 +00:00
parent ee13a94a8c
commit 2c5adda28c

View File

@@ -0,0 +1,269 @@
From 9ac56933abd5de4a1f349c811c2d74aab09f7ab1 Mon Sep 17 00:00:00 2001
From: Ettore Di Giacinto <mudler@localai.io>
Date: Mon, 22 Jun 2026 22:36:09 +0200
Subject: [PATCH] paged tile in-kernel decode read + dispatch guard (env
LLAMA_KV_PAGED) - patch 0010
Increment 2 (robustness, ~0 headline ms): make the paged in-kernel decode read
safe against silent mis-routing, and plumb the same read into the tile kernel
for the increment-3 GQA head-group work.
fattn-tile.cuh: graft the patch-0009 phys(j) block-table read (mirror of
fattn-vec.cuh). Both flash_attn_tile_load_tile overloads, flash_attn_tile_iter_KQ
(K) and flash_attn_tile_iter (V) take an optional per-sequence block table; a row
i is read from base + block_table[row_base + i]*stride instead of base + i*stride.
The table defaults to nullptr (default args + a null bt_seq when src[5] is unset),
so every existing non-paged caller is byte-identical to stock. The mask / KV_max
stay logical (token-position order), as in vec.
fattn.cu: DISPATCH GUARD. When the block table (src[5]) is present, route ONLY to
the vec or tile kernel and never fall through to the best-kernel switch. The
mma/wmma kernels GGML_UNUSED the table and would silently read the wrong
(contiguous physical) cells; the guard makes that unreachable. The vec dispatcher
GGML_ABORTs for an unsupported D/type rather than mis-reading. Default route is vec
(the inc-1 byte-validated path). LLAMA_KV_PAGED_DISPATCH_LOG=1 prints the routed
kernel once.
Gates: CPU byte-identical paged-on vs off (Qwen3-0.6B, build-cpu) PASS. GPU
vec-paged == stock at -s 1 PASS. Dispatch confirmed VEC for the real decode shape:
Qwen3-0.6B Q ne=[128,1,16,1] and Qwen3-32B NVFP4 Q ne=[128,1,64,N] both route to
vec, matching the nsys profile (flash_attn_ext_vec).
The tile graft is plumbed for increment-3 GQA head-group reuse but is EXPERIMENTAL
and NOT yet byte-validated (LLAMA_KV_PAGED_TILE=1). A tile-vs-tile gate shows
tile-paged diverging from tile-stock at the first cross-tile KV depth: the
GQA-grouped (ncols2>1) tile path reads a full nbatch_fa-row tile with
oob_check=false while the compacted paged mask is not padded to cover the tile, so
past-end rows leak. vec bounds its KV walk by KV_max and is unaffected. Bounding
the tile path is increment-3 work; the default vec route and all stock paths are
untouched.
Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
---
ggml/src/ggml-cuda/fattn-tile.cuh | 45 ++++++++++++++++++++-----------
ggml/src/ggml-cuda/fattn.cu | 38 +++++++++++++++++++++++---
2 files changed, 64 insertions(+), 19 deletions(-)
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
index 0ff14e6..bb84d61 100644
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
@@ -373,7 +373,8 @@ static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ,
// TODO: deduplicate with mma-f16
template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
static __device__ __forceinline__ void flash_attn_tile_load_tile(
- const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup,
+ const int * const __restrict__ block_table = nullptr, const int row_base = 0) {
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
@@ -402,9 +403,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
+ // [paged] remap the row through the block table (nullptr => stock contiguous read).
+ const half2 * const KV_row = block_table ? KV + (int64_t) block_table[row_base + i]*stride_KV : KV + i*stride_KV;
ggml_cuda_memcpy_1<cpy_nb>(
tile_KV + i*(J/2 + J_padding) + j,
- !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+ !oob_check || i < i_sup ? KV_row + j : zero);
}
}
}
@@ -423,7 +426,8 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
static __device__ __forceinline__ void flash_attn_tile_load_tile(
- const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
+ const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup,
+ const int * const __restrict__ block_table = nullptr, const int row_base = 0) {
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
@@ -453,8 +457,10 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
__align__(16) half2 tmp_h2[cpy_ne/2];
+ // [paged] remap the row through the block table (nullptr => stock contiguous read).
+ const half2 * const KV_row = block_table ? KV + (int64_t) block_table[row_base + i]*stride_KV : KV + i*stride_KV;
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
- tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+ tmp_h2, !oob_check || i < i_sup ? KV_row + j : zero);
__align__(16) float2 tmp_f2[cpy_ne/2];
#pragma unroll
@@ -487,6 +493,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
const int k_VKQ_0,
const int k_VKQ_sup,
const int k_KQ_0,
+ const int * const __restrict__ block_table,
float * KQ_acc) {
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
@@ -495,8 +502,10 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
+ // [paged] when block_table is set K_h2 is the un-offset base; the table supplies the row.
+ const half2 * const K_base = block_table ? (K_h2 + k_KQ_0/2) : (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2);
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
- (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
+ (K_base, KV_tmp, stride_K2, k_VKQ_sup, block_table, k_VKQ_0);
__syncthreads();
#ifdef FAST_FP16_AVAILABLE
@@ -572,7 +581,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
T_acc * const VKQ,
const int k_VKQ_0,
const int k_VKQ_max,
- const int col_Q_0) {
+ const int col_Q_0,
+ const int * const __restrict__ block_table) {
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
@@ -605,12 +615,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {
flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(
- Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+ Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, block_table, KQ_acc);
}
if (nbatch_K_last > 0) {
constexpr int k_KQ_0 = DKQ - nbatch_K_last;
flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(
- Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+ Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, block_table, KQ_acc);
}
// Apply logit softcap + mask, update KQ_max:
@@ -715,8 +725,10 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
static_assert(nbatch_V % np == 0, "bad nbatch_V");
#pragma unroll
for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
+ // [paged] when block_table is set V_h2 is the un-offset base; the table supplies the row.
+ const half2 * const V_base = block_table ? V_h2 : (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2);
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
- (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
+ (V_base, KV_tmp, stride_V2, k_VKQ_sup - k0, block_table, k_VKQ_0 + k0);
__syncthreads();
#ifdef FAST_FP16_AVAILABLE
@@ -810,7 +822,6 @@ static __global__ void flash_attn_tile(
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33,
const int * __restrict__ block_table) {
- GGML_UNUSED(block_table); // [paged] block table is honored only by the vec kernel
#ifdef FLASH_ATTN_AVAILABLE
const char * GGML_CUDA_RESTRICT Q = Q_ptr;
const char * GGML_CUDA_RESTRICT K = K_ptr;
@@ -837,7 +848,7 @@ static __global__ void flash_attn_tile(
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
- nb31, nb32, nb33);
+ nb31, nb32, nb33, block_table);
NO_DEVICE_CODE;
return;
}
@@ -861,6 +872,10 @@ static __global__ void flash_attn_tile(
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
+ // [paged] per-sequence logical->physical block table in token-position order
+ // (mask/KV_max stay logical); nullptr => the stock contiguous read.
+ const int * const __restrict__ bt_seq = block_table ? block_table + (size_t) sequence*ne11 : nullptr;
+
const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr;
const int stride_K2 = nb11 / sizeof(half2);
@@ -963,14 +978,14 @@ static __global__ void flash_attn_tile(
constexpr bool oob_check = false;
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
- stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
+ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, bt_seq);
k_VKQ_0 += gridDim.y*nbatch_fa;
}
if (k_VKQ_0 < k_VKQ_max) {
constexpr bool oob_check = true;
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
- stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
+ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, bt_seq);
}
} else {
// Branch without out-of-bounds checks.
@@ -978,7 +993,7 @@ static __global__ void flash_attn_tile(
constexpr bool oob_check = false;
flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
- stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
+ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, bt_seq);
}
}
@@ -1144,7 +1159,7 @@ static __global__ void flash_attn_tile(
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
- nb31, nb32, nb33);
+ nb31, nb32, nb33, block_table);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index e3771ee..afcafa2 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -575,11 +575,41 @@ size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * d
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_set_device(ctx.device);
- // [paged] the block table (src[5]) is only honored by the vec kernel's
- // in-kernel read; force it. build_attn only sets it for a vec-supported
- // 1-token-per-stream decode shape.
+ // [paged] DISPATCH GUARD. The block table (src[5]) is read in-kernel ONLY by
+ // the vec and tile kernels; the mma/wmma kernels GGML_UNUSED it and would
+ // silently read the wrong (contiguous physical) cells. So when a block table
+ // is present we route here and NEVER fall through to the best-kernel switch
+ // below - no decode shape can silently reach an mma/wmma misread. build_attn
+ // only sets src[5] for the 1-token-per-stream decode shape; the vec
+ // dispatcher GGML_ABORTs for an unsupported D/type rather than mis-reading,
+ // and any shape that should not be paged must take the host-side gather path
+ // (LLAMA_KV_PAGED_GATHER=1) instead.
+ //
+ // Default route = vec (inc-1, byte-validated: vec-paged == stock at -s 1 and
+ // CPU byte-identical). LLAMA_KV_PAGED_TILE=1 routes the same shape to the
+ // tile kernel; the tile in-kernel read is plumbed (fattn-tile.cuh) for the
+ // increment-3 GQA head-group reuse, but is EXPERIMENTAL / NOT yet byte-
+ // validated: the GQA-grouped (ncols2>1) tile path reads a full nbatch_fa tile
+ // with oob_check=false while the compacted paged mask is not padded to cover
+ // it, so it diverges from stock. Not for production paged decode until
+ // increment-3 bounds that path; the default vec route is unaffected.
if (dst->src[5] != nullptr) {
- ggml_cuda_flash_attn_ext_vec(ctx, dst);
+ static const bool paged_tile = getenv("LLAMA_KV_PAGED_TILE") != nullptr;
+ if (getenv("LLAMA_KV_PAGED_DISPATCH_LOG") != nullptr) {
+ static bool logged = false;
+ if (!logged) {
+ logged = true;
+ fprintf(stderr, "[paged] decode src[5] set -> routing to %s (Q ne=[%ld,%ld,%ld,%ld])\n",
+ paged_tile ? "TILE(experimental)" : "VEC",
+ (long) dst->src[0]->ne[0], (long) dst->src[0]->ne[1],
+ (long) dst->src[0]->ne[2], (long) dst->src[0]->ne[3]);
+ }
+ }
+ if (paged_tile) {
+ ggml_cuda_flash_attn_ext_tile(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec(ctx, dst);
+ }
return;
}
--
2.43.0