From 2b79083b71ec2c9dd476c46c9f2607471a1fbcb9 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 21 Jun 2026 02:01:12 +0000 Subject: [PATCH] feat(w4a16): grow tile to BN128/16w (q4_K +17%, pp512 148->178) P3b-2 for the Blackwell W4A16 Marlin GEMM. The q4_K dequant wall is partly cross-N-block-redundant: every N-block re-decodes the same weight strip, so halving the N-block count (BN 64->128) halves that redundant 6-bit superblock decode. A BN sweep showed this only pays off when BN is spread across more warps (16 warps, 8 m16n8 C-tiles/warp) rather than more fragments-per-warp - the FN=8 / FM=4 variants (16 C-tiles/warp) regressed to ~6.6 TFLOPS on register pressure. Shipping tile is now WM=4,WN=4,FM=2,FN=4 -> BM=128, BN=128, 16 warps. Thermally-bracketed cold A/B (q4_K n=512 / q4_0 n=512 via test-backend-ops perf; pp512/pp2048 via llama-bench Qwen3-32B-Q4_K_M): BN64/8w (prev): 8.50 / 10.56 TFLOPS, measured 8.45/10.51 again (bracket) BN128/16w (this): 9.92 / 11.68 TFLOPS, pp512 177.6, pp2048 185.0 -> +17% q4_K, +11% q4_0, +20% pp512 vs the previous commit; +49% pp512 vs the original block-tiled kernel (119). Parity gate GGML_CUDA_W4A16=1 test-backend-ops MUL_MAT = 1103/1103, flag set and unset (byte-identical when unset). Still ~4.7x under MMQ (47 TFLOPS) and does NOT beat MMQ; BN growth divides the redundant decode but cannot remove the per-k-step decode itself - the offline weight prepack remains the next unlock for q4_K. Plan doc P3 table + bottleneck notes updated. Assisted-by: Claude:opus-4.8 [Claude Code] Signed-off-by: Ettore Di Giacinto --- .../paged/W4A16_MARLIN_KERNEL_PLAN.md | 48 +++++++++++-------- .../paged/kernel/w4a16/marlin-w4a16.cu | 2 +- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/backend/cpp/llama-cpp/paged/W4A16_MARLIN_KERNEL_PLAN.md b/backend/cpp/llama-cpp/paged/W4A16_MARLIN_KERNEL_PLAN.md index 5db0d18d2..e46cc6712 100644 --- a/backend/cpp/llama-cpp/paged/W4A16_MARLIN_KERNEL_PLAN.md +++ b/backend/cpp/llama-cpp/paged/W4A16_MARLIN_KERNEL_PLAN.md @@ -86,21 +86,24 @@ the `BN x 16` f32->bf16 activation strip into shared, one `__syncthreads`, then (`W4A16_SPAD`): ldmatrix's per-lane address is `row*stride`, and the natural stride 8 (a divisor of the 32-bank / 128-byte cycle) collides rows 0,4,8,12 into a 2-way bank conflict; skewing to 12 (4-byte aligned, so ldmatrix's 16-byte alignment holds) makes `{r*12 mod 32}` hit 8 distinct bank-quads for r in 0..7, so both -halves of ldmatrix are conflict-free at only +50% on the small (~6 KB) staged tile. Shipping config -`WM=4,WN=2,FM=2,FN=4` -> `BM=128, BN=64`, 8 warps. M/N tails zero-padded in-kernel; still gated to contiguous -2D Q4_0/Q4_K f32 prefill, else falls back to MMQ. +halves of ldmatrix are conflict-free at only +50% on the small staged tile (~12 KB at the shipping tile). +Shipping config `WM=4,WN=4,FM=2,FN=4` -> `BM=128, BN=128`, 16 warps, 8 m16n8 C-tiles per warp (keeping +register pressure low is what lets BN grow without an occupancy cliff). M/N tails zero-padded in-kernel; still +gated to contiguous 2D Q4_0/Q4_K f32 prefill, else falls back to MMQ. **Per-step results (q4_K n=512 via `test-backend-ops perf`; pp512/pp2048 via llama-bench Qwen3-32B-Q4_K_M):** | step | q4_K n=512 | q4_0 n=512 | pp512 | pp2048 | vs MMQ 47 / 718 | notes | |---|---|---|---|---|---|---| | P2 (1 warp/tile) | ~2 TFLOPS | - | 31.75 | - | 0.04x | correctness checkpoint | -| Step 1: block tiling (load_generic, BM64/4w) | 6.63 (cold) | 7.53 | 119 | 123 | 0.14x | prior committed kernel | -| **P3b: skew-pad ldmatrix + BM128/8w** | **8.52 (cold)** | **10.49** | **148.5** | **153.9** | **0.18x** | +28% q4_K, +40% q4_0, +25% pp512 over step 1 | +| Step 1: block tiling (load_generic, BM64/4w) | 6.63 (cold) | 7.53 | 119 | 123 | 0.14x | original committed kernel | +| P3b-1: skew-pad ldmatrix + BM128/8w | 8.50 (cold) | 10.56 | 148.5 | 153.9 | 0.18x | +28% q4_K, +40% q4_0 over step 1 | +| **P3b-2: + BN128/16w (current)** | **9.92 (cold)** | **11.68** | **177.6** | **185.0** | **0.21x** | +17% q4_K, +20% pp512 over P3b-1 (+49% pp512 over step 1) | Parity gate **1103/1103** at every step, flag set and unset (byte-identical when unset). All P3b numbers above -are from a single thermally-bracketed cold A/B session (committed measured 6.63/7.53 immediately before AND -after the P3b kernel, identical both times -> the deltas are real, not thermal). +are from thermally-bracketed cold A/B sessions (committed measured immediately before AND after each candidate, +identical both times -> the deltas are real, not thermal). P3b-1 cold A/B: 6.63/7.53 vs 8.52/10.49. P3b-2 cold +A/B: BN64/8w 10.56/8.50 then 10.51/8.45 (bracket) vs BN128/16w 11.68/9.92. **What landed / what was tried (honest):** - **P3b - LANDED (committed).** Two combined changes lift the prior committed kernel: (1) **skew-pad @@ -119,18 +122,25 @@ after the P3b kernel, identical both times -> the deltas are real, not thermal). when the tile grew (BM128/8w). A 5-config tile sweep then split the two quant types: - **q4_0 SCALES with warps/tiles** (7.7 -> 10.5 -> **15.8 TFLOPS at BM128/16w**): feed/global-traffic bound, helped by cutting redundant activation re-reads (more BM = fewer M-blocks each re-reading the act column). - - **q4_K is now DEQUANT-COMPUTE bound** (stuck 6.7-8.5 across every tile; at 16 warps q4_0=15.8 but q4_K=6.8 - - they diverge hard). This **refines P3's "within 12%" finding**: that held only in the low-throughput memory - -bound regime; once the feed is unblocked, q4_K's per-element 6-bit superblock decode (`get_scale_min_k4` + - superblock indexing, redone every k-step AND re-done per N-block) becomes the wall. BM256 regressed both - (too few blocks / register pressure). -- **Next blocker (the real q4_K unlock) = offline prepack.** The dequant wall is cross-block-redundant: the same - q4_K weights are superblock-decoded by all 8 N-blocks. The fix is the **one-time offline repack** - decode the - Q4 tensor ONCE into a cached device buffer keyed off the tensor data pointer, in a layout with the scale/min - pre-applied (store reshuffled 4-bit + per-subblock bf16 d,m, ~1.25x the q4 size, NOT a full bf16 blow-up which - would be ~4x), so the in-kernel path becomes a cheap `q*d - m` with coalesced loads. Then `cp.async` - multi-stage (sized to NOT widen shared past the occupancy cliff) and **Stream-K** over M. These remain the - multi-week core; **prepack is the highest-value next step for q4_K specifically.** + - **q4_K is largely DEQUANT-COMPUTE bound** (the BM64/16w tile gives q4_0=15.8 but q4_K=6.8 - they diverge + hard). This **refines P3's "within 12%" finding**: that held only in the low-throughput memory-bound regime; + once the feed is unblocked, q4_K's per-element 6-bit superblock decode (`get_scale_min_k4` + superblock + indexing, redone every k-step AND re-done by every N-block) becomes the wall. BM256 regressed both (too few + blocks / register pressure). +- **Growing BN partly relieves the q4_K dequant wall (P3b-2).** Because every N-block re-decodes the same + weight strip, halving the N-block count (BN 64->128) halves that redundant q4_K decode - but only when BN is + spread across MORE WARPS (16w, 8 C-tiles/warp), not more fragments-per-warp: the FN=8 / FM=4 variants (16 + C-tiles/warp) regressed to ~6.6 on register pressure, while WM=4,WN=4,FM=2,FN=4 (16w, 8 tiles/warp) lifted + q4_K 8.5->9.9 and q4_0 10.6->11.7 cold. BN=256 was no better and costs more shared. **BN128/16w is the + shipping tile.** +- **Next blocker (the remaining q4_K unlock) = offline prepack.** BN growth only divides the redundant decode by + the N-block count; it cannot remove the per-k-step decode itself. The full fix is the **one-time offline + repack** - decode the Q4 tensor ONCE into a cached device buffer keyed off the tensor data pointer, in a layout + with the scale/min pre-applied (store reshuffled 4-bit + per-subblock bf16 d,m, ~1.25x the q4 size, NOT a full + bf16 blow-up which would be ~4x), so the in-kernel path becomes a cheap `q*d - m` with coalesced loads. Then + `cp.async` multi-stage (sized to NOT widen shared past the occupancy cliff) and **Stream-K** over M. These + remain the multi-week core; **prepack is the highest-value next step for q4_K specifically** (it should let + q4_K join q4_0 on the feed-bound scaling curve instead of plateauing at ~10). - **Methodology note (unchanged):** the box thermally throttles under sustained perf+bench runs (identical code ~8.8 cold vs ~6.6 hot earlier), so only same-session A/Bs are trustworthy. The P3b deltas above were taken in one bracketed cold session for exactly this reason. diff --git a/backend/cpp/llama-cpp/paged/kernel/w4a16/marlin-w4a16.cu b/backend/cpp/llama-cpp/paged/kernel/w4a16/marlin-w4a16.cu index 48b1816ff..57064ee42 100644 --- a/backend/cpp/llama-cpp/paged/kernel/w4a16/marlin-w4a16.cu +++ b/backend/cpp/llama-cpp/paged/kernel/w4a16/marlin-w4a16.cu @@ -239,7 +239,7 @@ bool ggml_cuda_w4a16_mul_mat( cudaStream_t stream = ctx.stream(); // Block tile config: WM*WN warps compute BM(=WM*FM*16) x BN(=WN*FN*8). - constexpr int WM = 4, WN = 2, FM = 2, FN = 4; // BM=128, BN=64, 8 warps + constexpr int WM = 4, WN = 4, FM = 2, FN = 4; // BM=128, BN=128, 16 warps constexpr int BM = WM*FM*16; constexpr int BN = WN*FN*8; const dim3 grid((unsigned)((M + BM - 1) / BM), (unsigned)((N + BN - 1) / BN), 1);