II.2: KV cache
The harness from II.1 printed one number that should bother you: the forward passes get slower as the response gets longer. Generating token 16 takes more wall time than generating token 4. Each forward pass produces exactly one token, so why would later passes cost more?
Because the model is doing the same work over and over. To generate token 16, our Act 1 forward pass runs the entire network over all 15 prior tokens, then reads off the next one. Token 17 reruns it over 16 tokens. The sequence keeps growing, every forward pass reprocesses all of it, and the per-token cost climbs linearly. Generating an N-token response costs roughly N² work.
It does not have to. Almost all of that recomputed work produces identical numbers each time, and we can save them. This chapter is the KV cache, the single biggest speedup in Act 2. It turns each decode step from "reprocess the whole sequence" into "process one token," and it does so by splitting the forward pass into two distinct paths.
Prefill and decode are different shapes of work
I.5 built one forward function. It takes a list of token ids, runs all of them through the 28 transformer layers, and returns logits. We've been calling it once per generated token, each time with the full sequence so far. That works, and it is wasteful, because it conflates two genuinely different jobs.
Prefill is processing the prompt. The first forward pass takes all P prompt tokens and runs them through the network together. Every token attends to every earlier token; this must look at the whole prompt at once. There is no shortcut here, and we don't want one.
Decode is generating the response, one token at a time. Step i has a single new token. It needs to produce one row of logits. And here is the key fact: the new token attends to all earlier tokens, but those earlier tokens were already processed on previous steps. Their contribution to attention does not change. If we save it, we never recompute it.
So the plan is two forward functions instead of one:
forward_prefill_with_kv_cache: runs the whole prompt, exactly like the Act 1 forward, but on the way it saves the per-layer attention data into a cache.forward_decode_with_kv_cache: runs just the one new token, reusing everything the cache already holds.
To see what to save, we need to look at attention.
What attention recomputes, and what's worth keeping
Inside each transformer layer, attention works like this (the full mechanism is in I.5; here's the part that matters for caching). Every token is projected into three vectors: a query (Q), a key (K), and a value (V). To compute the output for a given token, attention compares that token's query against the keys of every token up to and including it, turns those comparisons into weights, and uses the weights to mix together the values.
Now look at decode step i. The new token has a fresh Q, K, and V. It needs the K and V of every earlier token too, but those were computed from earlier tokens, which haven't changed. Token 5's key is the same number whether you compute it on step 5 or step 500. Recomputing it is pure waste.
That is the cache. For each of the 28 layers, we store the K and V vectors of every token processed so far. Hence "KV cache." On a decode step:
- Compute Q, K, V for the one new token only.
- Append its K and V to the cache for this layer.
- Run attention using the new token's Q against the whole cached K and V.
Step 3 still attends over the full history (correctness is preserved exactly), but steps 1 and 2 only touch the new token. The N² blowup is gone. Each decode step is now constant work: one token through the network.
What we don't cache is the queries. A query is only ever used on the step that created it: token i's query attends to tokens 0..=i and is then never needed again. Only K and V get reused, so only K and V get stored.
Without KV cache, generating token 4 reprocesses tokens 0,1,2,3:
step 1: [t0] -> t1 (1 token of work)
step 2: [t0 t1] -> t2 (2 tokens of work)
step 3: [t0 t1 t2] -> t3 (3 tokens of work)
step 4: [t0 t1 t2 t3] -> t4 (4 tokens of work) ... grows
With KV cache, each step processes one new token, reuses the rest:
prefill: [t0 t1 t2] -> t3 K,V of t0,t1,t2 saved
step 1: [t3] + cached K,V -> t4 append t3's K,V; 1 token of work
step 2: [t4] + cached K,V -> t5 append t4's K,V; 1 token of work ... flatThe plan
This chapter touches a lot of files because the cache cuts across layers: it is per-layer state, it needs to flow through the model, and the binary needs a flag to switch it on. The pieces:
- A
cachemodule with aKvCachetrait and one implementation,BasicKvCache. - Two new
Backendoperations the decode path needs: appending a row to a tensor (concat_dim0) and applying RoPE position-encoding to a single token (apply_rope_single_row). - A decode-specialized attention function,
gqa_attention_decode_one_query. - The two new model methods,
forward_prefill_with_kv_cacheandforward_decode_with_kv_cache. - A
--kvflag and the wiring throughmodel-generateandgreedy_generate.
We'll take them in roughly that order.
The KvCache trait
A trait, not just a struct, because Act 3 replaces BasicKvCache with smarter layouts (paged caches, shared-prefix caches), and the model code should not care which one it's handed. src/cache/kv_cache_trait.rs:
use crate::tensor::Tensor;
pub trait KvSnapshot: std::fmt::Debug + Send + Sync {
fn seq_len(&self) -> usize;
fn materialize(&self) -> Box<dyn KvCache>;
}
pub trait KvCache: std::fmt::Debug + Send + Sync {
fn snapshot(&self) -> Box<dyn KvSnapshot>;
fn kind(&self) -> &'static str;
fn set_prefill(&mut self, layer: usize, k: &Tensor, v: &Tensor);
fn push_row(&mut self, layer: usize, k_row: &Tensor, v_row: &Tensor);
fn materialize(&self, layer: usize) -> (Tensor, Tensor);
fn seq_len(&self, layer: usize) -> usize;
fn num_layers(&self) -> usize;
}KvCache is the working cache. The methods the model actually calls are four:
set_prefill(layer, k, v): prefill processed the whole prompt forlayer; install its K and V into the cache as the starting contents.push_row(layer, k_row, v_row): a decode step produced one new token's K and V forlayer; append them.materialize(layer): hand back the full(K, V)forlayerso attention can run.seq_len(layer): how many tokens are cached forlayer. Used as a consistency check.
KvSnapshot is a frozen copy of a cache. We don't use it this chapter. It surfaces in Act 3, where the prefix cache needs to save a cache's state at a point in the conversation and restore it later. snapshot() produces one; materialize() turns it back into a live cache. It's in the trait now so the trait is stable; ignore it for this chapter.
kind() returns a label like "basic" for logging. The Send + Sync bound means a cache can be moved across threads, which is not needed yet but will be needed once there's a server.
BasicKvCache
The implementation. The strategy is the obvious one: for each layer, hold the K tensor and the V tensor; appending a token is appending a row. src/cache/basic.rs starts with one layer's worth of state:
use std::fmt::Debug;
use std::sync::Arc;
use super::{KvCache, KvSnapshot};
use crate::backend::Backend;
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub(crate) struct LayerKvState {
pub k: Tensor,
pub v: Tensor,
}
impl LayerKvState {
pub fn empty(kv_width: usize) -> Self {
Self {
k: Tensor::new(vec![], vec![0, kv_width]),
v: Tensor::new(vec![], vec![0, kv_width]),
}
}LayerKvState is the K and V for one transformer layer. Each is a 2-D tensor shaped [seq_len, kv_width], with one row per token and each row holding that token's key (or value) across all the key/value heads concatenated. empty creates a layer with zero rows: shape [0, kv_width], an empty data vector. A fresh cache is 28 of these, one per layer, all empty. As tokens are processed, rows get added and seq_len grows.
pub fn set_prefill(&mut self, k: &Tensor, v: &Tensor) {
self.k = k.clone();
self.v = v.clone();
}
pub fn push_row(&mut self, ops: &dyn Backend, k_row: &Tensor, v_row: &Tensor) {
self.k = ops.concat_dim0(&self.k, k_row);
self.v = ops.concat_dim0(&self.v, v_row);
}
pub fn materialize(&self) -> (Tensor, Tensor) {
(self.k.clone(), self.v.clone())
}
pub fn seq_len(&self) -> usize {
self.k.shape()[0]
}
}set_prefill is the prefill case: prefill computed K and V for the entire prompt at once, so we just store those whole tensors as the layer's starting state. push_row is the decode case: it appends one new token's K row and V row to the bottom of the existing tensors with concat_dim0 (concatenation along dimension 0, the row axis). materialize clones out the tensors for attention to use; seq_len reads the row count off the K tensor's shape.
push_row doing a full concatenate-and-reallocate every step is admittedly not the most efficient layout. That's exactly why it's the basic cache. Act 3's paged cache replaces this with fixed-size blocks and no reallocation. For one request, the simple version is plenty, and its simplicity makes the idea legible.
Now the cache itself: a Vec of LayerKvState, plus an Arc<dyn Backend> because push_row needs a backend to call concat_dim0:
#[derive(Clone)]
pub(crate) struct BasicKvCache {
pub(crate) layers: Vec<LayerKvState>,
ops: Arc<dyn Backend>,
}
impl Debug for BasicKvCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BasicKvCache")
.field("num_layers", &self.layers.len())
.field(
"seq_len",
&self.layers.first().map(|l| l.k.shape()[0]).unwrap_or(0),
)
.finish()
}
}
impl BasicKvCache {
pub fn new(num_layers: usize, kv_width: usize, ops: Arc<dyn Backend>) -> Self {
Self {
layers: (0..num_layers)
.map(|_| LayerKvState::empty(kv_width))
.collect(),
ops,
}
}
}new builds num_layers empty layer states. The hand-written Debug impl prints just the layer count and sequence length rather than dumping every cached number, since a cache holding thousands of floats per layer is not something you want spilling into a log line. kv_width and num_layers come from the model; we'll see where shortly.
Next the snapshot type. It is a frozen copy: same fields, but immutable:
#[derive(Clone)]
struct BasicKvSnapshot {
layers: Vec<LayerKvState>,
ops: Arc<dyn Backend>,
}
impl std::fmt::Debug for BasicKvSnapshot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BasicKvSnapshot")
.field("num_layers", &self.layers.len())
.field(
"seq_len",
&self.layers.first().map(|l| l.seq_len()).unwrap_or(0),
)
.finish()
}
}
impl KvSnapshot for BasicKvSnapshot {
fn seq_len(&self) -> usize {
self.layers.first().map(|l| l.seq_len()).unwrap_or(0)
}
fn materialize(&self) -> Box<dyn KvCache> {
Box::new(BasicKvCache {
layers: self.layers.clone(),
ops: self.ops.clone(),
})
}
}A BasicKvSnapshot is a deep copy of the layer states. materialize turns it back into a usable BasicKvCache. This is Act 3 machinery; we include it because it's part of the trait. For this chapter, only the KvCache impl runs.
Finally, the KvCache impl itself, mostly forwarding to LayerKvState:
impl KvCache for BasicKvCache {
fn snapshot(&self) -> Box<dyn KvSnapshot> {
Box::new(BasicKvSnapshot {
layers: self.layers.clone(),
ops: self.ops.clone(),
})
}
fn kind(&self) -> &'static str {
"basic"
}
fn set_prefill(&mut self, layer: usize, k: &Tensor, v: &Tensor) {
self.layers[layer].set_prefill(k, v);
}
fn push_row(&mut self, layer: usize, k_row: &Tensor, v_row: &Tensor) {
let ops = self.ops.clone();
self.layers[layer].push_row(ops.as_ref(), k_row, v_row);
}
fn materialize(&self, layer: usize) -> (Tensor, Tensor) {
self.layers[layer].materialize()
}
fn seq_len(&self, layer: usize) -> usize {
self.layers[layer].seq_len()
}
fn num_layers(&self) -> usize {
self.layers.len()
}
}Each method indexes into self.layers[layer] and delegates. push_row clones the Arc<dyn Backend> out first to dodge a borrow-checker conflict: self.ops and self.layers[layer] would otherwise be borrowed at once, one shared and one mutable. The Arc clone is just a refcount bump, so it's free in practice.
The factory
src/cache/factory.rs maps a string mode to a cache, the same pattern as create_backend:
use std::sync::Arc;
use crate::backend::Backend;
use crate::model::Model;
use super::basic::BasicKvCache;
use super::kv_cache_trait::KvCache;
pub fn create_kv_cache(
mode: &str,
model: Arc<dyn Model>,
backend: Arc<dyn Backend>,
) -> Result<Box<dyn KvCache>, String> {
let (layers, width) = model.kv_cache_layers_and_width();
match mode.trim() {
"basic" => Ok(Box::new(BasicKvCache::new(layers, width, backend))),
other => Err(format!("unknown kv cache mode {other:?} (supported: basic)")),
}
}The cache's dimensions (how many layers, how wide each K/V row is) are properties of the model, so the factory asks the model for them via kv_cache_layers_and_width() (a Model method we add below) rather than hard-coding Qwen3's numbers. "basic" is the only mode for now; Act 3 adds more arms here.
The module file ties it together:
mod basic;
mod factory;
mod kv_cache_trait;
pub use factory::create_kv_cache;
pub use kv_cache_trait::KvCache;
pub(crate) use kv_cache_trait::KvSnapshot;Two new backend operations
The decode path needs two things the Backend trait can't yet do. We add them to src/backend/backend_trait.rs:
fn apply_rope_single_row(
&self,
x: &Tensor,
position: usize,
head_dim: usize,
rope_theta: f32,
) -> Tensor;
fn concat_dim0(&self, a: &Tensor, b: &Tensor) -> Tensor;concat_dim0 is the append we need for push_row. apply_rope_single_row needs a word of explanation. RoPE (rotary position embedding, built in I.5) encodes where a token sits in the sequence by rotating its query and key vectors by an angle that depends on the position. The Act 1 apply_rope rotates a whole batch of rows, each by its own row index. But a decode step has exactly one token, and its position isn't row 0: it's wherever the token actually sits in the sequence (after a 10-token prompt and 3 generated tokens, position 13). So we need a variant that rotates a single row by an explicit position. That's apply_rope_single_row.
The CpuBackend implementations, in src/backend/cpu.rs:
fn apply_rope_single_row(
&self,
x: &Tensor,
position: usize,
head_dim: usize,
rope_theta: f32,
) -> Tensor {
assert_eq!(x.shape().len(), 2);
assert_eq!(x.shape()[0], 1);
let total_width = x.shape()[1];
let n_heads = total_width / head_dim;
assert_eq!(total_width, n_heads * head_dim);
assert!(head_dim % 2 == 0);
let mut out = x.as_f32_slice().to_vec();
rope_rotate_row(&mut out, 0, position, n_heads, head_dim, rope_theta);
Tensor::new(out, vec![1, total_width])
}It asserts the input is a single row (x.shape()[0] == 1), copies the data, and calls the existing rope_rotate_row helper (the same per-row rotation Act 1's apply_rope uses in a loop) once, with the explicit position instead of a loop index.
fn concat_dim0(&self, a: &Tensor, b: &Tensor) -> Tensor {
assert_eq!(a.shape().len(), 2);
assert_eq!(b.shape().len(), 2);
let c = a.shape()[1];
assert_eq!(c, b.shape()[1]);
let r0 = a.shape()[0];
let r1 = b.shape()[0];
let mut data = Vec::with_capacity(a.as_f32_slice().len() + b.as_f32_slice().len());
data.extend_from_slice(a.as_f32_slice());
data.extend_from_slice(b.as_f32_slice());
Tensor::new(data, vec![r0 + r1, c])
}concat_dim0 stacks b underneath a. Both must be 2-D with the same column count; the result has r0 + r1 rows. Since tensors are row-major flat Vec<f32>s, stacking along the row axis is just appending b's bytes after a's, which is exactly what the two extend_from_slice calls do. (When a is the zero-row tensor a fresh cache starts with, this is just a copy of b.)
A decode-specialized attention
Act 1's attention, gqa_attention_forward_with_kv, processes a whole sequence and applies a causal mask: for seq tokens it builds a [seq, seq] score matrix and zeros out the upper triangle, so token i can't attend to tokens after it. Decode doesn't need that mask. There's one query (the new token), and every cached key is a valid, earlier token to attend to. The query is at the end of the sequence by construction. No masking required.
So we add a mask-free attention helper for decode. src/model/common/attention.rs gets a single-head core:
fn gqa_attention_context_one_head_decode(
ops: &dyn Backend,
q_h: &Tensor,
k_h: &Tensor,
v_h: &Tensor,
scale_attn: f32,
) -> Tensor {
let scores = ops.matmul(q_h, &ops.transpose_2d(k_h));
let scores = ops.scale(&scores, scale_attn);
let attn = ops.softmax_rows(&scores);
ops.matmul(&attn, v_h)
}It is the Act 1 single-head attention with one line removed: the causal_mask_upper_tri call is gone. q_h is one row (the new token's query for this head); k_h and v_h are the full cached keys and values for the matching key/value head. The query dotted against every key gives one row of scores; scale, softmax into weights, and mix the values. The result is one row, this token's attention output for this head.
And the full decode attention that loops over heads:
pub(crate) fn gqa_attention_decode_one_query(
ops: &dyn Backend,
q: &Tensor,
k_cache: &Tensor,
v_cache: &Tensor,
o_proj: &Tensor,
num_attention_heads: usize,
num_key_value_heads: usize,
head_dim: usize,
) -> Tensor {
let seq = k_cache.shape()[0];
let nh = num_attention_heads;
let nkv = num_key_value_heads;
let hd = head_dim;
let qw = nh * hd;
let kv_group = nh / nkv;
let scale_attn = 1.0 / (hd as f32).sqrt();
let mut concat = vec![0.0f32; qw];
for h_idx in 0..nh {
let kv_h = h_idx / kv_group;
let q_h = slice_head(ops, q, 1, nh, hd, h_idx);
let k_h = slice_head(ops, k_cache, seq, nkv, hd, kv_h);
let v_h = slice_head(ops, v_cache, seq, nkv, hd, kv_h);
let ctx = gqa_attention_context_one_head_decode(ops, &q_h, &k_h, &v_h, scale_attn);
ops.copy_2d_into_cols(&mut concat, qw, &ctx, h_idx * hd);
}
let merged = Tensor::new(concat, vec![1, qw]);
ops.matmul(&merged, o_proj)
}This is the Act 1 head loop adapted for decode. For each of the nh attention heads, it slices that head's query out of q (one row), slices the corresponding key/value head's full data out of the cache (seq rows, where seq is k_cache.shape()[0], the cached length), runs the mask-free attention, and writes the per-head result into a column slice of concat. kv_group = nh / nkv is the grouped-query-attention sharing factor: Qwen3 has more query heads than key/value heads, and several query heads share one cached K/V head, so kv_h = h_idx / kv_group maps a query head to its key/value head. After all heads, concat is one full row, and the output projection o_proj produces the layer's attention output.
The module re-export gains the new function:
pub(crate) use attention::{gqa_attention_decode_one_query, gqa_attention_forward_with_kv};The model trait gains three methods
src/model/model_trait.rs grows to declare the two cached forward paths and the dimension query:
use crate::cache::KvCache;
use crate::tensor::Tensor;
pub trait Model: Send + Sync {
fn forward(&self, token_ids: &[usize]) -> Tensor;
fn forward_prefill_with_kv_cache(&self, token_ids: &[usize], cache: &mut dyn KvCache)
-> Tensor;
fn forward_decode_with_kv_cache(
&self,
token_id: usize,
pos: usize,
cache: &mut dyn KvCache,
) -> Tensor;
fn kv_cache_layers_and_width(&self) -> (usize, usize);
}forward stays, since the non-cached path is still useful and --kv off keeps using it. forward_prefill_with_kv_cache takes the whole prompt and a cache to fill. forward_decode_with_kv_cache takes a single token id, its position pos, and the cache to read and extend. kv_cache_layers_and_width returns the (num_layers, kv_width) pair the factory needs.
The cached forward paths in Qwen3
Now src/model/qwen3/forward.rs. First, the dimension query: Qwen3 reads both numbers off its config:
pub fn kv_cache_layers_and_width(&self) -> (usize, usize) {
(self.config.num_hidden_layers, self.config.kv_width())
}num_hidden_layers is 28; kv_width() is the width of one K (or V) row, which is the number of key/value heads times the head dimension.
The Act 1 forward is refactored so prefill can share its body. The trick is a forward_common that takes an optional cache:
pub fn forward(&self, token_ids: &[usize]) -> Tensor {
self.forward_common(token_ids, None)
}
pub fn forward_prefill_with_kv_cache(
&self,
token_ids: &[usize],
cache: &mut dyn KvCache,
) -> Tensor {
self.forward_common(token_ids, Some(cache))
}
fn forward_common(&self, token_ids: &[usize], mut cache: Option<&mut dyn KvCache>) -> Tensor {
let ops = self.cpu_backend.as_ref();
let cfg = &self.config;
let mut x = ops.gather_rows(&self.embed, token_ids);
for (li, layer) in self.layers.iter().enumerate() {
let normed = rms_norm_weighted_last(ops, &x, &layer.input_layernorm, cfg.rms_norm_eps);
let (attn_out, k_rope, v) = gqa_attention_forward_with_kv(
ops,
&normed,
&layer.q_proj,
&layer.k_proj,
&layer.v_proj,
&layer.o_proj,
&layer.attn_q_norm,
&layer.attn_k_norm,
cfg.num_attention_heads,
cfg.num_key_value_heads,
cfg.head_dim,
cfg.rms_norm_eps,
cfg.rope_theta,
);
if let Some(c) = cache.as_mut() {
c.set_prefill(li, &k_rope, &v);
}
x = ops.add(&x, &attn_out);
let normed_mlp =
rms_norm_weighted_last(ops, &x, &layer.post_attention_layernorm, cfg.rms_norm_eps);
let mlp_out = mlp_forward(ops, &normed_mlp, layer);
x = ops.add(&x, &mlp_out);
}
let x = rms_norm_weighted_last(ops, &x, &self.norm, cfg.rms_norm_eps);
ops.matmul(&x, &self.lm_head)
}forward_common is the Act 1 forward pass with two changes. The layer loop now uses enumerate() so each layer knows its index li. And gqa_attention_forward_with_kv already returned the rotated keys k_rope and values v (Act 1 just discarded them with let (attn_out, _k_rope, _v) = ...). Now, if a cache was passed, those tensors are installed into it via c.set_prefill(li, &k_rope, &v). When cache is None (a plain forward), the if let is skipped and the function behaves exactly as it did in Act 1. Same code, two callers, one of which fills a cache on its way through.
Now the decode path, the one that does the real work-saving:
pub fn forward_decode_with_kv_cache(
&self,
token_id: usize,
pos: usize,
cache: &mut dyn KvCache,
) -> Tensor {
let ops = self.cpu_backend.as_ref();
let cfg = &self.config;
assert_eq!(
cache.seq_len(0),
pos,
"KV row count must equal RoPE position before push"
);
let mut x = ops.gather_rows(&self.embed, &[token_id]);It takes one token_id, not a slice. The assertion is a correctness guard worth pausing on: before we process this token, the cache must already hold exactly pos rows. A token at position pos has pos tokens before it, all of which should already be cached. If those two numbers disagree, the RoPE rotation would use the wrong angle and attention would be silently wrong. The embedding lookup gathers a single row: the one new token.
Then the per-layer loop:
for (li, layer) in self.layers.iter().enumerate() {
let normed = rms_norm_weighted_last(ops, &x, &layer.input_layernorm, cfg.rms_norm_eps);
let q = ops.matmul(&normed, &layer.q_proj);
let k = ops.matmul(&normed, &layer.k_proj);
let v = ops.matmul(&normed, &layer.v_proj);
let q = headwise_rms_norm_weighted(
ops,
&q,
cfg.num_attention_heads,
cfg.head_dim,
&layer.attn_q_norm,
cfg.rms_norm_eps,
);
let k = headwise_rms_norm_weighted(
ops,
&k,
cfg.num_key_value_heads,
cfg.head_dim,
&layer.attn_k_norm,
cfg.rms_norm_eps,
);
let q = ops.apply_rope_single_row(&q, pos, cfg.head_dim, cfg.rope_theta);
let k_rope = ops.apply_rope_single_row(&k, pos, cfg.head_dim, cfg.rope_theta);
cache.push_row(li, &k_rope, &v);
let (k_full, v_full) = cache.materialize(li);
let attn_out = gqa_attention_decode_one_query(
ops,
&q,
&k_full,
&v_full,
&layer.o_proj,
cfg.num_attention_heads,
cfg.num_key_value_heads,
cfg.head_dim,
);
x = ops.add(&x, &attn_out);
let normed_mlp =
rms_norm_weighted_last(ops, &x, &layer.post_attention_layernorm, cfg.rms_norm_eps);
let mlp_out = mlp_forward(ops, &normed_mlp, layer);
x = ops.add(&x, &mlp_out);
}
let x = rms_norm_weighted_last(ops, &x, &self.norm, cfg.rms_norm_eps);
ops.matmul(&x, &self.lm_head)
}This is the chapter's whole idea, in one loop. For each layer:
- Project the one new token into its query, key, and value:
matmulagainstq_proj,k_proj,v_proj. These are matrix-by-vector products now, because there's one token, not a sequence. - Apply head-wise RMS norm and then RoPE, using
apply_rope_single_rowwith the explicitpos, since this single token's position is known but isn't row 0. cache.push_row(li, &k_rope, &v): append this token's K and V to the layer's cache.cache.materialize(li): read back the full cached K and V, every token so far.- Run
gqa_attention_decode_one_query: the new token's query against the full cached K/V. - The rest of the layer (residual add, MLP, second residual) is identical to the prefill path.
Compare this to what the non-cached forward does to generate the same token: it runs steps 1–6 for every token in the sequence, every time. The decode path runs them for one token and reads the rest out of the cache. That is the N²-to-N collapse, made concrete.
The impl Model for Qwen3Model block gets the three new methods, each just forwarding to the inherent method of the same name:
impl Model for Qwen3Model {
fn forward(&self, token_ids: &[usize]) -> Tensor {
Qwen3Model::forward(self, token_ids)
}
fn forward_prefill_with_kv_cache(
&self,
token_ids: &[usize],
cache: &mut dyn KvCache,
) -> Tensor {
Qwen3Model::forward_prefill_with_kv_cache(self, token_ids, cache)
}
fn forward_decode_with_kv_cache(
&self,
token_id: usize,
pos: usize,
cache: &mut dyn KvCache,
) -> Tensor {
Qwen3Model::forward_decode_with_kv_cache(self, token_id, pos, cache)
}
fn kv_cache_layers_and_width(&self) -> (usize, usize) {
Qwen3Model::kv_cache_layers_and_width(self)
}
}The --kv flag
src/cli/args.rs learns to parse --kv [basic]. First the ArgCursor gets a peek (look at the next argument without consuming it, needed because --kv's mode argument is optional):
fn peek(&self) -> Option<&str> {
self.args.get(self.i).map(|s| s.as_str())
}Then a helper that reads the optional mode word after --kv:
fn parse_kv_mode(cursor: &mut ArgCursor<'_>) -> &'static str {
match cursor.peek() {
Some("basic") => {
cursor.advance();
"basic"
}
Some(s) if !s.starts_with('-') => {
panic!("invalid --kv mode '{s}' (expected basic)");
}
Some(_) | None => "basic",
}
}--kv can be followed by an explicit mode (--kv basic) or nothing (--kv, which defaults to basic). The logic: if the next word is basic, consume it; if it's a non-flag word that isn't a valid mode, that's a user error so we panic; otherwise (the next word is a flag, or there is no next word) leave it alone and default to basic.
CliArgs gains a kv_mode field and the --kv arm in the parse loop:
pub struct CliArgs {
kv_mode: Option<&'static str>,
positionals: Vec<String>,
}
impl CliArgs {
pub fn from_env() -> Self {
Self::parse(std::env::args().collect())
}
pub fn parse(args: Vec<String>) -> Self {
let mut kv_mode = None;
let mut positionals = Vec::new();
let mut cur = ArgCursor::new(&args);
while cur.has_more() {
match cur.peek() {
Some("--kv") => {
cur.advance();
kv_mode = Some(parse_kv_mode(&mut cur));
}
_ => positionals.push(cur.take()),
}
}
Self {
kv_mode,
positionals,
}
}
pub fn positionals(&self) -> &[String] {
&self.positionals
}
pub fn kv_cache_mode(&self) -> Option<&'static str> {
self.kv_mode
}
}kv_mode is None when --kv is absent (cache off) and Some("basic") when present. The parse loop now peeks: if the next argument is --kv, consume it and parse the mode; anything else is a positional, exactly as in I.1. This is the additive-flag pattern that chapter set up. kv_cache_mode() exposes the result.
The crate root re-exports the factory and trait:
pub use cache::{create_kv_cache, KvCache};Wiring through the binary and the loop
greedy_generate takes the cache and picks a forward path per step:
use tracing::info;
use crate::backend::Backend;
use crate::cache::KvCache;
use crate::decode::{Metrics, next_token_id_from_logits};
use crate::model::Model;
use crate::tokenizer::Tokenizer;pub fn greedy_generate(
model: &dyn Model,
ops: &dyn Backend,
tokenizer: &dyn Tokenizer,
prompt_ids: &[usize],
max_new_tokens: usize,
eos_token_id: usize,
metrics: &mut Metrics,
cache: &mut Option<Box<dyn KvCache>>,
) -> Vec<usize> {
let mut ids = prompt_ids.to_vec();
for step in 0..max_new_tokens {
let logits = metrics.record_timed(|| match cache.as_mut() {
Some(c) if step == 0 => model.forward_prefill_with_kv_cache(prompt_ids, c.as_mut()),
Some(c) => {
model.forward_decode_with_kv_cache(*ids.last().unwrap(), ids.len() - 1, c.as_mut())
}
None => model.forward(&ids),
});
let (next_id, prob) = next_token_id_from_logits(ops, &logits);
ids.push(next_id);
info!(
step = step + 1,
token_id = next_id,
prob = prob * 100.0,
token = %tokenizer.decode(&[next_id]),
phase = if step == 0 { "prefill" } else { "decode" },
);
if next_id == eos_token_id {
info!("[EOS reached]");
break;
}
}
ids
}The new parameter is cache: &mut Option<Box<dyn KvCache>>, None if --kv was not passed. The match inside record_timed is the dispatch:
Some(c) if step == 0: cache on, first step: run prefill, filling the cache from the whole prompt.Some(c)(step > 0): cache on, a decode step: run the decode path with the last token id and its positionids.len() - 1. After a 4-token prompt, the first decoded token is at position 4, which isids.len() - 1whenidsholds the prompt plus the newly accepted previous token.None: cache off: the Act 1forwardover the whole growing sequence. This is the path II.1 measured.
The phase field on the log event records "prefill" or "decode" so the streamed events show which path produced each token.
Finally model-generate. It reads the --kv flag, builds the cache, and threads it in:
use std::path::Path;
use inferno::{
CliArgs, Metrics, create_backend, create_kv_cache, greedy_generate, load_from_gguf_path,
};
fn usage() -> ! {
eprintln!("usage: model-generate [--kv [basic]] <gguf_path> [prompt] [max_new_tokens]");
std::process::exit(2);
} let args = CliArgs::from_env();
let kv_mode = args.kv_cache_mode(); println!("kv cache: {}", kv_mode.unwrap_or("off"));
println!();
let mut metrics = Metrics::default();
let mut cache = kv_mode
.map(|mode| create_kv_cache(mode, model.clone(), backend.clone()).unwrap_or_else(|e| {
eprintln!("error: {e}");
std::process::exit(2);
}));
let full_ids = greedy_generate(
model.as_ref(),
&*backend,
tokenizer.as_ref(),
&prompt_ids,
max_new_tokens,
eos_token_id,
&mut metrics,
&mut cache,
);kv_mode is None when --kv is absent. kv_mode.map(...) builds a cache only if a mode was given, so cache is Some(Box<dyn KvCache>) with --kv, and None without. create_kv_cache needs the model (for the dimensions) and the backend (for concat_dim0). The rest of main is unchanged from II.1.
Running it
Run the same prompt twice, once with the cache off and once with it on, and let the harness compare them:
cargo run --release --bin model-generate -- path/to/qwen3-0.6b.gguf "Once upon a time" 32
cargo run --release --bin model-generate -- --kv path/to/qwen3-0.6b.gguf "Once upon a time" 32Without --kv, the metrics block looks like II.1 and the per-forward dump climbs:
kv cache: off
metrics:
time_to_first_token_ms: 421.880
decode_tokens_per_second: 1.984
per_forward_ms: min 421.880 max 1224.402 mean 504.114 (n=32)
forward 1: 421.880 ms
forward 2: 447.221 ms
...
forward 32: 1224.402 msWith --kv, the per-forward times go flat and decode throughput jumps:
kv cache: basic
metrics:
time_to_first_token_ms: 423.104
decode_tokens_per_second: 47.602
per_forward_ms: min 19.882 max 423.104 mean 32.690 (n=32)
forward 1: 423.104 ms
forward 2: 21.114 ms
forward 3: 20.402 ms
...
forward 32: 21.985 msTwo things to read here. Time to first token is unchanged (~420 ms either way). That's prefill, and the cache doesn't make prefill faster; it makes prefill save its work. The payoff is decode: throughput went from about 2 tokens/sec to roughly 48, better than 20×. And the per-forward dump tells the real story. Without the cache, forward 32 takes nearly 3× as long as forward 2; with it, every decode forward pass costs the same flat ~21 ms whether it's the 2nd token or the 32nd. The N² is gone. Longer responses no longer get slower per token.
Where this leaves us
The KV cache is the single largest speedup in Act 2, and it came from an algorithmic fix rather than a faster kernel: stop recomputing what doesn't change. Decode is now constant work per token instead of growing work, and the harness from II.1 proved it: flat per-forward times, a 20×-plus throughput jump.
But look again at that decode number: ~21 ms per forward pass, still. A pass is now a fixed amount of arithmetic, and that arithmetic is running on a single CPU core, one float at a time, in a plain scalar loop. The algorithm is right; the kernel is slow. The next chapter keeps the algorithm exactly as it is and attacks the kernel, rewriting matmul to use the CPU's SIMD vector instructions, doing 4 and then 16 multiply-adds per instruction instead of one.