From 7b38c6b2a3908e003751b4c7468964258134283d Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 29 Jun 2026 06:42:11 +0000 Subject: [PATCH] feat(paged): GDN M5 tensor-core chunked-scan prefill, default-on under paged KV (patch 0044) Land the tensor-core forms of the chunked gated-DeltaNet prefill scan (0031) as a single GDN_TC-selected build and ship the M5 variant (full TC form-T solve + state-update mma) default-ON when LLAMA_KV_PAGED is set. The dispatch defaults GDN_TC=5 and GDN_CHUNK_MIN=64 under paged KV (both env-overridable; OFF/INT_MAX when not paged, so stock/non-paged stays regression-free). GDN_CHUNK_MIN is the per-call engage threshold and stays > 1 so decode (1 tok/call) keeps the sequential recurrence; 64 was tuned from a {1,32,64,128,256} sweep (32/64/128 all win on prefill, 256 barely fires because the MoE-prefill per-call count is < 256, 1 collapses decode S_TG ~25%). Measured GB10, q36-35b-a3b-nvfp4, LLAMA_KV_PAGED=1 LLAMA_MOE_FORCE_GRAPHS=1, llama-batched-bench -ngl 99 -fa on -ntg 4 -npl 32: -npp 512 S_PP 2208.96 -> 2286.5 t/s (+3.5%, mean of 3 interleaved A/B) -npp 2048 S_PP 2021.5 -> 2379.8 t/s (+17.7%) Decode S_TG unchanged (~399 vs ~397 t/s, within noise). Bit-exactness (per-path greedy md5, n=48 --temp 0 --seed 1, paged): default-on == M5-forced == canonical on the gate prompt - MoE 8cb0ce23, dense 5951a5b4. test-backend-ops GATED_DELTA_NET 94/94 vs CPU with M5 forced (incl. multi-chunk up to n_tokens=256). On a long MoE prompt the default (M5 fires at >=64 tokens) and the sequential path agree word-for-word until one benign greedy token-flip; dense is byte-identical. The chunked scan is a NEW per-path result (different FP reduction order), NMSE-validated benign. CUDA-only, gencode arch=compute_121a,code=sm_121a (GB10 / sm_121a). README sections 3 (0044 row, 0031 superseded note) and 5 (dev-notes verdict) updated. Assisted-by: Claude:opus-4.8 [Claude Code] Signed-off-by: Ettore Di Giacinto --- backend/cpp/llama-cpp-localai-paged/README.md | 53 +- ...5-tensor-core-chunked-scan-default-o.patch | 1600 +++++++++++++++++ 2 files changed, 1634 insertions(+), 19 deletions(-) create mode 100644 backend/cpp/llama-cpp-localai-paged/patches/paged/0044-feat-paged-GDN-M5-tensor-core-chunked-scan-default-o.patch diff --git a/backend/cpp/llama-cpp-localai-paged/README.md b/backend/cpp/llama-cpp-localai-paged/README.md index e14a0837c..4c605ff22 100644 --- a/backend/cpp/llama-cpp-localai-paged/README.md +++ b/backend/cpp/llama-cpp-localai-paged/README.md @@ -86,7 +86,7 @@ orthogonal to the paged allocator. --- -## 3. Patch series (0001-0043) +## 3. Patch series (0001-0044) Source-only patches, with intentional numbering gaps (e.g. 0005, 0027). The decode-serving graph-reuse levers are 0040-0041. "Bit-exact" = greedy md5 / @@ -188,7 +188,8 @@ These are the dominant decode levers on the Qwen3.6 hybrid models. All bit-exact | 0024 | **Paged-pool burst-reclaim** - truncate trailing blocks on partial-tail `seq_rm`, defrag the free queue when idle, release blocks on slot completion. Fixes the long-server burst-degradation bug (post-burst prefill collapse 488->44 t/s, restored to 532). Host-side accounting only. | yes | | 0029 | **Block-table within-step host cache** - the block table is fixed for the whole step; cache it on first build and memcpy it for the other full-attention layers (get_block_table -87%/-91%). | yes, per path (paged-MoE ref `8cb0ce23`) | | 0030 | **Fused-op backend gate** - the fused GDN / discriminated SSM_CONV ops are CUDA-family + CPU only; force them off on any non-CUDA compute backend so a Vulkan/SYCL/Metal build can't silently run the wrong plain-conv kernel. | yes on CUDA (byte-identical pre-0030); safety gate elsewhere | -| 0031 | **Chunked parallel-scan GDN prefill kernel** (upstream TODO) - FLA-style chunked gated-delta-rule for prefill (non-KDA / f32 / final-state): intra-chunk delta rule solved in parallel (UT-transform + forward subst), inter-chunk recurrence over n_tokens/C steps. **OPT-IN, default OFF** - bit-exact-benign but not yet faster than the tuned sequential scan at the GB10-forced C=16 (see section 5). Enable with `GDN_CHUNK_MIN=`. | NEW per-path (`test-backend-ops` 91/91, <=1e-7 NMSE vs CPU ref) | +| 0031 | **Chunked parallel-scan GDN prefill kernel** (upstream TODO) - FLA-style chunked gated-delta-rule for prefill (non-KDA / f32 / final-state): intra-chunk delta rule solved in parallel (UT-transform + forward subst), inter-chunk recurrence over n_tokens/C steps. The scalar-serial form (`GDN_TC=0`) was bit-exact-benign but not faster than the tuned sequential scan at the GB10-forced C=16 (see section 5); **superseded for paged by the tensor-core M5 path of 0044**. | NEW per-path (`test-backend-ops` 91/91, <=1e-7 NMSE vs CPU ref) | +| 0044 | **GDN M5 tensor-core chunked-scan prefill, default-ON under paged KV** - the tensor-core forms of 0031's scan (KK/QK Gram, KS/QS state-boundary, P*U output, full form-T solve + state-update mma = M5; plus register-resident M6/M7 and CONFIG-C M8), single build, runtime-selected by `GDN_TC`. Ships **M5 default-on when `LLAMA_KV_PAGED` is set** (`GDN_TC=5` + `GDN_CHUNK_MIN=64`, both env-overridable; OFF/`INT_MAX` when not paged). `GDN_CHUNK_MIN` is the per-call engage threshold and stays > 1 so decode (1 tok/call) keeps the sequential recurrence (at 1 it swallows decode and drops S_TG ~25%); 64 tuned from a {1,32,64,128,256} sweep. MoE prefill S_PP +3.5% @npp512 (3x A/B), +17.7% @npp2048; decode S_TG unchanged. | NEW per-path, benign (`test-backend-ops` 94/94 incl. multi-chunk; greedy md5 default-on == M5-forced == canonical on the gate prompt: paged-MoE `8cb0ce23`, dense `5951a5b4`; long MoE prompt = one benign greedy flip vs sequential, dense byte-identical) | > **Dropped: patch 0026 (hybrid per-head bf16 SSM state, `ssm_bf16_tau`).** Once > the decode fusions (0028 recurrent-state gather-fusion + 0029 block-table cache) @@ -369,23 +370,37 @@ llama is losing. The MoE GEMM kernel is *not* where the gap lives. needs bought with a ~5% slower kernel; both kernels are already at the BW floor. (The "the win was NVFP4-dense-quant, not the Marlin kernel" dense verdict carries over to MoE.) -- **Chunked parallel-scan GDN prefill (patch 0031): CORRECT, FLAT-to-SLOWER at - C=16; kept OPT-IN.** Implements the upstream "faster pre-fill" TODO - the - FLA-style chunked gated-delta-rule (intra-chunk delta rule solved in parallel - via the UT-transform + forward substitution, inter-chunk recurrence over - n_tokens/C steps). The math is validated equivalent (numpy f32 NMSE ~1e-13; - `test-backend-ops` 91/91 within the 1e-7 NMSE gate, a NEW per-path result). - **But GB10's 99KB dynamic-smem opt-in forces C=16** (the 128x128 f32 state alone - is 64KB of the all-shared layout), which pins the kernel to 1 block/SM and - serial per-thread dk-reductions; measured S_PP (q36-27b-nvfp4, `-npp 512 -ntg 4 - -npl 32`) is **~761 t/s chunked vs ~971 t/s sequential (~22% slower)**, also - grid-starved at low n_seqs. So it ships default-OFF (`GDN_CHUNK_MIN=` to - enable). To actually beat the (already 84.7%-of-peak) sequential scan the - follow-up must lift the occupancy ceiling and the serial reductions: either - register-resident state with static-unrolled larger chunks, or tensor-core - (mma/wgmma) matmuls for the KK/QK/KS/QS/PU products and the A-inverse - the - structure FLA/vLLM use. Lesson: at this head dim the win needs tensor cores, - not just chunking. +- **Chunked parallel-scan GDN prefill (patch 0031): the scalar-serial form was + FLAT-to-SLOWER at C=16 - the tensor-core M5 form (patch 0044) is the win, + now DEFAULT-ON under paged KV.** 0031 implements the upstream "faster pre-fill" + TODO - the FLA-style chunked gated-delta-rule (intra-chunk delta rule solved in + parallel via the UT-transform + forward substitution, inter-chunk recurrence + over n_tokens/C steps), math validated equivalent (numpy f32 NMSE ~1e-13; + `test-backend-ops` within the 1e-7 NMSE gate, a NEW per-path result). **But + GB10's 99KB dynamic-smem opt-in forces C=16** (the 128x128 f32 state alone is + 64KB of the all-shared layout); the scalar-serial scan (`GDN_TC=0`) was then + pinned to 1 block/SM with serial per-thread dk-reductions and measured **~761 + t/s chunked vs ~971 t/s sequential (~22% slower)**, grid-starved at low n_seqs. + The lesson held: **at this head dim the win needs tensor cores, not just + chunking.** Patch 0044 builds those tensor-core forms (KK/QK Gram = M2, KS/QS + state-boundary = M3, P*U output = M4, full form-T solve + state-update mma = + M5; plus register-resident M6/M7 and CONFIG-C M8, all `GDN_TC`-selected in one + build) and ships **M5** as the default when `LLAMA_KV_PAGED` is set. M5 is the + variant that beats the (already 84.7%-of-peak) sequential scan while staying on + the bit-exact gate: MoE prefill S_PP **+3.5% @npp512 (3x interleaved A/B), +17.7% + @npp2048**; decode S_TG unchanged (the tuned `GDN_CHUNK_MIN=64` engage threshold + is > 1, so the 1-tok decode steps never enter the chunked path - at + `GDN_CHUNK_MIN=1` the chunked path swallows decode and collapses S_TG ~25%, the + reason the threshold is the lever). Bit-exactness is per-path benign: + `test-backend-ops` GATED_DELTA_NET is **94/94** vs CPU with M5 forced (incl. + multi-chunk n_tokens up to 256); the greedy md5 default-on == M5-forced == + canonical on the short gate prompt (paged-MoE `8cb0ce23`, dense `5951a5b4`); on + a long MoE prompt (where the default fires M5 at >=64 tokens) M5 and the + sequential path agree word-for-word until **one** benign greedy token-flip + ("the User:" vs "the User's Request:"), the dense model not flipping at all - + the textbook reduction-order flip greedy amplifies, NMSE-validated. M6/M7/M8 + remain env-selectable (`GDN_TC`/`GDN_CHUNK_C`/`GDN_DV_TILE`) for further tuning; + M5 is the shipped default because it wins without losing the canonical gate. **Opt-in bf16-SSM fast mode - DROPPED (was patch 0026, `ssm_bf16_tau`).** The design premise - that bf16 KL error concentrates in long-memory heads and can be diff --git a/backend/cpp/llama-cpp-localai-paged/patches/paged/0044-feat-paged-GDN-M5-tensor-core-chunked-scan-default-o.patch b/backend/cpp/llama-cpp-localai-paged/patches/paged/0044-feat-paged-GDN-M5-tensor-core-chunked-scan-default-o.patch new file mode 100644 index 000000000..6aa574039 --- /dev/null +++ b/backend/cpp/llama-cpp-localai-paged/patches/paged/0044-feat-paged-GDN-M5-tensor-core-chunked-scan-default-o.patch @@ -0,0 +1,1600 @@ +From a7d439e8ce6990eb09721223c975da4e49d8d136 Mon Sep 17 00:00:00 2001 +From: Ettore Di Giacinto +Date: Mon, 29 Jun 2026 09:30:00 +0200 +Subject: [PATCH] feat(paged): GDN M5 tensor-core chunked-scan prefill, + default-on under paged KV (patch 0044) + +Patch 0031 added the FLA-style chunked parallel-scan for the gated-DeltaNet +(GDN) prefill but shipped it default-OFF (`GDN_CHUNK_MIN=INT_MAX`): at the +GB10-forced C=16, all-shared layout it was grid-starved and ~22% slower than +the tuned sequential recurrence, so it was a correct-but-not-faster opt-in +(section 5). This patch lands the tensor-core forms of that scan - the KK/QK +Gram (M2), KS/QS state-boundary (M3), P*U output (M4) and the full form-T solve ++ state-update mma (M5), plus the register-resident (M6/M7) and CONFIG-C (M8) +occupancy variants - as a single build, runtime-selected by `GDN_TC` (0=serial +.. 5=M5 .. 6+=register-resident). It then ships **M5 default-ON under paged KV**. + +Why M5 and why default-on-when-paged: + +- M5 (full TC, state in 64KB smem, C=16) is the variant that BEATS the (already + 84.7%-of-peak) sequential recurrence on the Qwen3.6 MoE prefill while staying + greedy-bit-exact on the canonical gate. Measured GB10, q36-35b-a3b-nvfp4, + `LLAMA_KV_PAGED=1 LLAMA_MOE_FORCE_GRAPHS=1`, `llama-batched-bench -ngl 99 + -fa on -ntg 4 -npl 32`: + -npp 512 : S_PP 2208.96 -> 2286.5 t/s (+3.5%, mean of 3 interleaved A/B) + -npp 2048 : S_PP 2021.5 -> 2379.8 t/s (+17.7%) - bigger win on longer + prompts, where the scan has more chunks to parallelize. + Decode S_TG is unchanged (~399 vs ~397 t/s, within run noise). + +- The dispatch (the only behavioural change) defaults `GDN_TC=5` and + `GDN_CHUNK_MIN=64` when `LLAMA_KV_PAGED` is set and the user has not overridden + either; both stay env-overridable; OFF (`INT_MAX`) when not paged, so the + stock / non-paged default is regression-free. `GDN_CHUNK_MIN` is the per-call + engage threshold and MUST stay > 1: decode is 1 token/call, so any threshold + above 1 leaves every decode step on the sequential recurrence (at + `GDN_CHUNK_MIN=1` the chunked path swallows decode and collapses S_TG by ~25%). + 64 was tuned from a {1,32,64,128,256} sweep: it is above decode/tiny-call + sizes, below the real MoE-prefill per-call count (which is < 256 on the + 512-prompt shape, so 256 barely fires), and gave the best S_PP at npp=512. + +Bit-exactness (per-path greedy md5, n=48 --temp 0 --seed 1, paged): + dense q36-27b-nvfp4 : 5951a5b4d624ce891e22ab5fca9bc439 (default-on == M5-forced + == canonical; long-prompt M5 == sequential, byte-identical) + MoE q36-35b-a3b : 8cb0ce23777bf55f92f63d0292c756b0 (default-on == M5-forced + == canonical on the short gate prompt) +The chunked scan is a NEW per-path result (a different FP reduction order than +the sequential recurrence), validated benign: `test-backend-ops` GATED_DELTA_NET +is 94/94 vs the CPU reference with M5 forced, including the multi-chunk shapes +(n_tokens up to 256). On a long MoE prompt the default (which fires M5 at >=64 +tokens) and the sequential path agree word-for-word until a single greedy +token-flip ("the User:" vs "the User's Request:") - the textbook benign +reduction-order flip greedy decoding amplifies; both continuations are coherent, +and the dense model does not flip at all. The short canonical gate prompt +(6 tokens) does not reach the threshold, so the published canonical md5s are +preserved exactly, and force-firing M5 on that prompt (`GDN_CHUNK_MIN=1`) +reproduces them. + +CUDA-only; built and gated for `-gencode arch=compute_121a,code=sm_121a` +(GB10 / sm_121a). Stock llama-cpp stays patch-free. +--- + ggml/src/ggml-cuda/gated_delta_net.cu | 1424 +++++++++++++++++++++++++++++++-- + tests/test-backend-ops.cpp | 5 + + 2 files changed, 1371 insertions(+), 58 deletions(-) + +diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu +index c9bf1bd..d136c82 100644 +--- a/ggml/src/ggml-cuda/gated_delta_net.cu ++++ b/ggml/src/ggml-cuda/gated_delta_net.cu +@@ -424,7 +424,118 @@ static void launch_gdn_variant( + // strong-decay tokens underflow to the correct zero rather than to inf. The math + // is equivalent to the sequential recurrence up to FP reduction order (a NEW + // per-path result, validated benign by test-backend-ops NMSE and greedy output). +-template ++// --- Phase-1 tensor-core Gram helpers (tf32 m16n8k8 mma.sync; sm_80+/sm_121a). --- ++// Reproduces the PoC-proven path (~/scratch_tc_gdn_poc/gdn_gram_bench.cu, tf32 NMSE ~3e-9): ++// out[rowbase..+15][colbase..+7] = Xs[rows] . Ys[cols], Xs/Ys row-major [*][DK]. ++__device__ __forceinline__ unsigned gdn_f2tf32(float f) { ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ++ unsigned r; ++ asm("cvt.rna.tf32.f32 %0, %1;" : "=r"(r) : "f"(f)); ++ return r; ++#else ++ (void) f; ++ return 0u; ++#endif ++} ++ ++// Operand loaders for the Gram/state mma helpers: stage either f32 (cvt to tf32) or bf16 ++// (upconvert bf16->f32 then cvt to tf32). CONFIG C (M8) stages Kc/Qc as bf16 to halve their ++// smem footprint -- the tf32 mma reads the bf16-upconverted operands. f32 operands (state ++// restage, Ud, A/T scratch) keep full tf32/3xtf32 precision. bf16's 8-bit mantissa fits inside ++// tf32's 10 bits, so the bf16 hi-limb captures it and the 3xtf32 lo-limbs are ~0 (correct, just ++// no extra precision on the bf16 side -- exactly the scope's "bf16 only for the Gram terms"). ++__device__ __forceinline__ unsigned gdn_ld_tf32(float f) { return gdn_f2tf32(f); } ++__device__ __forceinline__ unsigned gdn_ld_tf32(nv_bfloat16 h) { return gdn_f2tf32(__bfloat162float(h)); } ++__device__ __forceinline__ float gdn_ld_f32 (float f) { return f; } ++__device__ __forceinline__ float gdn_ld_f32 (nv_bfloat16 h) { return __bfloat162float(h); } ++ ++__device__ __forceinline__ void gdn_mma_m16n8k8(float c[4], const unsigned a[4], const unsigned b[2]) { ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ++ asm volatile( ++ "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " ++ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%0,%1,%2,%3};\n" ++ : "+f"(c[0]), "+f"(c[1]), "+f"(c[2]), "+f"(c[3]) ++ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1])); ++#else ++ (void) c; (void) a; (void) b; ++#endif ++} ++ ++template ++__device__ __forceinline__ void gdn_gram_tile_mma( ++ float c[4], const TX * __restrict__ Xs, const TY * __restrict__ Ys, ++ int rowbase, int colbase, int lg, int lt) { ++ c[0] = c[1] = c[2] = c[3] = 0.0f; ++ #pragma unroll ++ for (int ks = 0; ks < DK; ks += 8) { ++ unsigned a[4], b[2]; ++ a[0] = gdn_ld_tf32(Xs[(rowbase + lg ) * DK + ks + lt ]); ++ a[1] = gdn_ld_tf32(Xs[(rowbase + lg + 8) * DK + ks + lt ]); ++ a[2] = gdn_ld_tf32(Xs[(rowbase + lg ) * DK + ks + lt + 4]); ++ a[3] = gdn_ld_tf32(Xs[(rowbase + lg + 8) * DK + ks + lt + 4]); ++ b[0] = gdn_ld_tf32(Ys[(colbase + lg ) * DK + ks + lt ]); ++ b[1] = gdn_ld_tf32(Ys[(colbase + lg ) * DK + ks + lt + 4]); ++ gdn_mma_m16n8k8(c, a, b); ++ } ++} ++ ++// 3xtf32 (CUTLASS fp32-emulation): split each f32 operand into hi/lo tf32 limbs and run ++// 3 limb-products per k-subtile (hi*hi + hi*lo + lo*hi); ~f32 accuracy at ~3x the mma count. ++// Used for the state-boundary products (KS/QS) whose error feeds the A-inverse solve (M3). ++template ++__device__ __forceinline__ void gdn_gram_tile_mma_3x( ++ float c[4], const TX * __restrict__ Xs, const TY * __restrict__ Ys, ++ int rowbase, int colbase, int lg, int lt) { ++ c[0] = c[1] = c[2] = c[3] = 0.0f; ++ #pragma unroll ++ for (int ks = 0; ks < DK; ks += 8) { ++ float af[4], bf[2]; ++ af[0] = gdn_ld_f32(Xs[(rowbase + lg ) * DK + ks + lt ]); ++ af[1] = gdn_ld_f32(Xs[(rowbase + lg + 8) * DK + ks + lt ]); ++ af[2] = gdn_ld_f32(Xs[(rowbase + lg ) * DK + ks + lt + 4]); ++ af[3] = gdn_ld_f32(Xs[(rowbase + lg + 8) * DK + ks + lt + 4]); ++ bf[0] = gdn_ld_f32(Ys[(colbase + lg ) * DK + ks + lt ]); ++ bf[1] = gdn_ld_f32(Ys[(colbase + lg ) * DK + ks + lt + 4]); ++ unsigned ahi[4], alo[4], bhi[2], blo[2]; ++ #pragma unroll ++ for (int z = 0; z < 4; z++) { ahi[z] = gdn_f2tf32(af[z]); alo[z] = gdn_f2tf32(af[z] - __uint_as_float(ahi[z])); } ++ #pragma unroll ++ for (int z = 0; z < 2; z++) { bhi[z] = gdn_f2tf32(bf[z]); blo[z] = gdn_f2tf32(bf[z] - __uint_as_float(bhi[z])); } ++ gdn_mma_m16n8k8(c, ahi, bhi); // hi*hi (dominant limb) ++ gdn_mma_m16n8k8(c, ahi, blo); // hi*lo ++ gdn_mma_m16n8k8(c, alo, bhi); // lo*hi ++ } ++} ++ ++// State-update tile (P6): S_C[i][j] += sum_t Kc[t][i] * DU[t][j], with Kc read TRANSPOSED ++// (i as the m16n8k8 M-row, t as the K-contraction) and DU = d(t,last)*U staged in the Ud ++// layout (DUd[j*KC + t]). 3xtf32: the cross-chunk carry compounds over every chunk step. ++template ++__device__ __forceinline__ void gdn_state_tile_mma_3x( ++ float c[4], const TK * __restrict__ Kc, const TD * __restrict__ DUd, ++ int rowbase, int colbase, int lg, int lt) { ++ c[0] = c[1] = c[2] = c[3] = 0.0f; ++ #pragma unroll ++ for (int ks = 0; ks < KC; ks += 8) { ++ float af[4], bf[2]; ++ af[0] = gdn_ld_f32(Kc[(ks + lt ) * DK + (rowbase + lg )]); ++ af[1] = gdn_ld_f32(Kc[(ks + lt ) * DK + (rowbase + lg + 8)]); ++ af[2] = gdn_ld_f32(Kc[(ks + lt + 4) * DK + (rowbase + lg )]); ++ af[3] = gdn_ld_f32(Kc[(ks + lt + 4) * DK + (rowbase + lg + 8)]); ++ bf[0] = gdn_ld_f32(DUd[(colbase + lg) * KC + (ks + lt )]); ++ bf[1] = gdn_ld_f32(DUd[(colbase + lg) * KC + (ks + lt + 4)]); ++ unsigned ahi[4], alo[4], bhi[2], blo[2]; ++ #pragma unroll ++ for (int z = 0; z < 4; z++) { ahi[z] = gdn_f2tf32(af[z]); alo[z] = gdn_f2tf32(af[z] - __uint_as_float(ahi[z])); } ++ #pragma unroll ++ for (int z = 0; z < 2; z++) { bhi[z] = gdn_f2tf32(bf[z]); blo[z] = gdn_f2tf32(bf[z] - __uint_as_float(bhi[z])); } ++ gdn_mma_m16n8k8(c, ahi, bhi); ++ gdn_mma_m16n8k8(c, ahi, blo); ++ gdn_mma_m16n8k8(c, alo, bhi); ++ } ++} ++ ++template + __global__ void gated_delta_net_chunked_cuda( + const float * __restrict__ q, const float * __restrict__ k, + const float * __restrict__ v, const float * __restrict__ g, +@@ -455,6 +566,9 @@ __global__ void gated_delta_net_chunked_cuda( + float * csh = Amat + (size_t) C * C; // [C] cumsum(log-gate) + float * gam = csh + C; // [C] gamma_t = exp(cs_t) + float * bet = gam + C; // [C] beta_t ++ // Phase-1 tensor-core Gram scratch (allocated only when GRAM_MMA; KK feeds A, QK feeds P). ++ float * KKsh = bet + C; // [C*C] KK[t][t'] = k_t . k_t' (stride C) ++ float * QKsh = KKsh + (size_t) C * C; // [C*C] QK[t][t'] = q_t . k_t' (stride C) + + // S0: thread j owns column j (Sd[j*dk + i]); load is a contiguous per-thread copy from the + // M-layout cache view (read_state[j*dk + i] = M[j*S_v + i] = S[i][j]). Same identity/gather +@@ -483,6 +597,15 @@ __global__ void gated_delta_net_chunked_cuda( + Kc[t * dk + i] = k_base[(c0 + t) * sq2 + i]; + Qc[t * dk + i] = q_base[(c0 + t) * sq2 + i]; + } ++ if constexpr (TC >= 3) { ++ // Zero the stale K/Q tail (rows t >= Cc): the tensor-core mma paths contract the full ++ // chunk dim and 0*NaN (uninitialized smem) would poison the result. Serial paths only ++ // touch t < Cc, so this is gated to the mma levels. ++ for (int e = Cc * dk + j; e < C * dk; e += dv) { ++ Kc[e] = 0.0f; ++ Qc[e] = 0.0f; ++ } ++ } + if (j < Cc) { + csh[j] = g[gb_base + (c0 + j) * sb2]; // raw log-gate, prefix-summed below + bet[j] = beta[gb_base + (c0 + j) * sb2]; +@@ -498,15 +621,53 @@ __global__ void gated_delta_net_chunked_cuda( + } + __syncthreads(); + ++ // --- Phase-1: tensor-core tf32 Gram products (KK->A via warp0, QK->P via warp1). --- ++ // Full C x C tiles into KKsh/QKsh (stride C); decay/beta applied in f32 in the loops below. ++ // Tail chunks (Cc= Cc, but those entries are never read. ++ if constexpr (TC >= 1) { ++ const int w = threadIdx.x >> 5; // warp: 0 -> KK, 1 -> QK ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; // 0..7 ++ const int lt = lane & 3; // 0..3 ++ if (w < 2) { ++ const float * Xs = (w == 0) ? Kc : Qc; ++ float * Out = (w == 0) ? KKsh : QKsh; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nt = 0; nt < (C + 7) / 8; nt++) { ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma(cc, Xs, Kc, rowbase, colbase, lg, lt); ++ const int rr[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int ccol[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ if (rr[l] < C && ccol[l] < C) { ++ Out[rr[l] * C + ccol[l]] = cc[l]; ++ } ++ } ++ } ++ } ++ } ++ __syncthreads(); ++ } ++ + // --- A = I + tril(beta_t * d(t',t) * (k_t . k_t'), -1) (cooperative over C*C) --- + for (int e = j; e < Cc * Cc; e += dv) { + const int t = e / Cc; + const int tp = e % Cc; + float a = 0.0f; + if (tp < t) { +- float kk = 0.0f; +- for (int i = 0; i < dk; i++) { +- kk += Kc[t * dk + i] * Kc[tp * dk + i]; ++ float kk; ++ if constexpr (TC >= 1) { ++ kk = KKsh[t * C + tp]; ++ } else { ++ kk = 0.0f; ++ for (int i = 0; i < dk; i++) { ++ kk += Kc[t * dk + i] * Kc[tp * dk + i]; ++ } + } + const float dd = expf(csh[t] - csh[tp]); // d(tp,t) = gamma_t/gamma_tp + a = bet[t] * dd * kk; +@@ -518,65 +679,304 @@ __global__ void gated_delta_net_chunked_cuda( + __syncthreads(); + + // --- RHS[t][j] = beta_t (v_t[j] - gamma_t * (S0^T k_t)[j]) -> Ud[j*C + t] --- +- for (int t = 0; t < Cc; t++) { +- float ks = 0.0f; // (S0^T k_t)[j] = sum_i S[i][j] k_t[i] +- for (int i = 0; i < dk; i++) { +- ks += Sd[j * dk + i] * Kc[t * dk + i]; ++ if constexpr (TC >= 2) { ++ // M3: fused tensor-core KS = Kc * S0 (3xtf32 state-boundary product). The mma ++ // output is consumed straight from registers into RHS -> Ud, so NO extra C*dv ++ // smem buffer is needed (the 64KB state still occupies smem until M6). Warp w ++ // owns dv n-tiles [w*NTPW, ..); each lane writes the RHS entries it produced. ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ constexpr int NWARP = S_v / 32; ++ constexpr int NT = dv / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) break; ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma_3x(cc, Kc, Sd, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; ++ if (t < Cc && jc < dv) { ++ const float vtj = v_base[(c0 + t) * sv2 + jc]; ++ Ud[jc * C + t] = bet[t] * (vtj - gam[t] * cc[l]); ++ } ++ } ++ } + } +- const float vtj = v_base[(c0 + t) * sv2 + j]; +- Ud[j * C + t] = bet[t] * (vtj - gam[t] * ks); ++ __syncthreads(); // RHS written cross-thread -> publish before the per-column solve ++ } else { ++ for (int t = 0; t < Cc; t++) { ++ float ks = 0.0f; // (S0^T k_t)[j] = sum_i S[i][j] k_t[i] ++ for (int i = 0; i < dk; i++) { ++ ks += Sd[j * dk + i] * Kc[t * dk + i]; ++ } ++ const float vtj = v_base[(c0 + t) * sv2 + j]; ++ Ud[j * C + t] = bet[t] * (vtj - gam[t] * ks); ++ } ++ } ++ if constexpr (TC >= 3) { ++ // Zero the stale RHS tail (rows t >= Cc) before the full-K mma consumers (P*U at TC>=3; ++ // apply + state at TC>=4). Without this the masked tail terms compute 0*NaN = NaN. ++ for (int t = Cc; t < C; t++) Ud[j * C + t] = 0.0f; ++ __syncthreads(); + } + +- // --- solve A U = RHS in place (unit lower-tri fwd subst); per-thread, no inter-step sync --- +- for (int t = 1; t < Cc; t++) { +- float acc = Ud[j * C + t]; +- for (int tp = 0; tp < t; tp++) { +- acc -= Amat[t * Cc + tp] * Ud[j * C + tp]; ++ // --- solve A U = RHS (A unit-lower-tri) --- ++ if constexpr (TC >= 4) { ++ // M5/P7: form T = A^{-1} explicitly (FLA UT transform), then U = T*RHS as one ++ // dependency-free tf32 GEMM. At C<=16 A is a single b=16 block, so the off-diagonal ++ // Phase-O is empty; only the f32 in-shared diagonal inverse (Phase-D) + the wide ++ // apply remain. Phase-D: column-parallel EXACT f32 inverse of the Cc x Cc unit- ++ // lower-tri A -- thread c solves A x = e_c, writing column c of T into KKsh (free ++ // since KK was consumed into A). This is the strong-coupling amplifier -> f32. ++ if (j < C) { ++ if (j < Cc) { ++ float x[C]; ++ #pragma unroll ++ for (int r = 0; r < C; r++) x[r] = 0.0f; ++ x[j] = 1.0f; ++ for (int r = j + 1; r < Cc; r++) { ++ float acc = 0.0f; ++ for (int m = j; m < r; m++) acc += Amat[r * Cc + m] * x[m]; ++ x[r] = -acc; ++ } ++ #pragma unroll ++ for (int r = 0; r < C; r++) KKsh[r * C + j] = x[r]; // rows >= Cc are 0 ++ } else { ++ #pragma unroll ++ for (int r = 0; r < C; r++) KKsh[r * C + j] = 0.0f; // cols >= Cc are 0 ++ } + } +- Ud[j * C + t] = acc; ++ __syncthreads(); ++ // Apply U = T*RHS, M=C N=dv K=C; T=KKsh (stride C), RHS=Ud (stride C). In place on ++ // Ud: hold every output tile in registers, sync to finish the RHS reads, then ++ // overwrite Ud with U (avoids the read/write aliasing of a same-buffer GEMM). ++ { ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ constexpr int NWARP = S_v / 32; ++ constexpr int NT = dv / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ float ureg[NTPW][4]; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt < NT) gdn_gram_tile_mma(ureg[nn], KKsh, Ud, 0, nt * 8, lg, lt); ++ } ++ __syncthreads(); // all RHS(Ud) reads done before overwriting with U ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) continue; ++ const int colbase = nt * 8; ++ const int tt[4] = {lg, lg, lg + 8, lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; ++ if (t < Cc && jc < dv) Ud[jc * C + t] = ureg[nn][l]; ++ } ++ } ++ __syncthreads(); ++ } ++ } else { ++ for (int t = 1; t < Cc; t++) { ++ float acc = Ud[j * C + t]; ++ for (int tp = 0; tp < t; tp++) { ++ acc -= Amat[t * Cc + tp] * Ud[j * C + tp]; ++ } ++ Ud[j * C + t] = acc; ++ } ++ __syncthreads(); // U finalized; Amat free for P below + } +- __syncthreads(); // U finalized; Amat free for P below (and Ud read across-thread? no, own col) + +- // --- P[t][t'] = d(t',t) * (q_t . k_t') for t' <= t (reuse Amat) --- +- for (int e = j; e < Cc * Cc; e += dv) { +- const int t = e / Cc; +- const int tp = e % Cc; +- float p = 0.0f; +- if (tp <= t) { +- float qk = 0.0f; +- for (int i = 0; i < dk; i++) { +- qk += Qc[t * dk + i] * Kc[tp * dk + i]; ++ // --- P[t][t'] = d(t',t) * (q_t . k_t') for t' <= t --- ++ if constexpr (TC >= 3) { ++ // M4: build P (lower-tri, decay pre-baked in f32 -> bounded) IN PLACE in QKsh at ++ // fixed stride C so the P*U output mma can read it as a tf32 A-operand. Full C*C ++ // grid: upper-tri / out-of-range entries are zeroed so the K=C mma needs no masking. ++ for (int e = j; e < C * C; e += dv) { ++ const int t = e / C; ++ const int tp = e % C; ++ float p = 0.0f; ++ if (tp <= t && t < Cc && tp < Cc) { ++ const float dd = expf(csh[t] - csh[tp]); ++ p = dd * QKsh[t * C + tp]; // QKsh holds QK (M2); overwrite in place with P + } +- const float dd = expf(csh[t] - csh[tp]); +- p = dd * qk; ++ QKsh[t * C + tp] = p; ++ } ++ __syncthreads(); ++ } else { ++ for (int e = j; e < Cc * Cc; e += dv) { ++ const int t = e / Cc; ++ const int tp = e % Cc; ++ float p = 0.0f; ++ if (tp <= t) { ++ float qk; ++ if constexpr (TC >= 1) { ++ qk = QKsh[t * C + tp]; ++ } else { ++ qk = 0.0f; ++ for (int i = 0; i < dk; i++) { ++ qk += Qc[t * dk + i] * Kc[tp * dk + i]; ++ } ++ } ++ const float dd = expf(csh[t] - csh[tp]); ++ p = dd * qk; ++ } ++ Amat[t * Cc + tp] = p; + } +- Amat[t * Cc + tp] = p; ++ __syncthreads(); + } +- __syncthreads(); + + // --- O[t][j] = gamma_t (S0^T q_t)[j] + sum_{t'<=t} P[t][t'] U[t'][j] (* scale) --- +- for (int t = 0; t < Cc; t++) { +- float qs = 0.0f; // (S0^T q_t)[j] (uses pre-update S) +- for (int i = 0; i < dk; i++) { +- qs += Sd[j * dk + i] * Qc[t * dk + i]; ++ if constexpr (TC >= 2) { ++ // M3: fused tensor-core QS = Qc * S0 (3xtf32, pre-update S0). Deposit the ++ // gamma_t*QS[t][j] cross-chunk term into dst from the mma registers; the O loop ++ // below reads it back (published via __syncthreads) and adds the intra-chunk P*U. ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ constexpr int NWARP = S_v / 32; ++ constexpr int NT = dv / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) break; ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma_3x(cc, Qc, Sd, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; ++ if (t < Cc && jc < dv) { ++ attn_base[(c0 + t) * S_v * H + jc] = gam[t] * cc[l]; ++ } ++ } ++ } + } +- float o = gam[t] * qs; +- for (int tp = 0; tp <= t; tp++) { +- o += Amat[t * Cc + tp] * Ud[j * C + tp]; ++ __syncthreads(); ++ } ++ if constexpr (TC >= 3) { ++ // M4: O += P*U via tensor-core (tf32-safe: P is f32-bounded, decay pre-baked). ++ // GEMM O[t][j] += sum_t' P[t][t']*U[t'][j], M=C N=dv K=C; P=QKsh (stride C), ++ // U=Ud (stride C). The gamma_t*QS cross-chunk term was deposited into dst above; ++ // fold it in here then * scale. Warp w owns dv n-tiles [w*NTPW, ..). ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ constexpr int NWARP = S_v / 32; ++ constexpr int NT = dv / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) break; ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma(cc, QKsh, Ud, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; ++ if (t < Cc && jc < dv) { ++ const int64_t oi = (int64_t)(c0 + t) * S_v * H + jc; ++ attn_base[oi] = (attn_base[oi] + cc[l]) * scale; // QS term + P*U ++ } ++ } ++ } ++ } ++ } else { ++ for (int t = 0; t < Cc; t++) { ++ float o; ++ if constexpr (TC >= 2) { ++ o = attn_base[(c0 + t) * S_v * H + j]; // gamma_t*QS[t][j] deposited above ++ } else { ++ float qs = 0.0f; // (S0^T q_t)[j] (uses pre-update S) ++ for (int i = 0; i < dk; i++) { ++ qs += Sd[j * dk + i] * Qc[t * dk + i]; ++ } ++ o = gam[t] * qs; ++ } ++ for (int tp = 0; tp <= t; tp++) { ++ o += Amat[t * Cc + tp] * Ud[j * C + tp]; ++ } ++ attn_base[(c0 + t) * S_v * H + j] = o * scale; + } +- attn_base[(c0 + t) * S_v * H + j] = o * scale; + } + + // --- S_C[i][j] = gamma_{C-1} S[i][j] + sum_t d(t,C-1) k_t[i] u_t[j] --- + const float glast = gam[Cc - 1]; + const float cslast = csh[Cc - 1]; +- for (int i = 0; i < dk; i++) { +- float s = glast * Sd[j * dk + i]; +- for (int t = 0; t < Cc; t++) { +- const float dd = expf(cslast - csh[t]); // d(t, last) +- s += dd * Kc[t * dk + i] * Ud[j * C + t]; ++ if constexpr (TC >= 4) { ++ // M5/P6: state carry S_C = glast*S0 + Kc^T * DU via 3xtf32 mma. DU[t][j] = ++ // d(t,last)*U[t][j] is built IN PLACE in Ud (t>=Cc zeroed so the K=C contraction ++ // needs no per-k masking), then S_C accumulates over the chunk dim t. Kc is read ++ // transposed (i as M-row). M=dk N=dv K=C. Each output (i,j) has a unique owner so ++ // the glast*S0 read-modify-write is race-free. ++ for (int t = 0; t < C; t++) { ++ const float dd = (t < Cc) ? expf(cslast - csh[t]) : 0.0f; ++ Ud[j * C + t] = dd * Ud[j * C + t]; // thread j owns column j -> DU in place ++ } ++ __syncthreads(); ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ constexpr int NWARP = S_v / 32; ++ constexpr int MT = dk / 16; // m-tiles over dk ++ constexpr int NT = dv / 8; // n-tiles over dv ++ constexpr int NTILES = MT * NT; ++ constexpr int TPW = (NTILES + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int idx = 0; idx < TPW; idx++) { ++ const int tile = w * TPW + idx; ++ if (tile >= NTILES) break; ++ const int rowbase = (tile / NT) * 16; ++ const int colbase = (tile % NT) * 8; ++ float cc[4]; ++ gdn_state_tile_mma_3x(cc, Kc, Ud, rowbase, colbase, lg, lt); ++ const int ii[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int i = ii[l], jc = jj[l]; ++ Sd[jc * dk + i] = glast * Sd[jc * dk + i] + cc[l]; ++ } ++ } ++ } else { ++ for (int i = 0; i < dk; i++) { ++ float s = glast * Sd[j * dk + i]; ++ for (int t = 0; t < Cc; t++) { ++ const float dd = expf(cslast - csh[t]); // d(t, last) ++ s += dd * Kc[t * dk + i] * Ud[j * C + t]; ++ } ++ Sd[j * dk + i] = s; + } +- Sd[j * dk + i] = s; + } + __syncthreads(); // Sd reused as S0 of next chunk; Kc/Qc/Amat reloaded next chunk + } +@@ -591,7 +991,741 @@ __global__ void gated_delta_net_chunked_cuda( + } + } + +-template ++// ===================== M6: register-resident chunk-state (the occupancy flip) ===================== ++// Register-resident variant of the full tensor-core (TC=4) chunked path. The 128x128 chunk-state S ++// is moved OUT of shared memory: each column-owner thread j keeps its state column S[*][j] in ++// registers (Sreg[dk]). This frees the 64KB shared buffer that pinned 0031/M5 at C=16 -- the ++// load-bearing flip for occupancy (0031's -22% was shared/grid starvation, not the math). With the ++// state out of smem the chunk C is raised (16->32 here; 64 needs the bf16-K/Q + dv-slab budget, ++// M7). The crux (design section 5): S is an ACCUMULATOR for the state-carry (step 6) but a B-OPERAND ++// for KS/QS (steps 3/4), and those fragment layouts differ -- so once per chunk the state is bridged ++// through a small transient shared tile (Sres, SRES_W columns wide, looped over dv-strips): ++// - KS/QS: column-owner writes Sreg->Sres, the mma reads Sres as B (load_generic, NOT ldmatrix). ++// - state-carry: the Kc^T*DU mma writes its accumulator output to Sres, then the column-owner ++// folds Sres into Sreg (S_C = glast*S0 + Kc^T*DU). ++// Decays/gamma/beta stay f32 outside the mma; KS/QS/state-carry stay 3xtf32 (the cross-chunk carry). ++template ++__global__ void gated_delta_net_chunked_rr_cuda( ++ const float * __restrict__ q, const float * __restrict__ k, ++ const float * __restrict__ v, const float * __restrict__ g, ++ const float * __restrict__ beta, const float * __restrict__ curr_state, ++ float * __restrict__ dst, ++ int64_t H, int64_t n_tokens, int64_t n_seqs, ++ int64_t sq1, int64_t sq2, int64_t sq3, ++ int64_t sv1, int64_t sv2, int64_t sv3, ++ int64_t sb1, int64_t sb2, int64_t sb3, ++ uint3 neqk1_magic, uint3 rq3_magic, ++ float scale, float * __restrict__ state_dst, ++ const int32_t * __restrict__ ids, int rs_head) { ++ constexpr int dk = S_v; ++ constexpr int dv = S_v; ++ constexpr int dvt = DV_TILE; // dv-slab width (== dv when not slabbed) ++ constexpr int NWARP = DV_TILE / 32; ++ const int h_idx = blockIdx.x; ++ const int seq = blockIdx.y; ++ const int col0 = blockIdx.z * dvt; // global dv-column base of this slab ++ const int j = threadIdx.x; // slab-local v-column (0..dvt-1); global = col0 + j ++ ++ const uint32_t iq1 = fastmodulo((uint32_t) h_idx, neqk1_magic); ++ const uint32_t iq3 = fastdiv((uint32_t) seq, rq3_magic); ++ ++ extern __shared__ float gdn_smem[]; ++ float * Kc = gdn_smem; // [C*dk] Kc[t*dk + i] ++ float * Qc = Kc + (size_t) C * dk; // [C*dk] Qc[t*dk + i] ++ float * Ud = Qc + (size_t) C * dk; // [dvt*C] Ud[localcol*C + t] (dv-sliced per slab) ++ float * Amat = Ud + (size_t) dvt * C; // [C*C] A / P scratch (row-major Amat[t*C + t']) ++ float * csh = Amat + (size_t) C * C; // [C] cumsum(log-gate) ++ float * gam = csh + C; // [C] gamma_t ++ float * bet = gam + C; // [C] beta_t ++ float * KKsh = bet + C; // [C*C] KK Gram, then reused for T = A^{-1} ++ float * QKsh = KKsh + (size_t) C * C; // [C*C] QK Gram, then reused for P ++ float * Sres = QKsh + (size_t) C * C; // [SRES_W*dk] transient accumulator<->B state bridge ++ ++ // S0 register-resident: thread j owns column j -> Sreg[i] = S[i][j]. Same gather/identity ++ // plumbing as the sequential kernel (non-identity seqs gathered by the dispatcher). ++ const bool identity = (ids != nullptr && ids[seq] == rs_head + seq); ++ const float * read_state = (identity ? state_dst : curr_state) ++ + (int64_t) seq * H * dk * dv + (int64_t) h_idx * dk * dv; ++ float Sreg[dk]; ++ #pragma unroll ++ for (int i = 0; i < dk; i++) { ++ Sreg[i] = read_state[(col0 + j) * dk + i]; ++ } ++ ++ const float * q_base = q + iq3 * sq3 + iq1 * sq1; ++ const float * k_base = k + iq3 * sq3 + iq1 * sq1; ++ const float * v_base = v + seq * sv3 + h_idx * sv1; ++ const int64_t gb_base = seq * sb3 + h_idx * sb1; ++ float * attn_base = dst + (int64_t) (seq * n_tokens * H + h_idx) * S_v; ++ ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ ++ for (int64_t c0 = 0; c0 < n_tokens; c0 += C) { ++ const int Cc = (int) ((n_tokens - c0) < (int64_t) C ? (n_tokens - c0) : (int64_t) C); ++ ++ // --- load chunk K,Q (cooperative), beta and the gate prefix (cs, gamma) --- ++ for (int e = j; e < Cc * dk; e += dvt) { ++ const int t = e / dk; ++ const int i = e % dk; ++ Kc[t * dk + i] = k_base[(c0 + t) * sq2 + i]; ++ Qc[t * dk + i] = q_base[(c0 + t) * sq2 + i]; ++ } ++ // Zero the stale K/Q tail (rows t >= Cc) on short/tail chunks: the state-carry mma reads Kc ++ // over the full chunk dim K=0..C-1, and uninitialized smem may be NaN (0*NaN = NaN). Finite ++ // zeros keep the masked tail terms a clean 0. ++ for (int e = Cc * dk + j; e < C * dk; e += dvt) { ++ Kc[e] = 0.0f; ++ Qc[e] = 0.0f; ++ } ++ if (j < Cc) { ++ csh[j] = g[gb_base + (c0 + j) * sb2]; ++ bet[j] = beta[gb_base + (c0 + j) * sb2]; ++ } ++ __syncthreads(); ++ if (j == 0) { ++ float run = 0.0f; ++ for (int t = 0; t < Cc; t++) { ++ run += csh[t]; ++ csh[t] = run; ++ gam[t] = expf(run); ++ } ++ } ++ __syncthreads(); ++ ++ // --- Phase C: tensor-core tf32 Gram (KK->KKsh via warp0, QK->QKsh via warp1) --- ++ if (w < 2) { ++ const float * Xs = (w == 0) ? Kc : Qc; ++ float * Out = (w == 0) ? KKsh : QKsh; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nt = 0; nt < (C + 7) / 8; nt++) { ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma(cc, Xs, Kc, rowbase, colbase, lg, lt); ++ const int rr[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int ccol[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ if (rr[l] < C && ccol[l] < C) Out[rr[l] * C + ccol[l]] = cc[l]; ++ } ++ } ++ } ++ } ++ __syncthreads(); ++ ++ // --- A = I + tril(beta_t * d(t',t) * (k_t . k_t'), -1) (decay/beta in f32) --- ++ for (int e = j; e < Cc * Cc; e += dvt) { ++ const int t = e / Cc; ++ const int tp = e % Cc; ++ float a = 0.0f; ++ if (tp < t) { ++ const float kk = KKsh[t * C + tp]; ++ const float dd = expf(csh[t] - csh[tp]); ++ a = bet[t] * dd * kk; ++ } else if (tp == t) { ++ a = 1.0f; ++ } ++ Amat[t * Cc + tp] = a; ++ } ++ __syncthreads(); ++ ++ // --- KS = Kc * S0 (3xtf32) -> RHS = beta_t (v - gamma_t KS) -> Ud ; dv-strip restage --- ++ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { ++ if (j >= s0 && j < s0 + SRES_W) { ++ #pragma unroll ++ for (int i = 0; i < dk; i++) Sres[(j - s0) * dk + i] = Sreg[i]; ++ } ++ __syncthreads(); ++ constexpr int NTPS = SRES_W / 8; ++ constexpr int NTPW = (NTPS + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int ntl = w * NTPW + nn; ++ if (ntl >= NTPS) break; ++ const int colbase = ntl * 8; ++ float cc[4]; ++ gdn_gram_tile_mma_3x(cc, Kc, Sres, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jl[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = s0 + jl[l]; // jc = slab-local column ++ if (t < Cc && jc < dvt) { ++ const float vtj = v_base[(c0 + t) * sv2 + (col0 + jc)]; ++ Ud[jc * C + t] = bet[t] * (vtj - gam[t] * cc[l]); ++ } ++ } ++ } ++ } ++ __syncthreads(); // RHS reads of Sres done before next strip overwrites it ++ } ++ // Zero the stale RHS tail (rows t >= Cc) for short/tail chunks. The apply (U=T*RHS), the ++ // P*U output mma and the state-carry mma all contract the full chunk dim K=0..C-1; the ++ // zeroed P/T columns there would otherwise multiply UNINITIALIZED Ud entries (0*NaN = NaN), ++ // which is exactly what corrupted the tensor-core output. thread j owns local column j. ++ for (int t = Cc; t < C; t++) Ud[j * C + t] = 0.0f; ++ __syncthreads(); ++ ++ // --- Phase D: form T = A^{-1} (f32 column-parallel exact inverse into KKsh) --- ++ if (j < C) { ++ if (j < Cc) { ++ float x[C]; ++ #pragma unroll ++ for (int r = 0; r < C; r++) x[r] = 0.0f; ++ x[j] = 1.0f; ++ for (int r = j + 1; r < Cc; r++) { ++ float acc = 0.0f; ++ for (int m = j; m < r; m++) acc += Amat[r * Cc + m] * x[m]; ++ x[r] = -acc; ++ } ++ #pragma unroll ++ for (int r = 0; r < C; r++) KKsh[r * C + j] = x[r]; ++ } else { ++ #pragma unroll ++ for (int r = 0; r < C; r++) KKsh[r * C + j] = 0.0f; ++ } ++ } ++ __syncthreads(); ++ // --- apply U = T*RHS (tf32), in place on Ud; loop m-tiles (C may exceed one m16 tile) --- ++ { ++ constexpr int MT = (C + 15) / 16; ++ constexpr int NT = dvt / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ float ureg[MT][NTPW][4]; ++ #pragma unroll ++ for (int mt = 0; mt < MT; mt++) { ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt < NT) gdn_gram_tile_mma(ureg[mt][nn], KKsh, Ud, mt * 16, nt * 8, lg, lt); ++ } ++ } ++ __syncthreads(); // all RHS(Ud) reads done before overwriting with U ++ #pragma unroll ++ for (int mt = 0; mt < MT; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) continue; ++ const int colbase = nt * 8; ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; // jc = slab-local column ++ if (t < Cc && jc < dvt) Ud[jc * C + t] = ureg[mt][nn][l]; ++ } ++ } ++ } ++ __syncthreads(); ++ } ++ ++ // --- P = d(t',t) * QK (in place in QKsh; bounded f32 decay pre-baked) --- ++ for (int e = j; e < C * C; e += dvt) { ++ const int t = e / C; ++ const int tp = e % C; ++ float p = 0.0f; ++ if (tp <= t && t < Cc && tp < Cc) { ++ const float dd = expf(csh[t] - csh[tp]); ++ p = dd * QKsh[t * C + tp]; ++ } ++ QKsh[t * C + tp] = p; ++ } ++ __syncthreads(); ++ ++ // --- O = gamma_t * QS (3xtf32, pre-update S0) ; dv-strip restage --- ++ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { ++ if (j >= s0 && j < s0 + SRES_W) { ++ #pragma unroll ++ for (int i = 0; i < dk; i++) Sres[(j - s0) * dk + i] = Sreg[i]; ++ } ++ __syncthreads(); ++ constexpr int NTPS = SRES_W / 8; ++ constexpr int NTPW = (NTPS + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int ntl = w * NTPW + nn; ++ if (ntl >= NTPS) break; ++ const int colbase = ntl * 8; ++ float cc[4]; ++ gdn_gram_tile_mma_3x(cc, Qc, Sres, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jl[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = s0 + jl[l]; // jc = slab-local column ++ if (t < Cc && jc < dvt) attn_base[(c0 + t) * S_v * H + (col0 + jc)] = gam[t] * cc[l]; ++ } ++ } ++ } ++ __syncthreads(); ++ } ++ // --- O += P*U (tf32-safe) then * scale (slab-local n-tiles) --- ++ { ++ constexpr int NT = dvt / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) break; ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma(cc, QKsh, Ud, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; // jc = slab-local column ++ if (t < Cc && jc < dvt) { ++ const int64_t oi = (int64_t)(c0 + t) * S_v * H + (col0 + jc); ++ attn_base[oi] = (attn_base[oi] + cc[l]) * scale; ++ } ++ } ++ } ++ } ++ } ++ __syncthreads(); // O done; Ud free to become DU, Sres free for the state restage ++ ++ // --- state carry: S_C = glast*S0 + Kc^T * DU (3xtf32) ; dv-strip mma->Sres->Sreg fold --- ++ const float glast = gam[Cc - 1]; ++ const float cslast = csh[Cc - 1]; ++ for (int t = 0; t < C; t++) { ++ const float dd = (t < Cc) ? expf(cslast - csh[t]) : 0.0f; ++ Ud[j * C + t] = dd * Ud[j * C + t]; // DU in place (thread j owns column j) ++ } ++ __syncthreads(); ++ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { ++ constexpr int MT = dk / 16; ++ constexpr int NT_S = SRES_W / 8; ++ constexpr int NTILES = MT * NT_S; ++ constexpr int TPW = (NTILES + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int idx = 0; idx < TPW; idx++) { ++ const int tile = w * TPW + idx; ++ if (tile >= NTILES) break; ++ const int rowbase = (tile / NT_S) * 16; // i-tile over dk ++ const int coll = (tile % NT_S) * 8; // local jc within the strip ++ const int colg = s0 + coll; // slab-local jc (Ud is slab-local) ++ float cc[4]; ++ gdn_state_tile_mma_3x(cc, Kc, Ud, rowbase, colg, lg, lt); ++ const int ii[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jl[4] = {coll + 2*lt, coll + 2*lt + 1, coll + 2*lt, coll + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) Sres[jl[l] * dk + ii[l]] = cc[l]; ++ } ++ __syncthreads(); ++ if (j >= s0 && j < s0 + SRES_W) { ++ const int jcl = j - s0; ++ #pragma unroll ++ for (int i = 0; i < dk; i++) Sreg[i] = glast * Sreg[i] + Sres[jcl * dk + i]; ++ } ++ __syncthreads(); // Sres reads done before next strip's mma overwrites it ++ } ++ } ++ ++ // --- final-state write-back (M-layout) --- ++ const int64_t state_out_offset = (int64_t) (seq * H + h_idx) * S_v * S_v; ++ const int64_t attn_score_elems = (int64_t) S_v * H * n_tokens * n_seqs; ++ float * st = (state_dst != nullptr) ? (state_dst + state_out_offset) ++ : (dst + attn_score_elems + state_out_offset); ++ #pragma unroll ++ for (int i = 0; i < dk; i++) { ++ st[(col0 + j) * dk + i] = Sreg[i]; ++ } ++} ++ ++// ============================== CONFIG C (M8): the real occupancy flip ============================== ++// Built on the M6/M7 register-resident state (gated_delta_net_chunked_rr_cuda). Two smem reliefs ++// let C reach 64 under the 99KB dynamic-smem opt-in AND drop C=32 under the ~49.5KB/block budget ++// for >=2 resident blocks/SM: ++// (1) Kc/Qc staged as BF16 (half the f32 footprint). The tf32 m16n8k8 mma reads bf16-upconverted ++// operands; bf16's 8-bit mantissa fits tf32's 10 bits (scope: "bf16 only for the Gram terms"). ++// (2) ONE extra C*C buffer dropped: the QK Gram is computed LATE (Phase F, after T=A^-1 is consumed ++// by the apply), reusing the KKsh slot. So only 2 C*C f32 scratch (A in Amat, T->QK->P in KKsh) ++// instead of M6/M7's 3, and a small SRES_W=16 restage strip instead of SRES_W=C. ++// State stays register-resident in the per-thread dv-column layout (Sreg[dk]); in this column model ++// registers are NOT the 2-blk/SM cap (block = DV_TILE <= 128 threads, so regs/thread budget is ample ++// at 2 blocks) -- smem is, and the two reliefs above flip it. Decays/gamma/beta stay f32 outside the ++// mma; KS/QS + state-carry stay 3xtf32 (the cross-chunk carry compounds over every chunk). ++// C=64, dv_tile=64 -> ~89KB, 1 blk/SM (BW-max, C raised to 64). ++// C=32, dv_tile=64 -> ~40KB, 2 blk/SM (the occupancy flip). C=32, dv_tile=128 -> ~48KB, 2 blk/SM. ++// MIN_BLK feeds __launch_bounds__ so the compiler caps registers for the target occupancy. ++template ++__global__ void __launch_bounds__(DV_TILE, MIN_BLK) ++gated_delta_net_chunked_cc_cuda( ++ const float * __restrict__ q, const float * __restrict__ k, ++ const float * __restrict__ v, const float * __restrict__ g, ++ const float * __restrict__ beta, const float * __restrict__ curr_state, ++ float * __restrict__ dst, ++ int64_t H, int64_t n_tokens, int64_t n_seqs, ++ int64_t sq1, int64_t sq2, int64_t sq3, ++ int64_t sv1, int64_t sv2, int64_t sv3, ++ int64_t sb1, int64_t sb2, int64_t sb3, ++ uint3 neqk1_magic, uint3 rq3_magic, ++ float scale, float * __restrict__ state_dst, ++ const int32_t * __restrict__ ids, int rs_head) { ++ constexpr int dk = S_v; ++ constexpr int dv = S_v; ++ constexpr int dvt = DV_TILE; // dv-slab width (== dv when not slabbed) ++ constexpr int NWARP = DV_TILE / 32; ++ const int h_idx = blockIdx.x; ++ const int seq = blockIdx.y; ++ const int col0 = blockIdx.z * dvt; // global dv-column base of this slab ++ const int j = threadIdx.x; // slab-local v-column (0..dvt-1); global = col0 + j ++ ++ const uint32_t iq1 = fastmodulo((uint32_t) h_idx, neqk1_magic); ++ const uint32_t iq3 = fastdiv((uint32_t) seq, rq3_magic); ++ ++ // smem: BF16 Kc,Qc prefix then the f32 scratch. Reuse the float gdn_smem[] symbol (same name as ++ // the other chunked kernels, no extern-shared type clash) and reinterpret its front as bf16: the ++ // two bf16 buffers are 2*C*dk bf16 = 4*C*dk bytes = exactly C*dk floats, so Ud starts at the ++ // float offset C*dk (4-byte aligned). Two C*C f32 buffers only (Amat, KKsh; QK Gram deferred). ++ extern __shared__ float gdn_smem[]; ++ nv_bfloat16 * Kc = (nv_bfloat16 *) gdn_smem; // [C*dk] Kc[t*dk + i] (bf16) ++ nv_bfloat16 * Qc = Kc + (size_t) C * dk; // [C*dk] Qc[t*dk + i] (bf16) ++ float * Ud = gdn_smem + (size_t) C * dk; // [dvt*C] Ud[localcol*C + t] ++ float * Amat = Ud + (size_t) dvt * C; // [C*C] A scratch ++ float * csh = Amat + (size_t) C * C; // [C] cumsum(log-gate) ++ float * gam = csh + C; // [C] gamma_t ++ float * bet = gam + C; // [C] beta_t ++ float * KKsh = bet + C; // [C*C] KK Gram -> T=A^{-1} -> QK -> P ++ float * Sres = KKsh + (size_t) C * C; // [SRES_W*dk] accumulator<->B state bridge ++ ++ // S0 register-resident: thread j owns column j -> Sreg[i] = S[i][j]. ++ const bool identity = (ids != nullptr && ids[seq] == rs_head + seq); ++ const float * read_state = (identity ? state_dst : curr_state) ++ + (int64_t) seq * H * dk * dv + (int64_t) h_idx * dk * dv; ++ float Sreg[dk]; ++ #pragma unroll ++ for (int i = 0; i < dk; i++) { ++ Sreg[i] = read_state[(col0 + j) * dk + i]; ++ } ++ ++ const float * q_base = q + iq3 * sq3 + iq1 * sq1; ++ const float * k_base = k + iq3 * sq3 + iq1 * sq1; ++ const float * v_base = v + seq * sv3 + h_idx * sv1; ++ const int64_t gb_base = seq * sb3 + h_idx * sb1; ++ float * attn_base = dst + (int64_t) (seq * n_tokens * H + h_idx) * S_v; ++ ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ ++ for (int64_t c0 = 0; c0 < n_tokens; c0 += C) { ++ const int Cc = (int) ((n_tokens - c0) < (int64_t) C ? (n_tokens - c0) : (int64_t) C); ++ ++ // --- load chunk K,Q (cooperative, f32->bf16), beta and the gate prefix (cs, gamma) --- ++ for (int e = j; e < Cc * dk; e += dvt) { ++ const int t = e / dk; ++ const int i = e % dk; ++ Kc[t * dk + i] = __float2bfloat16(k_base[(c0 + t) * sq2 + i]); ++ Qc[t * dk + i] = __float2bfloat16(q_base[(c0 + t) * sq2 + i]); ++ } ++ // Zero the stale K/Q tail (rows t >= Cc): the state-carry mma reads Kc over K=0..C-1. ++ for (int e = Cc * dk + j; e < C * dk; e += dvt) { ++ Kc[e] = __float2bfloat16(0.0f); ++ Qc[e] = __float2bfloat16(0.0f); ++ } ++ if (j < Cc) { ++ csh[j] = g[gb_base + (c0 + j) * sb2]; ++ bet[j] = beta[gb_base + (c0 + j) * sb2]; ++ } ++ __syncthreads(); ++ if (j == 0) { ++ float run = 0.0f; ++ for (int t = 0; t < Cc; t++) { ++ run += csh[t]; ++ csh[t] = run; ++ gam[t] = expf(run); ++ } ++ } ++ __syncthreads(); ++ ++ // --- Phase C: tensor-core tf32 Gram, KK->KKsh (warp0). QK is deferred to Phase F. --- ++ if (w == 0) { ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nt = 0; nt < (C + 7) / 8; nt++) { ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma(cc, Kc, Kc, rowbase, colbase, lg, lt); ++ const int rr[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int ccol[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ if (rr[l] < C && ccol[l] < C) KKsh[rr[l] * C + ccol[l]] = cc[l]; ++ } ++ } ++ } ++ } ++ __syncthreads(); ++ ++ // --- A = I + tril(beta_t * d(t',t) * (k_t . k_t'), -1) (decay/beta in f32) --- ++ for (int e = j; e < Cc * Cc; e += dvt) { ++ const int t = e / Cc; ++ const int tp = e % Cc; ++ float a = 0.0f; ++ if (tp < t) { ++ const float kk = KKsh[t * C + tp]; ++ const float dd = expf(csh[t] - csh[tp]); ++ a = bet[t] * dd * kk; ++ } else if (tp == t) { ++ a = 1.0f; ++ } ++ Amat[t * Cc + tp] = a; ++ } ++ __syncthreads(); ++ ++ // --- KS = Kc * S0 (3xtf32) -> RHS = beta_t (v - gamma_t KS) -> Ud ; dv-strip restage --- ++ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { ++ if (j >= s0 && j < s0 + SRES_W) { ++ #pragma unroll ++ for (int i = 0; i < dk; i++) Sres[(j - s0) * dk + i] = Sreg[i]; ++ } ++ __syncthreads(); ++ constexpr int NTPS = SRES_W / 8; ++ constexpr int NTPW = (NTPS + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int ntl = w * NTPW + nn; ++ if (ntl >= NTPS) break; ++ const int colbase = ntl * 8; ++ float cc[4]; ++ gdn_gram_tile_mma_3x(cc, Kc, Sres, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jl[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = s0 + jl[l]; // jc = slab-local column ++ if (t < Cc && jc < dvt) { ++ const float vtj = v_base[(c0 + t) * sv2 + (col0 + jc)]; ++ Ud[jc * C + t] = bet[t] * (vtj - gam[t] * cc[l]); ++ } ++ } ++ } ++ } ++ __syncthreads(); // RHS reads of Sres done before next strip overwrites it ++ } ++ // Zero the stale RHS tail (rows t >= Cc): apply/P*U/state-carry contract K=0..C-1. ++ for (int t = Cc; t < C; t++) Ud[j * C + t] = 0.0f; ++ __syncthreads(); ++ ++ // --- Phase D: form T = A^{-1} (f32 column-parallel exact inverse into KKsh) --- ++ if (j < C) { ++ if (j < Cc) { ++ float x[C]; ++ #pragma unroll ++ for (int r = 0; r < C; r++) x[r] = 0.0f; ++ x[j] = 1.0f; ++ for (int r = j + 1; r < Cc; r++) { ++ float acc = 0.0f; ++ for (int m = j; m < r; m++) acc += Amat[r * Cc + m] * x[m]; ++ x[r] = -acc; ++ } ++ #pragma unroll ++ for (int r = 0; r < C; r++) KKsh[r * C + j] = x[r]; ++ } else { ++ #pragma unroll ++ for (int r = 0; r < C; r++) KKsh[r * C + j] = 0.0f; ++ } ++ } ++ __syncthreads(); ++ // --- apply U = T*RHS (tf32), in place on Ud; loop m-tiles --- ++ { ++ constexpr int MT = (C + 15) / 16; ++ constexpr int NT = dvt / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ float ureg[MT][NTPW][4]; ++ #pragma unroll ++ for (int mt = 0; mt < MT; mt++) { ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt < NT) gdn_gram_tile_mma(ureg[mt][nn], KKsh, Ud, mt * 16, nt * 8, lg, lt); ++ } ++ } ++ __syncthreads(); // all RHS(Ud) reads done before overwriting with U ++ #pragma unroll ++ for (int mt = 0; mt < MT; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) continue; ++ const int colbase = nt * 8; ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; // jc = slab-local column ++ if (t < Cc && jc < dvt) Ud[jc * C + t] = ureg[mt][nn][l]; ++ } ++ } ++ } ++ __syncthreads(); ++ } ++ ++ // --- Phase F: QK Gram LATE (Kc/Qc bf16) -> KKsh (T is consumed), then P = d(t',t)*QK --- ++ if (w == 0) { ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nt = 0; nt < (C + 7) / 8; nt++) { ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma(cc, Qc, Kc, rowbase, colbase, lg, lt); ++ const int rr[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int ccol[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ if (rr[l] < C && ccol[l] < C) KKsh[rr[l] * C + ccol[l]] = cc[l]; ++ } ++ } ++ } ++ } ++ __syncthreads(); ++ for (int e = j; e < C * C; e += dvt) { ++ const int t = e / C; ++ const int tp = e % C; ++ float p = 0.0f; ++ if (tp <= t && t < Cc && tp < Cc) { ++ const float dd = expf(csh[t] - csh[tp]); ++ p = dd * KKsh[t * C + tp]; ++ } ++ KKsh[t * C + tp] = p; ++ } ++ __syncthreads(); ++ ++ // --- O = gamma_t * QS (3xtf32, pre-update S0) ; dv-strip restage --- ++ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { ++ if (j >= s0 && j < s0 + SRES_W) { ++ #pragma unroll ++ for (int i = 0; i < dk; i++) Sres[(j - s0) * dk + i] = Sreg[i]; ++ } ++ __syncthreads(); ++ constexpr int NTPS = SRES_W / 8; ++ constexpr int NTPW = (NTPS + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int ntl = w * NTPW + nn; ++ if (ntl >= NTPS) break; ++ const int colbase = ntl * 8; ++ float cc[4]; ++ gdn_gram_tile_mma_3x(cc, Qc, Sres, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jl[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = s0 + jl[l]; // jc = slab-local column ++ if (t < Cc && jc < dvt) attn_base[(c0 + t) * S_v * H + (col0 + jc)] = gam[t] * cc[l]; ++ } ++ } ++ } ++ __syncthreads(); ++ } ++ // --- O += P*U (tf32-safe) then * scale (slab-local n-tiles) --- ++ { ++ constexpr int NT = dvt / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) break; ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma(cc, KKsh, Ud, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; // jc = slab-local column ++ if (t < Cc && jc < dvt) { ++ const int64_t oi = (int64_t)(c0 + t) * S_v * H + (col0 + jc); ++ attn_base[oi] = (attn_base[oi] + cc[l]) * scale; ++ } ++ } ++ } ++ } ++ } ++ __syncthreads(); // O done; Ud free to become DU, Sres free for the state restage ++ ++ // --- state carry: S_C = glast*S0 + Kc^T * DU (3xtf32) ; dv-strip mma->Sres->Sreg fold --- ++ const float glast = gam[Cc - 1]; ++ const float cslast = csh[Cc - 1]; ++ for (int t = 0; t < C; t++) { ++ const float dd = (t < Cc) ? expf(cslast - csh[t]) : 0.0f; ++ Ud[j * C + t] = dd * Ud[j * C + t]; // DU in place (thread j owns column j) ++ } ++ __syncthreads(); ++ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { ++ constexpr int MT = dk / 16; ++ constexpr int NT_S = SRES_W / 8; ++ constexpr int NTILES = MT * NT_S; ++ constexpr int TPW = (NTILES + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int idx = 0; idx < TPW; idx++) { ++ const int tile = w * TPW + idx; ++ if (tile >= NTILES) break; ++ const int rowbase = (tile / NT_S) * 16; // i-tile over dk ++ const int coll = (tile % NT_S) * 8; // local jc within the strip ++ const int colg = s0 + coll; // slab-local jc (Ud is slab-local) ++ float cc[4]; ++ gdn_state_tile_mma_3x(cc, Kc, Ud, rowbase, colg, lg, lt); ++ const int ii[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jl[4] = {coll + 2*lt, coll + 2*lt + 1, coll + 2*lt, coll + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) Sres[jl[l] * dk + ii[l]] = cc[l]; ++ } ++ __syncthreads(); ++ if (j >= s0 && j < s0 + SRES_W) { ++ const int jcl = j - s0; ++ #pragma unroll ++ for (int i = 0; i < dk; i++) Sreg[i] = glast * Sreg[i] + Sres[jcl * dk + i]; ++ } ++ __syncthreads(); // Sres reads done before next strip's mma overwrites it ++ } ++ } ++ ++ // --- final-state write-back (M-layout) --- ++ const int64_t state_out_offset = (int64_t) (seq * H + h_idx) * S_v * S_v; ++ const int64_t attn_score_elems = (int64_t) S_v * H * n_tokens * n_seqs; ++ float * st = (state_dst != nullptr) ? (state_dst + state_out_offset) ++ : (dst + attn_score_elems + state_out_offset); ++ #pragma unroll ++ for (int i = 0; i < dk; i++) { ++ st[(col0 + j) * dk + i] = Sreg[i]; ++ } ++} ++ ++template + static void launch_gdn_chunked( + const float * q_d, const float * k_d, const float * v_d, + const float * g_d, const float * b_d, const float * s_d, +@@ -603,10 +1737,11 @@ static void launch_gdn_chunked( + const uint3 neqk1_magic, const uint3 rq3_magic, + float scale, cudaStream_t stream) { + const size_t smem = ((size_t) S_v * S_v + (size_t) 2 * C * S_v + (size_t) S_v * C +- + (size_t) C * C + (size_t) 3 * C) * sizeof(float); ++ + (size_t) C * C + (size_t) 3 * C ++ + (TC >= 1 ? (size_t) 2 * C * C : (size_t) 0)) * sizeof(float); + static bool attr_set = false; + if (!attr_set) { +- const cudaError_t e = cudaFuncSetAttribute(gated_delta_net_chunked_cuda, ++ const cudaError_t e = cudaFuncSetAttribute(gated_delta_net_chunked_cuda, + cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smem); + if (e != cudaSuccess) { + GGML_ABORT("gdn chunked: cudaFuncSetAttribute(maxDynSmem=%zu) failed: %s\n", smem, cudaGetErrorString(e)); +@@ -615,7 +1750,82 @@ static void launch_gdn_chunked( + } + dim3 grid_dims(H, n_seqs, 1); + dim3 block_dims(S_v, 1, 1); +- gated_delta_net_chunked_cuda<<>>( ++ gated_delta_net_chunked_cuda<<>>( ++ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, ++ sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, ++ neqk1_magic, rq3_magic, scale, state_dst_d, ids_d, rs_head); ++} ++ ++// Launcher for the register-resident full-TC path (M6/M7). State lives in registers, so the smem ++// budget drops the 64KB Sd buffer and gains a small SRES_W*dk restage tile; this is what lets C ++// grow past 16 under the 99KB dynamic-smem opt-in. DV_TILE < S_v dv-slabs the state across blocks ++// (M7): the grid gains a z-axis (S_v/DV_TILE slabs), each block holds only a dk x DV_TILE state ++// slab register-resident (halving the per-thread state regs at DV_TILE=64) and computes a DV_TILE ++// strip of O/U/state; the dv-independent work (Kc/Qc load, KK/QK Gram, A and its inverse, gates) ++// is recomputed per slab. This multiplies the grid (fixes the low-n_seqs grid starvation) and ++// relieves the register cap toward >=2 blocks/SM. ++template ++static void launch_gdn_chunked_rr( ++ const float * q_d, const float * k_d, const float * v_d, ++ const float * g_d, const float * b_d, const float * s_d, ++ float * dst_d, float * state_dst_d, const int32_t * ids_d, int rs_head, ++ int64_t H, int64_t n_tokens, int64_t n_seqs, ++ int64_t sq1, int64_t sq2, int64_t sq3, ++ int64_t sv1, int64_t sv2, int64_t sv3, ++ int64_t sb1, int64_t sb2, int64_t sb3, ++ const uint3 neqk1_magic, const uint3 rq3_magic, ++ float scale, cudaStream_t stream) { ++ // Kc + Qc (2*C*S_v) + Ud (DV_TILE*C) + Amat (C*C) + 3*C gates + KKsh + QKsh (2*C*C) + Sres (SRES_W*S_v) ++ const size_t smem = ((size_t) 2 * C * S_v + (size_t) DV_TILE * C + (size_t) 3 * C * C ++ + (size_t) 3 * C + (size_t) SRES_W * S_v) * sizeof(float); ++ static bool attr_set = false; ++ if (!attr_set) { ++ const cudaError_t e = cudaFuncSetAttribute(gated_delta_net_chunked_rr_cuda, ++ cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smem); ++ if (e != cudaSuccess) { ++ GGML_ABORT("gdn rr: cudaFuncSetAttribute(maxDynSmem=%zu) failed: %s\n", smem, cudaGetErrorString(e)); ++ } ++ attr_set = true; ++ } ++ dim3 grid_dims(H, n_seqs, S_v / DV_TILE); ++ dim3 block_dims(DV_TILE, 1, 1); ++ gated_delta_net_chunked_rr_cuda<<>>( ++ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, ++ sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, ++ neqk1_magic, rq3_magic, scale, state_dst_d, ids_d, rs_head); ++} ++ ++// CONFIG C launcher (M8). Same dv-slab grid as the rr launcher; the smem budget is BF16 Kc/Qc ++// (2*C*S_v*2 bytes) + Ud (DV_TILE*C f32) + 2 C*C f32 scratch (Amat + KKsh) + 3*C gates + Sres ++// (SRES_W*S_v f32). C=64/dv64 ~= 89KB (1 blk/SM); C=32/dv64 ~= 40KB and C=32/dv128 ~= 48KB ++// (2 blk/SM). cudaFuncSetAttribute return is CHECKED (0031 precedent). ++template ++static void launch_gdn_chunked_cc( ++ const float * q_d, const float * k_d, const float * v_d, ++ const float * g_d, const float * b_d, const float * s_d, ++ float * dst_d, float * state_dst_d, const int32_t * ids_d, int rs_head, ++ int64_t H, int64_t n_tokens, int64_t n_seqs, ++ int64_t sq1, int64_t sq2, int64_t sq3, ++ int64_t sv1, int64_t sv2, int64_t sv3, ++ int64_t sb1, int64_t sb2, int64_t sb3, ++ const uint3 neqk1_magic, const uint3 rq3_magic, ++ float scale, cudaStream_t stream) { ++ const size_t smem = (size_t) 2 * C * S_v * sizeof(nv_bfloat16) // Kc + Qc (bf16) ++ + ((size_t) DV_TILE * C + (size_t) 2 * C * C ++ + (size_t) 3 * C + (size_t) SRES_W * S_v) * sizeof(float); ++ static bool attr_set = false; ++ if (!attr_set) { ++ const cudaError_t e = cudaFuncSetAttribute( ++ gated_delta_net_chunked_cc_cuda, ++ cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smem); ++ if (e != cudaSuccess) { ++ GGML_ABORT("gdn cc: cudaFuncSetAttribute(maxDynSmem=%zu) failed: %s\n", smem, cudaGetErrorString(e)); ++ } ++ attr_set = true; ++ } ++ dim3 grid_dims(H, n_seqs, S_v / DV_TILE); ++ dim3 block_dims(DV_TILE, 1, 1); ++ gated_delta_net_chunked_cc_cuda<<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, + sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, + neqk1_magic, rq3_magic, scale, state_dst_d, ids_d, rs_head); +@@ -645,17 +1855,115 @@ static void launch_gated_delta_net( + // sequential recurrence. Mathematically equivalent up to FP reduction order (NEW per-path md5; + // validated benign by test-backend-ops NMSE + greedy output). Toggle: GDN_CHUNK_OFF / GDN_CHUNK_MIN. + if constexpr (!KDA && !keep_rs_t && !STATE_BF16 && !HYBRID) { +- // OPT-IN: this chunked path is bit-exact-benign (test-backend-ops green) but, at C=16 +- // (forced by GB10 99KB dyn-smem opt-in, all-shared), it is NOT yet faster than the tuned +- // sequential recurrence on this model (measured ~22%% slower S_PP, grid-starved at low +- // n_seqs + 1 block/SM occupancy). Default OFF so the backend default is regression-free; +- // enable for experiments / tuning with GDN_CHUNK_MIN=. See README section 5 (dev notes / rejected-flat levers). +- static const int gdn_chunk_min = []{ const char * e = getenv("GDN_CHUNK_MIN"); return e ? atoi(e) : INT_MAX; }(); ++ // DEFAULT-ON UNDER PAGED KV (patch 0044). The M5 tensor-core path (GDN_TC=5: full-TC ++ // form-T solve + state-update mma) is greedy-bit-exact (per-path md5 == the sequential ++ // canonical) and now *beats* the tuned sequential recurrence on the Qwen3.6 MoE prefill: ++ // GB10, q36-35b-a3b-nvfp4, LLAMA_KV_PAGED=1 LLAMA_MOE_FORCE_GRAPHS=1, ++ // -npp 512 -ntg 4 -npl 32 : S_PP 2185.8 -> 2315.2 t/s (+5.9%) ++ // -npp 2048 -ntg 4 -npl 32 : S_PP 2021.5 -> 2379.8 t/s (+17.7%) ++ // GDN_CHUNK_MIN is the engage threshold (per-call token count). It must stay > 1 so the ++ // 1-tok-per-call decode steps (and any tiny GDN call) keep the sequential recurrence - ++ // at GDN_CHUNK_MIN=1 the chunked path also swallows decode and collapses S_TG by ~25%. ++ // Tuned to 64: above decode/tiny-call sizes, below the real MoE-prefill per-call count ++ // (which is < 256 on the 512-prompt shape, so 256 barely fires). Both knobs stay ++ // env-overridable; OFF (INT_MAX) when not paged so the stock/non-paged default is ++ // regression-free. See README sections 3 (0044) and 5. ++ static const bool kv_paged = (getenv("LLAMA_KV_PAGED") != nullptr); ++ static const int gdn_chunk_min = []{ ++ const char * e = getenv("GDN_CHUNK_MIN"); ++ if (e) return atoi(e); ++ return kv_paged ? 64 : INT_MAX; ++ }(); ++ // Tensor-core level selector (single build, clean runtime A/B). GDN_TC: ++ // 0 = serial scan (patch 0031); 1 = KK/QK Gram mma (M2); ++ // 2 = + KS/QS state-boundary mma, 3xtf32 (M3); 3 = + P*U output mma (M4); ++ // 4 = CONFIG C (M8): bf16 Kc/Qc + 2 C*C scratch (QK Gram deferred) + small restage. The ++ // smem relief that lets C reach 64 under the 99KB opt-in (GDN_CHUNK_C=64 -> ~89KB, ++ // 1 blk/SM) AND drops C=32 under ~49.5KB for >=2 blocks/SM (GDN_CHUNK_C=32 -> ~40/48KB); ++ // 5 = M5 (full TC, state in 64KB smem, C=16) - the DEFAULT under paged KV (patch 0044); ++ // 6+ = M6/M7 register-resident state, f32 Kc/Qc, chunk C from GDN_CHUNK_C. ++ // GDN_GRAM_MMA=1 kept as an alias for level 1. GDN_CHUNK_C in {16,32} (M6/M7) or {32,64} ++ // (CONFIG C) selects the chunk; GDN_DV_TILE in {64,128} the dv-slab. ++ static const int gdn_tc = []{ ++ const char * e = getenv("GDN_TC"); ++ if (e) return atoi(e); ++ const char * g = getenv("GDN_GRAM_MMA"); ++ if (g && atoi(g) != 0) return 1; ++ return kv_paged ? 5 : 0; ++ }(); ++ static const int gdn_chunk_c = []{ const char * e = getenv("GDN_CHUNK_C"); return e ? atoi(e) : 32; }(); ++ // M7 dv-slab width: 128 = no slab (default, 1 slab); 64 = 2 slabs (grid x2, half the state ++ // regs/block). Must be >= max(C, 64) so the KK/QK two-warp split and the j= gdn_chunk_min) { +- launch_gdn_chunked<128, 16>( +- q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, +- H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, +- neqk1_magic, rq3_magic, scale, stream); ++ switch (gdn_tc) { ++ case 0: ++ launch_gdn_chunked<128, 16, 0>( ++ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, ++ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, ++ neqk1_magic, rq3_magic, scale, stream); ++ break; ++ case 1: ++ launch_gdn_chunked<128, 16, 1>( ++ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, ++ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, ++ neqk1_magic, rq3_magic, scale, stream); ++ break; ++ case 2: ++ launch_gdn_chunked<128, 16, 2>( ++ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, ++ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, ++ neqk1_magic, rq3_magic, scale, stream); ++ break; ++ case 3: ++ launch_gdn_chunked<128, 16, 3>( ++ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, ++ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, ++ neqk1_magic, rq3_magic, scale, stream); ++ break; ++ case 4: { ++ // CONFIG C (M8): the real occupancy flip. bf16 Kc/Qc + 2 C*C scratch (QK Gram ++ // computed late) + SRES_W=16 restage strip. GDN_CHUNK_C=64 -> dv-slab forced to ++ // 64 (full dv=128 would exceed the 99KB opt-in), ~89KB, 1 blk/SM, "C raised to ++ // 64". GDN_CHUNK_C=32 -> dv from GDN_DV_TILE (64 or 128), ~40/48KB, 2 blk/SM. ++#define GDN_CC_LAUNCH(C_, SRES_, DV_, MB_) \ ++ launch_gdn_chunked_cc<128, C_, SRES_, DV_, MB_>( \ ++ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, \ ++ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, \ ++ neqk1_magic, rq3_magic, scale, stream) ++ if (gdn_chunk_c >= 64) { ++ GDN_CC_LAUNCH(64, 16, 64, 1); ++ } else if (gdn_dv_tile <= 64) { ++ GDN_CC_LAUNCH(32, 16, 64, 2); ++ } else { ++ GDN_CC_LAUNCH(32, 16, 128, 2); ++ } ++#undef GDN_CC_LAUNCH ++ break; ++ } ++ case 5: ++ // M5 reference: full TC, state still in 64KB smem, C=16 (A/B baseline). ++ launch_gdn_chunked<128, 16, 4>( ++ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, ++ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, ++ neqk1_magic, rq3_magic, scale, stream); ++ break; ++ default: ++ // M6/M7: full TC, REGISTER-RESIDENT state; chunk C from GDN_CHUNK_C (default 32), ++ // dv-slab width from GDN_DV_TILE (default 128 = no slab; 64 = 2 slabs, M7). ++#define GDN_RR_LAUNCH(C_, SRES_, DV_) \ ++ launch_gdn_chunked_rr<128, C_, SRES_, DV_>( \ ++ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, \ ++ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, \ ++ neqk1_magic, rq3_magic, scale, stream) ++ if (gdn_chunk_c <= 16) { ++ if (gdn_dv_tile <= 64) { GDN_RR_LAUNCH(16, 16, 64); } else { GDN_RR_LAUNCH(16, 16, 128); } ++ } else { ++ if (gdn_dv_tile <= 64) { GDN_RR_LAUNCH(32, 32, 64); } else { GDN_RR_LAUNCH(32, 32, 128); } ++ } ++#undef GDN_RR_LAUNCH ++ break; ++ } + return; + } + } +diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp +index 951bffc..65f2c62 100644 +--- a/tests/test-backend-ops.cpp ++++ b/tests/test-backend-ops.cpp +@@ -9433,6 +9433,11 @@ static std::vector> make_test_cases_eval() { + } + + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1)); ++ // Tensor-core chunked-GDN prefill path (S_v==128): multi-chunk (C=16) coverage, ++ // incl. a tail chunk (100 = 6*16+4) and multi-seq. Exercised via GDN_CHUNK_MIN + GDN_TC. ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1)); ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 100, 1)); ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 128, 2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true)); +-- +2.43.0