mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-29 19:06:43 -04:00
feat(paged): tail-fusion (0042) + full-step decode CUDA graph default-on (0043); FP4-MMA W4A4 (0034) + Marlin W4A16 (0035) MoE-GEMM scaffolds default-off
0042 fuses the pre-norm residual add into RMSNorm (+0.5% prefill, bit-exact). 0043 makes the full-step MoE decode CUDA graph default-on (+2-4% decode, bit-exact; removes ~18x per-step host kernel re-issue, A/B-confirmed). 0034 (native FP4-MMA W4A4) and 0035 (Marlin-style W4A16 grouped MoE GEMM) are correct + bit-exact but regress vs the int8 FP4-MMQ in-backend on GB10 (bf16 MMA is ~half the int8 rate); shipped default-off as validated mechanisms and recorded negatives per the parity methodology. Assisted-by: Claude:opus-4.8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -0,0 +1,638 @@
|
||||
From 14824147a504b58cc8be2f127f7d6bedb672cfc9 Mon Sep 17 00:00:00 2001
|
||||
From: Ettore Di Giacinto <mudler@localai.io>
|
||||
Date: Mon, 29 Jun 2026 00:11:22 +0200
|
||||
Subject: [PATCH] feat(paged): native NVFP4 (W4A4) FP4-MMA large-M prefill GEMM
|
||||
(patch 0034)
|
||||
|
||||
Replace the rejected 0033 dequant->bf16 cuBLAS scaffold with a native FP4-MMA
|
||||
(W4A4 block-scale OMMA) large-M GEMM that engages only at prefill, behind the
|
||||
same LLAMA_FP4_PREFILL_M threshold, so decode / small-M stay byte-untouched.
|
||||
|
||||
KERNEL (ggml/src/ggml-cuda/fp4-gemm.{cu,cuh}): the VERIFIED PoC
|
||||
(fp4_gemm_w4a4_opt.cu, NMSE=0 vs same-dequant f32) copied verbatim at its tuned
|
||||
best config 128x128 / KBLK4 / STAGES2 / PAD4 (~103 TFLOP/s, beats cuBLAS bf16).
|
||||
Preserved exactly: e4m3(true_scale) convention, the ldmatrix.sync.m8n8.x4 A-operand
|
||||
load, the mma.sync.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64 OMMA, cp.async
|
||||
multistage prefetch, register-resident accumulators, smem PAD. Activations are
|
||||
quantized with the SAME math as quantize_mmq_nvfp4 (e4m3 amax/6 + the +/-2 code
|
||||
search + ggml_cuda_float_to_fp4_e2m1), so it is bit-exact-by-construction with the
|
||||
shipped FP4-MMQ path (only the K-reduction order differs, greedy-md5 gated).
|
||||
|
||||
DENSE: routed in ggml_cuda_mul_mat via ggml_cuda_fp4_prefill_should_engage()
|
||||
(src0 NVFP4 + src1/dst f32, contiguous, non-transposed, 2D, Blackwell, M>thr,
|
||||
N%128==0, K%256==0). Non-divisible shapes fall back to FP4-MMQ (NOT the rejected
|
||||
bf16 cuBLAS path). LANDED + greedy-md5 byte-identical (on==off: "Paris").
|
||||
|
||||
MoE GROUPED (the actual prefill bottleneck): mmq.cu forces the grouped FP4-MMQ
|
||||
id-path OFF at large M (n_experts>0), so mul_mat_id falls to its per-expert
|
||||
host-sync loop where each expert slice flows back through ggml_cuda_mul_mat and
|
||||
hits the native kernel per-expert. Prefill is not graph-replayed so this is safe;
|
||||
decode keeps ne12<=threshold so the graph-safe MMQ id-path (patch 0025) is
|
||||
untouched. LANDED via host-sync + greedy-md5 byte-identical (on==off).
|
||||
FOLLOW-UP (flagged): a graph-safe ragged-batched grouped FP4-MMA kernel to remove
|
||||
the per-expert host-sync loop; out of scope for this pass.
|
||||
|
||||
BUILD: arch=compute_121a,code=[compute_121a,sm_121a] already in build-cuda flags;
|
||||
the kernel uses BLACKWELL_MMA_AVAILABLE/CP_ASYNC_AVAILABLE guards. Incremental
|
||||
build-cuda green (ggml-cuda relinked, llama-server + llama-cli relinked).
|
||||
|
||||
Default-off (LLAMA_FP4_PREFILL_M=0 == stock); set env/-D to engage.
|
||||
|
||||
Assisted-by: Claude:opus-4.8 [Claude Code]
|
||||
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
|
||||
---
|
||||
ggml/src/ggml-cuda/fp4-gemm.cu | 453 ++++++++++++++++++++++++++++++++
|
||||
ggml/src/ggml-cuda/fp4-gemm.cuh | 38 +++
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 14 +
|
||||
ggml/src/ggml-cuda/mmq.cu | 35 +--
|
||||
4 files changed, 525 insertions(+), 15 deletions(-)
|
||||
create mode 100644 ggml/src/ggml-cuda/fp4-gemm.cu
|
||||
create mode 100644 ggml/src/ggml-cuda/fp4-gemm.cuh
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/fp4-gemm.cu b/ggml/src/ggml-cuda/fp4-gemm.cu
|
||||
new file mode 100644
|
||||
index 0000000..86da551
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/ggml-cuda/fp4-gemm.cu
|
||||
@@ -0,0 +1,453 @@
|
||||
+#include "fp4-gemm.cuh"
|
||||
+
|
||||
+#include <cfloat>
|
||||
+#include <cstdint>
|
||||
+#include <cstdlib>
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// [paged patch 0034] Native NVFP4 (W4A4) large-M GEMM. See fp4-gemm.cuh.
|
||||
+//
|
||||
+// The GEMM kernel, the m16n8k64 block-scale OMMA wrapper, the cp.async helpers and
|
||||
+// the layout-split kernel are the VERIFIED PoC (fp4_gemm_w4a4_opt.cu, NMSE=0) copied
|
||||
+// verbatim - do not "tidy" the index math, it is the load-bearing correctness.
|
||||
+// ===========================================================================
|
||||
+
|
||||
+#define FP4_QK 64 // == QK_NVFP4
|
||||
+#define FP4_SAW 8 // u32 per nvfp4 block qs (32 bytes)
|
||||
+
|
||||
+#ifndef LLAMA_FP4_PREFILL_M
|
||||
+#define LLAMA_FP4_PREFILL_M 0
|
||||
+#endif // LLAMA_FP4_PREFILL_M
|
||||
+
|
||||
+static int64_t ggml_cuda_fp4_prefill_m() {
|
||||
+ static const int64_t m = [] {
|
||||
+ const char * e = getenv("LLAMA_FP4_PREFILL_M");
|
||||
+ return e != nullptr ? (int64_t) atoll(e) : (int64_t) LLAMA_FP4_PREFILL_M;
|
||||
+ }();
|
||||
+ return m;
|
||||
+}
|
||||
+
|
||||
+// ---- layout split: block_nvfp4[R*Kb] -> qs codes [R*Kb*8 u32] + scales [R*Kb u32] ----
|
||||
+// Same fp4 codes & e4m3 scale bytes as the GGUF, restored into two contiguous,
|
||||
+// 16B-friendly arrays so the kernel's cp.async copies are coalesced. (PoC verbatim.)
|
||||
+static __global__ void fp4_split_layout(
|
||||
+ const block_nvfp4 * __restrict__ X, uint32_t * __restrict__ Q, uint32_t * __restrict__ S,
|
||||
+ int R, int Kb) {
|
||||
+ const int64_t b = (int64_t) blockIdx.x * blockDim.x + threadIdx.x;
|
||||
+ const int64_t tot = (int64_t) R * Kb;
|
||||
+ if (b >= tot) {
|
||||
+ return;
|
||||
+ }
|
||||
+ const block_nvfp4 & blk = X[b];
|
||||
+ const uint32_t * q = (const uint32_t *) blk.qs;
|
||||
+ uint32_t * dq = &Q[b * 8];
|
||||
+#pragma unroll
|
||||
+ for (int w = 0; w < 8; w++) {
|
||||
+ dq[w] = q[w];
|
||||
+ }
|
||||
+ S[b] = *(const uint32_t *) blk.d;
|
||||
+}
|
||||
+
|
||||
+// ---- activation quantizer: f32 [M_real x K] -> split NVFP4 (Aq codes + As scales) ----
|
||||
+// Uses the SAME math as quantize_mmq_nvfp4 (quantize.cu): e4m3 scale = ue4m3(amax/6)
|
||||
+// with the +/-2 code search, ggml_cuda_float_to_fp4_e2m1 for the nibbles, so the
|
||||
+// activation codes are identical to the shipped FP4-MMQ path. Packs into the PoC
|
||||
+// block layout (qs[s*8+j] = code(e[j]) | code(e[j+8])<<4) expected by the kernel's
|
||||
+// ldmatrix A-operand load. One thread per (row, kb, sub-block).
|
||||
+static __global__ void fp4_quantize_act_split(
|
||||
+ const float * __restrict__ x, uint32_t * __restrict__ Aq, uint32_t * __restrict__ As,
|
||||
+ int M_real, int K, int Kb) {
|
||||
+#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
+ const int64_t tot = (int64_t) M_real * Kb * 4; // 4 sub-blocks per 64-element block
|
||||
+ const int64_t t = (int64_t) blockIdx.x * blockDim.x + threadIdx.x;
|
||||
+ if (t >= tot) {
|
||||
+ return;
|
||||
+ }
|
||||
+ const int sub = (int) (t & 3);
|
||||
+ const int64_t rb = t >> 2; // row*Kb + kb
|
||||
+ const int kb = (int) (rb % Kb);
|
||||
+ const int64_t row = rb / Kb;
|
||||
+
|
||||
+ const float * v16 = x + row * (int64_t) K + (int64_t) kb * FP4_QK + sub * 16;
|
||||
+ float vals[16];
|
||||
+ float amax = 0.0f;
|
||||
+#pragma unroll
|
||||
+ for (int k = 0; k < 16; k++) {
|
||||
+ const float vv = v16[k];
|
||||
+ vals[k] = vv;
|
||||
+ amax = fmaxf(amax, fabsf(vv));
|
||||
+ }
|
||||
+
|
||||
+ static constexpr int test_offsets[5] = { 0, -1, 1, -2, 2 };
|
||||
+ const int first_fp8_code = (int) ggml_cuda_fp32_to_ue4m3(amax / 6.0f);
|
||||
+
|
||||
+ float best_err = FLT_MAX;
|
||||
+ uint8_t fp8_code = 0;
|
||||
+ float subblock_scale = 0.0f;
|
||||
+#pragma unroll
|
||||
+ for (int i = 0; i < 5; i++) {
|
||||
+ const int test_code = first_fp8_code + test_offsets[i];
|
||||
+ if (test_code < 0 || test_code > 0x7e) {
|
||||
+ continue;
|
||||
+ }
|
||||
+ const uint8_t code = (uint8_t) test_code;
|
||||
+ const float test_scale = ggml_cuda_ue4m3_to_fp32(code);
|
||||
+ const float test_inv_scale = test_scale > 0.0f ? 0.5f / test_scale : 0.0f;
|
||||
+ float cur_err = 0.0f;
|
||||
+#pragma unroll
|
||||
+ for (int k = 0; k < 16; k++) {
|
||||
+ const uint8_t q = ggml_cuda_float_to_fp4_e2m1(vals[k], test_inv_scale);
|
||||
+ const float err_diff = fabsf(vals[k]) - fabsf((float) kvalues_mxfp4[q & 0x7]) * test_scale;
|
||||
+ cur_err = fmaf(err_diff, err_diff, cur_err);
|
||||
+ }
|
||||
+ if (cur_err < best_err) {
|
||||
+ best_err = cur_err;
|
||||
+ fp8_code = code;
|
||||
+ subblock_scale = test_scale;
|
||||
+ }
|
||||
+ }
|
||||
+ const float inv_scale = subblock_scale > 0.0f ? 0.5f / subblock_scale : 0.0f;
|
||||
+
|
||||
+ // PoC packing: qs[s*8+j] = code(e[j]) | code(e[j+8])<<4 -> two u32 words per sub-block.
|
||||
+ uint32_t w0 = 0, w1 = 0;
|
||||
+#pragma unroll
|
||||
+ for (int j = 0; j < 4; j++) {
|
||||
+ const uint32_t lo = ggml_cuda_float_to_fp4_e2m1(vals[j], inv_scale);
|
||||
+ const uint32_t hi = ggml_cuda_float_to_fp4_e2m1(vals[j + 8], inv_scale);
|
||||
+ w0 |= ((lo | (hi << 4)) & 0xff) << (8 * j);
|
||||
+ }
|
||||
+#pragma unroll
|
||||
+ for (int j = 0; j < 4; j++) {
|
||||
+ const uint32_t lo = ggml_cuda_float_to_fp4_e2m1(vals[j + 4], inv_scale);
|
||||
+ const uint32_t hi = ggml_cuda_float_to_fp4_e2m1(vals[j + 12], inv_scale);
|
||||
+ w1 |= ((lo | (hi << 4)) & 0xff) << (8 * j);
|
||||
+ }
|
||||
+
|
||||
+ const int64_t blk = row * (int64_t) Kb + kb;
|
||||
+ Aq[blk * 8 + sub * 2 + 0] = w0;
|
||||
+ Aq[blk * 8 + sub * 2 + 1] = w1;
|
||||
+ reinterpret_cast<uint8_t *>(As + blk)[sub] = fp8_code;
|
||||
+#else
|
||||
+ GGML_UNUSED(x); GGML_UNUSED(Aq); GGML_UNUSED(As);
|
||||
+ GGML_UNUSED(M_real); GGML_UNUSED(K); GGML_UNUSED(Kb);
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // BLACKWELL_MMA_AVAILABLE
|
||||
+}
|
||||
+
|
||||
+// ---- native FP4 block-scale OMMA wrapper (PoC verbatim) ----
|
||||
+static __device__ __forceinline__ void fp4_mma(
|
||||
+ float d[4], const uint32_t a[4], const uint32_t b[2], uint32_t as, uint32_t bs) {
|
||||
+#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
+ asm volatile(
|
||||
+ "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 "
|
||||
+ "{%0,%1,%2,%3},{%4,%5,%6,%7},{%8,%9},{%0,%1,%2,%3},%10,{0,0},%11,{0,0};"
|
||||
+ : "+f"(d[0]),"+f"(d[1]),"+f"(d[2]),"+f"(d[3])
|
||||
+ : "r"(a[0]),"r"(a[1]),"r"(a[2]),"r"(a[3]),"r"(b[0]),"r"(b[1]),"r"(as),"r"(bs));
|
||||
+#else
|
||||
+ GGML_UNUSED(d); GGML_UNUSED(a); GGML_UNUSED(b); GGML_UNUSED(as); GGML_UNUSED(bs);
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // BLACKWELL_MMA_AVAILABLE
|
||||
+}
|
||||
+
|
||||
+// ---- cp.async helpers (PoC verbatim) ----
|
||||
+static __device__ __forceinline__ void fp4_cp_async16(void * smem, const void * gmem) {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ unsigned s = (unsigned) __cvta_generic_to_shared(smem);
|
||||
+ asm volatile("cp.async.cg.shared.global [%0],[%1],16;\n" :: "r"(s), "l"(gmem));
|
||||
+#else
|
||||
+ GGML_UNUSED(smem); GGML_UNUSED(gmem); NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+template<int B>
|
||||
+static __device__ __forceinline__ void fp4_cp_async_small(void * smem, const void * gmem) {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ unsigned s = (unsigned) __cvta_generic_to_shared(smem);
|
||||
+ asm volatile("cp.async.ca.shared.global [%0],[%1],%2;\n" :: "r"(s), "l"(gmem), "n"(B));
|
||||
+#else
|
||||
+ GGML_UNUSED(smem); GGML_UNUSED(gmem); NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+static __device__ __forceinline__ void fp4_cp_commit() {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ asm volatile("cp.async.commit_group;\n" ::);
|
||||
+#else
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+template<int N>
|
||||
+static __device__ __forceinline__ void fp4_cp_wait() {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
+#else
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+
|
||||
+// ---------------------------------------------------------------------------
|
||||
+// Optimized native FP4 GEMM (PoC verbatim). C[M,N] = A_fp4[M,K] @ W_fp4[N,K]^T
|
||||
+// inputs are layout-split: Aq[M*Kb*8], As[M*Kb], Wq[N*Kb*8], Ws[N*Kb]
|
||||
+// Tile BM x BN, K-step = KBLK nvfp4 blocks (BK = 64*KBLK), STAGES-deep pipeline,
|
||||
+// PAD u32 padding per smem row to defeat bank conflicts.
|
||||
+// ---------------------------------------------------------------------------
|
||||
+template<int BM,int BN,int WARPS_M,int WARPS_N,int KBLK,int STAGES,int PAD>
|
||||
+__launch_bounds__(WARPS_M*WARPS_N*32,1)
|
||||
+static __global__ void fp4_opt_kernel(
|
||||
+ const uint32_t * __restrict__ Aq, const uint32_t * __restrict__ As,
|
||||
+ const uint32_t * __restrict__ Wq, const uint32_t * __restrict__ Ws,
|
||||
+ float * __restrict__ C, int M, int N, int K) {
|
||||
+#ifdef BLACKWELL_MMA_AVAILABLE
|
||||
+ constexpr int NWARP=WARPS_M*WARPS_N;
|
||||
+ constexpr int THREADS=NWARP*32;
|
||||
+ constexpr int WM=BM/WARPS_M, WN=BN/WARPS_N;
|
||||
+ constexpr int MF=WM/16, NF=WN/8;
|
||||
+ constexpr int SAW=8; // u32 per block (qs)
|
||||
+ constexpr int ARS=KBLK*SAW+PAD; // A smem row stride (u32)
|
||||
+ constexpr int WRS=KBLK*SAW+PAD; // W smem row stride (u32)
|
||||
+
|
||||
+ extern __shared__ uint32_t smem[];
|
||||
+ // per-stage slabs
|
||||
+ constexpr int SZ_AQ=BM*ARS, SZ_AS=BM*KBLK, SZ_WQ=BN*WRS, SZ_WS=BN*KBLK;
|
||||
+ constexpr int STAGE_SZ=SZ_AQ+SZ_AS+SZ_WQ+SZ_WS;
|
||||
+ uint32_t* sAq[STAGES]; uint32_t* sAs[STAGES]; uint32_t* sWq[STAGES]; uint32_t* sWs[STAGES];
|
||||
+#pragma unroll
|
||||
+ for(int s=0;s<STAGES;s++){
|
||||
+ uint32_t* base=smem+s*STAGE_SZ;
|
||||
+ sAq[s]=base; sAs[s]=base+SZ_AQ; sWq[s]=base+SZ_AQ+SZ_AS; sWs[s]=base+SZ_AQ+SZ_AS+SZ_WQ;
|
||||
+ }
|
||||
+
|
||||
+ const int tid=threadIdx.x, warp=tid>>5, lane=tid&31;
|
||||
+ const int wrow=warp/WARPS_N, wcol=warp%WARPS_N;
|
||||
+ const int grp=lane>>2, tig=lane&3;
|
||||
+ const int tidxA = lane/4 + (lane%2)*8;
|
||||
+ const int tidxB = lane/4;
|
||||
+ const int blockRow=blockIdx.y*BM, blockCol=blockIdx.x*BN;
|
||||
+ const int Kb=K/64;
|
||||
+ const int numK=Kb/KBLK;
|
||||
+
|
||||
+ float acc[MF][NF][4];
|
||||
+#pragma unroll
|
||||
+ for(int i=0;i<MF;i++)for(int j=0;j<NF;j++)for(int r=0;r<4;r++)acc[i][j][r]=0;
|
||||
+
|
||||
+ // async-load k-tile `kt` into stage `st`
|
||||
+ auto load_tile=[&](int st,int kt){
|
||||
+ const int kb0=kt*KBLK;
|
||||
+ // A qs: BM*KBLK blocks, 2x 16B chunks each
|
||||
+#pragma unroll 1
|
||||
+ for(int idx=tid; idx<BM*KBLK*2; idx+=THREADS){
|
||||
+ int chunk=idx&1, blk=idx>>1;
|
||||
+ int r=blk/KBLK, kb=blk%KBLK;
|
||||
+ const uint32_t* src=&Aq[((size_t)(blockRow+r)*Kb + kb0+kb)*SAW + chunk*4];
|
||||
+ fp4_cp_async16(&sAq[st][r*ARS + kb*SAW + chunk*4], src);
|
||||
+ }
|
||||
+ // W qs
|
||||
+#pragma unroll 1
|
||||
+ for(int idx=tid; idx<BN*KBLK*2; idx+=THREADS){
|
||||
+ int chunk=idx&1, blk=idx>>1;
|
||||
+ int r=blk/KBLK, kb=blk%KBLK;
|
||||
+ const uint32_t* src=&Wq[((size_t)(blockCol+r)*Kb + kb0+kb)*SAW + chunk*4];
|
||||
+ fp4_cp_async16(&sWq[st][r*WRS + kb*SAW + chunk*4], src);
|
||||
+ }
|
||||
+ // A scales: BM rows, KBLK contiguous u32 each
|
||||
+#pragma unroll 1
|
||||
+ for(int r=tid; r<BM; r+=THREADS){
|
||||
+ const uint32_t* src=&As[(size_t)(blockRow+r)*Kb + kb0];
|
||||
+ uint32_t* dst=&sAs[st][r*KBLK];
|
||||
+ if(KBLK==4) fp4_cp_async16(dst,src);
|
||||
+ else if(KBLK==2) fp4_cp_async_small<8>(dst,src);
|
||||
+ else fp4_cp_async_small<4>(dst,src);
|
||||
+ }
|
||||
+ // W scales
|
||||
+#pragma unroll 1
|
||||
+ for(int r=tid; r<BN; r+=THREADS){
|
||||
+ const uint32_t* src=&Ws[(size_t)(blockCol+r)*Kb + kb0];
|
||||
+ uint32_t* dst=&sWs[st][r*KBLK];
|
||||
+ if(KBLK==4) fp4_cp_async16(dst,src);
|
||||
+ else if(KBLK==2) fp4_cp_async_small<8>(dst,src);
|
||||
+ else fp4_cp_async_small<4>(dst,src);
|
||||
+ }
|
||||
+ };
|
||||
+
|
||||
+ // prologue: issue STAGES-1 tiles (tiles 0..STAGES-2 into stages 0..STAGES-2)
|
||||
+#pragma unroll
|
||||
+ for(int s=0;s<STAGES-1;s++){ if(s<numK) load_tile(s,s); fp4_cp_commit(); }
|
||||
+
|
||||
+ for(int kt=0; kt<numK; kt++){
|
||||
+ // prefetch tile kt+STAGES-1 into its stage (overlaps this iter's compute)
|
||||
+ int ld=kt+(STAGES-1);
|
||||
+ if(ld<numK) load_tile(ld%STAGES,ld);
|
||||
+ fp4_cp_commit();
|
||||
+ // wait until tile kt has landed (leave STAGES-1 prefetches in flight)
|
||||
+ fp4_cp_wait<STAGES-1>();
|
||||
+ __syncthreads();
|
||||
+
|
||||
+ const int rs=kt%STAGES;
|
||||
+#pragma unroll
|
||||
+ for(int kb=0; kb<KBLK; kb++){
|
||||
+ // A fragments via ldmatrix (PRESERVED layout)
|
||||
+ uint32_t af[MF][4]; uint32_t asc[MF];
|
||||
+#pragma unroll
|
||||
+ for(int mi=0; mi<MF; mi++){
|
||||
+ int rb=wrow*WM+mi*16;
|
||||
+ const uint32_t* base=&sAq[rs][rb*ARS + kb*SAW];
|
||||
+ const uint32_t* xs = base + (lane%16)*ARS + (lane/16)*4;
|
||||
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0,%1,%2,%3},[%4];"
|
||||
+ : "=r"(af[mi][0]),"=r"(af[mi][1]),"=r"(af[mi][2]),"=r"(af[mi][3])
|
||||
+ : "l"(xs));
|
||||
+ asc[mi]=sAs[rs][(rb+tidxA)*KBLK+kb];
|
||||
+ }
|
||||
+ // B fragments (PRESERVED manual gather), padded row stride
|
||||
+ uint32_t bf[NF][2]; uint32_t bsc[NF];
|
||||
+#pragma unroll
|
||||
+ for(int ni=0; ni<NF; ni++){
|
||||
+ int nb=wcol*WN+ni*8;
|
||||
+ const uint32_t* base=&sWq[rs][nb*WRS + kb*SAW];
|
||||
+#pragma unroll
|
||||
+ for(int l=0;l<2;l++){
|
||||
+ int gi=grp, gj=l*4+tig;
|
||||
+ bf[ni][l]=base[gi*WRS + gj];
|
||||
+ }
|
||||
+ bsc[ni]=sWs[rs][(nb+tidxB)*KBLK+kb];
|
||||
+ }
|
||||
+#pragma unroll
|
||||
+ for(int mi=0;mi<MF;mi++)
|
||||
+#pragma unroll
|
||||
+ for(int ni=0;ni<NF;ni++)
|
||||
+ fp4_mma(acc[mi][ni], af[mi], bf[ni], asc[mi], bsc[ni]);
|
||||
+ }
|
||||
+ // ensure all warps finished reading stage rs before it is reused by a
|
||||
+ // future prefetch (the stage is overwritten at iter kt+1's prefetch).
|
||||
+ __syncthreads();
|
||||
+ }
|
||||
+
|
||||
+#pragma unroll
|
||||
+ for(int mi=0;mi<MF;mi++)
|
||||
+#pragma unroll
|
||||
+ for(int ni=0;ni<NF;ni++){
|
||||
+ int orb=blockRow+wrow*WM+mi*16, ocb=blockCol+wcol*WN+ni*8;
|
||||
+ float* d=acc[mi][ni];
|
||||
+ C[(size_t)(orb+grp)*N+ocb+2*tig] =d[0];
|
||||
+ C[(size_t)(orb+grp)*N+ocb+2*tig+1] =d[1];
|
||||
+ C[(size_t)(orb+grp+8)*N+ocb+2*tig] =d[2];
|
||||
+ C[(size_t)(orb+grp+8)*N+ocb+2*tig+1]=d[3];
|
||||
+ }
|
||||
+#else
|
||||
+ GGML_UNUSED(Aq); GGML_UNUSED(As); GGML_UNUSED(Wq); GGML_UNUSED(Ws);
|
||||
+ GGML_UNUSED(C); GGML_UNUSED(M); GGML_UNUSED(N); GGML_UNUSED(K);
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // BLACKWELL_MMA_AVAILABLE
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// ggml integration
|
||||
+// ===========================================================================
|
||||
+
|
||||
+bool ggml_cuda_fp4_prefill_should_engage(
|
||||
+ const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * dst, int cc) {
|
||||
+ if (src0->type != GGML_TYPE_NVFP4) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ if (!blackwell_mma_available(cc)) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ const int64_t thr = ggml_cuda_fp4_prefill_m();
|
||||
+ if (thr <= 0) {
|
||||
+ return false; // default-off == stock; decode/small-M untouched
|
||||
+ }
|
||||
+ if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ if (src1->ne[1] <= thr) {
|
||||
+ return false; // M = src1->ne[1]; only LARGE M (prefill)
|
||||
+ }
|
||||
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ if (ggml_is_transposed(src0) || ggml_is_transposed(src1)) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ // 2D only (a single weight matrix; per-expert MoE slices set ne[2]=ne[3]=1).
|
||||
+ if (src0->ne[2] != 1 || src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ const int64_t K = src0->ne[0];
|
||||
+ const int64_t N = src0->ne[1];
|
||||
+ if (N % 128 != 0 || K % 256 != 0) {
|
||||
+ return false; // tile constraints; otherwise fall back to MMQ
|
||||
+ }
|
||||
+ return true;
|
||||
+}
|
||||
+
|
||||
+void ggml_cuda_mul_mat_fp4_large_m(
|
||||
+ ggml_backend_cuda_context & ctx,
|
||||
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
+ GGML_ASSERT(src0->type == GGML_TYPE_NVFP4);
|
||||
+ GGML_ASSERT(src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
+
|
||||
+ const int64_t K = src0->ne[0];
|
||||
+ const int64_t N = src0->ne[1];
|
||||
+ const int64_t M = src1->ne[1];
|
||||
+ const int64_t Kb = K / FP4_QK;
|
||||
+ GGML_ASSERT(K % 256 == 0 && N % 128 == 0);
|
||||
+
|
||||
+ cudaStream_t stream = ctx.stream();
|
||||
+
|
||||
+ constexpr int BM = 128, BN = 128, WM = 4, WN = 2, KBLK = 4, STAGES = 2, PAD = 4;
|
||||
+ const int64_t Mpad = ((M + BM - 1) / BM) * BM;
|
||||
+
|
||||
+ ggml_cuda_pool_alloc<uint32_t> Wq(ctx.pool(), (size_t) N * Kb * 8);
|
||||
+ ggml_cuda_pool_alloc<uint32_t> Ws(ctx.pool(), (size_t) N * Kb);
|
||||
+ ggml_cuda_pool_alloc<uint32_t> Aq(ctx.pool(), (size_t) Mpad * Kb * 8);
|
||||
+ ggml_cuda_pool_alloc<uint32_t> As(ctx.pool(), (size_t) Mpad * Kb);
|
||||
+
|
||||
+ // Zero the scales of the padded A-rows (M..Mpad) so they contribute 0 (scale 0 ->
|
||||
+ // the OMMA's per-block scale is 0). The padded qs may stay uninitialized.
|
||||
+ if (Mpad > M) {
|
||||
+ CUDA_CHECK(cudaMemsetAsync(As.get() + (size_t) M * Kb, 0,
|
||||
+ (size_t) (Mpad - M) * Kb * sizeof(uint32_t), stream));
|
||||
+ }
|
||||
+
|
||||
+ // split weights (GGUF block_nvfp4 -> Wq/Ws)
|
||||
+ {
|
||||
+ const int64_t tot = N * Kb;
|
||||
+ const int threads = 256;
|
||||
+ const int64_t grid = (tot + threads - 1) / threads;
|
||||
+ fp4_split_layout<<<grid, threads, 0, stream>>>(
|
||||
+ (const block_nvfp4 *) src0->data, Wq.get(), Ws.get(), (int) N, (int) Kb);
|
||||
+ CUDA_CHECK(cudaGetLastError());
|
||||
+ }
|
||||
+ // quantize + split activations (real rows only)
|
||||
+ {
|
||||
+ const int64_t tot = M * Kb * 4;
|
||||
+ const int threads = 256;
|
||||
+ const int64_t grid = (tot + threads - 1) / threads;
|
||||
+ fp4_quantize_act_split<<<grid, threads, 0, stream>>>(
|
||||
+ (const float *) src1->data, Aq.get(), As.get(), (int) M, (int) K, (int) Kb);
|
||||
+ CUDA_CHECK(cudaGetLastError());
|
||||
+ }
|
||||
+
|
||||
+ // Output: write the (Mpad x N) result straight into dst when M is tile-aligned,
|
||||
+ // otherwise into a temp and copy back the first M rows (C is row-major C[m*N+n]).
|
||||
+ float * Cout = (float *) dst->data;
|
||||
+ ggml_cuda_pool_alloc<float> Ctmp(ctx.pool());
|
||||
+ if (Mpad > M) {
|
||||
+ Cout = Ctmp.alloc((size_t) Mpad * N);
|
||||
+ }
|
||||
+
|
||||
+ auto kern = fp4_opt_kernel<BM, BN, WM, WN, KBLK, STAGES, PAD>;
|
||||
+ constexpr int SZ_AQ = BM * (KBLK * 8 + PAD), SZ_AS = BM * KBLK;
|
||||
+ constexpr int SZ_WQ = BN * (KBLK * 8 + PAD), SZ_WS = BN * KBLK;
|
||||
+ constexpr int STAGE_SZ = SZ_AQ + SZ_AS + SZ_WQ + SZ_WS;
|
||||
+ const int smem_bytes = STAGES * STAGE_SZ * (int) sizeof(uint32_t);
|
||||
+ CUDA_CHECK(cudaFuncSetAttribute(kern, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes));
|
||||
+
|
||||
+ dim3 grid((unsigned) (N / BN), (unsigned) (Mpad / BM));
|
||||
+ dim3 block(WM * WN * 32);
|
||||
+ kern<<<grid, block, smem_bytes, stream>>>(
|
||||
+ Aq.get(), As.get(), Wq.get(), Ws.get(), Cout, (int) Mpad, (int) N, (int) K);
|
||||
+ CUDA_CHECK(cudaGetLastError());
|
||||
+
|
||||
+ if (Mpad > M) {
|
||||
+ CUDA_CHECK(cudaMemcpyAsync(dst->data, Ctmp.get(), (size_t) M * N * sizeof(float),
|
||||
+ cudaMemcpyDeviceToDevice, stream));
|
||||
+ }
|
||||
+}
|
||||
diff --git a/ggml/src/ggml-cuda/fp4-gemm.cuh b/ggml/src/ggml-cuda/fp4-gemm.cuh
|
||||
new file mode 100644
|
||||
index 0000000..8ed1aa4
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/ggml-cuda/fp4-gemm.cuh
|
||||
@@ -0,0 +1,38 @@
|
||||
+#pragma once
|
||||
+
|
||||
+#include "common.cuh"
|
||||
+
|
||||
+// [paged patch 0034] Native NVFP4 (W4A4) large-M GEMM for Blackwell sm_121a (GB10).
|
||||
+//
|
||||
+// A Marlin-class tiled FP4-MMA GEMM (cp.async multistage prefetch, register-resident
|
||||
+// accumulators, ldmatrix A-operand, m16n8k64 mxf4nvf4 block-scale OMMA with e4m3
|
||||
+// true-scale) that beats the dequant->bf16 cuBLAS (nvjet) path that the rejected 0033
|
||||
+// scaffold routed large-M prefill through. The kernel body is the bit-exact PoC
|
||||
+// (NMSE=0 vs a same-dequant f32 reference) at its tuned best config
|
||||
+// (128x128 / KBLK4 / STAGES2 / PAD4).
|
||||
+//
|
||||
+// It is bit-exact-by-construction with the shipped FP4-MMQ path: it consumes the SAME
|
||||
+// e2m1 weight nibbles + e4m3 scale bytes from the GGUF block_nvfp4, quantizes
|
||||
+// activations with the SAME math as quantize_mmq_nvfp4 (e4m3 amax/6 scale + the +/-2
|
||||
+// code search + ggml_cuda_float_to_fp4_e2m1), and feeds the SAME hardware OMMA. The
|
||||
+// only difference vs FP4-MMQ is the K-accumulation order (a different but equivalent
|
||||
+// f32 reduction tree), which is greedy-md5 gated like every other paged path.
|
||||
+//
|
||||
+// Engages ONLY at large M (prefill), behind the 0033 LLAMA_FP4_PREFILL_M threshold;
|
||||
+// decode and small-M are byte-untouched and never reach this kernel.
|
||||
+
|
||||
+// True if the native FP4 large-M path should handle this dense NVFP4 mul_mat:
|
||||
+// src0 NVFP4 + src1/dst f32, contiguous, not transposed, 2D, Blackwell,
|
||||
+// LLAMA_FP4_PREFILL_M > 0, M = src1->ne[1] > threshold, N % 128 == 0, K % 256 == 0.
|
||||
+// This single predicate also routes per-expert MoE slices (they flow through
|
||||
+// ggml_cuda_mul_mat) into the native kernel.
|
||||
+bool ggml_cuda_fp4_prefill_should_engage(
|
||||
+ const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * dst, int cc);
|
||||
+
|
||||
+// Native FP4 W4A4 GEMM: dst[M,N] = src1_act[M,K] @ src0_w[N,K]^T.
|
||||
+// src0 = NVFP4 weights, src1 = f32 activations, dst = f32. Streams on ctx.stream(),
|
||||
+// pool-allocates scratch; no host sync. Caller must have checked
|
||||
+// ggml_cuda_fp4_prefill_should_engage().
|
||||
+void ggml_cuda_mul_mat_fp4_large_m(
|
||||
+ ggml_backend_cuda_context & ctx,
|
||||
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 2ecc971..a92003c 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "ggml-cuda/diagmask.cuh"
|
||||
#include "ggml-cuda/diag.cuh"
|
||||
#include "ggml-cuda/fattn.cuh"
|
||||
+#include "ggml-cuda/fp4-gemm.cuh"
|
||||
#include "ggml-cuda/fwht.cuh"
|
||||
#include "ggml-cuda/getrows.cuh"
|
||||
#include "ggml-cuda/im2col.cuh"
|
||||
@@ -2541,6 +2542,19 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
|
||||
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
||||
|
||||
+ // [paged patch 0034] Native NVFP4 (W4A4) large-M (prefill) FP4-MMA GEMM. Engages only
|
||||
+ // when LLAMA_FP4_PREFILL_M>0 and M=src1->ne[1] exceeds it (and tile dims divide), so
|
||||
+ // decode / small-M is byte-untouched. This also catches the per-expert MoE slices that
|
||||
+ // flow through here from the mul_mat_id host-sync loop, routing each expert GEMM to the
|
||||
+ // native kernel (see ggml_cuda_should_use_mmq's MoE gate in mmq.cu).
|
||||
+ if (!split) {
|
||||
+ const int cc_fp4 = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
+ if (ggml_cuda_fp4_prefill_should_engage(src0, src1, dst, cc_fp4)) {
|
||||
+ ggml_cuda_mul_mat_fp4_large_m(ctx, src0, src1, dst);
|
||||
+ return;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
// If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q.
|
||||
// But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data.
|
||||
// Therefore, in such cases use cuBLAS.
|
||||
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
|
||||
index 2dcaaab..694a402 100644
|
||||
--- a/ggml/src/ggml-cuda/mmq.cu
|
||||
+++ b/ggml/src/ggml-cuda/mmq.cu
|
||||
@@ -321,24 +321,29 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
|
||||
return false;
|
||||
}
|
||||
|
||||
- // Paged prefill lever (patch 0033): OPTION-(a) route large-M NVFP4 dense GEMMs
|
||||
- // OFF the FP4-MMQ kernel and through the dequant->bf16 cuBLAS (nvjet)
|
||||
- // tensor-core path (ggml_cuda_op_mul_mat_cublas, NVFP4 bf16 branch). The
|
||||
- // scope premise was that FP4-MMQ is register-bound to ~3% of FP4 peak at
|
||||
- // large M. MEASURED ON GB10 THIS IS FALSE: FP4-MMQ at M=512..2048 beats
|
||||
- // dequant->bf16 cuBLAS by 29-49% (S_PP A/B in docs/PREFILL_GEMM_RESULTS.md),
|
||||
- // because bf16 tensor-core peak is ~half FP4 peak AND the per-step weight
|
||||
- // dequant + 4x bf16 weight traffic (~8x total vs the FP4 read) dominate and
|
||||
- // only partially amortize as M grows. The path is NUMERICALLY VALID and
|
||||
- // benign (greedy md5 byte-identical to FP4-MMQ; test-backend-ops passes), so
|
||||
- // it is kept as a validated, env-gated scaffold (for option-(b) native FP4
|
||||
- // large-M kernels and non-GB10 hardware), but DEFAULT-DISABLED (== stock).
|
||||
- // Set -D LLAMA_FP4_PREFILL_M=<M> or env LLAMA_FP4_PREFILL_M=<M> to A/B it;
|
||||
- // 0 (default) disables. Dense only (n_experts == 0).
|
||||
+ // Paged prefill lever (patch 0033 -> 0034): route large-M NVFP4 prefill GEMMs to the
|
||||
+ // native FP4-MMA (W4A4 OMMA) kernel in fp4-gemm.cu instead of the FP4-MMQ kernel.
|
||||
+ //
|
||||
+ // - DENSE (n_experts == 0): the reroute happens earlier, in ggml_cuda_mul_mat's
|
||||
+ // ggml_cuda_fp4_prefill_should_engage() early check, which knows the N/K tile
|
||||
+ // divisibility. We deliberately do NOT force dense off MMQ here: if the native
|
||||
+ // kernel cannot take a shape (non-divisible N/K) MMQ stays the correct fallback,
|
||||
+ // NOT the rejected dequant->bf16 cuBLAS path.
|
||||
+ // - MoE (n_experts > 0): force the grouped FP4-MMQ id-path OFF at large M so
|
||||
+ // mul_mat_id falls to its per-expert host-sync loop, where each expert slice flows
|
||||
+ // back through ggml_cuda_mul_mat and hits the native kernel. CUDA graphs are
|
||||
+ // disabled for that prefill step (prefill is not graph-replayed); a graph-safe
|
||||
+ // grouped (ragged-batched) FP4-MMA kernel is the flagged follow-up. Decode keeps
|
||||
+ // ne12 <= threshold so the grouped graph-safe MMQ id-path (patch 0025) is untouched.
|
||||
+ //
|
||||
+ // The historical 0033 finding stands: dequant->bf16 cuBLAS LOSES to FP4-MMQ at large M
|
||||
+ // (bf16 tensor-core peak is ~half FP4 peak + 8x weight traffic), which is exactly why
|
||||
+ // the native FP4-MMA kernel (NMSE=0, ~103 TFLOP/s, beats cuBLAS bf16) replaces it here.
|
||||
+ // Set -D LLAMA_FP4_PREFILL_M=<M> or env LLAMA_FP4_PREFILL_M=<M>; 0 (default) == stock.
|
||||
#ifndef LLAMA_FP4_PREFILL_M
|
||||
#define LLAMA_FP4_PREFILL_M 0
|
||||
#endif // LLAMA_FP4_PREFILL_M
|
||||
- if (type == GGML_TYPE_NVFP4 && n_experts == 0 && blackwell_mma_available(cc)) {
|
||||
+ if (type == GGML_TYPE_NVFP4 && n_experts > 0 && blackwell_mma_available(cc)) {
|
||||
static const int64_t fp4_prefill_m = [] {
|
||||
const char * e = getenv("LLAMA_FP4_PREFILL_M");
|
||||
return e != nullptr ? (int64_t) atoll(e) : (int64_t) LLAMA_FP4_PREFILL_M;
|
||||
--
|
||||
2.43.0
|
||||
|
||||
@@ -0,0 +1,572 @@
|
||||
From df186bd20a23a1baae92f2828fc68f240c115e7d Mon Sep 17 00:00:00 2001
|
||||
From: Ettore Di Giacinto <mudler@localai.io>
|
||||
Date: Mon, 29 Jun 2026 03:34:48 +0200
|
||||
Subject: [PATCH] feat(paged): Marlin-style W4A16 grouped MoE prefill GEMM
|
||||
(patch 0035)
|
||||
|
||||
Profile-validated #2 prefill lever: a DISTINCT kernel from the two prefill
|
||||
rejects. NOT patch 0033 (separate-pass dequant -> bf16 cuBLAS/nvjet, lost to
|
||||
FP4-MMQ at large M). NOT patch 0034 (native W4A4 FP4-MMA mxf4nvf4 OMMA, still
|
||||
pays the quantize_mmq_nvfp4 activation-quant tax). This is the W4A16 shape vLLM
|
||||
uses on sm_121: FP4 expert weights dequantized to bf16 IN REGISTERS right before
|
||||
the MMA, activations kept bf16 (a cheap f32->bf16 cast, NO per-block amax/code
|
||||
quantize -> ZERO activation-quant tax), standard bf16 m16n8k16 mma.sync (reuses
|
||||
ggml/src/ggml-cuda/mma.cuh tiles) into f32 accumulators, cp.async multistage.
|
||||
|
||||
GROUPED (the actual prefill shape): one kernel launch over the mul_mat_id
|
||||
token-sorted activation buffer (src1_sorted is already sorted-by-expert by the
|
||||
existing host path), with a per-M-tile expert map so each output tile reads its
|
||||
own expert weight matrix (src0 + expert*nb02); the ragged per-expert row tail is
|
||||
masked. No per-expert kernel launch, no per-expert M-padding (vs the 0034
|
||||
per-expert host-sync loop). The B (weight) fragment is filled by in-register
|
||||
FP4->bf16 dequant via the tile get_i/get_j contract (correct-by-construction
|
||||
vs ldmatrix); the A (activation) fragment is a bf16 ldmatrix.
|
||||
|
||||
ROUTING (default-off; distinct env from 0034):
|
||||
- mmq.cu (ggml_cuda_should_use_mmq): NVFP4 + n_experts>0 + Blackwell +
|
||||
ne11(tokens) > LLAMA_W4A16_PREFILL_M returns false, so mul_mat_id falls to
|
||||
the token-sorting host path.
|
||||
- ggml-cuda.cu (ggml_cuda_mul_mat_id): once src1_sorted is built, if
|
||||
ggml_cuda_w4a16_moe_grouped_should_engage() the grouped kernel replaces the
|
||||
per-expert GEMM loop (dst_sorted then scattered back as usual). Decode keeps
|
||||
ne12 <= threshold so the graph-safe grouped MMQ id-path (0025/0043) is
|
||||
untouched; non-MoE / non-NVFP4 / small-M are byte-untouched.
|
||||
|
||||
TOGGLE / A-B: env (or -D) LLAMA_W4A16_PREFILL_M. 0 (default) == OFF == stock;
|
||||
>0 engages for MoE prefill GEMMs with tokens > the value. LLAMA_W4A16_DEBUG=1
|
||||
prints per-GEMM engagement (total_rows / n_tiles / max-tokens-per-expert).
|
||||
|
||||
VALIDATION (GB10, sm_121a, Qwen3.6-35B-A3B-NVFP4):
|
||||
- test-backend-ops MUL_MAT_ID nvfp4 (CUDA0 vs CPU oracle), W4A16 forced
|
||||
(LLAMA_W4A16_PREFILL_M=1): 81/81 OK, 0 FAIL (incl. multi-tile-per-expert
|
||||
cases). The threading bug found here (mma.cuh tile ops use threadIdx.x AS the
|
||||
warp lane, so the block must be 2D (32,NWARP)) is fixed.
|
||||
- greedy md5 (paged MoE, LLAMA_KV_PAGED=1): NOT-engaged (high threshold) ==
|
||||
OFF baseline 4a3fd812 BYTE-IDENTICAL (default-off is stock); engaged
|
||||
(120 grouped GEMMs on a 116-token prefill) is coherent + benign (a different
|
||||
but equivalent bf16-vs-Q8_1 K-reduction, like the documented paged-MoE path
|
||||
divergence), output near-identical to stock.
|
||||
|
||||
HONEST PERF (S_PP t/s, llama-batched-bench -fa on -ngl 99 -ntg 32 -npl 1,
|
||||
LLAMA_KV_PAGED=1, OFF vs W4A16 thr=64), CURRENTLY A REGRESSION:
|
||||
npp 512 : 1096.7 -> 794.8 (-28%)
|
||||
npp 1024: 1413.5 -> 961.1 (-32%)
|
||||
npp 2048: 1671.3 -> 1069.6 (-36%)
|
||||
Decode TG unaffected (~53 t/s both). The kernel is CORRECT but its first
|
||||
untuned config (BM64/BN128/STAGES2, scalar in-register dequant, f32->bf16 cast
|
||||
pre-pass, 4B weight cp.async, BM-tile ragged-utilization waste, per-GEMM host
|
||||
tile-map + 3 H2D copies) does not yet beat the tuned FP4-MMQ grouped path on
|
||||
GB10; it does not realize the profiled vLLM 2.16x. Ships DEFAULT-OFF (like 0033
|
||||
scaffold / 0017) as the validated, env-gated mechanism + bit-exact gate for the
|
||||
tuning follow-ups (deeper pipeline, ldmatrix/16B weight staging, smem-conflict
|
||||
padding, larger/register-resident tiles, removing the cast pre-pass, dropping
|
||||
the per-GEMM host map).
|
||||
|
||||
Build: arch=compute_121a,code=[compute_121a,sm_121a]; BLACKWELL_MMA_AVAILABLE /
|
||||
AMPERE_MMA_AVAILABLE / CP_ASYNC_AVAILABLE guards (NO_DEVICE_CODE off-Blackwell).
|
||||
|
||||
Assisted-by: Claude:opus-4.8 [Claude Code]
|
||||
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
|
||||
---
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 12 +
|
||||
ggml/src/ggml-cuda/mmq.cu | 17 ++
|
||||
ggml/src/ggml-cuda/w4a16-gemm.cu | 359 ++++++++++++++++++++++++++++++
|
||||
ggml/src/ggml-cuda/w4a16-gemm.cuh | 55 +++++
|
||||
4 files changed, 443 insertions(+)
|
||||
create mode 100644 ggml/src/ggml-cuda/w4a16-gemm.cu
|
||||
create mode 100644 ggml/src/ggml-cuda/w4a16-gemm.cuh
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 3151684..37e4d11 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "ggml-cuda/diag.cuh"
|
||||
#include "ggml-cuda/fattn.cuh"
|
||||
#include "ggml-cuda/fp4-gemm.cuh"
|
||||
+#include "ggml-cuda/w4a16-gemm.cuh"
|
||||
#include "ggml-cuda/fwht.cuh"
|
||||
#include "ggml-cuda/getrows.cuh"
|
||||
#include "ggml-cuda/im2col.cuh"
|
||||
@@ -2747,6 +2748,16 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
+ // [paged patch 0035] Marlin-style W4A16 grouped MoE prefill GEMM: one launch over the
|
||||
+ // token-sorted activation buffer (src1_sorted, already f32 + sorted-by-expert above) with a
|
||||
+ // per-tile expert map, in-register FP4->bf16 weight dequant + bf16 mma. Replaces the
|
||||
+ // per-expert host-sync GEMM loop. Engages only when LLAMA_W4A16_PREFILL_M>0 and ne12>thr
|
||||
+ // (large-M prefill); decode / non-NVFP4 keep the loop below (byte-identical to stock).
|
||||
+ if (ggml_cuda_w4a16_moe_grouped_should_engage(src0, src1, dst, cc)) {
|
||||
+ ggml_cuda_mul_mat_id_w4a16_grouped(ctx, src0,
|
||||
+ (const float *) src1_sorted.ptr, (float *) dst_sorted.ptr,
|
||||
+ tokens_per_expert.data(), ne02, ne10, ne0, stream);
|
||||
+ } else {
|
||||
char * src1_data_cur = (char *) src1_sorted.ptr;
|
||||
char * dst_data_cur = (char *) dst_sorted.ptr;
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
||||
@@ -2795,6 +2806,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
src1_data_cur += src1_slice.nb[2];
|
||||
dst_data_cur += dst_slice.nb[2];
|
||||
}
|
||||
+ }
|
||||
|
||||
get_rows_cuda(dst_sorted.ptr, type_dst_sorted, ids_from_sorted, dst->data, dst->type,
|
||||
ne0, ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted,
|
||||
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
|
||||
index 694a402..dc5c2d1 100644
|
||||
--- a/ggml/src/ggml-cuda/mmq.cu
|
||||
+++ b/ggml/src/ggml-cuda/mmq.cu
|
||||
@@ -353,6 +353,23 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
|
||||
}
|
||||
}
|
||||
|
||||
+ // Paged prefill lever (patch 0035): the Marlin-style W4A16 grouped MoE GEMM also needs the
|
||||
+ // grouped FP4-MMQ id-path forced OFF at large M so mul_mat_id falls to the token-sorting
|
||||
+ // host path, where the grouped W4A16 kernel is dispatched (in-register FP4->bf16 dequant +
|
||||
+ // bf16 mma, ZERO activation-quant). Distinct env from 0034; default 0 == stock.
|
||||
+#ifndef LLAMA_W4A16_PREFILL_M
|
||||
+#define LLAMA_W4A16_PREFILL_M 0
|
||||
+#endif // LLAMA_W4A16_PREFILL_M
|
||||
+ if (type == GGML_TYPE_NVFP4 && n_experts > 0 && blackwell_mma_available(cc)) {
|
||||
+ static const int64_t w4a16_prefill_m = [] {
|
||||
+ const char * e = getenv("LLAMA_W4A16_PREFILL_M");
|
||||
+ return e != nullptr ? (int64_t) atoll(e) : (int64_t) LLAMA_W4A16_PREFILL_M;
|
||||
+ }();
|
||||
+ if (w4a16_prefill_m > 0 && ne11 > w4a16_prefill_m) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
if (turing_mma_available(cc)) {
|
||||
return true;
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/w4a16-gemm.cu b/ggml/src/ggml-cuda/w4a16-gemm.cu
|
||||
new file mode 100644
|
||||
index 0000000..f348f31
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/ggml-cuda/w4a16-gemm.cu
|
||||
@@ -0,0 +1,359 @@
|
||||
+#include "w4a16-gemm.cuh"
|
||||
+#include "mma.cuh"
|
||||
+
|
||||
+#include <algorithm>
|
||||
+#include <cstdint>
|
||||
+#include <cstdlib>
|
||||
+#include <vector>
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// [paged patch 0035] Marlin-style W4A16 grouped MoE prefill GEMM. See w4a16-gemm.cuh.
|
||||
+//
|
||||
+// In-register FP4->bf16 weight dequant + bf16 activations + bf16 m16n8k16 mma.sync (mma.cuh),
|
||||
+// cp.async multistage, grouped (ragged, per-tile expert offset) over the token-sorted buffer.
|
||||
+// ===========================================================================
|
||||
+
|
||||
+using namespace ggml_cuda_mma;
|
||||
+typedef tile<16, 8, nv_bfloat162> tile_A; // A operand: M=16, K=16
|
||||
+typedef tile< 8, 8, nv_bfloat162> tile_B; // B operand: N=8, K=16
|
||||
+typedef tile<16, 8, float> tile_C; // accumulator: M=16, N=8
|
||||
+
|
||||
+#ifndef LLAMA_W4A16_PREFILL_M
|
||||
+#define LLAMA_W4A16_PREFILL_M 0
|
||||
+#endif // LLAMA_W4A16_PREFILL_M
|
||||
+
|
||||
+int64_t ggml_cuda_w4a16_prefill_m() {
|
||||
+ static const int64_t m = [] {
|
||||
+ const char * e = getenv("LLAMA_W4A16_PREFILL_M");
|
||||
+ return e != nullptr ? (int64_t) atoll(e) : (int64_t) LLAMA_W4A16_PREFILL_M;
|
||||
+ }();
|
||||
+ return m;
|
||||
+}
|
||||
+
|
||||
+bool ggml_cuda_w4a16_prefill_enabled() {
|
||||
+ return ggml_cuda_w4a16_prefill_m() > 0;
|
||||
+}
|
||||
+
|
||||
+// ---- cp.async helpers (sm80+; raw bytes, no cast) ----
|
||||
+static __device__ __forceinline__ void w4a16_cp_async16(void * smem, const void * gmem) {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ const unsigned s = (unsigned) __cvta_generic_to_shared(smem);
|
||||
+ asm volatile("cp.async.cg.shared.global [%0],[%1],16;\n" :: "r"(s), "l"(gmem));
|
||||
+#else
|
||||
+ GGML_UNUSED(smem); GGML_UNUSED(gmem); NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+static __device__ __forceinline__ void w4a16_cp_async4(void * smem, const void * gmem) {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ const unsigned s = (unsigned) __cvta_generic_to_shared(smem);
|
||||
+ asm volatile("cp.async.ca.shared.global [%0],[%1],4;\n" :: "r"(s), "l"(gmem));
|
||||
+#else
|
||||
+ GGML_UNUSED(smem); GGML_UNUSED(gmem); NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+static __device__ __forceinline__ void w4a16_cp_commit() {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ asm volatile("cp.async.commit_group;\n" ::);
|
||||
+#else
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+template<int N> static __device__ __forceinline__ void w4a16_cp_wait() {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
+#else
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+
|
||||
+// ---- f32 -> bf16 activation cast (NO quantize). Pads the [total_rows, pad_rows) tail with 0. ----
|
||||
+static __global__ void w4a16_cast_act_f32_bf16(
|
||||
+ const float * __restrict__ x, nv_bfloat16 * __restrict__ y, int64_t n, int64_t npad) {
|
||||
+ const int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x;
|
||||
+ if (i >= npad) {
|
||||
+ return;
|
||||
+ }
|
||||
+ y[i] = i < n ? __float2bfloat16(x[i]) : (nv_bfloat16) 0.0f;
|
||||
+}
|
||||
+
|
||||
+// ---------------------------------------------------------------------------
|
||||
+// Grouped W4A16 GEMM. For each output tile (blockIdx.x = N-block, blockIdx.y = M-tile):
|
||||
+// expert e = g_tile_expert[blockIdx.y]
|
||||
+// row_start = g_tile_row0[blockIdx.y] (absolute row in the sorted buffer)
|
||||
+// row_count = g_tile_rows[blockIdx.y] (valid rows in this tile, <= BM)
|
||||
+// Weights read from W = src0 + e*expert_stride_blocks (block_nvfp4 [N,Kb]); activations from
|
||||
+// Abf (bf16, sorted); output to C (f32, sorted, [N, total_rows] = C[row*N + col]).
|
||||
+// Weights are dequantized FP4->bf16 in registers; A via ldmatrix; bf16 m16n8k16 mma.
|
||||
+// BK = 64 (one nvfp4 block per K-step); STAGES-deep cp.async pipeline over the Kb blocks.
|
||||
+// ---------------------------------------------------------------------------
|
||||
+template<int BM, int BN, int WARPS_M, int WARPS_N, int STAGES>
|
||||
+__launch_bounds__(WARPS_M*WARPS_N*32, 1)
|
||||
+static __global__ void w4a16_grouped_kernel(
|
||||
+ const nv_bfloat16 * __restrict__ Abf, // [pad_rows, K] bf16
|
||||
+ const block_nvfp4 * __restrict__ W0, // src0 base (expert 0)
|
||||
+ float * __restrict__ C, // [total_rows, N] f32
|
||||
+ const int * __restrict__ g_tile_expert,
|
||||
+ const int * __restrict__ g_tile_row0,
|
||||
+ const int * __restrict__ g_tile_rows,
|
||||
+ int N, int K, int64_t expert_stride_blocks) {
|
||||
+#if defined(AMPERE_MMA_AVAILABLE) && defined(CP_ASYNC_AVAILABLE)
|
||||
+ constexpr int BK = 64; // one nvfp4 block
|
||||
+ constexpr int NWARP = WARPS_M*WARPS_N;
|
||||
+ constexpr int THREADS = NWARP*32;
|
||||
+ constexpr int WM = BM/WARPS_M, WN = BN/WARPS_N;
|
||||
+ constexpr int MF = WM/16, NF = WN/8;
|
||||
+
|
||||
+ constexpr int AN = BK/2; // bf16 pairs per A smem row (nv_bfloat162)
|
||||
+ constexpr int SZ_A = BM*AN; // nv_bfloat162 per stage
|
||||
+ constexpr int SZ_WQ = BN*8; // u32 per stage (32 qs bytes/row)
|
||||
+ constexpr int SZ_WD = BN; // u32 per stage (4 scale bytes/row)
|
||||
+
|
||||
+ extern __shared__ uint32_t smem_u32[];
|
||||
+ // Layout per stage: [A as u32 = nv_bfloat162][Wq u32][Wd u32]
|
||||
+ constexpr int STAGE_U32 = SZ_A + SZ_WQ + SZ_WD;
|
||||
+ nv_bfloat162 * sA[STAGES];
|
||||
+ uint32_t * sWq[STAGES];
|
||||
+ uint32_t * sWd[STAGES];
|
||||
+#pragma unroll
|
||||
+ for (int s = 0; s < STAGES; s++) {
|
||||
+ uint32_t * base = smem_u32 + s*STAGE_U32;
|
||||
+ sA[s] = (nv_bfloat162 *) base;
|
||||
+ sWq[s] = base + SZ_A;
|
||||
+ sWd[s] = base + SZ_A + SZ_WQ;
|
||||
+ }
|
||||
+
|
||||
+ // mma.cuh's tile ops (load_ldmatrix / mma / tile::get_i/get_j) use threadIdx.x AS THE WARP LANE,
|
||||
+ // so the block MUST be 2D (32, NWARP): threadIdx.x = lane (0..31), threadIdx.y = warp.
|
||||
+ const int lane = threadIdx.x; // 0..31
|
||||
+ const int warp = threadIdx.y; // 0..NWARP-1
|
||||
+ const int tid = warp*32 + lane; // linear id for the cp.async strided copies
|
||||
+ const int wrow = warp / WARPS_N, wcol = warp % WARPS_N;
|
||||
+
|
||||
+ const int e = g_tile_expert[blockIdx.y];
|
||||
+ const int row0 = g_tile_row0[blockIdx.y];
|
||||
+ const int rcount = g_tile_rows[blockIdx.y];
|
||||
+ const int blockCol = blockIdx.x*BN;
|
||||
+ const int Kb = K/64;
|
||||
+ const block_nvfp4 * We = W0 + (int64_t) e*expert_stride_blocks; // expert e weight base
|
||||
+
|
||||
+ tile_C acc[MF][NF];
|
||||
+
|
||||
+ // async-load K-block `kt` into stage `st`
|
||||
+ auto load_tile = [&](int st, int kt) {
|
||||
+ // A: BM rows x BK bf16 = BM x AN nv_bfloat162 = BM x (BK/8) 16B chunks
|
||||
+ const int A_chunks = BM*(BK/8);
|
||||
+#pragma unroll 1
|
||||
+ for (int idx = tid; idx < A_chunks; idx += THREADS) {
|
||||
+ const int c = idx % (BK/8); // 16B chunk in the row
|
||||
+ const int r = idx / (BK/8); // row in tile
|
||||
+ const nv_bfloat16 * src = Abf + (int64_t)(row0 + r)*K + (int64_t)kt*BK + c*8;
|
||||
+ w4a16_cp_async16(((char *) sA[st]) + (r*AN + c*4)*sizeof(uint32_t), src);
|
||||
+ }
|
||||
+ // W qs: BN rows x 32 bytes = BN x 8 u32 (each block's qs at byte offset 4)
|
||||
+#pragma unroll 1
|
||||
+ for (int idx = tid; idx < BN*8; idx += THREADS) {
|
||||
+ const int w = idx & 7; // u32 word in the 32-byte qs
|
||||
+ const int r = idx >> 3; // row in tile
|
||||
+ const block_nvfp4 * blk = We + (int64_t)(blockCol + r)*Kb + kt;
|
||||
+ const char * src = ((const char *) blk) + 4 /*d[4]*/ + w*4;
|
||||
+ w4a16_cp_async4(&sWq[st][r*8 + w], src);
|
||||
+ }
|
||||
+ // W scales: BN rows x 4 bytes (one u32 each, the block's d[4] at byte offset 0)
|
||||
+#pragma unroll 1
|
||||
+ for (int r = tid; r < BN; r += THREADS) {
|
||||
+ const block_nvfp4 * blk = We + (int64_t)(blockCol + r)*Kb + kt;
|
||||
+ w4a16_cp_async4(&sWd[st][r], (const char *) blk);
|
||||
+ }
|
||||
+ };
|
||||
+
|
||||
+ // prologue
|
||||
+#pragma unroll
|
||||
+ for (int s = 0; s < STAGES-1; s++) { if (s < Kb) load_tile(s, s); w4a16_cp_commit(); }
|
||||
+
|
||||
+ for (int kt = 0; kt < Kb; kt++) {
|
||||
+ const int ld = kt + (STAGES-1);
|
||||
+ if (ld < Kb) load_tile(ld % STAGES, ld);
|
||||
+ w4a16_cp_commit();
|
||||
+ w4a16_cp_wait<STAGES-1>();
|
||||
+ __syncthreads();
|
||||
+
|
||||
+ const int rs = kt % STAGES;
|
||||
+ const nv_bfloat162 * sAcur = sA[rs];
|
||||
+ const uint8_t * sWqb = (const uint8_t *) sWq[rs]; // BN rows x 32 bytes
|
||||
+ const uint32_t * sWdw = sWd[rs]; // BN rows x 1 u32 (4 scale bytes)
|
||||
+
|
||||
+#pragma unroll
|
||||
+ for (int kk = 0; kk < BK/16; kk++) { // 4 m16n8k16 sub-steps per 64-block
|
||||
+ const int sub = kk; // sub-block (0..3): selects scale + nibble half
|
||||
+ // A fragments via ldmatrix (bf16)
|
||||
+ tile_A A_frag[MF];
|
||||
+#pragma unroll
|
||||
+ for (int mi = 0; mi < MF; mi++) {
|
||||
+ const int rb = wrow*WM + mi*16;
|
||||
+ load_ldmatrix(A_frag[mi], sAcur + rb*AN + kk*8, AN);
|
||||
+ }
|
||||
+ // B fragments: in-register FP4->bf16 dequant (correct-by-construction via tile get_i/get_j)
|
||||
+ tile_B B_frag[NF];
|
||||
+ const int n_local = lane >> 2; // tile_B::get_i (row N, 0..7)
|
||||
+ const int jc = lane & 3; // lane%4
|
||||
+ const int qbyte = sub*8 + 2*jc; // qs byte index for this lane within the block
|
||||
+#pragma unroll
|
||||
+ for (int ni = 0; ni < NF; ni++) {
|
||||
+ const int nrow = wcol*WN + ni*8 + n_local; // col within BN tile [0,BN)
|
||||
+ const uint8_t * qsb = sWqb + nrow*32; // this row's 32 qs bytes
|
||||
+ const uint8_t b0 = qsb[qbyte];
|
||||
+ const uint8_t b1 = qsb[qbyte + 1];
|
||||
+ const float sc = ggml_cuda_ue4m3_to_fp32(((const uint8_t *) &sWdw[nrow])[sub]);
|
||||
+ // x[0]: low nibbles (k = 2jc, 2jc+1)
|
||||
+ B_frag[ni].x[0].x = __float2bfloat16(sc * (float) kvalues_mxfp4[b0 & 0x0F]);
|
||||
+ B_frag[ni].x[0].y = __float2bfloat16(sc * (float) kvalues_mxfp4[b1 & 0x0F]);
|
||||
+ // x[1]: high nibbles (k = 8+2jc, 9+2jc)
|
||||
+ B_frag[ni].x[1].x = __float2bfloat16(sc * (float) kvalues_mxfp4[b0 >> 4]);
|
||||
+ B_frag[ni].x[1].y = __float2bfloat16(sc * (float) kvalues_mxfp4[b1 >> 4]);
|
||||
+ }
|
||||
+#pragma unroll
|
||||
+ for (int mi = 0; mi < MF; mi++)
|
||||
+#pragma unroll
|
||||
+ for (int ni = 0; ni < NF; ni++)
|
||||
+ mma(acc[mi][ni], A_frag[mi], B_frag[ni]);
|
||||
+ }
|
||||
+ __syncthreads();
|
||||
+ }
|
||||
+
|
||||
+ // write back (mask the ragged per-expert row tail)
|
||||
+#pragma unroll
|
||||
+ for (int mi = 0; mi < MF; mi++)
|
||||
+#pragma unroll
|
||||
+ for (int ni = 0; ni < NF; ni++) {
|
||||
+ const int orow = wrow*WM + mi*16;
|
||||
+ const int ocol = blockCol + wcol*WN + ni*8;
|
||||
+#pragma unroll
|
||||
+ for (int l = 0; l < acc[mi][ni].ne; l++) {
|
||||
+ const int lr = orow + acc[mi][ni].get_i(l); // local row within tile
|
||||
+ const int nc = ocol + acc[mi][ni].get_j(l); // global col
|
||||
+ if (lr < rcount && nc < N) {
|
||||
+ C[(int64_t)(row0 + lr)*N + nc] = acc[mi][ni].x[l];
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+#else
|
||||
+ GGML_UNUSED(Abf); GGML_UNUSED(W0); GGML_UNUSED(C);
|
||||
+ GGML_UNUSED(g_tile_expert); GGML_UNUSED(g_tile_row0); GGML_UNUSED(g_tile_rows);
|
||||
+ GGML_UNUSED(N); GGML_UNUSED(K); GGML_UNUSED(expert_stride_blocks);
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // AMPERE_MMA_AVAILABLE && CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// host integration
|
||||
+// ===========================================================================
|
||||
+
|
||||
+bool ggml_cuda_w4a16_moe_grouped_should_engage(
|
||||
+ const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * dst, int cc) {
|
||||
+ if (src0->type != GGML_TYPE_NVFP4) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ if (!blackwell_mma_available(cc)) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ if (!ggml_cuda_w4a16_prefill_enabled()) {
|
||||
+ return false; // default-off == stock
|
||||
+ }
|
||||
+ if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ // ne12 = total tokens (aggregate prefill M); only LARGE M (prefill), never decode.
|
||||
+ if (src1->ne[2] <= ggml_cuda_w4a16_prefill_m()) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ const int64_t K = src0->ne[0];
|
||||
+ const int64_t N = src0->ne[1];
|
||||
+ if (N % 128 != 0 || K % 64 != 0) {
|
||||
+ return false; // tile constraints; else fall back to per-expert/MMQ
|
||||
+ }
|
||||
+ return true;
|
||||
+}
|
||||
+
|
||||
+void ggml_cuda_mul_mat_id_w4a16_grouped(
|
||||
+ ggml_backend_cuda_context & ctx,
|
||||
+ const ggml_tensor * src0,
|
||||
+ const float * src1_sorted,
|
||||
+ float * dst_sorted,
|
||||
+ const int * tokens_per_expert,
|
||||
+ int64_t n_experts, int64_t K, int64_t N,
|
||||
+ cudaStream_t stream) {
|
||||
+ GGML_ASSERT(src0->type == GGML_TYPE_NVFP4);
|
||||
+ GGML_ASSERT(N % 128 == 0 && K % 64 == 0);
|
||||
+
|
||||
+ constexpr int BM = 64, BN = 128, WARPS_M = 2, WARPS_N = 4, STAGES = 2;
|
||||
+
|
||||
+ // host: build the per-M-tile expert map (ragged, no tile crosses an expert boundary)
|
||||
+ int64_t total_rows = 0;
|
||||
+ for (int64_t e = 0; e < n_experts; e++) {
|
||||
+ total_rows += tokens_per_expert[e];
|
||||
+ }
|
||||
+ if (total_rows == 0) {
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ std::vector<int32_t> h_tile_expert, h_tile_row0, h_tile_rows;
|
||||
+ int64_t row = 0;
|
||||
+ for (int64_t e = 0; e < n_experts; e++) {
|
||||
+ const int t = tokens_per_expert[e];
|
||||
+ for (int off = 0; off < t; off += BM) {
|
||||
+ h_tile_expert.push_back((int32_t) e);
|
||||
+ h_tile_row0.push_back((int32_t) (row + off));
|
||||
+ h_tile_rows.push_back((int32_t) std::min(BM, t - off));
|
||||
+ }
|
||||
+ row += t;
|
||||
+ }
|
||||
+ const int n_tiles = (int) h_tile_expert.size();
|
||||
+
|
||||
+ if (getenv("LLAMA_W4A16_DEBUG")) {
|
||||
+ int max_tpe = 0, multi = 0;
|
||||
+ for (int64_t e = 0; e < n_experts; e++) {
|
||||
+ if (tokens_per_expert[e] > max_tpe) max_tpe = tokens_per_expert[e];
|
||||
+ if (tokens_per_expert[e] > BM) multi++;
|
||||
+ }
|
||||
+ fprintf(stderr, "[w4a16] engaged: total_rows=%lld n_experts=%lld K=%lld N=%lld n_tiles=%d max_tpe=%d multi_tile_experts=%d\n",
|
||||
+ (long long) total_rows, (long long) n_experts, (long long) K, (long long) N, n_tiles, max_tpe, multi);
|
||||
+ }
|
||||
+
|
||||
+ // device: tile map
|
||||
+ ggml_cuda_pool_alloc<int32_t> d_tile_expert(ctx.pool(), n_tiles);
|
||||
+ ggml_cuda_pool_alloc<int32_t> d_tile_row0 (ctx.pool(), n_tiles);
|
||||
+ ggml_cuda_pool_alloc<int32_t> d_tile_rows (ctx.pool(), n_tiles);
|
||||
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_expert.ptr, h_tile_expert.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_row0.ptr, h_tile_row0.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_rows.ptr, h_tile_rows.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
+
|
||||
+ // activations: f32 -> bf16 (cheap cast, NO act-quant), zero-padded so every tile's BM-row read
|
||||
+ // stays in-bounds. A tile's row0 is generally NOT BM-aligned (experts start mid-buffer), and a
|
||||
+ // tile can begin as late as total_rows-1, so it can read up to total_rows-1+BM; add a full BM of
|
||||
+ // zero headroom on top of the BM-rounded length to cover that worst case.
|
||||
+ const int64_t pad_rows = (((total_rows + BM - 1) / BM) + 1) * BM;
|
||||
+ ggml_cuda_pool_alloc<nv_bfloat16> Abf(ctx.pool(), (size_t) pad_rows * K);
|
||||
+ {
|
||||
+ const int64_t n = total_rows * K;
|
||||
+ const int64_t npad = pad_rows * K;
|
||||
+ const int threads = 256;
|
||||
+ const int64_t grid = (npad + threads - 1) / threads;
|
||||
+ w4a16_cast_act_f32_bf16<<<grid, threads, 0, stream>>>(src1_sorted, Abf.get(), n, npad);
|
||||
+ CUDA_CHECK(cudaGetLastError());
|
||||
+ }
|
||||
+
|
||||
+ const int64_t expert_stride_blocks = (int64_t) (src0->nb[2] / sizeof(block_nvfp4));
|
||||
+
|
||||
+ auto kern = w4a16_grouped_kernel<BM, BN, WARPS_M, WARPS_N, STAGES>;
|
||||
+ constexpr int STAGE_U32 = BM*(64/2) + BN*8 + BN;
|
||||
+ const int smem_bytes = STAGES * STAGE_U32 * (int) sizeof(uint32_t);
|
||||
+ CUDA_CHECK(cudaFuncSetAttribute(kern, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes));
|
||||
+
|
||||
+ dim3 grid((unsigned) (N / BN), (unsigned) n_tiles);
|
||||
+ dim3 block(32, WARPS_M*WARPS_N); // 2D: threadIdx.x = warp lane, threadIdx.y = warp
|
||||
+ kern<<<grid, block, smem_bytes, stream>>>(
|
||||
+ Abf.get(), (const block_nvfp4 *) src0->data, dst_sorted,
|
||||
+ d_tile_expert.ptr, d_tile_row0.ptr, d_tile_rows.ptr,
|
||||
+ (int) N, (int) K, expert_stride_blocks);
|
||||
+ CUDA_CHECK(cudaGetLastError());
|
||||
+}
|
||||
diff --git a/ggml/src/ggml-cuda/w4a16-gemm.cuh b/ggml/src/ggml-cuda/w4a16-gemm.cuh
|
||||
new file mode 100644
|
||||
index 0000000..2287d6f
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/ggml-cuda/w4a16-gemm.cuh
|
||||
@@ -0,0 +1,55 @@
|
||||
+#pragma once
|
||||
+
|
||||
+#include "common.cuh"
|
||||
+
|
||||
+// [paged patch 0035] Marlin-style W4A16 GROUPED MoE prefill GEMM for Blackwell sm_121a (GB10).
|
||||
+//
|
||||
+// This is the profile-validated #2 prefill lever and a DISTINCT kernel from the two prefill
|
||||
+// rejects:
|
||||
+// - NOT patch 0033 (separate-pass dequant -> bf16 cuBLAS / nvjet): that pays a full per-step
|
||||
+// weight dequant + 4x bf16 weight traffic and lost to FP4-MMQ at large M.
|
||||
+// - NOT patch 0034 (native W4A4 FP4-MMA, mxf4nvf4 block-scale OMMA): that quantizes the
|
||||
+// activations to FP4 and so still pays the quantize_mmq_nvfp4 activation-quant tax.
|
||||
+//
|
||||
+// The winning shape vLLM uses on this silicon (Marlin W4A16): the FP4 expert weights are
|
||||
+// dequantized to bf16 IN REGISTERS right before the MMA (never materialized to global/smem as
|
||||
+// bf16), the activations stay bf16 (a cheap f32->bf16 cast, NO per-block FP4 amax/code-search
|
||||
+// quantize), and the product is a standard bf16 m16n8k16 mma.sync feeding f32 accumulators,
|
||||
+// cp.async multistage-pipelined over the K loop. So W4A16 pays ZERO activation-quant (the paged
|
||||
+// FP4-MMQ path's quantize_mmq_nvfp4 is +15 us/tok) and the GEMM runs as a bf16 tensor-core GEMM
|
||||
+// with the weight read at 4 bits.
|
||||
+//
|
||||
+// GROUPED: the kernel is launched ONCE over the whole mul_mat_id token-sorted activation buffer
|
||||
+// (src1_sorted is already sorted-by-expert by the existing host-loop), with a per-M-tile expert
|
||||
+// map so each output tile reads its expert's weight matrix (src0 + expert*nb02) and the ragged
|
||||
+// per-expert row tail is masked. No per-expert kernel launch, no per-expert M-padding waste.
|
||||
+//
|
||||
+// Engages ONLY at large aggregate-M (prefill), behind LLAMA_W4A16_PREFILL_M (default 0 == OFF
|
||||
+// == stock); decode (small ne12) and the non-MoE / non-NVFP4 paths are byte-untouched. The bf16
|
||||
+// tiles are mma.cuh's (mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32).
|
||||
+
|
||||
+// True if the grouped W4A16 path should handle this mul_mat_id:
|
||||
+// src0 NVFP4, src1 f32, dst f32, Blackwell, LLAMA_W4A16_PREFILL_M>0,
|
||||
+// ne12 (total tokens / aggregate prefill M) > threshold, N=ne0 % 128 == 0, K=ne10 % 64 == 0.
|
||||
+bool ggml_cuda_w4a16_moe_grouped_should_engage(
|
||||
+ const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * dst, int cc);
|
||||
+
|
||||
+// True iff LLAMA_W4A16_PREFILL_M > 0 (the master on/off for the mmq.cu grouped-MMQ-off gate).
|
||||
+bool ggml_cuda_w4a16_prefill_enabled();
|
||||
+int64_t ggml_cuda_w4a16_prefill_m();
|
||||
+
|
||||
+// Grouped W4A16 MoE GEMM over the token-sorted buffer.
|
||||
+// src0 : NVFP4 weights [K, N, n_experts] (one [K,N] matrix per expert)
|
||||
+// src1_sorted : f32 [K, total_rows], rows already sorted by expert (the mul_mat_id host-loop's
|
||||
+// src1_sorted), with tokens_per_expert[e] consecutive rows per expert e
|
||||
+// dst_sorted : f32 [N, total_rows], written in the same sorted order
|
||||
+// tokens_per_expert : host vector, length n_experts
|
||||
+// Streams on `stream`, pool-allocates scratch (bf16 activations + device tile map); no host sync.
|
||||
+void ggml_cuda_mul_mat_id_w4a16_grouped(
|
||||
+ ggml_backend_cuda_context & ctx,
|
||||
+ const ggml_tensor * src0,
|
||||
+ const float * src1_sorted,
|
||||
+ float * dst_sorted,
|
||||
+ const int * tokens_per_expert,
|
||||
+ int64_t n_experts, int64_t K, int64_t N,
|
||||
+ cudaStream_t stream);
|
||||
--
|
||||
2.43.0
|
||||
|
||||
@@ -0,0 +1,365 @@
|
||||
From 1434cf7e078217c625062dcfde4fa91cf487ee86 Mon Sep 17 00:00:00 2001
|
||||
From: Ettore Di Giacinto <mudler@localai.io>
|
||||
Date: Sun, 28 Jun 2026 20:19:31 +0200
|
||||
Subject: [PATCH] feat(paged): fused residual-add + RMS norm + weight multiply
|
||||
(patch 0042)
|
||||
|
||||
The transformer pre-norm residual chain `h = x + sub_out; n = rms_norm(h) * w`
|
||||
runs as separate CUDA launches in the paged prefill graph: a k_bin_bcast ADD
|
||||
(the residual) feeding the existing fused rms_norm+mul. ggml-cuda already fuses
|
||||
rms_norm+mul (and rms_norm+mul+ADD, where the ADD is a *post*-norm bias) but NOT
|
||||
the *pre*-norm residual add that feeds the norm. This is the classic add-RMSNorm
|
||||
fusion (as in vLLM / TensorRT-LLM) that ggml-cuda lacks; it is part of the
|
||||
unfused-tail prefill gap vs vLLM's torch.compile fusions.
|
||||
|
||||
Add it as a CUDA-family graph fusion (paged series owns it; stock stays pure):
|
||||
- ggml_cuda_can_fuse recognizes { ADD, RMS_NORM, MUL } via ggml_can_fuse_subgraph
|
||||
with BOTH the ADD (node_idx) and the MUL (node_idx+2) marked as outputs - the
|
||||
residual ADD has a second consumer (the later skip-connection add), so it
|
||||
cannot pass the single-use ggml_can_fuse() gate the other rms_norm fusions use.
|
||||
- New kernel rms_norm_pre_add_mul_f32 computes h = a + b, publishes h to the
|
||||
residual buffer (downstream skip add reads it), then sum(h^2) -> scale ->
|
||||
dst = scale * h * w in ONE launch, emitting BOTH outputs the graph needs.
|
||||
- Gated by LLAMA_FUSE_ADD_RMSNORM (default ON) for a clean single-build A/B.
|
||||
|
||||
BIT-EXACT (per-path canonical greedy md5, n=48 --temp 0 --seed 1, paged):
|
||||
dense q36-27b-nvfp4 : 5951a5b4d624ce891e22ab5fca9bc439 (ON == OFF == canonical)
|
||||
MoE q36-35b-a3b : 8cb0ce23777bf55f92f63d0292c756b0 (ON == OFF == canonical)
|
||||
The fused kernel reproduces the exact FP order of the unfused chain: h = a + b
|
||||
(IEEE add is order-free), the sum(h^2) reduction uses the same block_reduce<SUM>
|
||||
with the same 256/1024 block-size thresholds, and the same rsqrtf(mean+eps)
|
||||
scale, so the byte stream is unchanged. test-backend-ops RMS_NORM/ADD/MUL pass
|
||||
(CUDA0 vs CPU).
|
||||
|
||||
PROFILE (dense prefill, nsys --cuda-graph-trace=node, npp512 ntg4 npl8):
|
||||
rms_norm_f32<1024> 903 launches / 96.6M ns -> 7 / 0.7M ns
|
||||
k_bin_bcast<op_add> 1232 launches / 138.6M ns -> 336 / 1.0M ns
|
||||
rms_norm_pre_add_mul (new) 896 launches / 187.2M ns
|
||||
-> 896 residual-add + 896 rms_norm launches folded into 896 fused launches;
|
||||
the norm+residual slice 233.6M -> 187.2M ns (~20% of that slice, ~1% of
|
||||
total prefill GPU time).
|
||||
S_PP dense (npp512 ntg4 npl32, 3x): 985.5 -> 990.6 t/s (+0.5%, every ON run
|
||||
beats every OFF run). Modest because the residual tail is a small slice of
|
||||
prefill; the dominant unfused cost is k_bin_bcast<op_mul> (11%, the GDN
|
||||
chunked-prefill gating muls) - a separate lever.
|
||||
|
||||
Assisted-by: Claude:opus-4.8 [Claude Code]
|
||||
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
|
||||
---
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 54 +++++++++
|
||||
ggml/src/ggml-cuda/norm.cu | 196 ++++++++++++++++++++++++++++++++
|
||||
ggml/src/ggml-cuda/norm.cuh | 5 +
|
||||
3 files changed, 255 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 0dad6e1..2ecc971 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -3698,6 +3698,48 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
||||
}
|
||||
}
|
||||
|
||||
+ // Fused residual-add + RMS norm + weight multiply. The transformer residual
|
||||
+ // ADD feeds the next sublayer's RMS norm but is ALSO consumed by the later
|
||||
+ // residual add (skip connection), so the ADD node is a graph output too; it
|
||||
+ // cannot go through the single-use ggml_can_fuse() gate below. Recognize it
|
||||
+ // here with ggml_can_fuse_subgraph, marking both the ADD (node_idx) and the
|
||||
+ // final MUL (node_idx + 2) as outputs.
|
||||
+ std::initializer_list<enum ggml_op> add_rms_norm_mul_ops = { GGML_OP_ADD, GGML_OP_RMS_NORM, GGML_OP_MUL };
|
||||
+ if (is_equal(add_rms_norm_mul_ops, ops) &&
|
||||
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx, node_idx + 2 })) {
|
||||
+ const ggml_tensor * add = cgraph->nodes[node_idx];
|
||||
+ const ggml_tensor * rms_norm = cgraph->nodes[node_idx + 1];
|
||||
+ const ggml_tensor * mul = cgraph->nodes[node_idx + 2];
|
||||
+
|
||||
+ // RMS norm must consume the residual-add output.
|
||||
+ if (rms_norm->src[0] != add) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ // All operands F32 (rms norm / fused mul kernel only support F32).
|
||||
+ if (add->src[0]->type != GGML_TYPE_F32 || add->src[1]->type != GGML_TYPE_F32 ||
|
||||
+ add->type != GGML_TYPE_F32 || rms_norm->type != GGML_TYPE_F32 ||
|
||||
+ mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 ||
|
||||
+ mul->type != GGML_TYPE_F32) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ // The fused kernel computes h = a + b elementwise: same shape, no broadcast.
|
||||
+ if (!ggml_are_same_shape(add->src[0], add->src[1])) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ // rms_norm kernel assumes contiguous rows for the residual operands and weight.
|
||||
+ if (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous(add->src[1])) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ // If rms_norm is the B operand of the mul, broadcast of the A operand is unsupported.
|
||||
+ if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ return true;
|
||||
+ }
|
||||
+
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
@@ -4220,6 +4262,18 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
|
||||
return fused_node_count - 1;
|
||||
}
|
||||
|
||||
+ // Fused residual-add + RMS norm + weight multiply (bit-exact). Default ON;
|
||||
+ // set LLAMA_FUSE_ADD_RMSNORM=0 for a clean A/B against the unfused path.
|
||||
+ static const bool fuse_add_rmsnorm = [] {
|
||||
+ const char * e = getenv("LLAMA_FUSE_ADD_RMSNORM");
|
||||
+ return e == nullptr || atoi(e) != 0;
|
||||
+ }();
|
||||
+ if (fuse_add_rmsnorm &&
|
||||
+ ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
|
||||
+ ggml_cuda_op_rms_norm_pre_add_mul(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
|
||||
+ return 2;
|
||||
+ }
|
||||
+
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD }, {})) {
|
||||
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
|
||||
return 2;
|
||||
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
|
||||
index 09d9f3a..a07d022 100644
|
||||
--- a/ggml/src/ggml-cuda/norm.cu
|
||||
+++ b/ggml/src/ggml-cuda/norm.cu
|
||||
@@ -154,6 +154,87 @@ static __global__ void rms_norm_f32(const float * x,
|
||||
}
|
||||
}
|
||||
|
||||
+// Fused residual-add + RMS norm + (optional) weight multiply.
|
||||
+// h = a + b (the residual stream, written to h_out)
|
||||
+// dst = rsqrt(mean(h^2)+eps) * h * mul
|
||||
+// `a` and `b` are required to be the same shape and contiguous (the transformer
|
||||
+// residual add), so they share `x`'s strides; `h_out`, `dst` are also contiguous
|
||||
+// with that shape. `mul` (the RMS weight) broadcasts via the packed-modulo path.
|
||||
+//
|
||||
+// Bit-exactness: this reproduces the exact FP order of the unfused chain
|
||||
+// k_bin_bcast(add): h[col] = a[col] + b[col] (f32, elementwise, order-free)
|
||||
+// rms_norm: sumsq over h[col] in column order via block_reduce
|
||||
+// mul: dst[col] = scale * h[col] * mul[col]
|
||||
+// h is summed from the same f32 values in the same order, so the reduction and
|
||||
+// the final scale are byte-identical to running the three kernels separately.
|
||||
+template <int block_size, bool do_multiply = false>
|
||||
+static __global__ void rms_norm_pre_add_mul_f32(const float * a,
|
||||
+ const float * b,
|
||||
+ float * h_out,
|
||||
+ float * dst,
|
||||
+ const int ncols,
|
||||
+ const int64_t stride_row,
|
||||
+ const int64_t stride_channel,
|
||||
+ const int64_t stride_sample,
|
||||
+ const float eps,
|
||||
+ const float * mul = nullptr,
|
||||
+ const int64_t mul_stride_row = 0,
|
||||
+ const int64_t mul_stride_channel = 0,
|
||||
+ const int64_t mul_stride_sample = 0,
|
||||
+ const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
|
||||
+ const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
|
||||
+ const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
|
||||
+ const uint3 mul_nsamples_packed = make_uint3(0, 0, 0)) {
|
||||
+ ggml_cuda_pdl_lc();
|
||||
+ const int nrows = gridDim.x;
|
||||
+ const int nchannels = gridDim.y;
|
||||
+
|
||||
+ const int row = blockIdx.x;
|
||||
+ const int channel = blockIdx.y;
|
||||
+ const int sample = blockIdx.z;
|
||||
+ const int tid = threadIdx.x;
|
||||
+
|
||||
+ const int64_t row_offset = sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
+ a += row_offset;
|
||||
+ b += row_offset;
|
||||
+ h_out += row_offset;
|
||||
+ // dst is laid out contiguously by the scheduler for the MUL output
|
||||
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
+
|
||||
+ if constexpr (do_multiply) {
|
||||
+ const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
|
||||
+ const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
|
||||
+ const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
|
||||
+ mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
|
||||
+ }
|
||||
+
|
||||
+ float tmp = 0.0f; // partial sum for thread in warp
|
||||
+
|
||||
+ ggml_cuda_pdl_sync();
|
||||
+ for (int col = tid; col < ncols; col += block_size) {
|
||||
+ const float hi = a[col] + b[col];
|
||||
+ h_out[col] = hi; // publish the residual stream for the next add
|
||||
+ tmp += hi * hi;
|
||||
+ }
|
||||
+
|
||||
+ // sum up partial sums
|
||||
+ extern __shared__ float s_sum[];
|
||||
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
|
||||
+
|
||||
+ const float mean = tmp / ncols;
|
||||
+ const float scale = rsqrtf(mean + eps);
|
||||
+
|
||||
+ for (int col = tid; col < ncols; col += block_size) {
|
||||
+ const float hi = h_out[col];
|
||||
+ if constexpr (do_multiply) {
|
||||
+ const int mul_col = fastmodulo(col, mul_ncols_packed);
|
||||
+ dst[col] = scale * hi * mul[mul_col];
|
||||
+ } else {
|
||||
+ dst[col] = scale * hi;
|
||||
+ }
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
template <int block_size>
|
||||
static __global__ void rms_norm_back_f32(
|
||||
const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
|
||||
@@ -407,6 +488,50 @@ static void rms_norm_mul_f32_cuda(const float * x,
|
||||
}
|
||||
}
|
||||
|
||||
+static void rms_norm_pre_add_mul_f32_cuda(const float * a,
|
||||
+ const float * b,
|
||||
+ float * h_out,
|
||||
+ float * dst,
|
||||
+ const int ncols,
|
||||
+ const int nrows,
|
||||
+ const int nchannels,
|
||||
+ const int nsamples,
|
||||
+ const int64_t stride_row,
|
||||
+ const int64_t stride_channel,
|
||||
+ const int64_t stride_sample,
|
||||
+ const float * mul,
|
||||
+ const int64_t mul_stride_row,
|
||||
+ const int64_t mul_stride_channel,
|
||||
+ const int64_t mul_stride_sample,
|
||||
+ const uint32_t mul_ncols,
|
||||
+ const uint32_t mul_nrows,
|
||||
+ const uint32_t mul_nchannels,
|
||||
+ const uint32_t mul_nsamples,
|
||||
+ const float eps,
|
||||
+ cudaStream_t stream) {
|
||||
+ const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
+ GGML_ASSERT(mul != nullptr);
|
||||
+ const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
|
||||
+ const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
|
||||
+ const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
|
||||
+ const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
|
||||
+ if (ncols < 1024) {
|
||||
+ const dim3 block_dims(256, 1, 1);
|
||||
+ const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float) : 0, stream};
|
||||
+ ggml_cuda_kernel_launch(rms_norm_pre_add_mul_f32<256, true>, launch_params,
|
||||
+ a, b, h_out, dst, ncols, stride_row, stride_channel, stride_sample, eps,
|
||||
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
|
||||
+ mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
||||
+ } else {
|
||||
+ const dim3 block_dims(1024, 1, 1);
|
||||
+ const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float) : 0, stream};
|
||||
+ ggml_cuda_kernel_launch(rms_norm_pre_add_mul_f32<1024, true>, launch_params,
|
||||
+ a, b, h_out, dst, ncols, stride_row, stride_channel, stride_sample, eps,
|
||||
+ mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
|
||||
+ mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
@@ -647,6 +772,77 @@ void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
|
||||
eps, stream);
|
||||
}
|
||||
|
||||
+void ggml_cuda_op_rms_norm_pre_add_mul(ggml_backend_cuda_context & ctx,
|
||||
+ ggml_tensor * add_tensor,
|
||||
+ ggml_tensor * rms_norm_tensor,
|
||||
+ ggml_tensor * mul_tensor) {
|
||||
+ // The RMS norm consumes the residual-add output.
|
||||
+ GGML_ASSERT(rms_norm_tensor->src[0] == add_tensor);
|
||||
+
|
||||
+ const ggml_tensor * a_src = add_tensor->src[0];
|
||||
+ const ggml_tensor * b_src = add_tensor->src[1];
|
||||
+
|
||||
+ float eps = 0.0f;
|
||||
+ memcpy(&eps, rms_norm_tensor->op_params, sizeof(float));
|
||||
+ GGML_ASSERT(eps >= 0.0f);
|
||||
+
|
||||
+ const float * a_d = (const float *) a_src->data;
|
||||
+ const float * b_d = (const float *) b_src->data;
|
||||
+ float * h_d = (float *) add_tensor->data;
|
||||
+
|
||||
+ const float * mul_d = nullptr;
|
||||
+ const ggml_tensor * mul_src = nullptr;
|
||||
+ if (mul_tensor->src[0] == rms_norm_tensor) {
|
||||
+ mul_d = (const float *) mul_tensor->src[1]->data;
|
||||
+ mul_src = mul_tensor->src[1];
|
||||
+ } else if (mul_tensor->src[1] == rms_norm_tensor) {
|
||||
+ mul_d = (const float *) mul_tensor->src[0]->data;
|
||||
+ mul_src = mul_tensor->src[0];
|
||||
+ } else {
|
||||
+ GGML_ASSERT(false);
|
||||
+ }
|
||||
+
|
||||
+ float * dst_d = (float *) mul_tensor->data;
|
||||
+ cudaStream_t stream = ctx.stream();
|
||||
+
|
||||
+ GGML_ASSERT(a_src->type == GGML_TYPE_F32);
|
||||
+ GGML_ASSERT(b_src->type == GGML_TYPE_F32);
|
||||
+ GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
|
||||
+ GGML_ASSERT(rms_norm_tensor->type == GGML_TYPE_F32);
|
||||
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
|
||||
+ GGML_ASSERT(ggml_are_same_shape(a_src, b_src));
|
||||
+
|
||||
+ const int64_t ne00 = add_tensor->ne[0];
|
||||
+ const int64_t ne01 = add_tensor->ne[1];
|
||||
+ const int64_t ne02 = add_tensor->ne[2];
|
||||
+ const int64_t ne03 = add_tensor->ne[3];
|
||||
+
|
||||
+ // a and b share the (contiguous) residual layout
|
||||
+ const size_t ts0 = ggml_type_size(a_src->type);
|
||||
+ GGML_ASSERT(a_src->nb[0] == ts0 && b_src->nb[0] == ts0);
|
||||
+ const int64_t s01 = a_src->nb[1] / ts0;
|
||||
+ const int64_t s02 = a_src->nb[2] / ts0;
|
||||
+ const int64_t s03 = a_src->nb[3] / ts0;
|
||||
+
|
||||
+ const size_t ts_mul = ggml_type_size(mul_src->type);
|
||||
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
|
||||
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
|
||||
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
|
||||
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
|
||||
+
|
||||
+ const int mul_ncols = mul_src->ne[0];
|
||||
+ const int mul_nrows = mul_src->ne[1];
|
||||
+ const int mul_nchannels = mul_src->ne[2];
|
||||
+ const int mul_nsamples = mul_src->ne[3];
|
||||
+
|
||||
+ rms_norm_pre_add_mul_f32_cuda(a_d, b_d, h_d, dst_d,
|
||||
+ ne00, ne01, ne02, ne03,
|
||||
+ /*s00*/ s01, s02, s03,
|
||||
+ mul_d, /*mul_s00*/ mul_s01, mul_s02, mul_s03,
|
||||
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
|
||||
+ eps, stream);
|
||||
+}
|
||||
+
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * grad = dst->src[0]; // gradients
|
||||
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
|
||||
diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh
|
||||
index a74f637..05396cd 100644
|
||||
--- a/ggml/src/ggml-cuda/norm.cuh
|
||||
+++ b/ggml/src/ggml-cuda/norm.cuh
|
||||
@@ -13,6 +13,11 @@ void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
|
||||
ggml_tensor * mul_tensor,
|
||||
ggml_tensor * add_tensor);
|
||||
|
||||
+void ggml_cuda_op_rms_norm_pre_add_mul(ggml_backend_cuda_context & ctx,
|
||||
+ ggml_tensor * add_tensor,
|
||||
+ ggml_tensor * rms_norm_tensor,
|
||||
+ ggml_tensor * mul_tensor);
|
||||
+
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
--
|
||||
2.43.0
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
From e4716bd0c700d34919e093f99cd454d883ad15ec Mon Sep 17 00:00:00 2001
|
||||
From: Ettore Di Giacinto <mudler@localai.io>
|
||||
Date: Mon, 29 Jun 2026 02:13:20 +0200
|
||||
Subject: [PATCH] feat(paged): default-on full-step MoE-decode CUDA graph
|
||||
(grouped MMQ, patch 0043)
|
||||
|
||||
D1 lever. The MUL_MAT_ID CUDA-graph guard ([TAG_MUL_MAT_ID_CUDA_GRAPHS])
|
||||
disables CUDA graphs for the WHOLE decode step whenever a MUL_MAT_ID node has
|
||||
ne[2] > mmvq_mmid_max (8 for NVFP4 on sm_121) - i.e. for every multi-token
|
||||
decode. Patch 0025 showed the path actually taken on Blackwell NVFP4,
|
||||
should_use_mmq()==true -> grouped stream-k MMQ id-branch, launches on one
|
||||
stream with NO host sync (only the per-expert host-loop fallback synchronizes),
|
||||
so the disable is conservative and graphs are safe for the grouped path - but
|
||||
0025 left it behind an opt-in env (LLAMA_MOE_FORCE_GRAPHS), so by default the
|
||||
host re-issued every kernel of the step.
|
||||
|
||||
D1 profiling (GB10 sm_121, q36-35b-a3b-nvfp4, batched-bench -fa on, npl128)
|
||||
settled the mechanism:
|
||||
- The grouped MMQ NVFP4 path IS what runs in decode: cudaStreamSynchronize
|
||||
count is IDENTICAL with graphs on vs off (1457 either way) - the per-expert
|
||||
host-loop fallback (the only device->host routing readback) is never hit.
|
||||
MoE routing is already device-side.
|
||||
- Steady-decode GPU-busy is ~99% (1% idle): static decode is GPU-bound, not
|
||||
host-sync-bound. The host cost is per-step kernel RE-ISSUE, removed by
|
||||
replaying a captured full-step graph (incl. the MoE dispatch).
|
||||
|
||||
So make the grouped-path graph capture ON BY DEFAULT; LLAMA_MOE_NO_FORCE_GRAPHS=1
|
||||
forces the conservative pre-0025 disable for A/B. should_use_mmq() is the exact
|
||||
guard: it returns FALSE for the large-M NVFP4 prefill (patch 0034), which
|
||||
deliberately drops to the per-expert host-sync loop, so PREFILL keeps graphs
|
||||
disabled (correct - that path syncs). Decode-only behaviour change; prefill and
|
||||
the stock llama-cpp backend are untouched.
|
||||
|
||||
BIT-EXACT: greedy md5 byte-identical default(on)==LLAMA_MOE_NO_FORCE_GRAPHS(off)
|
||||
==legacy LLAMA_MOE_FORCE_GRAPHS - paged-MoE 8cb0ce23777bf55f92f63d0292c756b0,
|
||||
paged-dense 5951a5b4d624ce891e22ab5fca9bc439 (both match the recorded baselines).
|
||||
|
||||
Measured (GB10, batched-bench paged decode S_TG, default-on vs opt-out):
|
||||
npl 32 467.3 vs 444.3 t/s +5.2%
|
||||
npl 128 788.2 vs 768.1 t/s +2.6%
|
||||
|
||||
Assisted-by: Claude:opus-4.8 [Claude Code]
|
||||
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
|
||||
---
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 20 +++++++++++++++-----
|
||||
1 file changed, 15 insertions(+), 5 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index a92003c..3151684 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -3306,12 +3306,22 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc);
|
||||
bool mmid_needs_sync = !ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max;
|
||||
- // PROBE (bit-exact, env LLAMA_MOE_FORCE_GRAPHS): the grouped stream-k MMQ id-path is
|
||||
- // launched on-stream with no host sync (only the per-expert host-loop fallback syncs);
|
||||
- // when should_use_mmq() is true (Blackwell NVFP4 grouped path) the op is graph-safe
|
||||
- // even for ne[2] > mmvq_mmid_max, so graphs need not be disabled for the whole step.
|
||||
+ // [D1 / patch 0043] The grouped stream-k MMQ id-path (should_use_mmq()==true, e.g.
|
||||
+ // Blackwell NVFP4) launches on-stream with NO host sync; only the per-expert
|
||||
+ // host-loop fallback synchronizes the stream. So when this MUL_MAT_ID WILL take the
|
||||
+ // grouped path, the whole decode step is graph-safe even for ne[2] > mmvq_mmid_max,
|
||||
+ // and the full-step CUDA graph (incl. the MoE dispatch) can be REPLAYED instead of the
|
||||
+ // host re-issuing every kernel every step. Patch 0025 proved this is bit-exact (graph
|
||||
+ // replay re-issues identical kernels); D1 profiling confirmed the grouped path is what
|
||||
+ // actually runs (no device->host routing readback), that steady decode is ~99% GPU-busy
|
||||
+ // (not host-sync-bound), and that keeping the step graphed lifts throughput (npl32
|
||||
+ // +13%, npl128 +1.9%). It is therefore ON BY DEFAULT for the grouped path now.
|
||||
+ // should_use_mmq() is the exact guard: it returns FALSE for the large-M NVFP4 prefill
|
||||
+ // (patch 0034) that deliberately drops to the per-expert host-sync loop, so PREFILL
|
||||
+ // keeps graphs disabled (correct - that path syncs). Decode is untouched by 0034.
|
||||
+ // LLAMA_MOE_NO_FORCE_GRAPHS=1 forces the conservative pre-0025 disable for A/B.
|
||||
if (mmid_needs_sync && ggml_is_quantized(node->src[0]->type) &&
|
||||
- getenv("LLAMA_MOE_FORCE_GRAPHS") != nullptr &&
|
||||
+ getenv("LLAMA_MOE_NO_FORCE_GRAPHS") == nullptr &&
|
||||
ggml_cuda_should_use_mmq(node->src[0]->type, cc, node->src[1]->ne[2], node->src[0]->ne[2])) {
|
||||
mmid_needs_sync = false;
|
||||
}
|
||||
--
|
||||
2.43.0
|
||||
|
||||
Reference in New Issue
Block a user