mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-23 16:19:07 -04:00
feat(w4a16): conflict-free skew-pad ldmatrix + BM128/8w tile (q4_K +28%, q4_0 +40%)
P3b for the Blackwell (sm_120/121) W4A16 Marlin GEMM. Two combined changes
over the prior block-tiled kernel, both verified by a thermally-bracketed
cold A/B (committed measured identically before and after):
- Skew-padded shared layout: store the staged weight/activation rows at a
padded stride of 12 bf162 (8 data + 4 pad) and feed the tensor cores with
ldmatrix.x4 (A) / ldmatrix.x2 (B). ldmatrix's per-lane address is
row*stride; the natural stride 8 divides the 32-bank cycle and collides
rows 0,4,8,12 (2-way bank conflict). Skewing to 12 (still 16-byte aligned)
spreads {r*12 mod 32} across 8 distinct bank-quads, so both ldmatrix halves
are conflict-free at only +50% on the ~6 KB staged tile - unlike a 128-byte
-row XOR swizzle, which is conflict-free but needs 16 KB shared and
collapses occupancy on GB10 (measured 2.84 TFLOPS, worse than baseline).
- Larger tile: BM=128, BN=64, 8 warps (WM=4,WN=2,FM=2,FN=4), which cuts the
redundant per-M-block activation re-reads.
Cold A/B (q4_K n=512 / q4_0 n=512 via test-backend-ops perf; pp512/pp2048 via
llama-bench Qwen3-32B-Q4_K_M):
committed: 6.63 / 7.53 TFLOPS, pp512 119
this: 8.52 / 10.49 TFLOPS, pp512 148.5, pp2048 153.9 (+28% / +40% / +25%)
Parity gate GGML_CUDA_W4A16=1 test-backend-ops MUL_MAT = 1103/1103, flag set
and unset (byte-identical when unset). Still ~5.5x under MMQ (47 TFLOPS) and
does NOT beat MMQ yet; the q4_K limiter has now moved from the mma feed to the
per-element 6-bit superblock dequant (q4_0 scales to 15.8 TFLOPS with more
warps while q4_K stays ~8.5), so the offline weight prepack is the next unlock.
Plan doc P3 section updated with the sweep data and the corrected bottleneck.
Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -21,23 +21,28 @@
|
||||
// Thread layout: blockDim = (32, WM*WN). threadIdx.x is the warp lane (0..31,
|
||||
// required by mma.cuh get_i/get_j), threadIdx.y is the warp index.
|
||||
//
|
||||
// P3 structure:
|
||||
// - Step 1 (block tiling): WM*WN warps compute a BM(=WM*FM*16) x BN(=WN*FN*8)
|
||||
// output tile; each warp owns an FM x FN grid of m16n8 mma fragments. Replaces
|
||||
// P2's 1-warp-per-16x8 launch (kills warp underutilization).
|
||||
// - Step 2 (dequant reuse): the BM x 16 dequantized weight strip is staged once
|
||||
// per k-step in shared and reused across the block's whole BN span.
|
||||
// - Small shared footprint (one 16-deep k-step per buffer) keeps occupancy high,
|
||||
// so block-level parallelism hides the dequant + global-load latency. On this
|
||||
// path q4_0 and q4_K perform within ~12% of each other, so the dequant compute
|
||||
// is NOT the limiter - the shared-load / mma-feed throughput is. Measured
|
||||
// dead-ends (kept here so they are not re-tried blindly): a double-buffered
|
||||
// cp.async-style pipeline with a large KSTAGE (32 KB shared) collapsed
|
||||
// occupancy (8.8 -> 2.7 TFLOPS at q4_K n=512), and swapping load_generic for
|
||||
// ldmatrix regressed to 6.6 TFLOPS because the unswizzled row-major shared
|
||||
// layout makes ldmatrix bank-conflict. Beating MMQ here needs the full Marlin
|
||||
// machinery (XOR-swizzled shared layout + tuned async pipeline + offline
|
||||
// weight reshuffle), which is deferred (P3 step 4).
|
||||
// P3b step 1 - conflict-free shared layout via SKEW PADDING:
|
||||
// - WM*WN warps compute a BM(=WM*FM*16) x BN(=WN*FN*8) output tile; each warp
|
||||
// owns an FM x FN grid of m16n8k16 mma fragments accumulated in F32.
|
||||
// - Per 16-deep k-step the warps cooperatively dequant the BM x 16 Q4 weight
|
||||
// strip + load the BN x 16 f32->bf16 activation strip into shared, then feed
|
||||
// the tensor cores with ldmatrix.x4 (A) / ldmatrix.x2 (B).
|
||||
// - The shared rows are PADDED to SPAD(=12) bf162 instead of the natural 8.
|
||||
// ldmatrix's per-lane address is row*stride; with the natural stride 8 (a
|
||||
// divisor of the 32-bank / 128-byte cycle) rows 0,4,8,12 collide -> 2-way
|
||||
// bank conflict on every fragment load (this is why P3 measured a plain
|
||||
// ldmatrix swap as neutral). Skewing the stride to 12 (4-byte aligned, so
|
||||
// ldmatrix's 16-byte alignment holds) makes {r*12 mod 32} hit 8 distinct
|
||||
// bank-quads for r in 0..7, so both halves of ldmatrix.x4 and ldmatrix.x2 are
|
||||
// conflict-free. The pad costs only +50% on the small (~4 KB) staged tile, so
|
||||
// unlike a 128-byte-row XOR swizzle it does NOT collapse occupancy on GB10
|
||||
// (a wide-row swizzle pushed shared to 16 KB and dropped this to ~2.8 TFLOPS).
|
||||
//
|
||||
// Dead-ends already proven (do not re-try): a double-buffered KSTAGE=64 cp.async
|
||||
// pipeline collapsed occupancy (32 KB shared -> 2.7 TFLOPS); a plain ldmatrix on
|
||||
// the UNpadded layout was neutral (bank conflicts); a wide-row (BK=64) XOR swizzle
|
||||
// was conflict-free but occupancy-starved (16 KB shared -> 2.8 TFLOPS). Skew
|
||||
// padding gets the conflict-free feed at near-zero occupancy cost.
|
||||
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
@@ -45,6 +50,11 @@ typedef tile<16, 8, nv_bfloat162> tile_A; // 16(M) x 16(K)
|
||||
typedef tile< 8, 8, nv_bfloat162> tile_B; // 8(N) x 16(K)
|
||||
typedef tile<16, 8, float> tile_C; // 16(M) x 8(N)
|
||||
|
||||
// bf162 columns actually live per shared row (16 k-values = 8 bf162) ...
|
||||
#define W4A16_KP 8
|
||||
// ... padded to this stride to bank-skew the ldmatrix row addresses.
|
||||
#define W4A16_SPAD 12
|
||||
|
||||
static bool w4a16_enabled() {
|
||||
static const bool en = (std::getenv("GGML_CUDA_W4A16") != nullptr);
|
||||
return en;
|
||||
@@ -99,7 +109,8 @@ w4a16_gemm_kernel(
|
||||
float * __restrict__ dst,
|
||||
const int M, const int N, const int K,
|
||||
const int64_t nb01, const int64_t nb11, const int64_t dst_ne0) {
|
||||
constexpr int KP = 8; // bf162 pairs per 16-wide k-step (row stride in shared)
|
||||
constexpr int KP = W4A16_KP; // 8 bf162 = 16 k per row
|
||||
constexpr int SPAD = W4A16_SPAD; // padded row stride (bank skew)
|
||||
constexpr int BM = WM*FM*16;
|
||||
constexpr int BN = WN*FN*8;
|
||||
constexpr int NTH = WM*WN*32;
|
||||
@@ -112,8 +123,8 @@ w4a16_gemm_kernel(
|
||||
const int warp_m = warp_id / WN;
|
||||
const int tid = threadIdx.y*32 + threadIdx.x;
|
||||
|
||||
__shared__ nv_bfloat162 sW[BM*KP]; // [m][kpair], row stride KP (16-byte aligned)
|
||||
__shared__ nv_bfloat162 sB[BN*KP]; // [n][kpair], row stride KP
|
||||
__shared__ nv_bfloat162 sW[BM*SPAD]; // [m][kpair], padded row stride SPAD
|
||||
__shared__ nv_bfloat162 sB[BN*SPAD]; // [n][kpair], padded row stride SPAD
|
||||
|
||||
tile_C C[FM][FN]; // zero-initialized accumulators
|
||||
|
||||
@@ -130,7 +141,7 @@ w4a16_gemm_kernel(
|
||||
if (IS_Q4_K) { w0 = w4a16_dq_q4_K(row, k); w1 = w4a16_dq_q4_K(row, k + 1); }
|
||||
else { w0 = w4a16_dq_q4_0(row, k); w1 = w4a16_dq_q4_0(row, k + 1); }
|
||||
}
|
||||
sW[idx] = __floats2bfloat162_rn(w0, w1);
|
||||
sW[m*SPAD + kk] = __floats2bfloat162_rn(w0, w1);
|
||||
}
|
||||
// Load the BN x 16 activation strip (f32 -> bf16).
|
||||
#pragma unroll
|
||||
@@ -143,7 +154,7 @@ w4a16_gemm_kernel(
|
||||
const float * arow = (const float *)(src1 + (int64_t)(n0 + n) * nb11);
|
||||
a0 = arow[k]; a1 = arow[k + 1];
|
||||
}
|
||||
sB[idx] = __floats2bfloat162_rn(a0, a1);
|
||||
sB[n*SPAD + kk] = __floats2bfloat162_rn(a0, a1);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -152,12 +163,12 @@ w4a16_gemm_kernel(
|
||||
#pragma unroll
|
||||
for (int fm = 0; fm < FM; ++fm) {
|
||||
const int mrow = (warp_m*FM + fm) * 16;
|
||||
load_generic(Af[fm], sW + mrow*KP, KP);
|
||||
load_ldmatrix(Af[fm], sW + mrow*SPAD, SPAD);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int fn = 0; fn < FN; ++fn) {
|
||||
const int ncol = (warp_n*FN + fn) * 8;
|
||||
load_generic(Bf[fn], sB + ncol*KP, KP);
|
||||
load_ldmatrix(Bf[fn], sB + ncol*SPAD, SPAD);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int fm = 0; fm < FM; ++fm) {
|
||||
@@ -228,7 +239,7 @@ bool ggml_cuda_w4a16_mul_mat(
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
// Block tile config: WM*WN warps compute BM(=WM*FM*16) x BN(=WN*FN*8).
|
||||
constexpr int WM = 2, WN = 2, FM = 2, FN = 4; // BM=64, BN=64, 4 warps
|
||||
constexpr int WM = 4, WN = 2, FM = 2, FN = 4; // BM=128, BN=64, 8 warps
|
||||
constexpr int BM = WM*FM*16;
|
||||
constexpr int BN = WN*FN*8;
|
||||
const dim3 grid((unsigned)((M + BM - 1) / BM), (unsigned)((N + BN - 1) / BN), 1);
|
||||
|
||||
Reference in New Issue
Block a user