mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-30 03:17:01 -04:00
docs(paged): record MoE-prefill engine-gap decomposition + GEMM-port negatives (default-off)
nsys cross-engine decomposition: the MoE prefill 64% gap vs vLLM is engine plumbing, not the kernel (GPU 97% busy, 443 vs 197 us/tok). Three buckets: per-expert W4A4 M-fragmentation (58%), GDN scan (24%), f32<->bf16 casts (15%). Offline-repack (0045) and verbatim vLLM-marlin port both trail FP4-MMQ via wrapper overhead, kept default-off as recorded negatives. Assisted-by: Claude:opus-4.8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -0,0 +1,650 @@
|
||||
From 864e92807809c55905bdb73ef492f6b1d03986f9 Mon Sep 17 00:00:00 2001
|
||||
From: Ettore Di Giacinto <mudler@localai.io>
|
||||
Date: Mon, 29 Jun 2026 11:02:07 +0200
|
||||
Subject: [PATCH] feat(paged): vLLM-exact offline-repack Marlin W4A16 grouped
|
||||
MoE prefill GEMM (patch 0045)
|
||||
|
||||
The patch-0035 root-cause sweep proved the W4A16 grouped MoE GEMM is occupancy/
|
||||
latency bound (~10% MMA util), NOT dequant bound: every prefill step it re-derives
|
||||
a per-block STRIDED weight gather and streams the 4-bit weights with scattered
|
||||
4-BYTE cp.async transactions (BN*8 per K-block) + a per-step ue4m3 scale decode,
|
||||
plus a SEPARATE global f32->bf16 activation pre-pass kernel. vLLM Marlin wins by
|
||||
REPACKING the weights ONCE at load into an MMA-ready tile-contiguous 4-bit layout
|
||||
so the GEMM loop issues only cheap COALESCED 16-byte loads and the per-lane reads
|
||||
are already in fragment order; the loop does ZERO dequant work beyond the cheap
|
||||
in-register nibble->bf16 unpack. This is that approach, faithful to vLLM Marlin and
|
||||
distinct from the rejected patch-0044 int8-Marlin (which doubled the weight bytes
|
||||
to int8 + a separate act-quant pass).
|
||||
|
||||
(1) OFFLINE REPACK (one-time, first engage; cached in a persistent cudaMalloc
|
||||
buffer keyed by src0->data): transpose the NVFP4 experts [N][Kb](d[4]+qs[32])
|
||||
into tile-contiguous planes Q=[e][Kb][N][32 qs bytes] + S=[e][Kb][N][4 ue4m3
|
||||
scale bytes]. STAYS 4-BIT: qs packed, scales ue4m3 -> the repacked buffer is
|
||||
byte-for-byte the SAME size as the source NVFP4 weights (144 MiB/tensor for the
|
||||
35B-A3B experts, == source; no int8 2x). Persists across all prefill steps.
|
||||
(2) KERNEL reads the pre-packed tiles directly: per (K-block, N-block) the BN rows
|
||||
qs/scales are contiguous -> the smem fill is a flat COALESCED 16-byte cp.async
|
||||
stream (no strided gather, no address math). Activations cast f32->bf16 IN
|
||||
REGISTER on the load into smem (no separate global act pre-pass, no act-quant).
|
||||
Inner loop: ldmatrix bf16 A, in-register PRMT (LUT-free) nibble->bf16 unpack
|
||||
scaled by the in-register ue4m3 decode, bf16 m16n8k16 mma.sync into f32 accum,
|
||||
cp.async multistage, ragged per-tile expert offset (reuses 0035 tile map). NO
|
||||
separate smem-staged dequant pass, NO __syncthreads-gated dequant pass (SASS:
|
||||
LDSM->PRMT->I2F->FMUL->F2F.BF16->HMMA.16816.F32.BF16 with no STS/BAR.SYNC between
|
||||
the weight load and the MMA).
|
||||
|
||||
The repacked smem is bit-identical to the 0035 W4A16 smem, so the proven inner
|
||||
fragment math is unchanged and the output is bit-identical to the 0035 W4A16 path.
|
||||
|
||||
TOGGLE: LLAMA_W4A16_REPACK=1 (default 0 == OFF), engaged only inside the already
|
||||
default-off 0035 W4A16 path (LLAMA_W4A16_PREFILL_M>0). LLAMA_W4A16_REPACK_NOCACHE=1
|
||||
repacks into a transient pool buffer per call (for test-backend-ops, which frees/
|
||||
reuses src0 addresses and so defeats the pointer-keyed cache). Stock / decode /
|
||||
non-NVFP4 byte-untouched.
|
||||
|
||||
VALIDATION (GB10, sm_121a, Qwen3.6-35B-A3B-NVFP4):
|
||||
- test-backend-ops MUL_MAT_ID nvfp4 (vs CPU oracle), REPACK forced + NOCACHE:
|
||||
81/81 OK, 0 FAIL.
|
||||
- real-model greedy md5 (paged MoE): stock == W4A16-non-repack == W4A16-repack
|
||||
(cached) == W4A16-repack (nocache) == default-off (REPACK=1, PREFILL_M unset),
|
||||
all fda1aadbbbfb36fe8ab0f5f5465c745e (bit-identical; default-off is stock).
|
||||
|
||||
HONEST PERF (S_PP t/s, llama-batched-bench -fa on -ngl 99 -ntg 32 -npl 1, paged,
|
||||
warm cache): the offline repack is a large win over the prior non-repack W4A16
|
||||
(npp512 854.7 -> 1452.4 = +70%; beats it at every M) and is FLAT across M
|
||||
(1452/1478/1467 at 512/1024/2048 = MMA-pipeline bound as designed). But the heavily
|
||||
tuned FP4-MMQ baseline on GB10 is still ~38% faster (~2030 flat). Decode S_TG
|
||||
unchanged (~55 t/s, prefill-only lever). One-time cache build amortized (cold first
|
||||
prefill). Ships DEFAULT-OFF (like 0033/0034/0035): the validated, env-gated, bit-
|
||||
exact-gated mechanism + the recorded result that even vLLM-exact offline-repack
|
||||
Marlin W4A16 does not overtake the GB10 FP4-MMQ winner.
|
||||
|
||||
Weight-memory: the repacked representation STAYS 4-BIT (no int8 expansion; vs 0044
|
||||
+100%); the persistent cache is an additive 4-bit copy of the engaged expert tensors
|
||||
when ON, and literally 0 when OFF (default). Decode keeps the original block_nvfp4
|
||||
layout for its MMQ/graph path, so the cache is additive rather than in-place.
|
||||
|
||||
Build: arch=compute_121a,code=[compute_121a,sm_121a]; AMPERE_MMA_AVAILABLE /
|
||||
CP_ASYNC_AVAILABLE guards (NO_DEVICE_CODE off-Blackwell).
|
||||
|
||||
Assisted-by: Claude:opus-4.8 [Claude Code]
|
||||
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
|
||||
---
|
||||
ggml/src/ggml-cuda/w4a16-gemm.cu | 11 +
|
||||
ggml/src/ggml-cuda/w4a16-repack.cu | 470 ++++++++++++++++++++++++++++
|
||||
ggml/src/ggml-cuda/w4a16-repack.cuh | 59 ++++
|
||||
3 files changed, 540 insertions(+)
|
||||
create mode 100644 ggml/src/ggml-cuda/w4a16-repack.cu
|
||||
create mode 100644 ggml/src/ggml-cuda/w4a16-repack.cuh
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/w4a16-gemm.cu b/ggml/src/ggml-cuda/w4a16-gemm.cu
|
||||
index c5c9ef7..687a7ae 100644
|
||||
--- a/ggml/src/ggml-cuda/w4a16-gemm.cu
|
||||
+++ b/ggml/src/ggml-cuda/w4a16-gemm.cu
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "w4a16-gemm.cuh"
|
||||
+#include "w4a16-repack.cuh"
|
||||
#include "mma.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -389,6 +390,16 @@ void ggml_cuda_mul_mat_id_w4a16_grouped(
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_NVFP4);
|
||||
GGML_ASSERT(N % 128 == 0 && K % 64 == 0);
|
||||
|
||||
+ // [paged patch 0045] vLLM-exact offline-repack sub-mode: repack the NVFP4 experts ONCE at first
|
||||
+ // engage into an MMA-ready tile-contiguous 4-bit layout (cached, stays 4-bit), then run a GEMM
|
||||
+ // loop with coalesced 16B weight loads + in-register act cast + ZERO in-loop dequant beyond the
|
||||
+ // cheap in-register nibble->bf16 unpack. Gated by LLAMA_W4A16_REPACK (default off).
|
||||
+ if (ggml_cuda_w4a16_repack_enabled()) {
|
||||
+ ggml_cuda_mul_mat_id_w4a16_repack(ctx, src0, src1_sorted, dst_sorted,
|
||||
+ tokens_per_expert, n_experts, K, N, stream);
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
int sel = w4a16_cfg_sel();
|
||||
if (N % w4a16_cfg_bn(sel) != 0) {
|
||||
sel = 3; // BN>128 config whose BN doesn't divide N: safe BN=128 winner
|
||||
diff --git a/ggml/src/ggml-cuda/w4a16-repack.cu b/ggml/src/ggml-cuda/w4a16-repack.cu
|
||||
new file mode 100644
|
||||
index 0000000..4a7a125
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/ggml-cuda/w4a16-repack.cu
|
||||
@@ -0,0 +1,470 @@
|
||||
+#include "w4a16-repack.cuh"
|
||||
+#include "mma.cuh"
|
||||
+
|
||||
+#include <cstdint>
|
||||
+#include <cstdlib>
|
||||
+#include <mutex>
|
||||
+#include <unordered_map>
|
||||
+#include <vector>
|
||||
+#include <algorithm>
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// [paged patch 0045] vLLM-exact offline-repack Marlin W4A16 grouped MoE prefill GEMM.
|
||||
+// See w4a16-repack.cuh for the design. Default-off (LLAMA_W4A16_REPACK).
|
||||
+// ===========================================================================
|
||||
+
|
||||
+using namespace ggml_cuda_mma;
|
||||
+typedef tile<16, 8, nv_bfloat162> rp_tile_A; // A operand: M=16, K=16
|
||||
+typedef tile< 8, 8, nv_bfloat162> rp_tile_B; // B operand: N=8, K=16
|
||||
+typedef tile<16, 8, float> rp_tile_C; // accumulator: M=16, N=8
|
||||
+
|
||||
+bool ggml_cuda_w4a16_repack_enabled() {
|
||||
+ static const bool e = [] {
|
||||
+ const char * s = getenv("LLAMA_W4A16_REPACK");
|
||||
+ return s != nullptr && atoi(s) != 0;
|
||||
+ }();
|
||||
+ return e;
|
||||
+}
|
||||
+
|
||||
+// ---- cp.async helpers (sm80+; raw bytes, no cast) ----
|
||||
+static __device__ __forceinline__ void rp_cp_async16(void * smem, const void * gmem) {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ const unsigned s = (unsigned) __cvta_generic_to_shared(smem);
|
||||
+ asm volatile("cp.async.cg.shared.global [%0],[%1],16;\n" :: "r"(s), "l"(gmem));
|
||||
+#else
|
||||
+ GGML_UNUSED(smem); GGML_UNUSED(gmem); NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+static __device__ __forceinline__ void rp_cp_commit() {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ asm volatile("cp.async.commit_group;\n" ::);
|
||||
+#else
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+template<int N> static __device__ __forceinline__ void rp_cp_wait() {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
+#else
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+
|
||||
+// ---- fast 16-entry FP4 (E2M1) -> kvalues_mxfp4 lookup via __byte_perm (PRMT), LUT-free, NO
|
||||
+// divergent LDG. Bit-identical to kvalues_mxfp4[code]; same trick MMQ's get_int_from_table_16 and
|
||||
+// the 0035 kernel use. The "cheap in-register unpack" - the ONLY dequant work the loop does. ----
|
||||
+static __device__ __forceinline__ int2 rp_table16(uint32_t q4) {
|
||||
+ const uint32_t * table32 = (const uint32_t *) kvalues_mxfp4; // 16 int8 == 4 u32
|
||||
+ uint32_t tmp[2];
|
||||
+ const uint32_t sel = 0x32103210 | ((q4 & 0x88888888) >> 1);
|
||||
+#pragma unroll
|
||||
+ for (uint32_t i = 0; i < 2; ++i) {
|
||||
+ const uint32_t shift = 16*i;
|
||||
+ const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
|
||||
+ const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
|
||||
+ tmp[i] = __byte_perm(low, high, sel >> shift);
|
||||
+ }
|
||||
+ return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// (1) OFFLINE REPACK kernel (one-time). Transpose src0 NVFP4 [e][N][Kb](d[4]+qs[32]) into
|
||||
+// Qrp = [e][Kb][N][32 qs bytes] (u32 plane, 8 u32/block)
|
||||
+// Srp = [e][Kb][N][ 4 scale bytes] (u32 plane, 1 u32/block)
|
||||
+// Same total bytes as src0 (36 B/block); stays 4-bit. One thread per source block.
|
||||
+// ===========================================================================
|
||||
+static __global__ void w4a16_repack_kernel(
|
||||
+ const block_nvfp4 * __restrict__ W0, int64_t expert_stride_blocks,
|
||||
+ uint32_t * __restrict__ Qrp, uint32_t * __restrict__ Srp,
|
||||
+ int N, int Kb, int n_experts) {
|
||||
+ const int64_t total = (int64_t) n_experts * N * Kb;
|
||||
+ const int64_t t = (int64_t) blockIdx.x * blockDim.x + threadIdx.x;
|
||||
+ if (t >= total) {
|
||||
+ return;
|
||||
+ }
|
||||
+ const int kt = (int) (t % Kb);
|
||||
+ const int64_t r2 = t / Kb;
|
||||
+ const int n = (int) (r2 % N);
|
||||
+ const int e = (int) (r2 / N);
|
||||
+
|
||||
+ const block_nvfp4 * blk = W0 + (int64_t) e*expert_stride_blocks + (int64_t) n*Kb + kt;
|
||||
+ const uint32_t * src = (const uint32_t *) blk; // word0=d[4], words1..8=qs[32]
|
||||
+
|
||||
+ const int64_t qbase = (((int64_t) e*Kb + kt) * N + n) * 8; // u32 offset in Qrp
|
||||
+ const int64_t sbase = ((int64_t) e*Kb + kt) * N + n; // u32 offset in Srp
|
||||
+ Srp[sbase] = src[0];
|
||||
+#pragma unroll
|
||||
+ for (int w = 0; w < 8; w++) {
|
||||
+ Qrp[qbase + w] = src[1 + w];
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// (2) GROUPED GEMM over the repacked tiles. Per output tile (blockIdx.x=N-block, blockIdx.y=M-tile):
|
||||
+// expert e=g_tile_expert[by], row0=g_tile_row0[by], rcount=g_tile_rows[by].
|
||||
+// Weights from Qrp/Srp (tile-contiguous), activations from A_f32 (cast bf16 in-register), out to C.
|
||||
+// ZERO in-loop dequant beyond the in-register nibble->bf16 unpack + ue4m3 scale decode.
|
||||
+// ===========================================================================
|
||||
+template<int BM, int BN, int WARPS_M, int WARPS_N, int STAGES, int APAD>
|
||||
+__launch_bounds__(WARPS_M*WARPS_N*32, 1)
|
||||
+static __global__ void w4a16_repack_grouped_kernel(
|
||||
+ const float * __restrict__ A_f32, // [total_rows, K] f32 (sorted)
|
||||
+ const uint32_t * __restrict__ Qrp, // [e][Kb][N][8 u32]
|
||||
+ const uint32_t * __restrict__ Srp, // [e][Kb][N][1 u32]
|
||||
+ float * __restrict__ C, // [total_rows, N] f32
|
||||
+ const int * __restrict__ g_tile_expert,
|
||||
+ const int * __restrict__ g_tile_row0,
|
||||
+ const int * __restrict__ g_tile_rows,
|
||||
+ int N, int K, int total_rows) {
|
||||
+#if defined(AMPERE_MMA_AVAILABLE) && defined(CP_ASYNC_AVAILABLE)
|
||||
+ constexpr int BK = 64; // one nvfp4 block
|
||||
+ constexpr int NWARP = WARPS_M*WARPS_N;
|
||||
+ constexpr int THREADS = NWARP*32;
|
||||
+ constexpr int WM = BM/WARPS_M, WN = BN/WARPS_N;
|
||||
+ constexpr int MF = WM/16, NF = WN/8;
|
||||
+
|
||||
+ constexpr int AN = BK/2; // bf16 pairs per A smem row (nv_bfloat162)
|
||||
+ constexpr int ASTRIDE = AN + APAD; // padded A smem row stride (skew banks, +19% lesson)
|
||||
+ constexpr int SZ_A = BM*ASTRIDE; // nv_bfloat162 (== u32) per stage (padded)
|
||||
+ constexpr int SZ_WQ = BN*8; // u32 per stage (32 qs bytes/row, tile-contiguous)
|
||||
+ constexpr int SZ_WD = BN; // u32 per stage (4 scale bytes/row, tile-contiguous)
|
||||
+
|
||||
+ extern __shared__ uint32_t smem_u32[];
|
||||
+ constexpr int STAGE_U32 = SZ_A + SZ_WQ + SZ_WD;
|
||||
+ nv_bfloat162 * sA[STAGES];
|
||||
+ uint32_t * sWq[STAGES];
|
||||
+ uint32_t * sWd[STAGES];
|
||||
+#pragma unroll
|
||||
+ for (int s = 0; s < STAGES; s++) {
|
||||
+ uint32_t * base = smem_u32 + s*STAGE_U32;
|
||||
+ sA[s] = (nv_bfloat162 *) base;
|
||||
+ sWq[s] = base + SZ_A;
|
||||
+ sWd[s] = base + SZ_A + SZ_WQ;
|
||||
+ }
|
||||
+
|
||||
+ const int lane = threadIdx.x; // 0..31 (mma.cuh uses threadIdx.x AS the warp lane)
|
||||
+ const int warp = threadIdx.y; // 0..NWARP-1
|
||||
+ const int tid = warp*32 + lane;
|
||||
+ const int wrow = warp / WARPS_N, wcol = warp % WARPS_N;
|
||||
+
|
||||
+ const int e = g_tile_expert[blockIdx.y];
|
||||
+ const int row0 = g_tile_row0[blockIdx.y];
|
||||
+ const int rcount = g_tile_rows[blockIdx.y];
|
||||
+ const int blockCol = blockIdx.x*BN;
|
||||
+ const int Kb = K/64;
|
||||
+ // base u32 offset of expert e in the repacked planes (tile = + (kt*N + blockCol)*{8,1})
|
||||
+ const int64_t Qe = (int64_t) e*Kb*N*8;
|
||||
+ const int64_t Se = (int64_t) e*Kb*N;
|
||||
+
|
||||
+ rp_tile_C acc[MF][NF];
|
||||
+
|
||||
+ // async-load K-block kt into stage st: A cast in-register (no global pre-pass); W coalesced 16B.
|
||||
+ auto load_tile = [&](int st, int kt) {
|
||||
+ // A: BM rows x BK bf16 = BM x (BK/8) 16B chunks. Source f32 -> cast bf16 in register.
|
||||
+ const int A_chunks = BM*(BK/8);
|
||||
+#pragma unroll 1
|
||||
+ for (int idx = tid; idx < A_chunks; idx += THREADS) {
|
||||
+ const int c = idx % (BK/8); // 16B (8 bf16) chunk in the row
|
||||
+ const int r = idx / (BK/8); // row in tile
|
||||
+ const int gr = row0 + r;
|
||||
+ nv_bfloat162 * d2 = sA[st] + r*ASTRIDE + c*4; // 4 nv_bfloat162 = 8 bf16
|
||||
+ if (gr < total_rows) {
|
||||
+ const float * src = A_f32 + (int64_t) gr*K + (int64_t) kt*BK + c*8;
|
||||
+ const float4 a = *(const float4 *) (src);
|
||||
+ const float4 b = *(const float4 *) (src + 4);
|
||||
+ d2[0] = make_bfloat162(__float2bfloat16(a.x), __float2bfloat16(a.y));
|
||||
+ d2[1] = make_bfloat162(__float2bfloat16(a.z), __float2bfloat16(a.w));
|
||||
+ d2[2] = make_bfloat162(__float2bfloat16(b.x), __float2bfloat16(b.y));
|
||||
+ d2[3] = make_bfloat162(__float2bfloat16(b.z), __float2bfloat16(b.w));
|
||||
+ } else {
|
||||
+ const nv_bfloat162 z = make_bfloat162((nv_bfloat16) 0.0f, (nv_bfloat16) 0.0f);
|
||||
+ d2[0] = z; d2[1] = z; d2[2] = z; d2[3] = z;
|
||||
+ }
|
||||
+ }
|
||||
+ // W qs: tile-contiguous BN*32 bytes = BN*8 u32 = BN*2 16B chunks. Coalesced cp.async.16.
|
||||
+ const uint32_t * Qtile = Qrp + Qe + ((int64_t) kt*N + blockCol) * 8;
|
||||
+ const int Q_chunks = (BN*8) / 4; // 16B (4 u32) chunks
|
||||
+#pragma unroll 1
|
||||
+ for (int idx = tid; idx < Q_chunks; idx += THREADS) {
|
||||
+ rp_cp_async16(&sWq[st][idx*4], Qtile + idx*4);
|
||||
+ }
|
||||
+ // W scales: tile-contiguous BN*4 bytes = BN u32 = BN/4 16B chunks. Coalesced cp.async.16.
|
||||
+ const uint32_t * Stile = Srp + Se + ((int64_t) kt*N + blockCol);
|
||||
+ const int S_chunks = BN / 4;
|
||||
+#pragma unroll 1
|
||||
+ for (int idx = tid; idx < S_chunks; idx += THREADS) {
|
||||
+ rp_cp_async16(&sWd[st][idx*4], Stile + idx*4);
|
||||
+ }
|
||||
+ };
|
||||
+
|
||||
+ // prologue
|
||||
+#pragma unroll
|
||||
+ for (int s = 0; s < STAGES-1; s++) { if (s < Kb) load_tile(s, s); rp_cp_commit(); }
|
||||
+
|
||||
+ for (int kt = 0; kt < Kb; kt++) {
|
||||
+ const int ld = kt + (STAGES-1);
|
||||
+ if (ld < Kb) load_tile(ld % STAGES, ld);
|
||||
+ rp_cp_commit();
|
||||
+ rp_cp_wait<STAGES-1>();
|
||||
+ __syncthreads();
|
||||
+
|
||||
+ const int rs = kt % STAGES;
|
||||
+ const nv_bfloat162 * sAcur = sA[rs];
|
||||
+ const uint32_t * sWqw = sWq[rs]; // BN rows x 8 u32 (32 qs bytes)
|
||||
+ const uint32_t * sWdw = sWd[rs]; // BN rows x 1 u32 (4 scale bytes)
|
||||
+
|
||||
+#pragma unroll
|
||||
+ for (int kk = 0; kk < BK/16; kk++) { // 4 m16n8k16 sub-steps per 64-block
|
||||
+ const int sub = kk;
|
||||
+ // A fragments via ldmatrix (bf16)
|
||||
+ rp_tile_A A_frag[MF];
|
||||
+#pragma unroll
|
||||
+ for (int mi = 0; mi < MF; mi++) {
|
||||
+ const int rb = wrow*WM + mi*16;
|
||||
+ load_ldmatrix(A_frag[mi], sAcur + rb*ASTRIDE + kk*8, ASTRIDE);
|
||||
+ }
|
||||
+ // B fragments: in-register FP4->bf16 unpack (byte_perm LUT-free) * in-register ue4m3 scale.
|
||||
+ rp_tile_B B_frag[NF];
|
||||
+ const int n_local = lane >> 2; // tile_B::get_i (row N, 0..7)
|
||||
+ const int jc = lane & 3;
|
||||
+ const int wsel = sub*2 + (jc >> 1);
|
||||
+ const int bsh = 8 * (2*(jc & 1));
|
||||
+#pragma unroll
|
||||
+ for (int ni = 0; ni < NF; ni++) {
|
||||
+ const int nrow = wcol*WN + ni*8 + n_local; // col within BN tile [0,BN)
|
||||
+ const uint32_t w = sWqw[nrow*8 + wsel];
|
||||
+ const int2 kv = rp_table16(w);
|
||||
+ const float sc = ggml_cuda_ue4m3_to_fp32(((const uint8_t *) &sWdw[nrow])[sub]);
|
||||
+ B_frag[ni].x[0].x = __float2bfloat16(sc * (float) (int8_t) (kv.x >> bsh));
|
||||
+ B_frag[ni].x[0].y = __float2bfloat16(sc * (float) (int8_t) (kv.x >> (bsh + 8)));
|
||||
+ B_frag[ni].x[1].x = __float2bfloat16(sc * (float) (int8_t) (kv.y >> bsh));
|
||||
+ B_frag[ni].x[1].y = __float2bfloat16(sc * (float) (int8_t) (kv.y >> (bsh + 8)));
|
||||
+ }
|
||||
+#pragma unroll
|
||||
+ for (int mi = 0; mi < MF; mi++)
|
||||
+#pragma unroll
|
||||
+ for (int ni = 0; ni < NF; ni++)
|
||||
+ mma(acc[mi][ni], A_frag[mi], B_frag[ni]);
|
||||
+ }
|
||||
+ __syncthreads();
|
||||
+ }
|
||||
+
|
||||
+ // write back (mask the ragged per-expert row tail)
|
||||
+#pragma unroll
|
||||
+ for (int mi = 0; mi < MF; mi++)
|
||||
+#pragma unroll
|
||||
+ for (int ni = 0; ni < NF; ni++) {
|
||||
+ const int orow = wrow*WM + mi*16;
|
||||
+ const int ocol = blockCol + wcol*WN + ni*8;
|
||||
+#pragma unroll
|
||||
+ for (int l = 0; l < acc[mi][ni].ne; l++) {
|
||||
+ const int lr = orow + acc[mi][ni].get_i(l);
|
||||
+ const int nc = ocol + acc[mi][ni].get_j(l);
|
||||
+ if (lr < rcount && nc < N) {
|
||||
+ C[(int64_t)(row0 + lr)*N + nc] = acc[mi][ni].x[l];
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+#else
|
||||
+ GGML_UNUSED(A_f32); GGML_UNUSED(Qrp); GGML_UNUSED(Srp); GGML_UNUSED(C);
|
||||
+ GGML_UNUSED(g_tile_expert); GGML_UNUSED(g_tile_row0); GGML_UNUSED(g_tile_rows);
|
||||
+ GGML_UNUSED(N); GGML_UNUSED(K); GGML_UNUSED(total_rows);
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // AMPERE_MMA_AVAILABLE && CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+
|
||||
+// launch the one-time repack kernel: src0 NVFP4 -> (Qrp, Srp) repacked planes.
|
||||
+static void w4a16_repack_launch(
|
||||
+ const ggml_tensor * src0, int64_t n_experts, int64_t K, int64_t N,
|
||||
+ uint32_t * Qrp, uint32_t * Srp, cudaStream_t stream) {
|
||||
+ const int64_t Kb = K / 64;
|
||||
+ const int64_t nblk = n_experts * Kb * N;
|
||||
+ const int64_t expert_stride_blocks = (int64_t) (src0->nb[2] / sizeof(block_nvfp4));
|
||||
+ const int threads = 256;
|
||||
+ const int64_t grid = (nblk + threads - 1) / threads;
|
||||
+ w4a16_repack_kernel<<<grid, threads, 0, stream>>>(
|
||||
+ (const block_nvfp4 *) src0->data, expert_stride_blocks,
|
||||
+ Qrp, Srp, (int) N, (int) Kb, (int) n_experts);
|
||||
+ CUDA_CHECK(cudaGetLastError());
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// repack cache: persistent device buffer per src0 tensor (keyed by src0->data). One-time build.
|
||||
+//
|
||||
+// Correct for real models: model weights live in a persistent buffer for the model's lifetime, so
|
||||
+// src0->data is stable and never freed/reused during inference. NOT correct for test-backend-ops,
|
||||
+// which allocates/frees a fresh src0 per case and ggml reuses the address -> a cache HIT returns
|
||||
+// stale repacked weights. For harness validation use LLAMA_W4A16_REPACK_NOCACHE=1 (repack into a
|
||||
+// transient pool buffer every call, bypassing the cache); this proves the kernel+repack correct.
|
||||
+// ===========================================================================
|
||||
+struct w4a16_repack_entry {
|
||||
+ uint32_t * Qrp = nullptr; // [n_experts][Kb][N][8 u32]
|
||||
+ uint32_t * Srp = nullptr; // [n_experts][Kb][N][1 u32]
|
||||
+ int64_t n_experts = 0, Kb = 0, N = 0;
|
||||
+ size_t bytes = 0; // total cudaMalloc bytes (Q + S), for the memory-delta report
|
||||
+};
|
||||
+static std::mutex g_rp_mu;
|
||||
+static std::unordered_map<const void *, w4a16_repack_entry> g_rp_cache;
|
||||
+
|
||||
+// total bytes the repack cache currently holds on device (for the memory-delta report).
|
||||
+size_t ggml_cuda_w4a16_repack_cache_bytes() {
|
||||
+ std::lock_guard<std::mutex> lk(g_rp_mu);
|
||||
+ size_t b = 0;
|
||||
+ for (const auto & kv : g_rp_cache) {
|
||||
+ b += kv.second.bytes;
|
||||
+ }
|
||||
+ return b;
|
||||
+}
|
||||
+
|
||||
+static const w4a16_repack_entry & w4a16_get_or_build_repack(
|
||||
+ const ggml_tensor * src0, int64_t n_experts, int64_t K, int64_t N, cudaStream_t stream) {
|
||||
+ const void * key = src0->data;
|
||||
+ std::lock_guard<std::mutex> lk(g_rp_mu);
|
||||
+ auto it = g_rp_cache.find(key);
|
||||
+ if (it != g_rp_cache.end()) {
|
||||
+ return it->second;
|
||||
+ }
|
||||
+
|
||||
+ const int64_t Kb = K / 64;
|
||||
+ const int64_t nblk = n_experts * Kb * N;
|
||||
+ const size_t q_bytes = (size_t) nblk * 8 * sizeof(uint32_t); // 32 qs bytes/block
|
||||
+ const size_t s_bytes = (size_t) nblk * 1 * sizeof(uint32_t); // 4 scale bytes/block
|
||||
+ // single allocation: [Q][S]; total == source NVFP4 weight bytes (36 B/block) -> stays 4-bit.
|
||||
+ void * buf = nullptr;
|
||||
+ CUDA_CHECK(cudaMalloc(&buf, q_bytes + s_bytes));
|
||||
+ uint32_t * Qrp = (uint32_t *) buf;
|
||||
+ uint32_t * Srp = (uint32_t *) ((char *) buf + q_bytes);
|
||||
+
|
||||
+ w4a16_repack_launch(src0, n_experts, K, N, Qrp, Srp, stream);
|
||||
+
|
||||
+ w4a16_repack_entry ent;
|
||||
+ ent.Qrp = Qrp; ent.Srp = Srp;
|
||||
+ ent.n_experts = n_experts; ent.Kb = Kb; ent.N = N;
|
||||
+ ent.bytes = q_bytes + s_bytes;
|
||||
+ auto res = g_rp_cache.emplace(key, ent);
|
||||
+
|
||||
+ if (getenv("LLAMA_W4A16_DEBUG")) {
|
||||
+ // NB: we already hold g_rp_mu - sum the cache total INLINE (do NOT call the public
|
||||
+ // ggml_cuda_w4a16_repack_cache_bytes(), which re-locks the non-recursive mutex -> deadlock).
|
||||
+ size_t total_cache = 0;
|
||||
+ for (const auto & kv : g_rp_cache) {
|
||||
+ total_cache += kv.second.bytes;
|
||||
+ }
|
||||
+ fprintf(stderr, "[w4a16-repack] BUILT cache for src0=%p: n_experts=%lld Kb=%lld N=%lld "
|
||||
+ "bytes=%.1f MiB (== source NVFP4 weight bytes; stays 4-bit). total cache=%.1f MiB\n",
|
||||
+ key, (long long) n_experts, (long long) Kb, (long long) N,
|
||||
+ ent.bytes / (1024.0*1024.0), total_cache / (1024.0*1024.0));
|
||||
+ }
|
||||
+ return res.first->second;
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// GEMM run over already-repacked planes (Qrp, Srp). Single tuned config
|
||||
+// (the 0035 winner: BM64 BN128 WARPS_M1 WARPS_N8 STAGES3 APAD4).
|
||||
+// ===========================================================================
|
||||
+static void w4a16_repack_run(
|
||||
+ ggml_backend_cuda_context & ctx,
|
||||
+ const float * src1_sorted,
|
||||
+ float * dst_sorted,
|
||||
+ const int * tokens_per_expert,
|
||||
+ int64_t n_experts, int64_t K, int64_t N,
|
||||
+ const uint32_t * Qrp, const uint32_t * Srp,
|
||||
+ cudaStream_t stream) {
|
||||
+ constexpr int BM = 64, BN = 128, WARPS_M = 1, WARPS_N = 8, STAGES = 3, APAD = 4;
|
||||
+
|
||||
+ // host: per-M-tile expert map (ragged, no tile crosses an expert boundary)
|
||||
+ int64_t total_rows = 0;
|
||||
+ for (int64_t e = 0; e < n_experts; e++) {
|
||||
+ total_rows += tokens_per_expert[e];
|
||||
+ }
|
||||
+ if (total_rows == 0) {
|
||||
+ return;
|
||||
+ }
|
||||
+ std::vector<int32_t> h_tile_expert, h_tile_row0, h_tile_rows;
|
||||
+ int64_t row = 0;
|
||||
+ for (int64_t e = 0; e < n_experts; e++) {
|
||||
+ const int t = tokens_per_expert[e];
|
||||
+ for (int off = 0; off < t; off += BM) {
|
||||
+ h_tile_expert.push_back((int32_t) e);
|
||||
+ h_tile_row0.push_back((int32_t) (row + off));
|
||||
+ h_tile_rows.push_back((int32_t) std::min(BM, t - off));
|
||||
+ }
|
||||
+ row += t;
|
||||
+ }
|
||||
+ const int n_tiles = (int) h_tile_expert.size();
|
||||
+
|
||||
+ if (getenv("LLAMA_W4A16_DEBUG")) {
|
||||
+ int max_tpe = 0, multi = 0;
|
||||
+ for (int64_t e = 0; e < n_experts; e++) {
|
||||
+ if (tokens_per_expert[e] > max_tpe) max_tpe = tokens_per_expert[e];
|
||||
+ if (tokens_per_expert[e] > BM) multi++;
|
||||
+ }
|
||||
+ fprintf(stderr, "[w4a16-repack] engaged: total_rows=%lld n_experts=%lld K=%lld N=%lld "
|
||||
+ "n_tiles=%d max_tpe=%d multi_tile=%d (offline-repack, in-register act cast)\n",
|
||||
+ (long long) total_rows, (long long) n_experts, (long long) K, (long long) N,
|
||||
+ n_tiles, max_tpe, multi);
|
||||
+ }
|
||||
+
|
||||
+ ggml_cuda_pool_alloc<int32_t> d_tile_expert(ctx.pool(), n_tiles);
|
||||
+ ggml_cuda_pool_alloc<int32_t> d_tile_row0 (ctx.pool(), n_tiles);
|
||||
+ ggml_cuda_pool_alloc<int32_t> d_tile_rows (ctx.pool(), n_tiles);
|
||||
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_expert.ptr, h_tile_expert.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_row0.ptr, h_tile_row0.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_rows.ptr, h_tile_rows.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
+
|
||||
+ auto kern = w4a16_repack_grouped_kernel<BM, BN, WARPS_M, WARPS_N, STAGES, APAD>;
|
||||
+ constexpr int AN = 64/2, ASTRIDE = AN + APAD;
|
||||
+ constexpr int STAGE_U32 = BM*ASTRIDE + BN*8 + BN;
|
||||
+ const int smem_bytes = STAGES * STAGE_U32 * (int) sizeof(uint32_t);
|
||||
+ CUDA_CHECK(cudaFuncSetAttribute(kern, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes));
|
||||
+
|
||||
+ dim3 grid((unsigned) (N / BN), (unsigned) n_tiles);
|
||||
+ dim3 block(32, WARPS_M*WARPS_N);
|
||||
+ kern<<<grid, block, smem_bytes, stream>>>(
|
||||
+ src1_sorted, Qrp, Srp, dst_sorted,
|
||||
+ d_tile_expert.ptr, d_tile_row0.ptr, d_tile_rows.ptr,
|
||||
+ (int) N, (int) K, (int) total_rows);
|
||||
+ CUDA_CHECK(cudaGetLastError());
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// public host entry. Default: cached persistent repack (one-time, production-correct, stays
|
||||
+// 4-bit). LLAMA_W4A16_REPACK_NOCACHE=1: repack into a transient pool buffer EVERY call (bypasses
|
||||
+// the pointer-keyed cache) so test-backend-ops (which frees/reuses src0 addresses) can validate
|
||||
+// the kernel + repack correctness.
|
||||
+// ===========================================================================
|
||||
+void ggml_cuda_mul_mat_id_w4a16_repack(
|
||||
+ ggml_backend_cuda_context & ctx,
|
||||
+ const ggml_tensor * src0,
|
||||
+ const float * src1_sorted,
|
||||
+ float * dst_sorted,
|
||||
+ const int * tokens_per_expert,
|
||||
+ int64_t n_experts, int64_t K, int64_t N,
|
||||
+ cudaStream_t stream) {
|
||||
+ GGML_ASSERT(src0->type == GGML_TYPE_NVFP4);
|
||||
+ GGML_ASSERT(N % 128 == 0 && K % 64 == 0);
|
||||
+
|
||||
+ static const bool nocache = [] {
|
||||
+ const char * s = getenv("LLAMA_W4A16_REPACK_NOCACHE");
|
||||
+ return s != nullptr && atoi(s) != 0;
|
||||
+ }();
|
||||
+
|
||||
+ if (nocache) {
|
||||
+ const int64_t Kb = K / 64;
|
||||
+ const int64_t nblk = n_experts * Kb * N;
|
||||
+ // (1) repack into a transient pool buffer (same 4-bit size; stream-ordered before the GEMM).
|
||||
+ ggml_cuda_pool_alloc<uint32_t> q(ctx.pool(), (size_t) nblk * 8);
|
||||
+ ggml_cuda_pool_alloc<uint32_t> s(ctx.pool(), (size_t) nblk * 1);
|
||||
+ w4a16_repack_launch(src0, n_experts, K, N, q.ptr, s.ptr, stream);
|
||||
+ // (2) GEMM (q,s stay alive through the launch; kernel is queued on `stream` before dtor).
|
||||
+ w4a16_repack_run(ctx, src1_sorted, dst_sorted, tokens_per_expert,
|
||||
+ n_experts, K, N, q.ptr, s.ptr, stream);
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ // (1) one-time offline repack (cached, persistent, stays 4-bit). Stream-ordered before the GEMM.
|
||||
+ const w4a16_repack_entry & rp = w4a16_get_or_build_repack(src0, n_experts, K, N, stream);
|
||||
+ // (2) GEMM over the cached repacked planes.
|
||||
+ w4a16_repack_run(ctx, src1_sorted, dst_sorted, tokens_per_expert,
|
||||
+ n_experts, K, N, rp.Qrp, rp.Srp, stream);
|
||||
+}
|
||||
diff --git a/ggml/src/ggml-cuda/w4a16-repack.cuh b/ggml/src/ggml-cuda/w4a16-repack.cuh
|
||||
new file mode 100644
|
||||
index 0000000..f44f5f8
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/ggml-cuda/w4a16-repack.cuh
|
||||
@@ -0,0 +1,59 @@
|
||||
+#pragma once
|
||||
+
|
||||
+#include "common.cuh"
|
||||
+
|
||||
+// [paged patch 0045] vLLM-EXACT offline-repack Marlin W4A16 grouped MoE prefill GEMM.
|
||||
+//
|
||||
+// This is a SUB-MODE of the patch-0035 W4A16 grouped MoE GEMM (default-off). The 0035 root-cause
|
||||
+// sweep proved the W4A16 kernel is occupancy/latency bound (~10% MMA util): every prefill step it
|
||||
+// re-derives a per-block STRIDED weight gather (blk = We + (blockCol+r)*Kb + kt) and streams the
|
||||
+// 4-bit weights with scattered 4-BYTE cp.async transactions (BN*8 per K-block) + a per-step ue4m3
|
||||
+// scale decode, plus a SEPARATE global f32->bf16 activation pre-pass kernel. vLLM's Marlin wins by
|
||||
+// REPACKING the weights ONCE at load into an MMA-ready, tile-contiguous 4-bit layout so the GEMM
|
||||
+// loop issues only cheap COALESCED 16-byte loads and the per-lane reads are already in fragment
|
||||
+// order - the loop does ZERO dequant work beyond the cheap in-register nibble->bf16 unpack.
|
||||
+//
|
||||
+// What this mode does, faithful to vLLM Marlin:
|
||||
+// (1) OFFLINE REPACK (one-time, at first engage; cached in a persistent cudaMalloc device buffer
|
||||
+// keyed by src0->data): transpose the NVFP4 expert weights [N rows][Kb blocks](d[4]+qs[32])
|
||||
+// into two tile-contiguous planes, Q=[expert][Kb][N][32 qs bytes] and S=[expert][Kb][N][4
|
||||
+// ue4m3 scale bytes]. STAYS 4-BIT: qs kept packed, scales kept ue4m3 - the repacked buffer is
|
||||
+// byte-for-byte the SAME size as the source NVFP4 weights (36 B / 64 weights). No int8
|
||||
+// expansion, no 2x memory (vs the rejected patch-0044 int8-Marlin which doubled the weight
|
||||
+// bytes). Persists across all prefill steps.
|
||||
+// (2) KERNEL reads the pre-packed tiles DIRECTLY: per (K-block, N-block) the BN rows' qs/scales
|
||||
+// are contiguous, so the smem fill is a flat COALESCED 16-byte cp.async stream (no strided
|
||||
+// gather, no address math). Activations are cast f32->bf16 IN REGISTER on the load into smem
|
||||
+// (no separate global act pre-pass, no act-quant). The inner loop does only: ldmatrix the
|
||||
+// bf16 A, in-register nibble->bf16 unpack of the weight (byte_perm/PRMT, LUT-free) scaled by
|
||||
+// the in-register ue4m3 decode, and bf16 m16n8k16 mma.sync into f32 accumulators. cp.async
|
||||
+// multistage pipelined, ragged per-tile expert offset (reuses 0035's tile map). NO separate
|
||||
+// smem-staged dequant pass, NO __syncthreads-gated dequant pass.
|
||||
+//
|
||||
+// The repacked smem contents are bit-identical to the 0035 W4A16 smem, so the inner fragment math
|
||||
+// (proven 81/81 on test-backend-ops MUL_MAT_ID) is unchanged and the output is bit-identical to the
|
||||
+// 0035 W4A16 path; only the global->smem load path (coalesced) and the act cast (in-register) change.
|
||||
+//
|
||||
+// Toggle: LLAMA_W4A16_REPACK=1 (default 0 == OFF). Engages only inside the already-default-off
|
||||
+// 0035 W4A16 grouped path (LLAMA_W4A16_PREFILL_M>0). Stock / decode / non-NVFP4 byte-untouched.
|
||||
+
|
||||
+// True iff LLAMA_W4A16_REPACK != 0.
|
||||
+bool ggml_cuda_w4a16_repack_enabled();
|
||||
+
|
||||
+// Total bytes the persistent repack cache currently holds on device (for the weight-memory-delta
|
||||
+// report). The repacked layout stays 4-bit, so this equals the source NVFP4 bytes of the engaged
|
||||
+// expert weight tensors; 0 when the toggle is OFF (no repack happens).
|
||||
+size_t ggml_cuda_w4a16_repack_cache_bytes();
|
||||
+
|
||||
+// Offline-repack W4A16 grouped MoE GEMM over the token-sorted buffer. Same contract as
|
||||
+// ggml_cuda_mul_mat_id_w4a16_grouped (see w4a16-gemm.cuh) but src1_sorted is the RAW f32 sorted
|
||||
+// activations (cast to bf16 in-register; no global bf16 pre-pass needed) and the weights are read
|
||||
+// from the cached repacked layout (built one-time from src0).
|
||||
+void ggml_cuda_mul_mat_id_w4a16_repack(
|
||||
+ ggml_backend_cuda_context & ctx,
|
||||
+ const ggml_tensor * src0,
|
||||
+ const float * src1_sorted,
|
||||
+ float * dst_sorted,
|
||||
+ const int * tokens_per_expert,
|
||||
+ int64_t n_experts, int64_t K, int64_t N,
|
||||
+ cudaStream_t stream);
|
||||
--
|
||||
2.43.0
|
||||
|
||||
Reference in New Issue
Block a user