mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-28 10:27:30 -04:00
fix(paged): serialize both SSM partitions in hybrid bf16-tau state save/restore (patch 0026)
The opt-in ssm_bf16_tau hybrid mode splits a gated-DeltaNet layer's recurrent SSM state into an f32 partition (s_l) and a bf16 partition (s_l_bf16). The recurrent state serialization paths (state_write_data / state_read_data) were never updated for the split: they read/wrote s_l using the FULL hparams.n_embd_s() (S_v*S_v*H) row width, but a split layer's s_l only holds S_v*S_v*n_f32, so the access overruns the smaller tensor (a ggml_backend tensor read out of bounds), and the bf16 fast-head partition was never persisted at all. This is what broke high-concurrency serving with --ssm-bf16-tau: the server's context-checkpoint feature serializes per-sequence state via state_seq_get_data. With a checkpoint enabled, even a single request triggered the out-of-bounds read; at higher concurrency the cell range starts at a higher base slot so the overrun reaches further (hard abort in a debug build, silent state corruption then 1-token-then-EOS on restore in a release build). The static batched-bench never exercises save/restore so it did not catch it; the GDN decode kernel and per-head partition offsets were already correct (decode with checkpoints disabled is fine at N=8/16/32). Fix: serialize the f32 partition and, when the layer is split, the bf16 partition right after it, each with its OWN row width (tensor ne[0]). head_slot is rebuilt deterministically at load (same model + tau), so it is not serialized. Non-split layers have ne[0] == n_embd_s() and no bf16 partition, so their on-disk format and behavior are byte-identical (the default f32 path and the bit-exact gate are unaffected). Verified on GB10/DGX with Qwen3.6-35B-A3B-NVFP4 + --ssm-bf16-tau 64 via a continuous-batching llama-server: with context checkpoints enabled, N=8, N=16 and N=32 (slot reuse + restore) all now produce full coherent 128-token output and the server stays up; pre-fix the same config aborted on the first checkpoint. Assisted-by: Claude:claude-opus-4-8[1m] [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
diff --git a/common/arg.cpp b/common/arg.cpp
|
||||
index 841a38e..3e05bd4 100644
|
||||
index 841ca3c..0b5b6ec 100644
|
||||
--- a/common/arg.cpp
|
||||
+++ b/common/arg.cpp
|
||||
@@ -2157,6 +2157,47 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
@@ -2194,6 +2194,47 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.cache_type_v = kv_cache_type_from_str(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_CACHE_TYPE_V"));
|
||||
@@ -54,23 +54,13 @@ diff --git a/common/common.cpp b/common/common.cpp
|
||||
index a14e7bb..c4ab884 100644
|
||||
--- a/common/common.cpp
|
||||
+++ b/common/common.cpp
|
||||
@@ -1600,6 +1600,19 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
@@ -1600,6 +1600,9 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
|
||||
cparams.type_k = params.cache_type_k;
|
||||
cparams.type_v = params.cache_type_v;
|
||||
+ cparams.type_r = params.cache_type_conv;
|
||||
+ cparams.type_s = params.cache_type_ssm;
|
||||
+ cparams.ssm_hybrid_tau_thresh = params.ssm_hybrid_tau_thresh;
|
||||
+ // LocalAI per-model option hook: when the --ssm-bf16-tau CLI flag is at its bit-exact
|
||||
+ // default (0), honor LLAMA_SSM_BF16_TAU (set by the grpc-server from the model YAML
|
||||
+ // `options: [ssm_bf16_tau:N]`) so the reduced-precision hybrid fast mode is selectable
|
||||
+ // per model without a process-wide CLI flag. Absent/non-positive env => untouched, so
|
||||
+ // stock stays bit-exact; the CLI flag, when set, takes precedence.
|
||||
+ if (cparams.ssm_hybrid_tau_thresh == 0.0f) {
|
||||
+ if (const char * tau_env = std::getenv("LLAMA_SSM_BF16_TAU")) {
|
||||
+ try { cparams.ssm_hybrid_tau_thresh = std::stof(tau_env); } catch (...) {}
|
||||
+ }
|
||||
+ }
|
||||
|
||||
return cparams;
|
||||
}
|
||||
@@ -1358,7 +1348,7 @@ index 484eafb..46618d3 100644
|
||||
~llama_memory_hybrid() = default;
|
||||
|
||||
diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp
|
||||
index 6a4892f..aae57a4 100644
|
||||
index 6a4892f..5aba1e4 100644
|
||||
--- a/src/llama-memory-recurrent.cpp
|
||||
+++ b/src/llama-memory-recurrent.cpp
|
||||
@@ -8,6 +8,7 @@
|
||||
@@ -1529,7 +1519,52 @@ index 6a4892f..aae57a4 100644
|
||||
|
||||
return size_s_bytes;
|
||||
}
|
||||
@@ -1041,6 +1136,44 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
@@ -892,24 +987,33 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
|
||||
}
|
||||
|
||||
if (!s_trans) {
|
||||
- for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
- // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
|
||||
- if (s_l[il] == nullptr) continue;
|
||||
-
|
||||
- // Write S tensor type
|
||||
- const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
+ // Hybrid per-head SSM split (lever A): a split layer's SSM state lives in an f32 partition
|
||||
+ // (s_l, S_v*S_v*n_f32 wide) PLUS a bf16 partition (s_l_bf16, S_v*S_v*n_bf16 wide). BOTH must be
|
||||
+ // serialized, each with its OWN row width (tensor ne[0]) - using hparams.n_embd_s() (the full
|
||||
+ // S_v*S_v*H) over-reads the smaller f32 partition tensor out of bounds (the serving crash /
|
||||
+ // garbage-on-restore at high concurrency, where the cell range starts at a higher base) and
|
||||
+ // silently drops the fast-head bf16 state. head_slot is rebuilt deterministically at load (same
|
||||
+ // model + tau) so it is not serialized. Non-split layers have ne[0] == n_embd_s(), so their
|
||||
+ // on-disk format and behavior are byte-identical to before.
|
||||
+ auto write_s_partition = [&](ggml_tensor * s) {
|
||||
+ const int32_t s_type_i = (int32_t) s->type;
|
||||
io.write(&s_type_i, sizeof(s_type_i));
|
||||
-
|
||||
- // Write row size of S tensor
|
||||
- const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
+ const uint64_t s_size_row = ggml_row_size(s->type, s->ne[0]);
|
||||
io.write(&s_size_row, sizeof(s_size_row));
|
||||
-
|
||||
// Write each logical cell row range. With pending recurrent rollback,
|
||||
// the logical current state may live in a rollback snapshot plane.
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t buf_size = range_size * s_size_row;
|
||||
- io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
|
||||
+ io.write_tensor(s, range.first * s_size_row, buf_size);
|
||||
+ }
|
||||
+ };
|
||||
+ for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
+ // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
|
||||
+ if (s_l[il] == nullptr) continue;
|
||||
+ write_s_partition(s_l[il]);
|
||||
+ if (s_l_bf16[il] != nullptr) {
|
||||
+ write_s_partition(s_l_bf16[il]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -1041,6 +1145,44 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1574,7 +1609,7 @@ index 6a4892f..aae57a4 100644
|
||||
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||
uint32_t s_trans;
|
||||
uint32_t n_layer;
|
||||
@@ -1069,14 +1202,20 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
@@ -1069,14 +1211,20 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
int32_t r_type_i_ref;
|
||||
io.read(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||
const int32_t r_type_i = (int32_t) r_l[il]->type;
|
||||
@@ -1599,34 +1634,71 @@ index 6a4892f..aae57a4 100644
|
||||
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
if (r_size_row != r_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
||||
@@ -1099,14 +1238,21 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
@@ -1090,32 +1238,50 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
}
|
||||
|
||||
if (!s_trans) {
|
||||
- for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
- // skip null layers
|
||||
- if (s_l[il] == nullptr) continue;
|
||||
-
|
||||
+ // Hybrid per-head SSM split (lever A): mirror state_write_data - read the f32 partition (s_l)
|
||||
+ // and, when the layer is split, the bf16 partition (s_l_bf16) right after it, each with its OWN
|
||||
+ // row width (tensor ne[0]). For non-split layers ne[0] == n_embd_s(), so the on-disk format and
|
||||
+ // behavior are unchanged. The f32<->bf16 dtype back-compat conversion still applies per partition.
|
||||
+ auto read_s_partition = [&](ggml_tensor * s, uint32_t n_embd, uint32_t il) -> bool {
|
||||
// Read type of value
|
||||
int32_t s_type_i_ref;
|
||||
io.read(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
- const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
-
|
||||
- if (s_type_i != s_type_i_ref) {
|
||||
- LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||
- return false;
|
||||
- }
|
||||
-
|
||||
+ const int32_t s_type_i = (int32_t) s->type;
|
||||
|
||||
// Read row size of value
|
||||
uint64_t s_size_row_ref;
|
||||
io.read(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||
- const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
+
|
||||
+ if (s_type_i != s_type_i_ref) {
|
||||
+ // back-compat: convert f32<->bf16 saved rows into the live cache dtype; else hard error.
|
||||
+ // The SSM-state opt-in flips this dtype, so an f32-saved session can restore into a
|
||||
+ // bf16 cache (and vice versa) instead of failing the hard type match.
|
||||
+ if (!recurrent_read_convert_rows(io, s_l[il], (ggml_type) s_type_i_ref, s_size_row_ref,
|
||||
+ hparams.n_embd_s(), head, cell_count, "s", il)) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ continue;
|
||||
+ return recurrent_read_convert_rows(io, s, (ggml_type) s_type_i_ref, s_size_row_ref,
|
||||
+ n_embd, head, cell_count, "s", (int) il);
|
||||
+ }
|
||||
+
|
||||
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
+ const size_t s_size_row = ggml_row_size(s->type, n_embd);
|
||||
if (s_size_row != s_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
||||
@@ -1241,6 +1387,18 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
|
||||
- LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
||||
+ LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, (int) il);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the values for the whole cell range
|
||||
- io.read_tensor(s_l[il], head * s_size_row, cell_count * s_size_row);
|
||||
+ io.read_tensor(s, head * s_size_row, cell_count * s_size_row);
|
||||
+ }
|
||||
+ return true;
|
||||
+ };
|
||||
+ for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
+ // skip null layers
|
||||
+ if (s_l[il] == nullptr) continue;
|
||||
+ if (!read_s_partition(s_l[il], (uint32_t) s_l[il]->ne[0], il)) {
|
||||
+ return false;
|
||||
+ }
|
||||
+ if (s_l_bf16[il] != nullptr) {
|
||||
+ if (!read_s_partition(s_l_bf16[il], (uint32_t) s_l_bf16[il]->ne[0], il)) {
|
||||
+ return false;
|
||||
+ }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -1241,6 +1407,18 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
|
||||
return mem->s_l[il];
|
||||
}
|
||||
|
||||
@@ -1848,7 +1920,7 @@ index 0eee804..58f3d0c 100644
|
||||
// lambda, exactly like mamba-base's ggml_ssm_scan) and still performs the rs_zero clear and
|
||||
// the extra-states copy around it. The op reads curr_state from cache[ids[seq]] and writes
|
||||
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
|
||||
index 41babb8..b5e3048 100644
|
||||
index 330d936..188d7b3 100644
|
||||
--- a/tests/test-backend-ops.cpp
|
||||
+++ b/tests/test-backend-ops.cpp
|
||||
@@ -3901,6 +3901,7 @@ struct test_rwkv_wkv6 : public test_case {
|
||||
@@ -1973,7 +2045,7 @@ index 41babb8..b5e3048 100644
|
||||
// GGML_OP_GATED_LINEAR_ATTN
|
||||
struct test_gla : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -9316,6 +9398,45 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
@@ -9325,6 +9407,45 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 8, 1, 1, false, false, /*K=*/3));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 16, 2, 1, false, false, /*K=*/4));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user