feat(paged): qwen35 SSM decode fused recurrent-state gather (patch 0019)

Mirror of the llama-paged-dev patch 0019 engine change plus the measured
results. Step 2 of the SSM decode work: after Step 1 (in-place state write-back,
patch 0018) the largest non-GEMM decode bucket was the recurrent-state get_rows
gather (18.8 percent of decode GPU time). This removes that materialization,
mirroring ggml_ssm_scan's ids source: ggml_gated_delta_net_inplace_ids reads each
sequence's prior state directly from cache[ids[seq]] (src[5] = full cache,
src[7] = ids), so combined with Step 1's in-place write the op reads AND writes
the cache directly with no state materialization at all.

Race-free by construction: identity sequences (ids[seq] == rs_head + seq, the
whole AR decode path) read s0 in place from the destination slot; non-identity
sequences (reorder / rs_zero, e.g. multi-new-seq prefill) read from a disjoint
scratch a small gather kernel populates first. ids stays a device pointer.
Bit-identical to the get_rows path. Gated to qwen35 + qwen35moe; qwen3next,
kimi-linear, the non-fused and rollback paths are unchanged.

Measured (decode_agg S_TG, npp128 ntg128, -fa on, paged on, fusion off):
  q36-27b-nvfp4 dense: npl32 137.64 -> 170.68 (+24.0 percent),
    npl128 186.25 -> 256.57 (+37.8 percent, 47.6 -> 65.6 percent of vLLM 391).
  q36-35b-a3b-nvfp4 MoE: npl32 299.68 -> 366.69 (+22.4 percent),
    npl128 409.30 -> 553.63 (+35.3 percent).
Greedy (--temp 0 --seed 1) llama-completion bit-identical vs the Step-1 build
(dense + MoE). nsys k_get_rows_float bucket 18.8 -> 0.7 percent. The residual
decode gap to vLLM is now the FP4 GEMM (~48 percent of decode). See
SSM_DECODE_FIX_RESULTS.md.

Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-06-24 23:47:51 +00:00
parent 5ce2f1df51
commit 6f0792c3be
2 changed files with 764 additions and 0 deletions

View File

@@ -0,0 +1,678 @@
From 46d7dd80bbce7f3c1dbf9363d6527c8c9b687a6b Mon Sep 17 00:00:00 2001
From: Ettore Di Giacinto <mudler@localai.io>
Date: Thu, 25 Jun 2026 01:45:02 +0200
Subject: [PATCH] feat(paged): qwen35 SSM decode fused recurrent-state gather
(patch 0019)
Step 2 of the SSM decode-throughput work. After Step 1 (in-place state
write-back, patch 0018) the largest non-GEMM decode bucket was the recurrent-
state get_rows gather (18.8% of decode GPU time): build_rs materialized each
sequence's prior state into a contiguous scratch via ggml_get_rows before the
gated-DeltaNet op read it.
This eliminates that materialization, mirroring ggml_ssm_scan's ids source.
ggml_gated_delta_net_inplace_ids takes the FULL recurrent-state cache plus the
s_copy ids (src[5] = full cache, src[7] = ids, op_param[1] = rs_head) and reads
each sequence's prior state directly from cache[ids[seq]]. Combined with Step 1's
in-place write the op now reads AND writes the cache directly: no recurrent-state
materialization at all. build_recurrent_attn feeds the full cache + ids through
the build_rs get_state_rows lambda exactly like mamba-base, keeping the rs_zero
clear and the extra-states copy around the op.
Race-free by construction on CUDA. In-place write plus an ids read of the same
cache is only safe when read slot == write slot; s_copy is identity
(rs_head + s) for stable continuing sequences (the whole AR decode path) but can
remap on reorder or rs_zero (e.g. multiple new sequences in one prefill ubatch).
The recurrence kernel handles both per (seq, head) block on device: identity
sequences read s0 in place from the destination slot (the kernel loads all of s0
into registers before writing, so reading and writing the same slot is safe),
and non-identity sequences read from a disjoint scratch that a small gather
kernel copies from cache[ids[seq]] first, so the recurrence never reads a slot
another block writes. The CPU op mirrors this (host identity check + a serial
gather in the dispatcher). ids stays a device pointer (read only in-kernel; it is
device-resident at op-execute time). Bit-identical to the get_rows path in every
case.
- new builder ggml_gated_delta_net_inplace_ids; CUDA gather kernel
(gdn_gather_nonident) + per-block read-base select in gated_delta_net_cuda;
CPU identity guard + serial gather fallback in the dispatcher
- delta-net-base build_recurrent_attn gains a gather-free overload; qwen35 and
qwen35moe drop the pre-gather. qwen3next, kimi-linear, the non-fused path and
the rollback (n_rs_seq > 0) path are unchanged.
Measured (decode_agg S_TG, npp128 ntg128, -fa on, paged on, fusion off):
dense q36-27b-nvfp4 : npl 32 137.64 -> 170.68 (+24.0 percent)
npl 128 186.25 -> 256.57 (+37.8 percent, 47.6 -> 65.6 percent of vLLM 391)
MoE q36-35b-a3b-nvfp4: npl 32 299.68 -> 366.69 (+22.4 percent)
npl 128 409.30 -> 553.63 (+35.3 percent)
Greedy (--temp 0 --seed 1) llama-completion bit-identical vs the Step-1 build
(dense model text md5 match, MoE byte-identical, step2 run1 == run2). nsys
k_get_rows_float bucket 18.8 -> 0.7 percent; the new gdn_gather_nonident kernel
is 1.7 percent (no-op at decode, median 1.2 us). The residual decode gap to vLLM
is now the FP4 GEMM (~48 percent of decode), a separate kernel track.
Assisted-by: Claude:opus-4.8 [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
---
SSM_DECODE_FIX_RESULTS.md | 86 +++++++++++++++++++++++++++
ggml/include/ggml.h | 17 ++++++
ggml/src/ggml-cpu/ops.cpp | 49 ++++++++++++++-
ggml/src/ggml-cuda/gated_delta_net.cu | 85 ++++++++++++++++++++++----
ggml/src/ggml.c | 76 +++++++++++++++++++++++
src/models/delta-net-base.cpp | 63 ++++++++++++++++++++
src/models/models.h | 13 ++++
src/models/qwen35.cpp | 6 +-
src/models/qwen35moe.cpp | 6 +-
9 files changed, 378 insertions(+), 23 deletions(-)
diff --git a/SSM_DECODE_FIX_RESULTS.md b/SSM_DECODE_FIX_RESULTS.md
index 2e7c8c2..77879e4 100644
--- a/SSM_DECODE_FIX_RESULTS.md
+++ b/SSM_DECODE_FIX_RESULTS.md
@@ -96,3 +96,89 @@ precedent (`ssm_scan` `ids`) and is the clear next move. The residual gap to vLL
after both SSM steps is the FP4 GEMM (~37% of decode), which is a separate kernel
track. No paged/graph/block-table change can move decode on this model (full
attention is 0.4% of decode).
+
+## STEP 2 (patch 0019): fuse the recurrent-state gather into the op
+
+After Step 1 the largest non-GEMM decode bucket was the recurrent-state
+`get_rows` gather (18.8% of decode GPU time): `build_rs` materialized each
+sequence's prior state into a contiguous scratch via `ggml_get_rows` before the
+gated-DeltaNet op read it. Step 2 eliminates that materialization, mirroring
+`ggml_ssm_scan`'s `ids` source.
+
+`ggml_gated_delta_net_inplace_ids` takes the FULL recurrent-state cache plus the
+`s_copy` ids (`src[5]` = full cache `[S_v, S_v, H, n_rs_slots]`, `src[7]` = ids,
+`op_param[1]` = `rs_head`) and reads each sequence's prior state directly from
+`cache[ids[seq]]`. Combined with Step 1's in-place write the op now reads AND
+writes the cache directly: no recurrent-state materialization at all. The
+`build_recurrent_attn` fused path feeds the full cache and ids through the
+`build_rs` `get_state_rows` lambda exactly like `mamba-base.cpp`, keeping the
+`rs_zero` clear and the extra-states copy around the op.
+
+### Race-free by construction (CUDA)
+
+In-place write plus an ids read of the same cache is only safe when the read slot
+equals the write slot. `s_copy(s) = cells[s + head].src0`, which is identity
+(`rs_head + s`) for stable continuing sequences (the entire AR decode path) but
+can remap on sequence reorder or `rs_zero` (e.g. multiple new sequences in one
+prefill ubatch). The kernel handles both per (seq, head) block on device:
+
+- identity sequences read `s0` in place from the destination slot `state_dst`
+ (the kernel loads all of `s0` into registers before it writes the new state,
+ so reading and writing the same slot is race-free) -- no materialization;
+- non-identity sequences read from a disjoint scratch that a small
+ `gdn_gather_nonident_kernel` copies from `cache[ids[seq]]` first, so the
+ recurrence never reads a slot another block writes.
+
+`ids` stays a device pointer (dereferenced only in the kernels; the input is
+device-resident at op-execute time, so a host read segfaults). The CPU op
+mirrors the same logic (host identity check + a serial gather in the dispatcher
+for the non-identity case). The math is unchanged, so the result is bit-identical
+to the `get_rows` path in every case.
+
+Gated to the `qwen35` / `qwen35moe` fused decode/prefill path; `qwen3next`,
+`kimi-linear`, the non-fused path and the rollback (`n_rs_seq > 0`) path are
+untouched (they keep the materialized-state overload).
+
+### Measured decode_agg (`S_TG` t/s, npp 128, ntg 128, -fa on, paged on, fusion off)
+
+Dense `q36-27b-nvfp4`:
+
+| npl | Step 1 (baseline) | Step 2 | delta | % of vLLM (391 @128) |
+|-----|-------------------|----------|---------|----------------------|
+| 32 | 137.64 | 170.68 | +24.0% | - |
+| 128 | 186.25 | 256.57 | +37.8% | 47.6% -> 65.6% |
+
+The npl-128 result (256.57 t/s) beats the predicted ~247 t/s Step-2 ceiling.
+
+MoE `q36-35b-a3b-nvfp4`:
+
+| npl | Step 1 (baseline) | Step 2 | delta |
+|-----|-------------------|----------|---------|
+| 32 | 299.68 | 366.69 | +22.4% |
+| 128 | 409.30 | 553.63 | +35.3% |
+
+(Step-1 baselines re-measured in the same session; the brief's reference figures
+were 136 / 180 dense and 279 / 373 MoE.)
+
+### Bit-exact gate
+
+Greedy (`--temp 0 --seed 1`) `llama-completion` output (256 tokens, paged on,
+fusion off) vs the Step-1 build:
+
+- dense `q36-27b-nvfp4`: model text byte-identical (md5 match);
+- MoE `q36-35b-a3b-nvfp4`: byte-identical;
+- Step-2 dense run1 == run2 (deterministic, no race).
+
+### nsys confirmation (npp 128, ntg 24, npl 128, fusion off, eager)
+
+The recurrent-state gather bucket collapsed:
+
+| kernel | Step 1 | Step 2 |
+|----------------------------|----------|-----------------------------------------|
+| `k_get_rows_float` | 18.8% | 0.7% (residual: embeddings / conv-state)|
+| `gdn_gather_nonident` | - | 1.7% (no-op at decode, median ~1.2 us) |
+| `gated_delta_net_cuda` | 26.0% | 22.5% |
+| FP4 GEMM family | ~37.5% | ~48% (now the dominant residual) |
+
+The SSM state gather is effectively eliminated. The residual decode gap to vLLM
+is now the FP4 GEMM (~48% of decode), a separate kernel track.
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 4e7ab32..951dd21 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -2593,6 +2593,23 @@ extern "C" {
struct ggml_tensor * state,
struct ggml_tensor * state_dst);
+ // Step 2: same recurrence as ggml_gated_delta_net_inplace, but the prior recurrent state is read
+ // directly from the full state cache via per-sequence indices (ids == s_copy), mirroring
+ // ggml_ssm_scan, instead of from a materialized ggml_get_rows gather. `state` is the FULL cache
+ // [S_v, S_v, H, n_rs_slots]; `ids` are the per-seq source slots; `rs_head` is the destination
+ // base slot. Eliminates the recurrent-state gather on the decode path.
+ GGML_API struct ggml_tensor * ggml_gated_delta_net_inplace_ids(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * g,
+ struct ggml_tensor * beta,
+ struct ggml_tensor * state,
+ struct ggml_tensor * state_dst,
+ struct ggml_tensor * ids,
+ int rs_head);
+
// custom operators
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 9457add..b6a1976 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -10633,7 +10633,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const int64_t K = ggml_get_op_params_i32(dst, 0);
GGML_ASSERT(K >= 1);
// per-seq stride in floats (seq s starts at state + s * seq_stride)
- const int64_t state_seq_stride = src_state->nb[3] / sizeof(float);
+ int64_t state_seq_stride = src_state->nb[3] / sizeof(float);
const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
const int ith = params->ith;
@@ -10654,6 +10654,26 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const float * state_in_base = (const float *)src_state->data;
+ // Step 2: fused recurrent-state gather (ids == s_copy in src[7]). Read the prior state directly
+ // from the full cache at cache[ids[seq]] instead of from a materialized gather. For the identity
+ // decode case the prior state is the in-place destination block [rs_head, rs_head+n_seqs);
+ // otherwise the dispatcher has gathered cache[ids[seq]] into the (unused) output-state scratch
+ // region. Bit-identical to the get_rows path.
+ ggml_tensor * src_ids = dst->src[7];
+ if (src_ids != nullptr) {
+ const int64_t D = S_v * S_v * H;
+ const int32_t rs_head = ggml_get_op_params_i32(dst, 1);
+ const int32_t * ids = (const int32_t *) src_ids->data;
+ bool identity = true;
+ for (int64_t s = 0; s < n_seqs; ++s) {
+ if (ids[s] != rs_head + (int32_t) s) { identity = false; break; }
+ }
+ state_seq_stride = D;
+ state_in_base = identity
+ ? (const float *) src_state->data + (int64_t) rs_head * D
+ : (const float *) state_out_base; // gathered by the dispatcher (non-identity)
+ }
+
//const int64_t rq1 = nev1 / neq1;
//const int64_t rk1 = nev1 / nek1;
const int64_t rq3 = nev3 / neq3;
@@ -10777,6 +10797,33 @@ static void ggml_compute_forward_gated_delta_net_f32(
if (ith == 0) {
ggml_threadpool_chunk_set(params->threadpool, nth);
+
+ // Step 2: non-identity ids fallback -- serially gather each sequence's prior state from
+ // cache[ids[seq]] into the (otherwise unused) output-state scratch region before the parallel
+ // recurrence, so the in-place write never aliases another sequence's read.
+ ggml_tensor * src_ids = dst->src[7];
+ if (src_ids != nullptr) {
+ const ggml_tensor * src_state = dst->src[5];
+ const int64_t S_v = V->ne[0];
+ const int64_t H = V->ne[1];
+ const int64_t n_tokens = V->ne[2];
+ const int64_t n_seqs = V->ne[3];
+ const int64_t D = S_v * S_v * H;
+ const int32_t rs_head = ggml_get_op_params_i32(dst, 1);
+ const int32_t * ids = (const int32_t *) src_ids->data;
+ bool identity = true;
+ for (int64_t s = 0; s < n_seqs; ++s) {
+ if (ids[s] != rs_head + (int32_t) s) { identity = false; break; }
+ }
+ if (!identity) {
+ const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
+ const float * cache = (const float *) src_state->data;
+ float * scratch = (float *) dst->data + attn_score_elems;
+ for (int64_t s = 0; s < n_seqs; ++s) {
+ memcpy(scratch + s * D, cache + (int64_t) ids[s] * D, D * sizeof(float));
+ }
+ }
+ }
}
ggml_barrier(params->threadpool);
diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu
index 61a2b91..86d5e2a 100644
--- a/ggml/src/ggml-cuda/gated_delta_net.cu
+++ b/ggml/src/ggml-cuda/gated_delta_net.cu
@@ -1,6 +1,34 @@
#include "gated_delta_net.cuh"
#include "ggml-cuda/common.cuh"
+// Step 2: gather only the NON-identity sequences' prior recurrent state from the full cache into a
+// disjoint scratch buffer. Identity sequences (ids[s] == rs_head + s) are read in place from the
+// destination slot by the recurrence kernel and are skipped here. One block per sequence.
+__global__ void gdn_gather_nonident_kernel(const float * cache, const int32_t * ids, int rs_head,
+ float * scratch, int64_t D, int n_seqs) {
+ const int s = blockIdx.x;
+ if (s >= n_seqs) {
+ return;
+ }
+ const int r = ids[s];
+ if (r == rs_head + s) {
+ return; // identity: prior state already lives in the in-place destination slot
+ }
+ const float * src = cache + (int64_t) r * D;
+ float * dst = scratch + (int64_t) s * D;
+ for (int64_t i = threadIdx.x; i < D; i += blockDim.x) {
+ dst[i] = src[i];
+ }
+}
+
+static void ggml_cuda_gdn_gather_nonident(const float * cache, const int32_t * ids, int rs_head,
+ float * scratch, int64_t D, int64_t n_seqs, cudaStream_t stream) {
+ if (n_seqs <= 0) {
+ return;
+ }
+ gdn_gather_nonident_kernel<<<(unsigned) n_seqs, 256, 0, stream>>>(cache, ids, rs_head, scratch, D, (int) n_seqs);
+}
+
template <int S_v, bool KDA, bool keep_rs_t>
__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
gated_delta_net_cuda(const float * q,
@@ -26,7 +54,9 @@ gated_delta_net_cuda(const float * q,
const uint3 rq3_magic,
float scale,
int K,
- float * state_dst) {
+ float * state_dst,
+ const int32_t * ids,
+ int rs_head) {
const uint32_t h_idx = blockIdx.x;
const uint32_t sequence = blockIdx.y;
// each warp owns one column, using warp-level primitives to reduce across rows
@@ -48,7 +78,15 @@ gated_delta_net_cuda(const float * q,
const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v;
const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v;
state += state_out_offset;
- curr_state += state_in_offset + col * S_v;
+ // Step 2: select the prior-state read base per sequence. For the ids variant, identity
+ // sequences (ids[seq] == rs_head + seq) read s0 directly from the in-place destination slot
+ // state_dst (no materialization); non-identity sequences read from the pre-gathered scratch
+ // (curr_state). state_in_offset == state_out_offset, so both bases use the same per-(seq,head)
+ // offset. The whole s0 is loaded into registers before the new state is written, so reading and
+ // writing the same slot per block (identity) is race-free.
+ const float * read_state = (ids != nullptr && ids[sequence] == rs_head + (int) sequence)
+ ? state_dst : curr_state;
+ read_state += state_in_offset + col * S_v;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
@@ -61,7 +99,7 @@ gated_delta_net_cuda(const float * q,
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
- s_shard[r] = curr_state[i];
+ s_shard[r] = read_state[i];
}
for (int t = 0; t < n_tokens; t++) {
@@ -176,6 +214,7 @@ static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d,
float * dst_d, float * state_dst_d,
+ const int32_t * ids_d, int rs_head,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
@@ -199,26 +238,26 @@ static void launch_gated_delta_net(
ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
- sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K, state_dst_d);
+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K, state_dst_d, ids_d, rs_head);
break;
case 32:
ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
- sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K, state_dst_d);
+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K, state_dst_d, ids_d, rs_head);
break;
case 64: {
ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
- sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K, state_dst_d);
+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K, state_dst_d, ids_d, rs_head);
break;
}
case 128: {
ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params,
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
- sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K, state_dst_d);
+ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K, state_dst_d, ids_d, rs_head);
break;
}
default:
@@ -262,7 +301,6 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
const float * g_d = (const float *) src_g->data;
const float * b_d = (const float *) src_beta->data;
- const float * s_d = (const float *) src_state->data;
float * dst_d = (float *) dst->data;
float * state_dst_d = nullptr;
@@ -274,6 +312,29 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
state_dst_d = (float *) src_state_dst->data;
}
+ // Step 2: fused recurrent-state gather (src[7] = ids == s_copy). Read the prior state directly
+ // from the full cache via ids instead of from a materialized ggml_get_rows gather. The recurrence
+ // kernel reads identity sequences (ids[seq] == rs_head + seq) in place from state_dst (no
+ // materialization at all); any non-identity sequence (reorder / rs_zero remap) is gathered here
+ // into a disjoint scratch that the kernel reads instead. The gather writes a disjoint buffer and
+ // the recurrence never reads a slot another block writes, so it is race-free and bit-identical to
+ // the get_rows path. ids stays a DEVICE pointer (dereferenced only inside the kernels).
+ ggml_tensor * src_ids = dst->src[7];
+ const float * s_d = (const float *) src_state->data;
+ const int32_t * ids_d = nullptr;
+ int rs_head = 0;
+ ggml_cuda_pool_alloc<float> ids_state_scratch(ctx.pool());
+ if (src_ids != nullptr) {
+ GGML_ASSERT(state_dst_d != nullptr);
+ GGML_ASSERT(src_ids->type == GGML_TYPE_I32);
+ rs_head = ggml_get_op_params_i32(dst, 1);
+ ids_d = (const int32_t *) src_ids->data;
+ const int64_t D = S_v * S_v * H;
+ float * scratch = ids_state_scratch.alloc((size_t) D * n_seqs);
+ ggml_cuda_gdn_gather_nonident(s_d, ids_d, rs_head, scratch, D, n_seqs, ctx.stream());
+ s_d = scratch;
+ }
+
GGML_ASSERT(ggml_is_contiguous_rows(src_q));
GGML_ASSERT(ggml_is_contiguous_rows(src_k));
GGML_ASSERT(ggml_is_contiguous_rows(src_v));
@@ -307,21 +368,21 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
if (kda) {
if (keep_rs) {
- launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d,
+ launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d, ids_d, rs_head,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
} else {
- launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d,
+ launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d, ids_d, rs_head,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
}
} else {
if (keep_rs) {
- launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d,
+ launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d, ids_d, rs_head,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
} else {
- launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d,
+ launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d, ids_d, rs_head,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
}
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index b8d34bf..1762037 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -6353,6 +6353,82 @@ struct ggml_tensor * ggml_gated_delta_net_inplace(
return result;
}
+// ggml_gated_delta_net_inplace_ids
+//
+// Same recurrence as ggml_gated_delta_net_inplace, but the prior recurrent state is read directly
+// from the FULL state cache `state` ([S_v, S_v, H, n_rs_slots]) at cache[ids[seq]] (mirroring
+// ggml_ssm_scan's ids source) instead of from a materialized ggml_get_rows gather. `rs_head` is the
+// destination base slot, used by the backend to detect the common identity case (ids[s] == rs_head
+// + s), where the prior state already lives in the in-place destination slots.
+struct ggml_tensor * ggml_gated_delta_net_inplace_ids(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * g,
+ struct ggml_tensor * beta,
+ struct ggml_tensor * state,
+ struct ggml_tensor * state_dst,
+ struct ggml_tensor * ids,
+ int rs_head) {
+ GGML_ASSERT(ggml_is_contiguous_rows(q));
+ GGML_ASSERT(ggml_is_contiguous_rows(k));
+ GGML_ASSERT(ggml_is_contiguous_rows(v));
+ GGML_ASSERT(ggml_is_contiguous(g));
+ GGML_ASSERT(ggml_is_contiguous(beta));
+ GGML_ASSERT(ggml_is_contiguous(state));
+
+ GGML_ASSERT(q->type == GGML_TYPE_F32);
+ GGML_ASSERT(k->type == GGML_TYPE_F32);
+ GGML_ASSERT(v->type == GGML_TYPE_F32);
+ GGML_ASSERT(g->type == GGML_TYPE_F32);
+ GGML_ASSERT(beta->type == GGML_TYPE_F32);
+ GGML_ASSERT(state->type == GGML_TYPE_F32);
+ GGML_ASSERT(state_dst != NULL && state_dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(ids != NULL && ids->type == GGML_TYPE_I32);
+
+ const int64_t S_v = v->ne[0];
+ const int64_t H = v->ne[1];
+ const int64_t n_tokens = v->ne[2];
+ const int64_t n_seqs = v->ne[3];
+
+ GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
+ GGML_ASSERT(beta->ne[0] == 1);
+
+ // state is the FULL recurrent-state cache: [S_v, S_v, H, n_rs_slots], n_rs_slots >= n_seqs
+ GGML_ASSERT(state->ne[0] == S_v);
+ GGML_ASSERT(state->ne[1] == S_v);
+ GGML_ASSERT(state->ne[2] == H);
+ GGML_ASSERT(state->ne[3] >= n_seqs);
+
+ // state_dst holds the per-seq final state contiguously: [S_v*S_v*H, >= n_seqs]
+ GGML_ASSERT(state_dst->ne[0] == S_v * S_v * H);
+ GGML_ASSERT(state_dst->ne[1] >= n_seqs);
+ GGML_ASSERT(state_dst->nb[0] == sizeof(float));
+
+ // ids: per-seq source slot into the full cache (s_copy_main)
+ GGML_ASSERT(ids->ne[0] >= n_seqs);
+
+ const int64_t state_rows = S_v * n_seqs; // K == 1
+ const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ ggml_set_op_params_i32(result, 0, 1); // K == 1
+ ggml_set_op_params_i32(result, 1, rs_head); // destination base slot (for the ids identity check)
+
+ result->op = GGML_OP_GATED_DELTA_NET;
+ result->src[0] = q;
+ result->src[1] = k;
+ result->src[2] = v;
+ result->src[3] = g;
+ result->src[4] = beta;
+ result->src[5] = state; // FULL cache (read via ids)
+ result->src[6] = state_dst; // in-place final-state write-back target
+ result->src[7] = ids; // per-seq source slots (s_copy)
+
+ return result;
+}
+
////////////////////////////////////////////////////////////////////////////////
struct ggml_hash_set ggml_hash_set_new(size_t size) {
diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp
index 26a718b..194e611 100644
--- a/src/models/delta-net-base.cpp
+++ b/src/models/delta-net-base.cpp
@@ -524,6 +524,69 @@ ggml_tensor * llm_build_delta_net_base::build_conv_state(
return conv_input;
}
+// Step 2: gather-free recurrent attention. Mirrors mamba-base's get_ssm_rows pattern: the fused
+// gated-DeltaNet op reads each sequence's prior state directly from the full cache via the s_copy
+// ids (no ggml_get_rows materialization) and writes the new state in place (Step 1). The non-fused
+// and rollback paths fall back to materializing the prior state and delegating below.
+ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
+ llm_graph_input_rs * inp,
+ ggml_tensor * ssm_states_all,
+ ggml_tensor * q,
+ ggml_tensor * k,
+ ggml_tensor * v,
+ ggml_tensor * g,
+ ggml_tensor * b,
+ int il) {
+ const auto * mctx_cur = inp->mctx;
+ const auto kv_head = mctx_cur->get_head();
+
+ const int64_t S_v = v->ne[0];
+ const int64_t H_v = v->ne[1];
+ const int64_t n_seqs = v->ne[3];
+ const int64_t n_seq_tokens = q->ne[2];
+
+ const bool keep = cparams.n_rs_seq > 0;
+ const bool fused = (n_seq_tokens == 1) ? cparams.fused_gdn_ar : cparams.fused_gdn_ch;
+
+ if (!keep && fused) {
+ // build_rs feeds the FULL state cache + the s_copy ids into the op (via the get_state_rows
+ // lambda, exactly like mamba-base's ggml_ssm_scan) and still performs the rs_zero clear and
+ // the extra-states copy around it. The op reads curr_state from cache[ids[seq]] and writes
+ // the final state in place at kv_head; no recurrent-state materialization at all.
+ auto get_state_op = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) -> ggml_tensor * {
+ ggml_tensor * cache4d = ggml_reshape_4d(ctx, states, S_v, S_v, H_v, states->ne[1]);
+ ggml_tensor * state_dst = ggml_view_2d(ctx, ssm_states_all, hparams.n_embd_s(), n_seqs,
+ ssm_states_all->nb[1], kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all));
+ return ggml_gated_delta_net_inplace_ids(ctx, q, k, v, g, b, cache4d, state_dst, ids, (int) kv_head);
+ };
+
+ ggml_tensor * result = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs, get_state_op);
+ if (n_seq_tokens == 1) {
+ cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);
+ } else {
+ cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il);
+ }
+
+ ggml_tensor * output = ggml_view_4d(ctx0, result,
+ S_v, H_v, n_seq_tokens, n_seqs,
+ ggml_row_size(result->type, S_v),
+ ggml_row_size(result->type, S_v * H_v),
+ ggml_row_size(result->type, S_v * H_v * n_seq_tokens), 0);
+ cb(output, "attn_output", il);
+
+ // the state write is a side effect of the op; pull the op into the graph via the output
+ ggml_build_forward_expand(gf, output);
+
+ return output;
+ }
+
+ // non-fused / rollback: materialize the prior state via gather and delegate to the
+ // state-taking overload (its fused !keep branch performs the Step-1 in-place write).
+ ggml_tensor * s = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+ s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs);
+ return build_recurrent_attn(inp, ssm_states_all, q, k, v, g, b, s, il);
+}
+
ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
llm_graph_input_rs * inp,
ggml_tensor * ssm_states_all,
diff --git a/src/models/models.h b/src/models/models.h
index 2ac8415..98b89e9 100644
--- a/src/models/models.h
+++ b/src/models/models.h
@@ -88,6 +88,19 @@ struct llm_build_delta_net_base : public llm_graph_context {
ggml_tensor * b,
ggml_tensor * s,
int il);
+
+ // Step 2: gather-free variant. Reads the prior recurrent state directly from the full cache via
+ // the s_copy ids (no ggml_get_rows materialization) on the fused decode/prefill path, and
+ // delegates to the state-taking overload for the non-fused and rollback paths.
+ ggml_tensor * build_recurrent_attn(
+ llm_graph_input_rs * inp,
+ ggml_tensor * ssm_states_all,
+ ggml_tensor * q,
+ ggml_tensor * k,
+ ggml_tensor * v,
+ ggml_tensor * g,
+ ggml_tensor * b,
+ int il);
};
struct llm_build_rwkv6_base : public llm_graph_context {
diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp
index 6783d98..0be3247 100644
--- a/src/models/qwen35.cpp
+++ b/src/models/qwen35.cpp
@@ -385,10 +385,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il);
- ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
- state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
- cb(state, "state_predelta", il);
-
ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
cb(conv_output_proper, "conv_output_raw", il);
@@ -445,7 +441,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
- ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il);
+ ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, il);
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp
index eb5e9a4..2995f04 100644
--- a/src/models/qwen35moe.cpp
+++ b/src/models/qwen35moe.cpp
@@ -409,10 +409,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il);
- ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
- state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
- cb(state, "state_predelta", il);
-
ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
cb(conv_output_proper, "conv_output_raw", il);
@@ -469,7 +465,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);
- ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il);
+ ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, il);
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
--
2.43.0

View File

@@ -96,3 +96,89 @@ precedent (`ssm_scan` `ids`) and is the clear next move. The residual gap to vLL
after both SSM steps is the FP4 GEMM (~37% of decode), which is a separate kernel
track. No paged/graph/block-table change can move decode on this model (full
attention is 0.4% of decode).
## STEP 2 (patch 0019): fuse the recurrent-state gather into the op
After Step 1 the largest non-GEMM decode bucket was the recurrent-state
`get_rows` gather (18.8% of decode GPU time): `build_rs` materialized each
sequence's prior state into a contiguous scratch via `ggml_get_rows` before the
gated-DeltaNet op read it. Step 2 eliminates that materialization, mirroring
`ggml_ssm_scan`'s `ids` source.
`ggml_gated_delta_net_inplace_ids` takes the FULL recurrent-state cache plus the
`s_copy` ids (`src[5]` = full cache `[S_v, S_v, H, n_rs_slots]`, `src[7]` = ids,
`op_param[1]` = `rs_head`) and reads each sequence's prior state directly from
`cache[ids[seq]]`. Combined with Step 1's in-place write the op now reads AND
writes the cache directly: no recurrent-state materialization at all. The
`build_recurrent_attn` fused path feeds the full cache and ids through the
`build_rs` `get_state_rows` lambda exactly like `mamba-base.cpp`, keeping the
`rs_zero` clear and the extra-states copy around the op.
### Race-free by construction (CUDA)
In-place write plus an ids read of the same cache is only safe when the read slot
equals the write slot. `s_copy(s) = cells[s + head].src0`, which is identity
(`rs_head + s`) for stable continuing sequences (the entire AR decode path) but
can remap on sequence reorder or `rs_zero` (e.g. multiple new sequences in one
prefill ubatch). The kernel handles both per (seq, head) block on device:
- identity sequences read `s0` in place from the destination slot `state_dst`
(the kernel loads all of `s0` into registers before it writes the new state,
so reading and writing the same slot is race-free) -- no materialization;
- non-identity sequences read from a disjoint scratch that a small
`gdn_gather_nonident_kernel` copies from `cache[ids[seq]]` first, so the
recurrence never reads a slot another block writes.
`ids` stays a device pointer (dereferenced only in the kernels; the input is
device-resident at op-execute time, so a host read segfaults). The CPU op
mirrors the same logic (host identity check + a serial gather in the dispatcher
for the non-identity case). The math is unchanged, so the result is bit-identical
to the `get_rows` path in every case.
Gated to the `qwen35` / `qwen35moe` fused decode/prefill path; `qwen3next`,
`kimi-linear`, the non-fused path and the rollback (`n_rs_seq > 0`) path are
untouched (they keep the materialized-state overload).
### Measured decode_agg (`S_TG` t/s, npp 128, ntg 128, -fa on, paged on, fusion off)
Dense `q36-27b-nvfp4`:
| npl | Step 1 (baseline) | Step 2 | delta | % of vLLM (391 @128) |
|-----|-------------------|----------|---------|----------------------|
| 32 | 137.64 | 170.68 | +24.0% | - |
| 128 | 186.25 | 256.57 | +37.8% | 47.6% -> 65.6% |
The npl-128 result (256.57 t/s) beats the predicted ~247 t/s Step-2 ceiling.
MoE `q36-35b-a3b-nvfp4`:
| npl | Step 1 (baseline) | Step 2 | delta |
|-----|-------------------|----------|---------|
| 32 | 299.68 | 366.69 | +22.4% |
| 128 | 409.30 | 553.63 | +35.3% |
(Step-1 baselines re-measured in the same session; the brief's reference figures
were 136 / 180 dense and 279 / 373 MoE.)
### Bit-exact gate
Greedy (`--temp 0 --seed 1`) `llama-completion` output (256 tokens, paged on,
fusion off) vs the Step-1 build:
- dense `q36-27b-nvfp4`: model text byte-identical (md5 match);
- MoE `q36-35b-a3b-nvfp4`: byte-identical;
- Step-2 dense run1 == run2 (deterministic, no race).
### nsys confirmation (npp 128, ntg 24, npl 128, fusion off, eager)
The recurrent-state gather bucket collapsed:
| kernel | Step 1 | Step 2 |
|----------------------------|----------|-----------------------------------------|
| `k_get_rows_float` | 18.8% | 0.7% (residual: embeddings / conv-state)|
| `gdn_gather_nonident` | - | 1.7% (no-op at decode, median ~1.2 us) |
| `gated_delta_net_cuda` | 26.0% | 22.5% |
| FP4 GEMM family | ~37.5% | ~48% (now the dominant residual) |
The SSM state gather is effectively eliminated. The residual decode gap to vLLM
is now the FP4 GEMM (~48% of decode), a separate kernel track.