docs(paged): record MoE-prefill engine-gap decomposition + GEMM-port negatives (default-off)

nsys cross-engine decomposition: the MoE prefill 64% gap vs vLLM is engine plumbing, not the kernel (GPU 97% busy, 443 vs 197 us/tok). Three buckets: per-expert W4A4 M-fragmentation (58%), GDN scan (24%), f32<->bf16 casts (15%). Offline-repack (0045) and verbatim vLLM-marlin port both trail FP4-MMQ via wrapper overhead, kept default-off as recorded negatives.

Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-06-29 17:20:07 +00:00
parent 7b38c6b2a3
commit be65438eac

View File

@@ -0,0 +1,650 @@
From 864e92807809c55905bdb73ef492f6b1d03986f9 Mon Sep 17 00:00:00 2001
From: Ettore Di Giacinto <mudler@localai.io>
Date: Mon, 29 Jun 2026 11:02:07 +0200
Subject: [PATCH] feat(paged): vLLM-exact offline-repack Marlin W4A16 grouped
MoE prefill GEMM (patch 0045)
The patch-0035 root-cause sweep proved the W4A16 grouped MoE GEMM is occupancy/
latency bound (~10% MMA util), NOT dequant bound: every prefill step it re-derives
a per-block STRIDED weight gather and streams the 4-bit weights with scattered
4-BYTE cp.async transactions (BN*8 per K-block) + a per-step ue4m3 scale decode,
plus a SEPARATE global f32->bf16 activation pre-pass kernel. vLLM Marlin wins by
REPACKING the weights ONCE at load into an MMA-ready tile-contiguous 4-bit layout
so the GEMM loop issues only cheap COALESCED 16-byte loads and the per-lane reads
are already in fragment order; the loop does ZERO dequant work beyond the cheap
in-register nibble->bf16 unpack. This is that approach, faithful to vLLM Marlin and
distinct from the rejected patch-0044 int8-Marlin (which doubled the weight bytes
to int8 + a separate act-quant pass).
(1) OFFLINE REPACK (one-time, first engage; cached in a persistent cudaMalloc
buffer keyed by src0->data): transpose the NVFP4 experts [N][Kb](d[4]+qs[32])
into tile-contiguous planes Q=[e][Kb][N][32 qs bytes] + S=[e][Kb][N][4 ue4m3
scale bytes]. STAYS 4-BIT: qs packed, scales ue4m3 -> the repacked buffer is
byte-for-byte the SAME size as the source NVFP4 weights (144 MiB/tensor for the
35B-A3B experts, == source; no int8 2x). Persists across all prefill steps.
(2) KERNEL reads the pre-packed tiles directly: per (K-block, N-block) the BN rows
qs/scales are contiguous -> the smem fill is a flat COALESCED 16-byte cp.async
stream (no strided gather, no address math). Activations cast f32->bf16 IN
REGISTER on the load into smem (no separate global act pre-pass, no act-quant).
Inner loop: ldmatrix bf16 A, in-register PRMT (LUT-free) nibble->bf16 unpack
scaled by the in-register ue4m3 decode, bf16 m16n8k16 mma.sync into f32 accum,
cp.async multistage, ragged per-tile expert offset (reuses 0035 tile map). NO
separate smem-staged dequant pass, NO __syncthreads-gated dequant pass (SASS:
LDSM->PRMT->I2F->FMUL->F2F.BF16->HMMA.16816.F32.BF16 with no STS/BAR.SYNC between
the weight load and the MMA).
The repacked smem is bit-identical to the 0035 W4A16 smem, so the proven inner
fragment math is unchanged and the output is bit-identical to the 0035 W4A16 path.
TOGGLE: LLAMA_W4A16_REPACK=1 (default 0 == OFF), engaged only inside the already
default-off 0035 W4A16 path (LLAMA_W4A16_PREFILL_M>0). LLAMA_W4A16_REPACK_NOCACHE=1
repacks into a transient pool buffer per call (for test-backend-ops, which frees/
reuses src0 addresses and so defeats the pointer-keyed cache). Stock / decode /
non-NVFP4 byte-untouched.
VALIDATION (GB10, sm_121a, Qwen3.6-35B-A3B-NVFP4):
- test-backend-ops MUL_MAT_ID nvfp4 (vs CPU oracle), REPACK forced + NOCACHE:
81/81 OK, 0 FAIL.
- real-model greedy md5 (paged MoE): stock == W4A16-non-repack == W4A16-repack
(cached) == W4A16-repack (nocache) == default-off (REPACK=1, PREFILL_M unset),
all fda1aadbbbfb36fe8ab0f5f5465c745e (bit-identical; default-off is stock).
HONEST PERF (S_PP t/s, llama-batched-bench -fa on -ngl 99 -ntg 32 -npl 1, paged,
warm cache): the offline repack is a large win over the prior non-repack W4A16
(npp512 854.7 -> 1452.4 = +70%; beats it at every M) and is FLAT across M
(1452/1478/1467 at 512/1024/2048 = MMA-pipeline bound as designed). But the heavily
tuned FP4-MMQ baseline on GB10 is still ~38% faster (~2030 flat). Decode S_TG
unchanged (~55 t/s, prefill-only lever). One-time cache build amortized (cold first
prefill). Ships DEFAULT-OFF (like 0033/0034/0035): the validated, env-gated, bit-
exact-gated mechanism + the recorded result that even vLLM-exact offline-repack
Marlin W4A16 does not overtake the GB10 FP4-MMQ winner.
Weight-memory: the repacked representation STAYS 4-BIT (no int8 expansion; vs 0044
+100%); the persistent cache is an additive 4-bit copy of the engaged expert tensors
when ON, and literally 0 when OFF (default). Decode keeps the original block_nvfp4
layout for its MMQ/graph path, so the cache is additive rather than in-place.
Build: arch=compute_121a,code=[compute_121a,sm_121a]; AMPERE_MMA_AVAILABLE /
CP_ASYNC_AVAILABLE guards (NO_DEVICE_CODE off-Blackwell).
Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
---
ggml/src/ggml-cuda/w4a16-gemm.cu | 11 +
ggml/src/ggml-cuda/w4a16-repack.cu | 470 ++++++++++++++++++++++++++++
ggml/src/ggml-cuda/w4a16-repack.cuh | 59 ++++
3 files changed, 540 insertions(+)
create mode 100644 ggml/src/ggml-cuda/w4a16-repack.cu
create mode 100644 ggml/src/ggml-cuda/w4a16-repack.cuh
diff --git a/ggml/src/ggml-cuda/w4a16-gemm.cu b/ggml/src/ggml-cuda/w4a16-gemm.cu
index c5c9ef7..687a7ae 100644
--- a/ggml/src/ggml-cuda/w4a16-gemm.cu
+++ b/ggml/src/ggml-cuda/w4a16-gemm.cu
@@ -1,4 +1,5 @@
#include "w4a16-gemm.cuh"
+#include "w4a16-repack.cuh"
#include "mma.cuh"
#include <algorithm>
@@ -389,6 +390,16 @@ void ggml_cuda_mul_mat_id_w4a16_grouped(
GGML_ASSERT(src0->type == GGML_TYPE_NVFP4);
GGML_ASSERT(N % 128 == 0 && K % 64 == 0);
+ // [paged patch 0045] vLLM-exact offline-repack sub-mode: repack the NVFP4 experts ONCE at first
+ // engage into an MMA-ready tile-contiguous 4-bit layout (cached, stays 4-bit), then run a GEMM
+ // loop with coalesced 16B weight loads + in-register act cast + ZERO in-loop dequant beyond the
+ // cheap in-register nibble->bf16 unpack. Gated by LLAMA_W4A16_REPACK (default off).
+ if (ggml_cuda_w4a16_repack_enabled()) {
+ ggml_cuda_mul_mat_id_w4a16_repack(ctx, src0, src1_sorted, dst_sorted,
+ tokens_per_expert, n_experts, K, N, stream);
+ return;
+ }
+
int sel = w4a16_cfg_sel();
if (N % w4a16_cfg_bn(sel) != 0) {
sel = 3; // BN>128 config whose BN doesn't divide N: safe BN=128 winner
diff --git a/ggml/src/ggml-cuda/w4a16-repack.cu b/ggml/src/ggml-cuda/w4a16-repack.cu
new file mode 100644
index 0000000..4a7a125
--- /dev/null
+++ b/ggml/src/ggml-cuda/w4a16-repack.cu
@@ -0,0 +1,470 @@
+#include "w4a16-repack.cuh"
+#include "mma.cuh"
+
+#include <cstdint>
+#include <cstdlib>
+#include <mutex>
+#include <unordered_map>
+#include <vector>
+#include <algorithm>
+
+// ===========================================================================
+// [paged patch 0045] vLLM-exact offline-repack Marlin W4A16 grouped MoE prefill GEMM.
+// See w4a16-repack.cuh for the design. Default-off (LLAMA_W4A16_REPACK).
+// ===========================================================================
+
+using namespace ggml_cuda_mma;
+typedef tile<16, 8, nv_bfloat162> rp_tile_A; // A operand: M=16, K=16
+typedef tile< 8, 8, nv_bfloat162> rp_tile_B; // B operand: N=8, K=16
+typedef tile<16, 8, float> rp_tile_C; // accumulator: M=16, N=8
+
+bool ggml_cuda_w4a16_repack_enabled() {
+ static const bool e = [] {
+ const char * s = getenv("LLAMA_W4A16_REPACK");
+ return s != nullptr && atoi(s) != 0;
+ }();
+ return e;
+}
+
+// ---- cp.async helpers (sm80+; raw bytes, no cast) ----
+static __device__ __forceinline__ void rp_cp_async16(void * smem, const void * gmem) {
+#ifdef CP_ASYNC_AVAILABLE
+ const unsigned s = (unsigned) __cvta_generic_to_shared(smem);
+ asm volatile("cp.async.cg.shared.global [%0],[%1],16;\n" :: "r"(s), "l"(gmem));
+#else
+ GGML_UNUSED(smem); GGML_UNUSED(gmem); NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
+static __device__ __forceinline__ void rp_cp_commit() {
+#ifdef CP_ASYNC_AVAILABLE
+ asm volatile("cp.async.commit_group;\n" ::);
+#else
+ NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
+template<int N> static __device__ __forceinline__ void rp_cp_wait() {
+#ifdef CP_ASYNC_AVAILABLE
+ asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
+#else
+ NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
+
+// ---- fast 16-entry FP4 (E2M1) -> kvalues_mxfp4 lookup via __byte_perm (PRMT), LUT-free, NO
+// divergent LDG. Bit-identical to kvalues_mxfp4[code]; same trick MMQ's get_int_from_table_16 and
+// the 0035 kernel use. The "cheap in-register unpack" - the ONLY dequant work the loop does. ----
+static __device__ __forceinline__ int2 rp_table16(uint32_t q4) {
+ const uint32_t * table32 = (const uint32_t *) kvalues_mxfp4; // 16 int8 == 4 u32
+ uint32_t tmp[2];
+ const uint32_t sel = 0x32103210 | ((q4 & 0x88888888) >> 1);
+#pragma unroll
+ for (uint32_t i = 0; i < 2; ++i) {
+ const uint32_t shift = 16*i;
+ const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
+ const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
+ tmp[i] = __byte_perm(low, high, sel >> shift);
+ }
+ return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
+}
+
+// ===========================================================================
+// (1) OFFLINE REPACK kernel (one-time). Transpose src0 NVFP4 [e][N][Kb](d[4]+qs[32]) into
+// Qrp = [e][Kb][N][32 qs bytes] (u32 plane, 8 u32/block)
+// Srp = [e][Kb][N][ 4 scale bytes] (u32 plane, 1 u32/block)
+// Same total bytes as src0 (36 B/block); stays 4-bit. One thread per source block.
+// ===========================================================================
+static __global__ void w4a16_repack_kernel(
+ const block_nvfp4 * __restrict__ W0, int64_t expert_stride_blocks,
+ uint32_t * __restrict__ Qrp, uint32_t * __restrict__ Srp,
+ int N, int Kb, int n_experts) {
+ const int64_t total = (int64_t) n_experts * N * Kb;
+ const int64_t t = (int64_t) blockIdx.x * blockDim.x + threadIdx.x;
+ if (t >= total) {
+ return;
+ }
+ const int kt = (int) (t % Kb);
+ const int64_t r2 = t / Kb;
+ const int n = (int) (r2 % N);
+ const int e = (int) (r2 / N);
+
+ const block_nvfp4 * blk = W0 + (int64_t) e*expert_stride_blocks + (int64_t) n*Kb + kt;
+ const uint32_t * src = (const uint32_t *) blk; // word0=d[4], words1..8=qs[32]
+
+ const int64_t qbase = (((int64_t) e*Kb + kt) * N + n) * 8; // u32 offset in Qrp
+ const int64_t sbase = ((int64_t) e*Kb + kt) * N + n; // u32 offset in Srp
+ Srp[sbase] = src[0];
+#pragma unroll
+ for (int w = 0; w < 8; w++) {
+ Qrp[qbase + w] = src[1 + w];
+ }
+}
+
+// ===========================================================================
+// (2) GROUPED GEMM over the repacked tiles. Per output tile (blockIdx.x=N-block, blockIdx.y=M-tile):
+// expert e=g_tile_expert[by], row0=g_tile_row0[by], rcount=g_tile_rows[by].
+// Weights from Qrp/Srp (tile-contiguous), activations from A_f32 (cast bf16 in-register), out to C.
+// ZERO in-loop dequant beyond the in-register nibble->bf16 unpack + ue4m3 scale decode.
+// ===========================================================================
+template<int BM, int BN, int WARPS_M, int WARPS_N, int STAGES, int APAD>
+__launch_bounds__(WARPS_M*WARPS_N*32, 1)
+static __global__ void w4a16_repack_grouped_kernel(
+ const float * __restrict__ A_f32, // [total_rows, K] f32 (sorted)
+ const uint32_t * __restrict__ Qrp, // [e][Kb][N][8 u32]
+ const uint32_t * __restrict__ Srp, // [e][Kb][N][1 u32]
+ float * __restrict__ C, // [total_rows, N] f32
+ const int * __restrict__ g_tile_expert,
+ const int * __restrict__ g_tile_row0,
+ const int * __restrict__ g_tile_rows,
+ int N, int K, int total_rows) {
+#if defined(AMPERE_MMA_AVAILABLE) && defined(CP_ASYNC_AVAILABLE)
+ constexpr int BK = 64; // one nvfp4 block
+ constexpr int NWARP = WARPS_M*WARPS_N;
+ constexpr int THREADS = NWARP*32;
+ constexpr int WM = BM/WARPS_M, WN = BN/WARPS_N;
+ constexpr int MF = WM/16, NF = WN/8;
+
+ constexpr int AN = BK/2; // bf16 pairs per A smem row (nv_bfloat162)
+ constexpr int ASTRIDE = AN + APAD; // padded A smem row stride (skew banks, +19% lesson)
+ constexpr int SZ_A = BM*ASTRIDE; // nv_bfloat162 (== u32) per stage (padded)
+ constexpr int SZ_WQ = BN*8; // u32 per stage (32 qs bytes/row, tile-contiguous)
+ constexpr int SZ_WD = BN; // u32 per stage (4 scale bytes/row, tile-contiguous)
+
+ extern __shared__ uint32_t smem_u32[];
+ constexpr int STAGE_U32 = SZ_A + SZ_WQ + SZ_WD;
+ nv_bfloat162 * sA[STAGES];
+ uint32_t * sWq[STAGES];
+ uint32_t * sWd[STAGES];
+#pragma unroll
+ for (int s = 0; s < STAGES; s++) {
+ uint32_t * base = smem_u32 + s*STAGE_U32;
+ sA[s] = (nv_bfloat162 *) base;
+ sWq[s] = base + SZ_A;
+ sWd[s] = base + SZ_A + SZ_WQ;
+ }
+
+ const int lane = threadIdx.x; // 0..31 (mma.cuh uses threadIdx.x AS the warp lane)
+ const int warp = threadIdx.y; // 0..NWARP-1
+ const int tid = warp*32 + lane;
+ const int wrow = warp / WARPS_N, wcol = warp % WARPS_N;
+
+ const int e = g_tile_expert[blockIdx.y];
+ const int row0 = g_tile_row0[blockIdx.y];
+ const int rcount = g_tile_rows[blockIdx.y];
+ const int blockCol = blockIdx.x*BN;
+ const int Kb = K/64;
+ // base u32 offset of expert e in the repacked planes (tile = + (kt*N + blockCol)*{8,1})
+ const int64_t Qe = (int64_t) e*Kb*N*8;
+ const int64_t Se = (int64_t) e*Kb*N;
+
+ rp_tile_C acc[MF][NF];
+
+ // async-load K-block kt into stage st: A cast in-register (no global pre-pass); W coalesced 16B.
+ auto load_tile = [&](int st, int kt) {
+ // A: BM rows x BK bf16 = BM x (BK/8) 16B chunks. Source f32 -> cast bf16 in register.
+ const int A_chunks = BM*(BK/8);
+#pragma unroll 1
+ for (int idx = tid; idx < A_chunks; idx += THREADS) {
+ const int c = idx % (BK/8); // 16B (8 bf16) chunk in the row
+ const int r = idx / (BK/8); // row in tile
+ const int gr = row0 + r;
+ nv_bfloat162 * d2 = sA[st] + r*ASTRIDE + c*4; // 4 nv_bfloat162 = 8 bf16
+ if (gr < total_rows) {
+ const float * src = A_f32 + (int64_t) gr*K + (int64_t) kt*BK + c*8;
+ const float4 a = *(const float4 *) (src);
+ const float4 b = *(const float4 *) (src + 4);
+ d2[0] = make_bfloat162(__float2bfloat16(a.x), __float2bfloat16(a.y));
+ d2[1] = make_bfloat162(__float2bfloat16(a.z), __float2bfloat16(a.w));
+ d2[2] = make_bfloat162(__float2bfloat16(b.x), __float2bfloat16(b.y));
+ d2[3] = make_bfloat162(__float2bfloat16(b.z), __float2bfloat16(b.w));
+ } else {
+ const nv_bfloat162 z = make_bfloat162((nv_bfloat16) 0.0f, (nv_bfloat16) 0.0f);
+ d2[0] = z; d2[1] = z; d2[2] = z; d2[3] = z;
+ }
+ }
+ // W qs: tile-contiguous BN*32 bytes = BN*8 u32 = BN*2 16B chunks. Coalesced cp.async.16.
+ const uint32_t * Qtile = Qrp + Qe + ((int64_t) kt*N + blockCol) * 8;
+ const int Q_chunks = (BN*8) / 4; // 16B (4 u32) chunks
+#pragma unroll 1
+ for (int idx = tid; idx < Q_chunks; idx += THREADS) {
+ rp_cp_async16(&sWq[st][idx*4], Qtile + idx*4);
+ }
+ // W scales: tile-contiguous BN*4 bytes = BN u32 = BN/4 16B chunks. Coalesced cp.async.16.
+ const uint32_t * Stile = Srp + Se + ((int64_t) kt*N + blockCol);
+ const int S_chunks = BN / 4;
+#pragma unroll 1
+ for (int idx = tid; idx < S_chunks; idx += THREADS) {
+ rp_cp_async16(&sWd[st][idx*4], Stile + idx*4);
+ }
+ };
+
+ // prologue
+#pragma unroll
+ for (int s = 0; s < STAGES-1; s++) { if (s < Kb) load_tile(s, s); rp_cp_commit(); }
+
+ for (int kt = 0; kt < Kb; kt++) {
+ const int ld = kt + (STAGES-1);
+ if (ld < Kb) load_tile(ld % STAGES, ld);
+ rp_cp_commit();
+ rp_cp_wait<STAGES-1>();
+ __syncthreads();
+
+ const int rs = kt % STAGES;
+ const nv_bfloat162 * sAcur = sA[rs];
+ const uint32_t * sWqw = sWq[rs]; // BN rows x 8 u32 (32 qs bytes)
+ const uint32_t * sWdw = sWd[rs]; // BN rows x 1 u32 (4 scale bytes)
+
+#pragma unroll
+ for (int kk = 0; kk < BK/16; kk++) { // 4 m16n8k16 sub-steps per 64-block
+ const int sub = kk;
+ // A fragments via ldmatrix (bf16)
+ rp_tile_A A_frag[MF];
+#pragma unroll
+ for (int mi = 0; mi < MF; mi++) {
+ const int rb = wrow*WM + mi*16;
+ load_ldmatrix(A_frag[mi], sAcur + rb*ASTRIDE + kk*8, ASTRIDE);
+ }
+ // B fragments: in-register FP4->bf16 unpack (byte_perm LUT-free) * in-register ue4m3 scale.
+ rp_tile_B B_frag[NF];
+ const int n_local = lane >> 2; // tile_B::get_i (row N, 0..7)
+ const int jc = lane & 3;
+ const int wsel = sub*2 + (jc >> 1);
+ const int bsh = 8 * (2*(jc & 1));
+#pragma unroll
+ for (int ni = 0; ni < NF; ni++) {
+ const int nrow = wcol*WN + ni*8 + n_local; // col within BN tile [0,BN)
+ const uint32_t w = sWqw[nrow*8 + wsel];
+ const int2 kv = rp_table16(w);
+ const float sc = ggml_cuda_ue4m3_to_fp32(((const uint8_t *) &sWdw[nrow])[sub]);
+ B_frag[ni].x[0].x = __float2bfloat16(sc * (float) (int8_t) (kv.x >> bsh));
+ B_frag[ni].x[0].y = __float2bfloat16(sc * (float) (int8_t) (kv.x >> (bsh + 8)));
+ B_frag[ni].x[1].x = __float2bfloat16(sc * (float) (int8_t) (kv.y >> bsh));
+ B_frag[ni].x[1].y = __float2bfloat16(sc * (float) (int8_t) (kv.y >> (bsh + 8)));
+ }
+#pragma unroll
+ for (int mi = 0; mi < MF; mi++)
+#pragma unroll
+ for (int ni = 0; ni < NF; ni++)
+ mma(acc[mi][ni], A_frag[mi], B_frag[ni]);
+ }
+ __syncthreads();
+ }
+
+ // write back (mask the ragged per-expert row tail)
+#pragma unroll
+ for (int mi = 0; mi < MF; mi++)
+#pragma unroll
+ for (int ni = 0; ni < NF; ni++) {
+ const int orow = wrow*WM + mi*16;
+ const int ocol = blockCol + wcol*WN + ni*8;
+#pragma unroll
+ for (int l = 0; l < acc[mi][ni].ne; l++) {
+ const int lr = orow + acc[mi][ni].get_i(l);
+ const int nc = ocol + acc[mi][ni].get_j(l);
+ if (lr < rcount && nc < N) {
+ C[(int64_t)(row0 + lr)*N + nc] = acc[mi][ni].x[l];
+ }
+ }
+ }
+#else
+ GGML_UNUSED(A_f32); GGML_UNUSED(Qrp); GGML_UNUSED(Srp); GGML_UNUSED(C);
+ GGML_UNUSED(g_tile_expert); GGML_UNUSED(g_tile_row0); GGML_UNUSED(g_tile_rows);
+ GGML_UNUSED(N); GGML_UNUSED(K); GGML_UNUSED(total_rows);
+ NO_DEVICE_CODE;
+#endif // AMPERE_MMA_AVAILABLE && CP_ASYNC_AVAILABLE
+}
+
+// launch the one-time repack kernel: src0 NVFP4 -> (Qrp, Srp) repacked planes.
+static void w4a16_repack_launch(
+ const ggml_tensor * src0, int64_t n_experts, int64_t K, int64_t N,
+ uint32_t * Qrp, uint32_t * Srp, cudaStream_t stream) {
+ const int64_t Kb = K / 64;
+ const int64_t nblk = n_experts * Kb * N;
+ const int64_t expert_stride_blocks = (int64_t) (src0->nb[2] / sizeof(block_nvfp4));
+ const int threads = 256;
+ const int64_t grid = (nblk + threads - 1) / threads;
+ w4a16_repack_kernel<<<grid, threads, 0, stream>>>(
+ (const block_nvfp4 *) src0->data, expert_stride_blocks,
+ Qrp, Srp, (int) N, (int) Kb, (int) n_experts);
+ CUDA_CHECK(cudaGetLastError());
+}
+
+// ===========================================================================
+// repack cache: persistent device buffer per src0 tensor (keyed by src0->data). One-time build.
+//
+// Correct for real models: model weights live in a persistent buffer for the model's lifetime, so
+// src0->data is stable and never freed/reused during inference. NOT correct for test-backend-ops,
+// which allocates/frees a fresh src0 per case and ggml reuses the address -> a cache HIT returns
+// stale repacked weights. For harness validation use LLAMA_W4A16_REPACK_NOCACHE=1 (repack into a
+// transient pool buffer every call, bypassing the cache); this proves the kernel+repack correct.
+// ===========================================================================
+struct w4a16_repack_entry {
+ uint32_t * Qrp = nullptr; // [n_experts][Kb][N][8 u32]
+ uint32_t * Srp = nullptr; // [n_experts][Kb][N][1 u32]
+ int64_t n_experts = 0, Kb = 0, N = 0;
+ size_t bytes = 0; // total cudaMalloc bytes (Q + S), for the memory-delta report
+};
+static std::mutex g_rp_mu;
+static std::unordered_map<const void *, w4a16_repack_entry> g_rp_cache;
+
+// total bytes the repack cache currently holds on device (for the memory-delta report).
+size_t ggml_cuda_w4a16_repack_cache_bytes() {
+ std::lock_guard<std::mutex> lk(g_rp_mu);
+ size_t b = 0;
+ for (const auto & kv : g_rp_cache) {
+ b += kv.second.bytes;
+ }
+ return b;
+}
+
+static const w4a16_repack_entry & w4a16_get_or_build_repack(
+ const ggml_tensor * src0, int64_t n_experts, int64_t K, int64_t N, cudaStream_t stream) {
+ const void * key = src0->data;
+ std::lock_guard<std::mutex> lk(g_rp_mu);
+ auto it = g_rp_cache.find(key);
+ if (it != g_rp_cache.end()) {
+ return it->second;
+ }
+
+ const int64_t Kb = K / 64;
+ const int64_t nblk = n_experts * Kb * N;
+ const size_t q_bytes = (size_t) nblk * 8 * sizeof(uint32_t); // 32 qs bytes/block
+ const size_t s_bytes = (size_t) nblk * 1 * sizeof(uint32_t); // 4 scale bytes/block
+ // single allocation: [Q][S]; total == source NVFP4 weight bytes (36 B/block) -> stays 4-bit.
+ void * buf = nullptr;
+ CUDA_CHECK(cudaMalloc(&buf, q_bytes + s_bytes));
+ uint32_t * Qrp = (uint32_t *) buf;
+ uint32_t * Srp = (uint32_t *) ((char *) buf + q_bytes);
+
+ w4a16_repack_launch(src0, n_experts, K, N, Qrp, Srp, stream);
+
+ w4a16_repack_entry ent;
+ ent.Qrp = Qrp; ent.Srp = Srp;
+ ent.n_experts = n_experts; ent.Kb = Kb; ent.N = N;
+ ent.bytes = q_bytes + s_bytes;
+ auto res = g_rp_cache.emplace(key, ent);
+
+ if (getenv("LLAMA_W4A16_DEBUG")) {
+ // NB: we already hold g_rp_mu - sum the cache total INLINE (do NOT call the public
+ // ggml_cuda_w4a16_repack_cache_bytes(), which re-locks the non-recursive mutex -> deadlock).
+ size_t total_cache = 0;
+ for (const auto & kv : g_rp_cache) {
+ total_cache += kv.second.bytes;
+ }
+ fprintf(stderr, "[w4a16-repack] BUILT cache for src0=%p: n_experts=%lld Kb=%lld N=%lld "
+ "bytes=%.1f MiB (== source NVFP4 weight bytes; stays 4-bit). total cache=%.1f MiB\n",
+ key, (long long) n_experts, (long long) Kb, (long long) N,
+ ent.bytes / (1024.0*1024.0), total_cache / (1024.0*1024.0));
+ }
+ return res.first->second;
+}
+
+// ===========================================================================
+// GEMM run over already-repacked planes (Qrp, Srp). Single tuned config
+// (the 0035 winner: BM64 BN128 WARPS_M1 WARPS_N8 STAGES3 APAD4).
+// ===========================================================================
+static void w4a16_repack_run(
+ ggml_backend_cuda_context & ctx,
+ const float * src1_sorted,
+ float * dst_sorted,
+ const int * tokens_per_expert,
+ int64_t n_experts, int64_t K, int64_t N,
+ const uint32_t * Qrp, const uint32_t * Srp,
+ cudaStream_t stream) {
+ constexpr int BM = 64, BN = 128, WARPS_M = 1, WARPS_N = 8, STAGES = 3, APAD = 4;
+
+ // host: per-M-tile expert map (ragged, no tile crosses an expert boundary)
+ int64_t total_rows = 0;
+ for (int64_t e = 0; e < n_experts; e++) {
+ total_rows += tokens_per_expert[e];
+ }
+ if (total_rows == 0) {
+ return;
+ }
+ std::vector<int32_t> h_tile_expert, h_tile_row0, h_tile_rows;
+ int64_t row = 0;
+ for (int64_t e = 0; e < n_experts; e++) {
+ const int t = tokens_per_expert[e];
+ for (int off = 0; off < t; off += BM) {
+ h_tile_expert.push_back((int32_t) e);
+ h_tile_row0.push_back((int32_t) (row + off));
+ h_tile_rows.push_back((int32_t) std::min(BM, t - off));
+ }
+ row += t;
+ }
+ const int n_tiles = (int) h_tile_expert.size();
+
+ if (getenv("LLAMA_W4A16_DEBUG")) {
+ int max_tpe = 0, multi = 0;
+ for (int64_t e = 0; e < n_experts; e++) {
+ if (tokens_per_expert[e] > max_tpe) max_tpe = tokens_per_expert[e];
+ if (tokens_per_expert[e] > BM) multi++;
+ }
+ fprintf(stderr, "[w4a16-repack] engaged: total_rows=%lld n_experts=%lld K=%lld N=%lld "
+ "n_tiles=%d max_tpe=%d multi_tile=%d (offline-repack, in-register act cast)\n",
+ (long long) total_rows, (long long) n_experts, (long long) K, (long long) N,
+ n_tiles, max_tpe, multi);
+ }
+
+ ggml_cuda_pool_alloc<int32_t> d_tile_expert(ctx.pool(), n_tiles);
+ ggml_cuda_pool_alloc<int32_t> d_tile_row0 (ctx.pool(), n_tiles);
+ ggml_cuda_pool_alloc<int32_t> d_tile_rows (ctx.pool(), n_tiles);
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_expert.ptr, h_tile_expert.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_row0.ptr, h_tile_row0.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_rows.ptr, h_tile_rows.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
+
+ auto kern = w4a16_repack_grouped_kernel<BM, BN, WARPS_M, WARPS_N, STAGES, APAD>;
+ constexpr int AN = 64/2, ASTRIDE = AN + APAD;
+ constexpr int STAGE_U32 = BM*ASTRIDE + BN*8 + BN;
+ const int smem_bytes = STAGES * STAGE_U32 * (int) sizeof(uint32_t);
+ CUDA_CHECK(cudaFuncSetAttribute(kern, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes));
+
+ dim3 grid((unsigned) (N / BN), (unsigned) n_tiles);
+ dim3 block(32, WARPS_M*WARPS_N);
+ kern<<<grid, block, smem_bytes, stream>>>(
+ src1_sorted, Qrp, Srp, dst_sorted,
+ d_tile_expert.ptr, d_tile_row0.ptr, d_tile_rows.ptr,
+ (int) N, (int) K, (int) total_rows);
+ CUDA_CHECK(cudaGetLastError());
+}
+
+// ===========================================================================
+// public host entry. Default: cached persistent repack (one-time, production-correct, stays
+// 4-bit). LLAMA_W4A16_REPACK_NOCACHE=1: repack into a transient pool buffer EVERY call (bypasses
+// the pointer-keyed cache) so test-backend-ops (which frees/reuses src0 addresses) can validate
+// the kernel + repack correctness.
+// ===========================================================================
+void ggml_cuda_mul_mat_id_w4a16_repack(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0,
+ const float * src1_sorted,
+ float * dst_sorted,
+ const int * tokens_per_expert,
+ int64_t n_experts, int64_t K, int64_t N,
+ cudaStream_t stream) {
+ GGML_ASSERT(src0->type == GGML_TYPE_NVFP4);
+ GGML_ASSERT(N % 128 == 0 && K % 64 == 0);
+
+ static const bool nocache = [] {
+ const char * s = getenv("LLAMA_W4A16_REPACK_NOCACHE");
+ return s != nullptr && atoi(s) != 0;
+ }();
+
+ if (nocache) {
+ const int64_t Kb = K / 64;
+ const int64_t nblk = n_experts * Kb * N;
+ // (1) repack into a transient pool buffer (same 4-bit size; stream-ordered before the GEMM).
+ ggml_cuda_pool_alloc<uint32_t> q(ctx.pool(), (size_t) nblk * 8);
+ ggml_cuda_pool_alloc<uint32_t> s(ctx.pool(), (size_t) nblk * 1);
+ w4a16_repack_launch(src0, n_experts, K, N, q.ptr, s.ptr, stream);
+ // (2) GEMM (q,s stay alive through the launch; kernel is queued on `stream` before dtor).
+ w4a16_repack_run(ctx, src1_sorted, dst_sorted, tokens_per_expert,
+ n_experts, K, N, q.ptr, s.ptr, stream);
+ return;
+ }
+
+ // (1) one-time offline repack (cached, persistent, stays 4-bit). Stream-ordered before the GEMM.
+ const w4a16_repack_entry & rp = w4a16_get_or_build_repack(src0, n_experts, K, N, stream);
+ // (2) GEMM over the cached repacked planes.
+ w4a16_repack_run(ctx, src1_sorted, dst_sorted, tokens_per_expert,
+ n_experts, K, N, rp.Qrp, rp.Srp, stream);
+}
diff --git a/ggml/src/ggml-cuda/w4a16-repack.cuh b/ggml/src/ggml-cuda/w4a16-repack.cuh
new file mode 100644
index 0000000..f44f5f8
--- /dev/null
+++ b/ggml/src/ggml-cuda/w4a16-repack.cuh
@@ -0,0 +1,59 @@
+#pragma once
+
+#include "common.cuh"
+
+// [paged patch 0045] vLLM-EXACT offline-repack Marlin W4A16 grouped MoE prefill GEMM.
+//
+// This is a SUB-MODE of the patch-0035 W4A16 grouped MoE GEMM (default-off). The 0035 root-cause
+// sweep proved the W4A16 kernel is occupancy/latency bound (~10% MMA util): every prefill step it
+// re-derives a per-block STRIDED weight gather (blk = We + (blockCol+r)*Kb + kt) and streams the
+// 4-bit weights with scattered 4-BYTE cp.async transactions (BN*8 per K-block) + a per-step ue4m3
+// scale decode, plus a SEPARATE global f32->bf16 activation pre-pass kernel. vLLM's Marlin wins by
+// REPACKING the weights ONCE at load into an MMA-ready, tile-contiguous 4-bit layout so the GEMM
+// loop issues only cheap COALESCED 16-byte loads and the per-lane reads are already in fragment
+// order - the loop does ZERO dequant work beyond the cheap in-register nibble->bf16 unpack.
+//
+// What this mode does, faithful to vLLM Marlin:
+// (1) OFFLINE REPACK (one-time, at first engage; cached in a persistent cudaMalloc device buffer
+// keyed by src0->data): transpose the NVFP4 expert weights [N rows][Kb blocks](d[4]+qs[32])
+// into two tile-contiguous planes, Q=[expert][Kb][N][32 qs bytes] and S=[expert][Kb][N][4
+// ue4m3 scale bytes]. STAYS 4-BIT: qs kept packed, scales kept ue4m3 - the repacked buffer is
+// byte-for-byte the SAME size as the source NVFP4 weights (36 B / 64 weights). No int8
+// expansion, no 2x memory (vs the rejected patch-0044 int8-Marlin which doubled the weight
+// bytes). Persists across all prefill steps.
+// (2) KERNEL reads the pre-packed tiles DIRECTLY: per (K-block, N-block) the BN rows' qs/scales
+// are contiguous, so the smem fill is a flat COALESCED 16-byte cp.async stream (no strided
+// gather, no address math). Activations are cast f32->bf16 IN REGISTER on the load into smem
+// (no separate global act pre-pass, no act-quant). The inner loop does only: ldmatrix the
+// bf16 A, in-register nibble->bf16 unpack of the weight (byte_perm/PRMT, LUT-free) scaled by
+// the in-register ue4m3 decode, and bf16 m16n8k16 mma.sync into f32 accumulators. cp.async
+// multistage pipelined, ragged per-tile expert offset (reuses 0035's tile map). NO separate
+// smem-staged dequant pass, NO __syncthreads-gated dequant pass.
+//
+// The repacked smem contents are bit-identical to the 0035 W4A16 smem, so the inner fragment math
+// (proven 81/81 on test-backend-ops MUL_MAT_ID) is unchanged and the output is bit-identical to the
+// 0035 W4A16 path; only the global->smem load path (coalesced) and the act cast (in-register) change.
+//
+// Toggle: LLAMA_W4A16_REPACK=1 (default 0 == OFF). Engages only inside the already-default-off
+// 0035 W4A16 grouped path (LLAMA_W4A16_PREFILL_M>0). Stock / decode / non-NVFP4 byte-untouched.
+
+// True iff LLAMA_W4A16_REPACK != 0.
+bool ggml_cuda_w4a16_repack_enabled();
+
+// Total bytes the persistent repack cache currently holds on device (for the weight-memory-delta
+// report). The repacked layout stays 4-bit, so this equals the source NVFP4 bytes of the engaged
+// expert weight tensors; 0 when the toggle is OFF (no repack happens).
+size_t ggml_cuda_w4a16_repack_cache_bytes();
+
+// Offline-repack W4A16 grouped MoE GEMM over the token-sorted buffer. Same contract as
+// ggml_cuda_mul_mat_id_w4a16_grouped (see w4a16-gemm.cuh) but src1_sorted is the RAW f32 sorted
+// activations (cast to bf16 in-register; no global bf16 pre-pass needed) and the weights are read
+// from the cached repacked layout (built one-time from src0).
+void ggml_cuda_mul_mat_id_w4a16_repack(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0,
+ const float * src1_sorted,
+ float * dst_sorted,
+ const int * tokens_per_expert,
+ int64_t n_experts, int64_t K, int64_t N,
+ cudaStream_t stream);
--
2.43.0