mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-23 08:08:52 -04:00
feat(w4a16): conflict-free skew-pad ldmatrix + BM128/8w tile (q4_K +28%, q4_0 +40%)
P3b for the Blackwell (sm_120/121) W4A16 Marlin GEMM. Two combined changes
over the prior block-tiled kernel, both verified by a thermally-bracketed
cold A/B (committed measured identically before and after):
- Skew-padded shared layout: store the staged weight/activation rows at a
padded stride of 12 bf162 (8 data + 4 pad) and feed the tensor cores with
ldmatrix.x4 (A) / ldmatrix.x2 (B). ldmatrix's per-lane address is
row*stride; the natural stride 8 divides the 32-bank cycle and collides
rows 0,4,8,12 (2-way bank conflict). Skewing to 12 (still 16-byte aligned)
spreads {r*12 mod 32} across 8 distinct bank-quads, so both ldmatrix halves
are conflict-free at only +50% on the ~6 KB staged tile - unlike a 128-byte
-row XOR swizzle, which is conflict-free but needs 16 KB shared and
collapses occupancy on GB10 (measured 2.84 TFLOPS, worse than baseline).
- Larger tile: BM=128, BN=64, 8 warps (WM=4,WN=2,FM=2,FN=4), which cuts the
redundant per-M-block activation re-reads.
Cold A/B (q4_K n=512 / q4_0 n=512 via test-backend-ops perf; pp512/pp2048 via
llama-bench Qwen3-32B-Q4_K_M):
committed: 6.63 / 7.53 TFLOPS, pp512 119
this: 8.52 / 10.49 TFLOPS, pp512 148.5, pp2048 153.9 (+28% / +40% / +25%)
Parity gate GGML_CUDA_W4A16=1 test-backend-ops MUL_MAT = 1103/1103, flag set
and unset (byte-identical when unset). Still ~5.5x under MMQ (47 TFLOPS) and
does NOT beat MMQ yet; the q4_K limiter has now moved from the mma feed to the
per-element 6-bit superblock dequant (q4_0 scales to 15.8 TFLOPS with more
warps while q4_K stays ~8.5), so the offline weight prepack is the next unlock.
Plan doc P3 section updated with the sweep data and the corrected bottleneck.
Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -70,19 +70,24 @@ and **Stream-K** partitioning. Sources: IST-DASLab/marlin, arXiv 2408.11743, vLL
|
||||
re-dequantized per n-tile, no pipeline) - this is the correctness checkpoint; P3 brings the speedup. The real
|
||||
Q4_K model matmul path engages the kernel without error.
|
||||
|
||||
### P3 — The Marlin pipeline (the speedup) — STEP 1 LANDED; STEPS 3-4 DEFERRED
|
||||
### P3 — The Marlin pipeline (the speedup) — STEP 1 + SKEW-PAD/TILING LANDED; PREPACK + PIPELINE + STREAM-K DEFERRED
|
||||
Goal: `cp.async` double/triple-buffered global->shared; offline weight reshuffle (a one-time repack of the Q4
|
||||
tensor into the mma+pipeline layout); register-resident activation tiles; Stream-K split for the prefill M.
|
||||
Target: >=150 TFLOP/s (>=~2,300 t/s), then ~213. **MMQ baseline to beat: 47.1 TFLOPS (q4_K n=512) / pp512 718.**
|
||||
|
||||
**Kernel structure now (committed):** block-tiled multi-warp GEMM. `blockDim=(32, WM*WN)` so `threadIdx.x` is the
|
||||
warp lane (required by `mma.cuh` get_i/get_j) and `threadIdx.y` is the warp index; the original 1-warp P2
|
||||
launch put 128 threads on `threadIdx.x` and exploded `get_j` into an out-of-bounds shared read (found via
|
||||
compute-sanitizer). `WM*WN` warps compute a `BM(=WM*FM*16) x BN(=WN*FN*8)` output tile; each warp owns an
|
||||
`FM x FN` grid of m16n8k16 mma fragments accumulated in F32. Per k-step (16-deep): all warps cooperatively
|
||||
dequant the `BM x 16` Q4 weight strip + load the `BN x 16` f32->bf16 activation strip into a single small
|
||||
shared buffer (~4 KB), one `__syncthreads`, then `load_generic` fragments + `FM*FN` mmas. Shipping config
|
||||
`WM=2,WN=2,FM=2,FN=4` -> `BM=64, BN=64`, 4 warps. M/N tails zero-padded in-kernel; still gated to contiguous
|
||||
**Kernel structure now (committed P3b):** block-tiled multi-warp GEMM with a CONFLICT-FREE shared feed via skew
|
||||
padding. `blockDim=(32, WM*WN)` so `threadIdx.x` is the warp lane (required by `mma.cuh` get_i/get_j) and
|
||||
`threadIdx.y` is the warp index; the original 1-warp P2 launch put 128 threads on `threadIdx.x` and exploded
|
||||
`get_j` into an out-of-bounds shared read (found via compute-sanitizer). `WM*WN` warps compute a
|
||||
`BM(=WM*FM*16) x BN(=WN*FN*8)` output tile; each warp owns an `FM x FN` grid of m16n8k16 mma fragments
|
||||
accumulated in F32. Per k-step (16-deep): all warps cooperatively dequant the `BM x 16` Q4 weight strip + load
|
||||
the `BN x 16` f32->bf16 activation strip into shared, one `__syncthreads`, then `ldmatrix.x4` (A) / `ldmatrix.x2`
|
||||
(B) fragments + `FM*FN` mmas. The shared rows hold 8 bf162 of data but are stored at a PADDED stride of 12 bf162
|
||||
(`W4A16_SPAD`): ldmatrix's per-lane address is `row*stride`, and the natural stride 8 (a divisor of the
|
||||
32-bank / 128-byte cycle) collides rows 0,4,8,12 into a 2-way bank conflict; skewing to 12 (4-byte aligned, so
|
||||
ldmatrix's 16-byte alignment holds) makes `{r*12 mod 32}` hit 8 distinct bank-quads for r in 0..7, so both
|
||||
halves of ldmatrix are conflict-free at only +50% on the small (~6 KB) staged tile. Shipping config
|
||||
`WM=4,WN=2,FM=2,FN=4` -> `BM=128, BN=64`, 8 warps. M/N tails zero-padded in-kernel; still gated to contiguous
|
||||
2D Q4_0/Q4_K f32 prefill, else falls back to MMQ.
|
||||
|
||||
**Per-step results (q4_K n=512 via `test-backend-ops perf`; pp512/pp2048 via llama-bench Qwen3-32B-Q4_K_M):**
|
||||
@@ -90,36 +95,45 @@ shared buffer (~4 KB), one `__syncthreads`, then `load_generic` fragments + `FM*
|
||||
| step | q4_K n=512 | q4_0 n=512 | pp512 | pp2048 | vs MMQ 47 / 718 | notes |
|
||||
|---|---|---|---|---|---|---|
|
||||
| P2 (1 warp/tile) | ~2 TFLOPS | - | 31.75 | - | 0.04x | correctness checkpoint |
|
||||
| **Step 1: block tiling** | **6.6-8.8 TFLOPS** | 7.5-9.9 | **118-142** | 122-156 | **~0.15-0.19x** | ~3.5-4.4x over P2; the banked win |
|
||||
| Step 2: dequant reuse | (folded into step 1) | | | | | see below |
|
||||
| Step 3: pipeline | regressed/neutral | | | | | reverted, see below |
|
||||
| Step 4: reshuffle + Stream-K | deferred | | | | | not started |
|
||||
| Step 1: block tiling (load_generic, BM64/4w) | 6.63 (cold) | 7.53 | 119 | 123 | 0.14x | prior committed kernel |
|
||||
| **P3b: skew-pad ldmatrix + BM128/8w** | **8.52 (cold)** | **10.49** | **148.5** | **153.9** | **0.18x** | +28% q4_K, +40% q4_0, +25% pp512 over step 1 |
|
||||
|
||||
Parity gate **1103/1103** at every step, flag set and unset (byte-identical when unset).
|
||||
Parity gate **1103/1103** at every step, flag set and unset (byte-identical when unset). All P3b numbers above
|
||||
are from a single thermally-bracketed cold A/B session (committed measured 6.63/7.53 immediately before AND
|
||||
after the P3b kernel, identical both times -> the deltas are real, not thermal).
|
||||
|
||||
**What landed / what was tried (honest):**
|
||||
- **Step 1 (block tiling) - LANDED.** The bulk of the realised win (P2 ~2 -> ~7-9 TFLOPS). This is the
|
||||
committed kernel.
|
||||
- **Step 2 (dequant reuse across N) - no extra gain, root-caused.** A tile sweep (BM/BN from 64 to 128, 4-16
|
||||
warps) held flat at 8.6-8.8 TFLOPS: enlarging BN to amortize the weight dequant did **not** help. Decisive
|
||||
diagnostic: q4_0 (trivial dequant) and q4_K (heavy 6-bit superblock dequant) run **within ~12%** of each
|
||||
other, so **dequant compute is not the limiter** - the shared-load / mma-feed throughput (and occupancy-hidden
|
||||
global latency) is. Larger BN already reuses the strip across the block; cross-block reuse needs step 4.
|
||||
- **Step 3 (software pipeline) - tried, reverted.** (a) A double-buffered (`NBUF=2`) KSTAGE=64 stage loader
|
||||
(dequant stage s+1 into the spare shared buffer while the mma of stage s runs) collapsed occupancy via 32 KB
|
||||
shared and dropped q4_K n=512 to **2.7 TFLOPS**. (b) Swapping `load_generic` for `ldmatrix` was **neutral**
|
||||
(~6.6 vs ~6.7 TFLOPS measured in the same thermal window) because the unswizzled row-major shared layout makes
|
||||
`ldmatrix.x4` bank-conflict. Both reverted; step 1 (small shared, high occupancy) is strictly better on this
|
||||
GB10. **Methodology note:** the box thermally throttles under sustained perf+bench runs (identical step-1 code
|
||||
measured 8.83 TFLOPS cold vs 6.65 hot), so only same-session A/Bs are trustworthy - earlier cross-run deltas
|
||||
were partly thermal.
|
||||
- **Step 4 (offline weight reshuffle + Stream-K) - DEFERRED, and now known to be the real unlock.** The
|
||||
evidence above says the path to >=150 TFLOPS is *not* bigger tiles or a naive cp.async pipeline but the full
|
||||
Marlin machinery: an **XOR-swizzled shared layout** (so `ldmatrix` is conflict-free), a **one-time offline
|
||||
repack** of the Q4 tensor into that mma+pipeline layout (a load-time transform keyed off the tensor data
|
||||
pointer; ~M*K/2 bytes prepacked buffer, same size as the q4 weights) so dequant becomes cheap conflict-free
|
||||
bit-extraction and the per-(m,n)-block re-dequant disappears, a **tuned cp.async multi-stage** sized to keep
|
||||
occupancy, and **Stream-K** over M. That is the remaining multi-week core.
|
||||
- **P3b - LANDED (committed).** Two combined changes lift the prior committed kernel: (1) **skew-pad
|
||||
conflict-free ldmatrix** (shared row stride 8->12 bf162; makes `ldmatrix.x4`/`.x2` bank-conflict-free at near
|
||||
zero occupancy cost) and (2) **bigger tile / more warps** (`BM=128, BN=64`, 8 warps). Cold A/B: q4_K
|
||||
6.63->8.52 (+28%), q4_0 7.53->10.49 (+40%), pp512 119->148.5 (+25%). **Still ~5.5x under MMQ (47) per-op and
|
||||
~4.8x under pp512 718 - does NOT beat MMQ.** This is forward progress, not the finish line.
|
||||
- **The XOR-swizzle-FIRST plan was tested and is WRONG for this GPU - documented so it is not re-tried.** A
|
||||
wide-row (BK=64, 128-byte rows) XOR swizzle `seg ^ (row&7)` IS conflict-free, but the 16 KB shared it needs
|
||||
collapsed occupancy and dropped q4_K n=512 to **2.84 TFLOPS** (worse than the unswizzled 6.63) - the same
|
||||
occupancy cliff P3 hit with a 32 KB pipeline. The conflict-free feed must be bought WITHOUT widening shared:
|
||||
skew padding (above) does exactly that (6 KB), which is why it is the committed form. Lesson: on GB10 occupancy
|
||||
dominates bank-conflict latency; never trade occupancy for a conflict-free layout.
|
||||
- **Conflict-free feed alone did NOT beat the unswizzled kernel - the limiter moved.** At the SAME BM64/4w tile,
|
||||
skew-pad ldmatrix (6.70) ~= load_generic (6.63): removing bank conflicts bought ~nothing. The win came only
|
||||
when the tile grew (BM128/8w). A 5-config tile sweep then split the two quant types:
|
||||
- **q4_0 SCALES with warps/tiles** (7.7 -> 10.5 -> **15.8 TFLOPS at BM128/16w**): feed/global-traffic bound,
|
||||
helped by cutting redundant activation re-reads (more BM = fewer M-blocks each re-reading the act column).
|
||||
- **q4_K is now DEQUANT-COMPUTE bound** (stuck 6.7-8.5 across every tile; at 16 warps q4_0=15.8 but q4_K=6.8 -
|
||||
they diverge hard). This **refines P3's "within 12%" finding**: that held only in the low-throughput memory
|
||||
-bound regime; once the feed is unblocked, q4_K's per-element 6-bit superblock decode (`get_scale_min_k4` +
|
||||
superblock indexing, redone every k-step AND re-done per N-block) becomes the wall. BM256 regressed both
|
||||
(too few blocks / register pressure).
|
||||
- **Next blocker (the real q4_K unlock) = offline prepack.** The dequant wall is cross-block-redundant: the same
|
||||
q4_K weights are superblock-decoded by all 8 N-blocks. The fix is the **one-time offline repack** - decode the
|
||||
Q4 tensor ONCE into a cached device buffer keyed off the tensor data pointer, in a layout with the scale/min
|
||||
pre-applied (store reshuffled 4-bit + per-subblock bf16 d,m, ~1.25x the q4 size, NOT a full bf16 blow-up which
|
||||
would be ~4x), so the in-kernel path becomes a cheap `q*d - m` with coalesced loads. Then `cp.async`
|
||||
multi-stage (sized to NOT widen shared past the occupancy cliff) and **Stream-K** over M. These remain the
|
||||
multi-week core; **prepack is the highest-value next step for q4_K specifically.**
|
||||
- **Methodology note (unchanged):** the box thermally throttles under sustained perf+bench runs (identical code
|
||||
~8.8 cold vs ~6.6 hot earlier), so only same-session A/Bs are trustworthy. The P3b deltas above were taken in
|
||||
one bracketed cold session for exactly this reason.
|
||||
|
||||
### P4 — Tune
|
||||
- Tile (mmq_x/y analogues), warps, pipeline depth, occupancy. We have nsys (throughput) but **not ncu** on the
|
||||
|
||||
@@ -21,23 +21,28 @@
|
||||
// Thread layout: blockDim = (32, WM*WN). threadIdx.x is the warp lane (0..31,
|
||||
// required by mma.cuh get_i/get_j), threadIdx.y is the warp index.
|
||||
//
|
||||
// P3 structure:
|
||||
// - Step 1 (block tiling): WM*WN warps compute a BM(=WM*FM*16) x BN(=WN*FN*8)
|
||||
// output tile; each warp owns an FM x FN grid of m16n8 mma fragments. Replaces
|
||||
// P2's 1-warp-per-16x8 launch (kills warp underutilization).
|
||||
// - Step 2 (dequant reuse): the BM x 16 dequantized weight strip is staged once
|
||||
// per k-step in shared and reused across the block's whole BN span.
|
||||
// - Small shared footprint (one 16-deep k-step per buffer) keeps occupancy high,
|
||||
// so block-level parallelism hides the dequant + global-load latency. On this
|
||||
// path q4_0 and q4_K perform within ~12% of each other, so the dequant compute
|
||||
// is NOT the limiter - the shared-load / mma-feed throughput is. Measured
|
||||
// dead-ends (kept here so they are not re-tried blindly): a double-buffered
|
||||
// cp.async-style pipeline with a large KSTAGE (32 KB shared) collapsed
|
||||
// occupancy (8.8 -> 2.7 TFLOPS at q4_K n=512), and swapping load_generic for
|
||||
// ldmatrix regressed to 6.6 TFLOPS because the unswizzled row-major shared
|
||||
// layout makes ldmatrix bank-conflict. Beating MMQ here needs the full Marlin
|
||||
// machinery (XOR-swizzled shared layout + tuned async pipeline + offline
|
||||
// weight reshuffle), which is deferred (P3 step 4).
|
||||
// P3b step 1 - conflict-free shared layout via SKEW PADDING:
|
||||
// - WM*WN warps compute a BM(=WM*FM*16) x BN(=WN*FN*8) output tile; each warp
|
||||
// owns an FM x FN grid of m16n8k16 mma fragments accumulated in F32.
|
||||
// - Per 16-deep k-step the warps cooperatively dequant the BM x 16 Q4 weight
|
||||
// strip + load the BN x 16 f32->bf16 activation strip into shared, then feed
|
||||
// the tensor cores with ldmatrix.x4 (A) / ldmatrix.x2 (B).
|
||||
// - The shared rows are PADDED to SPAD(=12) bf162 instead of the natural 8.
|
||||
// ldmatrix's per-lane address is row*stride; with the natural stride 8 (a
|
||||
// divisor of the 32-bank / 128-byte cycle) rows 0,4,8,12 collide -> 2-way
|
||||
// bank conflict on every fragment load (this is why P3 measured a plain
|
||||
// ldmatrix swap as neutral). Skewing the stride to 12 (4-byte aligned, so
|
||||
// ldmatrix's 16-byte alignment holds) makes {r*12 mod 32} hit 8 distinct
|
||||
// bank-quads for r in 0..7, so both halves of ldmatrix.x4 and ldmatrix.x2 are
|
||||
// conflict-free. The pad costs only +50% on the small (~4 KB) staged tile, so
|
||||
// unlike a 128-byte-row XOR swizzle it does NOT collapse occupancy on GB10
|
||||
// (a wide-row swizzle pushed shared to 16 KB and dropped this to ~2.8 TFLOPS).
|
||||
//
|
||||
// Dead-ends already proven (do not re-try): a double-buffered KSTAGE=64 cp.async
|
||||
// pipeline collapsed occupancy (32 KB shared -> 2.7 TFLOPS); a plain ldmatrix on
|
||||
// the UNpadded layout was neutral (bank conflicts); a wide-row (BK=64) XOR swizzle
|
||||
// was conflict-free but occupancy-starved (16 KB shared -> 2.8 TFLOPS). Skew
|
||||
// padding gets the conflict-free feed at near-zero occupancy cost.
|
||||
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
@@ -45,6 +50,11 @@ typedef tile<16, 8, nv_bfloat162> tile_A; // 16(M) x 16(K)
|
||||
typedef tile< 8, 8, nv_bfloat162> tile_B; // 8(N) x 16(K)
|
||||
typedef tile<16, 8, float> tile_C; // 16(M) x 8(N)
|
||||
|
||||
// bf162 columns actually live per shared row (16 k-values = 8 bf162) ...
|
||||
#define W4A16_KP 8
|
||||
// ... padded to this stride to bank-skew the ldmatrix row addresses.
|
||||
#define W4A16_SPAD 12
|
||||
|
||||
static bool w4a16_enabled() {
|
||||
static const bool en = (std::getenv("GGML_CUDA_W4A16") != nullptr);
|
||||
return en;
|
||||
@@ -99,7 +109,8 @@ w4a16_gemm_kernel(
|
||||
float * __restrict__ dst,
|
||||
const int M, const int N, const int K,
|
||||
const int64_t nb01, const int64_t nb11, const int64_t dst_ne0) {
|
||||
constexpr int KP = 8; // bf162 pairs per 16-wide k-step (row stride in shared)
|
||||
constexpr int KP = W4A16_KP; // 8 bf162 = 16 k per row
|
||||
constexpr int SPAD = W4A16_SPAD; // padded row stride (bank skew)
|
||||
constexpr int BM = WM*FM*16;
|
||||
constexpr int BN = WN*FN*8;
|
||||
constexpr int NTH = WM*WN*32;
|
||||
@@ -112,8 +123,8 @@ w4a16_gemm_kernel(
|
||||
const int warp_m = warp_id / WN;
|
||||
const int tid = threadIdx.y*32 + threadIdx.x;
|
||||
|
||||
__shared__ nv_bfloat162 sW[BM*KP]; // [m][kpair], row stride KP (16-byte aligned)
|
||||
__shared__ nv_bfloat162 sB[BN*KP]; // [n][kpair], row stride KP
|
||||
__shared__ nv_bfloat162 sW[BM*SPAD]; // [m][kpair], padded row stride SPAD
|
||||
__shared__ nv_bfloat162 sB[BN*SPAD]; // [n][kpair], padded row stride SPAD
|
||||
|
||||
tile_C C[FM][FN]; // zero-initialized accumulators
|
||||
|
||||
@@ -130,7 +141,7 @@ w4a16_gemm_kernel(
|
||||
if (IS_Q4_K) { w0 = w4a16_dq_q4_K(row, k); w1 = w4a16_dq_q4_K(row, k + 1); }
|
||||
else { w0 = w4a16_dq_q4_0(row, k); w1 = w4a16_dq_q4_0(row, k + 1); }
|
||||
}
|
||||
sW[idx] = __floats2bfloat162_rn(w0, w1);
|
||||
sW[m*SPAD + kk] = __floats2bfloat162_rn(w0, w1);
|
||||
}
|
||||
// Load the BN x 16 activation strip (f32 -> bf16).
|
||||
#pragma unroll
|
||||
@@ -143,7 +154,7 @@ w4a16_gemm_kernel(
|
||||
const float * arow = (const float *)(src1 + (int64_t)(n0 + n) * nb11);
|
||||
a0 = arow[k]; a1 = arow[k + 1];
|
||||
}
|
||||
sB[idx] = __floats2bfloat162_rn(a0, a1);
|
||||
sB[n*SPAD + kk] = __floats2bfloat162_rn(a0, a1);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -152,12 +163,12 @@ w4a16_gemm_kernel(
|
||||
#pragma unroll
|
||||
for (int fm = 0; fm < FM; ++fm) {
|
||||
const int mrow = (warp_m*FM + fm) * 16;
|
||||
load_generic(Af[fm], sW + mrow*KP, KP);
|
||||
load_ldmatrix(Af[fm], sW + mrow*SPAD, SPAD);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int fn = 0; fn < FN; ++fn) {
|
||||
const int ncol = (warp_n*FN + fn) * 8;
|
||||
load_generic(Bf[fn], sB + ncol*KP, KP);
|
||||
load_ldmatrix(Bf[fn], sB + ncol*SPAD, SPAD);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int fm = 0; fm < FM; ++fm) {
|
||||
@@ -228,7 +239,7 @@ bool ggml_cuda_w4a16_mul_mat(
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
// Block tile config: WM*WN warps compute BM(=WM*FM*16) x BN(=WN*FN*8).
|
||||
constexpr int WM = 2, WN = 2, FM = 2, FN = 4; // BM=64, BN=64, 4 warps
|
||||
constexpr int WM = 4, WN = 2, FM = 2, FN = 4; // BM=128, BN=64, 8 warps
|
||||
constexpr int BM = WM*FM*16;
|
||||
constexpr int BN = WN*FN*8;
|
||||
const dim3 grid((unsigned)((M + BM - 1) / BM), (unsigned)((N + BN - 1) / BN), 1);
|
||||
|
||||
Reference in New Issue
Block a user