mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-23 08:08:52 -04:00
feat(w4a16): grow tile to BN128/16w (q4_K +17%, pp512 148->178)
P3b-2 for the Blackwell W4A16 Marlin GEMM. The q4_K dequant wall is partly
cross-N-block-redundant: every N-block re-decodes the same weight strip, so
halving the N-block count (BN 64->128) halves that redundant 6-bit superblock
decode. A BN sweep showed this only pays off when BN is spread across more
warps (16 warps, 8 m16n8 C-tiles/warp) rather than more fragments-per-warp -
the FN=8 / FM=4 variants (16 C-tiles/warp) regressed to ~6.6 TFLOPS on
register pressure. Shipping tile is now WM=4,WN=4,FM=2,FN=4 -> BM=128, BN=128,
16 warps.
Thermally-bracketed cold A/B (q4_K n=512 / q4_0 n=512 via test-backend-ops
perf; pp512/pp2048 via llama-bench Qwen3-32B-Q4_K_M):
BN64/8w (prev): 8.50 / 10.56 TFLOPS, measured 8.45/10.51 again (bracket)
BN128/16w (this): 9.92 / 11.68 TFLOPS, pp512 177.6, pp2048 185.0
-> +17% q4_K, +11% q4_0, +20% pp512 vs the previous commit; +49% pp512 vs
the original block-tiled kernel (119).
Parity gate GGML_CUDA_W4A16=1 test-backend-ops MUL_MAT = 1103/1103, flag set
and unset (byte-identical when unset). Still ~4.7x under MMQ (47 TFLOPS) and
does NOT beat MMQ; BN growth divides the redundant decode but cannot remove
the per-k-step decode itself - the offline weight prepack remains the next
unlock for q4_K. Plan doc P3 table + bottleneck notes updated.
Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -86,21 +86,24 @@ the `BN x 16` f32->bf16 activation strip into shared, one `__syncthreads`, then
|
||||
(`W4A16_SPAD`): ldmatrix's per-lane address is `row*stride`, and the natural stride 8 (a divisor of the
|
||||
32-bank / 128-byte cycle) collides rows 0,4,8,12 into a 2-way bank conflict; skewing to 12 (4-byte aligned, so
|
||||
ldmatrix's 16-byte alignment holds) makes `{r*12 mod 32}` hit 8 distinct bank-quads for r in 0..7, so both
|
||||
halves of ldmatrix are conflict-free at only +50% on the small (~6 KB) staged tile. Shipping config
|
||||
`WM=4,WN=2,FM=2,FN=4` -> `BM=128, BN=64`, 8 warps. M/N tails zero-padded in-kernel; still gated to contiguous
|
||||
2D Q4_0/Q4_K f32 prefill, else falls back to MMQ.
|
||||
halves of ldmatrix are conflict-free at only +50% on the small staged tile (~12 KB at the shipping tile).
|
||||
Shipping config `WM=4,WN=4,FM=2,FN=4` -> `BM=128, BN=128`, 16 warps, 8 m16n8 C-tiles per warp (keeping
|
||||
register pressure low is what lets BN grow without an occupancy cliff). M/N tails zero-padded in-kernel; still
|
||||
gated to contiguous 2D Q4_0/Q4_K f32 prefill, else falls back to MMQ.
|
||||
|
||||
**Per-step results (q4_K n=512 via `test-backend-ops perf`; pp512/pp2048 via llama-bench Qwen3-32B-Q4_K_M):**
|
||||
|
||||
| step | q4_K n=512 | q4_0 n=512 | pp512 | pp2048 | vs MMQ 47 / 718 | notes |
|
||||
|---|---|---|---|---|---|---|
|
||||
| P2 (1 warp/tile) | ~2 TFLOPS | - | 31.75 | - | 0.04x | correctness checkpoint |
|
||||
| Step 1: block tiling (load_generic, BM64/4w) | 6.63 (cold) | 7.53 | 119 | 123 | 0.14x | prior committed kernel |
|
||||
| **P3b: skew-pad ldmatrix + BM128/8w** | **8.52 (cold)** | **10.49** | **148.5** | **153.9** | **0.18x** | +28% q4_K, +40% q4_0, +25% pp512 over step 1 |
|
||||
| Step 1: block tiling (load_generic, BM64/4w) | 6.63 (cold) | 7.53 | 119 | 123 | 0.14x | original committed kernel |
|
||||
| P3b-1: skew-pad ldmatrix + BM128/8w | 8.50 (cold) | 10.56 | 148.5 | 153.9 | 0.18x | +28% q4_K, +40% q4_0 over step 1 |
|
||||
| **P3b-2: + BN128/16w (current)** | **9.92 (cold)** | **11.68** | **177.6** | **185.0** | **0.21x** | +17% q4_K, +20% pp512 over P3b-1 (+49% pp512 over step 1) |
|
||||
|
||||
Parity gate **1103/1103** at every step, flag set and unset (byte-identical when unset). All P3b numbers above
|
||||
are from a single thermally-bracketed cold A/B session (committed measured 6.63/7.53 immediately before AND
|
||||
after the P3b kernel, identical both times -> the deltas are real, not thermal).
|
||||
are from thermally-bracketed cold A/B sessions (committed measured immediately before AND after each candidate,
|
||||
identical both times -> the deltas are real, not thermal). P3b-1 cold A/B: 6.63/7.53 vs 8.52/10.49. P3b-2 cold
|
||||
A/B: BN64/8w 10.56/8.50 then 10.51/8.45 (bracket) vs BN128/16w 11.68/9.92.
|
||||
|
||||
**What landed / what was tried (honest):**
|
||||
- **P3b - LANDED (committed).** Two combined changes lift the prior committed kernel: (1) **skew-pad
|
||||
@@ -119,18 +122,25 @@ after the P3b kernel, identical both times -> the deltas are real, not thermal).
|
||||
when the tile grew (BM128/8w). A 5-config tile sweep then split the two quant types:
|
||||
- **q4_0 SCALES with warps/tiles** (7.7 -> 10.5 -> **15.8 TFLOPS at BM128/16w**): feed/global-traffic bound,
|
||||
helped by cutting redundant activation re-reads (more BM = fewer M-blocks each re-reading the act column).
|
||||
- **q4_K is now DEQUANT-COMPUTE bound** (stuck 6.7-8.5 across every tile; at 16 warps q4_0=15.8 but q4_K=6.8 -
|
||||
they diverge hard). This **refines P3's "within 12%" finding**: that held only in the low-throughput memory
|
||||
-bound regime; once the feed is unblocked, q4_K's per-element 6-bit superblock decode (`get_scale_min_k4` +
|
||||
superblock indexing, redone every k-step AND re-done per N-block) becomes the wall. BM256 regressed both
|
||||
(too few blocks / register pressure).
|
||||
- **Next blocker (the real q4_K unlock) = offline prepack.** The dequant wall is cross-block-redundant: the same
|
||||
q4_K weights are superblock-decoded by all 8 N-blocks. The fix is the **one-time offline repack** - decode the
|
||||
Q4 tensor ONCE into a cached device buffer keyed off the tensor data pointer, in a layout with the scale/min
|
||||
pre-applied (store reshuffled 4-bit + per-subblock bf16 d,m, ~1.25x the q4 size, NOT a full bf16 blow-up which
|
||||
would be ~4x), so the in-kernel path becomes a cheap `q*d - m` with coalesced loads. Then `cp.async`
|
||||
multi-stage (sized to NOT widen shared past the occupancy cliff) and **Stream-K** over M. These remain the
|
||||
multi-week core; **prepack is the highest-value next step for q4_K specifically.**
|
||||
- **q4_K is largely DEQUANT-COMPUTE bound** (the BM64/16w tile gives q4_0=15.8 but q4_K=6.8 - they diverge
|
||||
hard). This **refines P3's "within 12%" finding**: that held only in the low-throughput memory-bound regime;
|
||||
once the feed is unblocked, q4_K's per-element 6-bit superblock decode (`get_scale_min_k4` + superblock
|
||||
indexing, redone every k-step AND re-done by every N-block) becomes the wall. BM256 regressed both (too few
|
||||
blocks / register pressure).
|
||||
- **Growing BN partly relieves the q4_K dequant wall (P3b-2).** Because every N-block re-decodes the same
|
||||
weight strip, halving the N-block count (BN 64->128) halves that redundant q4_K decode - but only when BN is
|
||||
spread across MORE WARPS (16w, 8 C-tiles/warp), not more fragments-per-warp: the FN=8 / FM=4 variants (16
|
||||
C-tiles/warp) regressed to ~6.6 on register pressure, while WM=4,WN=4,FM=2,FN=4 (16w, 8 tiles/warp) lifted
|
||||
q4_K 8.5->9.9 and q4_0 10.6->11.7 cold. BN=256 was no better and costs more shared. **BN128/16w is the
|
||||
shipping tile.**
|
||||
- **Next blocker (the remaining q4_K unlock) = offline prepack.** BN growth only divides the redundant decode by
|
||||
the N-block count; it cannot remove the per-k-step decode itself. The full fix is the **one-time offline
|
||||
repack** - decode the Q4 tensor ONCE into a cached device buffer keyed off the tensor data pointer, in a layout
|
||||
with the scale/min pre-applied (store reshuffled 4-bit + per-subblock bf16 d,m, ~1.25x the q4 size, NOT a full
|
||||
bf16 blow-up which would be ~4x), so the in-kernel path becomes a cheap `q*d - m` with coalesced loads. Then
|
||||
`cp.async` multi-stage (sized to NOT widen shared past the occupancy cliff) and **Stream-K** over M. These
|
||||
remain the multi-week core; **prepack is the highest-value next step for q4_K specifically** (it should let
|
||||
q4_K join q4_0 on the feed-bound scaling curve instead of plateauing at ~10).
|
||||
- **Methodology note (unchanged):** the box thermally throttles under sustained perf+bench runs (identical code
|
||||
~8.8 cold vs ~6.6 hot earlier), so only same-session A/Bs are trustworthy. The P3b deltas above were taken in
|
||||
one bracketed cold session for exactly this reason.
|
||||
|
||||
@@ -239,7 +239,7 @@ bool ggml_cuda_w4a16_mul_mat(
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
// Block tile config: WM*WN warps compute BM(=WM*FM*16) x BN(=WN*FN*8).
|
||||
constexpr int WM = 4, WN = 2, FM = 2, FN = 4; // BM=128, BN=64, 8 warps
|
||||
constexpr int WM = 4, WN = 4, FM = 2, FN = 4; // BM=128, BN=128, 16 warps
|
||||
constexpr int BM = WM*FM*16;
|
||||
constexpr int BN = WN*FN*8;
|
||||
const dim3 grid((unsigned)((M + BM - 1) / BM), (unsigned)((N + BN - 1) / BN), 1);
|
||||
|
||||
Reference in New Issue
Block a user