III.5: Radix prefix cache
III.4 gave us a KV cache built for serving (fixed-size blocks, O(1) append, no fragmentation) and it quietly added a snapshot method that freezes a cache into a re-materializable form. This chapter is what that method was for.
Here is the waste we're attacking. A chat server's requests are not independent strings. They overlap, heavily:
- A shared system prompt. Every request to one deployment usually begins with the same several-hundred-token system prompt.
- Multi-turn conversations. Turn 3 of a chat contains all of turns 1 and 2 verbatim; recall from III.1 that each turn re-renders and re-prefills the whole message history.
The engine prefills every one of those shared tokens every single time, running the prompt through the model in one forward pass. Prefilling a 500-token shared prefix is real work: it's most of a request's time-to-first-token. Doing it identically for request after request is pure waste.
The fix: cache the result of prefilling a prefix (its KV snapshot) keyed by the token ids of that prefix. When a new request's prompt starts with a prefix we've already prefilled, skip straight past the shared tokens and reuse the saved KV. The data structure that makes "longest cached prefix of these token ids" a fast lookup is a radix tree.
What a radix tree buys us
A prefix cache needs one operation: given a list of token ids, find the longest prefix of that list we've stored, and return its saved KV. A hash map keyed on the full id list can't do this; it only matches exact keys, and we want the longest partial match.
A radix tree (also called a prefix tree, or trie) is built for exactly this. It's a tree where each edge is labeled with one token id, so a path from the root spells out a sequence of ids. Store the system prompt [A, B, C] and the tree has a path root → A → B → C. Now a request whose prompt is [A, B, C, D, E] walks the tree from the root: it follows A, B, C (matched) then hits no edge for D and stops. We matched the 3-token prefix [A, B, C] and learned the request only needs to prefill [D, E].
root
└─ A
└─ B
└─ C ● ← terminal: KV snapshot for [A,B,C] stored here
└─ D
└─ E ● ← terminal: KV snapshot for [A,B,C,D,E]Nodes marked ● are terminal: a prefix we actually prefilled and saved. A node can be non-terminal (a waypoint on the path to longer entries). Lookup walks ids one at a time, remembering the deepest terminal it passed; that's the longest cached prefix.
Three things make it production-shaped rather than a toy:
- It's bounded. A server runs forever; the cache can't grow without limit. We cap the number of entries and evict the least-recently-used when full.
- It's safe under concurrency. Several requests touch the cache at once, so it's behind a
Mutex. And eviction must never delete a snapshot a live request is still decoding from, so each entry carries a pin count, incremented on lookup and decremented (viaDrop) when the request is done. Pinned entries are skipped by the evictor.
The III.5 commit puts this in src/cache/prefix/radix.rs.
The nodes
src/cache/prefix/radix.rs starts with the tree's nodes:
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::cache::KvSnapshot;
pub type TokenId = usize;
struct RadixNode {
children: HashMap<TokenId, Box<RadixNode>>,
terminal: Option<TerminalEntry>,
}
struct TerminalEntry {
path: Arc<[TokenId]>,
snapshot: Arc<dyn KvSnapshot>,
next_id: usize,
pin_count: Arc<AtomicUsize>,
}
impl RadixNode {
fn new() -> Self {
Self {
children: HashMap::new(),
terminal: None,
}
}
}A RadixNode has children keyed by token id (the labeled edges) and an optional terminal. The HashMap for children means following an edge is an O(1) hash lookup; a node can branch into many children when different prompts diverge after a shared prefix.
TerminalEntry is the payload stored at a ● node:
path: the full id sequence from root to this node, kept so the LRU list and the evictor can identify the entry.Arc<[TokenId]>so it's cheaply shareable.snapshot: theKvSnapshotfrom III.4: the frozen KV state after prefillingpath.next_id: the first token the model predicted after the prefix. Stored so a full hit (the new prompt equals a cached prefix exactly) needs zero forward passes; we already know the next token.pin_count: anAtomicUsizecounting live readers. Atomic so it can be bumped without holding the tree'sMutex.
The cache struct
The cache wraps a tree behind a Mutex and adds an LRU list:
pub struct RadixPrefixCache {
inner: Mutex<Inner>,
max_entries: usize,
}
struct Inner {
root: Box<RadixNode>,
lru: VecDeque<Arc<[TokenId]>>,
lru_positions: HashMap<Arc<[TokenId]>, ()>,
entries: usize,
}
impl RadixPrefixCache {
pub fn new(max_entries: usize) -> Arc<RadixPrefixCache> {
Arc::new(Self {
inner: Mutex::new(Inner {
root: Box::new(RadixNode::new()),
lru: VecDeque::new(),
lru_positions: HashMap::new(),
entries: 0,
}),
max_entries,
})
}
pub fn max_entries(&self) -> usize {
self.max_entries
}
pub fn entries(&self) -> usize {
self.inner.lock().map(|g| g.entries).unwrap_or(0)
}Inner is everything the Mutex guards: the tree root, the LRU ordering, and the entry count. lru is a deque of entry paths, front is least-recently-used, back is most-recent. lru_positions is a set of which paths are in the LRU (a HashMap<_, ()> used as a set) so we can check membership without scanning the deque. new returns the cache already wrapped in an Arc, since every part of the server shares one instance.
Lookup
lookup_longest walks the tree following the prompt's ids, tracking the deepest terminal:
pub fn lookup_longest(&self, ids: &[TokenId]) -> Option<CacheHit> {
if self.max_entries == 0 || ids.is_empty() {
return None;
}
let mut guard = self.inner.lock().ok()?;
let mut node: &RadixNode = &guard.root;
let mut best: Option<(usize, &TerminalEntry)> = None;
for (i, &tid) in ids.iter().enumerate() {
let Some(next) = node.children.get(&tid) else {
break;
};
node = next.as_ref();
if let Some(ref t) = node.terminal {
best = Some((i + 1, t));
}
}
let (prefix_len, path_for_lru, pin_count, snapshot, next_id) = {
let (prefix_len, term) = best?;
(
prefix_len,
Arc::clone(&term.path),
Arc::clone(&term.pin_count),
Arc::clone(&term.snapshot),
term.next_id,
)
};
pin_count.fetch_add(1, Ordering::AcqRel);
Self::touch_lru(&mut guard, &path_for_lru);
drop(guard);
Some(CacheHit {
prefix_len,
snapshot,
next_id,
pin_count,
})
}The walk: for each id, try to follow the matching child edge; if there's no edge, stop. Every time the walk lands on a node with a terminal, record it as best along with i + 1 (the length of the prefix matched so far). When the loop ends, best holds the deepest terminal we passed, the longest cached prefix.
If we found one, three reference-counted handles are cloned out of the entry (path, pin_count, snapshot) plus next_id. Then pin_count.fetch_add(1, ...) pins the entry: a live reader now exists, and the evictor must leave this entry alone. touch_lru marks the entry most-recently-used. We drop the lock and hand back a CacheHit.
Insert
insert adds a freshly prefilled prefix to the tree:
pub fn insert(
&self,
ids: Vec<TokenId>,
snapshot: Arc<dyn KvSnapshot>,
next_id: usize,
) -> Result<(), String> {
if self.max_entries == 0 || ids.is_empty() {
return Ok(());
}
let mut guard = self
.inner
.lock()
.map_err(|_| "cache lock poisoned".to_string())?;
let mut node = guard.root.as_mut();
for &tid in &ids {
node = node
.children
.entry(tid)
.or_insert_with(|| Box::new(RadixNode::new()));
}
let is_new = node.terminal.is_none();
let pin_count = if let Some(ref old) = node.terminal {
Arc::clone(&old.pin_count)
} else {
Arc::new(AtomicUsize::new(0))
};
let path: Arc<[TokenId]> = match &node.terminal {
Some(old) => Arc::clone(&old.path),
None => ids.into(),
};
node.terminal = Some(TerminalEntry {
path: Arc::clone(&path),
snapshot,
next_id,
pin_count,
});
if is_new {
guard.entries += 1;
guard.lru.push_back(Arc::clone(&path));
let _ = guard.lru_positions.insert(path, ());
} else {
Self::touch_lru(&mut guard, &path);
}
self.evict_if_needed(&mut guard);
Ok(())
}It walks the ids from the root, creating any missing child node along the way (entry(...).or_insert_with(...)). When it reaches the final node it installs a TerminalEntry. If that node was already terminal (re-inserting an existing prefix) it reuses the old pin_count and path so any live reader's pin stays valid; otherwise it starts a fresh pin count at 0.
A genuinely new entry bumps the count and joins the back of the LRU; a re-insert just touches the LRU. Either way, evict_if_needed runs at the end to enforce the cap.
touch_lru moves an entry to the most-recently-used end:
fn touch_lru(inner: &mut Inner, path: &Arc<[TokenId]>) {
if inner.lru_positions.contains_key(path) {
inner.lru.retain(|p| !Arc::ptr_eq(p, path));
inner.lru.push_back(Arc::clone(path));
}
}Remove the path from wherever it sits and push it to the back. Arc::ptr_eq compares by pointer identity, not by contents; the same allocation is being re-positioned.
Eviction
evict_if_needed drops least-recently-used entries until the cache is back under its cap, skipping any entry that's pinned:
fn evict_if_needed(&self, inner: &mut Inner) {
while inner.entries > self.max_entries {
let mut evict_idx = None;
for (i, p) in inner.lru.iter().enumerate() {
if Self::is_pinned(inner, p) {
continue;
}
evict_idx = Some(i);
break;
}
let Some(i) = evict_idx else {
break;
};
let path = inner.lru.remove(i).expect("evict_idx valid");
inner.lru_positions.remove(&path);
if Self::remove_terminal(&mut inner.root, &path) {
inner.entries -= 1;
}
}
}While over the cap, scan the LRU from the front (least-recent) for the first un-pinned entry. If every entry is pinned, give up; better to run slightly over the cap than to evict KV out from under a request mid-decode. Otherwise remove that entry from the LRU and from the tree.
is_pinned walks to a path's terminal and checks its count:
fn is_pinned(inner: &Inner, path: &Arc<[TokenId]>) -> bool {
let mut node: &RadixNode = &inner.root;
for &tid in path.iter() {
let Some(next) = node.children.get(&tid) else {
return false;
};
node = next.as_ref();
}
node.terminal
.as_ref()
.map(|t| t.pin_count.load(Ordering::Acquire) > 0)
.unwrap_or(false)
}remove_terminal clears the terminal marker at a path:
fn remove_terminal(root: &mut RadixNode, path: &[TokenId]) -> bool {
fn walk(node: &mut RadixNode, path: &[TokenId]) -> bool {
if path.is_empty() {
return node.terminal.take().is_some();
}
let Some(child) = node.children.get_mut(&path[0]) else {
return false;
};
walk(child, &path[1..])
}
walk(root, path)
}
}walk recurses down one id at a time; at the end it take()s the terminal, dropping the stored snapshot. The intermediate nodes are left in place (they may be on the path to other entries); only the terminal marker is removed.
The hit handle and its pin
CacheHit is what lookup_longest returns, and crucially, it manages the pin via Drop:
pub struct CacheHit {
pub prefix_len: usize,
pub snapshot: Arc<dyn KvSnapshot>,
pub next_id: usize,
pin_count: Arc<AtomicUsize>,
}
impl CacheHit {
pub fn is_full_hit(&self, prompt_len: usize) -> bool {
self.prefix_len == prompt_len
}
}
impl Drop for CacheHit {
fn drop(&mut self) {
self.pin_count.fetch_sub(1, Ordering::AcqRel);
}
}
impl std::fmt::Debug for CacheHit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CacheHit")
.field("prefix_len", &self.prefix_len)
.field("next_id", &self.next_id)
.finish()
}
}This is the RAII pin. lookup_longest incremented pin_count; the Drop impl decrements it. As long as the CacheHit value is alive somewhere (a request still decoding from its snapshot) the entry counts as pinned and the evictor skips it. The moment the CacheHit is dropped, the pin releases automatically. No manual unpin call to forget; the type system guarantees the count balances.
is_full_hit reports whether the cached prefix covered the entire prompt, the case where no prefill is needed at all.
The module re-export:
mod radix;
pub use radix::RadixPrefixCache;pub(crate) mod prefix;pub use prefix::RadixPrefixCache;Using the cache in the prefill path
The prefix cache is wired into prepare_chat_decode_state from III.1, the function that builds the KV cache and runs prefill. It gains a prefix_cache parameter and a three-way branch.
pub(crate) fn prepare_chat_decode_state(
model: Arc<dyn Model>,
tokenizer: &dyn Tokenizer,
backend: Arc<dyn Backend>,
messages: &[ChatTemplateMessage],
kv_mode: &str,
prefix_cache: Option<&Arc<RadixPrefixCache>>,
) -> Result<ChatDecodeState, String> {
let (prompt, prompt_ids, eos_token_id) = chat_prompt_details(tokenizer, messages)?;
log_verbose_prompt(&prompt, &prompt_ids);
let mut metrics = Metrics::default();
let mut cache: Box<dyn KvCache>;
let mut next_id: usize;
let mut full_hit = false;The first branch: there is a prefix cache, and lookup_longest found a hit.
if let Some(pc) = prefix_cache {
if let Some(hit) = pc.lookup_longest(&prompt_ids) {
let prefix_len = hit.prefix_len;
next_id = hit.next_id;
cache = hit.snapshot.materialize();
drop(hit);
if prefix_len == prompt_ids.len() {
full_hit = true;
} else {
let suffix = &prompt_ids[prefix_len..];
let mut last_logits = None;
for (i, &tid) in suffix.iter().enumerate() {
let pos = prefix_len + i;
let logits = metrics.record_timed(|| {
model
.as_ref()
.forward_decode_with_kv_cache(tid, pos, cache.as_mut())
});
last_logits = Some(logits);
}
let logits = last_logits.expect("non-empty suffix must produce logits");
next_id = next_token_id_from_logits(backend.as_ref(), &logits).0;
}
} else {On a hit: materialize rebuilds a usable KV cache from the saved snapshot, the shared prefix's KV state, not recomputed. Then drop(hit) releases the pin; we've copied the snapshot's contents into our own cache, so the cached entry no longer needs protecting from this request.
Two sub-cases. If the hit covered the whole prompt (prefix_len == prompt_ids.len()), it's a full_hit: no prefill at all, and next_id is the cached next_id. Otherwise the prompt has a suffix the cache didn't have; those tokens still need running, but with forward_decode_with_kv_cache one at a time (cheap decode steps) instead of a full prefill of the whole prompt. The last suffix token's logits give the first reply token.
The other two branches are the original prefill path, used on a cache miss and when there's no prefix cache at all:
cache = create_kv_cache(kv_mode, model.clone(), backend.clone())?;
let logits = metrics.record_timed(|| {
model
.as_ref()
.forward_prefill_with_kv_cache(&prompt_ids, cache.as_mut())
});
next_id = next_token_id_from_logits(backend.as_ref(), &logits).0;
}
} else {
cache = create_kv_cache(kv_mode, model.clone(), backend.clone())?;
let logits = metrics.record_timed(|| {
model
.as_ref()
.forward_prefill_with_kv_cache(&prompt_ids, cache.as_mut())
});
next_id = next_token_id_from_logits(backend.as_ref(), &logits).0;
}A fresh cache, a full prefill: the III.1 behavior.
Then, after prefill by whichever path, insert this prompt's result into the cache so the next request can hit it:
if let Some(pc) = prefix_cache {
if !full_hit {
let snap = cache.snapshot();
let _ = pc.insert(prompt_ids.clone(), Arc::from(snap), next_id);
}
}We skip insertion on a full_hit; that exact prefix is already in the tree. Otherwise snapshot the cache and insert it keyed by the full prompt ids. The next request with this prompt as a prefix gets a hit.
The prefix_cache parameter threads up through run_chat_turn_with_prefix and run_chat_turn_streaming_with_prefix; both gain it and pass it down. For now the HTTP handlers pass None; the next chapter, which gives the server a real scheduler, is where the server wires in a live cache. chat-repl, though, can use it today.
Wiring chat-repl
chat-repl learns a --prefix-cache-max N flag, the entry cap, defaulting to 0 (disabled):
Some("--prefix-cache-max") => {
cur.advance();
prefix_cache_max = Some(parse_usize(&mut cur, "--prefix-cache-max"));
} pub fn prefix_cache_max(&self, default: usize) -> usize {
self.prefix_cache_max.unwrap_or(default)
}main builds the cache when the flag is positive (and a KV cache is on, since a prefix cache with no KV mode has nothing to snapshot):
let prefix_cache = if kv_mode.is_some() && prefix_cache_max > 0 {
Some(RadixPrefixCache::new(prefix_cache_max))
} else {
None
};
eprintln!("prefix cache max: {}", prefix_cache_max);and threads prefix_cache.as_ref() into run_one_shot / run_repl, which pass it on to run_chat_turn_streaming_with_prefix. In the REPL, where the conversation history grows every turn, this is exactly the multi-turn case: turn 2's prompt has turn 1's prompt as a prefix, so turn 2 reuses turn 1's KV instead of re-prefilling it.
Running it
Run the REPL with a small prefix cache:
cargo run --release --bin chat-repl -- --kv paged --prefix-cache-max 16 path/to/qwen3-0.6b.gguf 64backend: simd
kv cache: paged
prefix cache max: 16
Enter user messages (empty line to quit). Ctrl-D EOF also exits.
user> Name a primary color.
assistant> Red is a primary color.
ttft: 38.1 ms
decode: 31.0 tok/s
user> Name another one.
assistant> Blue is another primary color.
ttft: 9.4 ms
decode: 31.2 tok/sLook at the time-to-first-token. Turn 1 prefills its whole prompt: 38 ms. Turn 2's prompt re-renders the full history: the system framing, turn 1's question, turn 1's answer, then the new question. All of that except the last few tokens is a prefix of something already in the cache, so turn 2 materializes the saved KV and only decodes the short new suffix: TTFT drops to 9 ms. The shared prefix was prefilled once, in turn 1, and reused.
Where this leaves us
Shared prompt prefixes are no longer recomputed. A radix tree keyed on token ids stores KV snapshots; a request whose prompt extends a cached prefix skips straight past the shared tokens. LRU eviction keeps the cache bounded, and RAII pins guarantee a live request's snapshot can't be evicted mid-decode.
But the HTTP handlers still pass None; the server can't use this yet, because the server still has no place to own a shared cache, and still serves requests one at a time. The next chapter builds the decode scheduler: a background worker that owns the model, the prefix cache, and a set of slots, and runs concurrent requests.