docs(paged): scope tensor-core (mma) chunked GDN prefill kernel

Scopes the follow-up recorded by patch 0031 + README section 5: replace the
serial per-thread reductions of the chunked gated-DeltaNet prefill scan with
mma.sync tensor-core matmuls and lift the 1-block/SM occupancy ceiling, the
path that would beat the tuned sequential scan and close the GDN prefill
bucket toward vLLM's ~2.5x-cheaper chunked scan.

Confirmed (not assumed) the GB10/sm_121a tensor-core reality: consumer
Blackwell (SM12x) has NO wgmma (Hopper-only) and NO tcgen05/TMEM (sm_100a
data-center only); the usable path is the extended mma.sync family. So the
kernel is a warp-synchronous mma.sync + cp.async design (reusing ggml's
mma.cuh tiles), not a wgmma/TMA/tcgen05 design - patch 0031's 'mma/wgmma'
shorthand reads as mma only on this part.

Design: register-resident state frees the 64KB that forced C=16, admitting
C=64 under the 99KB shared opt-in; tf32 inputs / f32 accumulate with a 3xtf32
precision ladder; decays/gamma/beta stay f32 outside the mma to preserve the
bounded de-gating; A-inverse via blocked forward substitution (FLA UT
transform) with mma off-diagonal coupling. Mechanism: chunking cuts state-BW
~Cx, mma absorbs the O(C^2) intra-chunk flops the serial 0031 could not.
Honest: multi-week, high risk, no vendor kernel to route to on sm_121; gains
beat the sequential scan and close most of the bucket but not full sm_100-class
parity. KL-gate binding (NMSE likely fails at reduced precision). Phased:
re-profile -> two-product PoC -> full intra-chunk + C=64 + reg-state ->
occupancy/cp.async; opt-in default-OFF until A/B-proven.

Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-06-28 17:23:51 +00:00
parent 9a28f23134
commit 4bdd26a7f0

View File

@@ -0,0 +1,362 @@
# TENSORCORE_GDN_SCOPE - tensor-core chunked gated-DeltaNet prefill (design only)
**Status: DESIGN + SCOPE ONLY. No kernel written, no GPU run, no PTX in this pass.**
This scopes the follow-up recorded by patch 0031 and README section 5: a
tensor-core (`mma`) chunked gated-DeltaNet (GDN) prefill kernel - the path that
would actually *beat* the tuned sequential scan and close the GDN prefill bucket
toward vLLM. vLLM's chunked GDN scan was measured ~2.5x cheaper in the prefill
ground-truth precisely because it pushes the intra-chunk products through
tensor-core matmuls; patch 0031 proved the chunking math but, with serial
per-thread reductions at the GB10-forced `C=16`, came out ~22% *slower* than the
sequential recurrence. This document scopes replacing those reductions with
`mma.sync` matmuls and lifting the occupancy ceiling.
> **Read patch 0031 + README section 5 first.** The bounded/stable de-gating form
> (pairwise decays `d <= 1`, `gamma <= 1`), the per-path bit-exact precedent, and
> the honest negative ("C=16 all-shared -> 1 block/SM -> serial reductions -> 22%
> slower, grid-starved at low n_seqs") are the starting point. This doc does not
> re-derive the algebra; it maps it onto tensor cores.
> **Regime note (the mechanism, read this).** The sequential scan is
> **bandwidth-bound**: it re-streams the entire `128x128` f32 state (64KB) once
> *per token*. README section 5 already records it runs at ~84.7% of GB10 peak BW
> (decode) and the recurrence is a llama *win* vs vLLM's BW. So a tensor-core
> kernel does **not** win by doing the same work faster - it wins by **changing
> the work**: chunking by `C` reads/writes the state `n_tokens/C` times instead of
> `n_tokens` (a ~`C`x cut in state traffic, the dominant prefill GDN cost), and the
> price is `O(C^2)` extra intra-chunk dot-products per chunk. The naive 0031 paid
> that price in serial f32 reductions, which cost *more* than the BW it saved -
> hence 22% slower. **Tensor cores make the added intra-chunk flops nearly free,
> so the BW saving becomes a net win.** That is exactly why vLLM's chunked scan is
> 2.5x cheaper. The whole lever rests on this trade; if a GPU re-profile shows
> prefill GDN is *not* state-BW-bound, stop and re-scope (step 0 below).
---
## 1. GB10 tensor-core reality (sm_121a) - confirmed, not assumed
GB10 / DGX Spark reports **compute capability 12.1 (sm_121)**, CUDA 13 (README
section "Hardware: GB10 / DGX Spark (CUDA 13, sm_121)"). sm_121a is **consumer
Blackwell** (the SM12x family, same tensor-core programming model as RTX 50 /
sm_120), **not** data-center Blackwell (sm_100a / GB200). This distinction is the
single most important input to the design and is confirmed from sources, not
assumed:
- **No `wgmma`.** Warp-group MMA is Hopper (sm_90a) only; targeting SM12x yields
`ptxas error: Instruction 'wgmma.fence' not supported on .target 'sm_120'`.
Do **not** design around Hopper-style warp-group MMA.
- **No `tcgen05` / no TMEM.** SM12x lacks the Tensor Memory hardware entirely, so
the autonomous 5th-gen tensor-core path (`tcgen05.mma`, the sm_100a data-center
instruction) is unavailable. This is the same wall that makes vLLM/CUTLASS fall
back to Marlin and gate FP4 to sm_100a on GB10 (tracked in CUTLASS #2800/#2947,
vLLM #43906). We cannot use it either.
- **What sm_121a DOES have: extended `mma.sync`.** The Ampere/Ada warp-level
`mma.sync` family, extended with the Blackwell numeric formats (FP8/FP6/FP4).
"Consumer Blackwell put new data types on top of the oldest programming model."
For our operands (q/k/v/state are f32 in the op, see below) the usable tiles are
the standard warp-level ones:
- **bf16/f16 inputs, f32 accumulate:** `mma.sync.aligned.m16n8k16` (and
`m16n8k8`). 7-bit (bf16) / 10-bit (f16) input mantissa.
- **tf32 inputs, f32 accumulate:** `m16n8k8` / `m16n8k4`. 10-bit input mantissa
- the **highest-precision tensor-core option** on this part, and the one this
design defaults to (the GDN is decay-sensitive; see section 4).
- FP8 (`m16n8k32`) / FP4 (`m16n8k64.kind::mxf4nvf4`, block-scaled) compile on
sm_121a but are **out of scope** here - the GDN q/k/v/state are not 4/8-bit.
- **`cp.async` is available** (Ampere+), so global->shared double-buffering of the
K/Q chunk tiles is on the table for the occupancy phase. There is **no TMA** on
SM12x; staging is plain `cp.async`, not `cp.async.bulk`.
**Reuse, do not hand-roll PTX.** ggml already ships a warp-level MMA tile
abstraction at `ggml/src/ggml-cuda/mma.cuh` (the `tile<M,N,T>` fragments +
`mma()` used by the FlashAttention-mma and MMQ kernels), and it already routes
through `turing_mma_available(cc)` / `ampere_mma_available(cc)` - i.e. it is
sm_121-correct today. Build the GDN matmuls on that API (bf16/half/tf32 fragments,
f32 accumulators), not on raw `asm volatile("mma.sync...")`. This de-risks the
kernel and keeps it consistent with the backend's other tensor-core paths.
**Bottom line for the design:** the kernel is a **warp-synchronous `mma.sync`**
kernel (Ampere-class programming model with Blackwell silicon), *not* a
warp-group / TMA / tcgen05 kernel. Every "wgmma"/"tcgen05" idea from FLA's
sm_90/sm_100 kernels must be down-translated to `mma.sync` + `cp.async`. Patch
0031's and README's shorthand "mma/wgmma" should be read as **mma only** on this
part.
---
## 2. Mapping the chunked GDN matmuls onto `mma.sync`
The chunked gated-delta-rule (patch 0031 header) has six dot-product families.
Five are plain matmuls and map cleanly to `mma`; the sixth (the A-inverse) is a
unit-lower-triangular solve and is the one subtle case. Notation: `C` = chunk
length, `dk = dv = S_v = 128` (GDN head dim), per `(head, seq)` block.
| # | Product (0031 step) | Shape | mma form | Notes |
|---|---|---|---|---|
| 1 | `KK[t,t'] = k_t . k_t'` (for `A`) | `C x C` over `k=dk=128` | `(C x dk) x (dk x C)` | Gram matrix; only strict-lower triangle used. Decay `d(t',t)` + `beta_t` applied **after** mma in f32. |
| 2 | `QK[t,t'] = q_t . k_t'` (for `P`/`O`) | `C x C` over `k=dk` | `(C x dk) x (dk x C)` | Lower triangle (`t' <= t`); decay applied after in f32. |
| 3 | `KS[t,j] = (S0^T k_t)[j]` | `C x dv` over `k=dk` | `(C x dk) x (dk x dv)` | `S0` is the chunk-entry state (stationary operand). Feeds RHS of the solve. |
| 4 | `QS[t,j] = (S0^T q_t)[j]` | `C x dv` over `k=dk` | `(C x dk) x (dk x dv)` | The `gamma_t` cross-chunk term of `O`. |
| 5 | `O += P . U` | `C x dv` over `k=C` | `(C x C) x (C x dv)` | `P` (decay-masked `QK`) times the solved `U`. |
| 6 | `S_C += K^T (D .* U)` | `dk x dv` over `k=C` | `(dk x C) x (C x dv)` | The state update; `D` = `diag(d(t,last))` applied to `U` in f32 first. |
| 7 | `U = A^{-1} RHS` | `C x C` solve, `C x dv` RHS | blocked fwd-subst (see below) | The only non-GEMM. |
**Critical precision invariant (preserve the bounded de-gating).** Every decay
(`gamma_t`, `d(t',t) = exp(cs_t - cs_t')`) and every `beta_t` stays in **f32** and
is applied as an elementwise scale **before/after** the mma, never inside it. The
mma only ever multiplies the raw, unweighted dot-products (`k.k`, `q.k`,
`S0^T k`, `S0^T q`, `P.U`, `K^T U`). This keeps the strong-decay underflow-to-zero
behaviour (the adversarial `g in [-20, -1e-4]` op test) exactly as 0031 has it -
the numerically delicate part never touches reduced precision. This is the
discipline that makes a tf32/bf16 mma kernel safe for a decay-sensitive op.
### The A-inverse (step 7) - it CAN be tensor-core'd
`A = I + N`, `N = tril(beta_t d(t',t) k_t.k_t', -1)` is **strictly lower
triangular**, hence **nilpotent** (`N^C = 0`). Two routes, both better than 0031's
serial per-thread forward substitution:
- **Blocked forward substitution (RECOMMENDED, this is the FLA "UT transform").**
Partition `C` into sub-blocks of `b` (e.g. `b = 16`, one mma `m`-tile). Invert
each `b x b` diagonal block in registers (it is unit-lower-triangular `b x b`,
cheap: a short serial solve or the finite Neumann series on a `b`-nilpotent,
`<= b-1` terms), then propagate to the off-diagonal sub-blocks with **mma**
(the inter-block coupling `U_i -= A_ij U_j` is exactly a `(b x b) x (b x dv)`
matmul). For `C = 64, b = 16` that is 4 tiny in-register diagonal solves + a
triangular sweep of mma updates - the bulk of the solve is on tensor cores, only
the `16 x 16` diagonals stay scalar.
- **Neumann/Newton-Schulz inverse (fallback).** `A^{-1} = I - N + N^2 - ... ` is
finite (`C` terms) but `O(C)` mma's of `C x C`; Newton-Schulz
(`X <- X(2I - AX)`) converges in `~log2(C)` steps for the nilpotent part. Cheap
in flops, but more numerically exposed than blocked subst for adversarial decays.
Keep as a fallback if blocked subst's register pressure hurts occupancy.
Verdict: **blocked forward substitution** - it keeps the sensitive diagonal solve
exact-in-registers and tensor-core's only the well-conditioned off-diagonal
coupling. This is precisely the structure FLA/vLLM use, down-translated to `mma`.
### Tile/chunk design that fits the 99KB shared budget AND feeds the mma
The 0031 failure was a layout failure: the all-shared `128x128` f32 state (64KB)
crowded out everything and forced `C=16`. The fix is to get the state **out of the
bulk shared footprint**. Two complementary mechanisms:
1. **State register-resident across the chunk loop (the key move).** `S` only
participates at chunk boundaries (steps 3,4 at entry; step 6 at exit). Keep it
as **mma accumulator fragments distributed across the block's warps** (each
warp owns a `dk x dv` sub-tile of `S`), persisting in registers across the
sequential chunk loop. Steps 3/4 read `S` as the stationary mma operand; step 6
accumulates into it. This **frees the entire 64KB** - shared then holds only the
per-chunk K/Q/U/A tiles. (The chunked algorithm's whole point is that the heavy
work is intra-chunk and state-free, so the state need not be in shared.)
2. **dv-slab tiling for occupancy (the secondary move).** If register pressure
from a register-resident `128x128` state caps the kernel at 1 block/SM (likely
- that is a lot of accumulator registers), split the `dv=128` value dimension
into slabs (`dv_tile in {64, 32}`). Each warp-group owns a `128 x dv_tile`
state slab. `A` and the solve depend only on `K` (not `dv`), so they are
computed once and the `C x C` `A^{-1}` is **broadcast/recomputed** per slab
(cheap once it is mma'd). This shrinks per-block register/shared pressure and is
the lever for >1 block/SM.
**Shared budget at `C = 64` (state register-resident), staging K/Q as bf16/tf32:**
| Buffer | Elems | Bytes |
|---|---|---|
| `Kc` (chunk K) | `C x dk = 64x128` | 16KB (bf16) |
| `Qc` (chunk Q) | `C x dk` | 16KB (bf16) |
| `Uc` (solved U) | `C x dv = 64x128` | 32KB (f32 for the solve) / 16KB (bf16 for the P.U + K^T U mma) |
| `A`/`P` scratch | `C x C = 64x64` | 16KB (f32) |
| gates `cs/gam/beta` | `~3C` | <1KB |
| **state** | (registers) | **0KB shared** |
| **Total** | | **~64-80KB** (under the 99KB opt-in) |
So **`C = 64` fits the 99KB budget once the state is register-resident** - 4x the
0031 chunk, and a natural multiple of the `m16n8k*` tiles. For >1 block/SM, drop
to `C = 32` + bf16-staged U (`8 + 8 + 16 + 4 = 36KB`, two blocks fit under the
~49.5KB/block needed) and/or dv-slab the state. **Recommended default: `C = 64`,
tf32 mma, state register-resident** (maximize the BW-saving `C` first; chase the
second block/SM only if the bench says occupancy, not BW, is the residual).
---
## 3. Occupancy plan (break the 1 block/SM ceiling)
0031 is pinned to 1 block/SM by the 64KB shared state. The plan, in priority order:
1. **Free the 64KB: state register-resident** (section 2). This alone may not give
2 blocks/SM (the register-distributed `128x128` f32 accumulator is heavy), but
it is the precondition for everything and it lets `C` grow to 64 - which is the
dominant win (`C`x less state BW). Even at 1 block/SM, `C=64` + mma should flip
the sign vs 0031.
2. **dv-slab the state** (`dv_tile = 64` then `32`): halve/quarter the per-block
accumulator-register and shared pressure to admit a 2nd resident block, at the
cost of recomputing the `C x C` `A^{-1}` per slab (cheap on mma). This is the
primary occupancy lever once (1) is in.
3. **`cp.async` double-buffer the K/Q chunk loads**: overlap the next chunk's
global->shared staging with the current chunk's mma, hiding LPDDR5x latency that
1-2 blocks/SM cannot. No TMA on sm_121, so plain `cp.async` (`commit_group` /
`wait_group`), Ampere-style.
4. **Grid starvation at low `n_seqs`** (0031's other failure: grid is `H x n_seqs`,
~few hundred blocks): the larger `C` reduces per-block serial chunk steps, and
dv-slabbing **multiplies the grid by the slab count** (`H x n_seqs x n_slabs`),
directly mitigating the low-`n_seqs` starvation that hurt 0031.
Honest occupancy caveat: a register-resident `128x128` f32 state is a large
register commitment; the realistic outcome is **1-2 blocks/SM**, not high
occupancy. The design leans on **mma throughput + cp.async latency hiding + the
`C`x BW cut**, not on many resident blocks, to win. If profiling shows the kernel
register-capped at 1 block/SM *and* tensor-core-active-% still low, that is the
signal to dv-slab harder (smaller `dv_tile`) or accept the achieved win.
---
## 4. Bit-exactness + precision risk
This is a **NEW FP path on top of a NEW FP path**. 0031 is already not byte-equal
to the sequential recurrence (different reduction order; README s5 records it as a
benign per-path result). Adding tf32/bf16 mma is a *further* reduced-precision
step. Gate it exactly like the backend's other new-FP-path precedents
(`PAGED_BITEXACT_NOTE.md`, the paged-MoE `8cb0ce23`, the PREFILL_GEMM scope):
- **Greedy md5 stability** on the standard prompt (README s5 harness) - to catch
*unexpected* divergence on the non-prefill paths (decode must stay on the tuned
sequential kernel and byte-match its reference; this lever is prefill-only and
opt-in, so the default path is untouched).
- **`test-backend-ops GATED_DELTA_NET`** at the 0031 prefill shapes (the
`S_v=128` exact-multiple / tail / multi-seq / GQA / permuted cases), CUDA0 vs the
CPU f32 oracle. **Honest expectation: bf16 mma will likely NOT clear the 1e-7
NMSE gate; tf32 is borderline.** So the binding gate is the **KL-gate**, not
strict NMSE: require `KLD(tensorcore || f16) <= KLD(sequential || f16)` and PPL
within the established band, recorded in `PAGED_BITEXACT_NOTE.md`. tf32 (10-bit
mantissa, f32 accumulate) is the precision default precisely to give the KL-gate
the best chance.
- **Precision fallback ladder if tf32 fails the KL-gate:** (i) **3xtf32**
emulation (split each f32 operand into 3 tf32 limbs, 3 mma's, recombine - the
CUTLASS fp32-emulation trick; near-f32 accuracy at 3x the mma cost, still far
cheaper than serial f32 loops and still a likely net win given the `C`x BW cut);
(ii) keep the **decay-coupled and state-boundary products in 3xtf32/f32** while
the well-conditioned intra-chunk Gram products use plain tf32 (mixed precision by
sensitivity). Do **not** fall back to bf16 for the decay-sensitive terms.
- **Preserve the bounded de-gating (section 2):** decays/`gamma`/`beta` stay f32,
applied outside the mma. Re-run the adversarial `g in [-20, -1e-4]` op case
specifically; a tensor-core kernel that moved a decay inside the mma would be a
silent precision regression even if the benign cases pass.
The likely-favourable framing (as in PREFILL_GEMM): keeping the heavy reductions
in f32-accumulate tensor cores is *more* precise than a naive f32 serial loop only
if the inputs stay full-width; here inputs are down-cast (tf32/bf16), so this is a
genuine precision *trade*, not a free win - hence the KL-gate is mandatory and the
3xtf32 ladder exists. Treat NMSE-gate-pass as a bonus, KL-gate-pass as the bar.
---
## 5. Honest effort + expected gain
**This is a multi-week GPU kernel project, not a routing change.** Unlike the
PREFILL_GEMM dense lever (a dispatch flip onto an existing vendor kernel), there is
no vendor chunked-GDN kernel to route to on sm_121 (CUTLASS/FLA gate the good
paths to sm_100a; that is the whole reason vLLM falls back to Marlin on GB10). We
must write the `mma` kernel ourselves. Realistic estimate: **4-8 weeks** of
focused kernel work, high risk, with non-trivial probability the occupancy/register
wall caps the win.
**Expected gain (mechanism-grounded, section 0/regime-note):** the lever attacks
the state-BW that dominates sequential GDN prefill by `~C`x (chunking) while
tensor cores absorb the `O(C^2)` intra-chunk flops. Fully realized, it targets
vLLM's ~2.5x-cheaper chunked GDN prefill bucket = the ~17% prefill lever the
ground-truth attributes to GDN. It should also help the serial-SSM portion of the
**decode** residual (README names the irreducible "serial-SSM host loop" as part
of the decode floor; a chunked state-update reduces the per-step state traffic
there too, though decode `n_tokens` is small so the prefill regime is where it
pays). **Honest ceiling:** sm_121 has no wgmma/tcgen05, so we cannot match a
hypothetical sm_100a FLA kernel's throughput - the `mma.sync` path is the Ampere-
class programming model on Blackwell silicon. But `mma` over serial f32 reductions
is an order-of-magnitude flop-rate jump, which is more than enough to flip 0031's
-22% into a win and recover most of the GDN prefill bucket. Do not promise full
parity with vLLM's sm_100-class kernels; promise "beats the sequential scan and
closes most of the GDN prefill gap."
**Risk register:**
- Register-resident `128x128` state may cap occupancy at 1 block/SM (section 3) -
mitigated by dv-slabbing, but slabbing recomputes `A^{-1}` per slab.
- tf32 may miss the KL-gate -> 3xtf32 ladder (3x mma cost) -> thinner margin.
- The win is contingent on prefill GDN being state-BW-bound (regime note); a GPU
re-profile that says otherwise kills the lever (step 0).
- Blocked-forward-subst register pressure trades against state-register pressure;
both compete for the same budget on a 1-block/SM kernel.
---
## 6. Phased build plan
Smallest tensor-core proof-of-concept first, bit-exact/KL-gate + A/B bench at every
phase, per `.agents/vllm-parity-methodology.md` (one lever at a time, record
rejected/flat variants, ground-truth both engines).
### Phase 0 - re-confirm the regime on GPU (NO code)
nsys a **prefill-only** window (`llama-batched-bench -npp <large> -ntg 0/1`,
exclude graph capture) on q36-27b-nvfp4 + q36-35b-a3b, at the backend pin, with
`GDN_CHUNK_MIN` set so 0031 runs. Confirm (a) the GDN prefill bucket is
state-BW-bound (state memcpy/recurrence dominates, tensor-core-active-% low), and
(b) it is ~17% of the prefill step / ~2.5x vLLM's. **If prefill GDN is not
state-BW-bound, stop and re-scope** - the entire mechanism (section 0) fails.
### Phase 1 - PoC: tensor-core just TWO products, same occupancy
Keep 0031's `C=16` all-shared layout and 1 block/SM. Replace **only** the two
cleanest `C x C` Gram products - step 1 (`KK` for `A`) and step 2 (`QK` for `P`) -
with `ggml/src/ggml-cuda/mma.cuh` tf32 tiles (decays still applied in f32 after).
Leave the solve, the `S0` products, and the state update serial. This is the
minimal "do tensor cores help here at all" probe at fixed occupancy.
- Gate: greedy md5 stable; `test-backend-ops GATED_DELTA_NET` prefill shapes via
the KL-gate (NMSE if it passes).
- Bench: `llama-batched-bench` S_PP, A/B vs sequential and vs 0031-serial, same
harness. **If even this does not move S_PP, the head-dim/occupancy is the wall,
not the reductions - learn it cheaply before the big build.**
### Phase 2 - full intra-chunk tensor-core + register-resident state + C=64
State register-resident (free the 64KB), `C=64`, tf32 mma for all of steps 1-6,
blocked-forward-subst `A^{-1}` (step 7) with mma off-diagonal coupling +
in-register `16x16` diagonal solves. Decays/gamma/beta stay f32 throughout.
- Gate: as Phase 1, plus the adversarial `g in [-20,-1e-4]` op case explicitly.
If tf32 misses the KL-gate, climb the 3xtf32 ladder (section 4).
- Bench: S_PP A/B vs sequential, sweep prefill length and `npl`; record the
`C in {32,64,128}` sweep and any rejected `C`.
### Phase 3 - occupancy + latency hiding
dv-slab the state (`dv_tile in {64,32}`) for a 2nd resident block and to multiply
the grid (fix low-`n_seqs` starvation); `cp.async` double-buffer the K/Q chunk
loads. Tune `C`, `dv_tile`, warp count per the bench.
- Gate: unchanged (the FP path does not change; this is scheduling).
- Bench: final S_PP vs sequential + indicative % of vLLM prefill; name the
residual floor honestly (register-cap / sm_121-has-no-tcgen05).
### Disposition
Like 0031, ship **opt-in default-OFF first** (extend the existing `GDN_CHUNK_MIN`
gate, add a `GDN_CHUNK_TC` selector if the serial path is kept as fallback). Flip
the default only when a separately-built A/B proves S_PP beats the sequential scan
*and* the KL-gate holds, recorded in README section 5 + `PAGED_BITEXACT_NOTE.md`.
If a phase comes back flat-or-slower, record it as a rejected lever with the reason
(the most valuable output if it fails) and keep 0031's serial path as the shipped
prefill kernel.
---
## 7. Summary
| Aspect | Decision |
|---|---|
| Tensor-core ISA | **`mma.sync` only** (sm_121a: no wgmma, no tcgen05/TMEM - confirmed) |
| Building block | reuse `ggml/src/ggml-cuda/mma.cuh` tiles, not raw PTX |
| Precision default | **tf32** inputs / f32 accumulate; **3xtf32** ladder if KL-gate misses; bf16 only for well-conditioned Gram terms |
| Decay handling | gamma/d/beta stay **f32**, applied outside the mma (preserve bounded de-gating) |
| A-inverse | blocked forward substitution (FLA UT-transform): in-register diagonal solves + mma off-diagonal |
| Chunk size | **C=64** default (4x 0031), C=32 for 2 blocks/SM |
| State | **register-resident** (frees the 64KB that forced C=16); dv-slab for occupancy |
| Shared budget | ~64-80KB at C=64 state-register-resident (under the 99KB opt-in) |
| Mechanism / why it wins | chunking cuts state-BW by ~Cx; mma absorbs the O(C^2) intra-chunk flops the serial 0031 could not |
| Bit-exact | NEW per-path; **KL-gate** binding (NMSE likely fails at reduced precision), greedy md5 + adversarial-decay op case |
| Effort | **multi-week (4-8 wk), high risk**; no vendor kernel to route to on sm_121 |
| Expected gain | beats the sequential scan, closes most of the ~17% GDN prefill bucket toward vLLM's 2.5x; also helps the decode serial-SSM residual. NOT full sm_100-class parity. |
| Phasing | P0 re-profile -> P1 two-product PoC -> P2 full intra-chunk + C=64 + reg-state -> P3 occupancy/cp.async; opt-in default-OFF until A/B-proven |
Decode is untouched (this is prefill-only, opt-in); the stock `llama-cpp` backend
stays patch-free. This lever lives entirely in `llama-cpp-localai-paged`.