From fe5bd3f53d6828a6f8c0fadc8e38792d41d1731c Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 26 Jun 2026 16:21:33 +0000 Subject: [PATCH] feat(paged): qwen35 hybrid per-head f32/bf16 SSM state (patch 0026) Lever A patch + build/de-risk results. Splits the persisted gated-DeltaNet recurrent state per head: f32 on long-memory heads (where bf16 rounding does not contract and the KL error concentrates), bf16 on fast-decaying heads, classified at model load by tau_h = 1/(|ssm_a|*softplus(ssm_dt)). Default ssm_hybrid_tau_thresh = 0.0 keeps every head f32 (bit-exact opt-out). De-risk gates BOTH PASS: test-backend-ops GATED_DELTA_NET CUDA0 OK (incl 32 hybrid mixed CUDA-vs-CPU cases); default all-f32 greedy md5 == 0023 baseline both models (dense 5951a5b4d624ce891e22ab5fca9bc439, MoE 07db32c2bcb78d17a43ed18bc22705cd). Known open issue (opt-in hybrid only; default unaffected): hybrid-ON model decode (ids in-place path) is incoherent; classifier/cache/kernel-params verified correct, bug isolated to the ids in-place cross-step state path. See A_HYBRID_SSM_RESULTS.md. Not ready for the GateSweep until fixed. Signed-off-by: Ettore Di Giacinto Assisted-by: Claude:opus-4.8 [Claude Code] --- ...0026-qwen35-hybrid-perhead-ssm-state.patch | 1983 +++++++++++++++++ .../patches/paged/A_HYBRID_PROGRESS.md | 48 + .../patches/paged/A_HYBRID_SSM_RESULTS.md | 90 + 3 files changed, 2121 insertions(+) create mode 100644 backend/cpp/llama-cpp/patches/paged/0026-qwen35-hybrid-perhead-ssm-state.patch create mode 100644 backend/cpp/llama-cpp/patches/paged/A_HYBRID_PROGRESS.md create mode 100644 backend/cpp/llama-cpp/patches/paged/A_HYBRID_SSM_RESULTS.md diff --git a/backend/cpp/llama-cpp/patches/paged/0026-qwen35-hybrid-perhead-ssm-state.patch b/backend/cpp/llama-cpp/patches/paged/0026-qwen35-hybrid-perhead-ssm-state.patch new file mode 100644 index 000000000..bf5f580c5 --- /dev/null +++ b/backend/cpp/llama-cpp/patches/paged/0026-qwen35-hybrid-perhead-ssm-state.patch @@ -0,0 +1,1983 @@ +diff --git a/common/arg.cpp b/common/arg.cpp +index 841a38e..3e05bd4 100644 +--- a/common/arg.cpp ++++ b/common/arg.cpp +@@ -2157,6 +2157,47 @@ common_params_context common_params_parser_init(common_params & params, llama_ex + params.cache_type_v = kv_cache_type_from_str(value); + } + ).set_env("LLAMA_ARG_CACHE_TYPE_V")); ++ add_opt(common_arg( ++ {"-ctssm", "--cache-type-ssm"}, "TYPE", ++ string_format( ++ "recurrent SSM-state cache data type (default f32 for bit-exact recurrence; pass bf16 to\n" ++ "halve the dominant gated-DeltaNet decode byte stream)\n" ++ "allowed values: %s\n" ++ "(default: %s)", ++ get_all_kv_cache_types().c_str(), ++ ggml_type_name(params.cache_type_ssm) ++ ), ++ [](common_params & params, const std::string & value) { ++ params.cache_type_ssm = kv_cache_type_from_str(value); ++ } ++ ).set_env("LLAMA_ARG_CACHE_TYPE_SSM")); ++ add_opt(common_arg( ++ {"-ctconv", "--cache-type-conv"}, "TYPE", ++ string_format( ++ "recurrent conv-state cache data type (default f32)\n" ++ "allowed values: %s\n" ++ "(default: %s)", ++ get_all_kv_cache_types().c_str(), ++ ggml_type_name(params.cache_type_conv) ++ ), ++ [](common_params & params, const std::string & value) { ++ params.cache_type_conv = kv_cache_type_from_str(value); ++ } ++ ).set_env("LLAMA_ARG_CACHE_TYPE_CONV")); ++ add_opt(common_arg( ++ {"--ssm-bf16-tau"}, "TAU", ++ string_format( ++ "hybrid per-head SSM-state precision (lever A): a gated-DeltaNet head is kept f32 iff its\n" ++ "memory length tau_h = 1/(|ssm_a|*softplus(ssm_dt)) tokens > TAU; else its persisted state\n" ++ "is bf16 (halving that head's recurrence byte stream). 0 => every head f32 (bit-exact\n" ++ "default); raise it (e.g. 32/64) to bf16 the fast-decaying heads.\n" ++ "(default: %.1f)", ++ (double) params.ssm_hybrid_tau_thresh ++ ), ++ [](common_params & params, const std::string & value) { ++ params.ssm_hybrid_tau_thresh = std::stof(value); ++ } ++ ).set_env("LLAMA_ARG_SSM_BF16_TAU")); + add_opt(common_arg( + {"--hellaswag"}, + "compute HellaSwag score over random tasks from datafile supplied with -f", +diff --git a/common/common.cpp b/common/common.cpp +index a14e7bb..c4ab884 100644 +--- a/common/common.cpp ++++ b/common/common.cpp +@@ -1600,6 +1600,9 @@ struct llama_context_params common_context_params_to_llama(const common_params & + + cparams.type_k = params.cache_type_k; + cparams.type_v = params.cache_type_v; ++ cparams.type_r = params.cache_type_conv; ++ cparams.type_s = params.cache_type_ssm; ++ cparams.ssm_hybrid_tau_thresh = params.ssm_hybrid_tau_thresh; + + return cparams; + } +diff --git a/common/common.h b/common/common.h +index 94147d5..ad79a4e 100644 +--- a/common/common.h ++++ b/common/common.h +@@ -578,6 +578,9 @@ struct common_params { + + ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K + ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V ++ ggml_type cache_type_conv = GGML_TYPE_F32; // recurrent conv-state cache type (bit-exact f32 default) ++ ggml_type cache_type_ssm = GGML_TYPE_F32; // recurrent SSM-state cache type (f32 default; bf16 opt-in for decode BW) ++ float ssm_hybrid_tau_thresh = 0.0f; // hybrid per-head SSM precision (lever A): f32 head iff tau>thresh; 0 => all f32 + + common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; + +diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h +index 76fa401..2a5cbce 100644 +--- a/ggml/include/ggml.h ++++ b/ggml/include/ggml.h +@@ -2626,6 +2626,42 @@ extern "C" { + struct ggml_tensor * ids, + int rs_head); + ++ // Hybrid per-head mixed-dtype recurrent state (lever A). The persisted SSM state is split into two ++ // dtype-homogeneous partitions sized by head COUNT: `state_f32` ([S_v,S_v,n_f32,*]) holds the ++ // long-memory heads at f32, `state_bf16` ([S_v,S_v,n_bf16,*]) holds the fast-decaying heads at bf16. ++ // `head_slot` is an I32 [H] map: head h -> local_idx in its partition, encoded local_idx (>=0) for ++ // an f32 head and -(local_idx+1) (<0) for a bf16 head. q/k/v/g/beta keep natural head order. The ++ // recurrence math runs in f32 registers; only the per-head load/store crosses the dtype boundary. ++ // Output-append form (mirrors ggml_gated_delta_net): the op output carries attn + the full f32 state. ++ GGML_API struct ggml_tensor * ggml_gated_delta_net_hybrid( ++ struct ggml_context * ctx, ++ struct ggml_tensor * q, ++ struct ggml_tensor * k, ++ struct ggml_tensor * v, ++ struct ggml_tensor * g, ++ struct ggml_tensor * beta, ++ struct ggml_tensor * state_f32, ++ struct ggml_tensor * state_bf16, ++ struct ggml_tensor * head_slot, ++ int64_t K); ++ ++ // Hybrid in-place ids form (mirrors ggml_gated_delta_net_inplace_ids). `state_f32`/`state_bf16` are ++ // the FULL split caches; `state_dst_f32` is the in-place write view into the f32 partition at rs_head ++ // (the bf16 partition's write view is derived from state_bf16 + rs_head). Used by the GDN decode path. ++ GGML_API struct ggml_tensor * ggml_gated_delta_net_inplace_ids_hybrid( ++ struct ggml_context * ctx, ++ struct ggml_tensor * q, ++ struct ggml_tensor * k, ++ struct ggml_tensor * v, ++ struct ggml_tensor * g, ++ struct ggml_tensor * beta, ++ struct ggml_tensor * state_f32, ++ struct ggml_tensor * state_dst_f32, ++ struct ggml_tensor * ids, ++ struct ggml_tensor * state_bf16, ++ struct ggml_tensor * head_slot, ++ int rs_head); ++ + // custom operators + + typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); +diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c +index eb8341c..7f19cd3 100644 +--- a/ggml/src/ggml-cpu/ggml-cpu.c ++++ b/ggml/src/ggml-cpu/ggml-cpu.c +@@ -2949,7 +2949,10 @@ struct ggml_cplan ggml_graph_plan( + { + const int64_t S_v = node->src[2]->ne[0]; + const int64_t K = ggml_get_op_params_i32(node, 0); +- const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); ++ // bf16 in-place final-state cache needs an extra f32 working buffer (the ++ // recurrence cannot run in-place on a bf16 cache); mirror need_work in ops.cpp. ++ const bool inplace_bf16 = node->src[6] != NULL && node->src[6]->type == GGML_TYPE_BF16; ++ const int64_t per_thread = S_v + ((K > 1 || inplace_bf16) ? S_v * S_v : 0); + cur = per_thread * sizeof(float) * n_tasks; + } break; + case GGML_OP_COUNT: +diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp +index fbe30d6..07ab9e5 100644 +--- a/ggml/src/ggml-cpu/ops.cpp ++++ b/ggml/src/ggml-cpu/ops.cpp +@@ -2254,6 +2254,25 @@ static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, gg + } + } + ++static void ggml_compute_forward_fill_bf16(const ggml_compute_params * params, ggml_tensor * dst) { ++ const ggml_bf16_t c = GGML_FP32_TO_BF16(ggml_get_op_params_f32(dst, 0)); ++ ++ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); ++ GGML_TENSOR_LOCALS(size_t, nb, dst, nb); ++ ++ const auto [ir0, ir1] = get_thread_range(params, dst); ++ ++ for (int64_t ir = ir0; ir < ir1; ++ir) { ++ const int64_t i03 = ir/(ne2*ne1); ++ const int64_t i02 = (ir - i03*ne2*ne1)/ne1; ++ const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1); ++ ++ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1); ++ ++ for (int64_t i = 0; i < ne0; ++i) { dst_ptr[i] = c; } ++ } ++} ++ + void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + +@@ -2266,6 +2285,10 @@ void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * + { + ggml_compute_forward_fill_f16(params, dst); + } break; ++ case GGML_TYPE_BF16: ++ { ++ ggml_compute_forward_fill_bf16(params, dst); ++ } break; + default: + { + GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type)); +@@ -10699,6 +10722,12 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + ggml_tensor * src_state_dst = dst->src[6]; // optional in-place final-state write-back target ++ // Hybrid per-head mixed-dtype state (lever A): src[8] = bf16 partition, src[9] = head_slot map. ++ // The CPU reference handles the output-append form (the test path): each head reads its partition ++ // (f32 src[5] or bf16 src[8]) per head_slot[h]; the recurrence + output store stay f32 (unchanged). ++ ggml_tensor * src_state_bf16 = dst->src[8]; ++ ggml_tensor * src_head_slot = dst->src[9]; ++ const bool hybrid = (src_head_slot != nullptr); + + const int64_t S_v = src_v->ne[0]; + const int64_t H = src_v->ne[1]; +@@ -10730,14 +10759,22 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int64_t K = ggml_get_op_params_i32(dst, 0); + GGML_ASSERT(K >= 1); +- // per-seq stride in floats (seq s starts at state + s * seq_stride) +- int64_t state_seq_stride = src_state->nb[3] / sizeof(float); +- +- const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); ++ // prior-state read source: f32 cache/scratch (default) or bf16 cache (opt-in width halving). ++ // Track element size + dtype so the per-seq stride and the load conversion are correct either way. ++ size_t state_in_ts = ggml_type_size(src_state->type); ++ bool read_bf16 = (src_state->type == GGML_TYPE_BF16); ++ // per-seq stride in elements (seq s starts at state + s * seq_stride elements of state_in_ts) ++ int64_t state_seq_stride = (int64_t)(src_state->nb[3] / state_in_ts); ++ ++ // bf16 in-place final-state cache: run the recurrence in an f32 working buffer (cannot do f32 ++ // math in-place on a bf16 cache); convert-store to the bf16 cache after each (head,seq) token loop. ++ const bool inplace_bf16 = (src_state_dst != nullptr && src_state_dst->type == GGML_TYPE_BF16); ++ const bool need_work = (K > 1) || inplace_bf16; ++ const int64_t per_thread = S_v + (need_work ? S_v * S_v : 0); + const int ith = params->ith; + + float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32; +- float * state_work = K > 1 ? (delta + S_v) : nullptr; ++ float * state_work = need_work ? (delta + S_v) : nullptr; + + // output layout: [attn_scores | new_states] + // attn_scores: S_v * H * n_tokens * n_seqs floats +@@ -10750,7 +10787,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. + +- const float * state_in_base = (const float *)src_state->data; ++ const char * state_in_base = (const char *)src_state->data; + + // Step 2: fused recurrent-state gather (ids == s_copy in src[7]). Read the prior state directly + // from the full cache at cache[ids[seq]] instead of from a materialized gather. For the identity +@@ -10767,9 +10804,14 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( + if (ids[s] != rs_head + (int32_t) s) { identity = false; break; } + } + state_seq_stride = D; +- state_in_base = identity +- ? (const float *) src_state->data + (int64_t) rs_head * D +- : (const float *) state_out_base; // gathered by the dispatcher (non-identity) ++ if (identity) { ++ state_in_base = (const char *) src_state->data + (int64_t) rs_head * D * state_in_ts; ++ } else { ++ // dispatcher gathered cache[ids[seq]] into the f32 output scratch (already widened to f32) ++ state_in_base = (const char *) state_out_base; ++ state_in_ts = sizeof(float); ++ read_bf16 = false; ++ } + } + + //const int64_t rq1 = nev1 / neq1; +@@ -10784,9 +10826,28 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( + float * inplace_state_base = nullptr; + if (src_state_dst != nullptr) { + GGML_ASSERT(K == 1); +- GGML_ASSERT(src_state_dst->nb[0] == sizeof(float)); +- GGML_ASSERT(src_state_dst->nb[1] == (size_t) S_v * S_v * H * sizeof(float)); +- inplace_state_base = (float *) src_state_dst->data; ++ GGML_ASSERT(src_state_dst->nb[0] == ggml_type_size(src_state_dst->type)); ++ GGML_ASSERT(src_state_dst->nb[1] == (size_t) S_v * S_v * H * ggml_type_size(src_state_dst->type)); ++ // f32 in-place writes directly to the cache; bf16 in-place uses state_work + convert-store. ++ inplace_state_base = inplace_bf16 ? nullptr : (float *) src_state_dst->data; ++ } ++ ++ // Hybrid (lever A): CPU reference covers the output-append form. Each head reads from its partition ++ // (f32 src[5] / bf16 src[8]) at the partition's per-seq stride; the recurrence + op-output store ++ // stay f32 and are unchanged. The in-place ids decode path is GPU-only (asserted out here). ++ const int32_t * head_slot = nullptr; ++ int64_t n_f32_heads = 0; ++ int64_t n_bf16_heads = 0; ++ const char * f32_part_base = nullptr; ++ const char * bf16_part_base = nullptr; ++ if (hybrid) { ++ GGML_ASSERT(src_ids == nullptr && src_state_dst == nullptr); // GPU owns the in-place ids path ++ GGML_ASSERT(src_state->type == GGML_TYPE_F32 && src_state_bf16->type == GGML_TYPE_BF16); ++ head_slot = (const int32_t *) src_head_slot->data; ++ n_f32_heads = src_state->ne[2]; ++ n_bf16_heads = src_state_bf16->ne[2]; ++ f32_part_base = (const char *) src_state->data; ++ bf16_part_base = (const char *) src_state_bf16->data; + } + + for (int64_t ir = ir0; ir < ir1; ++ir) { +@@ -10799,16 +10860,38 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( + const int64_t iq3 = iv3 / rq3; + const int64_t ik3 = iv3 / rk3; + +- // For K=1, write directly to the single output slot to avoid an extra memcpy at the end. +- // For K>1, work in scratch and copy out per-token when the slot is in range. +- float * s_out = (K > 1) ++ // For K=1 f32, write directly to the single output/cache slot to avoid an extra copy. ++ // For K>1 or bf16 in-place, run in the f32 scratch buffer and copy/convert out afterwards. ++ float * s_out = need_work + ? state_work + : (inplace_state_base ? inplace_state_base : state_out_base) + (iv3 * H + iv1) * S_v * S_v; + +- // copy input state into the working buffer and operate in-place +- // state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride. +- const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; +- memcpy(s_out, s_in, S_v * S_v * sizeof(float)); ++ // copy input state into the f32 working buffer and operate in-place ++ // state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride (in elements). ++ // Hybrid: select this head's partition (f32 src[5] / bf16 src[8]) with the partition's per-seq ++ // stride (n_part * S_v*S_v) and the head's local index; else the uniform single-cache layout. ++ const char * s_in; ++ bool ld_bf16; ++ if (hybrid) { ++ const int32_t hs = head_slot[iv1]; ++ const bool h_bf = hs < 0; ++ const int64_t local = h_bf ? (-(int64_t) hs - 1) : (int64_t) hs; ++ const int64_t n_p = h_bf ? n_bf16_heads : n_f32_heads; ++ const size_t ts = h_bf ? sizeof(ggml_bf16_t) : sizeof(float); ++ s_in = (h_bf ? bf16_part_base : f32_part_base) + (iv3 * n_p * S_v * S_v + local * S_v * S_v) * ts; ++ ld_bf16 = h_bf; ++ } else { ++ s_in = state_in_base + (iv3 * state_seq_stride + iv1 * S_v * S_v) * state_in_ts; ++ ld_bf16 = read_bf16; ++ } ++ if (ld_bf16) { ++ const ggml_bf16_t * s_in_bf = (const ggml_bf16_t *) s_in; ++ for (int64_t e = 0; e < S_v * S_v; ++e) { ++ s_out[e] = GGML_BF16_TO_FP32(s_in_bf[e]); ++ } ++ } else { ++ memcpy(s_out, (const float *) s_in, S_v * S_v * sizeof(float)); ++ } + + // attn output pointer for first token of this (head, seq) + float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v; +@@ -10867,6 +10950,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( + } + } + } ++ ++ // bf16 in-place final-state write-back: convert the f32 working state into the bf16 cache slot ++ // (K==1; the f32 in-place path already wrote the cache directly via s_out == inplace_state_base). ++ if (inplace_bf16) { ++ ggml_bf16_t * dst_cache = (ggml_bf16_t *) src_state_dst->data + (iv3 * H + iv1) * S_v * S_v; ++ for (int64_t e = 0; e < S_v * S_v; ++e) { ++ dst_cache[e] = GGML_FP32_TO_BF16(s_out[e]); ++ } ++ } + } + } + +@@ -10915,10 +11007,21 @@ static void ggml_compute_forward_gated_delta_net_f32( + } + if (!identity) { + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; +- const float * cache = (const float *) src_state->data; + float * scratch = (float *) dst->data + attn_score_elems; +- for (int64_t s = 0; s < n_seqs; ++s) { +- memcpy(scratch + s * D, cache + (int64_t) ids[s] * D, D * sizeof(float)); ++ // gather widens to f32: bf16 cache rows are converted; f32 cache rows are copied verbatim. ++ if (src_state->type == GGML_TYPE_BF16) { ++ const ggml_bf16_t * cache = (const ggml_bf16_t *) src_state->data; ++ for (int64_t s = 0; s < n_seqs; ++s) { ++ const ggml_bf16_t * src = cache + (int64_t) ids[s] * D; ++ for (int64_t e = 0; e < D; ++e) { ++ scratch[s * D + e] = GGML_BF16_TO_FP32(src[e]); ++ } ++ } ++ } else { ++ const float * cache = (const float *) src_state->data; ++ for (int64_t s = 0; s < n_seqs; ++s) { ++ memcpy(scratch + s * D, cache + (int64_t) ids[s] * D, D * sizeof(float)); ++ } + } + } + } +diff --git a/ggml/src/ggml-cuda/fill.cu b/ggml/src/ggml-cuda/fill.cu +index 739062c..8be4d56 100644 +--- a/ggml/src/ggml-cuda/fill.cu ++++ b/ggml/src/ggml-cuda/fill.cu +@@ -31,6 +31,9 @@ void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + case GGML_TYPE_F16: + fill_kernel<<>>((half *)dst_d, k, ggml_cuda_cast(value)); + break; ++ case GGML_TYPE_BF16: ++ fill_kernel<<>>((nv_bfloat16 *)dst_d, k, ggml_cuda_cast(value)); ++ break; + default: + GGML_ABORT("unsupported type"); + } +diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu +index d071d5a..830118a 100644 +--- a/ggml/src/ggml-cuda/gated_delta_net.cu ++++ b/ggml/src/ggml-cuda/gated_delta_net.cu +@@ -2,12 +2,38 @@ + #include "ggml-cuda/common.cuh" + + #include ++#include ++#include ++ ++// Element type of the persisted recurrent-state cache: f32 (default, bit-exact) or bf16 (opt-in, ++// halves the dominant decode byte stream). All recurrence math stays in f32 registers; only the ++// load (cache->f32) and store (f32->cache) cross the dtype boundary. Selected by STATE_BF16. ++template using gdn_state_t = std::conditional_t; ++ ++// Hybrid per-head mixed-dtype state (lever A). Both partition bases are passed and the per-block ++// branch on head_slot[h_idx] picks the partition + local index. The branch is UNIFORM within a block ++// (all threads share h_idx == blockIdx.x) so there is NO warp divergence. The recurrence math stays ++// byte-for-byte f32-register; only the per-head load/store crosses the dtype boundary. Encoding: ++// head_slot[h] = local_idx (>=0) for an f32 head, -(local_idx+1) (<0) for a bf16 head. ++struct gdn_hybrid_args { ++ const float * rin_f32; // f32 read-input base (s0 partition or gather scratch), seq-0-based ++ const nv_bfloat16 * rin_bf16; // bf16 read-input base ++ float * d_f32; // f32 in-place dst base, PRE-OFFSET to rs_head slot (null => append) ++ nv_bfloat16 * d_bf16; // bf16 in-place dst base, PRE-OFFSET to rs_head slot (null => append) ++ const int32_t * head_slot; // [H] per-head partition+local-index map ++ int n_f32; // f32 partition head count (per-seq stride n_f32*S_v*S_v) ++ int n_bf16; // bf16 partition head count ++}; + + // Step 2: gather only the NON-identity sequences' prior recurrent state from the full cache into a + // disjoint scratch buffer. Identity sequences (ids[s] == rs_head + s) are read in place from the +-// destination slot by the recurrence kernel and are skipped here. One block per sequence. +-__global__ void gdn_gather_nonident_kernel(const float * cache, const int32_t * ids, int rs_head, +- float * scratch, int64_t D, int n_seqs) { ++// destination slot by the recurrence kernel and are skipped here. One block per sequence. The ++// scratch shares the cache element type (STATE_T): the gather is a pure element copy, so the ++// recurrence kernel performs the single bf16->f32 load conversion uniformly regardless of whether a ++// sequence read the in-place slot or the gathered scratch (eliminates the mixed-dtype read path). ++template ++__global__ void gdn_gather_nonident_kernel(const STATE_T * cache, const int32_t * ids, int rs_head, ++ STATE_T * scratch, int64_t D, int n_seqs) { + const int s = blockIdx.x; + if (s >= n_seqs) { + return; +@@ -16,19 +42,20 @@ __global__ void gdn_gather_nonident_kernel(const float * cache, const int32_t * + if (r == rs_head + s) { + return; // identity: prior state already lives in the in-place destination slot + } +- const float * src = cache + (int64_t) r * D; +- float * dst = scratch + (int64_t) s * D; ++ const STATE_T * src = cache + (int64_t) r * D; ++ STATE_T * dst = scratch + (int64_t) s * D; + for (int64_t i = threadIdx.x; i < D; i += blockDim.x) { + dst[i] = src[i]; + } + } + +-static void ggml_cuda_gdn_gather_nonident(const float * cache, const int32_t * ids, int rs_head, +- float * scratch, int64_t D, int64_t n_seqs, cudaStream_t stream) { ++template ++static void ggml_cuda_gdn_gather_nonident(const STATE_T * cache, const int32_t * ids, int rs_head, ++ STATE_T * scratch, int64_t D, int64_t n_seqs, cudaStream_t stream) { + if (n_seqs <= 0) { + return; + } +- gdn_gather_nonident_kernel<<<(unsigned) n_seqs, 256, 0, stream>>>(cache, ids, rs_head, scratch, D, (int) n_seqs); ++ gdn_gather_nonident_kernel<<<(unsigned) n_seqs, 256, 0, stream>>>(cache, ids, rs_head, scratch, D, (int) n_seqs); + } + + // Occupancy/coalescing retune (patch 0022). Each warp owns COLS_PER_WARP columns of the recurrent +@@ -45,14 +72,14 @@ static void ggml_cuda_gdn_gather_nonident(const float * cache, const int32_t * i + // to hide each other's shfl latency) which covers more DRAM latency on this bandwidth-bound kernel. + // Every individual global access stays IDENTICALLY coalesced (32 consecutive lanes -> one 128B + // sector), so this is a latency-coverage / scheduling win, not a coalescing change. +-template ++template + __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * NUM_WARPS, MIN_BLOCKS) + gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, +- const float * curr_state, ++ const gdn_state_t * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, +@@ -70,9 +97,11 @@ gated_delta_net_cuda(const float * q, + const uint3 rq3_magic, + float scale, + int K, +- float * state_dst, ++ gdn_state_t * state_dst, + const int32_t * ids, +- int rs_head) { ++ int rs_head, ++ gdn_hybrid_args hyb) { ++ using STATE_T = gdn_state_t; + const uint32_t h_idx = blockIdx.x; + const uint32_t sequence = blockIdx.y; + // each warp owns COLS_PER_WARP columns, using warp-level primitives to reduce across rows. +@@ -84,27 +113,45 @@ gated_delta_net_cuda(const float * q, + + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_data = dst; +- // when state_dst is provided (in-place decode write-back) the final recurrent state is written +- // directly into the persistent cache view instead of being appended to the op output; this +- // eliminates the per-layer per-step D2D state copy-back. Only used when keep_rs_t == false. +- float * state = (state_dst != nullptr) ? state_dst : (dst + attn_score_elems); ++ // When state_dst is provided (in-place decode write-back) the final recurrent state is written ++ // directly into the persistent cache view (STATE_T, possibly bf16) instead of being appended to ++ // the f32 op output. The bf16 store conversion is applied only on that cache path; the f32 op ++ // output (keep_rs snapshots and the non-in-place final state) always stays f32. Selected at the ++ // write sites below by (state_dst != nullptr). + + // input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. +- const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; +- state += state_out_offset; +- // Step 2: select the prior-state read base per sequence. For the ids variant, identity +- // sequences (ids[seq] == rs_head + seq) read s0 directly from the in-place destination slot +- // state_dst (no materialization); non-identity sequences read from the pre-gathered scratch +- // (curr_state). state_in_offset == state_out_offset, so both bases use the same per-(seq,head) +- // offset. The whole s0 is loaded into registers before the new state is written, so reading and +- // writing the same slot per block (identity) is race-free. +- const float * read_state = (ids != nullptr && ids[sequence] == rs_head + (int) sequence) +- ? state_dst : curr_state; +- read_state += state_in_offset; + attn_data += (sequence * n_tokens * H + h_idx) * S_v; + ++ // identity decode: ids[seq] == rs_head + seq => the prior state already lives in the in-place ++ // destination slot (no materialization); else read from the pre-gathered scratch. ++ const bool identity = (ids != nullptr && ids[sequence] == rs_head + (int) sequence); ++ ++ // Homogeneous (HYBRID=false) read base: select state_dst (identity) or curr_state (gathered), ++ // both share STATE_T (the gather scratch matches the cache dtype). state_in_offset uses full H. ++ const STATE_T * read_state = nullptr; ++ if constexpr (!HYBRID) { ++ const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; ++ read_state = (identity ? state_dst : curr_state) + state_in_offset; ++ } ++ ++ // Hybrid (lever A) read bases: per-block partition from head_slot[h_idx] (uniform => no divergence). ++ bool hyb_bf16 = false; ++ int64_t hyb_seqoff = 0; // sequence * n_part * S_v*S_v + local * S_v*S_v ++ const float * hyb_rf32 = nullptr; ++ const nv_bfloat16 * hyb_rbf16 = nullptr; ++ if constexpr (HYBRID) { ++ const int hs = hyb.head_slot[h_idx]; ++ hyb_bf16 = hs < 0; ++ const int local = hyb_bf16 ? (-hs - 1) : hs; ++ const int n_p = hyb_bf16 ? hyb.n_bf16 : hyb.n_f32; ++ hyb_seqoff = (int64_t) sequence * n_p * S_v * S_v + (int64_t) local * S_v * S_v; ++ // identity reads the (pre-rs_head-offset) in-place dst partition; else the seq-0-based scratch/s0. ++ hyb_rf32 = (identity ? hyb.d_f32 : hyb.rin_f32) + hyb_seqoff; ++ hyb_rbf16 = (identity ? hyb.d_bf16 : hyb.rin_bf16) + hyb_seqoff; ++ } ++ + constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; + static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); + constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; +@@ -114,12 +161,36 @@ gated_delta_net_cuda(const float * q, + ggml_cuda_pdl_sync(); + #pragma unroll + for (int cc = 0; cc < COLS_PER_WARP; cc++) { +- const int col = col_base + cc * NUM_WARPS; +- const float * rs = read_state + col * S_v; ++ const int col = col_base + cc * NUM_WARPS; ++ if constexpr (HYBRID) { ++ // LOAD: per-head partition -> f32 register. bf16 widened via __bfloat162float; f32 verbatim. ++ if (hyb_bf16) { ++ const nv_bfloat16 * rs = hyb_rbf16 + col * S_v; + #pragma unroll +- for (int r = 0; r < rows_per_lane; r++) { +- const int i = r * warp_size + lane; +- s_shard[cc][r] = rs[i]; ++ for (int r = 0; r < rows_per_lane; r++) { ++ const int i = r * warp_size + lane; ++ s_shard[cc][r] = __bfloat162float(rs[i]); ++ } ++ } else { ++ const float * rs = hyb_rf32 + col * S_v; ++#pragma unroll ++ for (int r = 0; r < rows_per_lane; r++) { ++ const int i = r * warp_size + lane; ++ s_shard[cc][r] = rs[i]; ++ } ++ } ++ } else { ++ const STATE_T * rs = read_state + col * S_v; ++#pragma unroll ++ for (int r = 0; r < rows_per_lane; r++) { ++ const int i = r * warp_size + lane; ++ // LOAD: cache(STATE_T) -> f32 register. bf16 widened via __bfloat162float; f32 verbatim. ++ if constexpr (STATE_BF16) { ++ s_shard[cc][r] = __bfloat162float(rs[i]); ++ } else { ++ s_shard[cc][r] = rs[i]; ++ } ++ } + } + } + +@@ -218,17 +289,19 @@ gated_delta_net_cuda(const float * q, + if constexpr (keep_rs_t) { + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. ++ // Snapshots target the f32 op-output scratch (never the bf16 cache); the persisted bf16 ++ // write-back for this path is the downstream ggml_cpy (f32->bf16). KEEP F32 here. + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + const int target_slot = (int) n_tokens - 1 - t; + if (target_slot >= 0 && target_slot < K) { + #pragma unroll + for (int cc = 0; cc < COLS_PER_WARP; cc++) { + const int col = col_base + cc * NUM_WARPS; +- float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; ++ float * snap_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; + #pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; +- curr_state[col * S_v + i] = s_shard[cc][r]; ++ snap_state[col * S_v + i] = s_shard[cc][r]; + } + } + } +@@ -236,13 +309,67 @@ gated_delta_net_cuda(const float * q, + } + + if constexpr (!keep_rs_t) { ++ if constexpr (HYBRID) { ++ // Hybrid final-state write. In-place decode (d_f32/d_bf16 != null) writes the per-head ++ // partition slot: bf16 heads store via __float2bfloat16, f32 heads verbatim. The branch is ++ // uniform per block (hyb_bf16). Output-append (both null) writes the full f32 op-output. ++ const bool inplace = (hyb.d_f32 != nullptr || hyb.d_bf16 != nullptr); ++ if (inplace) { ++ float * wb_f32 = hyb.d_f32 + hyb_seqoff; ++ nv_bfloat16 * wb_bf16 = hyb.d_bf16 + hyb_seqoff; + #pragma unroll +- for (int cc = 0; cc < COLS_PER_WARP; cc++) { +- const int col = col_base + cc * NUM_WARPS; ++ for (int cc = 0; cc < COLS_PER_WARP; cc++) { ++ const int col = col_base + cc * NUM_WARPS; + #pragma unroll +- for (int r = 0; r < rows_per_lane; r++) { +- const int i = r * warp_size + lane; +- state[col * S_v + i] = s_shard[cc][r]; ++ for (int r = 0; r < rows_per_lane; r++) { ++ const int i = r * warp_size + lane; ++ if (hyb_bf16) { ++ wb_bf16[col * S_v + i] = __float2bfloat16(s_shard[cc][r]); ++ } else { ++ wb_f32[col * S_v + i] = s_shard[cc][r]; ++ } ++ } ++ } ++ } else { ++ float * st = (dst + attn_score_elems) + state_out_offset; ++#pragma unroll ++ for (int cc = 0; cc < COLS_PER_WARP; cc++) { ++ const int col = col_base + cc * NUM_WARPS; ++#pragma unroll ++ for (int r = 0; r < rows_per_lane; r++) { ++ const int i = r * warp_size + lane; ++ st[col * S_v + i] = s_shard[cc][r]; ++ } ++ } ++ } ++ } else ++ // Final-state write. In-place decode (state_dst != nullptr) writes the persistent cache view ++ // (STATE_T): bf16 store via __float2bfloat16. Non-in-place writes the f32 op-output scratch. ++ if (state_dst != nullptr) { ++ STATE_T * st = state_dst + state_out_offset; ++#pragma unroll ++ for (int cc = 0; cc < COLS_PER_WARP; cc++) { ++ const int col = col_base + cc * NUM_WARPS; ++#pragma unroll ++ for (int r = 0; r < rows_per_lane; r++) { ++ const int i = r * warp_size + lane; ++ if constexpr (STATE_BF16) { ++ st[col * S_v + i] = __float2bfloat16(s_shard[cc][r]); ++ } else { ++ st[col * S_v + i] = s_shard[cc][r]; ++ } ++ } ++ } ++ } else { ++ float * st = (dst + attn_score_elems) + state_out_offset; ++#pragma unroll ++ for (int cc = 0; cc < COLS_PER_WARP; cc++) { ++ const int col = col_base + cc * NUM_WARPS; ++#pragma unroll ++ for (int r = 0; r < rows_per_lane; r++) { ++ const int i = r * warp_size + lane; ++ st[col * S_v + i] = s_shard[cc][r]; ++ } + } + } + } +@@ -258,11 +385,12 @@ gated_delta_net_cuda(const float * q, + #define GDN_DEFAULT_CPW 8 + #endif + +-template ++template + static void launch_gdn_variant( + const float * q_d, const float * k_d, const float * v_d, +- const float * g_d, const float * b_d, const float * s_d, +- float * dst_d, float * state_dst_d, const int32_t * ids_d, int rs_head, ++ const float * g_d, const float * b_d, const gdn_state_t * s_d, ++ float * dst_d, gdn_state_t * state_dst_d, const int32_t * ids_d, int rs_head, ++ gdn_hybrid_args hyb, + int64_t H, int64_t n_tokens, int64_t n_seqs, + int64_t sq1, int64_t sq2, int64_t sq3, + int64_t sv1, int64_t sv2, int64_t sv3, +@@ -273,18 +401,18 @@ static void launch_gdn_variant( + dim3 grid_dims(H, n_seqs, S_v / (NUM_WARPS * COLS_PER_WARP)); + dim3 block_dims(warp_size <= S_v ? warp_size : S_v, NUM_WARPS, 1); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); +- ggml_cuda_kernel_launch(gated_delta_net_cuda, launch_params, ++ ggml_cuda_kernel_launch(gated_delta_net_cuda, launch_params, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, +- sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K, state_dst_d, ids_d, rs_head); ++ sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K, state_dst_d, ids_d, rs_head, hyb); + } + +-template ++template + static void launch_gated_delta_net( + const float * q_d, const float * k_d, const float * v_d, +- const float * g_d, const float * b_d, const float * s_d, +- float * dst_d, float * state_dst_d, +- const int32_t * ids_d, int rs_head, ++ const float * g_d, const float * b_d, const gdn_state_t * s_d, ++ float * dst_d, gdn_state_t * state_dst_d, ++ const int32_t * ids_d, int rs_head, gdn_hybrid_args hyb, + int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, + int64_t sq1, int64_t sq2, int64_t sq3, + int64_t sv1, int64_t sv2, int64_t sv3, +@@ -298,19 +426,19 @@ static void launch_gated_delta_net( + const uint3 rq3_magic = init_fastdiv_values(rq3); + + #define GDN_LAUNCH_ARGS \ +- q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d, ids_d, rs_head, \ ++ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d, ids_d, rs_head, hyb, \ + H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, \ + neqk1_magic, rq3_magic, scale, K, warp_size, stream + + switch (S_v) { + case 16: +- launch_gdn_variant<16, KDA, keep_rs_t, 4, 1, 2>(GDN_LAUNCH_ARGS); ++ launch_gdn_variant<16, KDA, keep_rs_t, 4, 1, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); + break; + case 32: +- launch_gdn_variant<32, KDA, keep_rs_t, 4, 1, 2>(GDN_LAUNCH_ARGS); ++ launch_gdn_variant<32, KDA, keep_rs_t, 4, 1, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); + break; + case 64: +- launch_gdn_variant<64, KDA, keep_rs_t, 4, 1, 2>(GDN_LAUNCH_ARGS); ++ launch_gdn_variant<64, KDA, keep_rs_t, 4, 1, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); + break; + case 128: { + // Bit-exact occupancy/coalescing retune (patch 0022): fold COLS_PER_WARP columns per warp +@@ -322,18 +450,18 @@ static void launch_gated_delta_net( + // NUM_WARPS in {4,8,16} x COLS_PER_WARP ladder (all <=512 threads/block, no 1024-thread + // .minnctapersm warnings). Measured GB10 %peak: (4,1)=73 baseline ... (16,4)=82 ... + // (16,8)=84.7 winner ~ tied with (8,8)/(8,16)/(32,4); the plateau is just above vLLM (82.4). +- if (gdn_nw == 4 && gdn_cpw == 1) launch_gdn_variant<128, KDA, keep_rs_t, 4, 1, 2>(GDN_LAUNCH_ARGS); +- else if (gdn_nw == 4 && gdn_cpw == 2) launch_gdn_variant<128, KDA, keep_rs_t, 4, 2, 2>(GDN_LAUNCH_ARGS); +- else if (gdn_nw == 4 && gdn_cpw == 4) launch_gdn_variant<128, KDA, keep_rs_t, 4, 4, 2>(GDN_LAUNCH_ARGS); +- else if (gdn_nw == 8 && gdn_cpw == 1) launch_gdn_variant<128, KDA, keep_rs_t, 8, 1, 2>(GDN_LAUNCH_ARGS); +- else if (gdn_nw == 8 && gdn_cpw == 2) launch_gdn_variant<128, KDA, keep_rs_t, 8, 2, 2>(GDN_LAUNCH_ARGS); +- else if (gdn_nw == 8 && gdn_cpw == 4) launch_gdn_variant<128, KDA, keep_rs_t, 8, 4, 2>(GDN_LAUNCH_ARGS); +- else if (gdn_nw == 8 && gdn_cpw == 8) launch_gdn_variant<128, KDA, keep_rs_t, 8, 8, 2>(GDN_LAUNCH_ARGS); +- else if (gdn_nw == 16 && gdn_cpw == 1) launch_gdn_variant<128, KDA, keep_rs_t, 16, 1, 2>(GDN_LAUNCH_ARGS); +- else if (gdn_nw == 16 && gdn_cpw == 2) launch_gdn_variant<128, KDA, keep_rs_t, 16, 2, 2>(GDN_LAUNCH_ARGS); +- else if (gdn_nw == 16 && gdn_cpw == 4) launch_gdn_variant<128, KDA, keep_rs_t, 16, 4, 2>(GDN_LAUNCH_ARGS); +- else if (gdn_nw == 16 && gdn_cpw == 8) launch_gdn_variant<128, KDA, keep_rs_t, 16, 8, 2>(GDN_LAUNCH_ARGS); +- else launch_gdn_variant<128, KDA, keep_rs_t, GDN_DEFAULT_NW, GDN_DEFAULT_CPW, 2>(GDN_LAUNCH_ARGS); ++ if (gdn_nw == 4 && gdn_cpw == 1) launch_gdn_variant<128, KDA, keep_rs_t, 4, 1, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); ++ else if (gdn_nw == 4 && gdn_cpw == 2) launch_gdn_variant<128, KDA, keep_rs_t, 4, 2, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); ++ else if (gdn_nw == 4 && gdn_cpw == 4) launch_gdn_variant<128, KDA, keep_rs_t, 4, 4, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); ++ else if (gdn_nw == 8 && gdn_cpw == 1) launch_gdn_variant<128, KDA, keep_rs_t, 8, 1, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); ++ else if (gdn_nw == 8 && gdn_cpw == 2) launch_gdn_variant<128, KDA, keep_rs_t, 8, 2, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); ++ else if (gdn_nw == 8 && gdn_cpw == 4) launch_gdn_variant<128, KDA, keep_rs_t, 8, 4, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); ++ else if (gdn_nw == 8 && gdn_cpw == 8) launch_gdn_variant<128, KDA, keep_rs_t, 8, 8, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); ++ else if (gdn_nw == 16 && gdn_cpw == 1) launch_gdn_variant<128, KDA, keep_rs_t, 16, 1, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); ++ else if (gdn_nw == 16 && gdn_cpw == 2) launch_gdn_variant<128, KDA, keep_rs_t, 16, 2, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); ++ else if (gdn_nw == 16 && gdn_cpw == 4) launch_gdn_variant<128, KDA, keep_rs_t, 16, 4, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); ++ else if (gdn_nw == 16 && gdn_cpw == 8) launch_gdn_variant<128, KDA, keep_rs_t, 16, 8, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); ++ else launch_gdn_variant<128, KDA, keep_rs_t, GDN_DEFAULT_NW, GDN_DEFAULT_CPW, 2, STATE_BF16, HYBRID>(GDN_LAUNCH_ARGS); + break; + } + default: +@@ -352,6 +480,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + ggml_tensor * src_state_dst = dst->src[6]; // optional in-place state write-back target ++ // Hybrid per-head mixed-dtype state (lever A): src[8] = bf16 partition, src[9] = head_slot map. ++ // When present the persisted state is split into an f32 partition (src[5]) and a bf16 partition ++ // (src[8]); each head reads/writes its partition per head_slot. Detected by src[9] != null. ++ ggml_tensor * src_state_bf16 = dst->src[8]; ++ ggml_tensor * src_head_slot = dst->src[9]; ++ const bool hybrid = (src_head_slot != nullptr); + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb); +@@ -381,13 +515,19 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * + + float * dst_d = (float *) dst->data; + +- float * state_dst_d = nullptr; +- if (src_state_dst != nullptr) { ++ // Recurrent SSM state cache element type: f32 (default, bit-exact) or bf16 (opt-in, halves the ++ // dominant decode byte stream). Math stays f32 in registers; only load/store cross the boundary. ++ const ggml_type state_type = src_state->type; ++ const bool state_bf16 = (state_type == GGML_TYPE_BF16); ++ const size_t state_ts = ggml_type_size(state_type); ++ ++ void * state_dst_d = nullptr; ++ if (!hybrid && src_state_dst != nullptr) { + // in-place final-state cache view: per-seq stride must be the dense state size D = S_v*S_v*H +- GGML_ASSERT(src_state_dst->type == GGML_TYPE_F32); +- GGML_ASSERT(src_state_dst->nb[0] == sizeof(float)); +- GGML_ASSERT(src_state_dst->nb[1] == (size_t) S_v * S_v * H * sizeof(float)); +- state_dst_d = (float *) src_state_dst->data; ++ GGML_ASSERT(src_state_dst->type == state_type); ++ GGML_ASSERT(src_state_dst->nb[0] == state_ts); ++ GGML_ASSERT(src_state_dst->nb[1] == (size_t) S_v * S_v * H * state_ts); ++ state_dst_d = src_state_dst->data; + } + + // Step 2: fused recurrent-state gather (src[7] = ids == s_copy). Read the prior state directly +@@ -398,18 +538,26 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * + // the recurrence never reads a slot another block writes, so it is race-free and bit-identical to + // the get_rows path. ids stays a DEVICE pointer (dereferenced only inside the kernels). + ggml_tensor * src_ids = dst->src[7]; +- const float * s_d = (const float *) src_state->data; ++ const void * s_d = src_state->data; + const int32_t * ids_d = nullptr; + int rs_head = 0; +- ggml_cuda_pool_alloc ids_state_scratch(ctx.pool()); +- if (src_ids != nullptr) { ++ // gather scratch shares the cache element type (state_ts bytes/elem) so the recurrence kernel's ++ // load conversion is uniform for in-place and gathered sequences (see gdn_gather_nonident_kernel). ++ ggml_cuda_pool_alloc ids_state_scratch(ctx.pool()); ++ if (!hybrid && src_ids != nullptr) { + GGML_ASSERT(state_dst_d != nullptr); + GGML_ASSERT(src_ids->type == GGML_TYPE_I32); + rs_head = ggml_get_op_params_i32(dst, 1); + ids_d = (const int32_t *) src_ids->data; + const int64_t D = S_v * S_v * H; +- float * scratch = ids_state_scratch.alloc((size_t) D * n_seqs); +- ggml_cuda_gdn_gather_nonident(s_d, ids_d, rs_head, scratch, D, n_seqs, ctx.stream()); ++ char * scratch = ids_state_scratch.alloc((size_t) D * n_seqs * state_ts); ++ if (state_bf16) { ++ ggml_cuda_gdn_gather_nonident((const nv_bfloat16 *) s_d, ids_d, rs_head, ++ (nv_bfloat16 *) scratch, D, n_seqs, ctx.stream()); ++ } else { ++ ggml_cuda_gdn_gather_nonident((const float *) s_d, ids_d, rs_head, ++ (float *) scratch, D, n_seqs, ctx.stream()); ++ } + s_d = scratch; + } + +@@ -444,25 +592,95 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * + // in-place write-back is only valid for the single-snapshot (final-state) case + GGML_ASSERT(state_dst_d == nullptr || !keep_rs); + +- if (kda) { +- if (keep_rs) { +- launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d, ids_d, rs_head, +- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, +- sb1, sb2, sb3, neqk1, rq3, scale, K, stream); ++ // Hybrid per-head mixed-dtype dispatch (lever A). The state lives in two partitions: src[5] f32 ++ // (long-memory heads) + src[8] bf16 (fast heads); head_slot maps each head to its partition+index. ++ // q/k/v/g/beta and the op output stay f32; only the per-head load/store crosses the dtype boundary. ++ if (hybrid) { ++ GGML_ASSERT(src_state->type == GGML_TYPE_F32); ++ GGML_ASSERT(src_state_bf16->type == GGML_TYPE_BF16); ++ GGML_ASSERT(src_head_slot->type == GGML_TYPE_I32); ++ const int n_f32 = (int) src_state->ne[2]; ++ const int n_bf16 = (int) src_state_bf16->ne[2]; ++ GGML_ASSERT((int64_t) n_f32 + n_bf16 == H); ++ const int64_t SS = S_v * S_v; ++ const int64_t D_f32 = SS * n_f32; ++ const int64_t D_bf16 = SS * n_bf16; ++ ++ gdn_hybrid_args hyb = {}; ++ hyb.head_slot = (const int32_t *) src_head_slot->data; ++ hyb.n_f32 = n_f32; ++ hyb.n_bf16 = n_bf16; ++ ++ const int32_t * hids = nullptr; ++ int hrs_head = 0; ++ ggml_cuda_pool_alloc hyb_scr_f32(ctx.pool()); ++ ggml_cuda_pool_alloc hyb_scr_bf16(ctx.pool()); ++ ++ if (src_ids != nullptr) { ++ // in-place ids decode: gather both partitions for non-identity seqs; identity reads dst. ++ GGML_ASSERT(src_state_dst != nullptr && src_state_dst->type == GGML_TYPE_F32); ++ GGML_ASSERT(src_ids->type == GGML_TYPE_I32); ++ GGML_ASSERT(src_state_dst->ne[0] == D_f32); ++ GGML_ASSERT(src_state_dst->nb[1] == (size_t) D_f32 * sizeof(float)); ++ hrs_head = ggml_get_op_params_i32(dst, 1); ++ hids = (const int32_t *) src_ids->data; ++ float * scr_f32 = hyb_scr_f32.alloc((size_t) D_f32 * n_seqs); ++ nv_bfloat16 * scr_bf16 = hyb_scr_bf16.alloc((size_t) D_bf16 * n_seqs); ++ ggml_cuda_gdn_gather_nonident((const float *) src_state->data, hids, hrs_head, ++ scr_f32, D_f32, n_seqs, stream); ++ ggml_cuda_gdn_gather_nonident((const nv_bfloat16 *) src_state_bf16->data, hids, hrs_head, ++ scr_bf16, D_bf16, n_seqs, stream); ++ hyb.rin_f32 = scr_f32; ++ hyb.rin_bf16 = scr_bf16; ++ // in-place dst bases, PRE-OFFSET to the rs_head slot. ++ hyb.d_f32 = (float *) src_state_dst->data; // view already at rs_head*D_f32 ++ hyb.d_bf16 = (nv_bfloat16 *) src_state_bf16->data + (int64_t) hrs_head * D_bf16; + } else { +- launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d, ids_d, rs_head, +- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, +- sb1, sb2, sb3, neqk1, rq3, scale, K, stream); ++ // output-append form (test / non-ids): read s0 partitions directly; write the f32 op output. ++ hyb.rin_f32 = (const float *) src_state->data; ++ hyb.rin_bf16 = (const nv_bfloat16 *) src_state_bf16->data; ++ hyb.d_f32 = nullptr; ++ hyb.d_bf16 = nullptr; + } +- } else { +- if (keep_rs) { +- launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d, ids_d, rs_head, +- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, +- sb1, sb2, sb3, neqk1, rq3, scale, K, stream); ++ ++ float * s_null = nullptr; ++#define GDN_DISPATCH_HYB(KDA_, KEEP_) \ ++ launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, \ ++ (const float *) s_null, dst_d, (float *) s_null, hids, hrs_head, hyb, \ ++ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, \ ++ sb1, sb2, sb3, neqk1, rq3, scale, K, stream) ++ if (kda) { ++ if (keep_rs) { GDN_DISPATCH_HYB(true, true); } else { GDN_DISPATCH_HYB(true, false); } + } else { +- launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_dst_d, ids_d, rs_head, +- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, +- sb1, sb2, sb3, neqk1, rq3, scale, K, stream); ++ if (keep_rs) { GDN_DISPATCH_HYB(false, true); } else { GDN_DISPATCH_HYB(false, false); } + } ++#undef GDN_DISPATCH_HYB ++ return; ++ } ++ ++ // Dispatch on (kda, keep_rs, state_bf16). The state pointers (s_d / state_dst_d) are typed by the ++ // cache dtype; q/k/v/g/beta and the op output (dst_d) stay f32. f32 is the default; bf16 opt-in. ++ // hyb_zero is unused on the homogeneous path (HYBRID=false elides every hyb read at compile time). ++ gdn_hybrid_args hyb_zero = {}; ++#define GDN_DISPATCH(KDA_, KEEP_) \ ++ do { \ ++ if (state_bf16) { \ ++ launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, \ ++ (const nv_bfloat16 *) s_d, dst_d, (nv_bfloat16 *) state_dst_d, ids_d, rs_head, hyb_zero,\ ++ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, \ ++ sb1, sb2, sb3, neqk1, rq3, scale, K, stream); \ ++ } else { \ ++ launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, \ ++ (const float *) s_d, dst_d, (float *) state_dst_d, ids_d, rs_head, hyb_zero, \ ++ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, \ ++ sb1, sb2, sb3, neqk1, rq3, scale, K, stream); \ ++ } \ ++ } while (0) ++ ++ if (kda) { ++ if (keep_rs) { GDN_DISPATCH(true, true); } else { GDN_DISPATCH(true, false); } ++ } else { ++ if (keep_rs) { GDN_DISPATCH(false, true); } else { GDN_DISPATCH(false, false); } + } ++#undef GDN_DISPATCH + } +diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c +index 0b8798d..16b180f 100644 +--- a/ggml/src/ggml.c ++++ b/ggml/src/ggml.c +@@ -5257,7 +5257,7 @@ static struct ggml_tensor * ggml_fill_impl( + struct ggml_tensor * a, + float c, + bool inplace) { +- GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16); ++ GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16); + GGML_ASSERT(ggml_is_contiguous(a)); + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); +@@ -6302,7 +6302,9 @@ struct ggml_tensor * ggml_gated_delta_net( + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(beta->type == GGML_TYPE_F32); +- GGML_ASSERT(state->type == GGML_TYPE_F32); ++ // recurrent SSM state cache may be f32 (default, bit-exact) or bf16 (opt-in width halving); the ++ // recurrence math is always f32 in registers (load->f32, store->cache dtype). ++ GGML_ASSERT(state->type == GGML_TYPE_F32 || state->type == GGML_TYPE_BF16); + + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; +@@ -6364,9 +6366,9 @@ struct ggml_tensor * ggml_gated_delta_net_inplace( + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(beta->type == GGML_TYPE_F32); +- GGML_ASSERT(state->type == GGML_TYPE_F32); ++ GGML_ASSERT(state->type == GGML_TYPE_F32 || state->type == GGML_TYPE_BF16); + GGML_ASSERT(state_dst != NULL); +- GGML_ASSERT(state_dst->type == GGML_TYPE_F32); ++ GGML_ASSERT(state_dst->type == GGML_TYPE_F32 || state_dst->type == GGML_TYPE_BF16); + + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; +@@ -6384,7 +6386,7 @@ struct ggml_tensor * ggml_gated_delta_net_inplace( + // state_dst holds the per-seq final state contiguously: [S_v*S_v*H, >= n_seqs] + GGML_ASSERT(state_dst->ne[0] == S_v * S_v * H); + GGML_ASSERT(state_dst->ne[1] >= n_seqs); +- GGML_ASSERT(state_dst->nb[0] == sizeof(float)); ++ GGML_ASSERT(state_dst->nb[0] == ggml_type_size(state_dst->type)); + + const int64_t state_rows = S_v * n_seqs; // K == 1 + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; +@@ -6434,8 +6436,8 @@ struct ggml_tensor * ggml_gated_delta_net_inplace_ids( + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(beta->type == GGML_TYPE_F32); +- GGML_ASSERT(state->type == GGML_TYPE_F32); +- GGML_ASSERT(state_dst != NULL && state_dst->type == GGML_TYPE_F32); ++ GGML_ASSERT(state->type == GGML_TYPE_F32 || state->type == GGML_TYPE_BF16); ++ GGML_ASSERT(state_dst != NULL && (state_dst->type == GGML_TYPE_F32 || state_dst->type == GGML_TYPE_BF16)); + GGML_ASSERT(ids != NULL && ids->type == GGML_TYPE_I32); + + const int64_t S_v = v->ne[0]; +@@ -6455,7 +6457,7 @@ struct ggml_tensor * ggml_gated_delta_net_inplace_ids( + // state_dst holds the per-seq final state contiguously: [S_v*S_v*H, >= n_seqs] + GGML_ASSERT(state_dst->ne[0] == S_v * S_v * H); + GGML_ASSERT(state_dst->ne[1] >= n_seqs); +- GGML_ASSERT(state_dst->nb[0] == sizeof(float)); ++ GGML_ASSERT(state_dst->nb[0] == ggml_type_size(state_dst->type)); + + // ids: per-seq source slot into the full cache (s_copy_main) + GGML_ASSERT(ids->ne[0] >= n_seqs); +@@ -6480,6 +6482,158 @@ struct ggml_tensor * ggml_gated_delta_net_inplace_ids( + return result; + } + ++// ggml_gated_delta_net_hybrid ++// ++// Per-head mixed-dtype recurrent state (lever A), output-append form. The persisted state is split ++// into two dtype-homogeneous partitions sized by head COUNT: state_f32 [S_v,S_v,n_f32,n_seqs] (f32, ++// long-memory heads) and state_bf16 [S_v,S_v,n_bf16,n_seqs] (bf16, fast heads). head_slot is I32[H] ++// mapping head h -> local_idx in its partition: encoded local_idx (>=0) for f32, -(local_idx+1) (<0) ++// for bf16. q/k/v/g/beta keep natural head order. The recurrence runs in f32 registers; only the ++// per-head load crosses the dtype boundary. The op output is FULL f32 (attn + state region), like ++// ggml_gated_delta_net. ++struct ggml_tensor * ggml_gated_delta_net_hybrid( ++ struct ggml_context * ctx, ++ struct ggml_tensor * q, ++ struct ggml_tensor * k, ++ struct ggml_tensor * v, ++ struct ggml_tensor * g, ++ struct ggml_tensor * beta, ++ struct ggml_tensor * state_f32, ++ struct ggml_tensor * state_bf16, ++ struct ggml_tensor * head_slot, ++ int64_t K) { ++ GGML_ASSERT(ggml_is_contiguous_rows(q)); ++ GGML_ASSERT(ggml_is_contiguous_rows(k)); ++ GGML_ASSERT(ggml_is_contiguous_rows(v)); ++ GGML_ASSERT(ggml_is_contiguous(g)); ++ GGML_ASSERT(ggml_is_contiguous(beta)); ++ GGML_ASSERT(ggml_is_contiguous(state_f32)); ++ GGML_ASSERT(ggml_is_contiguous(state_bf16)); ++ ++ GGML_ASSERT(q->type == GGML_TYPE_F32); ++ GGML_ASSERT(k->type == GGML_TYPE_F32); ++ GGML_ASSERT(v->type == GGML_TYPE_F32); ++ GGML_ASSERT(g->type == GGML_TYPE_F32); ++ GGML_ASSERT(beta->type == GGML_TYPE_F32); ++ GGML_ASSERT(state_f32->type == GGML_TYPE_F32); ++ GGML_ASSERT(state_bf16->type == GGML_TYPE_BF16); ++ GGML_ASSERT(head_slot != NULL && head_slot->type == GGML_TYPE_I32); ++ ++ const int64_t S_v = v->ne[0]; ++ const int64_t H = v->ne[1]; ++ const int64_t n_tokens = v->ne[2]; ++ const int64_t n_seqs = v->ne[3]; ++ ++ GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); ++ GGML_ASSERT(beta->ne[0] == 1); ++ ++ const int64_t n_f32 = state_f32->ne[2]; ++ const int64_t n_bf16 = state_bf16->ne[2]; ++ GGML_ASSERT(state_f32->ne[0] == S_v && state_f32->ne[1] == S_v && state_f32->ne[3] == n_seqs); ++ GGML_ASSERT(state_bf16->ne[0] == S_v && state_bf16->ne[1] == S_v && state_bf16->ne[3] == n_seqs); ++ GGML_ASSERT(n_f32 + n_bf16 == H); ++ GGML_ASSERT(head_slot->ne[0] >= H); ++ GGML_ASSERT(K >= 1); ++ ++ const int64_t state_rows = K * S_v * n_seqs; ++ const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; ++ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); ++ ++ ggml_set_op_params_i32(result, 0, (int32_t) K); ++ ++ result->op = GGML_OP_GATED_DELTA_NET; ++ result->src[0] = q; ++ result->src[1] = k; ++ result->src[2] = v; ++ result->src[3] = g; ++ result->src[4] = beta; ++ result->src[5] = state_f32; // f32 partition (long-memory heads) ++ result->src[8] = state_bf16; // bf16 partition (fast heads) ++ result->src[9] = head_slot; // per-head -> partition + local index ++ ++ return result; ++} ++ ++// ggml_gated_delta_net_inplace_ids_hybrid ++// ++// Per-head mixed-dtype recurrent state (lever A), in-place ids form (the GDN decode path). state_f32 ++// /state_bf16 are the FULL split caches [S_v,S_v,n_part,n_rs_slots]; state_dst_f32 is the in-place ++// f32 write view at rs_head; the bf16 write view is derived in the backend from state_bf16 + rs_head. ++struct ggml_tensor * ggml_gated_delta_net_inplace_ids_hybrid( ++ struct ggml_context * ctx, ++ struct ggml_tensor * q, ++ struct ggml_tensor * k, ++ struct ggml_tensor * v, ++ struct ggml_tensor * g, ++ struct ggml_tensor * beta, ++ struct ggml_tensor * state_f32, ++ struct ggml_tensor * state_dst_f32, ++ struct ggml_tensor * ids, ++ struct ggml_tensor * state_bf16, ++ struct ggml_tensor * head_slot, ++ int rs_head) { ++ GGML_ASSERT(ggml_is_contiguous_rows(q)); ++ GGML_ASSERT(ggml_is_contiguous_rows(k)); ++ GGML_ASSERT(ggml_is_contiguous_rows(v)); ++ GGML_ASSERT(ggml_is_contiguous(g)); ++ GGML_ASSERT(ggml_is_contiguous(beta)); ++ GGML_ASSERT(ggml_is_contiguous(state_f32)); ++ GGML_ASSERT(ggml_is_contiguous(state_bf16)); ++ ++ GGML_ASSERT(q->type == GGML_TYPE_F32); ++ GGML_ASSERT(k->type == GGML_TYPE_F32); ++ GGML_ASSERT(v->type == GGML_TYPE_F32); ++ GGML_ASSERT(g->type == GGML_TYPE_F32); ++ GGML_ASSERT(beta->type == GGML_TYPE_F32); ++ GGML_ASSERT(state_f32->type == GGML_TYPE_F32); ++ GGML_ASSERT(state_bf16->type == GGML_TYPE_BF16); ++ GGML_ASSERT(state_dst_f32 != NULL && state_dst_f32->type == GGML_TYPE_F32); ++ GGML_ASSERT(ids != NULL && ids->type == GGML_TYPE_I32); ++ GGML_ASSERT(head_slot != NULL && head_slot->type == GGML_TYPE_I32); ++ ++ const int64_t S_v = v->ne[0]; ++ const int64_t H = v->ne[1]; ++ const int64_t n_tokens = v->ne[2]; ++ const int64_t n_seqs = v->ne[3]; ++ ++ GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); ++ GGML_ASSERT(beta->ne[0] == 1); ++ ++ const int64_t n_f32 = state_f32->ne[2]; ++ const int64_t n_bf16 = state_bf16->ne[2]; ++ GGML_ASSERT(state_f32->ne[0] == S_v && state_f32->ne[1] == S_v && state_f32->ne[3] >= n_seqs); ++ GGML_ASSERT(state_bf16->ne[0] == S_v && state_bf16->ne[1] == S_v && state_bf16->ne[3] >= n_seqs); ++ GGML_ASSERT(n_f32 + n_bf16 == H); ++ GGML_ASSERT(head_slot->ne[0] >= H); ++ ++ // state_dst_f32 holds the per-seq f32-partition final state contiguously: [S_v*S_v*n_f32, >= n_seqs] ++ GGML_ASSERT(state_dst_f32->ne[0] == S_v * S_v * n_f32); ++ GGML_ASSERT(state_dst_f32->ne[1] >= n_seqs); ++ GGML_ASSERT(state_dst_f32->nb[0] == sizeof(float)); ++ GGML_ASSERT(ids->ne[0] >= n_seqs); ++ ++ const int64_t state_rows = S_v * n_seqs; // K == 1 ++ const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; ++ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); ++ ++ ggml_set_op_params_i32(result, 0, 1); // K == 1 ++ ggml_set_op_params_i32(result, 1, rs_head); // destination base slot ++ ++ result->op = GGML_OP_GATED_DELTA_NET; ++ result->src[0] = q; ++ result->src[1] = k; ++ result->src[2] = v; ++ result->src[3] = g; ++ result->src[4] = beta; ++ result->src[5] = state_f32; // FULL f32 cache (read via ids / gather) ++ result->src[6] = state_dst_f32; // in-place f32 final-state write view at rs_head ++ result->src[7] = ids; // per-seq source slots (s_copy) ++ result->src[8] = state_bf16; // FULL bf16 cache (read via ids; dst derived from rs_head) ++ result->src[9] = head_slot; // per-head -> partition + local index ++ ++ return result; ++} ++ + //////////////////////////////////////////////////////////////////////////////// + + struct ggml_hash_set ggml_hash_set_new(size_t size) { +diff --git a/include/llama.h b/include/llama.h +index f723c9f..74d8599 100644 +--- a/include/llama.h ++++ b/include/llama.h +@@ -364,6 +364,12 @@ extern "C" { + + enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] + enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] ++ enum ggml_type type_r; // data type for the recurrent conv-state cache (default f32) [EXPERIMENTAL] ++ enum ggml_type type_s; // data type for the recurrent SSM-state cache (default f32; bf16 opt-in) [EXPERIMENTAL] ++ // Hybrid per-head SSM-state precision (lever A): a head is kept f32 iff its memory length ++ // tau_h = 1/(|ssm_a|*softplus(ssm_dt)) (tokens) > this threshold; else bf16. Default 0.0 keeps ++ // EVERY head f32 (bit-exact). Raise it (e.g. 32/64) to bf16 the fast-decaying heads. [EXPERIMENTAL] ++ float ssm_hybrid_tau_thresh; + + // Abort callback + // if it returns true, execution of llama_decode() will be aborted +diff --git a/src/llama-context.cpp b/src/llama-context.cpp +index 220240e..5c90c48 100644 +--- a/src/llama-context.cpp ++++ b/src/llama-context.cpp +@@ -201,6 +201,8 @@ llama_context::llama_context( + cparams.fused_gdn_ch = true; + cparams.auto_fgdn = true; + ++ cparams.ssm_hybrid_tau_thresh = params.ssm_hybrid_tau_thresh; ++ + // with causal attention, the batch size is limited by the context size + cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; + +@@ -325,6 +327,9 @@ llama_context::llama_context( + llama_memory_params params_mem = { + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, ++ /*.type_r =*/ params.type_r, ++ /*.type_s =*/ params.type_s, ++ /*.ssm_hybrid_tau_thresh =*/ params.ssm_hybrid_tau_thresh, + /*.swa_full =*/ params.swa_full, + /*.ctx_type =*/ cparams.ctx_type, + /*.mem_other =*/ llama_get_memory(cparams.ctx_other), +@@ -3471,6 +3476,9 @@ llama_context_params llama_context_default_params() { + /*.cb_eval_user_data =*/ nullptr, + /*.type_k =*/ GGML_TYPE_F16, + /*.type_v =*/ GGML_TYPE_F16, ++ /*.type_r =*/ GGML_TYPE_F32, // recurrent conv-state cache: f32 (bit-exact default) ++ /*.type_s =*/ GGML_TYPE_F32, // recurrent SSM-state cache: f32 default, bf16 opt-in ++ /*.ssm_hybrid_tau_thresh =*/ 0.0f, // 0 => all heads f32 (bit-exact default) + /*.abort_callback =*/ nullptr, + /*.abort_callback_data =*/ nullptr, + /*.embeddings =*/ false, +diff --git a/src/llama-cparams.h b/src/llama-cparams.h +index 546ae1e..888f9d7 100644 +--- a/src/llama-cparams.h ++++ b/src/llama-cparams.h +@@ -31,6 +31,9 @@ struct llama_cparams { + float yarn_beta_fast; + float yarn_beta_slow; + ++ // hybrid per-head SSM-state precision threshold (lever A); 0.0 => all heads f32 (bit-exact default) ++ float ssm_hybrid_tau_thresh = 0.0f; ++ + bool embeddings; + bool embeddings_nextn; // also extract the hidden state before the final output norm + bool embeddings_nextn_masked; // extract for only rows where batch.logits != 0 +diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp +index abdb48d..931258d 100644 +--- a/src/llama-graph.cpp ++++ b/src/llama-graph.cpp +@@ -2791,7 +2791,13 @@ ggml_tensor * llm_graph_context::build_rs( + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); +- ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); ++ // ggml_scale is f32-only on CUDA; for a bf16 (opt-in) SSM-state cache use ggml_fill, which writes ++ // the cache dtype directly. The f32 conv/SSM path keeps the exact scale-by-0 op (bit-exactness). ++ if (states->type == GGML_TYPE_F32) { ++ ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); ++ } else { ++ ggml_build_forward_expand(gf, ggml_fill_inplace(ctx0, state_zero, 0.0f)); ++ } + + // copy states + // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs +diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp +index c7d4bcd..d9893bb 100644 +--- a/src/llama-memory-hybrid-iswa.cpp ++++ b/src/llama-memory-hybrid-iswa.cpp +@@ -29,7 +29,8 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn, +- const layer_filter_cb & filter_recr) : ++ const layer_filter_cb & filter_recr, ++ float recurrent_ssm_hybrid_tau) : + hparams(model.hparams), + mem_attn(new llama_kv_cache_iswa( + model, +@@ -60,7 +61,8 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( + n_rs_seq, + filter_recr == nullptr ? + [&](int32_t il) { return hparams.is_recr(il); } +- : filter_recr ++ : filter_recr, ++ recurrent_ssm_hybrid_tau + )) {} + + llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { +diff --git a/src/llama-memory-hybrid-iswa.h b/src/llama-memory-hybrid-iswa.h +index c9d3f9f..6be9cf2 100644 +--- a/src/llama-memory-hybrid-iswa.h ++++ b/src/llama-memory-hybrid-iswa.h +@@ -39,7 +39,8 @@ public: + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn = nullptr, +- const layer_filter_cb & filter_recr = nullptr); ++ const layer_filter_cb & filter_recr = nullptr, ++ float recurrent_ssm_hybrid_tau = 0.0f); + + ~llama_memory_hybrid_iswa() = default; + +diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp +index f2d49cb..dc3084d 100644 +--- a/src/llama-memory-hybrid.cpp ++++ b/src/llama-memory-hybrid.cpp +@@ -29,7 +29,8 @@ llama_memory_hybrid::llama_memory_hybrid( + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn, +- const layer_filter_cb & filter_recr) : ++ const layer_filter_cb & filter_recr, ++ float recurrent_ssm_hybrid_tau) : + hparams(model.hparams), + mem_attn(new llama_kv_cache( + model, +@@ -61,7 +62,8 @@ llama_memory_hybrid::llama_memory_hybrid( + n_rs_seq, + filter_recr == nullptr ? + [&](int32_t il) { return hparams.is_recr(il); } +- : filter_recr ++ : filter_recr, ++ recurrent_ssm_hybrid_tau + )) {} + + llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { +diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h +index 484eafb..46618d3 100644 +--- a/src/llama-memory-hybrid.h ++++ b/src/llama-memory-hybrid.h +@@ -39,7 +39,9 @@ public: + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn = nullptr, +- const layer_filter_cb & filter_recr = nullptr); ++ const layer_filter_cb & filter_recr = nullptr, ++ /* hybrid per-head SSM precision (lever A); 0 => all f32 */ ++ float recurrent_ssm_hybrid_tau = 0.0f); + + ~llama_memory_hybrid() = default; + +diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp +index 6a4892f..cb7a9ed 100644 +--- a/src/llama-memory-recurrent.cpp ++++ b/src/llama-memory-recurrent.cpp +@@ -8,6 +8,7 @@ + + #include + #include ++#include + #include + #include + #include +@@ -25,7 +26,8 @@ llama_memory_recurrent::llama_memory_recurrent( + uint32_t mem_size, + uint32_t n_seq_max, + uint32_t n_rs_seq, +- const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { ++ const layer_filter_cb & filter, ++ float ssm_hybrid_tau_thresh) : hparams(model.hparams), n_seq_max(n_seq_max) { + const int32_t n_layer = hparams.n_layer(); + + head = 0; +@@ -51,7 +53,8 @@ llama_memory_recurrent::llama_memory_recurrent( + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { +- /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), ++ // up to 4 tensors per layer when hybrid SSM state is split (r, s_f32, s_bf16, head_slot) ++ /*.mem_size =*/ size_t(4u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; +@@ -71,6 +74,20 @@ llama_memory_recurrent::llama_memory_recurrent( + + r_l.resize(n_layer); + s_l.resize(n_layer); ++ s_l_bf16.assign(n_layer, nullptr); ++ head_slot_l.assign(n_layer, nullptr); ++ n_s_f32.assign(n_layer, 0); ++ n_s_bf16.assign(n_layer, 0); ++ ++ // Hybrid per-head SSM-state classifier (lever A). For each GDN layer, tau_h = 1/(|ssm_a[h]| * ++ // softplus(ssm_dt[h])) is the head's memory length in tokens. A head is kept f32 iff tau_h > ++ // ssm_hybrid_tau_thresh, else bf16. Default thresh 0.0 => every tau>0 => ALL heads f32 (no split, ++ // bit-exact). The head_slot map is uploaded after the backend buffers are allocated. ++ const bool want_hybrid = (ssm_hybrid_tau_thresh > 0.0f); ++ const int64_t H_v = (int64_t) hparams.ssm_dt_rank; // n_v_heads ++ const int64_t SS = (H_v > 0) ? (int64_t) hparams.n_embd_s() / H_v : 0; // S_v*S_v ++ auto softplus = [](float x) -> float { return x > 20.0f ? x : log1pf(expf(x)); }; ++ std::vector> head_slot_host(n_layer); + + for (int i = 0; i < n_layer; i++) { + if (filter && !filter(i)) { +@@ -98,11 +115,62 @@ llama_memory_recurrent::llama_memory_recurrent( + + const uint32_t n_rows = mem_size * (1 + n_rs_seq); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows); +- ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows); + ggml_format_name(r, "cache_r_l%d", i); +- ggml_format_name(s, "cache_s_l%d", i); + r_l[i] = r; +- s_l[i] = s; ++ ++ // classify this GDN layer's heads (only when hybrid is requested, the SSM weights exist as ++ // f32, and the layer is a true GDN layer with n_embd_s split cleanly into H_v heads). ++ int n_f32 = (int) H_v; ++ int n_bf16 = 0; ++ std::vector hslot; ++ const llama_layer & lay = model.layers[i]; ++ // NOTE: data may be NULL during the device-memory-fitting pre-pass (weights not yet allocated). ++ // In that case fall back to the single f32 cache (a conservative, larger memory estimate); the ++ // real load runs with data populated and performs the actual classification. ++ if (want_hybrid && H_v > 0 && SS * H_v == (int64_t) hparams.n_embd_s() && ++ lay.ssm_a != nullptr && lay.ssm_dt != nullptr && ++ lay.ssm_a->data != nullptr && lay.ssm_dt->data != nullptr && ++ lay.ssm_a->type == GGML_TYPE_F32 && lay.ssm_dt->type == GGML_TYPE_F32 && ++ ggml_nelements(lay.ssm_a) >= H_v && ggml_nelements(lay.ssm_dt) >= H_v) { ++ std::vector a_host(H_v), dt_host(H_v); ++ ggml_backend_tensor_get(lay.ssm_a, a_host.data(), 0, H_v * sizeof(float)); ++ ggml_backend_tensor_get(lay.ssm_dt, dt_host.data(), 0, H_v * sizeof(float)); ++ hslot.assign(H_v, 0); ++ int lf = 0, lb = 0; ++ for (int64_t h = 0; h < H_v; ++h) { ++ const float denom = fabsf(a_host[h]) * softplus(dt_host[h]); ++ const float tau = denom > 0.0f ? 1.0f / denom : INFINITY; ++ const bool is_f32 = (tau > ssm_hybrid_tau_thresh); ++ if (is_f32) { hslot[h] = lf++; } ++ else { hslot[h] = -(lb++ + 1); } ++ } ++ n_f32 = lf; ++ n_bf16 = lb; ++ } ++ ++ const bool split = (n_bf16 > 0 && n_f32 > 0); ++ if (split) { ++ ssm_hybrid = true; ++ ggml_tensor * s = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, SS * n_f32, n_rows); ++ ggml_tensor * sb = ggml_new_tensor_2d(ctx, GGML_TYPE_BF16, SS * n_bf16, n_rows); ++ ggml_tensor * hsl = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, H_v); ++ ggml_format_name(s, "cache_s_l%d", i); ++ ggml_format_name(sb, "cache_s_bf16_l%d", i); ++ ggml_format_name(hsl, "cache_headslot_l%d", i); ++ s_l[i] = s; ++ s_l_bf16[i] = sb; ++ head_slot_l[i] = hsl; ++ head_slot_host[i] = std::move(hslot); ++ LLAMA_LOG_INFO("%s, layer %3d: hybrid SSM state n_f32=%d n_bf16=%d (f_bytes=%.3f)\n", ++ __func__, i, n_f32, n_bf16, (float)(n_f32 + n_bf16 * 0.5f) / (float) H_v); ++ } else { ++ // not split: single f32 cache for ALL heads (the existing byte-identical path). ++ ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows); ++ ggml_format_name(s, "cache_s_l%d", i); ++ s_l[i] = s; ++ } ++ n_s_f32[i] = n_f32; ++ n_s_bf16[i] = split ? n_bf16 : 0; + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding +@@ -116,6 +184,14 @@ llama_memory_recurrent::llama_memory_recurrent( + ctxs_bufs.emplace_back(std::move(ctx), buf); + } + ++ // upload the per-layer head_slot maps now that the backend buffers exist (split layers only). ++ for (int i = 0; i < n_layer; i++) { ++ if (head_slot_l[i] != nullptr && !head_slot_host[i].empty()) { ++ ggml_backend_tensor_set(head_slot_l[i], head_slot_host[i].data(), 0, ++ head_slot_host[i].size() * sizeof(int32_t)); ++ } ++ } ++ + { + const size_t memory_size_r = size_r_bytes(); + const size_t memory_size_s = size_s_bytes(); +@@ -729,6 +805,17 @@ size_t llama_memory_recurrent::size_s_bytes() const { + size_s_bytes += ggml_nbytes(s); + } + } ++ // hybrid (lever A): also count the bf16 partitions (+ the tiny head_slot maps). ++ for (const auto & s : s_l_bf16) { ++ if (s != nullptr) { ++ size_s_bytes += ggml_nbytes(s); ++ } ++ } ++ for (const auto & h : head_slot_l) { ++ if (h != nullptr) { ++ size_s_bytes += ggml_nbytes(h); ++ } ++ } + + return size_s_bytes; + } +@@ -1041,6 +1128,44 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell + return true; + } + ++// Back-compat for recurrent state files whose saved dtype differs from the live cache dtype, limited ++// to f32<->bf16 (the SSM-state opt-in). Reads cell_count rows of n_embd elements in the SAVED dtype, ++// converts to the tensor's live dtype, and writes them at `head`. `size_row_ref` is the saved row ++// size already read from the stream. Returns false for any non-f32<->bf16 mismatch or size mismatch. ++static bool recurrent_read_convert_rows(llama_io_read_i & io, ggml_tensor * t, ggml_type saved_type, ++ uint64_t size_row_ref, uint32_t n_embd, ++ uint32_t head, uint32_t cell_count, const char * tag, int il) { ++ const ggml_type live_type = t->type; ++ const bool ok_pair = ++ (saved_type == GGML_TYPE_F32 && live_type == GGML_TYPE_BF16) || ++ (saved_type == GGML_TYPE_BF16 && live_type == GGML_TYPE_F32); ++ if (!ok_pair) { ++ LLAMA_LOG_ERROR("%s: mismatched %s type (%d != %d, layer %d) is not an f32<->bf16 conversion\n", ++ __func__, tag, (int) live_type, (int) saved_type, il); ++ return false; ++ } ++ const size_t saved_row = ggml_row_size(saved_type, n_embd); ++ if (saved_row != (size_t) size_row_ref) { ++ LLAMA_LOG_ERROR("%s: mismatched %s row size (%zu != %zu, layer %d)\n", ++ __func__, tag, saved_row, (size_t) size_row_ref, il); ++ return false; ++ } ++ if (cell_count == 0) { ++ return true; ++ } ++ const int64_t n_el = (int64_t) cell_count * n_embd; ++ std::vector raw((size_t) n_el * ggml_type_size(saved_type)); ++ io.read(raw.data(), raw.size()); ++ std::vector conv((size_t) n_el * ggml_type_size(live_type)); ++ if (saved_type == GGML_TYPE_BF16) { ++ ggml_bf16_to_fp32_row((const ggml_bf16_t *) raw.data(), (float *) conv.data(), n_el); ++ } else { ++ ggml_fp32_to_bf16_row((const float *) raw.data(), (ggml_bf16_t *) conv.data(), n_el); ++ } ++ ggml_backend_tensor_set(t, conv.data(), (size_t) head * ggml_row_size(live_type, n_embd), conv.size()); ++ return true; ++} ++ + bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t s_trans; + uint32_t n_layer; +@@ -1069,14 +1194,20 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell + int32_t r_type_i_ref; + io.read(&r_type_i_ref, sizeof(r_type_i_ref)); + const int32_t r_type_i = (int32_t) r_l[il]->type; +- if (r_type_i != r_type_i_ref) { +- LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il); +- return false; +- } + + // Read row size of key + uint64_t r_size_row_ref; + io.read(&r_size_row_ref, sizeof(r_size_row_ref)); ++ ++ if (r_type_i != r_type_i_ref) { ++ // back-compat: convert f32<->bf16 saved rows into the live cache dtype; else hard error. ++ if (!recurrent_read_convert_rows(io, r_l[il], (ggml_type) r_type_i_ref, r_size_row_ref, ++ hparams.n_embd_r(), head, cell_count, "r", il)) { ++ return false; ++ } ++ continue; ++ } ++ + const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); + if (r_size_row != r_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il); +@@ -1099,14 +1230,21 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell + io.read(&s_type_i_ref, sizeof(s_type_i_ref)); + const int32_t s_type_i = (int32_t)s_l[il]->type; + +- if (s_type_i != s_type_i_ref) { +- LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); +- return false; +- } +- + // Read row size of value + uint64_t s_size_row_ref; + io.read(&s_size_row_ref, sizeof(s_size_row_ref)); ++ ++ if (s_type_i != s_type_i_ref) { ++ // back-compat: convert f32<->bf16 saved rows into the live cache dtype; else hard error. ++ // The SSM-state opt-in flips this dtype, so an f32-saved session can restore into a ++ // bf16 cache (and vice versa) instead of failing the hard type match. ++ if (!recurrent_read_convert_rows(io, s_l[il], (ggml_type) s_type_i_ref, s_size_row_ref, ++ hparams.n_embd_s(), head, cell_count, "s", il)) { ++ return false; ++ } ++ continue; ++ } ++ + const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); + if (s_size_row != s_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il); +@@ -1241,6 +1379,18 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { + return mem->s_l[il]; + } + ++bool llama_memory_recurrent_context::is_s_hybrid(int32_t il) const { ++ return il >= 0 && il < (int32_t) mem->s_l_bf16.size() && mem->s_l_bf16[il] != nullptr; ++} ++ ++ggml_tensor * llama_memory_recurrent_context::get_s_l_bf16(int32_t il) const { ++ return mem->s_l_bf16[il]; ++} ++ ++ggml_tensor * llama_memory_recurrent_context::get_head_slot(int32_t il) const { ++ return mem->head_slot_l[il]; ++} ++ + int32_t llama_memory_recurrent_context::s_copy(int i) const { + const uint32_t cell_idx = i + mem->head; + const int32_t src0 = mem->cells[cell_idx].src0; +diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h +index b13b7b7..bd8bbd9 100644 +--- a/src/llama-memory-recurrent.h ++++ b/src/llama-memory-recurrent.h +@@ -24,7 +24,8 @@ public: + uint32_t mem_size, + uint32_t n_seq_max, + uint32_t n_rs_seq, +- const layer_filter_cb & filter); ++ const layer_filter_cb & filter, ++ float ssm_hybrid_tau_thresh = 0.0f); + + ~llama_memory_recurrent() = default; + +@@ -112,6 +113,19 @@ public: + std::vector r_l; + std::vector s_l; + ++ // hybrid per-head SSM-state precision (lever A). When a GDN layer has both f32 and bf16 heads, ++ // s_l[il] holds ONLY the f32-head partition ([S_v*S_v*n_s_f32, n_rows]) and s_l_bf16[il] holds ++ // the bf16-head partition ([S_v*S_v*n_s_bf16, n_rows], bf16). head_slot_l[il] is an I32[H] map: ++ // head h -> local_idx (>=0) in the f32 partition, or -(local_idx+1) (<0) in the bf16 partition. ++ // For non-hybrid layers (default, or n_s_bf16==0) s_l[il] holds ALL heads at f32 and the rest are ++ // null/0 (the existing single-cache path is taken, byte-identical). ssm_hybrid is true if ANY ++ // layer is split. ++ bool ssm_hybrid = false; ++ std::vector s_l_bf16; // bf16 partition per layer (null when layer not split) ++ std::vector head_slot_l; // I32[H] per-head partition map (null when layer not split) ++ std::vector n_s_f32; // f32-head count per layer ++ std::vector n_s_bf16; // bf16-head count per layer (0 => layer not split) ++ + private: + //const llama_model & model; + const llama_hparams & hparams; +@@ -171,6 +185,13 @@ public: + ggml_tensor * get_r_l(int32_t il) const; + ggml_tensor * get_s_l(int32_t il) const; + ++ // hybrid per-head SSM precision (lever A): the bf16 partition + the head_slot map for layer il, ++ // and whether the layer is split. When not split, get_s_l_bf16 / get_head_slot return null and ++ // is_s_hybrid is false (the caller takes the existing single-cache f32 path). ++ bool is_s_hybrid(int32_t il) const; ++ ggml_tensor * get_s_l_bf16(int32_t il) const; ++ ggml_tensor * get_head_slot(int32_t il) const; ++ + int32_t s_copy(int i) const; + + private: +diff --git a/src/llama-memory.h b/src/llama-memory.h +index db82539..f512283 100644 +--- a/src/llama-memory.h ++++ b/src/llama-memory.h +@@ -19,6 +19,13 @@ struct llama_memory_params { + ggml_type type_k; + ggml_type type_v; + ++ // recurrent state cache (conv state type_r, SSM state type_s) ++ ggml_type type_r; ++ ggml_type type_s; ++ ++ // hybrid per-head SSM-state precision threshold (lever A); 0.0 => all heads f32 (bit-exact) ++ float ssm_hybrid_tau_thresh; ++ + // use full-size SWA cache + bool swa_full; + +diff --git a/src/llama-model.cpp b/src/llama-model.cpp +index 6cb0ec3..d98b1d6 100644 +--- a/src/llama-model.cpp ++++ b/src/llama-model.cpp +@@ -2054,13 +2054,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, + if (llm_arch_is_recurrent(arch)) { + res = new llama_memory_recurrent( + *this, +- GGML_TYPE_F32, +- GGML_TYPE_F32, ++ params.type_r, ++ params.type_s, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max, + cparams.n_rs_seq, +- nullptr); ++ nullptr, ++ params.ssm_hybrid_tau_thresh); + } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { + // The main difference between hybrid architectures is the + // layer filters, so pick the right one here +@@ -2096,15 +2097,16 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, + /* attn_kv_size */ cparams.n_ctx_seq, + /* attn_n_ubatch */ cparams.n_ubatch, + /* attn_n_pad */ 1, +- /* recurrent_type_r */ GGML_TYPE_F32, +- /* recurrent_type_s */ GGML_TYPE_F32, ++ /* recurrent_type_r */ params.type_r, ++ /* recurrent_type_s */ params.type_s, + /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), +- /* filter_recr */ std::move(filter_recr)); ++ /* filter_recr */ std::move(filter_recr), ++ /* recurrent_ssm_hybrid_tau */ params.ssm_hybrid_tau_thresh); + } else { + res = new llama_memory_hybrid( + /* model */ *this, +@@ -2115,15 +2117,16 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, + /* attn_n_pad */ 1, + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, +- /* recurrent_type_k */ GGML_TYPE_F32, +- /* recurrent_type_v */ GGML_TYPE_F32, ++ /* recurrent_type_k */ params.type_r, // recurrent conv state (r), name is legacy ++ /* recurrent_type_v */ params.type_s, // recurrent SSM state (s), name is legacy + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), +- /* filter_recr */ std::move(filter_recr)); ++ /* filter_recr */ std::move(filter_recr), ++ /* recurrent_ssm_hybrid_tau */ params.ssm_hybrid_tau_thresh); + } + } else { + llama_kv_cache::layer_filter_cb filter = nullptr; +diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp +index 0eee804..58f3d0c 100644 +--- a/src/models/delta-net-base.cpp ++++ b/src/models/delta-net-base.cpp +@@ -600,6 +600,61 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( + const bool fused = (n_seq_tokens == 1) ? cparams.fused_gdn_ar : cparams.fused_gdn_ch; + + if (!keep && fused) { ++ // Hybrid per-head SSM state (lever A): when this GDN layer is split, the persisted state lives ++ // in an f32 partition (ssm_states_all = s_l[il], the long-memory heads) and a bf16 partition ++ // (s_l_bf16[il], the fast heads), with a head_slot map. build_rs handles the rs_zero clear + ++ // extra-states copy for the f32 partition; the bf16 partition's clear + extra are mirrored here ++ // (no extra gather). The fused hybrid op reads each head from its partition via ids and writes ++ // the final state in place per partition. Non-hybrid layers (default) take the existing path. ++ if (mctx_cur->is_s_hybrid(il)) { ++ ggml_tensor * ssm_bf16 = mctx_cur->get_s_l_bf16(il); ++ ggml_tensor * head_slot = mctx_cur->get_head_slot(il); ++ const int64_t n_embd_s_f32 = ssm_states_all->ne[0]; // S_v*S_v*n_f32 ++ const int64_t n_embd_s_bf16 = ssm_bf16->ne[0]; // S_v*S_v*n_bf16 ++ const int64_t n_f32_heads = n_embd_s_f32 / (S_v * S_v); ++ const int64_t n_bf16_heads = n_embd_s_bf16 / (S_v * S_v); ++ ++ // Mirror build_rs's rs_zero clear + extra-states copy for the bf16 partition (side effects ++ // only; the bf16 cache starts zeroed, this handles slot restart/reuse across the batch). ++ { ++ ggml_tensor * states_bf = ggml_reshape_2d(ctx0, ssm_bf16, n_embd_s_bf16, ssm_bf16->ne[1]); ++ const int32_t rs_zero = inp->rs_z; ++ ggml_tensor * state_zero = ggml_view_1d(ctx0, states_bf, n_embd_s_bf16 * (rs_zero >= 0), ++ rs_zero * states_bf->nb[1] * (rs_zero >= 0)); ++ ggml_build_forward_expand(gf, ggml_fill_inplace(ctx0, state_zero, 0.0f)); ++ const uint32_t n_rs = mctx_cur->get_n_rs(); ++ ggml_tensor * states_extra = ggml_get_rows(ctx0, states_bf, inp->s_copy_extra); ++ ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, ++ ggml_view_2d(ctx0, ssm_bf16, n_embd_s_bf16, (n_rs - n_seqs), ssm_bf16->nb[1], ++ (kv_head + n_seqs) * ssm_bf16->nb[1]))); ++ } ++ ++ auto get_state_op_h = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) -> ggml_tensor * { ++ ggml_tensor * cache4d_f32 = ggml_reshape_4d(ctx, states, S_v, S_v, n_f32_heads, states->ne[1]); ++ ggml_tensor * cache4d_bf16 = ggml_reshape_4d(ctx, ssm_bf16, S_v, S_v, n_bf16_heads, ssm_bf16->ne[1]); ++ ggml_tensor * state_dst_f32 = ggml_view_2d(ctx, ssm_states_all, n_embd_s_f32, n_seqs, ++ ssm_states_all->nb[1], kv_head * n_embd_s_f32 * ggml_element_size(ssm_states_all)); ++ return ggml_gated_delta_net_inplace_ids_hybrid(ctx, q, k, v, g, b, ++ cache4d_f32, state_dst_f32, ids, cache4d_bf16, head_slot, (int) kv_head); ++ }; ++ ++ ggml_tensor * result = build_rs(inp, ssm_states_all, (int) n_embd_s_f32, n_seqs, get_state_op_h); ++ if (n_seq_tokens == 1) { ++ cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); ++ } else { ++ cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il); ++ } ++ ++ ggml_tensor * output = ggml_view_4d(ctx0, result, ++ S_v, H_v, n_seq_tokens, n_seqs, ++ ggml_row_size(result->type, S_v), ++ ggml_row_size(result->type, S_v * H_v), ++ ggml_row_size(result->type, S_v * H_v * n_seq_tokens), 0); ++ cb(output, "attn_output", il); ++ ggml_build_forward_expand(gf, output); ++ return output; ++ } ++ + // build_rs feeds the FULL state cache + the s_copy ids into the op (via the get_state_rows + // lambda, exactly like mamba-base's ggml_ssm_scan) and still performs the rs_zero clear and + // the extra-states copy around it. The op reads curr_state from cache[ids[seq]] and writes +diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp +index 41babb8..b5e3048 100644 +--- a/tests/test-backend-ops.cpp ++++ b/tests/test-backend-ops.cpp +@@ -3901,6 +3901,7 @@ struct test_rwkv_wkv6 : public test_case { + // GGML_OP_GATED_DELTA_NET + struct test_gated_delta_net : public test_case { + const ggml_type type; ++ const ggml_type state_type; // recurrent SSM state cache dtype: f32 (default) or bf16 (opt-in); q/k/v/g/beta stay `type` + + const int64_t head_count; + const int64_t head_size; +@@ -3912,13 +3913,15 @@ struct test_gated_delta_net : public test_case { + const int64_t K; // snapshot slot count: 1 = final-only, >1 = last K states + + std::string vars() override { +- return VARS_TO_STR9(type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda, K); ++ return VARS_TO_STR10(type, state_type, head_count, head_size, n_seq_tokens, n_seqs, v_repeat, permuted, kda, K); + } + + test_gated_delta_net(ggml_type type = GGML_TYPE_F32, + int64_t head_count = 4, int64_t head_size = 16, int64_t n_seq_tokens = 1, int64_t n_seqs = 1, +- int v_repeat = 1, bool permuted = false, bool kda = false, int64_t K = 1) +- : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), ++ int v_repeat = 1, bool permuted = false, bool kda = false, int64_t K = 1, ++ ggml_type state_type = GGML_TYPE_F32) ++ : type(type), state_type(state_type), head_count(head_count), head_size(head_size), ++ n_seq_tokens(n_seq_tokens), n_seqs(n_seqs), + v_repeat(v_repeat), permuted(permuted), kda(kda), K(K) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { +@@ -3941,7 +3944,7 @@ struct test_gated_delta_net : public test_case { + const int64_t g_ne0 = kda ? head_size : 1; + ggml_tensor * g = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs); + ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs); +- ggml_tensor * state = ggml_new_tensor_4d(ctx, type, head_size, head_size, head_count * v_repeat, n_seqs); ++ ggml_tensor * state = ggml_new_tensor_4d(ctx, state_type, head_size, head_size, head_count * v_repeat, n_seqs); + ggml_set_name(g, "g"); + ggml_set_name(beta, "beta"); + ggml_set_name(state, "state"); +@@ -3968,6 +3971,85 @@ struct test_gated_delta_net : public test_case { + } + }; + ++// GGML_OP_GATED_DELTA_NET - hybrid per-head mixed-dtype state (lever A). The persisted state is split ++// into an f32 partition (long-memory heads) and a bf16 partition (fast heads); head_slot maps each ++// natural-order head to its partition + local index. Activations (q/k/v/g/beta) stay f32 and keep ++// natural head order; only the per-head load crosses the dtype boundary. This is the de-risk net for ++// the mixed-dtype op/kernel/CPU mirror (CUDA mixed vs CPU mixed), covering single-token decode, ++// multi-token prefill/chunk, and the keep_rs_t (K>1) snapshot path. head_count must be even; heads ++// at even index are f32, odd index are bf16 (interleaved => both branches exercised across blocks). ++struct test_gated_delta_net_hybrid : public test_case { ++ const int64_t head_count; ++ const int64_t head_size; ++ const int64_t n_seq_tokens; ++ const int64_t n_seqs; ++ const bool kda; ++ const int64_t K; ++ ++ std::string vars() override { ++ return VARS_TO_STR6(head_count, head_size, n_seq_tokens, n_seqs, kda, K); ++ } ++ ++ test_gated_delta_net_hybrid(int64_t head_count = 4, int64_t head_size = 64, ++ int64_t n_seq_tokens = 1, int64_t n_seqs = 1, bool kda = false, int64_t K = 1) ++ : head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), ++ n_seqs(n_seqs), kda(kda), K(K) {} ++ ++ int64_t n_bf16_heads() const { return head_count / 2; } ++ int64_t n_f32_heads() const { return head_count - head_count / 2; } ++ ++ ggml_tensor * build_graph(ggml_context * ctx) override { ++ ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count, n_seq_tokens, n_seqs); ++ ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count, n_seq_tokens, n_seqs); ++ ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count, n_seq_tokens, n_seqs); ++ ggml_set_name(q, "q"); ++ ggml_set_name(k, "k"); ++ ggml_set_name(v, "v"); ++ const int64_t g_ne0 = kda ? head_size : 1; ++ ggml_tensor * g = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, g_ne0, head_count, n_seq_tokens, n_seqs); ++ ggml_tensor * beta = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, head_count, n_seq_tokens, n_seqs); ++ ggml_set_name(g, "g"); ++ ggml_set_name(beta, "beta"); ++ ++ ggml_tensor * state_f32 = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_size, n_f32_heads(), n_seqs); ++ ggml_tensor * state_bf16 = ggml_new_tensor_4d(ctx, GGML_TYPE_BF16, head_size, head_size, n_bf16_heads(), n_seqs); ++ ggml_tensor * head_slot = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, head_count); ++ ggml_set_name(state_f32, "state_f32"); ++ ggml_set_name(state_bf16, "state_bf16"); ++ ggml_set_name(head_slot, "head_slot"); ++ ++ // q/k are L2-normalised in qwen35/kimi-linear before delta_net ++ q = ggml_l2_norm(ctx, q, 1e-6f); ++ k = ggml_l2_norm(ctx, k, 1e-6f); ++ ggml_tensor * out = ggml_gated_delta_net_hybrid(ctx, q, k, v, g, beta, state_f32, state_bf16, head_slot, K); ++ return out; ++ } ++ ++ void initialize_tensors(ggml_context * ctx) override { ++ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { ++ if (ggml_is_view_op(t->op)) { continue; } ++ if (strcmp(t->name, "head_slot") == 0) { ++ // deterministic interleaved map: even head -> f32 (local_idx), odd head -> bf16 (-(idx+1)) ++ std::vector hs(head_count); ++ int32_t nf = 0, nb = 0; ++ for (int64_t h = 0; h < head_count; ++h) { ++ if (h % 2 == 1) { hs[h] = -(nb + 1); ++nb; } ++ else { hs[h] = nf; ++nf; } ++ } ++ ggml_backend_tensor_set(t, hs.data(), 0, head_count * sizeof(int32_t)); ++ } else if (strcmp(t->name, "g") == 0) { ++ init_tensor_uniform(t, -20.0f, -1e-4f); ++ } else if (strcmp(t->name, "beta") == 0) { ++ init_tensor_uniform(t, 0.0f, 1.0f); ++ } else if (strcmp(t->name, "v") == 0) { ++ init_tensor_uniform(t, -0.3f, 5.0f); ++ } else { ++ init_tensor_uniform(t); ++ } ++ } ++ } ++}; ++ + // GGML_OP_GATED_LINEAR_ATTN + struct test_gla : public test_case { + const ggml_type type; +@@ -9316,6 +9398,45 @@ static std::vector> make_test_cases_eval() { + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 8, 1, 1, false, false, /*K=*/3)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 16, 2, 1, false, false, /*K=*/4)); + ++ // bf16 recurrent-state cache (opt-in width halving). Activations (q/k/v/g/beta) stay f32; only the ++ // persisted SSM state is bf16: load->f32, recurrence math in f32, store->bf16. De-risk gate for the ++ // dtype-generic kernel/op: CUDA bf16-state vs CPU bf16-state reference across the three regimes that ++ // touch the state path differently — single-token decode, multi-token prefill/chunk, and the ++ // keep_rs_t==true (K>1) snapshot path (the prefill->decode handoff landmine, plan risk R2). ++ for (int64_t hs : {64, 128}) { ++ // single-token decode ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, hs, 1, 1, 1, false, false, 1, GGML_TYPE_BF16)); ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, hs, 1, 2, 1, false, false, 1, GGML_TYPE_BF16)); ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, hs, 1, 1, 1, false, true, 1, GGML_TYPE_BF16)); // kda ++ // multi-token prefill / chunk (n_tokens > 1) ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, hs, 33, 1, 1, false, false, 1, GGML_TYPE_BF16)); ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, hs, 64, 2, 1, false, false, 1, GGML_TYPE_BF16)); ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, hs, 100, 1, 1, false, true, 1, GGML_TYPE_BF16)); // kda ++ // keep_rs_t == true (K>1 snapshot path) ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, hs, 4, 1, 1, false, false, 4, GGML_TYPE_BF16)); ++ test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, hs, 8, 2, 1, false, true, 4, GGML_TYPE_BF16)); // kda ++ } ++ ++ // Hybrid per-head mixed-dtype state (lever A) de-risk net: some heads f32, some bf16 (interleaved), ++ // CUDA mixed vs CPU mixed, across single-token decode, multi-token prefill/chunk, and keep_rs_t. ++ // head_count even (half f32 / half bf16); head_size 64 + 128 (production S_v). The recurrence math ++ // stays f32; this validates the per-head partition load wiring on both backends. ++ for (int64_t hs : {64, 128}) { ++ for (int64_t hc : {4, 8}) { ++ // single-token decode ++ test_cases.emplace_back(new test_gated_delta_net_hybrid(hc, hs, 1, 1, false, 1)); ++ test_cases.emplace_back(new test_gated_delta_net_hybrid(hc, hs, 1, 2, false, 1)); ++ test_cases.emplace_back(new test_gated_delta_net_hybrid(hc, hs, 1, 1, true, 1)); // kda ++ // multi-token prefill / chunk (n_tokens > 1) ++ test_cases.emplace_back(new test_gated_delta_net_hybrid(hc, hs, 33, 1, false, 1)); ++ test_cases.emplace_back(new test_gated_delta_net_hybrid(hc, hs, 64, 2, false, 1)); ++ test_cases.emplace_back(new test_gated_delta_net_hybrid(hc, hs, 100, 1, true, 1)); // kda ++ // keep_rs_t == true (K>1 snapshot path) ++ test_cases.emplace_back(new test_gated_delta_net_hybrid(hc, hs, 4, 1, false, 4)); ++ test_cases.emplace_back(new test_gated_delta_net_hybrid(hc, hs, 8, 2, true, 4)); // kda ++ } ++ } ++ + #if 0 + // these tests are disabled to save execution time, sbut they can be handy for debugging + test_cases.emplace_back(new test_llama(2, true)); diff --git a/backend/cpp/llama-cpp/patches/paged/A_HYBRID_PROGRESS.md b/backend/cpp/llama-cpp/patches/paged/A_HYBRID_PROGRESS.md new file mode 100644 index 000000000..2a6f465e8 --- /dev/null +++ b/backend/cpp/llama-cpp/patches/paged/A_HYBRID_PROGRESS.md @@ -0,0 +1,48 @@ +# A-build: hybrid per-head f32/bf16 SSM state - BUILD PROGRESS + +Label: A-build (GPU agent). Base: DGX `~/llama-paged-dev` branch `paged` HEAD 2f4f5ab (patch 0025), +plus `BF16_SSM_STATE.diff` applied as the bf16 plumbing base. Goal: per-head mixed-dtype SSM state +(f32 long-memory heads, bf16 fast heads); default `ssm_hybrid_tau_thresh=inf` (all-f32, bit-exact). + +## Design recap (from SPEEDUP_HUNT.md A-hybrid-design) +- Classifier (host, model-load): tau_h = 1/(|ssm_a[il][h]| * softplus(ssm_dt[il][h])); f32 if tau_h>T. + ssm_a = SSM_A_NOSCAN = -exp(A_log) (verified qwen35.cpp:376). ssm_dt = SSM_DT bias. +- Split cache: per GDN layer, s_l (f32, n_f32 heads) + s_l_bf16 (bf16, n_bf16 heads). head_slot map. +- Kernel: ONE kernel templated +HYBRID; per-block (h_idx) branch on head_slot (uniform, no divergence). + Recurrence math byte-for-byte f32-register, untouched. Homogeneous (HYBRID=false) path bit-exact. +- Op: extra src[8]=state_bf16, src[9]=head_slot; backend detects hybrid = (src[9]!=null). +- CPU mirror: per-head partition read. +- test-backend-ops: MIXED case (some heads f32, some bf16) output-append, decode+prefill+keep_rs_t. + +## DE-RISK GATE (must pass before sweep) +1. test-backend-ops GATED_DELTA_NET mixed PASS (CUDA mixed vs CPU mixed). +2. T_thresh=inf greedy md5 == 0023 baseline: dense 5951a5b4d624ce891e22ab5fca9bc439, + MoE 07db32c2bcb78d17a43ed18bc22705cd. + +## KNOB SEMANTICS (IMPORTANT - brief endpoint wording corrected) +Rule (brief verbatim + physics + "start 32-64" guidance all agree): a head is kept f32 iff +tau_h > T_thresh, else bf16. tau_h = 1/(|ssm_a|*softplus(ssm_dt)) in tokens. Long-memory (large tau) +heads stay f32 (bf16 rounding does not contract there -> KL); fast (small tau) heads -> bf16. +- ssm_hybrid_tau_thresh DEFAULT = 0.0 => every tau>0 -> ALL F32 (bit-exact opt-out; the gate runs here). +- ssm_hybrid_tau_thresh -> +inf => ALL BF16 (shelved mode). +- sweep: raise T (16/32/64/128 tokens) to bf16 progressively more (longer-memory) heads = more speed. +NOTE: the brief's "inf=>all-f32, 0=>all-bf16" sentence is INVERTED vs the operative rule it states +("keep f32 if tau>T") and vs "start 32-64" + the physics. Correct endpoints: 0=all-f32, inf=all-bf16. +Implemented the physically-correct rule; default 0.0 = bit-exact all-f32. + +## STATUS +- [x] ggml.h/ggml.c hybrid op builders +- [x] gated_delta_net.cu hybrid kernel + dispatch (one kernel, +HYBRID template, uniform per-block branch) +- [x] ops.cpp CPU hybrid read mirror (output-append; ids in-place is GPU-only, asserted) +- [x] test-backend-ops mixed case (32 cases: hc 4/8 x hs 64/128 x decode/prefill/keep_rs_t x kda) +- [x] de-risk gate 1: test-backend-ops GATED_DELTA_NET = 84/84 PASS (incl 32 hybrid mixed CUDA-vs-CPU) +- [x] cparam/CLI ssm_hybrid_tau_thresh plumbing (default 0.0; threaded context->cparams->memory->ctors) +- [x] memory-recurrent split cache + classifier (validated: real tau split, correct 2-partition layout) +- [x] delta-net-base hybrid op build (fused ids decode + bf16 rs_zero/extra mirror) +- [x] full build clean (sm_121; llama-completion/batched-bench/perplexity/test-backend-ops) +- [x] de-risk gate 2 (default/all-f32 md5 == 0023 both models, re-verified post-build) +- [~] hybrid-ON smoke: RUNS (no crash) + classifier/cache/kernel-params verified, but OUTPUT INCOHERENT + => OPEN BUG in the ids in-place cross-step state path (opt-in only; default unaffected). See + A_HYBRID_SSM_RESULTS.md. NOT ready for the sweep until fixed. + +Committed: DGX paged 657e008; worktree patch 0026 + A_HYBRID_SSM_RESULTS.md. diff --git a/backend/cpp/llama-cpp/patches/paged/A_HYBRID_SSM_RESULTS.md b/backend/cpp/llama-cpp/patches/paged/A_HYBRID_SSM_RESULTS.md new file mode 100644 index 000000000..d589c4905 --- /dev/null +++ b/backend/cpp/llama-cpp/patches/paged/A_HYBRID_SSM_RESULTS.md @@ -0,0 +1,90 @@ +# A - HYBRID PER-HEAD f32/bf16 SSM STATE - BUILD + DE-RISK RESULTS + +Label: A-build (the GPU build agent). Lands as patch 0026 on top of 0025 (DGX HEAD 2f4f5ab), +incorporating the bf16-SSM-state plumbing (`BF16_SSM_STATE.diff`) as the base. Built into +`~/llama-paged-dev/build-cuda` (sm_121); committed on the DGX `paged` branch (657e008) and as +`patches/paged/0026-qwen35-hybrid-perhead-ssm-state.patch` + this doc in the worktree. + +## DE-RISK GATE - both required gates PASS + +### Gate 1: test-backend-ops MIXED GATED_DELTA_NET (CUDA mixed vs CPU mixed) +`./bin/test-backend-ops -o GATED_DELTA_NET -b CUDA0` = **84/84 PASS, CUDA0 OK**. This includes the +**32 new hybrid mixed-dtype cases** (`test_gated_delta_net_hybrid`): head_count {4,8} x head_size +{64,128} x {single-token decode, multi-token prefill 33/64/100, keep_rs_t K=4} x kda {0,1}, with an +interleaved head_slot map (even heads f32, odd heads bf16) so both partition branches are exercised +across blocks. CUDA mixed vs CPU mixed agree. (Plus the pre-existing 52 f32 + bf16 cases still pass.) + +### Gate 2: T_thresh=inf (default, all-f32) greedy md5 == 0023 baseline - BOTH MODELS +`llama-completion -ngl 99 -fa on -p "The capital of France is" -n 48 --temp 0 --seed 1`, NO +`--ssm-bf16-tau` flag (default 0.0 => every head f32 => no split => the existing single-cache path): +- dense q36-27b-nvfp4: `5951a5b4d624ce891e22ab5fca9bc439` == 0023 baseline. +- MoE q36-35b-a3b-nvfp4: `07db32c2bcb78d17a43ed18bc22705cd` == 0023 baseline. +Re-verified byte-identical AFTER the full build with every plumbing edit in place. **The bit-exact +opt-out is preserved.** + +## KNOB SEMANTICS (brief endpoint wording corrected) +`ssm_hybrid_tau_thresh` / `--ssm-bf16-tau` T: a gated-DeltaNet head is kept **f32 iff tau_h > T**, +else bf16. `tau_h = 1/(|ssm_a[il][h]| * softplus(ssm_dt[il][h]))` tokens (ssm_a = SSM_A_NOSCAN = +-exp(A_log), verified qwen35.cpp:376; ssm_dt = SSM_DT bias). This is the brief's operative rule + the +"start 32-64" guidance + the physics (long-memory/large-tau heads stay f32; fast/small-tau heads -> +bf16). Endpoints: +- **T = 0.0 (DEFAULT) => every tau>0 -> ALL F32 (bit-exact opt-out; the gate runs here).** +- **T -> +inf => ALL BF16 (shelved mode).** +- sweep T in {16,32,64,128} bf16's progressively more (longer-memory) heads = more speed. + +NOTE: the brief's "inf=>all-f32, 0=>all-bf16" sentence is INVERTED relative to the rule it states +("keep f32 if tau>T") and to "start 32-64" + the physics. The physically-correct rule is implemented; +the bit-exact all-f32 mode is the DEFAULT (T=0), which is exactly what the de-risk gate exercises. + +## What was built (all components, validated correct) +1. **Classifier** (llama-memory-recurrent ctor, host, model-load): reads ssm_a/ssm_dt per GDN layer, + computes tau_h, sets head_is_bf16. VALIDATED on dense q27 (H_v=48, S_v=128): real per-head tau + spread min~0.2-0.5 / max~800-26000 tokens; at T=32 the split is ~13-31 f32 / 17-35 bf16 per layer. + Guarded against the device-memory-fitting pre-pass (weights not yet allocated => data==NULL => + fall back to single f32 cache, a conservative/larger memory estimate; real load classifies). +2. **Split cache** (llama-memory-recurrent): per split GDN layer, s_l[il] holds the f32 partition + [S_v*S_v*n_f32, n_rows] and s_l_bf16[il] the bf16 partition [S_v*S_v*n_bf16, n_rows] + an I32[H] + head_slot map (local_idx>=0 f32, -(local_idx+1)<0 bf16), uploaded after buffer alloc. ctx metadata + budget bumped 2->4 tensors/layer (r, s_f32, s_bf16, head_slot). VALIDATED: cache layout correct + (f32/bf16 partitions 2MB apart, non-overlapping; sizes match counts). +3. **Kernel** (gated_delta_net.cu): ONE kernel templated +HYBRID; per-block (h_idx) branch on + head_slot picks the partition + local index (uniform within a block => no warp divergence). The + homogeneous (HYBRID=false) instantiations are byte-identical to before (if constexpr elides the + hybrid blocks). Two builders: ggml_gated_delta_net_hybrid (output-append, for the test) and + ggml_gated_delta_net_inplace_ids_hybrid (decode). Backend detects hybrid = src[9]!=null; gathers + both partitions for non-identity seqs; derives the bf16 in-place dst from src[8]+rs_head. +4. **CPU mirror** (ops.cpp): per-head partition read for the output-append form (the test path). +5. **Plumbing**: cparam ssm_hybrid_tau_thresh threaded llama_context_params -> cparams -> + llama_memory_params -> recurrent/hybrid/iswa ctors; common_params + CLI --ssm-bf16-tau (default 0). +6. **test-backend-ops**: the 32 mixed cases above. + +## KNOWN OPEN ISSUE - hybrid-ON decode is incoherent (opt-in only; does NOT affect the default) +With `--ssm-bf16-tau` > 0 (any split, even tau=1 = a handful of bf16 heads), the model generates +incoherent text (" the the the > EOF"). The bit-exact all-f32 default is UNAFFECTED (gate 2). + +Diagnosis (everything reachable by inspection was verified correct): +- The op-level MIXED test PASSES, but it only covers the **output-append** form (state read from the + s0 input partitions, write to the f32 op output). The model decode uses the **ids in-place** form: + read from the in-place cache partition (identity), write the new state in place per partition. That + cross-step state path is NOT exercised by a single-op test (the in-place state write is a side + effect, not the compared op output), so it is the only un-netted surface - and that is where the bug + lives. +- Confirmed correct at runtime: the classifier (real tau split), the split cache layout (partitions + 2MB apart, sizes match), and the exact kernel parameters (H=48, S_v=128, n_f32+n_bf16=H, head_slot + values, ids/state_dst/state_bf16 pointers all sane). The hybrid op IS built and dispatched (not a + homogeneous fallback). Garbage persists with CUDA graphs disabled, so it is not a graph-capture + issue. The recurrence math is shared with the (passing) output-append path. +- The bug is therefore in the ids in-place cross-step state handling (identity-d read and/or in-place + partition store, and/or the bf16 partition rs_zero/extra-states mirroring in delta-net-base) - a + state-corruption that cascades. It needs a multi-step reproduction (the single-op harness cannot + catch a cross-step in-place bug; the homogeneous in-place ids op itself has no op test - it was only + ever validated by model md5). + +## NOT ready for the GateSweep yet +The de-risk gates (mixed op test + bit-exact default) BOTH PASS, but the hybrid-ON path must be made +coherent before the T_thresh KL/throughput sweep can produce meaningful numbers. Recommended next +step: build a minimal 2-step in-place reproduction (CPU ids-in-place hybrid mirror + a decode-loop +harness, or a kernel-side state dump comparing hybrid vs homogeneous for an all-f32-disguised split) +to localize identity-d-read vs in-place-store vs the bf16 clear/extra mirror. + +Assisted-by: Claude:opus-4.8 [Claude Code]