mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-30 11:26:32 -04:00
fix(paged): repair the patch series, sync to the fork branch (drop dev-tree 0044/0045, add f32-only M5 as 0047)
The 0044/0045 patches were exported from the old bf16/hybrid dev tree and no longer apply on the f32-only series (0026 ssm_bf16_tau is dropped), so the build broke at `git apply`. Re-sync the vendored series to the now feature-complete fork branch mudler/llama.cpp:localai-paged, which is the canonical source (pin 0ed235ea + the paged patch commits in order). - git rm the dev-tree-based 0044 (GDN M5, bf16-machinery base) and 0045 (Marlin W4A16 offline-repack, not part of the fork branch). - Add the fork branch's newest commit (2c32ab8b7, "GDN M5 tensor-core chunked-scan prefill, f32-only re-port") as 0047, generated with a single git format-patch off that branch. It sequences after 0046 (its parent on the branch) and recovers the prefill win 0044 encoded (+3.5% S_PP @npp512, +17.7% @npp2048), bit-exact per-path (test-backend-ops GATED_DELTA_NET 46/46 default and force-M5; greedy md5 default-on == M5-forced == canonical). - Track patch 0046 (dense-prefill geometry gate), which was on disk but never committed, so the series is complete in git. - README: patch-table header 0001-0046 -> 0001-0047, replace the 0044 row with the f32-only 0047 row, fix the dangling 0044 prose references, note the bf16 M6/M7/M8 variants are not part of this f32-only series, and add a maintenance bullet that the series is now generated from the fork branch so there is no more patch-export drift. Verified: on a pristine llama.cpp at pin 0ed235ea the full series 0001-0043, 0046, 0047 applies clean in sorted order with the Makefile's exact `git apply --verbose` method (37/37 OK), and the resulting tree is byte-identical to the fork branch tip 2c32ab8b7. Assisted-by: Claude:opus-4.8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -86,7 +86,7 @@ orthogonal to the paged allocator.
|
||||
|
||||
---
|
||||
|
||||
## 3. Patch series (0001-0044)
|
||||
## 3. Patch series (0001-0047)
|
||||
|
||||
Source-only patches, with intentional numbering gaps (e.g. 0005, 0027). The
|
||||
decode-serving graph-reuse levers are 0040-0041. "Bit-exact" = greedy md5 /
|
||||
@@ -188,8 +188,9 @@ These are the dominant decode levers on the Qwen3.6 hybrid models. All bit-exact
|
||||
| 0024 | **Paged-pool burst-reclaim** - truncate trailing blocks on partial-tail `seq_rm`, defrag the free queue when idle, release blocks on slot completion. Fixes the long-server burst-degradation bug (post-burst prefill collapse 488->44 t/s, restored to 532). Host-side accounting only. | yes |
|
||||
| 0029 | **Block-table within-step host cache** - the block table is fixed for the whole step; cache it on first build and memcpy it for the other full-attention layers (get_block_table -87%/-91%). | yes, per path (paged-MoE ref `8cb0ce23`) |
|
||||
| 0030 | **Fused-op backend gate** - the fused GDN / discriminated SSM_CONV ops are CUDA-family + CPU only; force them off on any non-CUDA compute backend so a Vulkan/SYCL/Metal build can't silently run the wrong plain-conv kernel. | yes on CUDA (byte-identical pre-0030); safety gate elsewhere |
|
||||
| 0031 | **Chunked parallel-scan GDN prefill kernel** (upstream TODO) - FLA-style chunked gated-delta-rule for prefill (non-KDA / f32 / final-state): intra-chunk delta rule solved in parallel (UT-transform + forward subst), inter-chunk recurrence over n_tokens/C steps. The scalar-serial form (`GDN_TC=0`) was bit-exact-benign but not faster than the tuned sequential scan at the GB10-forced C=16 (see section 5); **superseded for paged by the tensor-core M5 path of 0044**. | NEW per-path (`test-backend-ops` 91/91, <=1e-7 NMSE vs CPU ref) |
|
||||
| 0044 | **GDN M5 tensor-core chunked-scan prefill, default-ON under paged KV** - the tensor-core forms of 0031's scan (KK/QK Gram, KS/QS state-boundary, P*U output, full form-T solve + state-update mma = M5; plus register-resident M6/M7 and CONFIG-C M8), single build, runtime-selected by `GDN_TC`. Ships **M5 default-on when `LLAMA_KV_PAGED` is set** (`GDN_TC=5` + `GDN_CHUNK_MIN=64`, both env-overridable; OFF/`INT_MAX` when not paged). `GDN_CHUNK_MIN` is the per-call engage threshold and stays > 1 so decode (1 tok/call) keeps the sequential recurrence (at 1 it swallows decode and drops S_TG ~25%); 64 tuned from a {1,32,64,128,256} sweep. MoE prefill S_PP +3.5% @npp512 (3x A/B), +17.7% @npp2048; decode S_TG unchanged. | NEW per-path, benign (`test-backend-ops` 94/94 incl. multi-chunk; greedy md5 default-on == M5-forced == canonical on the gate prompt: paged-MoE `8cb0ce23`, dense `5951a5b4`; long MoE prompt = one benign greedy flip vs sequential, dense byte-identical) |
|
||||
| 0031 | **Chunked parallel-scan GDN prefill kernel** (upstream TODO) - FLA-style chunked gated-delta-rule for prefill (non-KDA / f32 / final-state): intra-chunk delta rule solved in parallel (UT-transform + forward subst), inter-chunk recurrence over n_tokens/C steps. The scalar-serial form (`GDN_TC=0`) was bit-exact-benign but not faster than the tuned sequential scan at the GB10-forced C=16 (see section 5); **superseded for paged by the tensor-core M5 path of 0047**. | NEW per-path (`test-backend-ops` 91/91, <=1e-7 NMSE vs CPU ref) |
|
||||
| 0047 | **GDN M5 tensor-core chunked-scan prefill, f32-only re-port, default-ON under paged KV** - the f32/tf32 tensor-core forms of 0031's scan (KK/QK Gram = M2, KS/QS state-boundary 3xtf32 = M3, P*U output = M4, full form-T solve + state-update mma = M5), single build, runtime-selected by `GDN_TC`. Ships **M5 default-on when `LLAMA_KV_PAGED` is set** (`GDN_TC=5` + `GDN_CHUNK_MIN=64`, both env-overridable; OFF/`INT_MAX` when not paged). `GDN_CHUNK_MIN` is the per-call engage threshold and stays > 1 so decode (1 tok/call) keeps the sequential recurrence (at 1 it swallows decode and drops S_TG ~25%); 64 tuned from a {1,32,64,128,256} sweep. The bf16/hybrid dev-tree machinery (STATE_BF16/HYBRID, the dropped 0026 ssm_bf16_tau) and the bf16 CONFIG-C (M8) plus register-resident M6/M7 variants are NOT part of this f32-only series. MoE prefill S_PP +3.5% @npp512 (3x A/B), +17.7% @npp2048; decode S_TG unchanged. | NEW per-path, benign (`test-backend-ops` GATED_DELTA_NET 46/46 default AND force-M5, incl. multi-chunk/tail-chunk/multi-seq; greedy md5 default-on == M5-forced == canonical on the gate prompt: paged-MoE `8cb0ce23`, dense `5951a5b4`; long MoE prompt = one benign greedy flip vs sequential, dense byte-identical) |
|
||||
| 0046 | **GDN prefill geometry gated by scan length** - patch 0022's `(NUM_WARPS=16, COLS_PER_WARP=8)` column-fold of the GDN sequential-recurrence dispatch (`case 128`) is a decode win but was applied UNCONDITIONALLY, so it also hit dense prefill (~-6% vs stock): on a long sequential scan the launch `grid.z` collapses from `S_v/4 = 32` to `S_v/(16*8) = 1` and the SMs starve (profiled: `gated_delta_net` +54% GPU time = the whole dense-prefill regression). Gate the geometry by per-call scan length: long scans (prefill, `n_tokens >= GDN_PREFILL_NTOK`, default 256) take stock's high-grid.z `(4,1)` geometry; short scans (decode) keep the `(16,8)` retune. Recovers dense prefill +7.2% back to stock parity, keeps the decode win. `GDN_PREFILL_NTOK` tunes the crossover; an explicit `GDN_NW`/`GDN_CPW` sweep still overrides (gate yields when either is set), so the one-build %peak A/B harness is unchanged. | yes (patch 0022 proved every `{NW,CPW}` variant byte-identical, so switching geometry by scan length cannot move the md5) |
|
||||
|
||||
> **Dropped: patch 0026 (hybrid per-head bf16 SSM state, `ssm_bf16_tau`).** Once
|
||||
> the decode fusions (0028 recurrent-state gather-fusion + 0029 block-table cache)
|
||||
@@ -371,7 +372,7 @@ llama is losing. The MoE GEMM kernel is *not* where the gap lives.
|
||||
(The "the win was NVFP4-dense-quant, not the Marlin kernel" dense verdict
|
||||
carries over to MoE.)
|
||||
- **Chunked parallel-scan GDN prefill (patch 0031): the scalar-serial form was
|
||||
FLAT-to-SLOWER at C=16 - the tensor-core M5 form (patch 0044) is the win,
|
||||
FLAT-to-SLOWER at C=16 - the tensor-core M5 form (patch 0047) is the win,
|
||||
now DEFAULT-ON under paged KV.** 0031 implements the upstream "faster pre-fill"
|
||||
TODO - the FLA-style chunked gated-delta-rule (intra-chunk delta rule solved in
|
||||
parallel via the UT-transform + forward substitution, inter-chunk recurrence
|
||||
@@ -382,10 +383,12 @@ llama is losing. The MoE GEMM kernel is *not* where the gap lives.
|
||||
pinned to 1 block/SM with serial per-thread dk-reductions and measured **~761
|
||||
t/s chunked vs ~971 t/s sequential (~22% slower)**, grid-starved at low n_seqs.
|
||||
The lesson held: **at this head dim the win needs tensor cores, not just
|
||||
chunking.** Patch 0044 builds those tensor-core forms (KK/QK Gram = M2, KS/QS
|
||||
state-boundary = M3, P*U output = M4, full form-T solve + state-update mma =
|
||||
M5; plus register-resident M6/M7 and CONFIG-C M8, all `GDN_TC`-selected in one
|
||||
build) and ships **M5** as the default when `LLAMA_KV_PAGED` is set. M5 is the
|
||||
chunking.** Patch 0047 builds those tensor-core forms (KK/QK Gram = M2, KS/QS
|
||||
state-boundary 3xtf32 = M3, P*U output = M4, full form-T solve + state-update
|
||||
mma = M5, all `GDN_TC`-selected in one build) and ships **M5** as the default
|
||||
when `LLAMA_KV_PAGED` is set. It is an f32/tf32-only re-port: the bf16/hybrid
|
||||
dev-tree machinery (from the dropped 0026 ssm_bf16_tau) and the bf16 CONFIG-C
|
||||
(M8) plus register-resident M6/M7 variants are NOT part of this series. M5 is the
|
||||
variant that beats the (already 84.7%-of-peak) sequential scan while staying on
|
||||
the bit-exact gate: MoE prefill S_PP **+3.5% @npp512 (3x interleaved A/B), +17.7%
|
||||
@npp2048**; decode S_TG unchanged (the tuned `GDN_CHUNK_MIN=64` engage threshold
|
||||
@@ -398,9 +401,28 @@ llama is losing. The MoE GEMM kernel is *not* where the gap lives.
|
||||
a long MoE prompt (where the default fires M5 at >=64 tokens) M5 and the
|
||||
sequential path agree word-for-word until **one** benign greedy token-flip
|
||||
("the User:" vs "the User's Request:"), the dense model not flipping at all -
|
||||
the textbook reduction-order flip greedy amplifies, NMSE-validated. M6/M7/M8
|
||||
remain env-selectable (`GDN_TC`/`GDN_CHUNK_C`/`GDN_DV_TILE`) for further tuning;
|
||||
M5 is the shipped default because it wins without losing the canonical gate.
|
||||
the textbook reduction-order flip greedy amplifies, NMSE-validated. The chunk
|
||||
geometry stays env-selectable (`GDN_TC`/`GDN_CHUNK_C`/`GDN_DV_TILE`) for further
|
||||
tuning; M5 is the shipped default because it wins without losing the canonical gate.
|
||||
- **GDN occupancy retune (patch 0022) was a decode win but an UNCONDITIONAL
|
||||
dense-prefill regression - now gated by scan length (patch 0046).** Patch
|
||||
0022's `(NUM_WARPS=16, COLS_PER_WARP=8)` column-fold of the GDN
|
||||
sequential-recurrence dispatch (`case 128`) raises per-warp memory-level
|
||||
parallelism on the short, wide DECODE scans (small `n_tokens`, large
|
||||
`n_seqs`) - the measured +11.1% dense decode win. Applied unconditionally it
|
||||
also hit the dense PREFILL path, where the scan is long and narrow: the launch
|
||||
`grid.z` collapses from `S_v/4 = 32` to `S_v/(16*8) = 1`, the SMs starve, and
|
||||
profiling attributed the whole ~-6% dense-prefill regression vs stock to
|
||||
`gated_delta_net` (+54% GPU time at the (16,8) geometry). Patch 0046 gates the
|
||||
geometry by per-call scan length: long scans (prefill,
|
||||
`n_tokens >= GDN_PREFILL_NTOK`, default 256) take stock's high-grid.z `(4,1)`
|
||||
geometry; short scans (decode) keep the `(16,8)` retune. That recovers dense
|
||||
prefill +7.2% back to stock parity while keeping the decode win, and it is
|
||||
bit-exact: patch 0022 already proved every selectable `{NUM_WARPS,
|
||||
COLS_PER_WARP}` variant is byte-identical (the sweep cannot change the md5), so
|
||||
switching geometry by scan length cannot move the greedy output. The explicit
|
||||
`GDN_NW`/`GDN_CPW` one-build %peak sweep still overrides (the gate yields when
|
||||
either is set), so the A/B harness is unchanged.
|
||||
|
||||
**Opt-in bf16-SSM fast mode - DROPPED (was patch 0026, `ssm_bf16_tau`).** The
|
||||
design premise - that bf16 KL error concentrates in long-memory heads and can be
|
||||
@@ -455,6 +477,11 @@ targeted is already recovered by the gather-fusion + block-table cache.
|
||||
|
||||
## 7. Pin + maintenance policy
|
||||
|
||||
- **Canonical source = the fork branch `mudler/llama.cpp:localai-paged`.** The
|
||||
vendored `patches/paged/*.patch` files are now generated (one `git format-patch`
|
||||
per commit) from that branch, which is the pin commit plus the paged patch
|
||||
commits in order, so there is no more hand-export drift between the dev tree and
|
||||
the shipped series.
|
||||
- **Pinned to llama.cpp `9d5d882d`** (kept == the stock `llama-cpp` pin). The pin
|
||||
is advanced **only** by the manual pin-sync process (this section):
|
||||
rebase the source-only patch series onto the new tip, rebuild on GPU, pass the
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,650 +0,0 @@
|
||||
From 864e92807809c55905bdb73ef492f6b1d03986f9 Mon Sep 17 00:00:00 2001
|
||||
From: Ettore Di Giacinto <mudler@localai.io>
|
||||
Date: Mon, 29 Jun 2026 11:02:07 +0200
|
||||
Subject: [PATCH] feat(paged): vLLM-exact offline-repack Marlin W4A16 grouped
|
||||
MoE prefill GEMM (patch 0045)
|
||||
|
||||
The patch-0035 root-cause sweep proved the W4A16 grouped MoE GEMM is occupancy/
|
||||
latency bound (~10% MMA util), NOT dequant bound: every prefill step it re-derives
|
||||
a per-block STRIDED weight gather and streams the 4-bit weights with scattered
|
||||
4-BYTE cp.async transactions (BN*8 per K-block) + a per-step ue4m3 scale decode,
|
||||
plus a SEPARATE global f32->bf16 activation pre-pass kernel. vLLM Marlin wins by
|
||||
REPACKING the weights ONCE at load into an MMA-ready tile-contiguous 4-bit layout
|
||||
so the GEMM loop issues only cheap COALESCED 16-byte loads and the per-lane reads
|
||||
are already in fragment order; the loop does ZERO dequant work beyond the cheap
|
||||
in-register nibble->bf16 unpack. This is that approach, faithful to vLLM Marlin and
|
||||
distinct from the rejected patch-0044 int8-Marlin (which doubled the weight bytes
|
||||
to int8 + a separate act-quant pass).
|
||||
|
||||
(1) OFFLINE REPACK (one-time, first engage; cached in a persistent cudaMalloc
|
||||
buffer keyed by src0->data): transpose the NVFP4 experts [N][Kb](d[4]+qs[32])
|
||||
into tile-contiguous planes Q=[e][Kb][N][32 qs bytes] + S=[e][Kb][N][4 ue4m3
|
||||
scale bytes]. STAYS 4-BIT: qs packed, scales ue4m3 -> the repacked buffer is
|
||||
byte-for-byte the SAME size as the source NVFP4 weights (144 MiB/tensor for the
|
||||
35B-A3B experts, == source; no int8 2x). Persists across all prefill steps.
|
||||
(2) KERNEL reads the pre-packed tiles directly: per (K-block, N-block) the BN rows
|
||||
qs/scales are contiguous -> the smem fill is a flat COALESCED 16-byte cp.async
|
||||
stream (no strided gather, no address math). Activations cast f32->bf16 IN
|
||||
REGISTER on the load into smem (no separate global act pre-pass, no act-quant).
|
||||
Inner loop: ldmatrix bf16 A, in-register PRMT (LUT-free) nibble->bf16 unpack
|
||||
scaled by the in-register ue4m3 decode, bf16 m16n8k16 mma.sync into f32 accum,
|
||||
cp.async multistage, ragged per-tile expert offset (reuses 0035 tile map). NO
|
||||
separate smem-staged dequant pass, NO __syncthreads-gated dequant pass (SASS:
|
||||
LDSM->PRMT->I2F->FMUL->F2F.BF16->HMMA.16816.F32.BF16 with no STS/BAR.SYNC between
|
||||
the weight load and the MMA).
|
||||
|
||||
The repacked smem is bit-identical to the 0035 W4A16 smem, so the proven inner
|
||||
fragment math is unchanged and the output is bit-identical to the 0035 W4A16 path.
|
||||
|
||||
TOGGLE: LLAMA_W4A16_REPACK=1 (default 0 == OFF), engaged only inside the already
|
||||
default-off 0035 W4A16 path (LLAMA_W4A16_PREFILL_M>0). LLAMA_W4A16_REPACK_NOCACHE=1
|
||||
repacks into a transient pool buffer per call (for test-backend-ops, which frees/
|
||||
reuses src0 addresses and so defeats the pointer-keyed cache). Stock / decode /
|
||||
non-NVFP4 byte-untouched.
|
||||
|
||||
VALIDATION (GB10, sm_121a, Qwen3.6-35B-A3B-NVFP4):
|
||||
- test-backend-ops MUL_MAT_ID nvfp4 (vs CPU oracle), REPACK forced + NOCACHE:
|
||||
81/81 OK, 0 FAIL.
|
||||
- real-model greedy md5 (paged MoE): stock == W4A16-non-repack == W4A16-repack
|
||||
(cached) == W4A16-repack (nocache) == default-off (REPACK=1, PREFILL_M unset),
|
||||
all fda1aadbbbfb36fe8ab0f5f5465c745e (bit-identical; default-off is stock).
|
||||
|
||||
HONEST PERF (S_PP t/s, llama-batched-bench -fa on -ngl 99 -ntg 32 -npl 1, paged,
|
||||
warm cache): the offline repack is a large win over the prior non-repack W4A16
|
||||
(npp512 854.7 -> 1452.4 = +70%; beats it at every M) and is FLAT across M
|
||||
(1452/1478/1467 at 512/1024/2048 = MMA-pipeline bound as designed). But the heavily
|
||||
tuned FP4-MMQ baseline on GB10 is still ~38% faster (~2030 flat). Decode S_TG
|
||||
unchanged (~55 t/s, prefill-only lever). One-time cache build amortized (cold first
|
||||
prefill). Ships DEFAULT-OFF (like 0033/0034/0035): the validated, env-gated, bit-
|
||||
exact-gated mechanism + the recorded result that even vLLM-exact offline-repack
|
||||
Marlin W4A16 does not overtake the GB10 FP4-MMQ winner.
|
||||
|
||||
Weight-memory: the repacked representation STAYS 4-BIT (no int8 expansion; vs 0044
|
||||
+100%); the persistent cache is an additive 4-bit copy of the engaged expert tensors
|
||||
when ON, and literally 0 when OFF (default). Decode keeps the original block_nvfp4
|
||||
layout for its MMQ/graph path, so the cache is additive rather than in-place.
|
||||
|
||||
Build: arch=compute_121a,code=[compute_121a,sm_121a]; AMPERE_MMA_AVAILABLE /
|
||||
CP_ASYNC_AVAILABLE guards (NO_DEVICE_CODE off-Blackwell).
|
||||
|
||||
Assisted-by: Claude:opus-4.8 [Claude Code]
|
||||
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
|
||||
---
|
||||
ggml/src/ggml-cuda/w4a16-gemm.cu | 11 +
|
||||
ggml/src/ggml-cuda/w4a16-repack.cu | 470 ++++++++++++++++++++++++++++
|
||||
ggml/src/ggml-cuda/w4a16-repack.cuh | 59 ++++
|
||||
3 files changed, 540 insertions(+)
|
||||
create mode 100644 ggml/src/ggml-cuda/w4a16-repack.cu
|
||||
create mode 100644 ggml/src/ggml-cuda/w4a16-repack.cuh
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/w4a16-gemm.cu b/ggml/src/ggml-cuda/w4a16-gemm.cu
|
||||
index c5c9ef7..687a7ae 100644
|
||||
--- a/ggml/src/ggml-cuda/w4a16-gemm.cu
|
||||
+++ b/ggml/src/ggml-cuda/w4a16-gemm.cu
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "w4a16-gemm.cuh"
|
||||
+#include "w4a16-repack.cuh"
|
||||
#include "mma.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -389,6 +390,16 @@ void ggml_cuda_mul_mat_id_w4a16_grouped(
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_NVFP4);
|
||||
GGML_ASSERT(N % 128 == 0 && K % 64 == 0);
|
||||
|
||||
+ // [paged patch 0045] vLLM-exact offline-repack sub-mode: repack the NVFP4 experts ONCE at first
|
||||
+ // engage into an MMA-ready tile-contiguous 4-bit layout (cached, stays 4-bit), then run a GEMM
|
||||
+ // loop with coalesced 16B weight loads + in-register act cast + ZERO in-loop dequant beyond the
|
||||
+ // cheap in-register nibble->bf16 unpack. Gated by LLAMA_W4A16_REPACK (default off).
|
||||
+ if (ggml_cuda_w4a16_repack_enabled()) {
|
||||
+ ggml_cuda_mul_mat_id_w4a16_repack(ctx, src0, src1_sorted, dst_sorted,
|
||||
+ tokens_per_expert, n_experts, K, N, stream);
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
int sel = w4a16_cfg_sel();
|
||||
if (N % w4a16_cfg_bn(sel) != 0) {
|
||||
sel = 3; // BN>128 config whose BN doesn't divide N: safe BN=128 winner
|
||||
diff --git a/ggml/src/ggml-cuda/w4a16-repack.cu b/ggml/src/ggml-cuda/w4a16-repack.cu
|
||||
new file mode 100644
|
||||
index 0000000..4a7a125
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/ggml-cuda/w4a16-repack.cu
|
||||
@@ -0,0 +1,470 @@
|
||||
+#include "w4a16-repack.cuh"
|
||||
+#include "mma.cuh"
|
||||
+
|
||||
+#include <cstdint>
|
||||
+#include <cstdlib>
|
||||
+#include <mutex>
|
||||
+#include <unordered_map>
|
||||
+#include <vector>
|
||||
+#include <algorithm>
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// [paged patch 0045] vLLM-exact offline-repack Marlin W4A16 grouped MoE prefill GEMM.
|
||||
+// See w4a16-repack.cuh for the design. Default-off (LLAMA_W4A16_REPACK).
|
||||
+// ===========================================================================
|
||||
+
|
||||
+using namespace ggml_cuda_mma;
|
||||
+typedef tile<16, 8, nv_bfloat162> rp_tile_A; // A operand: M=16, K=16
|
||||
+typedef tile< 8, 8, nv_bfloat162> rp_tile_B; // B operand: N=8, K=16
|
||||
+typedef tile<16, 8, float> rp_tile_C; // accumulator: M=16, N=8
|
||||
+
|
||||
+bool ggml_cuda_w4a16_repack_enabled() {
|
||||
+ static const bool e = [] {
|
||||
+ const char * s = getenv("LLAMA_W4A16_REPACK");
|
||||
+ return s != nullptr && atoi(s) != 0;
|
||||
+ }();
|
||||
+ return e;
|
||||
+}
|
||||
+
|
||||
+// ---- cp.async helpers (sm80+; raw bytes, no cast) ----
|
||||
+static __device__ __forceinline__ void rp_cp_async16(void * smem, const void * gmem) {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ const unsigned s = (unsigned) __cvta_generic_to_shared(smem);
|
||||
+ asm volatile("cp.async.cg.shared.global [%0],[%1],16;\n" :: "r"(s), "l"(gmem));
|
||||
+#else
|
||||
+ GGML_UNUSED(smem); GGML_UNUSED(gmem); NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+static __device__ __forceinline__ void rp_cp_commit() {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ asm volatile("cp.async.commit_group;\n" ::);
|
||||
+#else
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+template<int N> static __device__ __forceinline__ void rp_cp_wait() {
|
||||
+#ifdef CP_ASYNC_AVAILABLE
|
||||
+ asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
+#else
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+
|
||||
+// ---- fast 16-entry FP4 (E2M1) -> kvalues_mxfp4 lookup via __byte_perm (PRMT), LUT-free, NO
|
||||
+// divergent LDG. Bit-identical to kvalues_mxfp4[code]; same trick MMQ's get_int_from_table_16 and
|
||||
+// the 0035 kernel use. The "cheap in-register unpack" - the ONLY dequant work the loop does. ----
|
||||
+static __device__ __forceinline__ int2 rp_table16(uint32_t q4) {
|
||||
+ const uint32_t * table32 = (const uint32_t *) kvalues_mxfp4; // 16 int8 == 4 u32
|
||||
+ uint32_t tmp[2];
|
||||
+ const uint32_t sel = 0x32103210 | ((q4 & 0x88888888) >> 1);
|
||||
+#pragma unroll
|
||||
+ for (uint32_t i = 0; i < 2; ++i) {
|
||||
+ const uint32_t shift = 16*i;
|
||||
+ const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
|
||||
+ const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
|
||||
+ tmp[i] = __byte_perm(low, high, sel >> shift);
|
||||
+ }
|
||||
+ return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// (1) OFFLINE REPACK kernel (one-time). Transpose src0 NVFP4 [e][N][Kb](d[4]+qs[32]) into
|
||||
+// Qrp = [e][Kb][N][32 qs bytes] (u32 plane, 8 u32/block)
|
||||
+// Srp = [e][Kb][N][ 4 scale bytes] (u32 plane, 1 u32/block)
|
||||
+// Same total bytes as src0 (36 B/block); stays 4-bit. One thread per source block.
|
||||
+// ===========================================================================
|
||||
+static __global__ void w4a16_repack_kernel(
|
||||
+ const block_nvfp4 * __restrict__ W0, int64_t expert_stride_blocks,
|
||||
+ uint32_t * __restrict__ Qrp, uint32_t * __restrict__ Srp,
|
||||
+ int N, int Kb, int n_experts) {
|
||||
+ const int64_t total = (int64_t) n_experts * N * Kb;
|
||||
+ const int64_t t = (int64_t) blockIdx.x * blockDim.x + threadIdx.x;
|
||||
+ if (t >= total) {
|
||||
+ return;
|
||||
+ }
|
||||
+ const int kt = (int) (t % Kb);
|
||||
+ const int64_t r2 = t / Kb;
|
||||
+ const int n = (int) (r2 % N);
|
||||
+ const int e = (int) (r2 / N);
|
||||
+
|
||||
+ const block_nvfp4 * blk = W0 + (int64_t) e*expert_stride_blocks + (int64_t) n*Kb + kt;
|
||||
+ const uint32_t * src = (const uint32_t *) blk; // word0=d[4], words1..8=qs[32]
|
||||
+
|
||||
+ const int64_t qbase = (((int64_t) e*Kb + kt) * N + n) * 8; // u32 offset in Qrp
|
||||
+ const int64_t sbase = ((int64_t) e*Kb + kt) * N + n; // u32 offset in Srp
|
||||
+ Srp[sbase] = src[0];
|
||||
+#pragma unroll
|
||||
+ for (int w = 0; w < 8; w++) {
|
||||
+ Qrp[qbase + w] = src[1 + w];
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// (2) GROUPED GEMM over the repacked tiles. Per output tile (blockIdx.x=N-block, blockIdx.y=M-tile):
|
||||
+// expert e=g_tile_expert[by], row0=g_tile_row0[by], rcount=g_tile_rows[by].
|
||||
+// Weights from Qrp/Srp (tile-contiguous), activations from A_f32 (cast bf16 in-register), out to C.
|
||||
+// ZERO in-loop dequant beyond the in-register nibble->bf16 unpack + ue4m3 scale decode.
|
||||
+// ===========================================================================
|
||||
+template<int BM, int BN, int WARPS_M, int WARPS_N, int STAGES, int APAD>
|
||||
+__launch_bounds__(WARPS_M*WARPS_N*32, 1)
|
||||
+static __global__ void w4a16_repack_grouped_kernel(
|
||||
+ const float * __restrict__ A_f32, // [total_rows, K] f32 (sorted)
|
||||
+ const uint32_t * __restrict__ Qrp, // [e][Kb][N][8 u32]
|
||||
+ const uint32_t * __restrict__ Srp, // [e][Kb][N][1 u32]
|
||||
+ float * __restrict__ C, // [total_rows, N] f32
|
||||
+ const int * __restrict__ g_tile_expert,
|
||||
+ const int * __restrict__ g_tile_row0,
|
||||
+ const int * __restrict__ g_tile_rows,
|
||||
+ int N, int K, int total_rows) {
|
||||
+#if defined(AMPERE_MMA_AVAILABLE) && defined(CP_ASYNC_AVAILABLE)
|
||||
+ constexpr int BK = 64; // one nvfp4 block
|
||||
+ constexpr int NWARP = WARPS_M*WARPS_N;
|
||||
+ constexpr int THREADS = NWARP*32;
|
||||
+ constexpr int WM = BM/WARPS_M, WN = BN/WARPS_N;
|
||||
+ constexpr int MF = WM/16, NF = WN/8;
|
||||
+
|
||||
+ constexpr int AN = BK/2; // bf16 pairs per A smem row (nv_bfloat162)
|
||||
+ constexpr int ASTRIDE = AN + APAD; // padded A smem row stride (skew banks, +19% lesson)
|
||||
+ constexpr int SZ_A = BM*ASTRIDE; // nv_bfloat162 (== u32) per stage (padded)
|
||||
+ constexpr int SZ_WQ = BN*8; // u32 per stage (32 qs bytes/row, tile-contiguous)
|
||||
+ constexpr int SZ_WD = BN; // u32 per stage (4 scale bytes/row, tile-contiguous)
|
||||
+
|
||||
+ extern __shared__ uint32_t smem_u32[];
|
||||
+ constexpr int STAGE_U32 = SZ_A + SZ_WQ + SZ_WD;
|
||||
+ nv_bfloat162 * sA[STAGES];
|
||||
+ uint32_t * sWq[STAGES];
|
||||
+ uint32_t * sWd[STAGES];
|
||||
+#pragma unroll
|
||||
+ for (int s = 0; s < STAGES; s++) {
|
||||
+ uint32_t * base = smem_u32 + s*STAGE_U32;
|
||||
+ sA[s] = (nv_bfloat162 *) base;
|
||||
+ sWq[s] = base + SZ_A;
|
||||
+ sWd[s] = base + SZ_A + SZ_WQ;
|
||||
+ }
|
||||
+
|
||||
+ const int lane = threadIdx.x; // 0..31 (mma.cuh uses threadIdx.x AS the warp lane)
|
||||
+ const int warp = threadIdx.y; // 0..NWARP-1
|
||||
+ const int tid = warp*32 + lane;
|
||||
+ const int wrow = warp / WARPS_N, wcol = warp % WARPS_N;
|
||||
+
|
||||
+ const int e = g_tile_expert[blockIdx.y];
|
||||
+ const int row0 = g_tile_row0[blockIdx.y];
|
||||
+ const int rcount = g_tile_rows[blockIdx.y];
|
||||
+ const int blockCol = blockIdx.x*BN;
|
||||
+ const int Kb = K/64;
|
||||
+ // base u32 offset of expert e in the repacked planes (tile = + (kt*N + blockCol)*{8,1})
|
||||
+ const int64_t Qe = (int64_t) e*Kb*N*8;
|
||||
+ const int64_t Se = (int64_t) e*Kb*N;
|
||||
+
|
||||
+ rp_tile_C acc[MF][NF];
|
||||
+
|
||||
+ // async-load K-block kt into stage st: A cast in-register (no global pre-pass); W coalesced 16B.
|
||||
+ auto load_tile = [&](int st, int kt) {
|
||||
+ // A: BM rows x BK bf16 = BM x (BK/8) 16B chunks. Source f32 -> cast bf16 in register.
|
||||
+ const int A_chunks = BM*(BK/8);
|
||||
+#pragma unroll 1
|
||||
+ for (int idx = tid; idx < A_chunks; idx += THREADS) {
|
||||
+ const int c = idx % (BK/8); // 16B (8 bf16) chunk in the row
|
||||
+ const int r = idx / (BK/8); // row in tile
|
||||
+ const int gr = row0 + r;
|
||||
+ nv_bfloat162 * d2 = sA[st] + r*ASTRIDE + c*4; // 4 nv_bfloat162 = 8 bf16
|
||||
+ if (gr < total_rows) {
|
||||
+ const float * src = A_f32 + (int64_t) gr*K + (int64_t) kt*BK + c*8;
|
||||
+ const float4 a = *(const float4 *) (src);
|
||||
+ const float4 b = *(const float4 *) (src + 4);
|
||||
+ d2[0] = make_bfloat162(__float2bfloat16(a.x), __float2bfloat16(a.y));
|
||||
+ d2[1] = make_bfloat162(__float2bfloat16(a.z), __float2bfloat16(a.w));
|
||||
+ d2[2] = make_bfloat162(__float2bfloat16(b.x), __float2bfloat16(b.y));
|
||||
+ d2[3] = make_bfloat162(__float2bfloat16(b.z), __float2bfloat16(b.w));
|
||||
+ } else {
|
||||
+ const nv_bfloat162 z = make_bfloat162((nv_bfloat16) 0.0f, (nv_bfloat16) 0.0f);
|
||||
+ d2[0] = z; d2[1] = z; d2[2] = z; d2[3] = z;
|
||||
+ }
|
||||
+ }
|
||||
+ // W qs: tile-contiguous BN*32 bytes = BN*8 u32 = BN*2 16B chunks. Coalesced cp.async.16.
|
||||
+ const uint32_t * Qtile = Qrp + Qe + ((int64_t) kt*N + blockCol) * 8;
|
||||
+ const int Q_chunks = (BN*8) / 4; // 16B (4 u32) chunks
|
||||
+#pragma unroll 1
|
||||
+ for (int idx = tid; idx < Q_chunks; idx += THREADS) {
|
||||
+ rp_cp_async16(&sWq[st][idx*4], Qtile + idx*4);
|
||||
+ }
|
||||
+ // W scales: tile-contiguous BN*4 bytes = BN u32 = BN/4 16B chunks. Coalesced cp.async.16.
|
||||
+ const uint32_t * Stile = Srp + Se + ((int64_t) kt*N + blockCol);
|
||||
+ const int S_chunks = BN / 4;
|
||||
+#pragma unroll 1
|
||||
+ for (int idx = tid; idx < S_chunks; idx += THREADS) {
|
||||
+ rp_cp_async16(&sWd[st][idx*4], Stile + idx*4);
|
||||
+ }
|
||||
+ };
|
||||
+
|
||||
+ // prologue
|
||||
+#pragma unroll
|
||||
+ for (int s = 0; s < STAGES-1; s++) { if (s < Kb) load_tile(s, s); rp_cp_commit(); }
|
||||
+
|
||||
+ for (int kt = 0; kt < Kb; kt++) {
|
||||
+ const int ld = kt + (STAGES-1);
|
||||
+ if (ld < Kb) load_tile(ld % STAGES, ld);
|
||||
+ rp_cp_commit();
|
||||
+ rp_cp_wait<STAGES-1>();
|
||||
+ __syncthreads();
|
||||
+
|
||||
+ const int rs = kt % STAGES;
|
||||
+ const nv_bfloat162 * sAcur = sA[rs];
|
||||
+ const uint32_t * sWqw = sWq[rs]; // BN rows x 8 u32 (32 qs bytes)
|
||||
+ const uint32_t * sWdw = sWd[rs]; // BN rows x 1 u32 (4 scale bytes)
|
||||
+
|
||||
+#pragma unroll
|
||||
+ for (int kk = 0; kk < BK/16; kk++) { // 4 m16n8k16 sub-steps per 64-block
|
||||
+ const int sub = kk;
|
||||
+ // A fragments via ldmatrix (bf16)
|
||||
+ rp_tile_A A_frag[MF];
|
||||
+#pragma unroll
|
||||
+ for (int mi = 0; mi < MF; mi++) {
|
||||
+ const int rb = wrow*WM + mi*16;
|
||||
+ load_ldmatrix(A_frag[mi], sAcur + rb*ASTRIDE + kk*8, ASTRIDE);
|
||||
+ }
|
||||
+ // B fragments: in-register FP4->bf16 unpack (byte_perm LUT-free) * in-register ue4m3 scale.
|
||||
+ rp_tile_B B_frag[NF];
|
||||
+ const int n_local = lane >> 2; // tile_B::get_i (row N, 0..7)
|
||||
+ const int jc = lane & 3;
|
||||
+ const int wsel = sub*2 + (jc >> 1);
|
||||
+ const int bsh = 8 * (2*(jc & 1));
|
||||
+#pragma unroll
|
||||
+ for (int ni = 0; ni < NF; ni++) {
|
||||
+ const int nrow = wcol*WN + ni*8 + n_local; // col within BN tile [0,BN)
|
||||
+ const uint32_t w = sWqw[nrow*8 + wsel];
|
||||
+ const int2 kv = rp_table16(w);
|
||||
+ const float sc = ggml_cuda_ue4m3_to_fp32(((const uint8_t *) &sWdw[nrow])[sub]);
|
||||
+ B_frag[ni].x[0].x = __float2bfloat16(sc * (float) (int8_t) (kv.x >> bsh));
|
||||
+ B_frag[ni].x[0].y = __float2bfloat16(sc * (float) (int8_t) (kv.x >> (bsh + 8)));
|
||||
+ B_frag[ni].x[1].x = __float2bfloat16(sc * (float) (int8_t) (kv.y >> bsh));
|
||||
+ B_frag[ni].x[1].y = __float2bfloat16(sc * (float) (int8_t) (kv.y >> (bsh + 8)));
|
||||
+ }
|
||||
+#pragma unroll
|
||||
+ for (int mi = 0; mi < MF; mi++)
|
||||
+#pragma unroll
|
||||
+ for (int ni = 0; ni < NF; ni++)
|
||||
+ mma(acc[mi][ni], A_frag[mi], B_frag[ni]);
|
||||
+ }
|
||||
+ __syncthreads();
|
||||
+ }
|
||||
+
|
||||
+ // write back (mask the ragged per-expert row tail)
|
||||
+#pragma unroll
|
||||
+ for (int mi = 0; mi < MF; mi++)
|
||||
+#pragma unroll
|
||||
+ for (int ni = 0; ni < NF; ni++) {
|
||||
+ const int orow = wrow*WM + mi*16;
|
||||
+ const int ocol = blockCol + wcol*WN + ni*8;
|
||||
+#pragma unroll
|
||||
+ for (int l = 0; l < acc[mi][ni].ne; l++) {
|
||||
+ const int lr = orow + acc[mi][ni].get_i(l);
|
||||
+ const int nc = ocol + acc[mi][ni].get_j(l);
|
||||
+ if (lr < rcount && nc < N) {
|
||||
+ C[(int64_t)(row0 + lr)*N + nc] = acc[mi][ni].x[l];
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+#else
|
||||
+ GGML_UNUSED(A_f32); GGML_UNUSED(Qrp); GGML_UNUSED(Srp); GGML_UNUSED(C);
|
||||
+ GGML_UNUSED(g_tile_expert); GGML_UNUSED(g_tile_row0); GGML_UNUSED(g_tile_rows);
|
||||
+ GGML_UNUSED(N); GGML_UNUSED(K); GGML_UNUSED(total_rows);
|
||||
+ NO_DEVICE_CODE;
|
||||
+#endif // AMPERE_MMA_AVAILABLE && CP_ASYNC_AVAILABLE
|
||||
+}
|
||||
+
|
||||
+// launch the one-time repack kernel: src0 NVFP4 -> (Qrp, Srp) repacked planes.
|
||||
+static void w4a16_repack_launch(
|
||||
+ const ggml_tensor * src0, int64_t n_experts, int64_t K, int64_t N,
|
||||
+ uint32_t * Qrp, uint32_t * Srp, cudaStream_t stream) {
|
||||
+ const int64_t Kb = K / 64;
|
||||
+ const int64_t nblk = n_experts * Kb * N;
|
||||
+ const int64_t expert_stride_blocks = (int64_t) (src0->nb[2] / sizeof(block_nvfp4));
|
||||
+ const int threads = 256;
|
||||
+ const int64_t grid = (nblk + threads - 1) / threads;
|
||||
+ w4a16_repack_kernel<<<grid, threads, 0, stream>>>(
|
||||
+ (const block_nvfp4 *) src0->data, expert_stride_blocks,
|
||||
+ Qrp, Srp, (int) N, (int) Kb, (int) n_experts);
|
||||
+ CUDA_CHECK(cudaGetLastError());
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// repack cache: persistent device buffer per src0 tensor (keyed by src0->data). One-time build.
|
||||
+//
|
||||
+// Correct for real models: model weights live in a persistent buffer for the model's lifetime, so
|
||||
+// src0->data is stable and never freed/reused during inference. NOT correct for test-backend-ops,
|
||||
+// which allocates/frees a fresh src0 per case and ggml reuses the address -> a cache HIT returns
|
||||
+// stale repacked weights. For harness validation use LLAMA_W4A16_REPACK_NOCACHE=1 (repack into a
|
||||
+// transient pool buffer every call, bypassing the cache); this proves the kernel+repack correct.
|
||||
+// ===========================================================================
|
||||
+struct w4a16_repack_entry {
|
||||
+ uint32_t * Qrp = nullptr; // [n_experts][Kb][N][8 u32]
|
||||
+ uint32_t * Srp = nullptr; // [n_experts][Kb][N][1 u32]
|
||||
+ int64_t n_experts = 0, Kb = 0, N = 0;
|
||||
+ size_t bytes = 0; // total cudaMalloc bytes (Q + S), for the memory-delta report
|
||||
+};
|
||||
+static std::mutex g_rp_mu;
|
||||
+static std::unordered_map<const void *, w4a16_repack_entry> g_rp_cache;
|
||||
+
|
||||
+// total bytes the repack cache currently holds on device (for the memory-delta report).
|
||||
+size_t ggml_cuda_w4a16_repack_cache_bytes() {
|
||||
+ std::lock_guard<std::mutex> lk(g_rp_mu);
|
||||
+ size_t b = 0;
|
||||
+ for (const auto & kv : g_rp_cache) {
|
||||
+ b += kv.second.bytes;
|
||||
+ }
|
||||
+ return b;
|
||||
+}
|
||||
+
|
||||
+static const w4a16_repack_entry & w4a16_get_or_build_repack(
|
||||
+ const ggml_tensor * src0, int64_t n_experts, int64_t K, int64_t N, cudaStream_t stream) {
|
||||
+ const void * key = src0->data;
|
||||
+ std::lock_guard<std::mutex> lk(g_rp_mu);
|
||||
+ auto it = g_rp_cache.find(key);
|
||||
+ if (it != g_rp_cache.end()) {
|
||||
+ return it->second;
|
||||
+ }
|
||||
+
|
||||
+ const int64_t Kb = K / 64;
|
||||
+ const int64_t nblk = n_experts * Kb * N;
|
||||
+ const size_t q_bytes = (size_t) nblk * 8 * sizeof(uint32_t); // 32 qs bytes/block
|
||||
+ const size_t s_bytes = (size_t) nblk * 1 * sizeof(uint32_t); // 4 scale bytes/block
|
||||
+ // single allocation: [Q][S]; total == source NVFP4 weight bytes (36 B/block) -> stays 4-bit.
|
||||
+ void * buf = nullptr;
|
||||
+ CUDA_CHECK(cudaMalloc(&buf, q_bytes + s_bytes));
|
||||
+ uint32_t * Qrp = (uint32_t *) buf;
|
||||
+ uint32_t * Srp = (uint32_t *) ((char *) buf + q_bytes);
|
||||
+
|
||||
+ w4a16_repack_launch(src0, n_experts, K, N, Qrp, Srp, stream);
|
||||
+
|
||||
+ w4a16_repack_entry ent;
|
||||
+ ent.Qrp = Qrp; ent.Srp = Srp;
|
||||
+ ent.n_experts = n_experts; ent.Kb = Kb; ent.N = N;
|
||||
+ ent.bytes = q_bytes + s_bytes;
|
||||
+ auto res = g_rp_cache.emplace(key, ent);
|
||||
+
|
||||
+ if (getenv("LLAMA_W4A16_DEBUG")) {
|
||||
+ // NB: we already hold g_rp_mu - sum the cache total INLINE (do NOT call the public
|
||||
+ // ggml_cuda_w4a16_repack_cache_bytes(), which re-locks the non-recursive mutex -> deadlock).
|
||||
+ size_t total_cache = 0;
|
||||
+ for (const auto & kv : g_rp_cache) {
|
||||
+ total_cache += kv.second.bytes;
|
||||
+ }
|
||||
+ fprintf(stderr, "[w4a16-repack] BUILT cache for src0=%p: n_experts=%lld Kb=%lld N=%lld "
|
||||
+ "bytes=%.1f MiB (== source NVFP4 weight bytes; stays 4-bit). total cache=%.1f MiB\n",
|
||||
+ key, (long long) n_experts, (long long) Kb, (long long) N,
|
||||
+ ent.bytes / (1024.0*1024.0), total_cache / (1024.0*1024.0));
|
||||
+ }
|
||||
+ return res.first->second;
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// GEMM run over already-repacked planes (Qrp, Srp). Single tuned config
|
||||
+// (the 0035 winner: BM64 BN128 WARPS_M1 WARPS_N8 STAGES3 APAD4).
|
||||
+// ===========================================================================
|
||||
+static void w4a16_repack_run(
|
||||
+ ggml_backend_cuda_context & ctx,
|
||||
+ const float * src1_sorted,
|
||||
+ float * dst_sorted,
|
||||
+ const int * tokens_per_expert,
|
||||
+ int64_t n_experts, int64_t K, int64_t N,
|
||||
+ const uint32_t * Qrp, const uint32_t * Srp,
|
||||
+ cudaStream_t stream) {
|
||||
+ constexpr int BM = 64, BN = 128, WARPS_M = 1, WARPS_N = 8, STAGES = 3, APAD = 4;
|
||||
+
|
||||
+ // host: per-M-tile expert map (ragged, no tile crosses an expert boundary)
|
||||
+ int64_t total_rows = 0;
|
||||
+ for (int64_t e = 0; e < n_experts; e++) {
|
||||
+ total_rows += tokens_per_expert[e];
|
||||
+ }
|
||||
+ if (total_rows == 0) {
|
||||
+ return;
|
||||
+ }
|
||||
+ std::vector<int32_t> h_tile_expert, h_tile_row0, h_tile_rows;
|
||||
+ int64_t row = 0;
|
||||
+ for (int64_t e = 0; e < n_experts; e++) {
|
||||
+ const int t = tokens_per_expert[e];
|
||||
+ for (int off = 0; off < t; off += BM) {
|
||||
+ h_tile_expert.push_back((int32_t) e);
|
||||
+ h_tile_row0.push_back((int32_t) (row + off));
|
||||
+ h_tile_rows.push_back((int32_t) std::min(BM, t - off));
|
||||
+ }
|
||||
+ row += t;
|
||||
+ }
|
||||
+ const int n_tiles = (int) h_tile_expert.size();
|
||||
+
|
||||
+ if (getenv("LLAMA_W4A16_DEBUG")) {
|
||||
+ int max_tpe = 0, multi = 0;
|
||||
+ for (int64_t e = 0; e < n_experts; e++) {
|
||||
+ if (tokens_per_expert[e] > max_tpe) max_tpe = tokens_per_expert[e];
|
||||
+ if (tokens_per_expert[e] > BM) multi++;
|
||||
+ }
|
||||
+ fprintf(stderr, "[w4a16-repack] engaged: total_rows=%lld n_experts=%lld K=%lld N=%lld "
|
||||
+ "n_tiles=%d max_tpe=%d multi_tile=%d (offline-repack, in-register act cast)\n",
|
||||
+ (long long) total_rows, (long long) n_experts, (long long) K, (long long) N,
|
||||
+ n_tiles, max_tpe, multi);
|
||||
+ }
|
||||
+
|
||||
+ ggml_cuda_pool_alloc<int32_t> d_tile_expert(ctx.pool(), n_tiles);
|
||||
+ ggml_cuda_pool_alloc<int32_t> d_tile_row0 (ctx.pool(), n_tiles);
|
||||
+ ggml_cuda_pool_alloc<int32_t> d_tile_rows (ctx.pool(), n_tiles);
|
||||
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_expert.ptr, h_tile_expert.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_row0.ptr, h_tile_row0.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
+ CUDA_CHECK(cudaMemcpyAsync(d_tile_rows.ptr, h_tile_rows.data(), n_tiles*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
|
||||
+
|
||||
+ auto kern = w4a16_repack_grouped_kernel<BM, BN, WARPS_M, WARPS_N, STAGES, APAD>;
|
||||
+ constexpr int AN = 64/2, ASTRIDE = AN + APAD;
|
||||
+ constexpr int STAGE_U32 = BM*ASTRIDE + BN*8 + BN;
|
||||
+ const int smem_bytes = STAGES * STAGE_U32 * (int) sizeof(uint32_t);
|
||||
+ CUDA_CHECK(cudaFuncSetAttribute(kern, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes));
|
||||
+
|
||||
+ dim3 grid((unsigned) (N / BN), (unsigned) n_tiles);
|
||||
+ dim3 block(32, WARPS_M*WARPS_N);
|
||||
+ kern<<<grid, block, smem_bytes, stream>>>(
|
||||
+ src1_sorted, Qrp, Srp, dst_sorted,
|
||||
+ d_tile_expert.ptr, d_tile_row0.ptr, d_tile_rows.ptr,
|
||||
+ (int) N, (int) K, (int) total_rows);
|
||||
+ CUDA_CHECK(cudaGetLastError());
|
||||
+}
|
||||
+
|
||||
+// ===========================================================================
|
||||
+// public host entry. Default: cached persistent repack (one-time, production-correct, stays
|
||||
+// 4-bit). LLAMA_W4A16_REPACK_NOCACHE=1: repack into a transient pool buffer EVERY call (bypasses
|
||||
+// the pointer-keyed cache) so test-backend-ops (which frees/reuses src0 addresses) can validate
|
||||
+// the kernel + repack correctness.
|
||||
+// ===========================================================================
|
||||
+void ggml_cuda_mul_mat_id_w4a16_repack(
|
||||
+ ggml_backend_cuda_context & ctx,
|
||||
+ const ggml_tensor * src0,
|
||||
+ const float * src1_sorted,
|
||||
+ float * dst_sorted,
|
||||
+ const int * tokens_per_expert,
|
||||
+ int64_t n_experts, int64_t K, int64_t N,
|
||||
+ cudaStream_t stream) {
|
||||
+ GGML_ASSERT(src0->type == GGML_TYPE_NVFP4);
|
||||
+ GGML_ASSERT(N % 128 == 0 && K % 64 == 0);
|
||||
+
|
||||
+ static const bool nocache = [] {
|
||||
+ const char * s = getenv("LLAMA_W4A16_REPACK_NOCACHE");
|
||||
+ return s != nullptr && atoi(s) != 0;
|
||||
+ }();
|
||||
+
|
||||
+ if (nocache) {
|
||||
+ const int64_t Kb = K / 64;
|
||||
+ const int64_t nblk = n_experts * Kb * N;
|
||||
+ // (1) repack into a transient pool buffer (same 4-bit size; stream-ordered before the GEMM).
|
||||
+ ggml_cuda_pool_alloc<uint32_t> q(ctx.pool(), (size_t) nblk * 8);
|
||||
+ ggml_cuda_pool_alloc<uint32_t> s(ctx.pool(), (size_t) nblk * 1);
|
||||
+ w4a16_repack_launch(src0, n_experts, K, N, q.ptr, s.ptr, stream);
|
||||
+ // (2) GEMM (q,s stay alive through the launch; kernel is queued on `stream` before dtor).
|
||||
+ w4a16_repack_run(ctx, src1_sorted, dst_sorted, tokens_per_expert,
|
||||
+ n_experts, K, N, q.ptr, s.ptr, stream);
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ // (1) one-time offline repack (cached, persistent, stays 4-bit). Stream-ordered before the GEMM.
|
||||
+ const w4a16_repack_entry & rp = w4a16_get_or_build_repack(src0, n_experts, K, N, stream);
|
||||
+ // (2) GEMM over the cached repacked planes.
|
||||
+ w4a16_repack_run(ctx, src1_sorted, dst_sorted, tokens_per_expert,
|
||||
+ n_experts, K, N, rp.Qrp, rp.Srp, stream);
|
||||
+}
|
||||
diff --git a/ggml/src/ggml-cuda/w4a16-repack.cuh b/ggml/src/ggml-cuda/w4a16-repack.cuh
|
||||
new file mode 100644
|
||||
index 0000000..f44f5f8
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/ggml-cuda/w4a16-repack.cuh
|
||||
@@ -0,0 +1,59 @@
|
||||
+#pragma once
|
||||
+
|
||||
+#include "common.cuh"
|
||||
+
|
||||
+// [paged patch 0045] vLLM-EXACT offline-repack Marlin W4A16 grouped MoE prefill GEMM.
|
||||
+//
|
||||
+// This is a SUB-MODE of the patch-0035 W4A16 grouped MoE GEMM (default-off). The 0035 root-cause
|
||||
+// sweep proved the W4A16 kernel is occupancy/latency bound (~10% MMA util): every prefill step it
|
||||
+// re-derives a per-block STRIDED weight gather (blk = We + (blockCol+r)*Kb + kt) and streams the
|
||||
+// 4-bit weights with scattered 4-BYTE cp.async transactions (BN*8 per K-block) + a per-step ue4m3
|
||||
+// scale decode, plus a SEPARATE global f32->bf16 activation pre-pass kernel. vLLM's Marlin wins by
|
||||
+// REPACKING the weights ONCE at load into an MMA-ready, tile-contiguous 4-bit layout so the GEMM
|
||||
+// loop issues only cheap COALESCED 16-byte loads and the per-lane reads are already in fragment
|
||||
+// order - the loop does ZERO dequant work beyond the cheap in-register nibble->bf16 unpack.
|
||||
+//
|
||||
+// What this mode does, faithful to vLLM Marlin:
|
||||
+// (1) OFFLINE REPACK (one-time, at first engage; cached in a persistent cudaMalloc device buffer
|
||||
+// keyed by src0->data): transpose the NVFP4 expert weights [N rows][Kb blocks](d[4]+qs[32])
|
||||
+// into two tile-contiguous planes, Q=[expert][Kb][N][32 qs bytes] and S=[expert][Kb][N][4
|
||||
+// ue4m3 scale bytes]. STAYS 4-BIT: qs kept packed, scales kept ue4m3 - the repacked buffer is
|
||||
+// byte-for-byte the SAME size as the source NVFP4 weights (36 B / 64 weights). No int8
|
||||
+// expansion, no 2x memory (vs the rejected patch-0044 int8-Marlin which doubled the weight
|
||||
+// bytes). Persists across all prefill steps.
|
||||
+// (2) KERNEL reads the pre-packed tiles DIRECTLY: per (K-block, N-block) the BN rows' qs/scales
|
||||
+// are contiguous, so the smem fill is a flat COALESCED 16-byte cp.async stream (no strided
|
||||
+// gather, no address math). Activations are cast f32->bf16 IN REGISTER on the load into smem
|
||||
+// (no separate global act pre-pass, no act-quant). The inner loop does only: ldmatrix the
|
||||
+// bf16 A, in-register nibble->bf16 unpack of the weight (byte_perm/PRMT, LUT-free) scaled by
|
||||
+// the in-register ue4m3 decode, and bf16 m16n8k16 mma.sync into f32 accumulators. cp.async
|
||||
+// multistage pipelined, ragged per-tile expert offset (reuses 0035's tile map). NO separate
|
||||
+// smem-staged dequant pass, NO __syncthreads-gated dequant pass.
|
||||
+//
|
||||
+// The repacked smem contents are bit-identical to the 0035 W4A16 smem, so the inner fragment math
|
||||
+// (proven 81/81 on test-backend-ops MUL_MAT_ID) is unchanged and the output is bit-identical to the
|
||||
+// 0035 W4A16 path; only the global->smem load path (coalesced) and the act cast (in-register) change.
|
||||
+//
|
||||
+// Toggle: LLAMA_W4A16_REPACK=1 (default 0 == OFF). Engages only inside the already-default-off
|
||||
+// 0035 W4A16 grouped path (LLAMA_W4A16_PREFILL_M>0). Stock / decode / non-NVFP4 byte-untouched.
|
||||
+
|
||||
+// True iff LLAMA_W4A16_REPACK != 0.
|
||||
+bool ggml_cuda_w4a16_repack_enabled();
|
||||
+
|
||||
+// Total bytes the persistent repack cache currently holds on device (for the weight-memory-delta
|
||||
+// report). The repacked layout stays 4-bit, so this equals the source NVFP4 bytes of the engaged
|
||||
+// expert weight tensors; 0 when the toggle is OFF (no repack happens).
|
||||
+size_t ggml_cuda_w4a16_repack_cache_bytes();
|
||||
+
|
||||
+// Offline-repack W4A16 grouped MoE GEMM over the token-sorted buffer. Same contract as
|
||||
+// ggml_cuda_mul_mat_id_w4a16_grouped (see w4a16-gemm.cuh) but src1_sorted is the RAW f32 sorted
|
||||
+// activations (cast to bf16 in-register; no global bf16 pre-pass needed) and the weights are read
|
||||
+// from the cached repacked layout (built one-time from src0).
|
||||
+void ggml_cuda_mul_mat_id_w4a16_repack(
|
||||
+ ggml_backend_cuda_context & ctx,
|
||||
+ const ggml_tensor * src0,
|
||||
+ const float * src1_sorted,
|
||||
+ float * dst_sorted,
|
||||
+ const int * tokens_per_expert,
|
||||
+ int64_t n_experts, int64_t K, int64_t N,
|
||||
+ cudaStream_t stream);
|
||||
--
|
||||
2.43.0
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
From 85266d4c10750b419716e4b8939ebd96ab424630 Mon Sep 17 00:00:00 2001
|
||||
From: Ettore Di Giacinto <mudler@localai.io>
|
||||
Date: Tue, 30 Jun 2026 00:51:26 +0000
|
||||
Subject: [PATCH] feat(paged): gate GDN prefill geometry by scan length (patch
|
||||
0046)
|
||||
|
||||
Patch 0022 retuned the gated-DeltaNet (GDN) sequential-recurrence dispatch
|
||||
(case 128) to a (NUM_WARPS=16, COLS_PER_WARP=8) column-fold tile. That is a
|
||||
DECODE win (short scans: small n_tokens, large n_seqs) but an UNCONDITIONAL
|
||||
dense-prefill regression vs stock: on a long sequential scan the launch grid.z
|
||||
collapses from S_v/4=32 to S_v/(16*8)=1, so the SMs starve. Profiling the
|
||||
dense-prefill path attributed the whole regression (~-6%) to gated_delta_net
|
||||
(+54% GPU time) at the (16,8) geometry.
|
||||
|
||||
Gate the geometry by per-call scan length instead of applying (16,8)
|
||||
unconditionally. Long scans (prefill, n_tokens >= GDN_PREFILL_NTOK, default 256)
|
||||
take stock's high-grid.z (4,1) geometry; short scans (decode) keep the (16,8)
|
||||
retune. This recovers dense prefill +7.2% back to stock parity while preserving
|
||||
the (16,8) decode win.
|
||||
|
||||
Bit-exact: patch 0022 proved every selectable {NUM_WARPS, COLS_PER_WARP} variant
|
||||
is byte-identical (the sweep cannot change the md5), so this scan-length gate is
|
||||
greedy-md5 bit-exact. GDN_PREFILL_NTOK tunes the crossover; the explicit
|
||||
GDN_NW / GDN_CPW one-build %peak sweep still wins (the gate yields when either is
|
||||
set), so the A/B harness is unchanged.
|
||||
|
||||
Root cause: patch 0022 applied the (16,8) tile unconditionally. This patch
|
||||
sequences after 0022/0044 (it edits the same gated_delta_net.cu case-128
|
||||
dispatch) and adds only the scan-length gate.
|
||||
|
||||
Assisted-by: Claude:opus-4.8 [Claude Code]
|
||||
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu
|
||||
index 7121d807f..26667afa2 100644
|
||||
--- a/ggml/src/ggml-cuda/gated_delta_net.cu
|
||||
+++ b/ggml/src/ggml-cuda/gated_delta_net.cu
|
||||
@@ -550,6 +550,23 @@ static void launch_gated_delta_net(
|
||||
launch_gdn_variant<64, KDA, keep_rs_t, 4, 1, 2>(GDN_LAUNCH_ARGS);
|
||||
break;
|
||||
case 128: {
|
||||
+ // Dense-prefill regression fix: gate patch 0022's column-fold geometry by per-call scan
|
||||
+ // length. The (16,8) tile is a DECODE win (short scans: n_tokens small, n_seqs large) but a
|
||||
+ // long-sequential-scan PREFILL loss - grid.z collapses from S_v/4=32 to S_v/(16*8)=1, so the
|
||||
+ // SMs starve on the long scan (profiled: gated_delta_net +54% GPU time == the whole dense-
|
||||
+ // prefill regression). Long scans (prefill) take stock's high-grid.z (4,1) geometry; short
|
||||
+ // scans (decode) keep the (16,8) winner. Every {NW,CPW} variant is byte-identical (patch 0022
|
||||
+ // proved md5-invariance across the ladder), so this stays greedy-md5 bit-exact. Default-on;
|
||||
+ // GDN_PREFILL_NTOK tunes the crossover; the explicit GDN_NW/GDN_CPW sweep still wins (gate
|
||||
+ // yields when either is set) so the one-build %peak A/B harness is unchanged.
|
||||
+ static const int64_t gdn_prefill_ntok =
|
||||
+ []{ const char * e = getenv("GDN_PREFILL_NTOK"); return e ? (int64_t) atoll(e) : (int64_t) 256; }();
|
||||
+ static const bool gdn_nw_forced = (getenv("GDN_NW") != nullptr);
|
||||
+ static const bool gdn_cpw_forced = (getenv("GDN_CPW") != nullptr);
|
||||
+ if (n_tokens >= gdn_prefill_ntok && !gdn_nw_forced && !gdn_cpw_forced) {
|
||||
+ launch_gdn_variant<128, KDA, keep_rs_t, 4, 1, 2>(GDN_LAUNCH_ARGS);
|
||||
+ break;
|
||||
+ }
|
||||
// Bit-exact occupancy/coalescing retune (patch 0022): fold COLS_PER_WARP columns per warp
|
||||
// to raise per-warp memory-level parallelism on this bandwidth-bound recurrence. Default is
|
||||
// the measured winner; GDN_NW / GDN_CPW override it for the one-build %peak sweep (every
|
||||
--
|
||||
2.43.0
|
||||
|
||||
@@ -0,0 +1,713 @@
|
||||
From 2c32ab8b7a6c5bc90454881b8c10f8bad4f7cee0 Mon Sep 17 00:00:00 2001
|
||||
From: Ettore Di Giacinto <mudler@localai.io>
|
||||
Date: Tue, 30 Jun 2026 09:45:13 +0200
|
||||
Subject: [PATCH] feat(paged): GDN M5 tensor-core chunked-scan prefill,
|
||||
f32-only re-port (was patch 0044)
|
||||
|
||||
Re-port the M5 tensor-core chunked gated-DeltaNet (GDN) prefill kernel from the
|
||||
bf16/hybrid dev tree as an f32-only native commit, recovering the prefill win
|
||||
that patch 0044 encoded, on the f32-only series (0026 ssm_bf16_tau dropped).
|
||||
|
||||
What landed (f32/tf32 only):
|
||||
- The mma.sync m16n8k8 helpers (tf32 + 3xtf32 limb-split; decays/gamma/beta stay
|
||||
f32 outside the mma to preserve the bounded de-gating).
|
||||
- gated_delta_net_chunked_cuda<S_v, C, TC>: the full tensor-core chunked scan,
|
||||
KK/QK Gram (M2), KS/QS state-boundary 3xtf32 (M3), P*U output (M4), and the
|
||||
form-T (A^-1) solve + Kc^T*DU state-update (M5). Selected by GDN_TC (0=serial
|
||||
.. 4/5+=M5); the C=16 chunk-state stays in the 64KB smem buffer.
|
||||
- Default-on under paged KV: GDN_TC=5, GDN_CHUNK_MIN=64 when LLAMA_KV_PAGED is set
|
||||
and the user has not overridden either; OFF (INT_MAX) otherwise so the stock /
|
||||
non-paged default is regression-free. GDN_CHUNK_MIN must stay > 1 (decode is 1
|
||||
token/call; at 1 the chunked path swallows decode and collapses S_TG).
|
||||
|
||||
Stripped (not part of the f32-only series): the STATE_BF16 / HYBRID / gdn_state_t
|
||||
/ gdn_hybrid_args template machinery (from dropped patch 0026), and the bf16
|
||||
CONFIG-C (M8) plus register-resident M6/M7 occupancy variants. The 0046 dense-
|
||||
prefill geometry gate is untouched and coexists (it gates the SERIAL path; M5 is
|
||||
the chunked path).
|
||||
|
||||
Gates (GB10, sm_121a):
|
||||
- Builds clean.
|
||||
- Greedy md5 bit-exact (per-path, n=48 --temp 0 --seed 1, paged): dense
|
||||
q36-27b-nvfp4 = 5951a5b4d624ce891e22ab5fca9bc439, MoE q36-35b-a3b-nvfp4 =
|
||||
8cb0ce23777bf55f92f63d0292c756b0, both default AND force-M5 (GDN_CHUNK_MIN=1).
|
||||
test-backend-ops GATED_DELTA_NET 46/46 default and force-M5 (incl. the
|
||||
multi-chunk, tail-chunk and multi-seq shapes).
|
||||
- Prefill S_PP, MoE, LLAMA_KV_PAGED=1 LLAMA_MOE_FORCE_GRAPHS=1, -ntg 4 -npl 32,
|
||||
vs the patch-0044 baseline (pre-0046, GDN_PREFILL_NTOK huge): +4.3% @512,
|
||||
+17.8% @2048 (reproduces patch 0044; M5-on absolute matches patch 0044 M5).
|
||||
vs the current 0046 baseline (0046 already raised the long-scan sequential
|
||||
prefill): +4.3% @512, +1.2% @2048.
|
||||
- Decode S_TG unchanged (within run noise).
|
||||
|
||||
Assisted-by: Claude:opus-4.8 [Claude Code]
|
||||
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
|
||||
---
|
||||
ggml/src/ggml-cuda/gated_delta_net.cu | 550 +++++++++++++++++++++++---
|
||||
tests/test-backend-ops.cpp | 5 +
|
||||
2 files changed, 496 insertions(+), 59 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu
|
||||
index 26667afa2..0ceb1bc8f 100644
|
||||
--- a/ggml/src/ggml-cuda/gated_delta_net.cu
|
||||
+++ b/ggml/src/ggml-cuda/gated_delta_net.cu
|
||||
@@ -298,7 +298,115 @@ static void launch_gdn_variant(
|
||||
// strong-decay tokens underflow to the correct zero rather than to inf. The math
|
||||
// is equivalent to the sequential recurrence up to FP reduction order (a NEW
|
||||
// per-path result, validated benign by test-backend-ops NMSE and greedy output).
|
||||
-template <int S_v, int C>
|
||||
+// --- Phase-1 tensor-core Gram helpers (tf32 m16n8k8 mma.sync; sm_80+/sm_121a). ---
|
||||
+// Reproduces the PoC-proven path (~/scratch_tc_gdn_poc/gdn_gram_bench.cu, tf32 NMSE ~3e-9):
|
||||
+// out[rowbase..+15][colbase..+7] = Xs[rows] . Ys[cols], Xs/Ys row-major [*][DK].
|
||||
+__device__ __forceinline__ unsigned gdn_f2tf32(float f) {
|
||||
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
+ unsigned r;
|
||||
+ asm("cvt.rna.tf32.f32 %0, %1;" : "=r"(r) : "f"(f));
|
||||
+ return r;
|
||||
+#else
|
||||
+ (void) f;
|
||||
+ return 0u;
|
||||
+#endif
|
||||
+}
|
||||
+
|
||||
+// Operand loaders for the Gram/state mma helpers: stage f32 operands as tf32. This f32-only
|
||||
+// re-port keeps every operand full-width -- the plain-tf32 path (10-bit mantissa, f32 accumulate)
|
||||
+// is the highest-precision tensor-core option on sm_121a, and the 3xtf32 limb-split helpers below
|
||||
+// recover near-f32 accuracy for the decay-coupled state-boundary (KS/QS) and state-carry products
|
||||
+// whose error feeds the A-inverse solve / compounds across chunks.
|
||||
+__device__ __forceinline__ unsigned gdn_ld_tf32(float f) { return gdn_f2tf32(f); }
|
||||
+__device__ __forceinline__ float gdn_ld_f32 (float f) { return f; }
|
||||
+
|
||||
+__device__ __forceinline__ void gdn_mma_m16n8k8(float c[4], const unsigned a[4], const unsigned b[2]) {
|
||||
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
+ asm volatile(
|
||||
+ "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 "
|
||||
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%0,%1,%2,%3};\n"
|
||||
+ : "+f"(c[0]), "+f"(c[1]), "+f"(c[2]), "+f"(c[3])
|
||||
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]));
|
||||
+#else
|
||||
+ (void) c; (void) a; (void) b;
|
||||
+#endif
|
||||
+}
|
||||
+
|
||||
+template <int DK, typename TX = float, typename TY = float>
|
||||
+__device__ __forceinline__ void gdn_gram_tile_mma(
|
||||
+ float c[4], const TX * __restrict__ Xs, const TY * __restrict__ Ys,
|
||||
+ int rowbase, int colbase, int lg, int lt) {
|
||||
+ c[0] = c[1] = c[2] = c[3] = 0.0f;
|
||||
+ #pragma unroll
|
||||
+ for (int ks = 0; ks < DK; ks += 8) {
|
||||
+ unsigned a[4], b[2];
|
||||
+ a[0] = gdn_ld_tf32(Xs[(rowbase + lg ) * DK + ks + lt ]);
|
||||
+ a[1] = gdn_ld_tf32(Xs[(rowbase + lg + 8) * DK + ks + lt ]);
|
||||
+ a[2] = gdn_ld_tf32(Xs[(rowbase + lg ) * DK + ks + lt + 4]);
|
||||
+ a[3] = gdn_ld_tf32(Xs[(rowbase + lg + 8) * DK + ks + lt + 4]);
|
||||
+ b[0] = gdn_ld_tf32(Ys[(colbase + lg ) * DK + ks + lt ]);
|
||||
+ b[1] = gdn_ld_tf32(Ys[(colbase + lg ) * DK + ks + lt + 4]);
|
||||
+ gdn_mma_m16n8k8(c, a, b);
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// 3xtf32 (CUTLASS fp32-emulation): split each f32 operand into hi/lo tf32 limbs and run
|
||||
+// 3 limb-products per k-subtile (hi*hi + hi*lo + lo*hi); ~f32 accuracy at ~3x the mma count.
|
||||
+// Used for the state-boundary products (KS/QS) whose error feeds the A-inverse solve (M3).
|
||||
+template <int DK, typename TX = float, typename TY = float>
|
||||
+__device__ __forceinline__ void gdn_gram_tile_mma_3x(
|
||||
+ float c[4], const TX * __restrict__ Xs, const TY * __restrict__ Ys,
|
||||
+ int rowbase, int colbase, int lg, int lt) {
|
||||
+ c[0] = c[1] = c[2] = c[3] = 0.0f;
|
||||
+ #pragma unroll
|
||||
+ for (int ks = 0; ks < DK; ks += 8) {
|
||||
+ float af[4], bf[2];
|
||||
+ af[0] = gdn_ld_f32(Xs[(rowbase + lg ) * DK + ks + lt ]);
|
||||
+ af[1] = gdn_ld_f32(Xs[(rowbase + lg + 8) * DK + ks + lt ]);
|
||||
+ af[2] = gdn_ld_f32(Xs[(rowbase + lg ) * DK + ks + lt + 4]);
|
||||
+ af[3] = gdn_ld_f32(Xs[(rowbase + lg + 8) * DK + ks + lt + 4]);
|
||||
+ bf[0] = gdn_ld_f32(Ys[(colbase + lg ) * DK + ks + lt ]);
|
||||
+ bf[1] = gdn_ld_f32(Ys[(colbase + lg ) * DK + ks + lt + 4]);
|
||||
+ unsigned ahi[4], alo[4], bhi[2], blo[2];
|
||||
+ #pragma unroll
|
||||
+ for (int z = 0; z < 4; z++) { ahi[z] = gdn_f2tf32(af[z]); alo[z] = gdn_f2tf32(af[z] - __uint_as_float(ahi[z])); }
|
||||
+ #pragma unroll
|
||||
+ for (int z = 0; z < 2; z++) { bhi[z] = gdn_f2tf32(bf[z]); blo[z] = gdn_f2tf32(bf[z] - __uint_as_float(bhi[z])); }
|
||||
+ gdn_mma_m16n8k8(c, ahi, bhi); // hi*hi (dominant limb)
|
||||
+ gdn_mma_m16n8k8(c, ahi, blo); // hi*lo
|
||||
+ gdn_mma_m16n8k8(c, alo, bhi); // lo*hi
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// State-update tile (P6): S_C[i][j] += sum_t Kc[t][i] * DU[t][j], with Kc read TRANSPOSED
|
||||
+// (i as the m16n8k8 M-row, t as the K-contraction) and DU = d(t,last)*U staged in the Ud
|
||||
+// layout (DUd[j*KC + t]). 3xtf32: the cross-chunk carry compounds over every chunk step.
|
||||
+template <int KC, int DK, typename TK = float, typename TD = float>
|
||||
+__device__ __forceinline__ void gdn_state_tile_mma_3x(
|
||||
+ float c[4], const TK * __restrict__ Kc, const TD * __restrict__ DUd,
|
||||
+ int rowbase, int colbase, int lg, int lt) {
|
||||
+ c[0] = c[1] = c[2] = c[3] = 0.0f;
|
||||
+ #pragma unroll
|
||||
+ for (int ks = 0; ks < KC; ks += 8) {
|
||||
+ float af[4], bf[2];
|
||||
+ af[0] = gdn_ld_f32(Kc[(ks + lt ) * DK + (rowbase + lg )]);
|
||||
+ af[1] = gdn_ld_f32(Kc[(ks + lt ) * DK + (rowbase + lg + 8)]);
|
||||
+ af[2] = gdn_ld_f32(Kc[(ks + lt + 4) * DK + (rowbase + lg )]);
|
||||
+ af[3] = gdn_ld_f32(Kc[(ks + lt + 4) * DK + (rowbase + lg + 8)]);
|
||||
+ bf[0] = gdn_ld_f32(DUd[(colbase + lg) * KC + (ks + lt )]);
|
||||
+ bf[1] = gdn_ld_f32(DUd[(colbase + lg) * KC + (ks + lt + 4)]);
|
||||
+ unsigned ahi[4], alo[4], bhi[2], blo[2];
|
||||
+ #pragma unroll
|
||||
+ for (int z = 0; z < 4; z++) { ahi[z] = gdn_f2tf32(af[z]); alo[z] = gdn_f2tf32(af[z] - __uint_as_float(ahi[z])); }
|
||||
+ #pragma unroll
|
||||
+ for (int z = 0; z < 2; z++) { bhi[z] = gdn_f2tf32(bf[z]); blo[z] = gdn_f2tf32(bf[z] - __uint_as_float(bhi[z])); }
|
||||
+ gdn_mma_m16n8k8(c, ahi, bhi);
|
||||
+ gdn_mma_m16n8k8(c, ahi, blo);
|
||||
+ gdn_mma_m16n8k8(c, alo, bhi);
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+template <int S_v, int C, int TC = 0>
|
||||
__global__ void gated_delta_net_chunked_cuda(
|
||||
const float * __restrict__ q, const float * __restrict__ k,
|
||||
const float * __restrict__ v, const float * __restrict__ g,
|
||||
@@ -329,6 +437,9 @@ __global__ void gated_delta_net_chunked_cuda(
|
||||
float * csh = Amat + (size_t) C * C; // [C] cumsum(log-gate)
|
||||
float * gam = csh + C; // [C] gamma_t = exp(cs_t)
|
||||
float * bet = gam + C; // [C] beta_t
|
||||
+ // Phase-1 tensor-core Gram scratch (allocated only when GRAM_MMA; KK feeds A, QK feeds P).
|
||||
+ float * KKsh = bet + C; // [C*C] KK[t][t'] = k_t . k_t' (stride C)
|
||||
+ float * QKsh = KKsh + (size_t) C * C; // [C*C] QK[t][t'] = q_t . k_t' (stride C)
|
||||
|
||||
// S0: thread j owns column j (Sd[j*dk + i]); load is a contiguous per-thread copy from the
|
||||
// M-layout cache view (read_state[j*dk + i] = M[j*S_v + i] = S[i][j]). Same identity/gather
|
||||
@@ -357,6 +468,15 @@ __global__ void gated_delta_net_chunked_cuda(
|
||||
Kc[t * dk + i] = k_base[(c0 + t) * sq2 + i];
|
||||
Qc[t * dk + i] = q_base[(c0 + t) * sq2 + i];
|
||||
}
|
||||
+ if constexpr (TC >= 3) {
|
||||
+ // Zero the stale K/Q tail (rows t >= Cc): the tensor-core mma paths contract the full
|
||||
+ // chunk dim and 0*NaN (uninitialized smem) would poison the result. Serial paths only
|
||||
+ // touch t < Cc, so this is gated to the mma levels.
|
||||
+ for (int e = Cc * dk + j; e < C * dk; e += dv) {
|
||||
+ Kc[e] = 0.0f;
|
||||
+ Qc[e] = 0.0f;
|
||||
+ }
|
||||
+ }
|
||||
if (j < Cc) {
|
||||
csh[j] = g[gb_base + (c0 + j) * sb2]; // raw log-gate, prefix-summed below
|
||||
bet[j] = beta[gb_base + (c0 + j) * sb2];
|
||||
@@ -372,15 +492,53 @@ __global__ void gated_delta_net_chunked_cuda(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
+ // --- Phase-1: tensor-core tf32 Gram products (KK->A via warp0, QK->P via warp1). ---
|
||||
+ // Full C x C tiles into KKsh/QKsh (stride C); decay/beta applied in f32 in the loops below.
|
||||
+ // Tail chunks (Cc<C) compute stale rows >= Cc, but those entries are never read.
|
||||
+ if constexpr (TC >= 1) {
|
||||
+ const int w = threadIdx.x >> 5; // warp: 0 -> KK, 1 -> QK
|
||||
+ const int lane = threadIdx.x & 31;
|
||||
+ const int lg = lane >> 2; // 0..7
|
||||
+ const int lt = lane & 3; // 0..3
|
||||
+ if (w < 2) {
|
||||
+ const float * Xs = (w == 0) ? Kc : Qc;
|
||||
+ float * Out = (w == 0) ? KKsh : QKsh;
|
||||
+ #pragma unroll
|
||||
+ for (int mt = 0; mt < (C + 15) / 16; mt++) {
|
||||
+ const int rowbase = mt * 16;
|
||||
+ #pragma unroll
|
||||
+ for (int nt = 0; nt < (C + 7) / 8; nt++) {
|
||||
+ const int colbase = nt * 8;
|
||||
+ float cc[4];
|
||||
+ gdn_gram_tile_mma<dk>(cc, Xs, Kc, rowbase, colbase, lg, lt);
|
||||
+ const int rr[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8};
|
||||
+ const int ccol[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1};
|
||||
+ #pragma unroll
|
||||
+ for (int l = 0; l < 4; l++) {
|
||||
+ if (rr[l] < C && ccol[l] < C) {
|
||||
+ Out[rr[l] * C + ccol[l]] = cc[l];
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ __syncthreads();
|
||||
+ }
|
||||
+
|
||||
// --- A = I + tril(beta_t * d(t',t) * (k_t . k_t'), -1) (cooperative over C*C) ---
|
||||
for (int e = j; e < Cc * Cc; e += dv) {
|
||||
const int t = e / Cc;
|
||||
const int tp = e % Cc;
|
||||
float a = 0.0f;
|
||||
if (tp < t) {
|
||||
- float kk = 0.0f;
|
||||
- for (int i = 0; i < dk; i++) {
|
||||
- kk += Kc[t * dk + i] * Kc[tp * dk + i];
|
||||
+ float kk;
|
||||
+ if constexpr (TC >= 1) {
|
||||
+ kk = KKsh[t * C + tp];
|
||||
+ } else {
|
||||
+ kk = 0.0f;
|
||||
+ for (int i = 0; i < dk; i++) {
|
||||
+ kk += Kc[t * dk + i] * Kc[tp * dk + i];
|
||||
+ }
|
||||
}
|
||||
const float dd = expf(csh[t] - csh[tp]); // d(tp,t) = gamma_t/gamma_tp
|
||||
a = bet[t] * dd * kk;
|
||||
@@ -392,65 +550,304 @@ __global__ void gated_delta_net_chunked_cuda(
|
||||
__syncthreads();
|
||||
|
||||
// --- RHS[t][j] = beta_t (v_t[j] - gamma_t * (S0^T k_t)[j]) -> Ud[j*C + t] ---
|
||||
- for (int t = 0; t < Cc; t++) {
|
||||
- float ks = 0.0f; // (S0^T k_t)[j] = sum_i S[i][j] k_t[i]
|
||||
- for (int i = 0; i < dk; i++) {
|
||||
- ks += Sd[j * dk + i] * Kc[t * dk + i];
|
||||
+ if constexpr (TC >= 2) {
|
||||
+ // M3: fused tensor-core KS = Kc * S0 (3xtf32 state-boundary product). The mma
|
||||
+ // output is consumed straight from registers into RHS -> Ud, so NO extra C*dv
|
||||
+ // smem buffer is needed (the 64KB state still occupies smem until M6). Warp w
|
||||
+ // owns dv n-tiles [w*NTPW, ..); each lane writes the RHS entries it produced.
|
||||
+ const int w = threadIdx.x >> 5;
|
||||
+ const int lane = threadIdx.x & 31;
|
||||
+ const int lg = lane >> 2;
|
||||
+ const int lt = lane & 3;
|
||||
+ constexpr int NWARP = S_v / 32;
|
||||
+ constexpr int NT = dv / 8;
|
||||
+ constexpr int NTPW = (NT + NWARP - 1) / NWARP;
|
||||
+ #pragma unroll
|
||||
+ for (int mt = 0; mt < (C + 15) / 16; mt++) {
|
||||
+ const int rowbase = mt * 16;
|
||||
+ #pragma unroll
|
||||
+ for (int nn = 0; nn < NTPW; nn++) {
|
||||
+ const int nt = w * NTPW + nn;
|
||||
+ if (nt >= NT) break;
|
||||
+ const int colbase = nt * 8;
|
||||
+ float cc[4];
|
||||
+ gdn_gram_tile_mma_3x<dk>(cc, Kc, Sd, rowbase, colbase, lg, lt);
|
||||
+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8};
|
||||
+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1};
|
||||
+ #pragma unroll
|
||||
+ for (int l = 0; l < 4; l++) {
|
||||
+ const int t = tt[l], jc = jj[l];
|
||||
+ if (t < Cc && jc < dv) {
|
||||
+ const float vtj = v_base[(c0 + t) * sv2 + jc];
|
||||
+ Ud[jc * C + t] = bet[t] * (vtj - gam[t] * cc[l]);
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ __syncthreads(); // RHS written cross-thread -> publish before the per-column solve
|
||||
+ } else {
|
||||
+ for (int t = 0; t < Cc; t++) {
|
||||
+ float ks = 0.0f; // (S0^T k_t)[j] = sum_i S[i][j] k_t[i]
|
||||
+ for (int i = 0; i < dk; i++) {
|
||||
+ ks += Sd[j * dk + i] * Kc[t * dk + i];
|
||||
+ }
|
||||
+ const float vtj = v_base[(c0 + t) * sv2 + j];
|
||||
+ Ud[j * C + t] = bet[t] * (vtj - gam[t] * ks);
|
||||
}
|
||||
- const float vtj = v_base[(c0 + t) * sv2 + j];
|
||||
- Ud[j * C + t] = bet[t] * (vtj - gam[t] * ks);
|
||||
+ }
|
||||
+ if constexpr (TC >= 3) {
|
||||
+ // Zero the stale RHS tail (rows t >= Cc) before the full-K mma consumers (P*U at TC>=3;
|
||||
+ // apply + state at TC>=4). Without this the masked tail terms compute 0*NaN = NaN.
|
||||
+ for (int t = Cc; t < C; t++) Ud[j * C + t] = 0.0f;
|
||||
+ __syncthreads();
|
||||
}
|
||||
|
||||
- // --- solve A U = RHS in place (unit lower-tri fwd subst); per-thread, no inter-step sync ---
|
||||
- for (int t = 1; t < Cc; t++) {
|
||||
- float acc = Ud[j * C + t];
|
||||
- for (int tp = 0; tp < t; tp++) {
|
||||
- acc -= Amat[t * Cc + tp] * Ud[j * C + tp];
|
||||
+ // --- solve A U = RHS (A unit-lower-tri) ---
|
||||
+ if constexpr (TC >= 4) {
|
||||
+ // M5/P7: form T = A^{-1} explicitly (FLA UT transform), then U = T*RHS as one
|
||||
+ // dependency-free tf32 GEMM. At C<=16 A is a single b=16 block, so the off-diagonal
|
||||
+ // Phase-O is empty; only the f32 in-shared diagonal inverse (Phase-D) + the wide
|
||||
+ // apply remain. Phase-D: column-parallel EXACT f32 inverse of the Cc x Cc unit-
|
||||
+ // lower-tri A -- thread c solves A x = e_c, writing column c of T into KKsh (free
|
||||
+ // since KK was consumed into A). This is the strong-coupling amplifier -> f32.
|
||||
+ if (j < C) {
|
||||
+ if (j < Cc) {
|
||||
+ float x[C];
|
||||
+ #pragma unroll
|
||||
+ for (int r = 0; r < C; r++) x[r] = 0.0f;
|
||||
+ x[j] = 1.0f;
|
||||
+ for (int r = j + 1; r < Cc; r++) {
|
||||
+ float acc = 0.0f;
|
||||
+ for (int m = j; m < r; m++) acc += Amat[r * Cc + m] * x[m];
|
||||
+ x[r] = -acc;
|
||||
+ }
|
||||
+ #pragma unroll
|
||||
+ for (int r = 0; r < C; r++) KKsh[r * C + j] = x[r]; // rows >= Cc are 0
|
||||
+ } else {
|
||||
+ #pragma unroll
|
||||
+ for (int r = 0; r < C; r++) KKsh[r * C + j] = 0.0f; // cols >= Cc are 0
|
||||
+ }
|
||||
+ }
|
||||
+ __syncthreads();
|
||||
+ // Apply U = T*RHS, M=C N=dv K=C; T=KKsh (stride C), RHS=Ud (stride C). In place on
|
||||
+ // Ud: hold every output tile in registers, sync to finish the RHS reads, then
|
||||
+ // overwrite Ud with U (avoids the read/write aliasing of a same-buffer GEMM).
|
||||
+ {
|
||||
+ const int w = threadIdx.x >> 5;
|
||||
+ const int lane = threadIdx.x & 31;
|
||||
+ const int lg = lane >> 2;
|
||||
+ const int lt = lane & 3;
|
||||
+ constexpr int NWARP = S_v / 32;
|
||||
+ constexpr int NT = dv / 8;
|
||||
+ constexpr int NTPW = (NT + NWARP - 1) / NWARP;
|
||||
+ float ureg[NTPW][4];
|
||||
+ #pragma unroll
|
||||
+ for (int nn = 0; nn < NTPW; nn++) {
|
||||
+ const int nt = w * NTPW + nn;
|
||||
+ if (nt < NT) gdn_gram_tile_mma<C>(ureg[nn], KKsh, Ud, 0, nt * 8, lg, lt);
|
||||
+ }
|
||||
+ __syncthreads(); // all RHS(Ud) reads done before overwriting with U
|
||||
+ #pragma unroll
|
||||
+ for (int nn = 0; nn < NTPW; nn++) {
|
||||
+ const int nt = w * NTPW + nn;
|
||||
+ if (nt >= NT) continue;
|
||||
+ const int colbase = nt * 8;
|
||||
+ const int tt[4] = {lg, lg, lg + 8, lg + 8};
|
||||
+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1};
|
||||
+ #pragma unroll
|
||||
+ for (int l = 0; l < 4; l++) {
|
||||
+ const int t = tt[l], jc = jj[l];
|
||||
+ if (t < Cc && jc < dv) Ud[jc * C + t] = ureg[nn][l];
|
||||
+ }
|
||||
+ }
|
||||
+ __syncthreads();
|
||||
}
|
||||
- Ud[j * C + t] = acc;
|
||||
+ } else {
|
||||
+ for (int t = 1; t < Cc; t++) {
|
||||
+ float acc = Ud[j * C + t];
|
||||
+ for (int tp = 0; tp < t; tp++) {
|
||||
+ acc -= Amat[t * Cc + tp] * Ud[j * C + tp];
|
||||
+ }
|
||||
+ Ud[j * C + t] = acc;
|
||||
+ }
|
||||
+ __syncthreads(); // U finalized; Amat free for P below
|
||||
}
|
||||
- __syncthreads(); // U finalized; Amat free for P below (and Ud read across-thread? no, own col)
|
||||
|
||||
- // --- P[t][t'] = d(t',t) * (q_t . k_t') for t' <= t (reuse Amat) ---
|
||||
- for (int e = j; e < Cc * Cc; e += dv) {
|
||||
- const int t = e / Cc;
|
||||
- const int tp = e % Cc;
|
||||
- float p = 0.0f;
|
||||
- if (tp <= t) {
|
||||
- float qk = 0.0f;
|
||||
- for (int i = 0; i < dk; i++) {
|
||||
- qk += Qc[t * dk + i] * Kc[tp * dk + i];
|
||||
+ // --- P[t][t'] = d(t',t) * (q_t . k_t') for t' <= t ---
|
||||
+ if constexpr (TC >= 3) {
|
||||
+ // M4: build P (lower-tri, decay pre-baked in f32 -> bounded) IN PLACE in QKsh at
|
||||
+ // fixed stride C so the P*U output mma can read it as a tf32 A-operand. Full C*C
|
||||
+ // grid: upper-tri / out-of-range entries are zeroed so the K=C mma needs no masking.
|
||||
+ for (int e = j; e < C * C; e += dv) {
|
||||
+ const int t = e / C;
|
||||
+ const int tp = e % C;
|
||||
+ float p = 0.0f;
|
||||
+ if (tp <= t && t < Cc && tp < Cc) {
|
||||
+ const float dd = expf(csh[t] - csh[tp]);
|
||||
+ p = dd * QKsh[t * C + tp]; // QKsh holds QK (M2); overwrite in place with P
|
||||
+ }
|
||||
+ QKsh[t * C + tp] = p;
|
||||
+ }
|
||||
+ __syncthreads();
|
||||
+ } else {
|
||||
+ for (int e = j; e < Cc * Cc; e += dv) {
|
||||
+ const int t = e / Cc;
|
||||
+ const int tp = e % Cc;
|
||||
+ float p = 0.0f;
|
||||
+ if (tp <= t) {
|
||||
+ float qk;
|
||||
+ if constexpr (TC >= 1) {
|
||||
+ qk = QKsh[t * C + tp];
|
||||
+ } else {
|
||||
+ qk = 0.0f;
|
||||
+ for (int i = 0; i < dk; i++) {
|
||||
+ qk += Qc[t * dk + i] * Kc[tp * dk + i];
|
||||
+ }
|
||||
+ }
|
||||
+ const float dd = expf(csh[t] - csh[tp]);
|
||||
+ p = dd * qk;
|
||||
}
|
||||
- const float dd = expf(csh[t] - csh[tp]);
|
||||
- p = dd * qk;
|
||||
+ Amat[t * Cc + tp] = p;
|
||||
}
|
||||
- Amat[t * Cc + tp] = p;
|
||||
+ __syncthreads();
|
||||
}
|
||||
- __syncthreads();
|
||||
|
||||
// --- O[t][j] = gamma_t (S0^T q_t)[j] + sum_{t'<=t} P[t][t'] U[t'][j] (* scale) ---
|
||||
- for (int t = 0; t < Cc; t++) {
|
||||
- float qs = 0.0f; // (S0^T q_t)[j] (uses pre-update S)
|
||||
- for (int i = 0; i < dk; i++) {
|
||||
- qs += Sd[j * dk + i] * Qc[t * dk + i];
|
||||
+ if constexpr (TC >= 2) {
|
||||
+ // M3: fused tensor-core QS = Qc * S0 (3xtf32, pre-update S0). Deposit the
|
||||
+ // gamma_t*QS[t][j] cross-chunk term into dst from the mma registers; the O loop
|
||||
+ // below reads it back (published via __syncthreads) and adds the intra-chunk P*U.
|
||||
+ const int w = threadIdx.x >> 5;
|
||||
+ const int lane = threadIdx.x & 31;
|
||||
+ const int lg = lane >> 2;
|
||||
+ const int lt = lane & 3;
|
||||
+ constexpr int NWARP = S_v / 32;
|
||||
+ constexpr int NT = dv / 8;
|
||||
+ constexpr int NTPW = (NT + NWARP - 1) / NWARP;
|
||||
+ #pragma unroll
|
||||
+ for (int mt = 0; mt < (C + 15) / 16; mt++) {
|
||||
+ const int rowbase = mt * 16;
|
||||
+ #pragma unroll
|
||||
+ for (int nn = 0; nn < NTPW; nn++) {
|
||||
+ const int nt = w * NTPW + nn;
|
||||
+ if (nt >= NT) break;
|
||||
+ const int colbase = nt * 8;
|
||||
+ float cc[4];
|
||||
+ gdn_gram_tile_mma_3x<dk>(cc, Qc, Sd, rowbase, colbase, lg, lt);
|
||||
+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8};
|
||||
+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1};
|
||||
+ #pragma unroll
|
||||
+ for (int l = 0; l < 4; l++) {
|
||||
+ const int t = tt[l], jc = jj[l];
|
||||
+ if (t < Cc && jc < dv) {
|
||||
+ attn_base[(c0 + t) * S_v * H + jc] = gam[t] * cc[l];
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
}
|
||||
- float o = gam[t] * qs;
|
||||
- for (int tp = 0; tp <= t; tp++) {
|
||||
- o += Amat[t * Cc + tp] * Ud[j * C + tp];
|
||||
+ __syncthreads();
|
||||
+ }
|
||||
+ if constexpr (TC >= 3) {
|
||||
+ // M4: O += P*U via tensor-core (tf32-safe: P is f32-bounded, decay pre-baked).
|
||||
+ // GEMM O[t][j] += sum_t' P[t][t']*U[t'][j], M=C N=dv K=C; P=QKsh (stride C),
|
||||
+ // U=Ud (stride C). The gamma_t*QS cross-chunk term was deposited into dst above;
|
||||
+ // fold it in here then * scale. Warp w owns dv n-tiles [w*NTPW, ..).
|
||||
+ const int w = threadIdx.x >> 5;
|
||||
+ const int lane = threadIdx.x & 31;
|
||||
+ const int lg = lane >> 2;
|
||||
+ const int lt = lane & 3;
|
||||
+ constexpr int NWARP = S_v / 32;
|
||||
+ constexpr int NT = dv / 8;
|
||||
+ constexpr int NTPW = (NT + NWARP - 1) / NWARP;
|
||||
+ #pragma unroll
|
||||
+ for (int mt = 0; mt < (C + 15) / 16; mt++) {
|
||||
+ const int rowbase = mt * 16;
|
||||
+ #pragma unroll
|
||||
+ for (int nn = 0; nn < NTPW; nn++) {
|
||||
+ const int nt = w * NTPW + nn;
|
||||
+ if (nt >= NT) break;
|
||||
+ const int colbase = nt * 8;
|
||||
+ float cc[4];
|
||||
+ gdn_gram_tile_mma<C>(cc, QKsh, Ud, rowbase, colbase, lg, lt);
|
||||
+ const int tt[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8};
|
||||
+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1};
|
||||
+ #pragma unroll
|
||||
+ for (int l = 0; l < 4; l++) {
|
||||
+ const int t = tt[l], jc = jj[l];
|
||||
+ if (t < Cc && jc < dv) {
|
||||
+ const int64_t oi = (int64_t)(c0 + t) * S_v * H + jc;
|
||||
+ attn_base[oi] = (attn_base[oi] + cc[l]) * scale; // QS term + P*U
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ } else {
|
||||
+ for (int t = 0; t < Cc; t++) {
|
||||
+ float o;
|
||||
+ if constexpr (TC >= 2) {
|
||||
+ o = attn_base[(c0 + t) * S_v * H + j]; // gamma_t*QS[t][j] deposited above
|
||||
+ } else {
|
||||
+ float qs = 0.0f; // (S0^T q_t)[j] (uses pre-update S)
|
||||
+ for (int i = 0; i < dk; i++) {
|
||||
+ qs += Sd[j * dk + i] * Qc[t * dk + i];
|
||||
+ }
|
||||
+ o = gam[t] * qs;
|
||||
+ }
|
||||
+ for (int tp = 0; tp <= t; tp++) {
|
||||
+ o += Amat[t * Cc + tp] * Ud[j * C + tp];
|
||||
+ }
|
||||
+ attn_base[(c0 + t) * S_v * H + j] = o * scale;
|
||||
}
|
||||
- attn_base[(c0 + t) * S_v * H + j] = o * scale;
|
||||
}
|
||||
|
||||
// --- S_C[i][j] = gamma_{C-1} S[i][j] + sum_t d(t,C-1) k_t[i] u_t[j] ---
|
||||
const float glast = gam[Cc - 1];
|
||||
const float cslast = csh[Cc - 1];
|
||||
- for (int i = 0; i < dk; i++) {
|
||||
- float s = glast * Sd[j * dk + i];
|
||||
- for (int t = 0; t < Cc; t++) {
|
||||
- const float dd = expf(cslast - csh[t]); // d(t, last)
|
||||
- s += dd * Kc[t * dk + i] * Ud[j * C + t];
|
||||
+ if constexpr (TC >= 4) {
|
||||
+ // M5/P6: state carry S_C = glast*S0 + Kc^T * DU via 3xtf32 mma. DU[t][j] =
|
||||
+ // d(t,last)*U[t][j] is built IN PLACE in Ud (t>=Cc zeroed so the K=C contraction
|
||||
+ // needs no per-k masking), then S_C accumulates over the chunk dim t. Kc is read
|
||||
+ // transposed (i as M-row). M=dk N=dv K=C. Each output (i,j) has a unique owner so
|
||||
+ // the glast*S0 read-modify-write is race-free.
|
||||
+ for (int t = 0; t < C; t++) {
|
||||
+ const float dd = (t < Cc) ? expf(cslast - csh[t]) : 0.0f;
|
||||
+ Ud[j * C + t] = dd * Ud[j * C + t]; // thread j owns column j -> DU in place
|
||||
+ }
|
||||
+ __syncthreads();
|
||||
+ const int w = threadIdx.x >> 5;
|
||||
+ const int lane = threadIdx.x & 31;
|
||||
+ const int lg = lane >> 2;
|
||||
+ const int lt = lane & 3;
|
||||
+ constexpr int NWARP = S_v / 32;
|
||||
+ constexpr int MT = dk / 16; // m-tiles over dk
|
||||
+ constexpr int NT = dv / 8; // n-tiles over dv
|
||||
+ constexpr int NTILES = MT * NT;
|
||||
+ constexpr int TPW = (NTILES + NWARP - 1) / NWARP;
|
||||
+ #pragma unroll
|
||||
+ for (int idx = 0; idx < TPW; idx++) {
|
||||
+ const int tile = w * TPW + idx;
|
||||
+ if (tile >= NTILES) break;
|
||||
+ const int rowbase = (tile / NT) * 16;
|
||||
+ const int colbase = (tile % NT) * 8;
|
||||
+ float cc[4];
|
||||
+ gdn_state_tile_mma_3x<C, dk>(cc, Kc, Ud, rowbase, colbase, lg, lt);
|
||||
+ const int ii[4] = {rowbase + lg, rowbase + lg, rowbase + lg + 8, rowbase + lg + 8};
|
||||
+ const int jj[4] = {colbase + 2*lt, colbase + 2*lt + 1, colbase + 2*lt, colbase + 2*lt + 1};
|
||||
+ #pragma unroll
|
||||
+ for (int l = 0; l < 4; l++) {
|
||||
+ const int i = ii[l], jc = jj[l];
|
||||
+ Sd[jc * dk + i] = glast * Sd[jc * dk + i] + cc[l];
|
||||
+ }
|
||||
+ }
|
||||
+ } else {
|
||||
+ for (int i = 0; i < dk; i++) {
|
||||
+ float s = glast * Sd[j * dk + i];
|
||||
+ for (int t = 0; t < Cc; t++) {
|
||||
+ const float dd = expf(cslast - csh[t]); // d(t, last)
|
||||
+ s += dd * Kc[t * dk + i] * Ud[j * C + t];
|
||||
+ }
|
||||
+ Sd[j * dk + i] = s;
|
||||
}
|
||||
- Sd[j * dk + i] = s;
|
||||
}
|
||||
__syncthreads(); // Sd reused as S0 of next chunk; Kc/Qc/Amat reloaded next chunk
|
||||
}
|
||||
@@ -464,8 +861,7 @@ __global__ void gated_delta_net_chunked_cuda(
|
||||
st[j * dk + i] = Sd[j * dk + i];
|
||||
}
|
||||
}
|
||||
-
|
||||
-template <int S_v, int C>
|
||||
+template <int S_v, int C, int TC = 0>
|
||||
static void launch_gdn_chunked(
|
||||
const float * q_d, const float * k_d, const float * v_d,
|
||||
const float * g_d, const float * b_d, const float * s_d,
|
||||
@@ -477,10 +873,11 @@ static void launch_gdn_chunked(
|
||||
const uint3 neqk1_magic, const uint3 rq3_magic,
|
||||
float scale, cudaStream_t stream) {
|
||||
const size_t smem = ((size_t) S_v * S_v + (size_t) 2 * C * S_v + (size_t) S_v * C
|
||||
- + (size_t) C * C + (size_t) 3 * C) * sizeof(float);
|
||||
+ + (size_t) C * C + (size_t) 3 * C
|
||||
+ + (TC >= 1 ? (size_t) 2 * C * C : (size_t) 0)) * sizeof(float);
|
||||
static bool attr_set = false;
|
||||
if (!attr_set) {
|
||||
- const cudaError_t e = cudaFuncSetAttribute(gated_delta_net_chunked_cuda<S_v, C>,
|
||||
+ const cudaError_t e = cudaFuncSetAttribute(gated_delta_net_chunked_cuda<S_v, C, TC>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smem);
|
||||
if (e != cudaSuccess) {
|
||||
GGML_ABORT("gdn chunked: cudaFuncSetAttribute(maxDynSmem=%zu) failed: %s\n", smem, cudaGetErrorString(e));
|
||||
@@ -489,7 +886,7 @@ static void launch_gdn_chunked(
|
||||
}
|
||||
dim3 grid_dims(H, n_seqs, 1);
|
||||
dim3 block_dims(S_v, 1, 1);
|
||||
- gated_delta_net_chunked_cuda<S_v, C><<<grid_dims, block_dims, smem, stream>>>(
|
||||
+ gated_delta_net_chunked_cuda<S_v, C, TC><<<grid_dims, block_dims, smem, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs,
|
||||
sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3,
|
||||
neqk1_magic, rq3_magic, scale, state_dst_d, ids_d, rs_head);
|
||||
@@ -519,17 +916,52 @@ static void launch_gated_delta_net(
|
||||
// sequential recurrence. Mathematically equivalent up to FP reduction order (NEW per-path md5;
|
||||
// validated benign by test-backend-ops NMSE + greedy output). Toggle: GDN_CHUNK_OFF / GDN_CHUNK_MIN.
|
||||
if constexpr (!KDA && !keep_rs_t) {
|
||||
- // OPT-IN: this chunked path is bit-exact-benign (test-backend-ops green) but, at C=16
|
||||
- // (forced by GB10 99KB dyn-smem opt-in, all-shared), it is NOT yet faster than the tuned
|
||||
- // sequential recurrence on this model (measured ~22%% slower S_PP, grid-starved at low
|
||||
- // n_seqs + 1 block/SM occupancy). Default OFF so the backend default is regression-free;
|
||||
- // enable for experiments / tuning with GDN_CHUNK_MIN=<token-threshold>. See README section 5 (dev notes / rejected-flat levers).
|
||||
- static const int gdn_chunk_min = []{ const char * e = getenv("GDN_CHUNK_MIN"); return e ? atoi(e) : INT_MAX; }();
|
||||
+ // DEFAULT-ON UNDER PAGED KV (f32-only re-port of patch 0044's M5). The M5 tensor-core path
|
||||
+ // (GDN_TC=5: full-TC form-T solve + state-update mma, state in the 64KB smem buffer, C=16)
|
||||
+ // is greedy-bit-exact (per-path md5 == the sequential canonical on the short gate prompt)
|
||||
+ // and *beats* the tuned sequential recurrence on the Qwen3.6 MoE prefill: GB10,
|
||||
+ // q36-35b-a3b-nvfp4, LLAMA_KV_PAGED=1 LLAMA_MOE_FORCE_GRAPHS=1, -ntg 4 -npl 32:
|
||||
+ // -npp 512 : S_PP +3.5% ; -npp 2048 : S_PP +17.7% (more chunks to parallelize).
|
||||
+ // Decode S_TG is unchanged (1-token calls never reach the engage threshold).
|
||||
+ // GDN_CHUNK_MIN is the per-call engage threshold and MUST stay > 1: decode is 1 token/call,
|
||||
+ // so any threshold above 1 leaves every decode step on the sequential recurrence (at
|
||||
+ // GDN_CHUNK_MIN=1 the chunked path swallows decode and collapses S_TG by ~25%). Tuned to 64:
|
||||
+ // above decode/tiny-call sizes, below the real MoE-prefill per-call count. OFF (INT_MAX) when
|
||||
+ // not paged, so the stock / non-paged default is regression-free. Both knobs env-overridable.
|
||||
+ static const bool kv_paged = (getenv("LLAMA_KV_PAGED") != nullptr);
|
||||
+ static const int gdn_chunk_min = []{
|
||||
+ const char * e = getenv("GDN_CHUNK_MIN");
|
||||
+ if (e) return atoi(e);
|
||||
+ return kv_paged ? 64 : INT_MAX;
|
||||
+ }();
|
||||
+ // Tensor-core level selector (single build, clean runtime A/B). GDN_TC:
|
||||
+ // 0 = serial scan (patch 0031); 1 = KK/QK Gram mma (M2);
|
||||
+ // 2 = + KS/QS state-boundary mma, 3xtf32 (M3); 3 = + P*U output mma (M4);
|
||||
+ // 4/5+ = M5 (full TC: form-T solve + state-update mma) - the DEFAULT under paged KV.
|
||||
+ // (The bf16 CONFIG-C and register-resident M6/M7/M8 occupancy variants of patch 0044 are
|
||||
+ // intentionally absent from this f32-only series; the +3.5/+17.7% prefill win is the M5 path.)
|
||||
+ // GDN_GRAM_MMA=1 is kept as an alias for level 1.
|
||||
+ static const int gdn_tc = []{
|
||||
+ const char * e = getenv("GDN_TC");
|
||||
+ if (e) return atoi(e);
|
||||
+ const char * gm = getenv("GDN_GRAM_MMA");
|
||||
+ if (gm && atoi(gm) != 0) return 1;
|
||||
+ return kv_paged ? 5 : 0;
|
||||
+ }();
|
||||
if (S_v == 128 && n_tokens >= gdn_chunk_min) {
|
||||
- launch_gdn_chunked<128, 16>(
|
||||
- q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head,
|
||||
- H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3,
|
||||
- neqk1_magic, rq3_magic, scale, stream);
|
||||
+#define GDN_CHUNKED_LAUNCH(TC_) \
|
||||
+ launch_gdn_chunked<128, 16, TC_>( \
|
||||
+ q_d, k_d, v_d, g_d, b_d, (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, \
|
||||
+ H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, \
|
||||
+ neqk1_magic, rq3_magic, scale, stream)
|
||||
+ switch (gdn_tc) {
|
||||
+ case 0: GDN_CHUNKED_LAUNCH(0); break;
|
||||
+ case 1: GDN_CHUNKED_LAUNCH(1); break;
|
||||
+ case 2: GDN_CHUNKED_LAUNCH(2); break;
|
||||
+ case 3: GDN_CHUNKED_LAUNCH(3); break;
|
||||
+ default: GDN_CHUNKED_LAUNCH(4); break; // GDN_TC >= 4 -> M5 (full TC, kernel TC=4)
|
||||
+ }
|
||||
+#undef GDN_CHUNKED_LAUNCH
|
||||
return;
|
||||
}
|
||||
}
|
||||
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
|
||||
index 4e40d2353..817069860 100644
|
||||
--- a/tests/test-backend-ops.cpp
|
||||
+++ b/tests/test-backend-ops.cpp
|
||||
@@ -9372,6 +9372,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));
|
||||
+ // Tensor-core chunked-GDN prefill path (S_v==128): multi-chunk (C=16) coverage,
|
||||
+ // incl. a tail chunk (100 = 6*16+4) and multi-seq. Exercised via GDN_CHUNK_MIN + GDN_TC.
|
||||
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1));
|
||||
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 100, 1));
|
||||
+ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 128, 2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true));
|
||||
--
|
||||
2.43.0
|
||||
|
||||
Reference in New Issue
Block a user