From 2c5adda28cedac87958778aed318805dfa37b365 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 22 Jun 2026 20:37:12 +0000 Subject: [PATCH] feat(paged): tile in-kernel decode read + dispatch guard (patch 0010) Increment 2 (robustness): graft the patch-0009 phys(j) block-table read into the CUDA tile kernel (mirror of fattn-vec.cuh) and add a dispatch guard so a present block table (src[5]) routes ONLY to the vec or tile kernel, never to mma/wmma (which ignore the table and would silently read the wrong physical cells). Default route stays vec, the inc-1 byte-validated path. Gates: CPU byte-identical paged-on vs off (Qwen3-0.6B) PASS; GPU vec-paged == stock at -s 1 PASS; the real Qwen3-32B NVFP4 batch decode confirmed dispatching to vec (Q ne=[128,1,64,N]). The tile graft is plumbed for the increment-3 GQA head-group reuse but is EXPERIMENTAL/not byte-validated (LLAMA_KV_PAGED_TILE=1): the GQA-grouped ncols2>1 tile path reads a full nbatch_fa tile unbounded while the compacted paged mask is not padded to cover it. Bounding that path is increment-3 work; the default vec route is unaffected. Assisted-by: Claude:opus-4.8 [Claude Code] Signed-off-by: Ettore Di Giacinto --- ...nd-dispatch-guard-env-LLAMA_KV_PAGED.patch | 269 ++++++++++++++++++ 1 file changed, 269 insertions(+) create mode 100644 backend/cpp/llama-cpp/patches/paged/0010-paged-tile-in-kernel-read-and-dispatch-guard-env-LLAMA_KV_PAGED.patch diff --git a/backend/cpp/llama-cpp/patches/paged/0010-paged-tile-in-kernel-read-and-dispatch-guard-env-LLAMA_KV_PAGED.patch b/backend/cpp/llama-cpp/patches/paged/0010-paged-tile-in-kernel-read-and-dispatch-guard-env-LLAMA_KV_PAGED.patch new file mode 100644 index 000000000..1e6a5a57f --- /dev/null +++ b/backend/cpp/llama-cpp/patches/paged/0010-paged-tile-in-kernel-read-and-dispatch-guard-env-LLAMA_KV_PAGED.patch @@ -0,0 +1,269 @@ +From 9ac56933abd5de4a1f349c811c2d74aab09f7ab1 Mon Sep 17 00:00:00 2001 +From: Ettore Di Giacinto +Date: Mon, 22 Jun 2026 22:36:09 +0200 +Subject: [PATCH] paged tile in-kernel decode read + dispatch guard (env + LLAMA_KV_PAGED) - patch 0010 + +Increment 2 (robustness, ~0 headline ms): make the paged in-kernel decode read +safe against silent mis-routing, and plumb the same read into the tile kernel +for the increment-3 GQA head-group work. + +fattn-tile.cuh: graft the patch-0009 phys(j) block-table read (mirror of +fattn-vec.cuh). Both flash_attn_tile_load_tile overloads, flash_attn_tile_iter_KQ +(K) and flash_attn_tile_iter (V) take an optional per-sequence block table; a row +i is read from base + block_table[row_base + i]*stride instead of base + i*stride. +The table defaults to nullptr (default args + a null bt_seq when src[5] is unset), +so every existing non-paged caller is byte-identical to stock. The mask / KV_max +stay logical (token-position order), as in vec. + +fattn.cu: DISPATCH GUARD. When the block table (src[5]) is present, route ONLY to +the vec or tile kernel and never fall through to the best-kernel switch. The +mma/wmma kernels GGML_UNUSED the table and would silently read the wrong +(contiguous physical) cells; the guard makes that unreachable. The vec dispatcher +GGML_ABORTs for an unsupported D/type rather than mis-reading. Default route is vec +(the inc-1 byte-validated path). LLAMA_KV_PAGED_DISPATCH_LOG=1 prints the routed +kernel once. + +Gates: CPU byte-identical paged-on vs off (Qwen3-0.6B, build-cpu) PASS. GPU +vec-paged == stock at -s 1 PASS. Dispatch confirmed VEC for the real decode shape: +Qwen3-0.6B Q ne=[128,1,16,1] and Qwen3-32B NVFP4 Q ne=[128,1,64,N] both route to +vec, matching the nsys profile (flash_attn_ext_vec). + +The tile graft is plumbed for increment-3 GQA head-group reuse but is EXPERIMENTAL +and NOT yet byte-validated (LLAMA_KV_PAGED_TILE=1). A tile-vs-tile gate shows +tile-paged diverging from tile-stock at the first cross-tile KV depth: the +GQA-grouped (ncols2>1) tile path reads a full nbatch_fa-row tile with +oob_check=false while the compacted paged mask is not padded to cover the tile, so +past-end rows leak. vec bounds its KV walk by KV_max and is unaffected. Bounding +the tile path is increment-3 work; the default vec route and all stock paths are +untouched. + +Assisted-by: Claude:opus-4.8 [Claude Code] +Signed-off-by: Ettore Di Giacinto +--- + ggml/src/ggml-cuda/fattn-tile.cuh | 45 ++++++++++++++++++++----------- + ggml/src/ggml-cuda/fattn.cu | 38 +++++++++++++++++++++++--- + 2 files changed, 64 insertions(+), 19 deletions(-) + +diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh +index 0ff14e6..bb84d61 100644 +--- a/ggml/src/ggml-cuda/fattn-tile.cuh ++++ b/ggml/src/ggml-cuda/fattn-tile.cuh +@@ -373,7 +373,8 @@ static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, + // TODO: deduplicate with mma-f16 + template + static __device__ __forceinline__ void flash_attn_tile_load_tile( +- const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { ++ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup, ++ const int * const __restrict__ block_table = nullptr, const int row_base = 0) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + +@@ -402,9 +403,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( + const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; + + const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}}; ++ // [paged] remap the row through the block table (nullptr => stock contiguous read). ++ const half2 * const KV_row = block_table ? KV + (int64_t) block_table[row_base + i]*stride_KV : KV + i*stride_KV; + ggml_cuda_memcpy_1( + tile_KV + i*(J/2 + J_padding) + j, +- !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); ++ !oob_check || i < i_sup ? KV_row + j : zero); + } + } + } +@@ -423,7 +426,8 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( + + template + static __device__ __forceinline__ void flash_attn_tile_load_tile( +- const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { ++ const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup, ++ const int * const __restrict__ block_table = nullptr, const int row_base = 0) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + +@@ -453,8 +457,10 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( + + const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; + __align__(16) half2 tmp_h2[cpy_ne/2]; ++ // [paged] remap the row through the block table (nullptr => stock contiguous read). ++ const half2 * const KV_row = block_table ? KV + (int64_t) block_table[row_base + i]*stride_KV : KV + i*stride_KV; + ggml_cuda_memcpy_1( +- tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); ++ tmp_h2, !oob_check || i < i_sup ? KV_row + j : zero); + + __align__(16) float2 tmp_f2[cpy_ne/2]; + #pragma unroll +@@ -487,6 +493,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( + const int k_VKQ_0, + const int k_VKQ_sup, + const int k_KQ_0, ++ const int * const __restrict__ block_table, + float * KQ_acc) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; +@@ -495,8 +502,10 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + ++ // [paged] when block_table is set K_h2 is the un-offset base; the table supplies the row. ++ const half2 * const K_base = block_table ? (K_h2 + k_KQ_0/2) : (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2); + flash_attn_tile_load_tile +- (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); ++ (K_base, KV_tmp, stride_K2, k_VKQ_sup, block_table, k_VKQ_0); + __syncthreads(); + + #ifdef FAST_FP16_AVAILABLE +@@ -572,7 +581,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( + T_acc * const VKQ, + const int k_VKQ_0, + const int k_VKQ_max, +- const int col_Q_0) { ++ const int col_Q_0, ++ const int * const __restrict__ block_table) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + +@@ -605,12 +615,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter( + #pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { + flash_attn_tile_iter_KQ( +- Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); ++ Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, block_table, KQ_acc); + } + if (nbatch_K_last > 0) { + constexpr int k_KQ_0 = DKQ - nbatch_K_last; + flash_attn_tile_iter_KQ( +- Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); ++ Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, block_table, KQ_acc); + } + + // Apply logit softcap + mask, update KQ_max: +@@ -715,8 +725,10 @@ static __device__ __forceinline__ void flash_attn_tile_iter( + static_assert(nbatch_V % np == 0, "bad nbatch_V"); + #pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { ++ // [paged] when block_table is set V_h2 is the un-offset base; the table supplies the row. ++ const half2 * const V_base = block_table ? V_h2 : (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2); + flash_attn_tile_load_tile +- (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); ++ (V_base, KV_tmp, stride_V2, k_VKQ_sup - k0, block_table, k_VKQ_0 + k0); + __syncthreads(); + + #ifdef FAST_FP16_AVAILABLE +@@ -810,7 +822,6 @@ static __global__ void flash_attn_tile( + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33, + const int * __restrict__ block_table) { +- GGML_UNUSED(block_table); // [paged] block table is honored only by the vec kernel + #ifdef FLASH_ATTN_AVAILABLE + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; +@@ -837,7 +848,7 @@ static __global__ void flash_attn_tile( + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, +- nb31, nb32, nb33); ++ nb31, nb32, nb33, block_table); + NO_DEVICE_CODE; + return; + } +@@ -861,6 +872,10 @@ static __global__ void flash_attn_tile( + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape + ++ // [paged] per-sequence logical->physical block table in token-position order ++ // (mask/KV_max stay logical); nullptr => the stock contiguous read. ++ const int * const __restrict__ bt_seq = block_table ? block_table + (size_t) sequence*ne11 : nullptr; ++ + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr; + + const int stride_K2 = nb11 / sizeof(half2); +@@ -963,14 +978,14 @@ static __global__ void flash_attn_tile( + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, +- stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); ++ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, bt_seq); + k_VKQ_0 += gridDim.y*nbatch_fa; + } + if (k_VKQ_0 < k_VKQ_max) { + constexpr bool oob_check = true; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, +- stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); ++ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, bt_seq); + } + } else { + // Branch without out-of-bounds checks. +@@ -978,7 +993,7 @@ static __global__ void flash_attn_tile( + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, +- stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); ++ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, bt_seq); + } + } + +@@ -1144,7 +1159,7 @@ static __global__ void flash_attn_tile( + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, +- nb31, nb32, nb33); ++ nb31, nb32, nb33, block_table); + NO_DEVICE_CODE; + #endif // FLASH_ATTN_AVAILABLE + } +diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu +index e3771ee..afcafa2 100644 +--- a/ggml/src/ggml-cuda/fattn.cu ++++ b/ggml/src/ggml-cuda/fattn.cu +@@ -575,11 +575,41 @@ size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * d + void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_set_device(ctx.device); + +- // [paged] the block table (src[5]) is only honored by the vec kernel's +- // in-kernel read; force it. build_attn only sets it for a vec-supported +- // 1-token-per-stream decode shape. ++ // [paged] DISPATCH GUARD. The block table (src[5]) is read in-kernel ONLY by ++ // the vec and tile kernels; the mma/wmma kernels GGML_UNUSED it and would ++ // silently read the wrong (contiguous physical) cells. So when a block table ++ // is present we route here and NEVER fall through to the best-kernel switch ++ // below - no decode shape can silently reach an mma/wmma misread. build_attn ++ // only sets src[5] for the 1-token-per-stream decode shape; the vec ++ // dispatcher GGML_ABORTs for an unsupported D/type rather than mis-reading, ++ // and any shape that should not be paged must take the host-side gather path ++ // (LLAMA_KV_PAGED_GATHER=1) instead. ++ // ++ // Default route = vec (inc-1, byte-validated: vec-paged == stock at -s 1 and ++ // CPU byte-identical). LLAMA_KV_PAGED_TILE=1 routes the same shape to the ++ // tile kernel; the tile in-kernel read is plumbed (fattn-tile.cuh) for the ++ // increment-3 GQA head-group reuse, but is EXPERIMENTAL / NOT yet byte- ++ // validated: the GQA-grouped (ncols2>1) tile path reads a full nbatch_fa tile ++ // with oob_check=false while the compacted paged mask is not padded to cover ++ // it, so it diverges from stock. Not for production paged decode until ++ // increment-3 bounds that path; the default vec route is unaffected. + if (dst->src[5] != nullptr) { +- ggml_cuda_flash_attn_ext_vec(ctx, dst); ++ static const bool paged_tile = getenv("LLAMA_KV_PAGED_TILE") != nullptr; ++ if (getenv("LLAMA_KV_PAGED_DISPATCH_LOG") != nullptr) { ++ static bool logged = false; ++ if (!logged) { ++ logged = true; ++ fprintf(stderr, "[paged] decode src[5] set -> routing to %s (Q ne=[%ld,%ld,%ld,%ld])\n", ++ paged_tile ? "TILE(experimental)" : "VEC", ++ (long) dst->src[0]->ne[0], (long) dst->src[0]->ne[1], ++ (long) dst->src[0]->ne[2], (long) dst->src[0]->ne[3]); ++ } ++ } ++ if (paged_tile) { ++ ggml_cuda_flash_attn_ext_tile(ctx, dst); ++ } else { ++ ggml_cuda_flash_attn_ext_vec(ctx, dst); ++ } + return; + } + +-- +2.43.0 +