fix(paged): serialize both SSM partitions in hybrid bf16-tau state save/restore (patch 0026)

The opt-in ssm_bf16_tau hybrid mode splits a gated-DeltaNet layer's
recurrent SSM state into an f32 partition (s_l) and a bf16 partition
(s_l_bf16). The recurrent state serialization paths (state_write_data /
state_read_data) were never updated for the split: they read/wrote s_l
using the FULL hparams.n_embd_s() (S_v*S_v*H) row width, but a split
layer's s_l only holds S_v*S_v*n_f32, so the access overruns the smaller
tensor (a ggml_backend tensor read out of bounds), and the bf16
fast-head partition was never persisted at all.

This is what broke high-concurrency serving with --ssm-bf16-tau: the
server's context-checkpoint feature serializes per-sequence state via
state_seq_get_data. With a checkpoint enabled, even a single request
triggered the out-of-bounds read; at higher concurrency the cell range
starts at a higher base slot so the overrun reaches further (hard abort
in a debug build, silent state corruption then 1-token-then-EOS on
restore in a release build). The static batched-bench never exercises
save/restore so it did not catch it; the GDN decode kernel and per-head
partition offsets were already correct (decode with checkpoints disabled
is fine at N=8/16/32).

Fix: serialize the f32 partition and, when the layer is split, the bf16
partition right after it, each with its OWN row width (tensor ne[0]).
head_slot is rebuilt deterministically at load (same model + tau), so it
is not serialized. Non-split layers have ne[0] == n_embd_s() and no bf16
partition, so their on-disk format and behavior are byte-identical (the
default f32 path and the bit-exact gate are unaffected).

Verified on GB10/DGX with Qwen3.6-35B-A3B-NVFP4 + --ssm-bf16-tau 64 via a
continuous-batching llama-server: with context checkpoints enabled, N=8,
N=16 and N=32 (slot reuse + restore) all now produce full coherent
128-token output and the server stays up; pre-fix the same config
aborted on the first checkpoint.

Assisted-by: Claude:claude-opus-4-8[1m] [Claude Code]
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-06-28 07:47:17 +00:00
parent 4da769c1ca
commit 1f3e5ba301

View File

@@ -1,8 +1,8 @@
diff --git a/common/arg.cpp b/common/arg.cpp
index 841a38e..3e05bd4 100644
index 841ca3c..0b5b6ec 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -2157,6 +2157,47 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
@@ -2194,6 +2194,47 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.cache_type_v = kv_cache_type_from_str(value);
}
).set_env("LLAMA_ARG_CACHE_TYPE_V"));
@@ -54,23 +54,13 @@ diff --git a/common/common.cpp b/common/common.cpp
index a14e7bb..c4ab884 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1600,6 +1600,19 @@ struct llama_context_params common_context_params_to_llama(const common_params &
@@ -1600,6 +1600,9 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;
+ cparams.type_r = params.cache_type_conv;
+ cparams.type_s = params.cache_type_ssm;
+ cparams.ssm_hybrid_tau_thresh = params.ssm_hybrid_tau_thresh;
+ // LocalAI per-model option hook: when the --ssm-bf16-tau CLI flag is at its bit-exact
+ // default (0), honor LLAMA_SSM_BF16_TAU (set by the grpc-server from the model YAML
+ // `options: [ssm_bf16_tau:N]`) so the reduced-precision hybrid fast mode is selectable
+ // per model without a process-wide CLI flag. Absent/non-positive env => untouched, so
+ // stock stays bit-exact; the CLI flag, when set, takes precedence.
+ if (cparams.ssm_hybrid_tau_thresh == 0.0f) {
+ if (const char * tau_env = std::getenv("LLAMA_SSM_BF16_TAU")) {
+ try { cparams.ssm_hybrid_tau_thresh = std::stof(tau_env); } catch (...) {}
+ }
+ }
return cparams;
}
@@ -1358,7 +1348,7 @@ index 484eafb..46618d3 100644
~llama_memory_hybrid() = default;
diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp
index 6a4892f..aae57a4 100644
index 6a4892f..5aba1e4 100644
--- a/src/llama-memory-recurrent.cpp
+++ b/src/llama-memory-recurrent.cpp
@@ -8,6 +8,7 @@
@@ -1529,7 +1519,52 @@ index 6a4892f..aae57a4 100644
return size_s_bytes;
}
@@ -1041,6 +1136,44 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
@@ -892,24 +987,33 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
}
if (!s_trans) {
- for (uint32_t il = 0; il < n_layer; ++il) {
- // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
- if (s_l[il] == nullptr) continue;
-
- // Write S tensor type
- const int32_t s_type_i = (int32_t)s_l[il]->type;
+ // Hybrid per-head SSM split (lever A): a split layer's SSM state lives in an f32 partition
+ // (s_l, S_v*S_v*n_f32 wide) PLUS a bf16 partition (s_l_bf16, S_v*S_v*n_bf16 wide). BOTH must be
+ // serialized, each with its OWN row width (tensor ne[0]) - using hparams.n_embd_s() (the full
+ // S_v*S_v*H) over-reads the smaller f32 partition tensor out of bounds (the serving crash /
+ // garbage-on-restore at high concurrency, where the cell range starts at a higher base) and
+ // silently drops the fast-head bf16 state. head_slot is rebuilt deterministically at load (same
+ // model + tau) so it is not serialized. Non-split layers have ne[0] == n_embd_s(), so their
+ // on-disk format and behavior are byte-identical to before.
+ auto write_s_partition = [&](ggml_tensor * s) {
+ const int32_t s_type_i = (int32_t) s->type;
io.write(&s_type_i, sizeof(s_type_i));
-
- // Write row size of S tensor
- const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
+ const uint64_t s_size_row = ggml_row_size(s->type, s->ne[0]);
io.write(&s_size_row, sizeof(s_size_row));
-
// Write each logical cell row range. With pending recurrent rollback,
// the logical current state may live in a rollback snapshot plane.
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * s_size_row;
- io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
+ io.write_tensor(s, range.first * s_size_row, buf_size);
+ }
+ };
+ for (uint32_t il = 0; il < n_layer; ++il) {
+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
+ if (s_l[il] == nullptr) continue;
+ write_s_partition(s_l[il]);
+ if (s_l_bf16[il] != nullptr) {
+ write_s_partition(s_l_bf16[il]);
}
}
} else {
@@ -1041,6 +1145,44 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
return true;
}
@@ -1574,7 +1609,7 @@ index 6a4892f..aae57a4 100644
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
uint32_t s_trans;
uint32_t n_layer;
@@ -1069,14 +1202,20 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
@@ -1069,14 +1211,20 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
int32_t r_type_i_ref;
io.read(&r_type_i_ref, sizeof(r_type_i_ref));
const int32_t r_type_i = (int32_t) r_l[il]->type;
@@ -1599,34 +1634,71 @@ index 6a4892f..aae57a4 100644
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
if (r_size_row != r_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
@@ -1099,14 +1238,21 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
const int32_t s_type_i = (int32_t)s_l[il]->type;
@@ -1090,32 +1238,50 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
}
if (!s_trans) {
- for (uint32_t il = 0; il < n_layer; ++il) {
- // skip null layers
- if (s_l[il] == nullptr) continue;
-
+ // Hybrid per-head SSM split (lever A): mirror state_write_data - read the f32 partition (s_l)
+ // and, when the layer is split, the bf16 partition (s_l_bf16) right after it, each with its OWN
+ // row width (tensor ne[0]). For non-split layers ne[0] == n_embd_s(), so the on-disk format and
+ // behavior are unchanged. The f32<->bf16 dtype back-compat conversion still applies per partition.
+ auto read_s_partition = [&](ggml_tensor * s, uint32_t n_embd, uint32_t il) -> bool {
// Read type of value
int32_t s_type_i_ref;
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
- const int32_t s_type_i = (int32_t)s_l[il]->type;
-
- if (s_type_i != s_type_i_ref) {
- LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
- return false;
- }
-
+ const int32_t s_type_i = (int32_t) s->type;
// Read row size of value
uint64_t s_size_row_ref;
io.read(&s_size_row_ref, sizeof(s_size_row_ref));
- const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
+
+ if (s_type_i != s_type_i_ref) {
+ // back-compat: convert f32<->bf16 saved rows into the live cache dtype; else hard error.
+ // The SSM-state opt-in flips this dtype, so an f32-saved session can restore into a
+ // bf16 cache (and vice versa) instead of failing the hard type match.
+ if (!recurrent_read_convert_rows(io, s_l[il], (ggml_type) s_type_i_ref, s_size_row_ref,
+ hparams.n_embd_s(), head, cell_count, "s", il)) {
+ return false;
+ }
+ continue;
+ return recurrent_read_convert_rows(io, s, (ggml_type) s_type_i_ref, s_size_row_ref,
+ n_embd, head, cell_count, "s", (int) il);
+ }
+
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
+ const size_t s_size_row = ggml_row_size(s->type, n_embd);
if (s_size_row != s_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
@@ -1241,6 +1387,18 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
- LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
+ LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, (int) il);
return false;
}
if (cell_count) {
// Read and set the values for the whole cell range
- io.read_tensor(s_l[il], head * s_size_row, cell_count * s_size_row);
+ io.read_tensor(s, head * s_size_row, cell_count * s_size_row);
+ }
+ return true;
+ };
+ for (uint32_t il = 0; il < n_layer; ++il) {
+ // skip null layers
+ if (s_l[il] == nullptr) continue;
+ if (!read_s_partition(s_l[il], (uint32_t) s_l[il]->ne[0], il)) {
+ return false;
+ }
+ if (s_l_bf16[il] != nullptr) {
+ if (!read_s_partition(s_l_bf16[il], (uint32_t) s_l_bf16[il]->ne[0], il)) {
+ return false;
+ }
}
}
} else {
@@ -1241,6 +1407,18 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
return mem->s_l[il];
}
@@ -1848,7 +1920,7 @@ index 0eee804..58f3d0c 100644
// lambda, exactly like mamba-base's ggml_ssm_scan) and still performs the rs_zero clear and
// the extra-states copy around it. The op reads curr_state from cache[ids[seq]] and writes
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 41babb8..b5e3048 100644
index 330d936..188d7b3 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -3901,6 +3901,7 @@ struct test_rwkv_wkv6 : public test_case {
@@ -1973,7 +2045,7 @@ index 41babb8..b5e3048 100644
// GGML_OP_GATED_LINEAR_ATTN
struct test_gla : public test_case {
const ggml_type type;
@@ -9316,6 +9398,45 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
@@ -9325,6 +9407,45 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 8, 1, 1, false, false, /*K=*/3));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 16, 2, 1, false, false, /*K=*/4));