feat(paged): chunked parallel-scan GDN prefill kernel (patch 0031)

Adds patch 0031 to the paged llama.cpp series: an FLA-style chunked
parallel-scan prefill kernel for gated DeltaNet (the upstream
gated_delta_net.cu "Add chunked kernel for even faster pre-fill" TODO).
Scope: non-KDA scalar gate, f32 state, final-state-only, homogeneous.

Bit-exact-benign (NEW per-path): test-backend-ops GATED_DELTA_NET 91/91 within
the 1e-7 NMSE gate vs the CPU reference (patch adds 8 S_v=128 prefill cases:
exact-multiple / tail / multi-seq / GQA / permuted); numpy prototype confirms
f32 chunked-vs-sequential NMSE ~1e-13.

OPT-IN, default OFF: GB10's 99KB dynamic-smem opt-in forces C=16 (the 128x128
f32 state is 64KB of the all-shared layout), pinning the kernel to 1 block/SM
with serial dk-reductions. Measured ~761 t/s chunked vs ~971 t/s sequential
(~22%% slower) on q36-27b-nvfp4 prefill, so it defaults OFF (enable with
GDN_CHUNK_MIN=<n>); the backend default is regression-free. Beating the
84.7%-of-peak sequential scan needs tensor-core matmuls / register-resident
state with larger chunks (recorded in README section 5).

Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-06-28 17:09:38 +00:00
parent 11128cb080
commit e610347367
2 changed files with 374 additions and 2 deletions

View File

@@ -86,9 +86,9 @@ orthogonal to the paged allocator.
---
## 3. Patch series (0001-0030)
## 3. Patch series (0001-0031)
28 patches (0005 and 0027 are intentionally unused). "Bit-exact" = greedy md5 /
29 patches (0005 and 0027 are intentionally unused). "Bit-exact" = greedy md5 /
`test-backend-ops` byte-identical to the relevant baseline; the gate methodology
is in section 5.
@@ -149,6 +149,7 @@ These are the dominant decode levers on the Qwen3.6 hybrid models. All bit-exact
| 0024 | **Paged-pool burst-reclaim** - truncate trailing blocks on partial-tail `seq_rm`, defrag the free queue when idle, release blocks on slot completion. Fixes the long-server burst-degradation bug (post-burst prefill collapse 488->44 t/s, restored to 532). Host-side accounting only. | yes |
| 0029 | **Block-table within-step host cache** - the block table is fixed for the whole step; cache it on first build and memcpy it for the other full-attention layers (get_block_table -87%/-91%). | yes, per path (paged-MoE ref `8cb0ce23`) |
| 0030 | **Fused-op backend gate** - the fused GDN / discriminated SSM_CONV ops are CUDA-family + CPU only; force them off on any non-CUDA compute backend so a Vulkan/SYCL/Metal build can't silently run the wrong plain-conv kernel. | yes on CUDA (byte-identical pre-0030); safety gate elsewhere |
| 0031 | **Chunked parallel-scan GDN prefill kernel** (upstream TODO) - FLA-style chunked gated-delta-rule for prefill (non-KDA / f32 / final-state): intra-chunk delta rule solved in parallel (UT-transform + forward subst), inter-chunk recurrence over n_tokens/C steps. **OPT-IN, default OFF** - bit-exact-benign but not yet faster than the tuned sequential scan at the GB10-forced C=16 (see section 5). Enable with `GDN_CHUNK_MIN=<n>`. | NEW per-path (`test-backend-ops` 91/91, <=1e-7 NMSE vs CPU ref) |
> **Dropped: patch 0026 (hybrid per-head bf16 SSM state, `ssm_bf16_tau`).** Once
> the decode fusions (0028 recurrent-state gather-fusion + 0029 block-table cache)
@@ -314,6 +315,23 @@ llama is losing. The MoE GEMM kernel is *not* where the gap lives.
needs bought with a ~5% slower kernel; both kernels are already at the BW floor.
(The "the win was NVFP4-dense-quant, not the Marlin kernel" dense verdict
carries over to MoE.)
- **Chunked parallel-scan GDN prefill (patch 0031): CORRECT, FLAT-to-SLOWER at
C=16; kept OPT-IN.** Implements the upstream "faster pre-fill" TODO - the
FLA-style chunked gated-delta-rule (intra-chunk delta rule solved in parallel
via the UT-transform + forward substitution, inter-chunk recurrence over
n_tokens/C steps). The math is validated equivalent (numpy f32 NMSE ~1e-13;
`test-backend-ops` 91/91 within the 1e-7 NMSE gate, a NEW per-path result).
**But GB10's 99KB dynamic-smem opt-in forces C=16** (the 128x128 f32 state alone
is 64KB of the all-shared layout), which pins the kernel to 1 block/SM and
serial per-thread dk-reductions; measured S_PP (q36-27b-nvfp4, `-npp 512 -ntg 4
-npl 32`) is **~761 t/s chunked vs ~971 t/s sequential (~22% slower)**, also
grid-starved at low n_seqs. So it ships default-OFF (`GDN_CHUNK_MIN=<n>` to
enable). To actually beat the (already 84.7%-of-peak) sequential scan the
follow-up must lift the occupancy ceiling and the serial reductions: either
register-resident state with static-unrolled larger chunks, or tensor-core
(mma/wgmma) matmuls for the KK/QK/KS/QS/PU products and the A-inverse - the
structure FLA/vLLM use. Lesson: at this head dim the win needs tensor cores,
not just chunking.
**Opt-in bf16-SSM fast mode - DROPPED (was patch 0026, `ssm_bf16_tau`).** The
design premise - that bf16 KL error concentrates in long-memory heads and can be

View File

@@ -0,0 +1,354 @@
From c9bf1bd0000000000000000000000000000031aa Mon Sep 17 00:00:00 2001
From: Ettore Di Giacinto <mudler@localai.io>
Date: Sun, 28 Jun 2026 12:00:00 +0000
Subject: [PATCH] feat(paged): chunked parallel-scan GDN prefill kernel (patch 0031)
Implements the explicit upstream TODO at gated_delta_net.cu's
launch_gated_delta_net ("Add chunked kernel for even faster pre-fill"). The
stock kernel runs a strictly sequential per-token recurrence (one block per
(head,seq) looping over all n_tokens), so prefill cannot use token-level
parallelism - a confirmed gap versus vLLM, which uses an FLA-style chunked
scan.
What this adds
--------------
A chunked parallel-scan prefill path for gated DeltaNet, gated to the
compile-time subset that matters for Qwen3.6 prefill: non-KDA (scalar gate),
f32 state, final-state-only (keep_rs == false), homogeneous (non-hybrid,
non-bf16-state). One block per (head,seq); thread j owns the j-th v-column.
The sequence is split into chunks of C tokens: the inter-chunk recurrence in
the state S stays sequential (n_tokens/C steps instead of n_tokens), while the
intra-chunk gated delta rule is solved in parallel via the FLA chunked form:
gamma_t = prod_{i<=t} g_i (<=1), d(j,t) = gamma_t / gamma_j in (0,1]
A = I + tril(beta_t d(j,t) (k_t . k_j), -1) [unit lower-tri, C x C]
U = A^{-1} ( beta_t (v_t - gamma_t S0^T k_t) ) (forward substitution)
O_t = gamma_t (S0^T q_t) + sum_{j<=t} d(j,t)(q_t . k_j) u_j (then * scale)
S_C = gamma_C S0 + sum_t d(t,C) k_t u_t^T
This uses the bounded/stable de-gating (pairwise decays d <= 1, gamma <= 1), so
strong-decay tokens underflow to the correct zero rather than to inf - it is
numerically robust even for the adversarial g in [-20, -1e-4] of the op test.
Bit-exactness (NEW per-path)
----------------------------
The chunked form is mathematically equivalent to the sequential recurrence but
reduces in a different FP order, so it is a NEW path (its md5 will not match the
sequential path), gated exactly like the paged-vs-nonpaged precedent. A numpy
prototype confirms f32 chunked-vs-sequential NMSE ~1e-13 (max abs ~1e-7).
test-backend-ops GATED_DELTA_NET is 91/91 (this patch adds 8 S_v=128 prefill
cases: exact-multiple / tail / multi-seq / GQA / permuted), i.e. within the
default 1e-7 NMSE gate versus the CPU reference.
Disposition: OPT-IN, default OFF (no regression)
------------------------------------------------
GB10's max dynamic shared-memory opt-in is 99KB, so the all-shared layout that
keeps the 128x128 state resident forces C=16 (89KB). At C=16, with one block /
SM (the 64KB state dominates shared) and serial per-thread dk-reductions, the
kernel is correct but NOT yet faster than the already-tuned sequential
recurrence: measured S_PP on q36-27b-nvfp4 (llama-batched-bench -npp 512 -ntg 4
-npl 32) is ~761 t/s chunked vs ~971 t/s sequential (~22% slower, also
grid-starved at low n_seqs). It is therefore wired OPT-IN: the default
(no env) keeps the sequential path, and the chunked path is enabled with
GDN_CHUNK_MIN=<token-threshold>. The default backend behaviour is unchanged.
cudaFuncSetAttribute's return is checked (a silent failure when the requested
dynamic smem exceeded the device opt-in left a sticky CUDA error during
bring-up).
Remaining work to make it a win (recorded for the follow-up): break the 1
block/SM occupancy ceiling (the 64KB state in shared) and the serial
dk-reductions - either register-resident state with static-unrolled (larger)
chunks, or tensor-core (mma/wgmma) matmuls for the KK/QK/KS/QS/PU products and
the A-inverse, which is what FLA/vLLM use to beat the sequential scan. See
README section 5 (dev notes / rejected-flat levers).
Assisted-by: Claude:opus-4.8 [Claude Code]
---
ggml/src/ggml-cuda/gated_delta_net.cu | 235 ++++++++++++++++++++++++++++++++++++++++++++++++++++
tests/test-backend-ops.cpp | 8 ++++++++
2 files changed, 243 insertions(+)
diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu
index 830118a..c9bf1bd 100644
--- a/ggml/src/ggml-cuda/gated_delta_net.cu
+++ b/ggml/src/ggml-cuda/gated_delta_net.cu
@@ -1,6 +1,7 @@
#include "gated_delta_net.cuh"
#include "ggml-cuda/common.cuh"
+#include <climits>
#include <cstdlib>
#include <cuda_bf16.h>
#include <type_traits>
@@ -407,6 +408,219 @@ static void launch_gdn_variant(
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K, state_dst_d, ids_d, rs_head, hyb);
}
+// ============================================================================
+// CHUNKED parallel-scan prefill kernel (upstream TODO: "faster pre-fill").
+// Scope: non-KDA (scalar gate), f32 state, final-state-only (keep_rs==false),
+// homogeneous (non-hybrid) path. One block per (head, seq); thread j owns the
+// j-th v-column. The sequence is split into chunks of C tokens; the inter-chunk
+// recurrence in S is sequential (n_tokens/C steps instead of n_tokens), and the
+// intra-chunk gated delta rule is solved in parallel via the FLA chunked form:
+// gamma_t = prod_{i<=t} g_i (<=1), d(j,t) = gamma_t / gamma_j in (0,1]
+// A = I + tril(beta_t d(j,t) (k_t . k_j), -1) [Cc x Cc unit lower-tri]
+// U = A^{-1} ( beta_t (v_t - gamma_t S0^T k_t) ) [Cc x dv] (fwd subst)
+// O_t = gamma_t (S0^T q_t) + sum_{j<=t} d(j,t)(q_t . k_j) u_j (then * scale)
+// S_C = gamma_C S0 + sum_t d(t,C) k_t u_t^T
+// This is the bounded/stable de-gating (pairwise decays d <= 1, gamma <= 1), so
+// strong-decay tokens underflow to the correct zero rather than to inf. The math
+// is equivalent to the sequential recurrence up to FP reduction order (a NEW
+// per-path result, validated benign by test-backend-ops NMSE and greedy output).
+template <int S_v, int C>
+__global__ void gated_delta_net_chunked_cuda(
+ const float * __restrict__ q, const float * __restrict__ k,
+ const float * __restrict__ v, const float * __restrict__ g,
+ const float * __restrict__ beta, const float * __restrict__ curr_state,
+ float * __restrict__ dst,
+ int64_t H, int64_t n_tokens, int64_t n_seqs,
+ int64_t sq1, int64_t sq2, int64_t sq3,
+ int64_t sv1, int64_t sv2, int64_t sv3,
+ int64_t sb1, int64_t sb2, int64_t sb3,
+ uint3 neqk1_magic, uint3 rq3_magic,
+ float scale, float * __restrict__ state_dst,
+ const int32_t * __restrict__ ids, int rs_head) {
+ constexpr int dk = S_v;
+ constexpr int dv = S_v;
+ const int h_idx = blockIdx.x;
+ const int seq = blockIdx.y;
+ const int j = threadIdx.x; // this thread's v-column (0..dv-1)
+
+ const uint32_t iq1 = fastmodulo((uint32_t) h_idx, neqk1_magic);
+ const uint32_t iq3 = fastdiv((uint32_t) seq, rq3_magic);
+
+ extern __shared__ float gdn_smem[];
+ float * Sd = gdn_smem; // [dk*dv] M-layout: Sd[col*dk + i] = S[i][col]
+ float * Kc = Sd + (size_t) dk * dv; // [C*dk] Kc[t*dk + i]
+ float * Qc = Kc + (size_t) C * dk; // [C*dk] Qc[t*dk + i]
+ float * Ud = Qc + (size_t) C * dk; // [dv*C] column-major per thread: Ud[col*C + t]
+ float * Amat = Ud + (size_t) dv * C; // [C*C] A / P scratch, row-major Amat[t*C + t']
+ float * csh = Amat + (size_t) C * C; // [C] cumsum(log-gate)
+ float * gam = csh + C; // [C] gamma_t = exp(cs_t)
+ float * bet = gam + C; // [C] beta_t
+
+ // S0: thread j owns column j (Sd[j*dk + i]); load is a contiguous per-thread copy from the
+ // M-layout cache view (read_state[j*dk + i] = M[j*S_v + i] = S[i][j]). Same identity/gather
+ // plumbing as the sequential kernel (gather of non-identity seqs done by the dispatcher).
+ const bool identity = (ids != nullptr && ids[seq] == rs_head + seq);
+ const float * read_state = (identity ? state_dst : curr_state)
+ + (int64_t) seq * H * dk * dv + (int64_t) h_idx * dk * dv;
+ for (int i = 0; i < dk; i++) {
+ Sd[j * dk + i] = read_state[j * dk + i];
+ }
+
+ const float * q_base = q + iq3 * sq3 + iq1 * sq1; // + t*sq2 + i
+ const float * k_base = k + iq3 * sq3 + iq1 * sq1;
+ const float * v_base = v + seq * sv3 + h_idx * sv1; // + t*sv2 + j
+ const int64_t gb_base = seq * sb3 + h_idx * sb1; // + t*sb2
+
+ float * attn_base = dst + (int64_t) (seq * n_tokens * H + h_idx) * S_v; // + tok*S_v*H + j
+
+ for (int64_t c0 = 0; c0 < n_tokens; c0 += C) {
+ const int Cc = (int) ((n_tokens - c0) < (int64_t) C ? (n_tokens - c0) : (int64_t) C);
+
+ // --- load chunk K,Q (cooperative), beta and the gate prefix (cs, gamma) ---
+ for (int e = j; e < Cc * dk; e += dv) {
+ const int t = e / dk;
+ const int i = e % dk;
+ Kc[t * dk + i] = k_base[(c0 + t) * sq2 + i];
+ Qc[t * dk + i] = q_base[(c0 + t) * sq2 + i];
+ }
+ if (j < Cc) {
+ csh[j] = g[gb_base + (c0 + j) * sb2]; // raw log-gate, prefix-summed below
+ bet[j] = beta[gb_base + (c0 + j) * sb2];
+ }
+ __syncthreads();
+ if (j == 0) {
+ float run = 0.0f;
+ for (int t = 0; t < Cc; t++) {
+ run += csh[t];
+ csh[t] = run; // cs_t = sum_{i<=t} g_i (<= 0)
+ gam[t] = expf(run); // gamma_t (<= 1)
+ }
+ }
+ __syncthreads();
+
+ // --- A = I + tril(beta_t * d(t',t) * (k_t . k_t'), -1) (cooperative over C*C) ---
+ for (int e = j; e < Cc * Cc; e += dv) {
+ const int t = e / Cc;
+ const int tp = e % Cc;
+ float a = 0.0f;
+ if (tp < t) {
+ float kk = 0.0f;
+ for (int i = 0; i < dk; i++) {
+ kk += Kc[t * dk + i] * Kc[tp * dk + i];
+ }
+ const float dd = expf(csh[t] - csh[tp]); // d(tp,t) = gamma_t/gamma_tp
+ a = bet[t] * dd * kk;
+ } else if (tp == t) {
+ a = 1.0f;
+ }
+ Amat[t * Cc + tp] = a;
+ }
+ __syncthreads();
+
+ // --- RHS[t][j] = beta_t (v_t[j] - gamma_t * (S0^T k_t)[j]) -> Ud[j*C + t] ---
+ for (int t = 0; t < Cc; t++) {
+ float ks = 0.0f; // (S0^T k_t)[j] = sum_i S[i][j] k_t[i]
+ for (int i = 0; i < dk; i++) {
+ ks += Sd[j * dk + i] * Kc[t * dk + i];
+ }
+ const float vtj = v_base[(c0 + t) * sv2 + j];
+ Ud[j * C + t] = bet[t] * (vtj - gam[t] * ks);
+ }
+
+ // --- solve A U = RHS in place (unit lower-tri fwd subst); per-thread, no inter-step sync ---
+ for (int t = 1; t < Cc; t++) {
+ float acc = Ud[j * C + t];
+ for (int tp = 0; tp < t; tp++) {
+ acc -= Amat[t * Cc + tp] * Ud[j * C + tp];
+ }
+ Ud[j * C + t] = acc;
+ }
+ __syncthreads(); // U finalized; Amat free for P below (and Ud read across-thread? no, own col)
+
+ // --- P[t][t'] = d(t',t) * (q_t . k_t') for t' <= t (reuse Amat) ---
+ for (int e = j; e < Cc * Cc; e += dv) {
+ const int t = e / Cc;
+ const int tp = e % Cc;
+ float p = 0.0f;
+ if (tp <= t) {
+ float qk = 0.0f;
+ for (int i = 0; i < dk; i++) {
+ qk += Qc[t * dk + i] * Kc[tp * dk + i];
+ }
+ const float dd = expf(csh[t] - csh[tp]);
+ p = dd * qk;
+ }
+ Amat[t * Cc + tp] = p;
+ }
+ __syncthreads();
+
+ // --- O[t][j] = gamma_t (S0^T q_t)[j] + sum_{t'<=t} P[t][t'] U[t'][j] (* scale) ---
+ for (int t = 0; t < Cc; t++) {
+ float qs = 0.0f; // (S0^T q_t)[j] (uses pre-update S)
+ for (int i = 0; i < dk; i++) {
+ qs += Sd[j * dk + i] * Qc[t * dk + i];
+ }
+ float o = gam[t] * qs;
+ for (int tp = 0; tp <= t; tp++) {
+ o += Amat[t * Cc + tp] * Ud[j * C + tp];
+ }
+ attn_base[(c0 + t) * S_v * H + j] = o * scale;
+ }
+
+ // --- S_C[i][j] = gamma_{C-1} S[i][j] + sum_t d(t,C-1) k_t[i] u_t[j] ---
+ const float glast = gam[Cc - 1];
+ const float cslast = csh[Cc - 1];
+ for (int i = 0; i < dk; i++) {
+ float s = glast * Sd[j * dk + i];
+ for (int t = 0; t < Cc; t++) {
+ const float dd = expf(cslast - csh[t]); // d(t, last)
+ s += dd * Kc[t * dk + i] * Ud[j * C + t];
+ }
+ Sd[j * dk + i] = s;
+ }
+ __syncthreads(); // Sd reused as S0 of next chunk; Kc/Qc/Amat reloaded next chunk
+ }
+
+ // --- final-state write-back (M-layout): in-place cache view or f32 op-output scratch ---
+ const int64_t state_out_offset = (int64_t) (seq * H + h_idx) * S_v * S_v;
+ const int64_t attn_score_elems = (int64_t) S_v * H * n_tokens * n_seqs;
+ float * st = (state_dst != nullptr) ? (state_dst + state_out_offset)
+ : (dst + attn_score_elems + state_out_offset);
+ for (int i = 0; i < dk; i++) {
+ st[j * dk + i] = Sd[j * dk + i];
+ }
+}
+
+template <int S_v, int C>
+static void launch_gdn_chunked(
+ const float * q_d, const float * k_d, const float * v_d,
+ const float * g_d, const float * b_d, const float * s_d,
+ float * dst_d, float * state_dst_d, const int32_t * ids_d, int rs_head,
+ int64_t H, int64_t n_tokens, int64_t n_seqs,
+ int64_t sq1, int64_t sq2, int64_t sq3,
+ int64_t sv1, int64_t sv2, int64_t sv3,
+ int64_t sb1, int64_t sb2, int64_t sb3,
+ const uint3 neqk1_magic, const uint3 rq3_magic,
+ float scale, cudaStream_t stream) {
+ const size_t smem = ((size_t) S_v * S_v + (size_t) 2 * C * S_v + (size_t) S_v * C
+ + (size_t) C * C + (size_t) 3 * C) * sizeof(float);
+ static bool attr_set = false;
+ if (!attr_set) {
+ const cudaError_t e = cudaFuncSetAttribute(gated_delta_net_chunked_cuda<S_v, C>,
+ cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smem);
+ if (e != cudaSuccess) {
+ GGML_ABORT("gdn chunked: cudaFuncSetAttribute(maxDynSmem=%zu) failed: %s\n", smem, cudaGetErrorString(e));
+ }
+ attr_set = true;
+ }
+ dim3 grid_dims(H, n_seqs, 1);
+ dim3 block_dims(S_v, 1, 1);
+ gated_delta_net_chunked_cuda<S_v, C><<<grid_dims, block_dims, smem, stream>>>(
+ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs,
+ sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3,
+ neqk1_magic, rq3_magic, scale, state_dst_d, ids_d, rs_head);
+}
+
template <bool KDA, bool keep_rs_t, bool STATE_BF16, bool HYBRID>
static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
@@ -425,6 +639,27 @@ static void launch_gated_delta_net(
const uint3 neqk1_magic = init_fastdiv_values(neqk1);
const uint3 rq3_magic = init_fastdiv_values(rq3);
+ // Chunked parallel-scan prefill path (upstream TODO at this site). Compile-time subset:
+ // non-KDA scalar gate, f32 state, final-state-only, homogeneous. Gated at runtime on the GDN
+ // head dim (S_v==128) and a prefill token threshold; decode (n_tokens small) keeps the tuned
+ // sequential recurrence. Mathematically equivalent up to FP reduction order (NEW per-path md5;
+ // validated benign by test-backend-ops NMSE + greedy output). Toggle: GDN_CHUNK_OFF / GDN_CHUNK_MIN.
+ if constexpr (!KDA && !keep_rs_t && !STATE_BF16 && !HYBRID) {
+ // OPT-IN: this chunked path is bit-exact-benign (test-backend-ops green) but, at C=16
+ // (forced by GB10 99KB dyn-smem opt-in, all-shared), it is NOT yet faster than the tuned
+ // sequential recurrence on this model (measured ~22%% slower S_PP, grid-starved at low
+ // n_seqs + 1 block/SM occupancy). Default OFF so the backend default is regression-free;
+ // enable for experiments / tuning with GDN_CHUNK_MIN=<token-threshold>. See README section 5 (dev notes / rejected-flat levers).
+ static const int gdn_chunk_min = []{ const char * e = getenv("GDN_CHUNK_MIN"); return e ? atoi(e) : INT_MAX; }();
+ if (S_v == 128 && n_tokens >= gdn_chunk_min) {
+ launch_gdn_chunked<128, 16>(
+ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head,
+ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3,
+ neqk1_magic, rq3_magic, scale, stream);
+ return;
+ }
+ }
+
#define GDN_LAUNCH_ARGS \
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d, ids_d, rs_head, hyb, \
H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, \
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index c0233eb..951bffc 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -9459,6 +9459,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 100, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 200, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 127, 2));
+ // chunked parallel-scan prefill path (S_v==128, n_tokens>=64): exact-multiple, tail, multi-seq, perm
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 64, 1));
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 128, 1));
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 127, 1));
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 256, 1));
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 100, 2));
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 2, 128, 200, 3));
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 130, 1, 1, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 64, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 33, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 100, 1, 1, false, true));
--
2.43.0