feat(w4a16): grow tile to BN128/16w (q4_K +17%, pp512 148->178)

P3b-2 for the Blackwell W4A16 Marlin GEMM. The q4_K dequant wall is partly
cross-N-block-redundant: every N-block re-decodes the same weight strip, so
halving the N-block count (BN 64->128) halves that redundant 6-bit superblock
decode. A BN sweep showed this only pays off when BN is spread across more
warps (16 warps, 8 m16n8 C-tiles/warp) rather than more fragments-per-warp -
the FN=8 / FM=4 variants (16 C-tiles/warp) regressed to ~6.6 TFLOPS on
register pressure. Shipping tile is now WM=4,WN=4,FM=2,FN=4 -> BM=128, BN=128,
16 warps.

Thermally-bracketed cold A/B (q4_K n=512 / q4_0 n=512 via test-backend-ops
perf; pp512/pp2048 via llama-bench Qwen3-32B-Q4_K_M):
  BN64/8w  (prev): 8.50 / 10.56 TFLOPS, measured 8.45/10.51 again (bracket)
  BN128/16w (this): 9.92 / 11.68 TFLOPS, pp512 177.6, pp2048 185.0
  -> +17% q4_K, +11% q4_0, +20% pp512 vs the previous commit; +49% pp512 vs
     the original block-tiled kernel (119).

Parity gate GGML_CUDA_W4A16=1 test-backend-ops MUL_MAT = 1103/1103, flag set
and unset (byte-identical when unset). Still ~4.7x under MMQ (47 TFLOPS) and
does NOT beat MMQ; BN growth divides the redundant decode but cannot remove
the per-k-step decode itself - the offline weight prepack remains the next
unlock for q4_K. Plan doc P3 table + bottleneck notes updated.

Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-06-21 02:01:12 +00:00
parent 2f648dc6a0
commit 2b79083b71
2 changed files with 30 additions and 20 deletions

View File

@@ -86,21 +86,24 @@ the `BN x 16` f32->bf16 activation strip into shared, one `__syncthreads`, then
(`W4A16_SPAD`): ldmatrix's per-lane address is `row*stride`, and the natural stride 8 (a divisor of the
32-bank / 128-byte cycle) collides rows 0,4,8,12 into a 2-way bank conflict; skewing to 12 (4-byte aligned, so
ldmatrix's 16-byte alignment holds) makes `{r*12 mod 32}` hit 8 distinct bank-quads for r in 0..7, so both
halves of ldmatrix are conflict-free at only +50% on the small (~6 KB) staged tile. Shipping config
`WM=4,WN=2,FM=2,FN=4` -> `BM=128, BN=64`, 8 warps. M/N tails zero-padded in-kernel; still gated to contiguous
2D Q4_0/Q4_K f32 prefill, else falls back to MMQ.
halves of ldmatrix are conflict-free at only +50% on the small staged tile (~12 KB at the shipping tile).
Shipping config `WM=4,WN=4,FM=2,FN=4` -> `BM=128, BN=128`, 16 warps, 8 m16n8 C-tiles per warp (keeping
register pressure low is what lets BN grow without an occupancy cliff). M/N tails zero-padded in-kernel; still
gated to contiguous 2D Q4_0/Q4_K f32 prefill, else falls back to MMQ.
**Per-step results (q4_K n=512 via `test-backend-ops perf`; pp512/pp2048 via llama-bench Qwen3-32B-Q4_K_M):**
| step | q4_K n=512 | q4_0 n=512 | pp512 | pp2048 | vs MMQ 47 / 718 | notes |
|---|---|---|---|---|---|---|
| P2 (1 warp/tile) | ~2 TFLOPS | - | 31.75 | - | 0.04x | correctness checkpoint |
| Step 1: block tiling (load_generic, BM64/4w) | 6.63 (cold) | 7.53 | 119 | 123 | 0.14x | prior committed kernel |
| **P3b: skew-pad ldmatrix + BM128/8w** | **8.52 (cold)** | **10.49** | **148.5** | **153.9** | **0.18x** | +28% q4_K, +40% q4_0, +25% pp512 over step 1 |
| Step 1: block tiling (load_generic, BM64/4w) | 6.63 (cold) | 7.53 | 119 | 123 | 0.14x | original committed kernel |
| P3b-1: skew-pad ldmatrix + BM128/8w | 8.50 (cold) | 10.56 | 148.5 | 153.9 | 0.18x | +28% q4_K, +40% q4_0 over step 1 |
| **P3b-2: + BN128/16w (current)** | **9.92 (cold)** | **11.68** | **177.6** | **185.0** | **0.21x** | +17% q4_K, +20% pp512 over P3b-1 (+49% pp512 over step 1) |
Parity gate **1103/1103** at every step, flag set and unset (byte-identical when unset). All P3b numbers above
are from a single thermally-bracketed cold A/B session (committed measured 6.63/7.53 immediately before AND
after the P3b kernel, identical both times -> the deltas are real, not thermal).
are from thermally-bracketed cold A/B sessions (committed measured immediately before AND after each candidate,
identical both times -> the deltas are real, not thermal). P3b-1 cold A/B: 6.63/7.53 vs 8.52/10.49. P3b-2 cold
A/B: BN64/8w 10.56/8.50 then 10.51/8.45 (bracket) vs BN128/16w 11.68/9.92.
**What landed / what was tried (honest):**
- **P3b - LANDED (committed).** Two combined changes lift the prior committed kernel: (1) **skew-pad
@@ -119,18 +122,25 @@ after the P3b kernel, identical both times -> the deltas are real, not thermal).
when the tile grew (BM128/8w). A 5-config tile sweep then split the two quant types:
- **q4_0 SCALES with warps/tiles** (7.7 -> 10.5 -> **15.8 TFLOPS at BM128/16w**): feed/global-traffic bound,
helped by cutting redundant activation re-reads (more BM = fewer M-blocks each re-reading the act column).
- **q4_K is now DEQUANT-COMPUTE bound** (stuck 6.7-8.5 across every tile; at 16 warps q4_0=15.8 but q4_K=6.8 -
they diverge hard). This **refines P3's "within 12%" finding**: that held only in the low-throughput memory
-bound regime; once the feed is unblocked, q4_K's per-element 6-bit superblock decode (`get_scale_min_k4` +
superblock indexing, redone every k-step AND re-done per N-block) becomes the wall. BM256 regressed both
(too few blocks / register pressure).
- **Next blocker (the real q4_K unlock) = offline prepack.** The dequant wall is cross-block-redundant: the same
q4_K weights are superblock-decoded by all 8 N-blocks. The fix is the **one-time offline repack** - decode the
Q4 tensor ONCE into a cached device buffer keyed off the tensor data pointer, in a layout with the scale/min
pre-applied (store reshuffled 4-bit + per-subblock bf16 d,m, ~1.25x the q4 size, NOT a full bf16 blow-up which
would be ~4x), so the in-kernel path becomes a cheap `q*d - m` with coalesced loads. Then `cp.async`
multi-stage (sized to NOT widen shared past the occupancy cliff) and **Stream-K** over M. These remain the
multi-week core; **prepack is the highest-value next step for q4_K specifically.**
- **q4_K is largely DEQUANT-COMPUTE bound** (the BM64/16w tile gives q4_0=15.8 but q4_K=6.8 - they diverge
hard). This **refines P3's "within 12%" finding**: that held only in the low-throughput memory-bound regime;
once the feed is unblocked, q4_K's per-element 6-bit superblock decode (`get_scale_min_k4` + superblock
indexing, redone every k-step AND re-done by every N-block) becomes the wall. BM256 regressed both (too few
blocks / register pressure).
- **Growing BN partly relieves the q4_K dequant wall (P3b-2).** Because every N-block re-decodes the same
weight strip, halving the N-block count (BN 64->128) halves that redundant q4_K decode - but only when BN is
spread across MORE WARPS (16w, 8 C-tiles/warp), not more fragments-per-warp: the FN=8 / FM=4 variants (16
C-tiles/warp) regressed to ~6.6 on register pressure, while WM=4,WN=4,FM=2,FN=4 (16w, 8 tiles/warp) lifted
q4_K 8.5->9.9 and q4_0 10.6->11.7 cold. BN=256 was no better and costs more shared. **BN128/16w is the
shipping tile.**
- **Next blocker (the remaining q4_K unlock) = offline prepack.** BN growth only divides the redundant decode by
the N-block count; it cannot remove the per-k-step decode itself. The full fix is the **one-time offline
repack** - decode the Q4 tensor ONCE into a cached device buffer keyed off the tensor data pointer, in a layout
with the scale/min pre-applied (store reshuffled 4-bit + per-subblock bf16 d,m, ~1.25x the q4 size, NOT a full
bf16 blow-up which would be ~4x), so the in-kernel path becomes a cheap `q*d - m` with coalesced loads. Then
`cp.async` multi-stage (sized to NOT widen shared past the occupancy cliff) and **Stream-K** over M. These
remain the multi-week core; **prepack is the highest-value next step for q4_K specifically** (it should let
q4_K join q4_0 on the feed-bound scaling curve instead of plateauing at ~10).
- **Methodology note (unchanged):** the box thermally throttles under sustained perf+bench runs (identical code
~8.8 cold vs ~6.6 hot earlier), so only same-session A/Bs are trustworthy. The P3b deltas above were taken in
one bracketed cold session for exactly this reason.

View File

@@ -239,7 +239,7 @@ bool ggml_cuda_w4a16_mul_mat(
cudaStream_t stream = ctx.stream();
// Block tile config: WM*WN warps compute BM(=WM*FM*16) x BN(=WN*FN*8).
constexpr int WM = 4, WN = 2, FM = 2, FN = 4; // BM=128, BN=64, 8 warps
constexpr int WM = 4, WN = 4, FM = 2, FN = 4; // BM=128, BN=128, 16 warps
constexpr int BM = WM*FM*16;
constexpr int BN = WN*FN*8;
const dim3 grid((unsigned)((M + BM - 1) / BM), (unsigned)((N + BN - 1) / BN), 1);