From 2f648dc6a06b3bc7d157bdfd6c6f6da745afaa80 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 21 Jun 2026 01:15:07 +0000 Subject: [PATCH] feat(w4a16): conflict-free skew-pad ldmatrix + BM128/8w tile (q4_K +28%, q4_0 +40%) P3b for the Blackwell (sm_120/121) W4A16 Marlin GEMM. Two combined changes over the prior block-tiled kernel, both verified by a thermally-bracketed cold A/B (committed measured identically before and after): - Skew-padded shared layout: store the staged weight/activation rows at a padded stride of 12 bf162 (8 data + 4 pad) and feed the tensor cores with ldmatrix.x4 (A) / ldmatrix.x2 (B). ldmatrix's per-lane address is row*stride; the natural stride 8 divides the 32-bank cycle and collides rows 0,4,8,12 (2-way bank conflict). Skewing to 12 (still 16-byte aligned) spreads {r*12 mod 32} across 8 distinct bank-quads, so both ldmatrix halves are conflict-free at only +50% on the ~6 KB staged tile - unlike a 128-byte -row XOR swizzle, which is conflict-free but needs 16 KB shared and collapses occupancy on GB10 (measured 2.84 TFLOPS, worse than baseline). - Larger tile: BM=128, BN=64, 8 warps (WM=4,WN=2,FM=2,FN=4), which cuts the redundant per-M-block activation re-reads. Cold A/B (q4_K n=512 / q4_0 n=512 via test-backend-ops perf; pp512/pp2048 via llama-bench Qwen3-32B-Q4_K_M): committed: 6.63 / 7.53 TFLOPS, pp512 119 this: 8.52 / 10.49 TFLOPS, pp512 148.5, pp2048 153.9 (+28% / +40% / +25%) Parity gate GGML_CUDA_W4A16=1 test-backend-ops MUL_MAT = 1103/1103, flag set and unset (byte-identical when unset). Still ~5.5x under MMQ (47 TFLOPS) and does NOT beat MMQ yet; the q4_K limiter has now moved from the mma feed to the per-element 6-bit superblock dequant (q4_0 scales to 15.8 TFLOPS with more warps while q4_K stays ~8.5), so the offline weight prepack is the next unlock. Plan doc P3 section updated with the sweep data and the corrected bottleneck. Assisted-by: Claude:opus-4.8 [Claude Code] Signed-off-by: Ettore Di Giacinto --- .../paged/W4A16_MARLIN_KERNEL_PLAN.md | 86 +++++++++++-------- .../paged/kernel/w4a16/marlin-w4a16.cu | 61 +++++++------ 2 files changed, 86 insertions(+), 61 deletions(-) diff --git a/backend/cpp/llama-cpp/paged/W4A16_MARLIN_KERNEL_PLAN.md b/backend/cpp/llama-cpp/paged/W4A16_MARLIN_KERNEL_PLAN.md index 60ff8d667..5db0d18d2 100644 --- a/backend/cpp/llama-cpp/paged/W4A16_MARLIN_KERNEL_PLAN.md +++ b/backend/cpp/llama-cpp/paged/W4A16_MARLIN_KERNEL_PLAN.md @@ -70,19 +70,24 @@ and **Stream-K** partitioning. Sources: IST-DASLab/marlin, arXiv 2408.11743, vLL re-dequantized per n-tile, no pipeline) - this is the correctness checkpoint; P3 brings the speedup. The real Q4_K model matmul path engages the kernel without error. -### P3 — The Marlin pipeline (the speedup) — STEP 1 LANDED; STEPS 3-4 DEFERRED +### P3 — The Marlin pipeline (the speedup) — STEP 1 + SKEW-PAD/TILING LANDED; PREPACK + PIPELINE + STREAM-K DEFERRED Goal: `cp.async` double/triple-buffered global->shared; offline weight reshuffle (a one-time repack of the Q4 tensor into the mma+pipeline layout); register-resident activation tiles; Stream-K split for the prefill M. Target: >=150 TFLOP/s (>=~2,300 t/s), then ~213. **MMQ baseline to beat: 47.1 TFLOPS (q4_K n=512) / pp512 718.** -**Kernel structure now (committed):** block-tiled multi-warp GEMM. `blockDim=(32, WM*WN)` so `threadIdx.x` is the -warp lane (required by `mma.cuh` get_i/get_j) and `threadIdx.y` is the warp index; the original 1-warp P2 -launch put 128 threads on `threadIdx.x` and exploded `get_j` into an out-of-bounds shared read (found via -compute-sanitizer). `WM*WN` warps compute a `BM(=WM*FM*16) x BN(=WN*FN*8)` output tile; each warp owns an -`FM x FN` grid of m16n8k16 mma fragments accumulated in F32. Per k-step (16-deep): all warps cooperatively -dequant the `BM x 16` Q4 weight strip + load the `BN x 16` f32->bf16 activation strip into a single small -shared buffer (~4 KB), one `__syncthreads`, then `load_generic` fragments + `FM*FN` mmas. Shipping config -`WM=2,WN=2,FM=2,FN=4` -> `BM=64, BN=64`, 4 warps. M/N tails zero-padded in-kernel; still gated to contiguous +**Kernel structure now (committed P3b):** block-tiled multi-warp GEMM with a CONFLICT-FREE shared feed via skew +padding. `blockDim=(32, WM*WN)` so `threadIdx.x` is the warp lane (required by `mma.cuh` get_i/get_j) and +`threadIdx.y` is the warp index; the original 1-warp P2 launch put 128 threads on `threadIdx.x` and exploded +`get_j` into an out-of-bounds shared read (found via compute-sanitizer). `WM*WN` warps compute a +`BM(=WM*FM*16) x BN(=WN*FN*8)` output tile; each warp owns an `FM x FN` grid of m16n8k16 mma fragments +accumulated in F32. Per k-step (16-deep): all warps cooperatively dequant the `BM x 16` Q4 weight strip + load +the `BN x 16` f32->bf16 activation strip into shared, one `__syncthreads`, then `ldmatrix.x4` (A) / `ldmatrix.x2` +(B) fragments + `FM*FN` mmas. The shared rows hold 8 bf162 of data but are stored at a PADDED stride of 12 bf162 +(`W4A16_SPAD`): ldmatrix's per-lane address is `row*stride`, and the natural stride 8 (a divisor of the +32-bank / 128-byte cycle) collides rows 0,4,8,12 into a 2-way bank conflict; skewing to 12 (4-byte aligned, so +ldmatrix's 16-byte alignment holds) makes `{r*12 mod 32}` hit 8 distinct bank-quads for r in 0..7, so both +halves of ldmatrix are conflict-free at only +50% on the small (~6 KB) staged tile. Shipping config +`WM=4,WN=2,FM=2,FN=4` -> `BM=128, BN=64`, 8 warps. M/N tails zero-padded in-kernel; still gated to contiguous 2D Q4_0/Q4_K f32 prefill, else falls back to MMQ. **Per-step results (q4_K n=512 via `test-backend-ops perf`; pp512/pp2048 via llama-bench Qwen3-32B-Q4_K_M):** @@ -90,36 +95,45 @@ shared buffer (~4 KB), one `__syncthreads`, then `load_generic` fragments + `FM* | step | q4_K n=512 | q4_0 n=512 | pp512 | pp2048 | vs MMQ 47 / 718 | notes | |---|---|---|---|---|---|---| | P2 (1 warp/tile) | ~2 TFLOPS | - | 31.75 | - | 0.04x | correctness checkpoint | -| **Step 1: block tiling** | **6.6-8.8 TFLOPS** | 7.5-9.9 | **118-142** | 122-156 | **~0.15-0.19x** | ~3.5-4.4x over P2; the banked win | -| Step 2: dequant reuse | (folded into step 1) | | | | | see below | -| Step 3: pipeline | regressed/neutral | | | | | reverted, see below | -| Step 4: reshuffle + Stream-K | deferred | | | | | not started | +| Step 1: block tiling (load_generic, BM64/4w) | 6.63 (cold) | 7.53 | 119 | 123 | 0.14x | prior committed kernel | +| **P3b: skew-pad ldmatrix + BM128/8w** | **8.52 (cold)** | **10.49** | **148.5** | **153.9** | **0.18x** | +28% q4_K, +40% q4_0, +25% pp512 over step 1 | -Parity gate **1103/1103** at every step, flag set and unset (byte-identical when unset). +Parity gate **1103/1103** at every step, flag set and unset (byte-identical when unset). All P3b numbers above +are from a single thermally-bracketed cold A/B session (committed measured 6.63/7.53 immediately before AND +after the P3b kernel, identical both times -> the deltas are real, not thermal). **What landed / what was tried (honest):** -- **Step 1 (block tiling) - LANDED.** The bulk of the realised win (P2 ~2 -> ~7-9 TFLOPS). This is the - committed kernel. -- **Step 2 (dequant reuse across N) - no extra gain, root-caused.** A tile sweep (BM/BN from 64 to 128, 4-16 - warps) held flat at 8.6-8.8 TFLOPS: enlarging BN to amortize the weight dequant did **not** help. Decisive - diagnostic: q4_0 (trivial dequant) and q4_K (heavy 6-bit superblock dequant) run **within ~12%** of each - other, so **dequant compute is not the limiter** - the shared-load / mma-feed throughput (and occupancy-hidden - global latency) is. Larger BN already reuses the strip across the block; cross-block reuse needs step 4. -- **Step 3 (software pipeline) - tried, reverted.** (a) A double-buffered (`NBUF=2`) KSTAGE=64 stage loader - (dequant stage s+1 into the spare shared buffer while the mma of stage s runs) collapsed occupancy via 32 KB - shared and dropped q4_K n=512 to **2.7 TFLOPS**. (b) Swapping `load_generic` for `ldmatrix` was **neutral** - (~6.6 vs ~6.7 TFLOPS measured in the same thermal window) because the unswizzled row-major shared layout makes - `ldmatrix.x4` bank-conflict. Both reverted; step 1 (small shared, high occupancy) is strictly better on this - GB10. **Methodology note:** the box thermally throttles under sustained perf+bench runs (identical step-1 code - measured 8.83 TFLOPS cold vs 6.65 hot), so only same-session A/Bs are trustworthy - earlier cross-run deltas - were partly thermal. -- **Step 4 (offline weight reshuffle + Stream-K) - DEFERRED, and now known to be the real unlock.** The - evidence above says the path to >=150 TFLOPS is *not* bigger tiles or a naive cp.async pipeline but the full - Marlin machinery: an **XOR-swizzled shared layout** (so `ldmatrix` is conflict-free), a **one-time offline - repack** of the Q4 tensor into that mma+pipeline layout (a load-time transform keyed off the tensor data - pointer; ~M*K/2 bytes prepacked buffer, same size as the q4 weights) so dequant becomes cheap conflict-free - bit-extraction and the per-(m,n)-block re-dequant disappears, a **tuned cp.async multi-stage** sized to keep - occupancy, and **Stream-K** over M. That is the remaining multi-week core. +- **P3b - LANDED (committed).** Two combined changes lift the prior committed kernel: (1) **skew-pad + conflict-free ldmatrix** (shared row stride 8->12 bf162; makes `ldmatrix.x4`/`.x2` bank-conflict-free at near + zero occupancy cost) and (2) **bigger tile / more warps** (`BM=128, BN=64`, 8 warps). Cold A/B: q4_K + 6.63->8.52 (+28%), q4_0 7.53->10.49 (+40%), pp512 119->148.5 (+25%). **Still ~5.5x under MMQ (47) per-op and + ~4.8x under pp512 718 - does NOT beat MMQ.** This is forward progress, not the finish line. +- **The XOR-swizzle-FIRST plan was tested and is WRONG for this GPU - documented so it is not re-tried.** A + wide-row (BK=64, 128-byte rows) XOR swizzle `seg ^ (row&7)` IS conflict-free, but the 16 KB shared it needs + collapsed occupancy and dropped q4_K n=512 to **2.84 TFLOPS** (worse than the unswizzled 6.63) - the same + occupancy cliff P3 hit with a 32 KB pipeline. The conflict-free feed must be bought WITHOUT widening shared: + skew padding (above) does exactly that (6 KB), which is why it is the committed form. Lesson: on GB10 occupancy + dominates bank-conflict latency; never trade occupancy for a conflict-free layout. +- **Conflict-free feed alone did NOT beat the unswizzled kernel - the limiter moved.** At the SAME BM64/4w tile, + skew-pad ldmatrix (6.70) ~= load_generic (6.63): removing bank conflicts bought ~nothing. The win came only + when the tile grew (BM128/8w). A 5-config tile sweep then split the two quant types: + - **q4_0 SCALES with warps/tiles** (7.7 -> 10.5 -> **15.8 TFLOPS at BM128/16w**): feed/global-traffic bound, + helped by cutting redundant activation re-reads (more BM = fewer M-blocks each re-reading the act column). + - **q4_K is now DEQUANT-COMPUTE bound** (stuck 6.7-8.5 across every tile; at 16 warps q4_0=15.8 but q4_K=6.8 - + they diverge hard). This **refines P3's "within 12%" finding**: that held only in the low-throughput memory + -bound regime; once the feed is unblocked, q4_K's per-element 6-bit superblock decode (`get_scale_min_k4` + + superblock indexing, redone every k-step AND re-done per N-block) becomes the wall. BM256 regressed both + (too few blocks / register pressure). +- **Next blocker (the real q4_K unlock) = offline prepack.** The dequant wall is cross-block-redundant: the same + q4_K weights are superblock-decoded by all 8 N-blocks. The fix is the **one-time offline repack** - decode the + Q4 tensor ONCE into a cached device buffer keyed off the tensor data pointer, in a layout with the scale/min + pre-applied (store reshuffled 4-bit + per-subblock bf16 d,m, ~1.25x the q4 size, NOT a full bf16 blow-up which + would be ~4x), so the in-kernel path becomes a cheap `q*d - m` with coalesced loads. Then `cp.async` + multi-stage (sized to NOT widen shared past the occupancy cliff) and **Stream-K** over M. These remain the + multi-week core; **prepack is the highest-value next step for q4_K specifically.** +- **Methodology note (unchanged):** the box thermally throttles under sustained perf+bench runs (identical code + ~8.8 cold vs ~6.6 hot earlier), so only same-session A/Bs are trustworthy. The P3b deltas above were taken in + one bracketed cold session for exactly this reason. ### P4 — Tune - Tile (mmq_x/y analogues), warps, pipeline depth, occupancy. We have nsys (throughput) but **not ncu** on the diff --git a/backend/cpp/llama-cpp/paged/kernel/w4a16/marlin-w4a16.cu b/backend/cpp/llama-cpp/paged/kernel/w4a16/marlin-w4a16.cu index 63a9f1908..48b1816ff 100644 --- a/backend/cpp/llama-cpp/paged/kernel/w4a16/marlin-w4a16.cu +++ b/backend/cpp/llama-cpp/paged/kernel/w4a16/marlin-w4a16.cu @@ -21,23 +21,28 @@ // Thread layout: blockDim = (32, WM*WN). threadIdx.x is the warp lane (0..31, // required by mma.cuh get_i/get_j), threadIdx.y is the warp index. // -// P3 structure: -// - Step 1 (block tiling): WM*WN warps compute a BM(=WM*FM*16) x BN(=WN*FN*8) -// output tile; each warp owns an FM x FN grid of m16n8 mma fragments. Replaces -// P2's 1-warp-per-16x8 launch (kills warp underutilization). -// - Step 2 (dequant reuse): the BM x 16 dequantized weight strip is staged once -// per k-step in shared and reused across the block's whole BN span. -// - Small shared footprint (one 16-deep k-step per buffer) keeps occupancy high, -// so block-level parallelism hides the dequant + global-load latency. On this -// path q4_0 and q4_K perform within ~12% of each other, so the dequant compute -// is NOT the limiter - the shared-load / mma-feed throughput is. Measured -// dead-ends (kept here so they are not re-tried blindly): a double-buffered -// cp.async-style pipeline with a large KSTAGE (32 KB shared) collapsed -// occupancy (8.8 -> 2.7 TFLOPS at q4_K n=512), and swapping load_generic for -// ldmatrix regressed to 6.6 TFLOPS because the unswizzled row-major shared -// layout makes ldmatrix bank-conflict. Beating MMQ here needs the full Marlin -// machinery (XOR-swizzled shared layout + tuned async pipeline + offline -// weight reshuffle), which is deferred (P3 step 4). +// P3b step 1 - conflict-free shared layout via SKEW PADDING: +// - WM*WN warps compute a BM(=WM*FM*16) x BN(=WN*FN*8) output tile; each warp +// owns an FM x FN grid of m16n8k16 mma fragments accumulated in F32. +// - Per 16-deep k-step the warps cooperatively dequant the BM x 16 Q4 weight +// strip + load the BN x 16 f32->bf16 activation strip into shared, then feed +// the tensor cores with ldmatrix.x4 (A) / ldmatrix.x2 (B). +// - The shared rows are PADDED to SPAD(=12) bf162 instead of the natural 8. +// ldmatrix's per-lane address is row*stride; with the natural stride 8 (a +// divisor of the 32-bank / 128-byte cycle) rows 0,4,8,12 collide -> 2-way +// bank conflict on every fragment load (this is why P3 measured a plain +// ldmatrix swap as neutral). Skewing the stride to 12 (4-byte aligned, so +// ldmatrix's 16-byte alignment holds) makes {r*12 mod 32} hit 8 distinct +// bank-quads for r in 0..7, so both halves of ldmatrix.x4 and ldmatrix.x2 are +// conflict-free. The pad costs only +50% on the small (~4 KB) staged tile, so +// unlike a 128-byte-row XOR swizzle it does NOT collapse occupancy on GB10 +// (a wide-row swizzle pushed shared to 16 KB and dropped this to ~2.8 TFLOPS). +// +// Dead-ends already proven (do not re-try): a double-buffered KSTAGE=64 cp.async +// pipeline collapsed occupancy (32 KB shared -> 2.7 TFLOPS); a plain ldmatrix on +// the UNpadded layout was neutral (bank conflicts); a wide-row (BK=64) XOR swizzle +// was conflict-free but occupancy-starved (16 KB shared -> 2.8 TFLOPS). Skew +// padding gets the conflict-free feed at near-zero occupancy cost. using namespace ggml_cuda_mma; @@ -45,6 +50,11 @@ typedef tile<16, 8, nv_bfloat162> tile_A; // 16(M) x 16(K) typedef tile< 8, 8, nv_bfloat162> tile_B; // 8(N) x 16(K) typedef tile<16, 8, float> tile_C; // 16(M) x 8(N) +// bf162 columns actually live per shared row (16 k-values = 8 bf162) ... +#define W4A16_KP 8 +// ... padded to this stride to bank-skew the ldmatrix row addresses. +#define W4A16_SPAD 12 + static bool w4a16_enabled() { static const bool en = (std::getenv("GGML_CUDA_W4A16") != nullptr); return en; @@ -99,7 +109,8 @@ w4a16_gemm_kernel( float * __restrict__ dst, const int M, const int N, const int K, const int64_t nb01, const int64_t nb11, const int64_t dst_ne0) { - constexpr int KP = 8; // bf162 pairs per 16-wide k-step (row stride in shared) + constexpr int KP = W4A16_KP; // 8 bf162 = 16 k per row + constexpr int SPAD = W4A16_SPAD; // padded row stride (bank skew) constexpr int BM = WM*FM*16; constexpr int BN = WN*FN*8; constexpr int NTH = WM*WN*32; @@ -112,8 +123,8 @@ w4a16_gemm_kernel( const int warp_m = warp_id / WN; const int tid = threadIdx.y*32 + threadIdx.x; - __shared__ nv_bfloat162 sW[BM*KP]; // [m][kpair], row stride KP (16-byte aligned) - __shared__ nv_bfloat162 sB[BN*KP]; // [n][kpair], row stride KP + __shared__ nv_bfloat162 sW[BM*SPAD]; // [m][kpair], padded row stride SPAD + __shared__ nv_bfloat162 sB[BN*SPAD]; // [n][kpair], padded row stride SPAD tile_C C[FM][FN]; // zero-initialized accumulators @@ -130,7 +141,7 @@ w4a16_gemm_kernel( if (IS_Q4_K) { w0 = w4a16_dq_q4_K(row, k); w1 = w4a16_dq_q4_K(row, k + 1); } else { w0 = w4a16_dq_q4_0(row, k); w1 = w4a16_dq_q4_0(row, k + 1); } } - sW[idx] = __floats2bfloat162_rn(w0, w1); + sW[m*SPAD + kk] = __floats2bfloat162_rn(w0, w1); } // Load the BN x 16 activation strip (f32 -> bf16). #pragma unroll @@ -143,7 +154,7 @@ w4a16_gemm_kernel( const float * arow = (const float *)(src1 + (int64_t)(n0 + n) * nb11); a0 = arow[k]; a1 = arow[k + 1]; } - sB[idx] = __floats2bfloat162_rn(a0, a1); + sB[n*SPAD + kk] = __floats2bfloat162_rn(a0, a1); } __syncthreads(); @@ -152,12 +163,12 @@ w4a16_gemm_kernel( #pragma unroll for (int fm = 0; fm < FM; ++fm) { const int mrow = (warp_m*FM + fm) * 16; - load_generic(Af[fm], sW + mrow*KP, KP); + load_ldmatrix(Af[fm], sW + mrow*SPAD, SPAD); } #pragma unroll for (int fn = 0; fn < FN; ++fn) { const int ncol = (warp_n*FN + fn) * 8; - load_generic(Bf[fn], sB + ncol*KP, KP); + load_ldmatrix(Bf[fn], sB + ncol*SPAD, SPAD); } #pragma unroll for (int fm = 0; fm < FM; ++fm) { @@ -228,7 +239,7 @@ bool ggml_cuda_w4a16_mul_mat( cudaStream_t stream = ctx.stream(); // Block tile config: WM*WN warps compute BM(=WM*FM*16) x BN(=WN*FN*8). - constexpr int WM = 2, WN = 2, FM = 2, FN = 4; // BM=64, BN=64, 4 warps + constexpr int WM = 4, WN = 2, FM = 2, FN = 4; // BM=128, BN=64, 8 warps constexpr int BM = WM*FM*16; constexpr int BN = WN*FN*8; const dim3 grid((unsigned)((M + BM - 1) / BM), (unsigned)((N + BN - 1) / BN), 1);