mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-23 08:08:52 -04:00
docs(paged): refine 0003 plan - used-cell gather, per-ubatch rebuild, single-stream first
Assisted-by: Claude:opus-4.8 [Claude Code] Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -17,6 +17,27 @@ ggml note: `ggml_get_rows(a,b)` gathers `a`'s **ne1** by `b` (I32). Raw K is `[n
|
||||
→ ne1 = cells → direct. The mask is `[n_kv, n_tokens, 1, n_stream]` → n_kv is **ne0**, so gather as
|
||||
`transpose → get_rows → transpose`.
|
||||
|
||||
### KEY CORRECTIONS (found while implementing — these change the edits)
|
||||
|
||||
1. **Gather index = ALL used (non-empty) cells in `[0,n_kv)`, NOT `sinfo.idxs`.** `sinfo.idxs` is only the
|
||||
*current ubatch's write slots*; attention reads the *full history*. The query set per token is masked by
|
||||
`kq_mask`, so gathering the union of all used cells + gathering the mask the same way is token-identical
|
||||
and drops exactly the empty (already-masked) cells. So: `gather = { i in [0,n_kv) : !cells.is_empty(i) }`.
|
||||
|
||||
2. **Static-graph size is fine because llama.cpp rebuilds the graph every ubatch.** `n_gather` (used-cell
|
||||
count) is therefore a build-time constant for that ubatch — `build_input_gather_idxs` sizes the I32
|
||||
tensor to `get_n_gather()` computed at build, `set_input_gather_idxs` fills the identical cell list. They
|
||||
MUST use the same loop (`for i in [0,n_kv): if !is_empty(i) push i`) so build-order == fill-order.
|
||||
|
||||
3. **K/V gather can live entirely in `build_attn`, no cache get_k change.** The `get_k` 4d view is contiguous
|
||||
in `[ne0,ne1,ne2]` from cell 0 (nb2 == n_embd_head*n_head_kv*elemsz), so for **single stream (ns==1)**:
|
||||
`reshape_3d(k, n_embd_head*n_head_kv, n_kv, 1) → get_rows(., gi) → reshape_4d(., n_embd_head, n_head_kv, n_gather, 1)`.
|
||||
Multi-stream (ns>1) breaks contiguity (nb3 uses kv_size) → gate to ns==1 first, multi-stream follow-up.
|
||||
|
||||
4. So the ONLY cache additions are `is_paged()`, `get_n_gather(n_kv)`, `build/set_input_gather_idxs(n_kv)`;
|
||||
everything else (K/V/mask gather) is in `build_attn`. `set_input_kq_mask` is **unchanged** (built over
|
||||
n_kv, then gathered). Smaller than the 7-edit estimate above.
|
||||
|
||||
## Edits
|
||||
|
||||
### 1. `src/llama-kv-cache.h` — declare gather infra (in `llama_kv_cache`)
|
||||
|
||||
Reference in New Issue
Block a user