diff --git a/backend/cpp/llama-cpp-localai-paged/README.md b/backend/cpp/llama-cpp-localai-paged/README.md index 4c605ff22..b6178f510 100644 --- a/backend/cpp/llama-cpp-localai-paged/README.md +++ b/backend/cpp/llama-cpp-localai-paged/README.md @@ -86,7 +86,7 @@ orthogonal to the paged allocator. --- -## 3. Patch series (0001-0044) +## 3. Patch series (0001-0047) Source-only patches, with intentional numbering gaps (e.g. 0005, 0027). The decode-serving graph-reuse levers are 0040-0041. "Bit-exact" = greedy md5 / @@ -188,8 +188,9 @@ These are the dominant decode levers on the Qwen3.6 hybrid models. All bit-exact | 0024 | **Paged-pool burst-reclaim** - truncate trailing blocks on partial-tail `seq_rm`, defrag the free queue when idle, release blocks on slot completion. Fixes the long-server burst-degradation bug (post-burst prefill collapse 488->44 t/s, restored to 532). Host-side accounting only. | yes | | 0029 | **Block-table within-step host cache** - the block table is fixed for the whole step; cache it on first build and memcpy it for the other full-attention layers (get_block_table -87%/-91%). | yes, per path (paged-MoE ref `8cb0ce23`) | | 0030 | **Fused-op backend gate** - the fused GDN / discriminated SSM_CONV ops are CUDA-family + CPU only; force them off on any non-CUDA compute backend so a Vulkan/SYCL/Metal build can't silently run the wrong plain-conv kernel. | yes on CUDA (byte-identical pre-0030); safety gate elsewhere | -| 0031 | **Chunked parallel-scan GDN prefill kernel** (upstream TODO) - FLA-style chunked gated-delta-rule for prefill (non-KDA / f32 / final-state): intra-chunk delta rule solved in parallel (UT-transform + forward subst), inter-chunk recurrence over n_tokens/C steps. The scalar-serial form (`GDN_TC=0`) was bit-exact-benign but not faster than the tuned sequential scan at the GB10-forced C=16 (see section 5); **superseded for paged by the tensor-core M5 path of 0044**. | NEW per-path (`test-backend-ops` 91/91, <=1e-7 NMSE vs CPU ref) | -| 0044 | **GDN M5 tensor-core chunked-scan prefill, default-ON under paged KV** - the tensor-core forms of 0031's scan (KK/QK Gram, KS/QS state-boundary, P*U output, full form-T solve + state-update mma = M5; plus register-resident M6/M7 and CONFIG-C M8), single build, runtime-selected by `GDN_TC`. Ships **M5 default-on when `LLAMA_KV_PAGED` is set** (`GDN_TC=5` + `GDN_CHUNK_MIN=64`, both env-overridable; OFF/`INT_MAX` when not paged). `GDN_CHUNK_MIN` is the per-call engage threshold and stays > 1 so decode (1 tok/call) keeps the sequential recurrence (at 1 it swallows decode and drops S_TG ~25%); 64 tuned from a {1,32,64,128,256} sweep. MoE prefill S_PP +3.5% @npp512 (3x A/B), +17.7% @npp2048; decode S_TG unchanged. | NEW per-path, benign (`test-backend-ops` 94/94 incl. multi-chunk; greedy md5 default-on == M5-forced == canonical on the gate prompt: paged-MoE `8cb0ce23`, dense `5951a5b4`; long MoE prompt = one benign greedy flip vs sequential, dense byte-identical) | +| 0031 | **Chunked parallel-scan GDN prefill kernel** (upstream TODO) - FLA-style chunked gated-delta-rule for prefill (non-KDA / f32 / final-state): intra-chunk delta rule solved in parallel (UT-transform + forward subst), inter-chunk recurrence over n_tokens/C steps. The scalar-serial form (`GDN_TC=0`) was bit-exact-benign but not faster than the tuned sequential scan at the GB10-forced C=16 (see section 5); **superseded for paged by the tensor-core M5 path of 0047**. | NEW per-path (`test-backend-ops` 91/91, <=1e-7 NMSE vs CPU ref) | +| 0047 | **GDN M5 tensor-core chunked-scan prefill, f32-only re-port, default-ON under paged KV** - the f32/tf32 tensor-core forms of 0031's scan (KK/QK Gram = M2, KS/QS state-boundary 3xtf32 = M3, P*U output = M4, full form-T solve + state-update mma = M5), single build, runtime-selected by `GDN_TC`. Ships **M5 default-on when `LLAMA_KV_PAGED` is set** (`GDN_TC=5` + `GDN_CHUNK_MIN=64`, both env-overridable; OFF/`INT_MAX` when not paged). `GDN_CHUNK_MIN` is the per-call engage threshold and stays > 1 so decode (1 tok/call) keeps the sequential recurrence (at 1 it swallows decode and drops S_TG ~25%); 64 tuned from a {1,32,64,128,256} sweep. The bf16/hybrid dev-tree machinery (STATE_BF16/HYBRID, the dropped 0026 ssm_bf16_tau) and the bf16 CONFIG-C (M8) plus register-resident M6/M7 variants are NOT part of this f32-only series. MoE prefill S_PP +3.5% @npp512 (3x A/B), +17.7% @npp2048; decode S_TG unchanged. | NEW per-path, benign (`test-backend-ops` GATED_DELTA_NET 46/46 default AND force-M5, incl. multi-chunk/tail-chunk/multi-seq; greedy md5 default-on == M5-forced == canonical on the gate prompt: paged-MoE `8cb0ce23`, dense `5951a5b4`; long MoE prompt = one benign greedy flip vs sequential, dense byte-identical) | +| 0046 | **GDN prefill geometry gated by scan length** - patch 0022's `(NUM_WARPS=16, COLS_PER_WARP=8)` column-fold of the GDN sequential-recurrence dispatch (`case 128`) is a decode win but was applied UNCONDITIONALLY, so it also hit dense prefill (~-6% vs stock): on a long sequential scan the launch `grid.z` collapses from `S_v/4 = 32` to `S_v/(16*8) = 1` and the SMs starve (profiled: `gated_delta_net` +54% GPU time = the whole dense-prefill regression). Gate the geometry by per-call scan length: long scans (prefill, `n_tokens >= GDN_PREFILL_NTOK`, default 256) take stock's high-grid.z `(4,1)` geometry; short scans (decode) keep the `(16,8)` retune. Recovers dense prefill +7.2% back to stock parity, keeps the decode win. `GDN_PREFILL_NTOK` tunes the crossover; an explicit `GDN_NW`/`GDN_CPW` sweep still overrides (gate yields when either is set), so the one-build %peak A/B harness is unchanged. | yes (patch 0022 proved every `{NW,CPW}` variant byte-identical, so switching geometry by scan length cannot move the md5) | > **Dropped: patch 0026 (hybrid per-head bf16 SSM state, `ssm_bf16_tau`).** Once > the decode fusions (0028 recurrent-state gather-fusion + 0029 block-table cache) @@ -371,7 +372,7 @@ llama is losing. The MoE GEMM kernel is *not* where the gap lives. (The "the win was NVFP4-dense-quant, not the Marlin kernel" dense verdict carries over to MoE.) - **Chunked parallel-scan GDN prefill (patch 0031): the scalar-serial form was - FLAT-to-SLOWER at C=16 - the tensor-core M5 form (patch 0044) is the win, + FLAT-to-SLOWER at C=16 - the tensor-core M5 form (patch 0047) is the win, now DEFAULT-ON under paged KV.** 0031 implements the upstream "faster pre-fill" TODO - the FLA-style chunked gated-delta-rule (intra-chunk delta rule solved in parallel via the UT-transform + forward substitution, inter-chunk recurrence @@ -382,10 +383,12 @@ llama is losing. The MoE GEMM kernel is *not* where the gap lives. pinned to 1 block/SM with serial per-thread dk-reductions and measured **~761 t/s chunked vs ~971 t/s sequential (~22% slower)**, grid-starved at low n_seqs. The lesson held: **at this head dim the win needs tensor cores, not just - chunking.** Patch 0044 builds those tensor-core forms (KK/QK Gram = M2, KS/QS - state-boundary = M3, P*U output = M4, full form-T solve + state-update mma = - M5; plus register-resident M6/M7 and CONFIG-C M8, all `GDN_TC`-selected in one - build) and ships **M5** as the default when `LLAMA_KV_PAGED` is set. M5 is the + chunking.** Patch 0047 builds those tensor-core forms (KK/QK Gram = M2, KS/QS + state-boundary 3xtf32 = M3, P*U output = M4, full form-T solve + state-update + mma = M5, all `GDN_TC`-selected in one build) and ships **M5** as the default + when `LLAMA_KV_PAGED` is set. It is an f32/tf32-only re-port: the bf16/hybrid + dev-tree machinery (from the dropped 0026 ssm_bf16_tau) and the bf16 CONFIG-C + (M8) plus register-resident M6/M7 variants are NOT part of this series. M5 is the variant that beats the (already 84.7%-of-peak) sequential scan while staying on the bit-exact gate: MoE prefill S_PP **+3.5% @npp512 (3x interleaved A/B), +17.7% @npp2048**; decode S_TG unchanged (the tuned `GDN_CHUNK_MIN=64` engage threshold @@ -398,9 +401,28 @@ llama is losing. The MoE GEMM kernel is *not* where the gap lives. a long MoE prompt (where the default fires M5 at >=64 tokens) M5 and the sequential path agree word-for-word until **one** benign greedy token-flip ("the User:" vs "the User's Request:"), the dense model not flipping at all - - the textbook reduction-order flip greedy amplifies, NMSE-validated. M6/M7/M8 - remain env-selectable (`GDN_TC`/`GDN_CHUNK_C`/`GDN_DV_TILE`) for further tuning; - M5 is the shipped default because it wins without losing the canonical gate. + the textbook reduction-order flip greedy amplifies, NMSE-validated. The chunk + geometry stays env-selectable (`GDN_TC`/`GDN_CHUNK_C`/`GDN_DV_TILE`) for further + tuning; M5 is the shipped default because it wins without losing the canonical gate. +- **GDN occupancy retune (patch 0022) was a decode win but an UNCONDITIONAL + dense-prefill regression - now gated by scan length (patch 0046).** Patch + 0022's `(NUM_WARPS=16, COLS_PER_WARP=8)` column-fold of the GDN + sequential-recurrence dispatch (`case 128`) raises per-warp memory-level + parallelism on the short, wide DECODE scans (small `n_tokens`, large + `n_seqs`) - the measured +11.1% dense decode win. Applied unconditionally it + also hit the dense PREFILL path, where the scan is long and narrow: the launch + `grid.z` collapses from `S_v/4 = 32` to `S_v/(16*8) = 1`, the SMs starve, and + profiling attributed the whole ~-6% dense-prefill regression vs stock to + `gated_delta_net` (+54% GPU time at the (16,8) geometry). Patch 0046 gates the + geometry by per-call scan length: long scans (prefill, + `n_tokens >= GDN_PREFILL_NTOK`, default 256) take stock's high-grid.z `(4,1)` + geometry; short scans (decode) keep the `(16,8)` retune. That recovers dense + prefill +7.2% back to stock parity while keeping the decode win, and it is + bit-exact: patch 0022 already proved every selectable `{NUM_WARPS, + COLS_PER_WARP}` variant is byte-identical (the sweep cannot change the md5), so + switching geometry by scan length cannot move the greedy output. The explicit + `GDN_NW`/`GDN_CPW` one-build %peak sweep still overrides (the gate yields when + either is set), so the A/B harness is unchanged. **Opt-in bf16-SSM fast mode - DROPPED (was patch 0026, `ssm_bf16_tau`).** The design premise - that bf16 KL error concentrates in long-memory heads and can be @@ -455,6 +477,11 @@ targeted is already recovered by the gather-fusion + block-table cache. ## 7. Pin + maintenance policy +- **Canonical source = the fork branch `mudler/llama.cpp:localai-paged`.** The + vendored `patches/paged/*.patch` files are now generated (one `git format-patch` + per commit) from that branch, which is the pin commit plus the paged patch + commits in order, so there is no more hand-export drift between the dev tree and + the shipped series. - **Pinned to llama.cpp `9d5d882d`** (kept == the stock `llama-cpp` pin). The pin is advanced **only** by the manual pin-sync process (this section): rebase the source-only patch series onto the new tip, rebuild on GPU, pass the diff --git a/backend/cpp/llama-cpp-localai-paged/patches/paged/0044-feat-paged-GDN-M5-tensor-core-chunked-scan-default-o.patch b/backend/cpp/llama-cpp-localai-paged/patches/paged/0044-feat-paged-GDN-M5-tensor-core-chunked-scan-default-o.patch deleted file mode 100644 index 6aa574039..000000000 --- a/backend/cpp/llama-cpp-localai-paged/patches/paged/0044-feat-paged-GDN-M5-tensor-core-chunked-scan-default-o.patch +++ /dev/null @@ -1,1600 +0,0 @@ -From a7d439e8ce6990eb09721223c975da4e49d8d136 Mon Sep 17 00:00:00 2001 -From: Ettore Di Giacinto -Date: Mon, 29 Jun 2026 09:30:00 +0200 -Subject: [PATCH] feat(paged): GDN M5 tensor-core chunked-scan prefill, - default-on under paged KV (patch 0044) - -Patch 0031 added the FLA-style chunked parallel-scan for the gated-DeltaNet -(GDN) prefill but shipped it default-OFF (`GDN_CHUNK_MIN=INT_MAX`): at the -GB10-forced C=16, all-shared layout it was grid-starved and ~22% slower than -the tuned sequential recurrence, so it was a correct-but-not-faster opt-in -(section 5). This patch lands the tensor-core forms of that scan - the KK/QK -Gram (M2), KS/QS state-boundary (M3), P*U output (M4) and the full form-T solve -+ state-update mma (M5), plus the register-resident (M6/M7) and CONFIG-C (M8) -occupancy variants - as a single build, runtime-selected by `GDN_TC` (0=serial -.. 5=M5 .. 6+=register-resident). It then ships **M5 default-ON under paged KV**. - -Why M5 and why default-on-when-paged: - -- M5 (full TC, state in 64KB smem, C=16) is the variant that BEATS the (already - 84.7%-of-peak) sequential recurrence on the Qwen3.6 MoE prefill while staying - greedy-bit-exact on the canonical gate. Measured GB10, q36-35b-a3b-nvfp4, - `LLAMA_KV_PAGED=1 LLAMA_MOE_FORCE_GRAPHS=1`, `llama-batched-bench -ngl 99 - -fa on -ntg 4 -npl 32`: - -npp 512 : S_PP 2208.96 -> 2286.5 t/s (+3.5%, mean of 3 interleaved A/B) - -npp 2048 : S_PP 2021.5 -> 2379.8 t/s (+17.7%) - bigger win on longer - prompts, where the scan has more chunks to parallelize. - Decode S_TG is unchanged (~399 vs ~397 t/s, within run noise). - -- The dispatch (the only behavioural change) defaults `GDN_TC=5` and - `GDN_CHUNK_MIN=64` when `LLAMA_KV_PAGED` is set and the user has not overridden - either; both stay env-overridable; OFF (`INT_MAX`) when not paged, so the - stock / non-paged default is regression-free. `GDN_CHUNK_MIN` is the per-call - engage threshold and MUST stay > 1: decode is 1 token/call, so any threshold - above 1 leaves every decode step on the sequential recurrence (at - `GDN_CHUNK_MIN=1` the chunked path swallows decode and collapses S_TG by ~25%). - 64 was tuned from a {1,32,64,128,256} sweep: it is above decode/tiny-call - sizes, below the real MoE-prefill per-call count (which is < 256 on the - 512-prompt shape, so 256 barely fires), and gave the best S_PP at npp=512. - -Bit-exactness (per-path greedy md5, n=48 --temp 0 --seed 1, paged): - dense q36-27b-nvfp4 : 5951a5b4d624ce891e22ab5fca9bc439 (default-on == M5-forced - == canonical; long-prompt M5 == sequential, byte-identical) - MoE q36-35b-a3b : 8cb0ce23777bf55f92f63d0292c756b0 (default-on == M5-forced - == canonical on the short gate prompt) -The chunked scan is a NEW per-path result (a different FP reduction order than -the sequential recurrence), validated benign: `test-backend-ops` GATED_DELTA_NET -is 94/94 vs the CPU reference with M5 forced, including the multi-chunk shapes -(n_tokens up to 256). On a long MoE prompt the default (which fires M5 at >=64 -tokens) and the sequential path agree word-for-word until a single greedy -token-flip ("the User:" vs "the User's Request:") - the textbook benign -reduction-order flip greedy decoding amplifies; both continuations are coherent, -and the dense model does not flip at all. The short canonical gate prompt -(6 tokens) does not reach the threshold, so the published canonical md5s are -preserved exactly, and force-firing M5 on that prompt (`GDN_CHUNK_MIN=1`) -reproduces them. - -CUDA-only; built and gated for `-gencode arch=compute_121a,code=sm_121a` -(GB10 / sm_121a). Stock llama-cpp stays patch-free. ---- - ggml/src/ggml-cuda/gated_delta_net.cu | 1424 +++++++++++++++++++++++++++++++-- - tests/test-backend-ops.cpp | 5 + - 2 files changed, 1371 insertions(+), 58 deletions(-) - -diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu -index c9bf1bd..d136c82 100644 ---- a/ggml/src/ggml-cuda/gated_delta_net.cu -+++ b/ggml/src/ggml-cuda/gated_delta_net.cu -@@ -424,7 +424,118 @@ static void launch_gdn_variant( - // strong-decay tokens underflow to the correct zero rather than to inf. The math - // is equivalent to the sequential recurrence up to FP reduction order (a NEW - // per-path result, validated benign by test-backend-ops NMSE and greedy output). --template -+// --- Phase-1 tensor-core Gram helpers (tf32 m16n8k8 mma.sync; sm_80+/sm_121a). --- -+// Reproduces the PoC-proven path (~/scratch_tc_gdn_poc/gdn_gram_bench.cu, tf32 NMSE ~3e-9): -+// out[rowbase..+15][colbase..+7] = Xs[rows] . Ys[cols], Xs/Ys row-major [*][DK]. -+__device__ __forceinline__ unsigned gdn_f2tf32(float f) { -+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -+ unsigned r; -+ asm("cvt.rna.tf32.f32 %0, %1;" : "=r"(r) : "f"(f)); -+ return r; -+#else -+ (void) f; -+ return 0u; -+#endif -+} -+ -+// Operand loaders for the Gram/state mma helpers: stage either f32 (cvt to tf32) or bf16 -+// (upconvert bf16->f32 then cvt to tf32). CONFIG C (M8) stages Kc/Qc as bf16 to halve their -+// smem footprint -- the tf32 mma reads the bf16-upconverted operands. f32 operands (state -+// restage, Ud, A/T scratch) keep full tf32/3xtf32 precision. bf16's 8-bit mantissa fits inside -+// tf32's 10 bits, so the bf16 hi-limb captures it and the 3xtf32 lo-limbs are ~0 (correct, just -+// no extra precision on the bf16 side -- exactly the scope's "bf16 only for the Gram terms"). -+__device__ __forceinline__ unsigned gdn_ld_tf32(float f) { return gdn_f2tf32(f); } -+__device__ __forceinline__ unsigned gdn_ld_tf32(nv_bfloat16 h) { return gdn_f2tf32(__bfloat162float(h)); } -+__device__ __forceinline__ float gdn_ld_f32 (float f) { return f; } -+__device__ __forceinline__ float gdn_ld_f32 (nv_bfloat16 h) { return __bfloat162float(h); } -+ -+__device__ __forceinline__ void gdn_mma_m16n8k8(float c[4], const unsigned a[4], const unsigned b[2]) { -+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -+ asm volatile( -+ "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " -+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%0,%1,%2,%3};\n" -+ : "+f"(c[0]), "+f"(c[1]), "+f"(c[2]), "+f"(c[3]) -+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1])); -+#else -+ (void) c; (void) a; (void) b; -+#endif -+} -+ -+template -+__device__ __forceinline__ void gdn_gram_tile_mma( -+ float c[4], const TX * __restrict__ Xs, const TY * __restrict__ Ys, -+ int rowbase, int colbase, int lg, int lt) { -+ c[0] = c[1] = c[2] = c[3] = 0.0f; -+ #pragma unroll -+ for (int ks = 0; ks < DK; ks += 8) { -+ unsigned a[4], b[2]; -+ a[0] = gdn_ld_tf32(Xs[(rowbase + lg ) * DK + ks + lt ]); -+ a[1] = gdn_ld_tf32(Xs[(rowbase + lg + 8) * DK + ks + lt ]); -+ a[2] = gdn_ld_tf32(Xs[(rowbase + lg ) * DK + ks + lt + 4]); -+ a[3] = gdn_ld_tf32(Xs[(rowbase + lg + 8) * DK + ks + lt + 4]); -+ b[0] = gdn_ld_tf32(Ys[(colbase + lg ) * DK + ks + lt ]); -+ b[1] = gdn_ld_tf32(Ys[(colbase + lg ) * DK + ks + lt + 4]); -+ gdn_mma_m16n8k8(c, a, b); -+ } -+} -+ -+// 3xtf32 (CUTLASS fp32-emulation): split each f32 operand into hi/lo tf32 limbs and run -+// 3 limb-products per k-subtile (hi*hi + hi*lo + lo*hi); ~f32 accuracy at ~3x the mma count. -+// Used for the state-boundary products (KS/QS) whose error feeds the A-inverse solve (M3). -+template -+__device__ __forceinline__ void gdn_gram_tile_mma_3x( -+ float c[4], const TX * __restrict__ Xs, const TY * __restrict__ Ys, -+ int rowbase, int colbase, int lg, int lt) { -+ c[0] = c[1] = c[2] = c[3] = 0.0f; -+ #pragma unroll -+ for (int ks = 0; ks < DK; ks += 8) { -+ float af[4], bf[2]; -+ af[0] = gdn_ld_f32(Xs[(rowbase + lg ) * DK + ks + lt ]); -+ af[1] = gdn_ld_f32(Xs[(rowbase + lg + 8) * DK + ks + lt ]); -+ af[2] = gdn_ld_f32(Xs[(rowbase + lg ) * DK + ks + lt + 4]); -+ af[3] = gdn_ld_f32(Xs[(rowbase + lg + 8) * DK + ks + lt + 4]); -+ bf[0] = gdn_ld_f32(Ys[(colbase + lg ) * DK + ks + lt ]); -+ bf[1] = gdn_ld_f32(Ys[(colbase + lg ) * DK + ks + lt + 4]); -+ unsigned ahi[4], alo[4], bhi[2], blo[2]; -+ #pragma unroll -+ for (int z = 0; z < 4; z++) { ahi[z] = gdn_f2tf32(af[z]); alo[z] = gdn_f2tf32(af[z] - __uint_as_float(ahi[z])); } -+ #pragma unroll -+ for (int z = 0; z < 2; z++) { bhi[z] = gdn_f2tf32(bf[z]); blo[z] = gdn_f2tf32(bf[z] - __uint_as_float(bhi[z])); } -+ gdn_mma_m16n8k8(c, ahi, bhi); // hi*hi (dominant limb) -+ gdn_mma_m16n8k8(c, ahi, blo); // hi*lo -+ gdn_mma_m16n8k8(c, alo, bhi); // lo*hi -+ } -+} -+ -+// State-update tile (P6): S_C[i][j] += sum_t Kc[t][i] * DU[t][j], with Kc read TRANSPOSED -+// (i as the m16n8k8 M-row, t as the K-contraction) and DU = d(t,last)*U staged in the Ud -+// layout (DUd[j*KC + t]). 3xtf32: the cross-chunk carry compounds over every chunk step. -+template -+__device__ __forceinline__ void gdn_state_tile_mma_3x( -+ float c[4], const TK * __restrict__ Kc, const TD * __restrict__ DUd, -+ int rowbase, int colbase, int lg, int lt) { -+ c[0] = c[1] = c[2] = c[3] = 0.0f; -+ #pragma unroll -+ for (int ks = 0; ks < KC; ks += 8) { -+ float af[4], bf[2]; -+ af[0] = gdn_ld_f32(Kc[(ks + lt ) * DK + (rowbase + lg )]); -+ af[1] = gdn_ld_f32(Kc[(ks + lt ) * DK + (rowbase + lg + 8)]); -+ af[2] = gdn_ld_f32(Kc[(ks + lt + 4) * DK + (rowbase + lg )]); -+ af[3] = gdn_ld_f32(Kc[(ks + lt + 4) * DK + (rowbase + lg + 8)]); -+ bf[0] = gdn_ld_f32(DUd[(colbase + lg) * KC + (ks + lt )]); -+ bf[1] = gdn_ld_f32(DUd[(colbase + lg) * KC + (ks + lt + 4)]); -+ unsigned ahi[4], alo[4], bhi[2], blo[2]; -+ #pragma unroll -+ for (int z = 0; z < 4; z++) { ahi[z] = gdn_f2tf32(af[z]); alo[z] = gdn_f2tf32(af[z] - __uint_as_float(ahi[z])); } -+ #pragma unroll -+ for (int z = 0; z < 2; z++) { bhi[z] = gdn_f2tf32(bf[z]); blo[z] = gdn_f2tf32(bf[z] - __uint_as_float(bhi[z])); } -+ gdn_mma_m16n8k8(c, ahi, bhi); -+ gdn_mma_m16n8k8(c, ahi, blo); -+ gdn_mma_m16n8k8(c, alo, bhi); -+ } -+} -+ -+template - __global__ void gated_delta_net_chunked_cuda( - const float * __restrict__ q, const float * __restrict__ k, - const float * __restrict__ v, const float * __restrict__ g, -@@ -455,6 +566,9 @@ __global__ void gated_delta_net_chunked_cuda( - float * csh = Amat + (size_t) C * C; // [C] cumsum(log-gate) - float * gam = csh + C; // [C] gamma_t = exp(cs_t) - float * bet = gam + C; // [C] beta_t -+ // Phase-1 tensor-core Gram scratch (allocated only when GRAM_MMA; KK feeds A, QK feeds P). -+ float * KKsh = bet + C; // [C*C] KK[t][t'] = k_t . k_t' (stride C) -+ float * QKsh = KKsh + (size_t) C * C; // [C*C] QK[t][t'] = q_t . k_t' (stride C) - - // S0: thread j owns column j (Sd[j*dk + i]); load is a contiguous per-thread copy from the - // M-layout cache view (read_state[j*dk + i] = M[j*S_v + i] = S[i][j]). Same identity/gather -@@ -483,6 +597,15 @@ __global__ void gated_delta_net_chunked_cuda( - Kc[t * dk + i] = k_base[(c0 + t) * sq2 + i]; - Qc[t * dk + i] = q_base[(c0 + t) * sq2 + i]; - } -+ if constexpr (TC >= 3) { -+ // Zero the stale K/Q tail (rows t >= Cc): the tensor-core mma paths contract the full -+ // chunk dim and 0*NaN (uninitialized smem) would poison the result. Serial paths only -+ // touch t < Cc, so this is gated to the mma levels. -+ for (int e = Cc * dk + j; e < C * dk; e += dv) { -+ Kc[e] = 0.0f; -+ Qc[e] = 0.0f; -+ } -+ } - if (j < Cc) { - csh[j] = g[gb_base + (c0 + j) * sb2]; // raw log-gate, prefix-summed below - bet[j] = beta[gb_base + (c0 + j) * sb2]; -@@ -498,15 +621,53 @@ __global__ void gated_delta_net_chunked_cuda( - } - __syncthreads(); - -+ // --- Phase-1: tensor-core tf32 Gram products (KK->A via warp0, QK->P via warp1). --- -+ // Full C x C tiles into KKsh/QKsh (stride C); decay/beta applied in f32 in the loops below. -+ // Tail chunks (Cc= Cc, but those entries are never read. -+ if constexpr (TC >= 1) { -+ const int w = threadIdx.x >> 5; // warp: 0 -> KK, 1 -> QK -+ const int lane = threadIdx.x & 31; -+ const int lg = lane >> 2; // 0..7 -+ const int lt = lane & 3; // 0..3 -+ if (w < 2) { -+ const float * Xs = (w == 0) ? Kc : Qc; -+ float * Out = (w == 0) ? KKsh : QKsh; -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nt = 0; nt < (C + 7) / 8; nt++) { -+ const int colbase = nt * 8; -+ float cc[4]; -+ gdn_gram_tile_mma(cc, Xs, Kc, rowbase, colbase, lg, lt); -+ const int rr[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int ccol[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ if (rr[l] < C && ccol[l] < C) { -+ Out[rr[l] * C + ccol[l]] = cc[l]; -+ } -+ } -+ } -+ } -+ } -+ __syncthreads(); -+ } -+ - // --- A = I + tril(beta_t * d(t',t) * (k_t . k_t'), -1) (cooperative over C*C) --- - for (int e = j; e < Cc * Cc; e += dv) { - const int t = e / Cc; - const int tp = e % Cc; - float a = 0.0f; - if (tp < t) { -- float kk = 0.0f; -- for (int i = 0; i < dk; i++) { -- kk += Kc[t * dk + i] * Kc[tp * dk + i]; -+ float kk; -+ if constexpr (TC >= 1) { -+ kk = KKsh[t * C + tp]; -+ } else { -+ kk = 0.0f; -+ for (int i = 0; i < dk; i++) { -+ kk += Kc[t * dk + i] * Kc[tp * dk + i]; -+ } - } - const float dd = expf(csh[t] - csh[tp]); // d(tp,t) = gamma_t/gamma_tp - a = bet[t] * dd * kk; -@@ -518,65 +679,304 @@ __global__ void gated_delta_net_chunked_cuda( - __syncthreads(); - - // --- RHS[t][j] = beta_t (v_t[j] - gamma_t * (S0^T k_t)[j]) -> Ud[j*C + t] --- -- for (int t = 0; t < Cc; t++) { -- float ks = 0.0f; // (S0^T k_t)[j] = sum_i S[i][j] k_t[i] -- for (int i = 0; i < dk; i++) { -- ks += Sd[j * dk + i] * Kc[t * dk + i]; -+ if constexpr (TC >= 2) { -+ // M3: fused tensor-core KS = Kc * S0 (3xtf32 state-boundary product). The mma -+ // output is consumed straight from registers into RHS -> Ud, so NO extra C*dv -+ // smem buffer is needed (the 64KB state still occupies smem until M6). Warp w -+ // owns dv n-tiles [w*NTPW, ..); each lane writes the RHS entries it produced. -+ const int w = threadIdx.x >> 5; -+ const int lane = threadIdx.x & 31; -+ const int lg = lane >> 2; -+ const int lt = lane & 3; -+ constexpr int NWARP = S_v / 32; -+ constexpr int NT = dv / 8; -+ constexpr int NTPW = (NT + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int nt = w * NTPW + nn; -+ if (nt >= NT) break; -+ const int colbase = nt * 8; -+ float cc[4]; -+ gdn_gram_tile_mma_3x(cc, Kc, Sd, rowbase, colbase, lg, lt); -+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = jj[l]; -+ if (t < Cc && jc < dv) { -+ const float vtj = v_base[(c0 + t) * sv2 + jc]; -+ Ud[jc * C + t] = bet[t] * (vtj - gam[t] * cc[l]); -+ } -+ } -+ } - } -- const float vtj = v_base[(c0 + t) * sv2 + j]; -- Ud[j * C + t] = bet[t] * (vtj - gam[t] * ks); -+ __syncthreads(); // RHS written cross-thread -> publish before the per-column solve -+ } else { -+ for (int t = 0; t < Cc; t++) { -+ float ks = 0.0f; // (S0^T k_t)[j] = sum_i S[i][j] k_t[i] -+ for (int i = 0; i < dk; i++) { -+ ks += Sd[j * dk + i] * Kc[t * dk + i]; -+ } -+ const float vtj = v_base[(c0 + t) * sv2 + j]; -+ Ud[j * C + t] = bet[t] * (vtj - gam[t] * ks); -+ } -+ } -+ if constexpr (TC >= 3) { -+ // Zero the stale RHS tail (rows t >= Cc) before the full-K mma consumers (P*U at TC>=3; -+ // apply + state at TC>=4). Without this the masked tail terms compute 0*NaN = NaN. -+ for (int t = Cc; t < C; t++) Ud[j * C + t] = 0.0f; -+ __syncthreads(); - } - -- // --- solve A U = RHS in place (unit lower-tri fwd subst); per-thread, no inter-step sync --- -- for (int t = 1; t < Cc; t++) { -- float acc = Ud[j * C + t]; -- for (int tp = 0; tp < t; tp++) { -- acc -= Amat[t * Cc + tp] * Ud[j * C + tp]; -+ // --- solve A U = RHS (A unit-lower-tri) --- -+ if constexpr (TC >= 4) { -+ // M5/P7: form T = A^{-1} explicitly (FLA UT transform), then U = T*RHS as one -+ // dependency-free tf32 GEMM. At C<=16 A is a single b=16 block, so the off-diagonal -+ // Phase-O is empty; only the f32 in-shared diagonal inverse (Phase-D) + the wide -+ // apply remain. Phase-D: column-parallel EXACT f32 inverse of the Cc x Cc unit- -+ // lower-tri A -- thread c solves A x = e_c, writing column c of T into KKsh (free -+ // since KK was consumed into A). This is the strong-coupling amplifier -> f32. -+ if (j < C) { -+ if (j < Cc) { -+ float x[C]; -+ #pragma unroll -+ for (int r = 0; r < C; r++) x[r] = 0.0f; -+ x[j] = 1.0f; -+ for (int r = j + 1; r < Cc; r++) { -+ float acc = 0.0f; -+ for (int m = j; m < r; m++) acc += Amat[r * Cc + m] * x[m]; -+ x[r] = -acc; -+ } -+ #pragma unroll -+ for (int r = 0; r < C; r++) KKsh[r * C + j] = x[r]; // rows >= Cc are 0 -+ } else { -+ #pragma unroll -+ for (int r = 0; r < C; r++) KKsh[r * C + j] = 0.0f; // cols >= Cc are 0 -+ } - } -- Ud[j * C + t] = acc; -+ __syncthreads(); -+ // Apply U = T*RHS, M=C N=dv K=C; T=KKsh (stride C), RHS=Ud (stride C). In place on -+ // Ud: hold every output tile in registers, sync to finish the RHS reads, then -+ // overwrite Ud with U (avoids the read/write aliasing of a same-buffer GEMM). -+ { -+ const int w = threadIdx.x >> 5; -+ const int lane = threadIdx.x & 31; -+ const int lg = lane >> 2; -+ const int lt = lane & 3; -+ constexpr int NWARP = S_v / 32; -+ constexpr int NT = dv / 8; -+ constexpr int NTPW = (NT + NWARP - 1) / NWARP; -+ float ureg[NTPW][4]; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int nt = w * NTPW + nn; -+ if (nt < NT) gdn_gram_tile_mma(ureg[nn], KKsh, Ud, 0, nt * 8, lg, lt); -+ } -+ __syncthreads(); // all RHS(Ud) reads done before overwriting with U -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int nt = w * NTPW + nn; -+ if (nt >= NT) continue; -+ const int colbase = nt * 8; -+ const int tt[4] = {lg, lg, lg + 8, lg + 8}; -+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = jj[l]; -+ if (t < Cc && jc < dv) Ud[jc * C + t] = ureg[nn][l]; -+ } -+ } -+ __syncthreads(); -+ } -+ } else { -+ for (int t = 1; t < Cc; t++) { -+ float acc = Ud[j * C + t]; -+ for (int tp = 0; tp < t; tp++) { -+ acc -= Amat[t * Cc + tp] * Ud[j * C + tp]; -+ } -+ Ud[j * C + t] = acc; -+ } -+ __syncthreads(); // U finalized; Amat free for P below - } -- __syncthreads(); // U finalized; Amat free for P below (and Ud read across-thread? no, own col) - -- // --- P[t][t'] = d(t',t) * (q_t . k_t') for t' <= t (reuse Amat) --- -- for (int e = j; e < Cc * Cc; e += dv) { -- const int t = e / Cc; -- const int tp = e % Cc; -- float p = 0.0f; -- if (tp <= t) { -- float qk = 0.0f; -- for (int i = 0; i < dk; i++) { -- qk += Qc[t * dk + i] * Kc[tp * dk + i]; -+ // --- P[t][t'] = d(t',t) * (q_t . k_t') for t' <= t --- -+ if constexpr (TC >= 3) { -+ // M4: build P (lower-tri, decay pre-baked in f32 -> bounded) IN PLACE in QKsh at -+ // fixed stride C so the P*U output mma can read it as a tf32 A-operand. Full C*C -+ // grid: upper-tri / out-of-range entries are zeroed so the K=C mma needs no masking. -+ for (int e = j; e < C * C; e += dv) { -+ const int t = e / C; -+ const int tp = e % C; -+ float p = 0.0f; -+ if (tp <= t && t < Cc && tp < Cc) { -+ const float dd = expf(csh[t] - csh[tp]); -+ p = dd * QKsh[t * C + tp]; // QKsh holds QK (M2); overwrite in place with P - } -- const float dd = expf(csh[t] - csh[tp]); -- p = dd * qk; -+ QKsh[t * C + tp] = p; -+ } -+ __syncthreads(); -+ } else { -+ for (int e = j; e < Cc * Cc; e += dv) { -+ const int t = e / Cc; -+ const int tp = e % Cc; -+ float p = 0.0f; -+ if (tp <= t) { -+ float qk; -+ if constexpr (TC >= 1) { -+ qk = QKsh[t * C + tp]; -+ } else { -+ qk = 0.0f; -+ for (int i = 0; i < dk; i++) { -+ qk += Qc[t * dk + i] * Kc[tp * dk + i]; -+ } -+ } -+ const float dd = expf(csh[t] - csh[tp]); -+ p = dd * qk; -+ } -+ Amat[t * Cc + tp] = p; - } -- Amat[t * Cc + tp] = p; -+ __syncthreads(); - } -- __syncthreads(); - - // --- O[t][j] = gamma_t (S0^T q_t)[j] + sum_{t'<=t} P[t][t'] U[t'][j] (* scale) --- -- for (int t = 0; t < Cc; t++) { -- float qs = 0.0f; // (S0^T q_t)[j] (uses pre-update S) -- for (int i = 0; i < dk; i++) { -- qs += Sd[j * dk + i] * Qc[t * dk + i]; -+ if constexpr (TC >= 2) { -+ // M3: fused tensor-core QS = Qc * S0 (3xtf32, pre-update S0). Deposit the -+ // gamma_t*QS[t][j] cross-chunk term into dst from the mma registers; the O loop -+ // below reads it back (published via __syncthreads) and adds the intra-chunk P*U. -+ const int w = threadIdx.x >> 5; -+ const int lane = threadIdx.x & 31; -+ const int lg = lane >> 2; -+ const int lt = lane & 3; -+ constexpr int NWARP = S_v / 32; -+ constexpr int NT = dv / 8; -+ constexpr int NTPW = (NT + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int nt = w * NTPW + nn; -+ if (nt >= NT) break; -+ const int colbase = nt * 8; -+ float cc[4]; -+ gdn_gram_tile_mma_3x(cc, Qc, Sd, rowbase, colbase, lg, lt); -+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = jj[l]; -+ if (t < Cc && jc < dv) { -+ attn_base[(c0 + t) * S_v * H + jc] = gam[t] * cc[l]; -+ } -+ } -+ } - } -- float o = gam[t] * qs; -- for (int tp = 0; tp <= t; tp++) { -- o += Amat[t * Cc + tp] * Ud[j * C + tp]; -+ __syncthreads(); -+ } -+ if constexpr (TC >= 3) { -+ // M4: O += P*U via tensor-core (tf32-safe: P is f32-bounded, decay pre-baked). -+ // GEMM O[t][j] += sum_t' P[t][t']*U[t'][j], M=C N=dv K=C; P=QKsh (stride C), -+ // U=Ud (stride C). The gamma_t*QS cross-chunk term was deposited into dst above; -+ // fold it in here then * scale. Warp w owns dv n-tiles [w*NTPW, ..). -+ const int w = threadIdx.x >> 5; -+ const int lane = threadIdx.x & 31; -+ const int lg = lane >> 2; -+ const int lt = lane & 3; -+ constexpr int NWARP = S_v / 32; -+ constexpr int NT = dv / 8; -+ constexpr int NTPW = (NT + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int nt = w * NTPW + nn; -+ if (nt >= NT) break; -+ const int colbase = nt * 8; -+ float cc[4]; -+ gdn_gram_tile_mma(cc, QKsh, Ud, rowbase, colbase, lg, lt); -+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = jj[l]; -+ if (t < Cc && jc < dv) { -+ const int64_t oi = (int64_t)(c0 + t) * S_v * H + jc; -+ attn_base[oi] = (attn_base[oi] + cc[l]) * scale; // QS term + P*U -+ } -+ } -+ } -+ } -+ } else { -+ for (int t = 0; t < Cc; t++) { -+ float o; -+ if constexpr (TC >= 2) { -+ o = attn_base[(c0 + t) * S_v * H + j]; // gamma_t*QS[t][j] deposited above -+ } else { -+ float qs = 0.0f; // (S0^T q_t)[j] (uses pre-update S) -+ for (int i = 0; i < dk; i++) { -+ qs += Sd[j * dk + i] * Qc[t * dk + i]; -+ } -+ o = gam[t] * qs; -+ } -+ for (int tp = 0; tp <= t; tp++) { -+ o += Amat[t * Cc + tp] * Ud[j * C + tp]; -+ } -+ attn_base[(c0 + t) * S_v * H + j] = o * scale; - } -- attn_base[(c0 + t) * S_v * H + j] = o * scale; - } - - // --- S_C[i][j] = gamma_{C-1} S[i][j] + sum_t d(t,C-1) k_t[i] u_t[j] --- - const float glast = gam[Cc - 1]; - const float cslast = csh[Cc - 1]; -- for (int i = 0; i < dk; i++) { -- float s = glast * Sd[j * dk + i]; -- for (int t = 0; t < Cc; t++) { -- const float dd = expf(cslast - csh[t]); // d(t, last) -- s += dd * Kc[t * dk + i] * Ud[j * C + t]; -+ if constexpr (TC >= 4) { -+ // M5/P6: state carry S_C = glast*S0 + Kc^T * DU via 3xtf32 mma. DU[t][j] = -+ // d(t,last)*U[t][j] is built IN PLACE in Ud (t>=Cc zeroed so the K=C contraction -+ // needs no per-k masking), then S_C accumulates over the chunk dim t. Kc is read -+ // transposed (i as M-row). M=dk N=dv K=C. Each output (i,j) has a unique owner so -+ // the glast*S0 read-modify-write is race-free. -+ for (int t = 0; t < C; t++) { -+ const float dd = (t < Cc) ? expf(cslast - csh[t]) : 0.0f; -+ Ud[j * C + t] = dd * Ud[j * C + t]; // thread j owns column j -> DU in place -+ } -+ __syncthreads(); -+ const int w = threadIdx.x >> 5; -+ const int lane = threadIdx.x & 31; -+ const int lg = lane >> 2; -+ const int lt = lane & 3; -+ constexpr int NWARP = S_v / 32; -+ constexpr int MT = dk / 16; // m-tiles over dk -+ constexpr int NT = dv / 8; // n-tiles over dv -+ constexpr int NTILES = MT * NT; -+ constexpr int TPW = (NTILES + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int idx = 0; idx < TPW; idx++) { -+ const int tile = w * TPW + idx; -+ if (tile >= NTILES) break; -+ const int rowbase = (tile / NT) * 16; -+ const int colbase = (tile % NT) * 8; -+ float cc[4]; -+ gdn_state_tile_mma_3x(cc, Kc, Ud, rowbase, colbase, lg, lt); -+ const int ii[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int i = ii[l], jc = jj[l]; -+ Sd[jc * dk + i] = glast * Sd[jc * dk + i] + cc[l]; -+ } -+ } -+ } else { -+ for (int i = 0; i < dk; i++) { -+ float s = glast * Sd[j * dk + i]; -+ for (int t = 0; t < Cc; t++) { -+ const float dd = expf(cslast - csh[t]); // d(t, last) -+ s += dd * Kc[t * dk + i] * Ud[j * C + t]; -+ } -+ Sd[j * dk + i] = s; - } -- Sd[j * dk + i] = s; - } - __syncthreads(); // Sd reused as S0 of next chunk; Kc/Qc/Amat reloaded next chunk - } -@@ -591,7 +991,741 @@ __global__ void gated_delta_net_chunked_cuda( - } - } - --template -+// ===================== M6: register-resident chunk-state (the occupancy flip) ===================== -+// Register-resident variant of the full tensor-core (TC=4) chunked path. The 128x128 chunk-state S -+// is moved OUT of shared memory: each column-owner thread j keeps its state column S[*][j] in -+// registers (Sreg[dk]). This frees the 64KB shared buffer that pinned 0031/M5 at C=16 -- the -+// load-bearing flip for occupancy (0031's -22% was shared/grid starvation, not the math). With the -+// state out of smem the chunk C is raised (16->32 here; 64 needs the bf16-K/Q + dv-slab budget, -+// M7). The crux (design section 5): S is an ACCUMULATOR for the state-carry (step 6) but a B-OPERAND -+// for KS/QS (steps 3/4), and those fragment layouts differ -- so once per chunk the state is bridged -+// through a small transient shared tile (Sres, SRES_W columns wide, looped over dv-strips): -+// - KS/QS: column-owner writes Sreg->Sres, the mma reads Sres as B (load_generic, NOT ldmatrix). -+// - state-carry: the Kc^T*DU mma writes its accumulator output to Sres, then the column-owner -+// folds Sres into Sreg (S_C = glast*S0 + Kc^T*DU). -+// Decays/gamma/beta stay f32 outside the mma; KS/QS/state-carry stay 3xtf32 (the cross-chunk carry). -+template -+__global__ void gated_delta_net_chunked_rr_cuda( -+ const float * __restrict__ q, const float * __restrict__ k, -+ const float * __restrict__ v, const float * __restrict__ g, -+ const float * __restrict__ beta, const float * __restrict__ curr_state, -+ float * __restrict__ dst, -+ int64_t H, int64_t n_tokens, int64_t n_seqs, -+ int64_t sq1, int64_t sq2, int64_t sq3, -+ int64_t sv1, int64_t sv2, int64_t sv3, -+ int64_t sb1, int64_t sb2, int64_t sb3, -+ uint3 neqk1_magic, uint3 rq3_magic, -+ float scale, float * __restrict__ state_dst, -+ const int32_t * __restrict__ ids, int rs_head) { -+ constexpr int dk = S_v; -+ constexpr int dv = S_v; -+ constexpr int dvt = DV_TILE; // dv-slab width (== dv when not slabbed) -+ constexpr int NWARP = DV_TILE / 32; -+ const int h_idx = blockIdx.x; -+ const int seq = blockIdx.y; -+ const int col0 = blockIdx.z * dvt; // global dv-column base of this slab -+ const int j = threadIdx.x; // slab-local v-column (0..dvt-1); global = col0 + j -+ -+ const uint32_t iq1 = fastmodulo((uint32_t) h_idx, neqk1_magic); -+ const uint32_t iq3 = fastdiv((uint32_t) seq, rq3_magic); -+ -+ extern __shared__ float gdn_smem[]; -+ float * Kc = gdn_smem; // [C*dk] Kc[t*dk + i] -+ float * Qc = Kc + (size_t) C * dk; // [C*dk] Qc[t*dk + i] -+ float * Ud = Qc + (size_t) C * dk; // [dvt*C] Ud[localcol*C + t] (dv-sliced per slab) -+ float * Amat = Ud + (size_t) dvt * C; // [C*C] A / P scratch (row-major Amat[t*C + t']) -+ float * csh = Amat + (size_t) C * C; // [C] cumsum(log-gate) -+ float * gam = csh + C; // [C] gamma_t -+ float * bet = gam + C; // [C] beta_t -+ float * KKsh = bet + C; // [C*C] KK Gram, then reused for T = A^{-1} -+ float * QKsh = KKsh + (size_t) C * C; // [C*C] QK Gram, then reused for P -+ float * Sres = QKsh + (size_t) C * C; // [SRES_W*dk] transient accumulator<->B state bridge -+ -+ // S0 register-resident: thread j owns column j -> Sreg[i] = S[i][j]. Same gather/identity -+ // plumbing as the sequential kernel (non-identity seqs gathered by the dispatcher). -+ const bool identity = (ids != nullptr && ids[seq] == rs_head + seq); -+ const float * read_state = (identity ? state_dst : curr_state) -+ + (int64_t) seq * H * dk * dv + (int64_t) h_idx * dk * dv; -+ float Sreg[dk]; -+ #pragma unroll -+ for (int i = 0; i < dk; i++) { -+ Sreg[i] = read_state[(col0 + j) * dk + i]; -+ } -+ -+ const float * q_base = q + iq3 * sq3 + iq1 * sq1; -+ const float * k_base = k + iq3 * sq3 + iq1 * sq1; -+ const float * v_base = v + seq * sv3 + h_idx * sv1; -+ const int64_t gb_base = seq * sb3 + h_idx * sb1; -+ float * attn_base = dst + (int64_t) (seq * n_tokens * H + h_idx) * S_v; -+ -+ const int w = threadIdx.x >> 5; -+ const int lane = threadIdx.x & 31; -+ const int lg = lane >> 2; -+ const int lt = lane & 3; -+ -+ for (int64_t c0 = 0; c0 < n_tokens; c0 += C) { -+ const int Cc = (int) ((n_tokens - c0) < (int64_t) C ? (n_tokens - c0) : (int64_t) C); -+ -+ // --- load chunk K,Q (cooperative), beta and the gate prefix (cs, gamma) --- -+ for (int e = j; e < Cc * dk; e += dvt) { -+ const int t = e / dk; -+ const int i = e % dk; -+ Kc[t * dk + i] = k_base[(c0 + t) * sq2 + i]; -+ Qc[t * dk + i] = q_base[(c0 + t) * sq2 + i]; -+ } -+ // Zero the stale K/Q tail (rows t >= Cc) on short/tail chunks: the state-carry mma reads Kc -+ // over the full chunk dim K=0..C-1, and uninitialized smem may be NaN (0*NaN = NaN). Finite -+ // zeros keep the masked tail terms a clean 0. -+ for (int e = Cc * dk + j; e < C * dk; e += dvt) { -+ Kc[e] = 0.0f; -+ Qc[e] = 0.0f; -+ } -+ if (j < Cc) { -+ csh[j] = g[gb_base + (c0 + j) * sb2]; -+ bet[j] = beta[gb_base + (c0 + j) * sb2]; -+ } -+ __syncthreads(); -+ if (j == 0) { -+ float run = 0.0f; -+ for (int t = 0; t < Cc; t++) { -+ run += csh[t]; -+ csh[t] = run; -+ gam[t] = expf(run); -+ } -+ } -+ __syncthreads(); -+ -+ // --- Phase C: tensor-core tf32 Gram (KK->KKsh via warp0, QK->QKsh via warp1) --- -+ if (w < 2) { -+ const float * Xs = (w == 0) ? Kc : Qc; -+ float * Out = (w == 0) ? KKsh : QKsh; -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nt = 0; nt < (C + 7) / 8; nt++) { -+ const int colbase = nt * 8; -+ float cc[4]; -+ gdn_gram_tile_mma(cc, Xs, Kc, rowbase, colbase, lg, lt); -+ const int rr[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int ccol[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ if (rr[l] < C && ccol[l] < C) Out[rr[l] * C + ccol[l]] = cc[l]; -+ } -+ } -+ } -+ } -+ __syncthreads(); -+ -+ // --- A = I + tril(beta_t * d(t',t) * (k_t . k_t'), -1) (decay/beta in f32) --- -+ for (int e = j; e < Cc * Cc; e += dvt) { -+ const int t = e / Cc; -+ const int tp = e % Cc; -+ float a = 0.0f; -+ if (tp < t) { -+ const float kk = KKsh[t * C + tp]; -+ const float dd = expf(csh[t] - csh[tp]); -+ a = bet[t] * dd * kk; -+ } else if (tp == t) { -+ a = 1.0f; -+ } -+ Amat[t * Cc + tp] = a; -+ } -+ __syncthreads(); -+ -+ // --- KS = Kc * S0 (3xtf32) -> RHS = beta_t (v - gamma_t KS) -> Ud ; dv-strip restage --- -+ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { -+ if (j >= s0 && j < s0 + SRES_W) { -+ #pragma unroll -+ for (int i = 0; i < dk; i++) Sres[(j - s0) * dk + i] = Sreg[i]; -+ } -+ __syncthreads(); -+ constexpr int NTPS = SRES_W / 8; -+ constexpr int NTPW = (NTPS + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int ntl = w * NTPW + nn; -+ if (ntl >= NTPS) break; -+ const int colbase = ntl * 8; -+ float cc[4]; -+ gdn_gram_tile_mma_3x(cc, Kc, Sres, rowbase, colbase, lg, lt); -+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jl[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = s0 + jl[l]; // jc = slab-local column -+ if (t < Cc && jc < dvt) { -+ const float vtj = v_base[(c0 + t) * sv2 + (col0 + jc)]; -+ Ud[jc * C + t] = bet[t] * (vtj - gam[t] * cc[l]); -+ } -+ } -+ } -+ } -+ __syncthreads(); // RHS reads of Sres done before next strip overwrites it -+ } -+ // Zero the stale RHS tail (rows t >= Cc) for short/tail chunks. The apply (U=T*RHS), the -+ // P*U output mma and the state-carry mma all contract the full chunk dim K=0..C-1; the -+ // zeroed P/T columns there would otherwise multiply UNINITIALIZED Ud entries (0*NaN = NaN), -+ // which is exactly what corrupted the tensor-core output. thread j owns local column j. -+ for (int t = Cc; t < C; t++) Ud[j * C + t] = 0.0f; -+ __syncthreads(); -+ -+ // --- Phase D: form T = A^{-1} (f32 column-parallel exact inverse into KKsh) --- -+ if (j < C) { -+ if (j < Cc) { -+ float x[C]; -+ #pragma unroll -+ for (int r = 0; r < C; r++) x[r] = 0.0f; -+ x[j] = 1.0f; -+ for (int r = j + 1; r < Cc; r++) { -+ float acc = 0.0f; -+ for (int m = j; m < r; m++) acc += Amat[r * Cc + m] * x[m]; -+ x[r] = -acc; -+ } -+ #pragma unroll -+ for (int r = 0; r < C; r++) KKsh[r * C + j] = x[r]; -+ } else { -+ #pragma unroll -+ for (int r = 0; r < C; r++) KKsh[r * C + j] = 0.0f; -+ } -+ } -+ __syncthreads(); -+ // --- apply U = T*RHS (tf32), in place on Ud; loop m-tiles (C may exceed one m16 tile) --- -+ { -+ constexpr int MT = (C + 15) / 16; -+ constexpr int NT = dvt / 8; -+ constexpr int NTPW = (NT + NWARP - 1) / NWARP; -+ float ureg[MT][NTPW][4]; -+ #pragma unroll -+ for (int mt = 0; mt < MT; mt++) { -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int nt = w * NTPW + nn; -+ if (nt < NT) gdn_gram_tile_mma(ureg[mt][nn], KKsh, Ud, mt * 16, nt * 8, lg, lt); -+ } -+ } -+ __syncthreads(); // all RHS(Ud) reads done before overwriting with U -+ #pragma unroll -+ for (int mt = 0; mt < MT; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int nt = w * NTPW + nn; -+ if (nt >= NT) continue; -+ const int colbase = nt * 8; -+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = jj[l]; // jc = slab-local column -+ if (t < Cc && jc < dvt) Ud[jc * C + t] = ureg[mt][nn][l]; -+ } -+ } -+ } -+ __syncthreads(); -+ } -+ -+ // --- P = d(t',t) * QK (in place in QKsh; bounded f32 decay pre-baked) --- -+ for (int e = j; e < C * C; e += dvt) { -+ const int t = e / C; -+ const int tp = e % C; -+ float p = 0.0f; -+ if (tp <= t && t < Cc && tp < Cc) { -+ const float dd = expf(csh[t] - csh[tp]); -+ p = dd * QKsh[t * C + tp]; -+ } -+ QKsh[t * C + tp] = p; -+ } -+ __syncthreads(); -+ -+ // --- O = gamma_t * QS (3xtf32, pre-update S0) ; dv-strip restage --- -+ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { -+ if (j >= s0 && j < s0 + SRES_W) { -+ #pragma unroll -+ for (int i = 0; i < dk; i++) Sres[(j - s0) * dk + i] = Sreg[i]; -+ } -+ __syncthreads(); -+ constexpr int NTPS = SRES_W / 8; -+ constexpr int NTPW = (NTPS + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int ntl = w * NTPW + nn; -+ if (ntl >= NTPS) break; -+ const int colbase = ntl * 8; -+ float cc[4]; -+ gdn_gram_tile_mma_3x(cc, Qc, Sres, rowbase, colbase, lg, lt); -+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jl[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = s0 + jl[l]; // jc = slab-local column -+ if (t < Cc && jc < dvt) attn_base[(c0 + t) * S_v * H + (col0 + jc)] = gam[t] * cc[l]; -+ } -+ } -+ } -+ __syncthreads(); -+ } -+ // --- O += P*U (tf32-safe) then * scale (slab-local n-tiles) --- -+ { -+ constexpr int NT = dvt / 8; -+ constexpr int NTPW = (NT + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int nt = w * NTPW + nn; -+ if (nt >= NT) break; -+ const int colbase = nt * 8; -+ float cc[4]; -+ gdn_gram_tile_mma(cc, QKsh, Ud, rowbase, colbase, lg, lt); -+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = jj[l]; // jc = slab-local column -+ if (t < Cc && jc < dvt) { -+ const int64_t oi = (int64_t)(c0 + t) * S_v * H + (col0 + jc); -+ attn_base[oi] = (attn_base[oi] + cc[l]) * scale; -+ } -+ } -+ } -+ } -+ } -+ __syncthreads(); // O done; Ud free to become DU, Sres free for the state restage -+ -+ // --- state carry: S_C = glast*S0 + Kc^T * DU (3xtf32) ; dv-strip mma->Sres->Sreg fold --- -+ const float glast = gam[Cc - 1]; -+ const float cslast = csh[Cc - 1]; -+ for (int t = 0; t < C; t++) { -+ const float dd = (t < Cc) ? expf(cslast - csh[t]) : 0.0f; -+ Ud[j * C + t] = dd * Ud[j * C + t]; // DU in place (thread j owns column j) -+ } -+ __syncthreads(); -+ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { -+ constexpr int MT = dk / 16; -+ constexpr int NT_S = SRES_W / 8; -+ constexpr int NTILES = MT * NT_S; -+ constexpr int TPW = (NTILES + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int idx = 0; idx < TPW; idx++) { -+ const int tile = w * TPW + idx; -+ if (tile >= NTILES) break; -+ const int rowbase = (tile / NT_S) * 16; // i-tile over dk -+ const int coll = (tile % NT_S) * 8; // local jc within the strip -+ const int colg = s0 + coll; // slab-local jc (Ud is slab-local) -+ float cc[4]; -+ gdn_state_tile_mma_3x(cc, Kc, Ud, rowbase, colg, lg, lt); -+ const int ii[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jl[4] = {coll + 2*lt, coll + 2*lt + 1, coll + 2*lt, coll + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) Sres[jl[l] * dk + ii[l]] = cc[l]; -+ } -+ __syncthreads(); -+ if (j >= s0 && j < s0 + SRES_W) { -+ const int jcl = j - s0; -+ #pragma unroll -+ for (int i = 0; i < dk; i++) Sreg[i] = glast * Sreg[i] + Sres[jcl * dk + i]; -+ } -+ __syncthreads(); // Sres reads done before next strip's mma overwrites it -+ } -+ } -+ -+ // --- final-state write-back (M-layout) --- -+ const int64_t state_out_offset = (int64_t) (seq * H + h_idx) * S_v * S_v; -+ const int64_t attn_score_elems = (int64_t) S_v * H * n_tokens * n_seqs; -+ float * st = (state_dst != nullptr) ? (state_dst + state_out_offset) -+ : (dst + attn_score_elems + state_out_offset); -+ #pragma unroll -+ for (int i = 0; i < dk; i++) { -+ st[(col0 + j) * dk + i] = Sreg[i]; -+ } -+} -+ -+// ============================== CONFIG C (M8): the real occupancy flip ============================== -+// Built on the M6/M7 register-resident state (gated_delta_net_chunked_rr_cuda). Two smem reliefs -+// let C reach 64 under the 99KB dynamic-smem opt-in AND drop C=32 under the ~49.5KB/block budget -+// for >=2 resident blocks/SM: -+// (1) Kc/Qc staged as BF16 (half the f32 footprint). The tf32 m16n8k8 mma reads bf16-upconverted -+// operands; bf16's 8-bit mantissa fits tf32's 10 bits (scope: "bf16 only for the Gram terms"). -+// (2) ONE extra C*C buffer dropped: the QK Gram is computed LATE (Phase F, after T=A^-1 is consumed -+// by the apply), reusing the KKsh slot. So only 2 C*C f32 scratch (A in Amat, T->QK->P in KKsh) -+// instead of M6/M7's 3, and a small SRES_W=16 restage strip instead of SRES_W=C. -+// State stays register-resident in the per-thread dv-column layout (Sreg[dk]); in this column model -+// registers are NOT the 2-blk/SM cap (block = DV_TILE <= 128 threads, so regs/thread budget is ample -+// at 2 blocks) -- smem is, and the two reliefs above flip it. Decays/gamma/beta stay f32 outside the -+// mma; KS/QS + state-carry stay 3xtf32 (the cross-chunk carry compounds over every chunk). -+// C=64, dv_tile=64 -> ~89KB, 1 blk/SM (BW-max, C raised to 64). -+// C=32, dv_tile=64 -> ~40KB, 2 blk/SM (the occupancy flip). C=32, dv_tile=128 -> ~48KB, 2 blk/SM. -+// MIN_BLK feeds __launch_bounds__ so the compiler caps registers for the target occupancy. -+template -+__global__ void __launch_bounds__(DV_TILE, MIN_BLK) -+gated_delta_net_chunked_cc_cuda( -+ const float * __restrict__ q, const float * __restrict__ k, -+ const float * __restrict__ v, const float * __restrict__ g, -+ const float * __restrict__ beta, const float * __restrict__ curr_state, -+ float * __restrict__ dst, -+ int64_t H, int64_t n_tokens, int64_t n_seqs, -+ int64_t sq1, int64_t sq2, int64_t sq3, -+ int64_t sv1, int64_t sv2, int64_t sv3, -+ int64_t sb1, int64_t sb2, int64_t sb3, -+ uint3 neqk1_magic, uint3 rq3_magic, -+ float scale, float * __restrict__ state_dst, -+ const int32_t * __restrict__ ids, int rs_head) { -+ constexpr int dk = S_v; -+ constexpr int dv = S_v; -+ constexpr int dvt = DV_TILE; // dv-slab width (== dv when not slabbed) -+ constexpr int NWARP = DV_TILE / 32; -+ const int h_idx = blockIdx.x; -+ const int seq = blockIdx.y; -+ const int col0 = blockIdx.z * dvt; // global dv-column base of this slab -+ const int j = threadIdx.x; // slab-local v-column (0..dvt-1); global = col0 + j -+ -+ const uint32_t iq1 = fastmodulo((uint32_t) h_idx, neqk1_magic); -+ const uint32_t iq3 = fastdiv((uint32_t) seq, rq3_magic); -+ -+ // smem: BF16 Kc,Qc prefix then the f32 scratch. Reuse the float gdn_smem[] symbol (same name as -+ // the other chunked kernels, no extern-shared type clash) and reinterpret its front as bf16: the -+ // two bf16 buffers are 2*C*dk bf16 = 4*C*dk bytes = exactly C*dk floats, so Ud starts at the -+ // float offset C*dk (4-byte aligned). Two C*C f32 buffers only (Amat, KKsh; QK Gram deferred). -+ extern __shared__ float gdn_smem[]; -+ nv_bfloat16 * Kc = (nv_bfloat16 *) gdn_smem; // [C*dk] Kc[t*dk + i] (bf16) -+ nv_bfloat16 * Qc = Kc + (size_t) C * dk; // [C*dk] Qc[t*dk + i] (bf16) -+ float * Ud = gdn_smem + (size_t) C * dk; // [dvt*C] Ud[localcol*C + t] -+ float * Amat = Ud + (size_t) dvt * C; // [C*C] A scratch -+ float * csh = Amat + (size_t) C * C; // [C] cumsum(log-gate) -+ float * gam = csh + C; // [C] gamma_t -+ float * bet = gam + C; // [C] beta_t -+ float * KKsh = bet + C; // [C*C] KK Gram -> T=A^{-1} -> QK -> P -+ float * Sres = KKsh + (size_t) C * C; // [SRES_W*dk] accumulator<->B state bridge -+ -+ // S0 register-resident: thread j owns column j -> Sreg[i] = S[i][j]. -+ const bool identity = (ids != nullptr && ids[seq] == rs_head + seq); -+ const float * read_state = (identity ? state_dst : curr_state) -+ + (int64_t) seq * H * dk * dv + (int64_t) h_idx * dk * dv; -+ float Sreg[dk]; -+ #pragma unroll -+ for (int i = 0; i < dk; i++) { -+ Sreg[i] = read_state[(col0 + j) * dk + i]; -+ } -+ -+ const float * q_base = q + iq3 * sq3 + iq1 * sq1; -+ const float * k_base = k + iq3 * sq3 + iq1 * sq1; -+ const float * v_base = v + seq * sv3 + h_idx * sv1; -+ const int64_t gb_base = seq * sb3 + h_idx * sb1; -+ float * attn_base = dst + (int64_t) (seq * n_tokens * H + h_idx) * S_v; -+ -+ const int w = threadIdx.x >> 5; -+ const int lane = threadIdx.x & 31; -+ const int lg = lane >> 2; -+ const int lt = lane & 3; -+ -+ for (int64_t c0 = 0; c0 < n_tokens; c0 += C) { -+ const int Cc = (int) ((n_tokens - c0) < (int64_t) C ? (n_tokens - c0) : (int64_t) C); -+ -+ // --- load chunk K,Q (cooperative, f32->bf16), beta and the gate prefix (cs, gamma) --- -+ for (int e = j; e < Cc * dk; e += dvt) { -+ const int t = e / dk; -+ const int i = e % dk; -+ Kc[t * dk + i] = __float2bfloat16(k_base[(c0 + t) * sq2 + i]); -+ Qc[t * dk + i] = __float2bfloat16(q_base[(c0 + t) * sq2 + i]); -+ } -+ // Zero the stale K/Q tail (rows t >= Cc): the state-carry mma reads Kc over K=0..C-1. -+ for (int e = Cc * dk + j; e < C * dk; e += dvt) { -+ Kc[e] = __float2bfloat16(0.0f); -+ Qc[e] = __float2bfloat16(0.0f); -+ } -+ if (j < Cc) { -+ csh[j] = g[gb_base + (c0 + j) * sb2]; -+ bet[j] = beta[gb_base + (c0 + j) * sb2]; -+ } -+ __syncthreads(); -+ if (j == 0) { -+ float run = 0.0f; -+ for (int t = 0; t < Cc; t++) { -+ run += csh[t]; -+ csh[t] = run; -+ gam[t] = expf(run); -+ } -+ } -+ __syncthreads(); -+ -+ // --- Phase C: tensor-core tf32 Gram, KK->KKsh (warp0). QK is deferred to Phase F. --- -+ if (w == 0) { -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nt = 0; nt < (C + 7) / 8; nt++) { -+ const int colbase = nt * 8; -+ float cc[4]; -+ gdn_gram_tile_mma(cc, Kc, Kc, rowbase, colbase, lg, lt); -+ const int rr[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int ccol[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ if (rr[l] < C && ccol[l] < C) KKsh[rr[l] * C + ccol[l]] = cc[l]; -+ } -+ } -+ } -+ } -+ __syncthreads(); -+ -+ // --- A = I + tril(beta_t * d(t',t) * (k_t . k_t'), -1) (decay/beta in f32) --- -+ for (int e = j; e < Cc * Cc; e += dvt) { -+ const int t = e / Cc; -+ const int tp = e % Cc; -+ float a = 0.0f; -+ if (tp < t) { -+ const float kk = KKsh[t * C + tp]; -+ const float dd = expf(csh[t] - csh[tp]); -+ a = bet[t] * dd * kk; -+ } else if (tp == t) { -+ a = 1.0f; -+ } -+ Amat[t * Cc + tp] = a; -+ } -+ __syncthreads(); -+ -+ // --- KS = Kc * S0 (3xtf32) -> RHS = beta_t (v - gamma_t KS) -> Ud ; dv-strip restage --- -+ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { -+ if (j >= s0 && j < s0 + SRES_W) { -+ #pragma unroll -+ for (int i = 0; i < dk; i++) Sres[(j - s0) * dk + i] = Sreg[i]; -+ } -+ __syncthreads(); -+ constexpr int NTPS = SRES_W / 8; -+ constexpr int NTPW = (NTPS + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int ntl = w * NTPW + nn; -+ if (ntl >= NTPS) break; -+ const int colbase = ntl * 8; -+ float cc[4]; -+ gdn_gram_tile_mma_3x(cc, Kc, Sres, rowbase, colbase, lg, lt); -+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jl[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = s0 + jl[l]; // jc = slab-local column -+ if (t < Cc && jc < dvt) { -+ const float vtj = v_base[(c0 + t) * sv2 + (col0 + jc)]; -+ Ud[jc * C + t] = bet[t] * (vtj - gam[t] * cc[l]); -+ } -+ } -+ } -+ } -+ __syncthreads(); // RHS reads of Sres done before next strip overwrites it -+ } -+ // Zero the stale RHS tail (rows t >= Cc): apply/P*U/state-carry contract K=0..C-1. -+ for (int t = Cc; t < C; t++) Ud[j * C + t] = 0.0f; -+ __syncthreads(); -+ -+ // --- Phase D: form T = A^{-1} (f32 column-parallel exact inverse into KKsh) --- -+ if (j < C) { -+ if (j < Cc) { -+ float x[C]; -+ #pragma unroll -+ for (int r = 0; r < C; r++) x[r] = 0.0f; -+ x[j] = 1.0f; -+ for (int r = j + 1; r < Cc; r++) { -+ float acc = 0.0f; -+ for (int m = j; m < r; m++) acc += Amat[r * Cc + m] * x[m]; -+ x[r] = -acc; -+ } -+ #pragma unroll -+ for (int r = 0; r < C; r++) KKsh[r * C + j] = x[r]; -+ } else { -+ #pragma unroll -+ for (int r = 0; r < C; r++) KKsh[r * C + j] = 0.0f; -+ } -+ } -+ __syncthreads(); -+ // --- apply U = T*RHS (tf32), in place on Ud; loop m-tiles --- -+ { -+ constexpr int MT = (C + 15) / 16; -+ constexpr int NT = dvt / 8; -+ constexpr int NTPW = (NT + NWARP - 1) / NWARP; -+ float ureg[MT][NTPW][4]; -+ #pragma unroll -+ for (int mt = 0; mt < MT; mt++) { -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int nt = w * NTPW + nn; -+ if (nt < NT) gdn_gram_tile_mma(ureg[mt][nn], KKsh, Ud, mt * 16, nt * 8, lg, lt); -+ } -+ } -+ __syncthreads(); // all RHS(Ud) reads done before overwriting with U -+ #pragma unroll -+ for (int mt = 0; mt < MT; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int nt = w * NTPW + nn; -+ if (nt >= NT) continue; -+ const int colbase = nt * 8; -+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = jj[l]; // jc = slab-local column -+ if (t < Cc && jc < dvt) Ud[jc * C + t] = ureg[mt][nn][l]; -+ } -+ } -+ } -+ __syncthreads(); -+ } -+ -+ // --- Phase F: QK Gram LATE (Kc/Qc bf16) -> KKsh (T is consumed), then P = d(t',t)*QK --- -+ if (w == 0) { -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nt = 0; nt < (C + 7) / 8; nt++) { -+ const int colbase = nt * 8; -+ float cc[4]; -+ gdn_gram_tile_mma(cc, Qc, Kc, rowbase, colbase, lg, lt); -+ const int rr[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int ccol[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ if (rr[l] < C && ccol[l] < C) KKsh[rr[l] * C + ccol[l]] = cc[l]; -+ } -+ } -+ } -+ } -+ __syncthreads(); -+ for (int e = j; e < C * C; e += dvt) { -+ const int t = e / C; -+ const int tp = e % C; -+ float p = 0.0f; -+ if (tp <= t && t < Cc && tp < Cc) { -+ const float dd = expf(csh[t] - csh[tp]); -+ p = dd * KKsh[t * C + tp]; -+ } -+ KKsh[t * C + tp] = p; -+ } -+ __syncthreads(); -+ -+ // --- O = gamma_t * QS (3xtf32, pre-update S0) ; dv-strip restage --- -+ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { -+ if (j >= s0 && j < s0 + SRES_W) { -+ #pragma unroll -+ for (int i = 0; i < dk; i++) Sres[(j - s0) * dk + i] = Sreg[i]; -+ } -+ __syncthreads(); -+ constexpr int NTPS = SRES_W / 8; -+ constexpr int NTPW = (NTPS + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int ntl = w * NTPW + nn; -+ if (ntl >= NTPS) break; -+ const int colbase = ntl * 8; -+ float cc[4]; -+ gdn_gram_tile_mma_3x(cc, Qc, Sres, rowbase, colbase, lg, lt); -+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jl[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = s0 + jl[l]; // jc = slab-local column -+ if (t < Cc && jc < dvt) attn_base[(c0 + t) * S_v * H + (col0 + jc)] = gam[t] * cc[l]; -+ } -+ } -+ } -+ __syncthreads(); -+ } -+ // --- O += P*U (tf32-safe) then * scale (slab-local n-tiles) --- -+ { -+ constexpr int NT = dvt / 8; -+ constexpr int NTPW = (NT + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int mt = 0; mt < (C + 15) / 16; mt++) { -+ const int rowbase = mt * 16; -+ #pragma unroll -+ for (int nn = 0; nn < NTPW; nn++) { -+ const int nt = w * NTPW + nn; -+ if (nt >= NT) break; -+ const int colbase = nt * 8; -+ float cc[4]; -+ gdn_gram_tile_mma(cc, KKsh, Ud, rowbase, colbase, lg, lt); -+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) { -+ const int t = tt[l], jc = jj[l]; // jc = slab-local column -+ if (t < Cc && jc < dvt) { -+ const int64_t oi = (int64_t)(c0 + t) * S_v * H + (col0 + jc); -+ attn_base[oi] = (attn_base[oi] + cc[l]) * scale; -+ } -+ } -+ } -+ } -+ } -+ __syncthreads(); // O done; Ud free to become DU, Sres free for the state restage -+ -+ // --- state carry: S_C = glast*S0 + Kc^T * DU (3xtf32) ; dv-strip mma->Sres->Sreg fold --- -+ const float glast = gam[Cc - 1]; -+ const float cslast = csh[Cc - 1]; -+ for (int t = 0; t < C; t++) { -+ const float dd = (t < Cc) ? expf(cslast - csh[t]) : 0.0f; -+ Ud[j * C + t] = dd * Ud[j * C + t]; // DU in place (thread j owns column j) -+ } -+ __syncthreads(); -+ for (int s0 = 0; s0 < dvt; s0 += SRES_W) { -+ constexpr int MT = dk / 16; -+ constexpr int NT_S = SRES_W / 8; -+ constexpr int NTILES = MT * NT_S; -+ constexpr int TPW = (NTILES + NWARP - 1) / NWARP; -+ #pragma unroll -+ for (int idx = 0; idx < TPW; idx++) { -+ const int tile = w * TPW + idx; -+ if (tile >= NTILES) break; -+ const int rowbase = (tile / NT_S) * 16; // i-tile over dk -+ const int coll = (tile % NT_S) * 8; // local jc within the strip -+ const int colg = s0 + coll; // slab-local jc (Ud is slab-local) -+ float cc[4]; -+ gdn_state_tile_mma_3x(cc, Kc, Ud, rowbase, colg, lg, lt); -+ const int ii[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; -+ const int jl[4] = {coll + 2*lt, coll + 2*lt + 1, coll + 2*lt, coll + 2*lt + 1}; -+ #pragma unroll -+ for (int l = 0; l < 4; l++) Sres[jl[l] * dk + ii[l]] = cc[l]; -+ } -+ __syncthreads(); -+ if (j >= s0 && j < s0 + SRES_W) { -+ const int jcl = j - s0; -+ #pragma unroll -+ for (int i = 0; i < dk; i++) Sreg[i] = glast * Sreg[i] + Sres[jcl * dk + i]; -+ } -+ __syncthreads(); // Sres reads done before next strip's mma overwrites it -+ } -+ } -+ -+ // --- final-state write-back (M-layout) --- -+ const int64_t state_out_offset = (int64_t) (seq * H + h_idx) * S_v * S_v; -+ const int64_t attn_score_elems = (int64_t) S_v * H * n_tokens * n_seqs; -+ float * st = (state_dst != nullptr) ? (state_dst + state_out_offset) -+ : (dst + attn_score_elems + state_out_offset); -+ #pragma unroll -+ for (int i = 0; i < dk; i++) { -+ st[(col0 + j) * dk + i] = Sreg[i]; -+ } -+} -+ -+template - static void launch_gdn_chunked( - const float * q_d, const float * k_d, const float * v_d, - const float * g_d, const float * b_d, const float * s_d, -@@ -603,10 +1737,11 @@ static void launch_gdn_chunked( - const uint3 neqk1_magic, const uint3 rq3_magic, - float scale, cudaStream_t stream) { - const size_t smem = ((size_t) S_v * S_v + (size_t) 2 * C * S_v + (size_t) S_v * C -- + (size_t) C * C + (size_t) 3 * C) * sizeof(float); -+ + (size_t) C * C + (size_t) 3 * C -+ + (TC >= 1 ? (size_t) 2 * C * C : (size_t) 0)) * sizeof(float); - static bool attr_set = false; - if (!attr_set) { -- const cudaError_t e = cudaFuncSetAttribute(gated_delta_net_chunked_cuda, -+ const cudaError_t e = cudaFuncSetAttribute(gated_delta_net_chunked_cuda, - cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smem); - if (e != cudaSuccess) { - GGML_ABORT("gdn chunked: cudaFuncSetAttribute(maxDynSmem=%zu) failed: %s\n", smem, cudaGetErrorString(e)); -@@ -615,7 +1750,82 @@ static void launch_gdn_chunked( - } - dim3 grid_dims(H, n_seqs, 1); - dim3 block_dims(S_v, 1, 1); -- gated_delta_net_chunked_cuda<<>>( -+ gated_delta_net_chunked_cuda<<>>( -+ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, -+ sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, -+ neqk1_magic, rq3_magic, scale, state_dst_d, ids_d, rs_head); -+} -+ -+// Launcher for the register-resident full-TC path (M6/M7). State lives in registers, so the smem -+// budget drops the 64KB Sd buffer and gains a small SRES_W*dk restage tile; this is what lets C -+// grow past 16 under the 99KB dynamic-smem opt-in. DV_TILE < S_v dv-slabs the state across blocks -+// (M7): the grid gains a z-axis (S_v/DV_TILE slabs), each block holds only a dk x DV_TILE state -+// slab register-resident (halving the per-thread state regs at DV_TILE=64) and computes a DV_TILE -+// strip of O/U/state; the dv-independent work (Kc/Qc load, KK/QK Gram, A and its inverse, gates) -+// is recomputed per slab. This multiplies the grid (fixes the low-n_seqs grid starvation) and -+// relieves the register cap toward >=2 blocks/SM. -+template -+static void launch_gdn_chunked_rr( -+ const float * q_d, const float * k_d, const float * v_d, -+ const float * g_d, const float * b_d, const float * s_d, -+ float * dst_d, float * state_dst_d, const int32_t * ids_d, int rs_head, -+ int64_t H, int64_t n_tokens, int64_t n_seqs, -+ int64_t sq1, int64_t sq2, int64_t sq3, -+ int64_t sv1, int64_t sv2, int64_t sv3, -+ int64_t sb1, int64_t sb2, int64_t sb3, -+ const uint3 neqk1_magic, const uint3 rq3_magic, -+ float scale, cudaStream_t stream) { -+ // Kc + Qc (2*C*S_v) + Ud (DV_TILE*C) + Amat (C*C) + 3*C gates + KKsh + QKsh (2*C*C) + Sres (SRES_W*S_v) -+ const size_t smem = ((size_t) 2 * C * S_v + (size_t) DV_TILE * C + (size_t) 3 * C * C -+ + (size_t) 3 * C + (size_t) SRES_W * S_v) * sizeof(float); -+ static bool attr_set = false; -+ if (!attr_set) { -+ const cudaError_t e = cudaFuncSetAttribute(gated_delta_net_chunked_rr_cuda, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smem); -+ if (e != cudaSuccess) { -+ GGML_ABORT("gdn rr: cudaFuncSetAttribute(maxDynSmem=%zu) failed: %s\n", smem, cudaGetErrorString(e)); -+ } -+ attr_set = true; -+ } -+ dim3 grid_dims(H, n_seqs, S_v / DV_TILE); -+ dim3 block_dims(DV_TILE, 1, 1); -+ gated_delta_net_chunked_rr_cuda<<>>( -+ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, -+ sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, -+ neqk1_magic, rq3_magic, scale, state_dst_d, ids_d, rs_head); -+} -+ -+// CONFIG C launcher (M8). Same dv-slab grid as the rr launcher; the smem budget is BF16 Kc/Qc -+// (2*C*S_v*2 bytes) + Ud (DV_TILE*C f32) + 2 C*C f32 scratch (Amat + KKsh) + 3*C gates + Sres -+// (SRES_W*S_v f32). C=64/dv64 ~= 89KB (1 blk/SM); C=32/dv64 ~= 40KB and C=32/dv128 ~= 48KB -+// (2 blk/SM). cudaFuncSetAttribute return is CHECKED (0031 precedent). -+template -+static void launch_gdn_chunked_cc( -+ const float * q_d, const float * k_d, const float * v_d, -+ const float * g_d, const float * b_d, const float * s_d, -+ float * dst_d, float * state_dst_d, const int32_t * ids_d, int rs_head, -+ int64_t H, int64_t n_tokens, int64_t n_seqs, -+ int64_t sq1, int64_t sq2, int64_t sq3, -+ int64_t sv1, int64_t sv2, int64_t sv3, -+ int64_t sb1, int64_t sb2, int64_t sb3, -+ const uint3 neqk1_magic, const uint3 rq3_magic, -+ float scale, cudaStream_t stream) { -+ const size_t smem = (size_t) 2 * C * S_v * sizeof(nv_bfloat16) // Kc + Qc (bf16) -+ + ((size_t) DV_TILE * C + (size_t) 2 * C * C -+ + (size_t) 3 * C + (size_t) SRES_W * S_v) * sizeof(float); -+ static bool attr_set = false; -+ if (!attr_set) { -+ const cudaError_t e = cudaFuncSetAttribute( -+ gated_delta_net_chunked_cc_cuda, -+ cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smem); -+ if (e != cudaSuccess) { -+ GGML_ABORT("gdn cc: cudaFuncSetAttribute(maxDynSmem=%zu) failed: %s\n", smem, cudaGetErrorString(e)); -+ } -+ attr_set = true; -+ } -+ dim3 grid_dims(H, n_seqs, S_v / DV_TILE); -+ dim3 block_dims(DV_TILE, 1, 1); -+ gated_delta_net_chunked_cc_cuda<<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, - sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, - neqk1_magic, rq3_magic, scale, state_dst_d, ids_d, rs_head); -@@ -645,17 +1855,115 @@ static void launch_gated_delta_net( - // sequential recurrence. Mathematically equivalent up to FP reduction order (NEW per-path md5; - // validated benign by test-backend-ops NMSE + greedy output). Toggle: GDN_CHUNK_OFF / GDN_CHUNK_MIN. - if constexpr (!KDA && !keep_rs_t && !STATE_BF16 && !HYBRID) { -- // OPT-IN: this chunked path is bit-exact-benign (test-backend-ops green) but, at C=16 -- // (forced by GB10 99KB dyn-smem opt-in, all-shared), it is NOT yet faster than the tuned -- // sequential recurrence on this model (measured ~22%% slower S_PP, grid-starved at low -- // n_seqs + 1 block/SM occupancy). Default OFF so the backend default is regression-free; -- // enable for experiments / tuning with GDN_CHUNK_MIN=. See README section 5 (dev notes / rejected-flat levers). -- static const int gdn_chunk_min = []{ const char * e = getenv("GDN_CHUNK_MIN"); return e ? atoi(e) : INT_MAX; }(); -+ // DEFAULT-ON UNDER PAGED KV (patch 0044). The M5 tensor-core path (GDN_TC=5: full-TC -+ // form-T solve + state-update mma) is greedy-bit-exact (per-path md5 == the sequential -+ // canonical) and now *beats* the tuned sequential recurrence on the Qwen3.6 MoE prefill: -+ // GB10, q36-35b-a3b-nvfp4, LLAMA_KV_PAGED=1 LLAMA_MOE_FORCE_GRAPHS=1, -+ // -npp 512 -ntg 4 -npl 32 : S_PP 2185.8 -> 2315.2 t/s (+5.9%) -+ // -npp 2048 -ntg 4 -npl 32 : S_PP 2021.5 -> 2379.8 t/s (+17.7%) -+ // GDN_CHUNK_MIN is the engage threshold (per-call token count). It must stay > 1 so the -+ // 1-tok-per-call decode steps (and any tiny GDN call) keep the sequential recurrence - -+ // at GDN_CHUNK_MIN=1 the chunked path also swallows decode and collapses S_TG by ~25%. -+ // Tuned to 64: above decode/tiny-call sizes, below the real MoE-prefill per-call count -+ // (which is < 256 on the 512-prompt shape, so 256 barely fires). Both knobs stay -+ // env-overridable; OFF (INT_MAX) when not paged so the stock/non-paged default is -+ // regression-free. See README sections 3 (0044) and 5. -+ static const bool kv_paged = (getenv("LLAMA_KV_PAGED") != nullptr); -+ static const int gdn_chunk_min = []{ -+ const char * e = getenv("GDN_CHUNK_MIN"); -+ if (e) return atoi(e); -+ return kv_paged ? 64 : INT_MAX; -+ }(); -+ // Tensor-core level selector (single build, clean runtime A/B). GDN_TC: -+ // 0 = serial scan (patch 0031); 1 = KK/QK Gram mma (M2); -+ // 2 = + KS/QS state-boundary mma, 3xtf32 (M3); 3 = + P*U output mma (M4); -+ // 4 = CONFIG C (M8): bf16 Kc/Qc + 2 C*C scratch (QK Gram deferred) + small restage. The -+ // smem relief that lets C reach 64 under the 99KB opt-in (GDN_CHUNK_C=64 -> ~89KB, -+ // 1 blk/SM) AND drops C=32 under ~49.5KB for >=2 blocks/SM (GDN_CHUNK_C=32 -> ~40/48KB); -+ // 5 = M5 (full TC, state in 64KB smem, C=16) - the DEFAULT under paged KV (patch 0044); -+ // 6+ = M6/M7 register-resident state, f32 Kc/Qc, chunk C from GDN_CHUNK_C. -+ // GDN_GRAM_MMA=1 kept as an alias for level 1. GDN_CHUNK_C in {16,32} (M6/M7) or {32,64} -+ // (CONFIG C) selects the chunk; GDN_DV_TILE in {64,128} the dv-slab. -+ static const int gdn_tc = []{ -+ const char * e = getenv("GDN_TC"); -+ if (e) return atoi(e); -+ const char * g = getenv("GDN_GRAM_MMA"); -+ if (g && atoi(g) != 0) return 1; -+ return kv_paged ? 5 : 0; -+ }(); -+ static const int gdn_chunk_c = []{ const char * e = getenv("GDN_CHUNK_C"); return e ? atoi(e) : 32; }(); -+ // M7 dv-slab width: 128 = no slab (default, 1 slab); 64 = 2 slabs (grid x2, half the state -+ // regs/block). Must be >= max(C, 64) so the KK/QK two-warp split and the j= gdn_chunk_min) { -- launch_gdn_chunked<128, 16>( -- q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, -- H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, -- neqk1_magic, rq3_magic, scale, stream); -+ switch (gdn_tc) { -+ case 0: -+ launch_gdn_chunked<128, 16, 0>( -+ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, -+ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, -+ neqk1_magic, rq3_magic, scale, stream); -+ break; -+ case 1: -+ launch_gdn_chunked<128, 16, 1>( -+ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, -+ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, -+ neqk1_magic, rq3_magic, scale, stream); -+ break; -+ case 2: -+ launch_gdn_chunked<128, 16, 2>( -+ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, -+ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, -+ neqk1_magic, rq3_magic, scale, stream); -+ break; -+ case 3: -+ launch_gdn_chunked<128, 16, 3>( -+ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, -+ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, -+ neqk1_magic, rq3_magic, scale, stream); -+ break; -+ case 4: { -+ // CONFIG C (M8): the real occupancy flip. bf16 Kc/Qc + 2 C*C scratch (QK Gram -+ // computed late) + SRES_W=16 restage strip. GDN_CHUNK_C=64 -> dv-slab forced to -+ // 64 (full dv=128 would exceed the 99KB opt-in), ~89KB, 1 blk/SM, "C raised to -+ // 64". GDN_CHUNK_C=32 -> dv from GDN_DV_TILE (64 or 128), ~40/48KB, 2 blk/SM. -+#define GDN_CC_LAUNCH(C_, SRES_, DV_, MB_) \ -+ launch_gdn_chunked_cc<128, C_, SRES_, DV_, MB_>( \ -+ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, \ -+ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, \ -+ neqk1_magic, rq3_magic, scale, stream) -+ if (gdn_chunk_c >= 64) { -+ GDN_CC_LAUNCH(64, 16, 64, 1); -+ } else if (gdn_dv_tile <= 64) { -+ GDN_CC_LAUNCH(32, 16, 64, 2); -+ } else { -+ GDN_CC_LAUNCH(32, 16, 128, 2); -+ } -+#undef GDN_CC_LAUNCH -+ break; -+ } -+ case 5: -+ // M5 reference: full TC, state still in 64KB smem, C=16 (A/B baseline). -+ launch_gdn_chunked<128, 16, 4>( -+ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, -+ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, -+ neqk1_magic, rq3_magic, scale, stream); -+ break; -+ default: -+ // M6/M7: full TC, REGISTER-RESIDENT state; chunk C from GDN_CHUNK_C (default 32), -+ // dv-slab width from GDN_DV_TILE (default 128 = no slab; 64 = 2 slabs, M7). -+#define GDN_RR_LAUNCH(C_, SRES_, DV_) \ -+ launch_gdn_chunked_rr<128, C_, SRES_, DV_>( \ -+ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, \ -+ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, \ -+ neqk1_magic, rq3_magic, scale, stream) -+ if (gdn_chunk_c <= 16) { -+ if (gdn_dv_tile <= 64) { GDN_RR_LAUNCH(16, 16, 64); } else { GDN_RR_LAUNCH(16, 16, 128); } -+ } else { -+ if (gdn_dv_tile <= 64) { GDN_RR_LAUNCH(32, 32, 64); } else { GDN_RR_LAUNCH(32, 32, 128); } -+ } -+#undef GDN_RR_LAUNCH -+ break; -+ } - return; - } - } -diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp -index 951bffc..65f2c62 100644 ---- a/tests/test-backend-ops.cpp -+++ b/tests/test-backend-ops.cpp -@@ -9433,6 +9433,11 @@ static std::vector> make_test_cases_eval() { - } - - test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1)); -+ // Tensor-core chunked-GDN prefill path (S_v==128): multi-chunk (C=16) coverage, -+ // incl. a tail chunk (100 = 6*16+4) and multi-seq. Exercised via GDN_CHUNK_MIN + GDN_TC. -+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1)); -+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 100, 1)); -+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 128, 2)); - test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1)); - test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true)); - test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true)); --- -2.43.0 diff --git a/backend/cpp/llama-cpp-localai-paged/patches/paged/0045-feat-paged-vLLM-exact-offline-repack-Marlin-W4A16-gr.patch b/backend/cpp/llama-cpp-localai-paged/patches/paged/0045-feat-paged-vLLM-exact-offline-repack-Marlin-W4A16-gr.patch deleted file mode 100644 index ec4f9cb4a..000000000 --- a/backend/cpp/llama-cpp-localai-paged/patches/paged/0045-feat-paged-vLLM-exact-offline-repack-Marlin-W4A16-gr.patch +++ /dev/null @@ -1,650 +0,0 @@ -From 864e92807809c55905bdb73ef492f6b1d03986f9 Mon Sep 17 00:00:00 2001 -From: Ettore Di Giacinto -Date: Mon, 29 Jun 2026 11:02:07 +0200 -Subject: [PATCH] feat(paged): vLLM-exact offline-repack Marlin W4A16 grouped - MoE prefill GEMM (patch 0045) - -The patch-0035 root-cause sweep proved the W4A16 grouped MoE GEMM is occupancy/ -latency bound (~10% MMA util), NOT dequant bound: every prefill step it re-derives -a per-block STRIDED weight gather and streams the 4-bit weights with scattered -4-BYTE cp.async transactions (BN*8 per K-block) + a per-step ue4m3 scale decode, -plus a SEPARATE global f32->bf16 activation pre-pass kernel. vLLM Marlin wins by -REPACKING the weights ONCE at load into an MMA-ready tile-contiguous 4-bit layout -so the GEMM loop issues only cheap COALESCED 16-byte loads and the per-lane reads -are already in fragment order; the loop does ZERO dequant work beyond the cheap -in-register nibble->bf16 unpack. This is that approach, faithful to vLLM Marlin and -distinct from the rejected patch-0044 int8-Marlin (which doubled the weight bytes -to int8 + a separate act-quant pass). - -(1) OFFLINE REPACK (one-time, first engage; cached in a persistent cudaMalloc - buffer keyed by src0->data): transpose the NVFP4 experts [N][Kb](d[4]+qs[32]) - into tile-contiguous planes Q=[e][Kb][N][32 qs bytes] + S=[e][Kb][N][4 ue4m3 - scale bytes]. STAYS 4-BIT: qs packed, scales ue4m3 -> the repacked buffer is - byte-for-byte the SAME size as the source NVFP4 weights (144 MiB/tensor for the - 35B-A3B experts, == source; no int8 2x). Persists across all prefill steps. -(2) KERNEL reads the pre-packed tiles directly: per (K-block, N-block) the BN rows - qs/scales are contiguous -> the smem fill is a flat COALESCED 16-byte cp.async - stream (no strided gather, no address math). Activations cast f32->bf16 IN - REGISTER on the load into smem (no separate global act pre-pass, no act-quant). - Inner loop: ldmatrix bf16 A, in-register PRMT (LUT-free) nibble->bf16 unpack - scaled by the in-register ue4m3 decode, bf16 m16n8k16 mma.sync into f32 accum, - cp.async multistage, ragged per-tile expert offset (reuses 0035 tile map). NO - separate smem-staged dequant pass, NO __syncthreads-gated dequant pass (SASS: - LDSM->PRMT->I2F->FMUL->F2F.BF16->HMMA.16816.F32.BF16 with no STS/BAR.SYNC between - the weight load and the MMA). - -The repacked smem is bit-identical to the 0035 W4A16 smem, so the proven inner -fragment math is unchanged and the output is bit-identical to the 0035 W4A16 path. - -TOGGLE: LLAMA_W4A16_REPACK=1 (default 0 == OFF), engaged only inside the already -default-off 0035 W4A16 path (LLAMA_W4A16_PREFILL_M>0). LLAMA_W4A16_REPACK_NOCACHE=1 -repacks into a transient pool buffer per call (for test-backend-ops, which frees/ -reuses src0 addresses and so defeats the pointer-keyed cache). Stock / decode / -non-NVFP4 byte-untouched. - -VALIDATION (GB10, sm_121a, Qwen3.6-35B-A3B-NVFP4): - - test-backend-ops MUL_MAT_ID nvfp4 (vs CPU oracle), REPACK forced + NOCACHE: - 81/81 OK, 0 FAIL. - - real-model greedy md5 (paged MoE): stock == W4A16-non-repack == W4A16-repack - (cached) == W4A16-repack (nocache) == default-off (REPACK=1, PREFILL_M unset), - all fda1aadbbbfb36fe8ab0f5f5465c745e (bit-identical; default-off is stock). - -HONEST PERF (S_PP t/s, llama-batched-bench -fa on -ngl 99 -ntg 32 -npl 1, paged, -warm cache): the offline repack is a large win over the prior non-repack W4A16 -(npp512 854.7 -> 1452.4 = +70%; beats it at every M) and is FLAT across M -(1452/1478/1467 at 512/1024/2048 = MMA-pipeline bound as designed). But the heavily -tuned FP4-MMQ baseline on GB10 is still ~38% faster (~2030 flat). Decode S_TG -unchanged (~55 t/s, prefill-only lever). One-time cache build amortized (cold first -prefill). Ships DEFAULT-OFF (like 0033/0034/0035): the validated, env-gated, bit- -exact-gated mechanism + the recorded result that even vLLM-exact offline-repack -Marlin W4A16 does not overtake the GB10 FP4-MMQ winner. - -Weight-memory: the repacked representation STAYS 4-BIT (no int8 expansion; vs 0044 -+100%); the persistent cache is an additive 4-bit copy of the engaged expert tensors -when ON, and literally 0 when OFF (default). Decode keeps the original block_nvfp4 -layout for its MMQ/graph path, so the cache is additive rather than in-place. - -Build: arch=compute_121a,code=[compute_121a,sm_121a]; AMPERE_MMA_AVAILABLE / -CP_ASYNC_AVAILABLE guards (NO_DEVICE_CODE off-Blackwell). - -Assisted-by: Claude:opus-4.8 [Claude Code] -Signed-off-by: Ettore Di Giacinto ---- - ggml/src/ggml-cuda/w4a16-gemm.cu | 11 + - ggml/src/ggml-cuda/w4a16-repack.cu | 470 ++++++++++++++++++++++++++++ - ggml/src/ggml-cuda/w4a16-repack.cuh | 59 ++++ - 3 files changed, 540 insertions(+) - create mode 100644 ggml/src/ggml-cuda/w4a16-repack.cu - create mode 100644 ggml/src/ggml-cuda/w4a16-repack.cuh - -diff --git a/ggml/src/ggml-cuda/w4a16-gemm.cu b/ggml/src/ggml-cuda/w4a16-gemm.cu -index c5c9ef7..687a7ae 100644 ---- a/ggml/src/ggml-cuda/w4a16-gemm.cu -+++ b/ggml/src/ggml-cuda/w4a16-gemm.cu -@@ -1,4 +1,5 @@ - #include "w4a16-gemm.cuh" -+#include "w4a16-repack.cuh" - #include "mma.cuh" - - #include -@@ -389,6 +390,16 @@ void ggml_cuda_mul_mat_id_w4a16_grouped( - GGML_ASSERT(src0->type == GGML_TYPE_NVFP4); - GGML_ASSERT(N % 128 == 0 && K % 64 == 0); - -+ // [paged patch 0045] vLLM-exact offline-repack sub-mode: repack the NVFP4 experts ONCE at first -+ // engage into an MMA-ready tile-contiguous 4-bit layout (cached, stays 4-bit), then run a GEMM -+ // loop with coalesced 16B weight loads + in-register act cast + ZERO in-loop dequant beyond the -+ // cheap in-register nibble->bf16 unpack. Gated by LLAMA_W4A16_REPACK (default off). -+ if (ggml_cuda_w4a16_repack_enabled()) { -+ ggml_cuda_mul_mat_id_w4a16_repack(ctx, src0, src1_sorted, dst_sorted, -+ tokens_per_expert, n_experts, K, N, stream); -+ return; -+ } -+ - int sel = w4a16_cfg_sel(); - if (N % w4a16_cfg_bn(sel) != 0) { - sel = 3; // BN>128 config whose BN doesn't divide N: safe BN=128 winner -diff --git a/ggml/src/ggml-cuda/w4a16-repack.cu b/ggml/src/ggml-cuda/w4a16-repack.cu -new file mode 100644 -index 0000000..4a7a125 ---- /dev/null -+++ b/ggml/src/ggml-cuda/w4a16-repack.cu -@@ -0,0 +1,470 @@ -+#include "w4a16-repack.cuh" -+#include "mma.cuh" -+ -+#include -+#include -+#include -+#include -+#include -+#include -+ -+// =========================================================================== -+// [paged patch 0045] vLLM-exact offline-repack Marlin W4A16 grouped MoE prefill GEMM. -+// See w4a16-repack.cuh for the design. Default-off (LLAMA_W4A16_REPACK). -+// =========================================================================== -+ -+using namespace ggml_cuda_mma; -+typedef tile<16, 8, nv_bfloat162> rp_tile_A; // A operand: M=16, K=16 -+typedef tile< 8, 8, nv_bfloat162> rp_tile_B; // B operand: N=8, K=16 -+typedef tile<16, 8, float> rp_tile_C; // accumulator: M=16, N=8 -+ -+bool ggml_cuda_w4a16_repack_enabled() { -+ static const bool e = [] { -+ const char * s = getenv("LLAMA_W4A16_REPACK"); -+ return s != nullptr && atoi(s) != 0; -+ }(); -+ return e; -+} -+ -+// ---- cp.async helpers (sm80+; raw bytes, no cast) ---- -+static __device__ __forceinline__ void rp_cp_async16(void * smem, const void * gmem) { -+#ifdef CP_ASYNC_AVAILABLE -+ const unsigned s = (unsigned) __cvta_generic_to_shared(smem); -+ asm volatile("cp.async.cg.shared.global [%0],[%1],16;\n" :: "r"(s), "l"(gmem)); -+#else -+ GGML_UNUSED(smem); GGML_UNUSED(gmem); NO_DEVICE_CODE; -+#endif // CP_ASYNC_AVAILABLE -+} -+static __device__ __forceinline__ void rp_cp_commit() { -+#ifdef CP_ASYNC_AVAILABLE -+ asm volatile("cp.async.commit_group;\n" ::); -+#else -+ NO_DEVICE_CODE; -+#endif // CP_ASYNC_AVAILABLE -+} -+template static __device__ __forceinline__ void rp_cp_wait() { -+#ifdef CP_ASYNC_AVAILABLE -+ asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); -+#else -+ NO_DEVICE_CODE; -+#endif // CP_ASYNC_AVAILABLE -+} -+ -+// ---- fast 16-entry FP4 (E2M1) -> kvalues_mxfp4 lookup via __byte_perm (PRMT), LUT-free, NO -+// divergent LDG. Bit-identical to kvalues_mxfp4[code]; same trick MMQ's get_int_from_table_16 and -+// the 0035 kernel use. The "cheap in-register unpack" - the ONLY dequant work the loop does. ---- -+static __device__ __forceinline__ int2 rp_table16(uint32_t q4) { -+ const uint32_t * table32 = (const uint32_t *) kvalues_mxfp4; // 16 int8 == 4 u32 -+ uint32_t tmp[2]; -+ const uint32_t sel = 0x32103210 | ((q4 & 0x88888888) >> 1); -+#pragma unroll -+ for (uint32_t i = 0; i < 2; ++i) { -+ const uint32_t shift = 16*i; -+ const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift); -+ const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift); -+ tmp[i] = __byte_perm(low, high, sel >> shift); -+ } -+ return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531)); -+} -+ -+// =========================================================================== -+// (1) OFFLINE REPACK kernel (one-time). Transpose src0 NVFP4 [e][N][Kb](d[4]+qs[32]) into -+// Qrp = [e][Kb][N][32 qs bytes] (u32 plane, 8 u32/block) -+// Srp = [e][Kb][N][ 4 scale bytes] (u32 plane, 1 u32/block) -+// Same total bytes as src0 (36 B/block); stays 4-bit. One thread per source block. -+// =========================================================================== -+static __global__ void w4a16_repack_kernel( -+ const block_nvfp4 * __restrict__ W0, int64_t expert_stride_blocks, -+ uint32_t * __restrict__ Qrp, uint32_t * __restrict__ Srp, -+ int N, int Kb, int n_experts) { -+ const int64_t total = (int64_t) n_experts * N * Kb; -+ const int64_t t = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; -+ if (t >= total) { -+ return; -+ } -+ const int kt = (int) (t % Kb); -+ const int64_t r2 = t / Kb; -+ const int n = (int) (r2 % N); -+ const int e = (int) (r2 / N); -+ -+ const block_nvfp4 * blk = W0 + (int64_t) e*expert_stride_blocks + (int64_t) n*Kb + kt; -+ const uint32_t * src = (const uint32_t *) blk; // word0=d[4], words1..8=qs[32] -+ -+ const int64_t qbase = (((int64_t) e*Kb + kt) * N + n) * 8; // u32 offset in Qrp -+ const int64_t sbase = ((int64_t) e*Kb + kt) * N + n; // u32 offset in Srp -+ Srp[sbase] = src[0]; -+#pragma unroll -+ for (int w = 0; w < 8; w++) { -+ Qrp[qbase + w] = src[1 + w]; -+ } -+} -+ -+// =========================================================================== -+// (2) GROUPED GEMM over the repacked tiles. Per output tile (blockIdx.x=N-block, blockIdx.y=M-tile): -+// expert e=g_tile_expert[by], row0=g_tile_row0[by], rcount=g_tile_rows[by]. -+// Weights from Qrp/Srp (tile-contiguous), activations from A_f32 (cast bf16 in-register), out to C. -+// ZERO in-loop dequant beyond the in-register nibble->bf16 unpack + ue4m3 scale decode. -+// =========================================================================== -+template -+__launch_bounds__(WARPS_M*WARPS_N*32, 1) -+static __global__ void w4a16_repack_grouped_kernel( -+ const float * __restrict__ A_f32, // [total_rows, K] f32 (sorted) -+ const uint32_t * __restrict__ Qrp, // [e][Kb][N][8 u32] -+ const uint32_t * __restrict__ Srp, // [e][Kb][N][1 u32] -+ float * __restrict__ C, // [total_rows, N] f32 -+ const int * __restrict__ g_tile_expert, -+ const int * __restrict__ g_tile_row0, -+ const int * __restrict__ g_tile_rows, -+ int N, int K, int total_rows) { -+#if defined(AMPERE_MMA_AVAILABLE) && defined(CP_ASYNC_AVAILABLE) -+ constexpr int BK = 64; // one nvfp4 block -+ constexpr int NWARP = WARPS_M*WARPS_N; -+ constexpr int THREADS = NWARP*32; -+ constexpr int WM = BM/WARPS_M, WN = BN/WARPS_N; -+ constexpr int MF = WM/16, NF = WN/8; -+ -+ constexpr int AN = BK/2; // bf16 pairs per A smem row (nv_bfloat162) -+ constexpr int ASTRIDE = AN + APAD; // padded A smem row stride (skew banks, +19% lesson) -+ constexpr int SZ_A = BM*ASTRIDE; // nv_bfloat162 (== u32) per stage (padded) -+ constexpr int SZ_WQ = BN*8; // u32 per stage (32 qs bytes/row, tile-contiguous) -+ constexpr int SZ_WD = BN; // u32 per stage (4 scale bytes/row, tile-contiguous) -+ -+ extern __shared__ uint32_t smem_u32[]; -+ constexpr int STAGE_U32 = SZ_A + SZ_WQ + SZ_WD; -+ nv_bfloat162 * sA[STAGES]; -+ uint32_t * sWq[STAGES]; -+ uint32_t * sWd[STAGES]; -+#pragma unroll -+ for (int s = 0; s < STAGES; s++) { -+ uint32_t * base = smem_u32 + s*STAGE_U32; -+ sA[s] = (nv_bfloat162 *) base; -+ sWq[s] = base + SZ_A; -+ sWd[s] = base + SZ_A + SZ_WQ; -+ } -+ -+ const int lane = threadIdx.x; // 0..31 (mma.cuh uses threadIdx.x AS the warp lane) -+ const int warp = threadIdx.y; // 0..NWARP-1 -+ const int tid = warp*32 + lane; -+ const int wrow = warp / WARPS_N, wcol = warp % WARPS_N; -+ -+ const int e = g_tile_expert[blockIdx.y]; -+ const int row0 = g_tile_row0[blockIdx.y]; -+ const int rcount = g_tile_rows[blockIdx.y]; -+ const int blockCol = blockIdx.x*BN; -+ const int Kb = K/64; -+ // base u32 offset of expert e in the repacked planes (tile = + (kt*N + blockCol)*{8,1}) -+ const int64_t Qe = (int64_t) e*Kb*N*8; -+ const int64_t Se = (int64_t) e*Kb*N; -+ -+ rp_tile_C acc[MF][NF]; -+ -+ // async-load K-block kt into stage st: A cast in-register (no global pre-pass); W coalesced 16B. -+ auto load_tile = [&](int st, int kt) { -+ // A: BM rows x BK bf16 = BM x (BK/8) 16B chunks. Source f32 -> cast bf16 in register. -+ const int A_chunks = BM*(BK/8); -+#pragma unroll 1 -+ for (int idx = tid; idx < A_chunks; idx += THREADS) { -+ const int c = idx % (BK/8); // 16B (8 bf16) chunk in the row -+ const int r = idx / (BK/8); // row in tile -+ const int gr = row0 + r; -+ nv_bfloat162 * d2 = sA[st] + r*ASTRIDE + c*4; // 4 nv_bfloat162 = 8 bf16 -+ if (gr < total_rows) { -+ const float * src = A_f32 + (int64_t) gr*K + (int64_t) kt*BK + c*8; -+ const float4 a = *(const float4 *) (src); -+ const float4 b = *(const float4 *) (src + 4); -+ d2[0] = make_bfloat162(__float2bfloat16(a.x), __float2bfloat16(a.y)); -+ d2[1] = make_bfloat162(__float2bfloat16(a.z), __float2bfloat16(a.w)); -+ d2[2] = make_bfloat162(__float2bfloat16(b.x), __float2bfloat16(b.y)); -+ d2[3] = make_bfloat162(__float2bfloat16(b.z), __float2bfloat16(b.w)); -+ } else { -+ const nv_bfloat162 z = make_bfloat162((nv_bfloat16) 0.0f, (nv_bfloat16) 0.0f); -+ d2[0] = z; d2[1] = z; d2[2] = z; d2[3] = z; -+ } -+ } -+ // W qs: tile-contiguous BN*32 bytes = BN*8 u32 = BN*2 16B chunks. Coalesced cp.async.16. -+ const uint32_t * Qtile = Qrp + Qe + ((int64_t) kt*N + blockCol) * 8; -+ const int Q_chunks = (BN*8) / 4; // 16B (4 u32) chunks -+#pragma unroll 1 -+ for (int idx = tid; idx < Q_chunks; idx += THREADS) { -+ rp_cp_async16(&sWq[st][idx*4], Qtile + idx*4); -+ } -+ // W scales: tile-contiguous BN*4 bytes = BN u32 = BN/4 16B chunks. Coalesced cp.async.16. -+ const uint32_t * Stile = Srp + Se + ((int64_t) kt*N + blockCol); -+ const int S_chunks = BN / 4; -+#pragma unroll 1 -+ for (int idx = tid; idx < S_chunks; idx += THREADS) { -+ rp_cp_async16(&sWd[st][idx*4], Stile + idx*4); -+ } -+ }; -+ -+ // prologue -+#pragma unroll -+ for (int s = 0; s < STAGES-1; s++) { if (s < Kb) load_tile(s, s); rp_cp_commit(); } -+ -+ for (int kt = 0; kt < Kb; kt++) { -+ const int ld = kt + (STAGES-1); -+ if (ld < Kb) load_tile(ld % STAGES, ld); -+ rp_cp_commit(); -+ rp_cp_wait(); -+ __syncthreads(); -+ -+ const int rs = kt % STAGES; -+ const nv_bfloat162 * sAcur = sA[rs]; -+ const uint32_t * sWqw = sWq[rs]; // BN rows x 8 u32 (32 qs bytes) -+ const uint32_t * sWdw = sWd[rs]; // BN rows x 1 u32 (4 scale bytes) -+ -+#pragma unroll -+ for (int kk = 0; kk < BK/16; kk++) { // 4 m16n8k16 sub-steps per 64-block -+ const int sub = kk; -+ // A fragments via ldmatrix (bf16) -+ rp_tile_A A_frag[MF]; -+#pragma unroll -+ for (int mi = 0; mi < MF; mi++) { -+ const int rb = wrow*WM + mi*16; -+ load_ldmatrix(A_frag[mi], sAcur + rb*ASTRIDE + kk*8, ASTRIDE); -+ } -+ // B fragments: in-register FP4->bf16 unpack (byte_perm LUT-free) * in-register ue4m3 scale. -+ rp_tile_B B_frag[NF]; -+ const int n_local = lane >> 2; // tile_B::get_i (row N, 0..7) -+ const int jc = lane & 3; -+ const int wsel = sub*2 + (jc >> 1); -+ const int bsh = 8 * (2*(jc & 1)); -+#pragma unroll -+ for (int ni = 0; ni < NF; ni++) { -+ const int nrow = wcol*WN + ni*8 + n_local; // col within BN tile [0,BN) -+ const uint32_t w = sWqw[nrow*8 + wsel]; -+ const int2 kv = rp_table16(w); -+ const float sc = ggml_cuda_ue4m3_to_fp32(((const uint8_t *) &sWdw[nrow])[sub]); -+ B_frag[ni].x[0].x = __float2bfloat16(sc * (float) (int8_t) (kv.x >> bsh)); -+ B_frag[ni].x[0].y = __float2bfloat16(sc * (float) (int8_t) (kv.x >> (bsh + 8))); -+ B_frag[ni].x[1].x = __float2bfloat16(sc * (float) (int8_t) (kv.y >> bsh)); -+ B_frag[ni].x[1].y = __float2bfloat16(sc * (float) (int8_t) (kv.y >> (bsh + 8))); -+ } -+#pragma unroll -+ for (int mi = 0; mi < MF; mi++) -+#pragma unroll -+ for (int ni = 0; ni < NF; ni++) -+ mma(acc[mi][ni], A_frag[mi], B_frag[ni]); -+ } -+ __syncthreads(); -+ } -+ -+ // write back (mask the ragged per-expert row tail) -+#pragma unroll -+ for (int mi = 0; mi < MF; mi++) -+#pragma unroll -+ for (int ni = 0; ni < NF; ni++) { -+ const int orow = wrow*WM + mi*16; -+ const int ocol = blockCol + wcol*WN + ni*8; -+#pragma unroll -+ for (int l = 0; l < acc[mi][ni].ne; l++) { -+ const int lr = orow + acc[mi][ni].get_i(l); -+ const int nc = ocol + acc[mi][ni].get_j(l); -+ if (lr < rcount && nc < N) { -+ C[(int64_t)(row0 + lr)*N + nc] = acc[mi][ni].x[l]; -+ } -+ } -+ } -+#else -+ GGML_UNUSED(A_f32); GGML_UNUSED(Qrp); GGML_UNUSED(Srp); GGML_UNUSED(C); -+ GGML_UNUSED(g_tile_expert); GGML_UNUSED(g_tile_row0); GGML_UNUSED(g_tile_rows); -+ GGML_UNUSED(N); GGML_UNUSED(K); GGML_UNUSED(total_rows); -+ NO_DEVICE_CODE; -+#endif // AMPERE_MMA_AVAILABLE && CP_ASYNC_AVAILABLE -+} -+ -+// launch the one-time repack kernel: src0 NVFP4 -> (Qrp, Srp) repacked planes. -+static void w4a16_repack_launch( -+ const ggml_tensor * src0, int64_t n_experts, int64_t K, int64_t N, -+ uint32_t * Qrp, uint32_t * Srp, cudaStream_t stream) { -+ const int64_t Kb = K / 64; -+ const int64_t nblk = n_experts * Kb * N; -+ const int64_t expert_stride_blocks = (int64_t) (src0->nb[2] / sizeof(block_nvfp4)); -+ const int threads = 256; -+ const int64_t grid = (nblk + threads - 1) / threads; -+ w4a16_repack_kernel<<>>( -+ (const block_nvfp4 *) src0->data, expert_stride_blocks, -+ Qrp, Srp, (int) N, (int) Kb, (int) n_experts); -+ CUDA_CHECK(cudaGetLastError()); -+} -+ -+// =========================================================================== -+// repack cache: persistent device buffer per src0 tensor (keyed by src0->data). One-time build. -+// -+// Correct for real models: model weights live in a persistent buffer for the model's lifetime, so -+// src0->data is stable and never freed/reused during inference. NOT correct for test-backend-ops, -+// which allocates/frees a fresh src0 per case and ggml reuses the address -> a cache HIT returns -+// stale repacked weights. For harness validation use LLAMA_W4A16_REPACK_NOCACHE=1 (repack into a -+// transient pool buffer every call, bypassing the cache); this proves the kernel+repack correct. -+// =========================================================================== -+struct w4a16_repack_entry { -+ uint32_t * Qrp = nullptr; // [n_experts][Kb][N][8 u32] -+ uint32_t * Srp = nullptr; // [n_experts][Kb][N][1 u32] -+ int64_t n_experts = 0, Kb = 0, N = 0; -+ size_t bytes = 0; // total cudaMalloc bytes (Q + S), for the memory-delta report -+}; -+static std::mutex g_rp_mu; -+static std::unordered_map g_rp_cache; -+ -+// total bytes the repack cache currently holds on device (for the memory-delta report). -+size_t ggml_cuda_w4a16_repack_cache_bytes() { -+ std::lock_guard lk(g_rp_mu); -+ size_t b = 0; -+ for (const auto & kv : g_rp_cache) { -+ b += kv.second.bytes; -+ } -+ return b; -+} -+ -+static const w4a16_repack_entry & w4a16_get_or_build_repack( -+ const ggml_tensor * src0, int64_t n_experts, int64_t K, int64_t N, cudaStream_t stream) { -+ const void * key = src0->data; -+ std::lock_guard lk(g_rp_mu); -+ auto it = g_rp_cache.find(key); -+ if (it != g_rp_cache.end()) { -+ return it->second; -+ } -+ -+ const int64_t Kb = K / 64; -+ const int64_t nblk = n_experts * Kb * N; -+ const size_t q_bytes = (size_t) nblk * 8 * sizeof(uint32_t); // 32 qs bytes/block -+ const size_t s_bytes = (size_t) nblk * 1 * sizeof(uint32_t); // 4 scale bytes/block -+ // single allocation: [Q][S]; total == source NVFP4 weight bytes (36 B/block) -> stays 4-bit. -+ void * buf = nullptr; -+ CUDA_CHECK(cudaMalloc(&buf, q_bytes + s_bytes)); -+ uint32_t * Qrp = (uint32_t *) buf; -+ uint32_t * Srp = (uint32_t *) ((char *) buf + q_bytes); -+ -+ w4a16_repack_launch(src0, n_experts, K, N, Qrp, Srp, stream); -+ -+ w4a16_repack_entry ent; -+ ent.Qrp = Qrp; ent.Srp = Srp; -+ ent.n_experts = n_experts; ent.Kb = Kb; ent.N = N; -+ ent.bytes = q_bytes + s_bytes; -+ auto res = g_rp_cache.emplace(key, ent); -+ -+ if (getenv("LLAMA_W4A16_DEBUG")) { -+ // NB: we already hold g_rp_mu - sum the cache total INLINE (do NOT call the public -+ // ggml_cuda_w4a16_repack_cache_bytes(), which re-locks the non-recursive mutex -> deadlock). -+ size_t total_cache = 0; -+ for (const auto & kv : g_rp_cache) { -+ total_cache += kv.second.bytes; -+ } -+ fprintf(stderr, "[w4a16-repack] BUILT cache for src0=%p: n_experts=%lld Kb=%lld N=%lld " -+ "bytes=%.1f MiB (== source NVFP4 weight bytes; stays 4-bit). total cache=%.1f MiB\n", -+ key, (long long) n_experts, (long long) Kb, (long long) N, -+ ent.bytes / (1024.0*1024.0), total_cache / (1024.0*1024.0)); -+ } -+ return res.first->second; -+} -+ -+// =========================================================================== -+// GEMM run over already-repacked planes (Qrp, Srp). Single tuned config -+// (the 0035 winner: BM64 BN128 WARPS_M1 WARPS_N8 STAGES3 APAD4). -+// =========================================================================== -+static void w4a16_repack_run( -+ ggml_backend_cuda_context & ctx, -+ const float * src1_sorted, -+ float * dst_sorted, -+ const int * tokens_per_expert, -+ int64_t n_experts, int64_t K, int64_t N, -+ const uint32_t * Qrp, const uint32_t * Srp, -+ cudaStream_t stream) { -+ constexpr int BM = 64, BN = 128, WARPS_M = 1, WARPS_N = 8, STAGES = 3, APAD = 4; -+ -+ // host: per-M-tile expert map (ragged, no tile crosses an expert boundary) -+ int64_t total_rows = 0; -+ for (int64_t e = 0; e < n_experts; e++) { -+ total_rows += tokens_per_expert[e]; -+ } -+ if (total_rows == 0) { -+ return; -+ } -+ std::vector h_tile_expert, h_tile_row0, h_tile_rows; -+ int64_t row = 0; -+ for (int64_t e = 0; e < n_experts; e++) { -+ const int t = tokens_per_expert[e]; -+ for (int off = 0; off < t; off += BM) { -+ h_tile_expert.push_back((int32_t) e); -+ h_tile_row0.push_back((int32_t) (row + off)); -+ h_tile_rows.push_back((int32_t) std::min(BM, t - off)); -+ } -+ row += t; -+ } -+ const int n_tiles = (int) h_tile_expert.size(); -+ -+ if (getenv("LLAMA_W4A16_DEBUG")) { -+ int max_tpe = 0, multi = 0; -+ for (int64_t e = 0; e < n_experts; e++) { -+ if (tokens_per_expert[e] > max_tpe) max_tpe = tokens_per_expert[e]; -+ if (tokens_per_expert[e] > BM) multi++; -+ } -+ fprintf(stderr, "[w4a16-repack] engaged: total_rows=%lld n_experts=%lld K=%lld N=%lld " -+ "n_tiles=%d max_tpe=%d multi_tile=%d (offline-repack, in-register act cast)\n", -+ (long long) total_rows, (long long) n_experts, (long long) K, (long long) N, -+ n_tiles, max_tpe, multi); -+ } -+ -+ ggml_cuda_pool_alloc d_tile_expert(ctx.pool(), n_tiles); -+ ggml_cuda_pool_alloc d_tile_row0 (ctx.pool(), n_tiles); -+ ggml_cuda_pool_alloc d_tile_rows (ctx.pool(), n_tiles); -+ CUDA_CHECK(cudaMemcpyAsync(d_tile_expert.ptr, h_tile_expert.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream)); -+ CUDA_CHECK(cudaMemcpyAsync(d_tile_row0.ptr, h_tile_row0.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream)); -+ CUDA_CHECK(cudaMemcpyAsync(d_tile_rows.ptr, h_tile_rows.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream)); -+ -+ auto kern = w4a16_repack_grouped_kernel; -+ constexpr int AN = 64/2, ASTRIDE = AN + APAD; -+ constexpr int STAGE_U32 = BM*ASTRIDE + BN*8 + BN; -+ const int smem_bytes = STAGES * STAGE_U32 * (int) sizeof(uint32_t); -+ CUDA_CHECK(cudaFuncSetAttribute(kern, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); -+ -+ dim3 grid((unsigned) (N / BN), (unsigned) n_tiles); -+ dim3 block(32, WARPS_M*WARPS_N); -+ kern<<>>( -+ src1_sorted, Qrp, Srp, dst_sorted, -+ d_tile_expert.ptr, d_tile_row0.ptr, d_tile_rows.ptr, -+ (int) N, (int) K, (int) total_rows); -+ CUDA_CHECK(cudaGetLastError()); -+} -+ -+// =========================================================================== -+// public host entry. Default: cached persistent repack (one-time, production-correct, stays -+// 4-bit). LLAMA_W4A16_REPACK_NOCACHE=1: repack into a transient pool buffer EVERY call (bypasses -+// the pointer-keyed cache) so test-backend-ops (which frees/reuses src0 addresses) can validate -+// the kernel + repack correctness. -+// =========================================================================== -+void ggml_cuda_mul_mat_id_w4a16_repack( -+ ggml_backend_cuda_context & ctx, -+ const ggml_tensor * src0, -+ const float * src1_sorted, -+ float * dst_sorted, -+ const int * tokens_per_expert, -+ int64_t n_experts, int64_t K, int64_t N, -+ cudaStream_t stream) { -+ GGML_ASSERT(src0->type == GGML_TYPE_NVFP4); -+ GGML_ASSERT(N % 128 == 0 && K % 64 == 0); -+ -+ static const bool nocache = [] { -+ const char * s = getenv("LLAMA_W4A16_REPACK_NOCACHE"); -+ return s != nullptr && atoi(s) != 0; -+ }(); -+ -+ if (nocache) { -+ const int64_t Kb = K / 64; -+ const int64_t nblk = n_experts * Kb * N; -+ // (1) repack into a transient pool buffer (same 4-bit size; stream-ordered before the GEMM). -+ ggml_cuda_pool_alloc q(ctx.pool(), (size_t) nblk * 8); -+ ggml_cuda_pool_alloc s(ctx.pool(), (size_t) nblk * 1); -+ w4a16_repack_launch(src0, n_experts, K, N, q.ptr, s.ptr, stream); -+ // (2) GEMM (q,s stay alive through the launch; kernel is queued on `stream` before dtor). -+ w4a16_repack_run(ctx, src1_sorted, dst_sorted, tokens_per_expert, -+ n_experts, K, N, q.ptr, s.ptr, stream); -+ return; -+ } -+ -+ // (1) one-time offline repack (cached, persistent, stays 4-bit). Stream-ordered before the GEMM. -+ const w4a16_repack_entry & rp = w4a16_get_or_build_repack(src0, n_experts, K, N, stream); -+ // (2) GEMM over the cached repacked planes. -+ w4a16_repack_run(ctx, src1_sorted, dst_sorted, tokens_per_expert, -+ n_experts, K, N, rp.Qrp, rp.Srp, stream); -+} -diff --git a/ggml/src/ggml-cuda/w4a16-repack.cuh b/ggml/src/ggml-cuda/w4a16-repack.cuh -new file mode 100644 -index 0000000..f44f5f8 ---- /dev/null -+++ b/ggml/src/ggml-cuda/w4a16-repack.cuh -@@ -0,0 +1,59 @@ -+#pragma once -+ -+#include "common.cuh" -+ -+// [paged patch 0045] vLLM-EXACT offline-repack Marlin W4A16 grouped MoE prefill GEMM. -+// -+// This is a SUB-MODE of the patch-0035 W4A16 grouped MoE GEMM (default-off). The 0035 root-cause -+// sweep proved the W4A16 kernel is occupancy/latency bound (~10% MMA util): every prefill step it -+// re-derives a per-block STRIDED weight gather (blk = We + (blockCol+r)*Kb + kt) and streams the -+// 4-bit weights with scattered 4-BYTE cp.async transactions (BN*8 per K-block) + a per-step ue4m3 -+// scale decode, plus a SEPARATE global f32->bf16 activation pre-pass kernel. vLLM's Marlin wins by -+// REPACKING the weights ONCE at load into an MMA-ready, tile-contiguous 4-bit layout so the GEMM -+// loop issues only cheap COALESCED 16-byte loads and the per-lane reads are already in fragment -+// order - the loop does ZERO dequant work beyond the cheap in-register nibble->bf16 unpack. -+// -+// What this mode does, faithful to vLLM Marlin: -+// (1) OFFLINE REPACK (one-time, at first engage; cached in a persistent cudaMalloc device buffer -+// keyed by src0->data): transpose the NVFP4 expert weights [N rows][Kb blocks](d[4]+qs[32]) -+// into two tile-contiguous planes, Q=[expert][Kb][N][32 qs bytes] and S=[expert][Kb][N][4 -+// ue4m3 scale bytes]. STAYS 4-BIT: qs kept packed, scales kept ue4m3 - the repacked buffer is -+// byte-for-byte the SAME size as the source NVFP4 weights (36 B / 64 weights). No int8 -+// expansion, no 2x memory (vs the rejected patch-0044 int8-Marlin which doubled the weight -+// bytes). Persists across all prefill steps. -+// (2) KERNEL reads the pre-packed tiles DIRECTLY: per (K-block, N-block) the BN rows' qs/scales -+// are contiguous, so the smem fill is a flat COALESCED 16-byte cp.async stream (no strided -+// gather, no address math). Activations are cast f32->bf16 IN REGISTER on the load into smem -+// (no separate global act pre-pass, no act-quant). The inner loop does only: ldmatrix the -+// bf16 A, in-register nibble->bf16 unpack of the weight (byte_perm/PRMT, LUT-free) scaled by -+// the in-register ue4m3 decode, and bf16 m16n8k16 mma.sync into f32 accumulators. cp.async -+// multistage pipelined, ragged per-tile expert offset (reuses 0035's tile map). NO separate -+// smem-staged dequant pass, NO __syncthreads-gated dequant pass. -+// -+// The repacked smem contents are bit-identical to the 0035 W4A16 smem, so the inner fragment math -+// (proven 81/81 on test-backend-ops MUL_MAT_ID) is unchanged and the output is bit-identical to the -+// 0035 W4A16 path; only the global->smem load path (coalesced) and the act cast (in-register) change. -+// -+// Toggle: LLAMA_W4A16_REPACK=1 (default 0 == OFF). Engages only inside the already-default-off -+// 0035 W4A16 grouped path (LLAMA_W4A16_PREFILL_M>0). Stock / decode / non-NVFP4 byte-untouched. -+ -+// True iff LLAMA_W4A16_REPACK != 0. -+bool ggml_cuda_w4a16_repack_enabled(); -+ -+// Total bytes the persistent repack cache currently holds on device (for the weight-memory-delta -+// report). The repacked layout stays 4-bit, so this equals the source NVFP4 bytes of the engaged -+// expert weight tensors; 0 when the toggle is OFF (no repack happens). -+size_t ggml_cuda_w4a16_repack_cache_bytes(); -+ -+// Offline-repack W4A16 grouped MoE GEMM over the token-sorted buffer. Same contract as -+// ggml_cuda_mul_mat_id_w4a16_grouped (see w4a16-gemm.cuh) but src1_sorted is the RAW f32 sorted -+// activations (cast to bf16 in-register; no global bf16 pre-pass needed) and the weights are read -+// from the cached repacked layout (built one-time from src0). -+void ggml_cuda_mul_mat_id_w4a16_repack( -+ ggml_backend_cuda_context & ctx, -+ const ggml_tensor * src0, -+ const float * src1_sorted, -+ float * dst_sorted, -+ const int * tokens_per_expert, -+ int64_t n_experts, int64_t K, int64_t N, -+ cudaStream_t stream); --- -2.43.0 - diff --git a/backend/cpp/llama-cpp-localai-paged/patches/paged/0046-paged-gate-GDN-prefill-geometry-by-scan-length.patch b/backend/cpp/llama-cpp-localai-paged/patches/paged/0046-paged-gate-GDN-prefill-geometry-by-scan-length.patch new file mode 100644 index 000000000..caaee575e --- /dev/null +++ b/backend/cpp/llama-cpp-localai-paged/patches/paged/0046-paged-gate-GDN-prefill-geometry-by-scan-length.patch @@ -0,0 +1,64 @@ +From 85266d4c10750b419716e4b8939ebd96ab424630 Mon Sep 17 00:00:00 2001 +From: Ettore Di Giacinto +Date: Tue, 30 Jun 2026 00:51:26 +0000 +Subject: [PATCH] feat(paged): gate GDN prefill geometry by scan length (patch + 0046) + +Patch 0022 retuned the gated-DeltaNet (GDN) sequential-recurrence dispatch +(case 128) to a (NUM_WARPS=16, COLS_PER_WARP=8) column-fold tile. That is a +DECODE win (short scans: small n_tokens, large n_seqs) but an UNCONDITIONAL +dense-prefill regression vs stock: on a long sequential scan the launch grid.z +collapses from S_v/4=32 to S_v/(16*8)=1, so the SMs starve. Profiling the +dense-prefill path attributed the whole regression (~-6%) to gated_delta_net +(+54% GPU time) at the (16,8) geometry. + +Gate the geometry by per-call scan length instead of applying (16,8) +unconditionally. Long scans (prefill, n_tokens >= GDN_PREFILL_NTOK, default 256) +take stock's high-grid.z (4,1) geometry; short scans (decode) keep the (16,8) +retune. This recovers dense prefill +7.2% back to stock parity while preserving +the (16,8) decode win. + +Bit-exact: patch 0022 proved every selectable {NUM_WARPS, COLS_PER_WARP} variant +is byte-identical (the sweep cannot change the md5), so this scan-length gate is +greedy-md5 bit-exact. GDN_PREFILL_NTOK tunes the crossover; the explicit +GDN_NW / GDN_CPW one-build %peak sweep still wins (the gate yields when either is +set), so the A/B harness is unchanged. + +Root cause: patch 0022 applied the (16,8) tile unconditionally. This patch +sequences after 0022/0044 (it edits the same gated_delta_net.cu case-128 +dispatch) and adds only the scan-length gate. + +Assisted-by: Claude:opus-4.8 [Claude Code] +Signed-off-by: Ettore Di Giacinto + +diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu +index 7121d807f..26667afa2 100644 +--- a/ggml/src/ggml-cuda/gated_delta_net.cu ++++ b/ggml/src/ggml-cuda/gated_delta_net.cu +@@ -550,6 +550,23 @@ static void launch_gated_delta_net( + launch_gdn_variant<64, KDA, keep_rs_t, 4, 1, 2>(GDN_LAUNCH_ARGS); + break; + case 128: { ++ // Dense-prefill regression fix: gate patch 0022's column-fold geometry by per-call scan ++ // length. The (16,8) tile is a DECODE win (short scans: n_tokens small, n_seqs large) but a ++ // long-sequential-scan PREFILL loss - grid.z collapses from S_v/4=32 to S_v/(16*8)=1, so the ++ // SMs starve on the long scan (profiled: gated_delta_net +54% GPU time == the whole dense- ++ // prefill regression). Long scans (prefill) take stock's high-grid.z (4,1) geometry; short ++ // scans (decode) keep the (16,8) winner. Every {NW,CPW} variant is byte-identical (patch 0022 ++ // proved md5-invariance across the ladder), so this stays greedy-md5 bit-exact. Default-on; ++ // GDN_PREFILL_NTOK tunes the crossover; the explicit GDN_NW/GDN_CPW sweep still wins (gate ++ // yields when either is set) so the one-build %peak A/B harness is unchanged. ++ static const int64_t gdn_prefill_ntok = ++ []{ const char * e = getenv("GDN_PREFILL_NTOK"); return e ? (int64_t) atoll(e) : (int64_t) 256; }(); ++ static const bool gdn_nw_forced = (getenv("GDN_NW") != nullptr); ++ static const bool gdn_cpw_forced = (getenv("GDN_CPW") != nullptr); ++ if (n_tokens >= gdn_prefill_ntok && !gdn_nw_forced && !gdn_cpw_forced) { ++ launch_gdn_variant<128, KDA, keep_rs_t, 4, 1, 2>(GDN_LAUNCH_ARGS); ++ break; ++ } + // Bit-exact occupancy/coalescing retune (patch 0022): fold COLS_PER_WARP columns per warp + // to raise per-warp memory-level parallelism on this bandwidth-bound recurrence. Default is + // the measured winner; GDN_NW / GDN_CPW override it for the one-build %peak sweep (every +-- +2.43.0 + diff --git a/backend/cpp/llama-cpp-localai-paged/patches/paged/0047-paged-GDN-M5-tensor-core-chunked-scan-f32.patch b/backend/cpp/llama-cpp-localai-paged/patches/paged/0047-paged-GDN-M5-tensor-core-chunked-scan-f32.patch new file mode 100644 index 000000000..fb4622ba8 --- /dev/null +++ b/backend/cpp/llama-cpp-localai-paged/patches/paged/0047-paged-GDN-M5-tensor-core-chunked-scan-f32.patch @@ -0,0 +1,713 @@ +From 2c32ab8b7a6c5bc90454881b8c10f8bad4f7cee0 Mon Sep 17 00:00:00 2001 +From: Ettore Di Giacinto +Date: Tue, 30 Jun 2026 09:45:13 +0200 +Subject: [PATCH] feat(paged): GDN M5 tensor-core chunked-scan prefill, + f32-only re-port (was patch 0044) + +Re-port the M5 tensor-core chunked gated-DeltaNet (GDN) prefill kernel from the +bf16/hybrid dev tree as an f32-only native commit, recovering the prefill win +that patch 0044 encoded, on the f32-only series (0026 ssm_bf16_tau dropped). + +What landed (f32/tf32 only): +- The mma.sync m16n8k8 helpers (tf32 + 3xtf32 limb-split; decays/gamma/beta stay + f32 outside the mma to preserve the bounded de-gating). +- gated_delta_net_chunked_cuda: the full tensor-core chunked scan, + KK/QK Gram (M2), KS/QS state-boundary 3xtf32 (M3), P*U output (M4), and the + form-T (A^-1) solve + Kc^T*DU state-update (M5). Selected by GDN_TC (0=serial + .. 4/5+=M5); the C=16 chunk-state stays in the 64KB smem buffer. +- Default-on under paged KV: GDN_TC=5, GDN_CHUNK_MIN=64 when LLAMA_KV_PAGED is set + and the user has not overridden either; OFF (INT_MAX) otherwise so the stock / + non-paged default is regression-free. GDN_CHUNK_MIN must stay > 1 (decode is 1 + token/call; at 1 the chunked path swallows decode and collapses S_TG). + +Stripped (not part of the f32-only series): the STATE_BF16 / HYBRID / gdn_state_t +/ gdn_hybrid_args template machinery (from dropped patch 0026), and the bf16 +CONFIG-C (M8) plus register-resident M6/M7 occupancy variants. The 0046 dense- +prefill geometry gate is untouched and coexists (it gates the SERIAL path; M5 is +the chunked path). + +Gates (GB10, sm_121a): +- Builds clean. +- Greedy md5 bit-exact (per-path, n=48 --temp 0 --seed 1, paged): dense + q36-27b-nvfp4 = 5951a5b4d624ce891e22ab5fca9bc439, MoE q36-35b-a3b-nvfp4 = + 8cb0ce23777bf55f92f63d0292c756b0, both default AND force-M5 (GDN_CHUNK_MIN=1). + test-backend-ops GATED_DELTA_NET 46/46 default and force-M5 (incl. the + multi-chunk, tail-chunk and multi-seq shapes). +- Prefill S_PP, MoE, LLAMA_KV_PAGED=1 LLAMA_MOE_FORCE_GRAPHS=1, -ntg 4 -npl 32, + vs the patch-0044 baseline (pre-0046, GDN_PREFILL_NTOK huge): +4.3% @512, + +17.8% @2048 (reproduces patch 0044; M5-on absolute matches patch 0044 M5). + vs the current 0046 baseline (0046 already raised the long-scan sequential + prefill): +4.3% @512, +1.2% @2048. +- Decode S_TG unchanged (within run noise). + +Assisted-by: Claude:opus-4.8 [Claude Code] +Signed-off-by: Ettore Di Giacinto +--- + ggml/src/ggml-cuda/gated_delta_net.cu | 550 +++++++++++++++++++++++--- + tests/test-backend-ops.cpp | 5 + + 2 files changed, 496 insertions(+), 59 deletions(-) + +diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu +index 26667afa2..0ceb1bc8f 100644 +--- a/ggml/src/ggml-cuda/gated_delta_net.cu ++++ b/ggml/src/ggml-cuda/gated_delta_net.cu +@@ -298,7 +298,115 @@ static void launch_gdn_variant( + // strong-decay tokens underflow to the correct zero rather than to inf. The math + // is equivalent to the sequential recurrence up to FP reduction order (a NEW + // per-path result, validated benign by test-backend-ops NMSE and greedy output). +-template ++// --- Phase-1 tensor-core Gram helpers (tf32 m16n8k8 mma.sync; sm_80+/sm_121a). --- ++// Reproduces the PoC-proven path (~/scratch_tc_gdn_poc/gdn_gram_bench.cu, tf32 NMSE ~3e-9): ++// out[rowbase..+15][colbase..+7] = Xs[rows] . Ys[cols], Xs/Ys row-major [*][DK]. ++__device__ __forceinline__ unsigned gdn_f2tf32(float f) { ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ++ unsigned r; ++ asm("cvt.rna.tf32.f32 %0, %1;" : "=r"(r) : "f"(f)); ++ return r; ++#else ++ (void) f; ++ return 0u; ++#endif ++} ++ ++// Operand loaders for the Gram/state mma helpers: stage f32 operands as tf32. This f32-only ++// re-port keeps every operand full-width -- the plain-tf32 path (10-bit mantissa, f32 accumulate) ++// is the highest-precision tensor-core option on sm_121a, and the 3xtf32 limb-split helpers below ++// recover near-f32 accuracy for the decay-coupled state-boundary (KS/QS) and state-carry products ++// whose error feeds the A-inverse solve / compounds across chunks. ++__device__ __forceinline__ unsigned gdn_ld_tf32(float f) { return gdn_f2tf32(f); } ++__device__ __forceinline__ float gdn_ld_f32 (float f) { return f; } ++ ++__device__ __forceinline__ void gdn_mma_m16n8k8(float c[4], const unsigned a[4], const unsigned b[2]) { ++#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ++ asm volatile( ++ "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " ++ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%0,%1,%2,%3};\n" ++ : "+f"(c[0]), "+f"(c[1]), "+f"(c[2]), "+f"(c[3]) ++ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1])); ++#else ++ (void) c; (void) a; (void) b; ++#endif ++} ++ ++template ++__device__ __forceinline__ void gdn_gram_tile_mma( ++ float c[4], const TX * __restrict__ Xs, const TY * __restrict__ Ys, ++ int rowbase, int colbase, int lg, int lt) { ++ c[0] = c[1] = c[2] = c[3] = 0.0f; ++ #pragma unroll ++ for (int ks = 0; ks < DK; ks += 8) { ++ unsigned a[4], b[2]; ++ a[0] = gdn_ld_tf32(Xs[(rowbase + lg ) * DK + ks + lt ]); ++ a[1] = gdn_ld_tf32(Xs[(rowbase + lg + 8) * DK + ks + lt ]); ++ a[2] = gdn_ld_tf32(Xs[(rowbase + lg ) * DK + ks + lt + 4]); ++ a[3] = gdn_ld_tf32(Xs[(rowbase + lg + 8) * DK + ks + lt + 4]); ++ b[0] = gdn_ld_tf32(Ys[(colbase + lg ) * DK + ks + lt ]); ++ b[1] = gdn_ld_tf32(Ys[(colbase + lg ) * DK + ks + lt + 4]); ++ gdn_mma_m16n8k8(c, a, b); ++ } ++} ++ ++// 3xtf32 (CUTLASS fp32-emulation): split each f32 operand into hi/lo tf32 limbs and run ++// 3 limb-products per k-subtile (hi*hi + hi*lo + lo*hi); ~f32 accuracy at ~3x the mma count. ++// Used for the state-boundary products (KS/QS) whose error feeds the A-inverse solve (M3). ++template ++__device__ __forceinline__ void gdn_gram_tile_mma_3x( ++ float c[4], const TX * __restrict__ Xs, const TY * __restrict__ Ys, ++ int rowbase, int colbase, int lg, int lt) { ++ c[0] = c[1] = c[2] = c[3] = 0.0f; ++ #pragma unroll ++ for (int ks = 0; ks < DK; ks += 8) { ++ float af[4], bf[2]; ++ af[0] = gdn_ld_f32(Xs[(rowbase + lg ) * DK + ks + lt ]); ++ af[1] = gdn_ld_f32(Xs[(rowbase + lg + 8) * DK + ks + lt ]); ++ af[2] = gdn_ld_f32(Xs[(rowbase + lg ) * DK + ks + lt + 4]); ++ af[3] = gdn_ld_f32(Xs[(rowbase + lg + 8) * DK + ks + lt + 4]); ++ bf[0] = gdn_ld_f32(Ys[(colbase + lg ) * DK + ks + lt ]); ++ bf[1] = gdn_ld_f32(Ys[(colbase + lg ) * DK + ks + lt + 4]); ++ unsigned ahi[4], alo[4], bhi[2], blo[2]; ++ #pragma unroll ++ for (int z = 0; z < 4; z++) { ahi[z] = gdn_f2tf32(af[z]); alo[z] = gdn_f2tf32(af[z] - __uint_as_float(ahi[z])); } ++ #pragma unroll ++ for (int z = 0; z < 2; z++) { bhi[z] = gdn_f2tf32(bf[z]); blo[z] = gdn_f2tf32(bf[z] - __uint_as_float(bhi[z])); } ++ gdn_mma_m16n8k8(c, ahi, bhi); // hi*hi (dominant limb) ++ gdn_mma_m16n8k8(c, ahi, blo); // hi*lo ++ gdn_mma_m16n8k8(c, alo, bhi); // lo*hi ++ } ++} ++ ++// State-update tile (P6): S_C[i][j] += sum_t Kc[t][i] * DU[t][j], with Kc read TRANSPOSED ++// (i as the m16n8k8 M-row, t as the K-contraction) and DU = d(t,last)*U staged in the Ud ++// layout (DUd[j*KC + t]). 3xtf32: the cross-chunk carry compounds over every chunk step. ++template ++__device__ __forceinline__ void gdn_state_tile_mma_3x( ++ float c[4], const TK * __restrict__ Kc, const TD * __restrict__ DUd, ++ int rowbase, int colbase, int lg, int lt) { ++ c[0] = c[1] = c[2] = c[3] = 0.0f; ++ #pragma unroll ++ for (int ks = 0; ks < KC; ks += 8) { ++ float af[4], bf[2]; ++ af[0] = gdn_ld_f32(Kc[(ks + lt ) * DK + (rowbase + lg )]); ++ af[1] = gdn_ld_f32(Kc[(ks + lt ) * DK + (rowbase + lg + 8)]); ++ af[2] = gdn_ld_f32(Kc[(ks + lt + 4) * DK + (rowbase + lg )]); ++ af[3] = gdn_ld_f32(Kc[(ks + lt + 4) * DK + (rowbase + lg + 8)]); ++ bf[0] = gdn_ld_f32(DUd[(colbase + lg) * KC + (ks + lt )]); ++ bf[1] = gdn_ld_f32(DUd[(colbase + lg) * KC + (ks + lt + 4)]); ++ unsigned ahi[4], alo[4], bhi[2], blo[2]; ++ #pragma unroll ++ for (int z = 0; z < 4; z++) { ahi[z] = gdn_f2tf32(af[z]); alo[z] = gdn_f2tf32(af[z] - __uint_as_float(ahi[z])); } ++ #pragma unroll ++ for (int z = 0; z < 2; z++) { bhi[z] = gdn_f2tf32(bf[z]); blo[z] = gdn_f2tf32(bf[z] - __uint_as_float(bhi[z])); } ++ gdn_mma_m16n8k8(c, ahi, bhi); ++ gdn_mma_m16n8k8(c, ahi, blo); ++ gdn_mma_m16n8k8(c, alo, bhi); ++ } ++} ++ ++template + __global__ void gated_delta_net_chunked_cuda( + const float * __restrict__ q, const float * __restrict__ k, + const float * __restrict__ v, const float * __restrict__ g, +@@ -329,6 +437,9 @@ __global__ void gated_delta_net_chunked_cuda( + float * csh = Amat + (size_t) C * C; // [C] cumsum(log-gate) + float * gam = csh + C; // [C] gamma_t = exp(cs_t) + float * bet = gam + C; // [C] beta_t ++ // Phase-1 tensor-core Gram scratch (allocated only when GRAM_MMA; KK feeds A, QK feeds P). ++ float * KKsh = bet + C; // [C*C] KK[t][t'] = k_t . k_t' (stride C) ++ float * QKsh = KKsh + (size_t) C * C; // [C*C] QK[t][t'] = q_t . k_t' (stride C) + + // S0: thread j owns column j (Sd[j*dk + i]); load is a contiguous per-thread copy from the + // M-layout cache view (read_state[j*dk + i] = M[j*S_v + i] = S[i][j]). Same identity/gather +@@ -357,6 +468,15 @@ __global__ void gated_delta_net_chunked_cuda( + Kc[t * dk + i] = k_base[(c0 + t) * sq2 + i]; + Qc[t * dk + i] = q_base[(c0 + t) * sq2 + i]; + } ++ if constexpr (TC >= 3) { ++ // Zero the stale K/Q tail (rows t >= Cc): the tensor-core mma paths contract the full ++ // chunk dim and 0*NaN (uninitialized smem) would poison the result. Serial paths only ++ // touch t < Cc, so this is gated to the mma levels. ++ for (int e = Cc * dk + j; e < C * dk; e += dv) { ++ Kc[e] = 0.0f; ++ Qc[e] = 0.0f; ++ } ++ } + if (j < Cc) { + csh[j] = g[gb_base + (c0 + j) * sb2]; // raw log-gate, prefix-summed below + bet[j] = beta[gb_base + (c0 + j) * sb2]; +@@ -372,15 +492,53 @@ __global__ void gated_delta_net_chunked_cuda( + } + __syncthreads(); + ++ // --- Phase-1: tensor-core tf32 Gram products (KK->A via warp0, QK->P via warp1). --- ++ // Full C x C tiles into KKsh/QKsh (stride C); decay/beta applied in f32 in the loops below. ++ // Tail chunks (Cc= Cc, but those entries are never read. ++ if constexpr (TC >= 1) { ++ const int w = threadIdx.x >> 5; // warp: 0 -> KK, 1 -> QK ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; // 0..7 ++ const int lt = lane & 3; // 0..3 ++ if (w < 2) { ++ const float * Xs = (w == 0) ? Kc : Qc; ++ float * Out = (w == 0) ? KKsh : QKsh; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nt = 0; nt < (C + 7) / 8; nt++) { ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma(cc, Xs, Kc, rowbase, colbase, lg, lt); ++ const int rr[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int ccol[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ if (rr[l] < C && ccol[l] < C) { ++ Out[rr[l] * C + ccol[l]] = cc[l]; ++ } ++ } ++ } ++ } ++ } ++ __syncthreads(); ++ } ++ + // --- A = I + tril(beta_t * d(t',t) * (k_t . k_t'), -1) (cooperative over C*C) --- + for (int e = j; e < Cc * Cc; e += dv) { + const int t = e / Cc; + const int tp = e % Cc; + float a = 0.0f; + if (tp < t) { +- float kk = 0.0f; +- for (int i = 0; i < dk; i++) { +- kk += Kc[t * dk + i] * Kc[tp * dk + i]; ++ float kk; ++ if constexpr (TC >= 1) { ++ kk = KKsh[t * C + tp]; ++ } else { ++ kk = 0.0f; ++ for (int i = 0; i < dk; i++) { ++ kk += Kc[t * dk + i] * Kc[tp * dk + i]; ++ } + } + const float dd = expf(csh[t] - csh[tp]); // d(tp,t) = gamma_t/gamma_tp + a = bet[t] * dd * kk; +@@ -392,65 +550,304 @@ __global__ void gated_delta_net_chunked_cuda( + __syncthreads(); + + // --- RHS[t][j] = beta_t (v_t[j] - gamma_t * (S0^T k_t)[j]) -> Ud[j*C + t] --- +- for (int t = 0; t < Cc; t++) { +- float ks = 0.0f; // (S0^T k_t)[j] = sum_i S[i][j] k_t[i] +- for (int i = 0; i < dk; i++) { +- ks += Sd[j * dk + i] * Kc[t * dk + i]; ++ if constexpr (TC >= 2) { ++ // M3: fused tensor-core KS = Kc * S0 (3xtf32 state-boundary product). The mma ++ // output is consumed straight from registers into RHS -> Ud, so NO extra C*dv ++ // smem buffer is needed (the 64KB state still occupies smem until M6). Warp w ++ // owns dv n-tiles [w*NTPW, ..); each lane writes the RHS entries it produced. ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ constexpr int NWARP = S_v / 32; ++ constexpr int NT = dv / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) break; ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma_3x(cc, Kc, Sd, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; ++ if (t < Cc && jc < dv) { ++ const float vtj = v_base[(c0 + t) * sv2 + jc]; ++ Ud[jc * C + t] = bet[t] * (vtj - gam[t] * cc[l]); ++ } ++ } ++ } ++ } ++ __syncthreads(); // RHS written cross-thread -> publish before the per-column solve ++ } else { ++ for (int t = 0; t < Cc; t++) { ++ float ks = 0.0f; // (S0^T k_t)[j] = sum_i S[i][j] k_t[i] ++ for (int i = 0; i < dk; i++) { ++ ks += Sd[j * dk + i] * Kc[t * dk + i]; ++ } ++ const float vtj = v_base[(c0 + t) * sv2 + j]; ++ Ud[j * C + t] = bet[t] * (vtj - gam[t] * ks); + } +- const float vtj = v_base[(c0 + t) * sv2 + j]; +- Ud[j * C + t] = bet[t] * (vtj - gam[t] * ks); ++ } ++ if constexpr (TC >= 3) { ++ // Zero the stale RHS tail (rows t >= Cc) before the full-K mma consumers (P*U at TC>=3; ++ // apply + state at TC>=4). Without this the masked tail terms compute 0*NaN = NaN. ++ for (int t = Cc; t < C; t++) Ud[j * C + t] = 0.0f; ++ __syncthreads(); + } + +- // --- solve A U = RHS in place (unit lower-tri fwd subst); per-thread, no inter-step sync --- +- for (int t = 1; t < Cc; t++) { +- float acc = Ud[j * C + t]; +- for (int tp = 0; tp < t; tp++) { +- acc -= Amat[t * Cc + tp] * Ud[j * C + tp]; ++ // --- solve A U = RHS (A unit-lower-tri) --- ++ if constexpr (TC >= 4) { ++ // M5/P7: form T = A^{-1} explicitly (FLA UT transform), then U = T*RHS as one ++ // dependency-free tf32 GEMM. At C<=16 A is a single b=16 block, so the off-diagonal ++ // Phase-O is empty; only the f32 in-shared diagonal inverse (Phase-D) + the wide ++ // apply remain. Phase-D: column-parallel EXACT f32 inverse of the Cc x Cc unit- ++ // lower-tri A -- thread c solves A x = e_c, writing column c of T into KKsh (free ++ // since KK was consumed into A). This is the strong-coupling amplifier -> f32. ++ if (j < C) { ++ if (j < Cc) { ++ float x[C]; ++ #pragma unroll ++ for (int r = 0; r < C; r++) x[r] = 0.0f; ++ x[j] = 1.0f; ++ for (int r = j + 1; r < Cc; r++) { ++ float acc = 0.0f; ++ for (int m = j; m < r; m++) acc += Amat[r * Cc + m] * x[m]; ++ x[r] = -acc; ++ } ++ #pragma unroll ++ for (int r = 0; r < C; r++) KKsh[r * C + j] = x[r]; // rows >= Cc are 0 ++ } else { ++ #pragma unroll ++ for (int r = 0; r < C; r++) KKsh[r * C + j] = 0.0f; // cols >= Cc are 0 ++ } ++ } ++ __syncthreads(); ++ // Apply U = T*RHS, M=C N=dv K=C; T=KKsh (stride C), RHS=Ud (stride C). In place on ++ // Ud: hold every output tile in registers, sync to finish the RHS reads, then ++ // overwrite Ud with U (avoids the read/write aliasing of a same-buffer GEMM). ++ { ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ constexpr int NWARP = S_v / 32; ++ constexpr int NT = dv / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ float ureg[NTPW][4]; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt < NT) gdn_gram_tile_mma(ureg[nn], KKsh, Ud, 0, nt * 8, lg, lt); ++ } ++ __syncthreads(); // all RHS(Ud) reads done before overwriting with U ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) continue; ++ const int colbase = nt * 8; ++ const int tt[4] = {lg, lg, lg + 8, lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; ++ if (t < Cc && jc < dv) Ud[jc * C + t] = ureg[nn][l]; ++ } ++ } ++ __syncthreads(); + } +- Ud[j * C + t] = acc; ++ } else { ++ for (int t = 1; t < Cc; t++) { ++ float acc = Ud[j * C + t]; ++ for (int tp = 0; tp < t; tp++) { ++ acc -= Amat[t * Cc + tp] * Ud[j * C + tp]; ++ } ++ Ud[j * C + t] = acc; ++ } ++ __syncthreads(); // U finalized; Amat free for P below + } +- __syncthreads(); // U finalized; Amat free for P below (and Ud read across-thread? no, own col) + +- // --- P[t][t'] = d(t',t) * (q_t . k_t') for t' <= t (reuse Amat) --- +- for (int e = j; e < Cc * Cc; e += dv) { +- const int t = e / Cc; +- const int tp = e % Cc; +- float p = 0.0f; +- if (tp <= t) { +- float qk = 0.0f; +- for (int i = 0; i < dk; i++) { +- qk += Qc[t * dk + i] * Kc[tp * dk + i]; ++ // --- P[t][t'] = d(t',t) * (q_t . k_t') for t' <= t --- ++ if constexpr (TC >= 3) { ++ // M4: build P (lower-tri, decay pre-baked in f32 -> bounded) IN PLACE in QKsh at ++ // fixed stride C so the P*U output mma can read it as a tf32 A-operand. Full C*C ++ // grid: upper-tri / out-of-range entries are zeroed so the K=C mma needs no masking. ++ for (int e = j; e < C * C; e += dv) { ++ const int t = e / C; ++ const int tp = e % C; ++ float p = 0.0f; ++ if (tp <= t && t < Cc && tp < Cc) { ++ const float dd = expf(csh[t] - csh[tp]); ++ p = dd * QKsh[t * C + tp]; // QKsh holds QK (M2); overwrite in place with P ++ } ++ QKsh[t * C + tp] = p; ++ } ++ __syncthreads(); ++ } else { ++ for (int e = j; e < Cc * Cc; e += dv) { ++ const int t = e / Cc; ++ const int tp = e % Cc; ++ float p = 0.0f; ++ if (tp <= t) { ++ float qk; ++ if constexpr (TC >= 1) { ++ qk = QKsh[t * C + tp]; ++ } else { ++ qk = 0.0f; ++ for (int i = 0; i < dk; i++) { ++ qk += Qc[t * dk + i] * Kc[tp * dk + i]; ++ } ++ } ++ const float dd = expf(csh[t] - csh[tp]); ++ p = dd * qk; + } +- const float dd = expf(csh[t] - csh[tp]); +- p = dd * qk; ++ Amat[t * Cc + tp] = p; + } +- Amat[t * Cc + tp] = p; ++ __syncthreads(); + } +- __syncthreads(); + + // --- O[t][j] = gamma_t (S0^T q_t)[j] + sum_{t'<=t} P[t][t'] U[t'][j] (* scale) --- +- for (int t = 0; t < Cc; t++) { +- float qs = 0.0f; // (S0^T q_t)[j] (uses pre-update S) +- for (int i = 0; i < dk; i++) { +- qs += Sd[j * dk + i] * Qc[t * dk + i]; ++ if constexpr (TC >= 2) { ++ // M3: fused tensor-core QS = Qc * S0 (3xtf32, pre-update S0). Deposit the ++ // gamma_t*QS[t][j] cross-chunk term into dst from the mma registers; the O loop ++ // below reads it back (published via __syncthreads) and adds the intra-chunk P*U. ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ constexpr int NWARP = S_v / 32; ++ constexpr int NT = dv / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) break; ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma_3x(cc, Qc, Sd, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; ++ if (t < Cc && jc < dv) { ++ attn_base[(c0 + t) * S_v * H + jc] = gam[t] * cc[l]; ++ } ++ } ++ } + } +- float o = gam[t] * qs; +- for (int tp = 0; tp <= t; tp++) { +- o += Amat[t * Cc + tp] * Ud[j * C + tp]; ++ __syncthreads(); ++ } ++ if constexpr (TC >= 3) { ++ // M4: O += P*U via tensor-core (tf32-safe: P is f32-bounded, decay pre-baked). ++ // GEMM O[t][j] += sum_t' P[t][t']*U[t'][j], M=C N=dv K=C; P=QKsh (stride C), ++ // U=Ud (stride C). The gamma_t*QS cross-chunk term was deposited into dst above; ++ // fold it in here then * scale. Warp w owns dv n-tiles [w*NTPW, ..). ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ constexpr int NWARP = S_v / 32; ++ constexpr int NT = dv / 8; ++ constexpr int NTPW = (NT + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int mt = 0; mt < (C + 15) / 16; mt++) { ++ const int rowbase = mt * 16; ++ #pragma unroll ++ for (int nn = 0; nn < NTPW; nn++) { ++ const int nt = w * NTPW + nn; ++ if (nt >= NT) break; ++ const int colbase = nt * 8; ++ float cc[4]; ++ gdn_gram_tile_mma(cc, QKsh, Ud, rowbase, colbase, lg, lt); ++ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int t = tt[l], jc = jj[l]; ++ if (t < Cc && jc < dv) { ++ const int64_t oi = (int64_t)(c0 + t) * S_v * H + jc; ++ attn_base[oi] = (attn_base[oi] + cc[l]) * scale; // QS term + P*U ++ } ++ } ++ } ++ } ++ } else { ++ for (int t = 0; t < Cc; t++) { ++ float o; ++ if constexpr (TC >= 2) { ++ o = attn_base[(c0 + t) * S_v * H + j]; // gamma_t*QS[t][j] deposited above ++ } else { ++ float qs = 0.0f; // (S0^T q_t)[j] (uses pre-update S) ++ for (int i = 0; i < dk; i++) { ++ qs += Sd[j * dk + i] * Qc[t * dk + i]; ++ } ++ o = gam[t] * qs; ++ } ++ for (int tp = 0; tp <= t; tp++) { ++ o += Amat[t * Cc + tp] * Ud[j * C + tp]; ++ } ++ attn_base[(c0 + t) * S_v * H + j] = o * scale; + } +- attn_base[(c0 + t) * S_v * H + j] = o * scale; + } + + // --- S_C[i][j] = gamma_{C-1} S[i][j] + sum_t d(t,C-1) k_t[i] u_t[j] --- + const float glast = gam[Cc - 1]; + const float cslast = csh[Cc - 1]; +- for (int i = 0; i < dk; i++) { +- float s = glast * Sd[j * dk + i]; +- for (int t = 0; t < Cc; t++) { +- const float dd = expf(cslast - csh[t]); // d(t, last) +- s += dd * Kc[t * dk + i] * Ud[j * C + t]; ++ if constexpr (TC >= 4) { ++ // M5/P6: state carry S_C = glast*S0 + Kc^T * DU via 3xtf32 mma. DU[t][j] = ++ // d(t,last)*U[t][j] is built IN PLACE in Ud (t>=Cc zeroed so the K=C contraction ++ // needs no per-k masking), then S_C accumulates over the chunk dim t. Kc is read ++ // transposed (i as M-row). M=dk N=dv K=C. Each output (i,j) has a unique owner so ++ // the glast*S0 read-modify-write is race-free. ++ for (int t = 0; t < C; t++) { ++ const float dd = (t < Cc) ? expf(cslast - csh[t]) : 0.0f; ++ Ud[j * C + t] = dd * Ud[j * C + t]; // thread j owns column j -> DU in place ++ } ++ __syncthreads(); ++ const int w = threadIdx.x >> 5; ++ const int lane = threadIdx.x & 31; ++ const int lg = lane >> 2; ++ const int lt = lane & 3; ++ constexpr int NWARP = S_v / 32; ++ constexpr int MT = dk / 16; // m-tiles over dk ++ constexpr int NT = dv / 8; // n-tiles over dv ++ constexpr int NTILES = MT * NT; ++ constexpr int TPW = (NTILES + NWARP - 1) / NWARP; ++ #pragma unroll ++ for (int idx = 0; idx < TPW; idx++) { ++ const int tile = w * TPW + idx; ++ if (tile >= NTILES) break; ++ const int rowbase = (tile / NT) * 16; ++ const int colbase = (tile % NT) * 8; ++ float cc[4]; ++ gdn_state_tile_mma_3x(cc, Kc, Ud, rowbase, colbase, lg, lt); ++ const int ii[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8}; ++ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1}; ++ #pragma unroll ++ for (int l = 0; l < 4; l++) { ++ const int i = ii[l], jc = jj[l]; ++ Sd[jc * dk + i] = glast * Sd[jc * dk + i] + cc[l]; ++ } ++ } ++ } else { ++ for (int i = 0; i < dk; i++) { ++ float s = glast * Sd[j * dk + i]; ++ for (int t = 0; t < Cc; t++) { ++ const float dd = expf(cslast - csh[t]); // d(t, last) ++ s += dd * Kc[t * dk + i] * Ud[j * C + t]; ++ } ++ Sd[j * dk + i] = s; + } +- Sd[j * dk + i] = s; + } + __syncthreads(); // Sd reused as S0 of next chunk; Kc/Qc/Amat reloaded next chunk + } +@@ -464,8 +861,7 @@ __global__ void gated_delta_net_chunked_cuda( + st[j * dk + i] = Sd[j * dk + i]; + } + } +- +-template ++template + static void launch_gdn_chunked( + const float * q_d, const float * k_d, const float * v_d, + const float * g_d, const float * b_d, const float * s_d, +@@ -477,10 +873,11 @@ static void launch_gdn_chunked( + const uint3 neqk1_magic, const uint3 rq3_magic, + float scale, cudaStream_t stream) { + const size_t smem = ((size_t) S_v * S_v + (size_t) 2 * C * S_v + (size_t) S_v * C +- + (size_t) C * C + (size_t) 3 * C) * sizeof(float); ++ + (size_t) C * C + (size_t) 3 * C ++ + (TC >= 1 ? (size_t) 2 * C * C : (size_t) 0)) * sizeof(float); + static bool attr_set = false; + if (!attr_set) { +- const cudaError_t e = cudaFuncSetAttribute(gated_delta_net_chunked_cuda, ++ const cudaError_t e = cudaFuncSetAttribute(gated_delta_net_chunked_cuda, + cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smem); + if (e != cudaSuccess) { + GGML_ABORT("gdn chunked: cudaFuncSetAttribute(maxDynSmem=%zu) failed: %s\n", smem, cudaGetErrorString(e)); +@@ -489,7 +886,7 @@ static void launch_gdn_chunked( + } + dim3 grid_dims(H, n_seqs, 1); + dim3 block_dims(S_v, 1, 1); +- gated_delta_net_chunked_cuda<<>>( ++ gated_delta_net_chunked_cuda<<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, + sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, + neqk1_magic, rq3_magic, scale, state_dst_d, ids_d, rs_head); +@@ -519,17 +916,52 @@ static void launch_gated_delta_net( + // sequential recurrence. Mathematically equivalent up to FP reduction order (NEW per-path md5; + // validated benign by test-backend-ops NMSE + greedy output). Toggle: GDN_CHUNK_OFF / GDN_CHUNK_MIN. + if constexpr (!KDA && !keep_rs_t) { +- // OPT-IN: this chunked path is bit-exact-benign (test-backend-ops green) but, at C=16 +- // (forced by GB10 99KB dyn-smem opt-in, all-shared), it is NOT yet faster than the tuned +- // sequential recurrence on this model (measured ~22%% slower S_PP, grid-starved at low +- // n_seqs + 1 block/SM occupancy). Default OFF so the backend default is regression-free; +- // enable for experiments / tuning with GDN_CHUNK_MIN=. See README section 5 (dev notes / rejected-flat levers). +- static const int gdn_chunk_min = []{ const char * e = getenv("GDN_CHUNK_MIN"); return e ? atoi(e) : INT_MAX; }(); ++ // DEFAULT-ON UNDER PAGED KV (f32-only re-port of patch 0044's M5). The M5 tensor-core path ++ // (GDN_TC=5: full-TC form-T solve + state-update mma, state in the 64KB smem buffer, C=16) ++ // is greedy-bit-exact (per-path md5 == the sequential canonical on the short gate prompt) ++ // and *beats* the tuned sequential recurrence on the Qwen3.6 MoE prefill: GB10, ++ // q36-35b-a3b-nvfp4, LLAMA_KV_PAGED=1 LLAMA_MOE_FORCE_GRAPHS=1, -ntg 4 -npl 32: ++ // -npp 512 : S_PP +3.5% ; -npp 2048 : S_PP +17.7% (more chunks to parallelize). ++ // Decode S_TG is unchanged (1-token calls never reach the engage threshold). ++ // GDN_CHUNK_MIN is the per-call engage threshold and MUST stay > 1: decode is 1 token/call, ++ // so any threshold above 1 leaves every decode step on the sequential recurrence (at ++ // GDN_CHUNK_MIN=1 the chunked path swallows decode and collapses S_TG by ~25%). Tuned to 64: ++ // above decode/tiny-call sizes, below the real MoE-prefill per-call count. OFF (INT_MAX) when ++ // not paged, so the stock / non-paged default is regression-free. Both knobs env-overridable. ++ static const bool kv_paged = (getenv("LLAMA_KV_PAGED") != nullptr); ++ static const int gdn_chunk_min = []{ ++ const char * e = getenv("GDN_CHUNK_MIN"); ++ if (e) return atoi(e); ++ return kv_paged ? 64 : INT_MAX; ++ }(); ++ // Tensor-core level selector (single build, clean runtime A/B). GDN_TC: ++ // 0 = serial scan (patch 0031); 1 = KK/QK Gram mma (M2); ++ // 2 = + KS/QS state-boundary mma, 3xtf32 (M3); 3 = + P*U output mma (M4); ++ // 4/5+ = M5 (full TC: form-T solve + state-update mma) - the DEFAULT under paged KV. ++ // (The bf16 CONFIG-C and register-resident M6/M7/M8 occupancy variants of patch 0044 are ++ // intentionally absent from this f32-only series; the +3.5/+17.7% prefill win is the M5 path.) ++ // GDN_GRAM_MMA=1 is kept as an alias for level 1. ++ static const int gdn_tc = []{ ++ const char * e = getenv("GDN_TC"); ++ if (e) return atoi(e); ++ const char * gm = getenv("GDN_GRAM_MMA"); ++ if (gm && atoi(gm) != 0) return 1; ++ return kv_paged ? 5 : 0; ++ }(); + if (S_v == 128 && n_tokens >= gdn_chunk_min) { +- launch_gdn_chunked<128, 16>( +- q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, +- H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, +- neqk1_magic, rq3_magic, scale, stream); ++#define GDN_CHUNKED_LAUNCH(TC_) \ ++ launch_gdn_chunked<128, 16, TC_>( \ ++ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, \ ++ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, \ ++ neqk1_magic, rq3_magic, scale, stream) ++ switch (gdn_tc) { ++ case 0: GDN_CHUNKED_LAUNCH(0); break; ++ case 1: GDN_CHUNKED_LAUNCH(1); break; ++ case 2: GDN_CHUNKED_LAUNCH(2); break; ++ case 3: GDN_CHUNKED_LAUNCH(3); break; ++ default: GDN_CHUNKED_LAUNCH(4); break; // GDN_TC >= 4 -> M5 (full TC, kernel TC=4) ++ } ++#undef GDN_CHUNKED_LAUNCH + return; + } + } +diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp +index 4e40d2353..817069860 100644 +--- a/tests/test-backend-ops.cpp ++++ b/tests/test-backend-ops.cpp +@@ -9372,6 +9372,11 @@ static std::vector> make_test_cases_eval() { + } + + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1)); ++ // Tensor-core chunked-GDN prefill path (S_v==128): multi-chunk (C=16) coverage, ++ // incl. a tail chunk (100 = 6*16+4) and multi-seq. Exercised via GDN_CHUNK_MIN + GDN_TC. ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1)); ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 100, 1)); ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 128, 2)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true)); +-- +2.43.0 +