#include "paged_kv_manager.h" #include #include namespace paged { // --------------------------------------------------------------------------- // FreeBlockQueue (port of kv_cache_utils.py FreeKVCacheBlockQueue) // --------------------------------------------------------------------------- FreeBlockQueue::FreeBlockQueue(const std::vector& blocks) { num_free_blocks = blocks.size(); for (size_t i = 0; i < blocks.size(); ++i) { if (i > 0) blocks[i]->prev_free = blocks[i - 1]; if (i + 1 < blocks.size()) blocks[i]->next_free = blocks[i + 1]; } if (!blocks.empty()) { fake_head.next_free = blocks.front(); blocks.front()->prev_free = &fake_head; fake_tail.prev_free = blocks.back(); blocks.back()->next_free = &fake_tail; } else { fake_head.next_free = &fake_tail; fake_tail.prev_free = &fake_head; } } KVCacheBlock* FreeBlockQueue::popleft() { KVCacheBlock* first = fake_head.next_free; if (first == &fake_tail || first == nullptr) { assert(num_free_blocks == 0); throw std::runtime_error("No free blocks available"); } fake_head.next_free = first->next_free; first->next_free->prev_free = &fake_head; first->prev_free = first->next_free = nullptr; num_free_blocks--; return first; } std::vector FreeBlockQueue::popleft_n(size_t n) { std::vector ret; if (n == 0) return ret; assert(num_free_blocks >= n); num_free_blocks -= n; KVCacheBlock* curr = fake_head.next_free; ret.reserve(n); for (size_t i = 0; i < n; ++i) { assert(curr != nullptr); ret.push_back(curr); KVCacheBlock* last = curr; curr = curr->next_free; last->prev_free = last->next_free = nullptr; } if (curr != nullptr) { fake_head.next_free = curr; curr->prev_free = &fake_head; } return ret; } void FreeBlockQueue::remove(KVCacheBlock* block) { if (!block->prev_free || !block->next_free) throw std::runtime_error("remove() called on an invalid block"); block->prev_free->next_free = block->next_free; block->next_free->prev_free = block->prev_free; block->prev_free = block->next_free = nullptr; num_free_blocks--; } void FreeBlockQueue::append(KVCacheBlock* block) { KVCacheBlock* last = fake_tail.prev_free; last->next_free = block; block->prev_free = last; block->next_free = &fake_tail; fake_tail.prev_free = block; num_free_blocks++; } void FreeBlockQueue::append_n(const std::vector& blocks) { if (blocks.empty()) return; KVCacheBlock* last = fake_tail.prev_free; for (KVCacheBlock* b : blocks) { b->prev_free = last; last->next_free = b; last = b; } last->next_free = &fake_tail; fake_tail.prev_free = last; num_free_blocks += blocks.size(); } void FreeBlockQueue::prepend_n(const std::vector& blocks) { if (blocks.empty()) return; KVCacheBlock* first = fake_head.next_free; KVCacheBlock* prev = &fake_head; for (KVCacheBlock* b : blocks) { b->prev_free = prev; prev->next_free = b; prev = b; } prev->next_free = first; first->prev_free = prev; num_free_blocks += blocks.size(); } std::vector FreeBlockQueue::get_all_free_blocks() const { std::vector ret; const KVCacheBlock* curr = fake_head.next_free; while (curr && curr->next_free != nullptr) { ret.push_back(const_cast(curr)); curr = curr->next_free; } return ret; } // --------------------------------------------------------------------------- // BlockPool (port of block_pool.py) // --------------------------------------------------------------------------- static std::vector make_ptrs(std::vector& v) { std::vector p; p.reserve(v.size()); for (auto& b : v) p.push_back(&b); return p; } static std::vector make_block_vec(int32_t num_blocks) { std::vector v; v.reserve(num_blocks); for (int32_t i = 0; i < num_blocks; ++i) v.emplace_back(i); return v; } BlockPool::BlockPool(int32_t num_blocks, bool enable_caching) : enable_caching_(enable_caching), blocks_(make_block_vec(num_blocks)), ptrs_(make_ptrs(blocks_)), free_queue_(ptrs_) { // vLLM reserves block_id 0 as the null block (never cached). null_block = free_queue_.popleft(); null_block->is_null = true; } bool BlockPool::maybe_evict_cached_block(KVCacheBlock* block) { if (!block->has_hash) return false; auto it = cached_block_hash_to_block_.find(block->block_hash); if (it == cached_block_hash_to_block_.end() || it->second != block) return false; cached_block_hash_to_block_.erase(it); block->reset_hash(); return true; } std::vector BlockPool::get_new_blocks(size_t n) { if (n > get_num_free_blocks()) throw std::runtime_error("Cannot get free blocks from pool"); auto ret = free_queue_.popleft_n(n); for (KVCacheBlock* b : ret) { if (enable_caching_) maybe_evict_cached_block(b); assert(b->ref_cnt == 0); b->ref_cnt += 1; } return ret; } KVCacheBlock* BlockPool::get_cached_block(uint64_t block_hash) { auto it = cached_block_hash_to_block_.find(block_hash); return it == cached_block_hash_to_block_.end() ? nullptr : it->second; } void BlockPool::touch(const std::vector& blocks) { for (KVCacheBlock* b : blocks) { // ref_cnt==0 means the block is a free-list eviction candidate; pull it out. if (b->ref_cnt == 0 && !b->is_null) free_queue_.remove(b); b->ref_cnt += 1; } } void BlockPool::free_blocks(const std::vector& ordered_blocks) { std::vector without_hash, with_hash; for (KVCacheBlock* b : ordered_blocks) { if (b->is_null) continue; b->ref_cnt -= 1; if (b->ref_cnt == 0) (b->has_hash ? with_hash : without_hash).push_back(b); } free_queue_.prepend_n(without_hash); // un-hashed: evicted first (front) free_queue_.append_n(with_hash); // hashed: kept warm (tail) } void BlockPool::cache_full_blocks(const std::vector& req_blocks, size_t num_cached_blocks, size_t num_full_blocks, const std::vector& block_hashes) { for (size_t i = num_cached_blocks; i < num_full_blocks; ++i) { KVCacheBlock* blk = req_blocks[i]; if (blk->has_hash) continue; blk->has_hash = true; blk->block_hash = block_hashes[i]; cached_block_hash_to_block_[blk->block_hash] = blk; } } // --------------------------------------------------------------------------- // PagedKVManager (port of SingleTypeKVCacheManager / FullAttentionManager) // --------------------------------------------------------------------------- static inline size_t cdiv(size_t a, size_t b) { return (a + b - 1) / b; } PagedKVManager::PagedKVManager(int32_t num_blocks, int block_size, bool enable_caching) : block_size_(block_size), pool_(num_blocks, enable_caching) {} bool PagedKVManager::allocate(int seq_id, size_t total_tokens) { auto& req = req_to_blocks_[seq_id]; size_t need = cdiv(total_tokens, block_size_); if (need <= req.size()) return true; size_t add = need - req.size(); if (add > pool_.get_num_free_blocks()) return false; // OOM auto nb = pool_.get_new_blocks(add); req.insert(req.end(), nb.begin(), nb.end()); return true; } std::vector PagedKVManager::block_table(int seq_id) const { std::vector bt; auto it = req_to_blocks_.find(seq_id); if (it == req_to_blocks_.end()) return bt; bt.reserve(it->second.size()); for (KVCacheBlock* b : it->second) bt.push_back(b->block_id); return bt; } int64_t PagedKVManager::slot(int seq_id, int pos) const { const auto& req = req_to_blocks_.at(seq_id); int32_t phys = req[pos / block_size_]->block_id; return (int64_t)phys * block_size_ + (pos % block_size_); } std::vector PagedKVManager::slot_mapping(int seq_id, const std::vector& positions) const { std::vector sm; sm.reserve(positions.size()); for (int p : positions) sm.push_back(slot(seq_id, p)); return sm; } void PagedKVManager::free(int seq_id) { auto it = req_to_blocks_.find(seq_id); if (it == req_to_blocks_.end()) return; // Free in reverse so the tail of the block chain is evicted first (vLLM order). std::vector ordered(it->second.rbegin(), it->second.rend()); pool_.free_blocks(ordered); req_to_blocks_.erase(it); } // FNV-1a chained block hash. Deterministic and prefix-sensitive; folds the parent // hash into the seed so each block hash transitively encodes its whole prefix // (behavioral parity with vLLM hash_block_tokens chaining; vLLM uses sha256 bytes). uint64_t PagedKVManager::hash_block(uint64_t parent_hash, const std::vector& token_ids) { uint64_t h = 1469598103934665603ull ^ parent_hash; for (int t : token_ids) { h ^= (uint64_t)(uint32_t)t; h *= 1099511628211ull; } if (h == 0) h = 0x9e3779b97f4a7c15ull; // never 0 (0 reads as "no hash") return h; } std::vector PagedKVManager::compute_block_hashes(const std::vector& token_ids) const { std::vector hashes; uint64_t parent = 0; // NONE_HASH analogue size_t n_full = token_ids.size() / block_size_; for (size_t i = 0; i < n_full; ++i) { std::vector blk(token_ids.begin() + i * block_size_, token_ids.begin() + (i + 1) * block_size_); parent = hash_block(parent, blk); hashes.push_back(parent); } return hashes; } size_t PagedKVManager::get_computed_blocks(const std::vector& block_hashes) { std::vector hits; for (uint64_t bh : block_hashes) { // stop at first miss (prefix property) KVCacheBlock* cb = pool_.get_cached_block(bh); if (!cb) break; hits.push_back(cb); } pool_.touch(hits); // ++ref_cnt, pull from free list return hits.size() * (size_t)block_size_; } void PagedKVManager::cache_blocks(int seq_id, const std::vector& block_hashes, size_t num_tokens) { auto& req = req_to_blocks_[seq_id]; size_t n_full = num_tokens / block_size_; pool_.cache_full_blocks(req, /*num_cached=*/0, n_full, block_hashes); } } // namespace paged