II.2: KV cache

The harness from II.1 printed one number that should bother you: the forward passes get slower as the response gets longer. Generating token 16 takes more wall time than generating token 4. Each forward pass produces exactly one token, so why would later passes cost more?

Because the model is doing the same work over and over. To generate token 16, our Act 1 forward pass runs the entire network over all 15 prior tokens, then reads off the next one. Token 17 reruns it over 16 tokens. The sequence keeps growing, every forward pass reprocesses all of it, and the per-token cost climbs linearly. Generating an N-token response costs roughly work.

It does not have to. Almost all of that recomputed work produces identical numbers each time, and we can save them. This chapter is the KV cache, the single biggest speedup in Act 2. It turns each decode step from "reprocess the whole sequence" into "process one token," and it does so by splitting the forward pass into two distinct paths.

Prefill and decode are different shapes of work

I.5 built one forward function. It takes a list of token ids, runs all of them through the 28 transformer layers, and returns logits. We've been calling it once per generated token, each time with the full sequence so far. That works, and it is wasteful, because it conflates two genuinely different jobs.

Prefill is processing the prompt. The first forward pass takes all P prompt tokens and runs them through the network together. Every token attends to every earlier token; this must look at the whole prompt at once. There is no shortcut here, and we don't want one.

Decode is generating the response, one token at a time. Step i has a single new token. It needs to produce one row of logits. And here is the key fact: the new token attends to all earlier tokens, but those earlier tokens were already processed on previous steps. Their contribution to attention does not change. If we save it, we never recompute it.

So the plan is two forward functions instead of one:

  • forward_prefill_with_kv_cache: runs the whole prompt, exactly like the Act 1 forward, but on the way it saves the per-layer attention data into a cache.
  • forward_decode_with_kv_cache: runs just the one new token, reusing everything the cache already holds.

To see what to save, we need to look at attention.

What attention recomputes, and what's worth keeping

Inside each transformer layer, attention works like this (the full mechanism is in I.5; here's the part that matters for caching). Every token is projected into three vectors: a query (Q), a key (K), and a value (V). To compute the output for a given token, attention compares that token's query against the keys of every token up to and including it, turns those comparisons into weights, and uses the weights to mix together the values.

Now look at decode step i. The new token has a fresh Q, K, and V. It needs the K and V of every earlier token too, but those were computed from earlier tokens, which haven't changed. Token 5's key is the same number whether you compute it on step 5 or step 500. Recomputing it is pure waste.

That is the cache. For each of the 28 layers, we store the K and V vectors of every token processed so far. Hence "KV cache." On a decode step:

  1. Compute Q, K, V for the one new token only.
  2. Append its K and V to the cache for this layer.
  3. Run attention using the new token's Q against the whole cached K and V.

Step 3 still attends over the full history (correctness is preserved exactly), but steps 1 and 2 only touch the new token. The blowup is gone. Each decode step is now constant work: one token through the network.

What we don't cache is the queries. A query is only ever used on the step that created it: token i's query attends to tokens 0..=i and is then never needed again. Only K and V get reused, so only K and V get stored.

PLAINTEXT
Without KV cache, generating token 4 reprocesses tokens 0,1,2,3:
 
  step 1:  [t0]                 -> t1     (1 token  of work)
  step 2:  [t0 t1]              -> t2     (2 tokens of work)
  step 3:  [t0 t1 t2]           -> t3     (3 tokens of work)
  step 4:  [t0 t1 t2 t3]        -> t4     (4 tokens of work)   ... grows
 
With KV cache, each step processes one new token, reuses the rest:
 
  prefill: [t0 t1 t2]           -> t3     K,V of t0,t1,t2 saved
  step 1:  [t3]  + cached K,V   -> t4     append t3's K,V; 1 token of work
  step 2:  [t4]  + cached K,V   -> t5     append t4's K,V; 1 token of work   ... flat

The plan

This chapter touches a lot of files because the cache cuts across layers: it is per-layer state, it needs to flow through the model, and the binary needs a flag to switch it on. The pieces:

  1. A cache module with a KvCache trait and one implementation, BasicKvCache.
  2. Two new Backend operations the decode path needs: appending a row to a tensor (concat_dim0) and applying RoPE position-encoding to a single token (apply_rope_single_row).
  3. A decode-specialized attention function, gqa_attention_decode_one_query.
  4. The two new model methods, forward_prefill_with_kv_cache and forward_decode_with_kv_cache.
  5. A --kv flag and the wiring through model-generate and greedy_generate.

We'll take them in roughly that order.

The KvCache trait

A trait, not just a struct, because Act 3 replaces BasicKvCache with smarter layouts (paged caches, shared-prefix caches), and the model code should not care which one it's handed. src/cache/kv_cache_trait.rs:

src/cache/kv_cache_trait.rsRUST
use crate::tensor::Tensor;
 
pub trait KvSnapshot: std::fmt::Debug + Send + Sync {
    fn seq_len(&self) -> usize;
    fn materialize(&self) -> Box<dyn KvCache>;
}
 
pub trait KvCache: std::fmt::Debug + Send + Sync {
    fn snapshot(&self) -> Box<dyn KvSnapshot>;
    fn kind(&self) -> &'static str;
    fn set_prefill(&mut self, layer: usize, k: &Tensor, v: &Tensor);
    fn push_row(&mut self, layer: usize, k_row: &Tensor, v_row: &Tensor);
    fn materialize(&self, layer: usize) -> (Tensor, Tensor);
    fn seq_len(&self, layer: usize) -> usize;
    fn num_layers(&self) -> usize;
}

KvCache is the working cache. The methods the model actually calls are four:

  • set_prefill(layer, k, v): prefill processed the whole prompt for layer; install its K and V into the cache as the starting contents.
  • push_row(layer, k_row, v_row): a decode step produced one new token's K and V for layer; append them.
  • materialize(layer): hand back the full (K, V) for layer so attention can run.
  • seq_len(layer): how many tokens are cached for layer. Used as a consistency check.

KvSnapshot is a frozen copy of a cache. We don't use it this chapter. It surfaces in Act 3, where the prefix cache needs to save a cache's state at a point in the conversation and restore it later. snapshot() produces one; materialize() turns it back into a live cache. It's in the trait now so the trait is stable; ignore it for this chapter.

kind() returns a label like "basic" for logging. The Send + Sync bound means a cache can be moved across threads, which is not needed yet but will be needed once there's a server.

BasicKvCache

The implementation. The strategy is the obvious one: for each layer, hold the K tensor and the V tensor; appending a token is appending a row. src/cache/basic.rs starts with one layer's worth of state:

src/cache/basic.rsRUST
use std::fmt::Debug;
use std::sync::Arc;
 
use super::{KvCache, KvSnapshot};
use crate::backend::Backend;
use crate::tensor::Tensor;
 
#[derive(Debug, Clone)]
pub(crate) struct LayerKvState {
    pub k: Tensor,
    pub v: Tensor,
}
 
impl LayerKvState {
    pub fn empty(kv_width: usize) -> Self {
        Self {
            k: Tensor::new(vec![], vec![0, kv_width]),
            v: Tensor::new(vec![], vec![0, kv_width]),
        }
    }

LayerKvState is the K and V for one transformer layer. Each is a 2-D tensor shaped [seq_len, kv_width], with one row per token and each row holding that token's key (or value) across all the key/value heads concatenated. empty creates a layer with zero rows: shape [0, kv_width], an empty data vector. A fresh cache is 28 of these, one per layer, all empty. As tokens are processed, rows get added and seq_len grows.

src/cache/basic.rsRUST
    pub fn set_prefill(&mut self, k: &Tensor, v: &Tensor) {
        self.k = k.clone();
        self.v = v.clone();
    }
 
    pub fn push_row(&mut self, ops: &dyn Backend, k_row: &Tensor, v_row: &Tensor) {
        self.k = ops.concat_dim0(&self.k, k_row);
        self.v = ops.concat_dim0(&self.v, v_row);
    }
 
    pub fn materialize(&self) -> (Tensor, Tensor) {
        (self.k.clone(), self.v.clone())
    }
 
    pub fn seq_len(&self) -> usize {
        self.k.shape()[0]
    }
}

set_prefill is the prefill case: prefill computed K and V for the entire prompt at once, so we just store those whole tensors as the layer's starting state. push_row is the decode case: it appends one new token's K row and V row to the bottom of the existing tensors with concat_dim0 (concatenation along dimension 0, the row axis). materialize clones out the tensors for attention to use; seq_len reads the row count off the K tensor's shape.

push_row doing a full concatenate-and-reallocate every step is admittedly not the most efficient layout. That's exactly why it's the basic cache. Act 3's paged cache replaces this with fixed-size blocks and no reallocation. For one request, the simple version is plenty, and its simplicity makes the idea legible.

Now the cache itself: a Vec of LayerKvState, plus an Arc<dyn Backend> because push_row needs a backend to call concat_dim0:

src/cache/basic.rsRUST
#[derive(Clone)]
pub(crate) struct BasicKvCache {
    pub(crate) layers: Vec<LayerKvState>,
    ops: Arc<dyn Backend>,
}
 
impl Debug for BasicKvCache {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("BasicKvCache")
            .field("num_layers", &self.layers.len())
            .field(
                "seq_len",
                &self.layers.first().map(|l| l.k.shape()[0]).unwrap_or(0),
            )
            .finish()
    }
}
 
impl BasicKvCache {
    pub fn new(num_layers: usize, kv_width: usize, ops: Arc<dyn Backend>) -> Self {
        Self {
            layers: (0..num_layers)
                .map(|_| LayerKvState::empty(kv_width))
                .collect(),
            ops,
        }
    }
}

new builds num_layers empty layer states. The hand-written Debug impl prints just the layer count and sequence length rather than dumping every cached number, since a cache holding thousands of floats per layer is not something you want spilling into a log line. kv_width and num_layers come from the model; we'll see where shortly.

Next the snapshot type. It is a frozen copy: same fields, but immutable:

src/cache/basic.rsRUST
#[derive(Clone)]
struct BasicKvSnapshot {
    layers: Vec<LayerKvState>,
    ops: Arc<dyn Backend>,
}
 
impl std::fmt::Debug for BasicKvSnapshot {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("BasicKvSnapshot")
            .field("num_layers", &self.layers.len())
            .field(
                "seq_len",
                &self.layers.first().map(|l| l.seq_len()).unwrap_or(0),
            )
            .finish()
    }
}
 
impl KvSnapshot for BasicKvSnapshot {
    fn seq_len(&self) -> usize {
        self.layers.first().map(|l| l.seq_len()).unwrap_or(0)
    }
 
    fn materialize(&self) -> Box<dyn KvCache> {
        Box::new(BasicKvCache {
            layers: self.layers.clone(),
            ops: self.ops.clone(),
        })
    }
}

A BasicKvSnapshot is a deep copy of the layer states. materialize turns it back into a usable BasicKvCache. This is Act 3 machinery; we include it because it's part of the trait. For this chapter, only the KvCache impl runs.

Finally, the KvCache impl itself, mostly forwarding to LayerKvState:

src/cache/basic.rsRUST
impl KvCache for BasicKvCache {
    fn snapshot(&self) -> Box<dyn KvSnapshot> {
        Box::new(BasicKvSnapshot {
            layers: self.layers.clone(),
            ops: self.ops.clone(),
        })
    }
 
    fn kind(&self) -> &'static str {
        "basic"
    }
 
    fn set_prefill(&mut self, layer: usize, k: &Tensor, v: &Tensor) {
        self.layers[layer].set_prefill(k, v);
    }
 
    fn push_row(&mut self, layer: usize, k_row: &Tensor, v_row: &Tensor) {
        let ops = self.ops.clone();
        self.layers[layer].push_row(ops.as_ref(), k_row, v_row);
    }
 
    fn materialize(&self, layer: usize) -> (Tensor, Tensor) {
        self.layers[layer].materialize()
    }
 
    fn seq_len(&self, layer: usize) -> usize {
        self.layers[layer].seq_len()
    }
 
    fn num_layers(&self) -> usize {
        self.layers.len()
    }
}

Each method indexes into self.layers[layer] and delegates. push_row clones the Arc<dyn Backend> out first to dodge a borrow-checker conflict: self.ops and self.layers[layer] would otherwise be borrowed at once, one shared and one mutable. The Arc clone is just a refcount bump, so it's free in practice.

The factory

src/cache/factory.rs maps a string mode to a cache, the same pattern as create_backend:

src/cache/factory.rsRUST
use std::sync::Arc;
 
use crate::backend::Backend;
use crate::model::Model;
 
use super::basic::BasicKvCache;
use super::kv_cache_trait::KvCache;
 
pub fn create_kv_cache(
    mode: &str,
    model: Arc<dyn Model>,
    backend: Arc<dyn Backend>,
) -> Result<Box<dyn KvCache>, String> {
    let (layers, width) = model.kv_cache_layers_and_width();
    match mode.trim() {
        "basic" => Ok(Box::new(BasicKvCache::new(layers, width, backend))),
        other => Err(format!("unknown kv cache mode {other:?} (supported: basic)")),
    }
}

The cache's dimensions (how many layers, how wide each K/V row is) are properties of the model, so the factory asks the model for them via kv_cache_layers_and_width() (a Model method we add below) rather than hard-coding Qwen3's numbers. "basic" is the only mode for now; Act 3 adds more arms here.

The module file ties it together:

src/cache/mod.rsRUST
mod basic;
mod factory;
mod kv_cache_trait;
 
pub use factory::create_kv_cache;
pub use kv_cache_trait::KvCache;
 
pub(crate) use kv_cache_trait::KvSnapshot;

Two new backend operations

The decode path needs two things the Backend trait can't yet do. We add them to src/backend/backend_trait.rs:

src/backend/backend_trait.rsRUST
    fn apply_rope_single_row(
        &self,
        x: &Tensor,
        position: usize,
        head_dim: usize,
        rope_theta: f32,
    ) -> Tensor;
 
    fn concat_dim0(&self, a: &Tensor, b: &Tensor) -> Tensor;

concat_dim0 is the append we need for push_row. apply_rope_single_row needs a word of explanation. RoPE (rotary position embedding, built in I.5) encodes where a token sits in the sequence by rotating its query and key vectors by an angle that depends on the position. The Act 1 apply_rope rotates a whole batch of rows, each by its own row index. But a decode step has exactly one token, and its position isn't row 0: it's wherever the token actually sits in the sequence (after a 10-token prompt and 3 generated tokens, position 13). So we need a variant that rotates a single row by an explicit position. That's apply_rope_single_row.

The CpuBackend implementations, in src/backend/cpu.rs:

src/backend/cpu.rsRUST
    fn apply_rope_single_row(
        &self,
        x: &Tensor,
        position: usize,
        head_dim: usize,
        rope_theta: f32,
    ) -> Tensor {
        assert_eq!(x.shape().len(), 2);
        assert_eq!(x.shape()[0], 1);
        let total_width = x.shape()[1];
        let n_heads = total_width / head_dim;
        assert_eq!(total_width, n_heads * head_dim);
        assert!(head_dim % 2 == 0);
        let mut out = x.as_f32_slice().to_vec();
        rope_rotate_row(&mut out, 0, position, n_heads, head_dim, rope_theta);
        Tensor::new(out, vec![1, total_width])
    }

It asserts the input is a single row (x.shape()[0] == 1), copies the data, and calls the existing rope_rotate_row helper (the same per-row rotation Act 1's apply_rope uses in a loop) once, with the explicit position instead of a loop index.

src/backend/cpu.rsRUST
    fn concat_dim0(&self, a: &Tensor, b: &Tensor) -> Tensor {
        assert_eq!(a.shape().len(), 2);
        assert_eq!(b.shape().len(), 2);
        let c = a.shape()[1];
        assert_eq!(c, b.shape()[1]);
        let r0 = a.shape()[0];
        let r1 = b.shape()[0];
        let mut data = Vec::with_capacity(a.as_f32_slice().len() + b.as_f32_slice().len());
        data.extend_from_slice(a.as_f32_slice());
        data.extend_from_slice(b.as_f32_slice());
        Tensor::new(data, vec![r0 + r1, c])
    }

concat_dim0 stacks b underneath a. Both must be 2-D with the same column count; the result has r0 + r1 rows. Since tensors are row-major flat Vec<f32>s, stacking along the row axis is just appending b's bytes after a's, which is exactly what the two extend_from_slice calls do. (When a is the zero-row tensor a fresh cache starts with, this is just a copy of b.)

A decode-specialized attention

Act 1's attention, gqa_attention_forward_with_kv, processes a whole sequence and applies a causal mask: for seq tokens it builds a [seq, seq] score matrix and zeros out the upper triangle, so token i can't attend to tokens after it. Decode doesn't need that mask. There's one query (the new token), and every cached key is a valid, earlier token to attend to. The query is at the end of the sequence by construction. No masking required.

So we add a mask-free attention helper for decode. src/model/common/attention.rs gets a single-head core:

src/model/common/attention.rsRUST
fn gqa_attention_context_one_head_decode(
    ops: &dyn Backend,
    q_h: &Tensor,
    k_h: &Tensor,
    v_h: &Tensor,
    scale_attn: f32,
) -> Tensor {
    let scores = ops.matmul(q_h, &ops.transpose_2d(k_h));
    let scores = ops.scale(&scores, scale_attn);
    let attn = ops.softmax_rows(&scores);
    ops.matmul(&attn, v_h)
}

It is the Act 1 single-head attention with one line removed: the causal_mask_upper_tri call is gone. q_h is one row (the new token's query for this head); k_h and v_h are the full cached keys and values for the matching key/value head. The query dotted against every key gives one row of scores; scale, softmax into weights, and mix the values. The result is one row, this token's attention output for this head.

And the full decode attention that loops over heads:

src/model/common/attention.rsRUST
pub(crate) fn gqa_attention_decode_one_query(
    ops: &dyn Backend,
    q: &Tensor,
    k_cache: &Tensor,
    v_cache: &Tensor,
    o_proj: &Tensor,
    num_attention_heads: usize,
    num_key_value_heads: usize,
    head_dim: usize,
) -> Tensor {
    let seq = k_cache.shape()[0];
    let nh = num_attention_heads;
    let nkv = num_key_value_heads;
    let hd = head_dim;
    let qw = nh * hd;
    let kv_group = nh / nkv;
 
    let scale_attn = 1.0 / (hd as f32).sqrt();
    let mut concat = vec![0.0f32; qw];
 
    for h_idx in 0..nh {
        let kv_h = h_idx / kv_group;
        let q_h = slice_head(ops, q, 1, nh, hd, h_idx);
        let k_h = slice_head(ops, k_cache, seq, nkv, hd, kv_h);
        let v_h = slice_head(ops, v_cache, seq, nkv, hd, kv_h);
 
        let ctx = gqa_attention_context_one_head_decode(ops, &q_h, &k_h, &v_h, scale_attn);
        ops.copy_2d_into_cols(&mut concat, qw, &ctx, h_idx * hd);
    }
 
    let merged = Tensor::new(concat, vec![1, qw]);
    ops.matmul(&merged, o_proj)
}

This is the Act 1 head loop adapted for decode. For each of the nh attention heads, it slices that head's query out of q (one row), slices the corresponding key/value head's full data out of the cache (seq rows, where seq is k_cache.shape()[0], the cached length), runs the mask-free attention, and writes the per-head result into a column slice of concat. kv_group = nh / nkv is the grouped-query-attention sharing factor: Qwen3 has more query heads than key/value heads, and several query heads share one cached K/V head, so kv_h = h_idx / kv_group maps a query head to its key/value head. After all heads, concat is one full row, and the output projection o_proj produces the layer's attention output.

The module re-export gains the new function:

src/model/common/mod.rsRUST
pub(crate) use attention::{gqa_attention_decode_one_query, gqa_attention_forward_with_kv};

The model trait gains three methods

src/model/model_trait.rs grows to declare the two cached forward paths and the dimension query:

src/model/model_trait.rsRUST
use crate::cache::KvCache;
use crate::tensor::Tensor;
 
pub trait Model: Send + Sync {
    fn forward(&self, token_ids: &[usize]) -> Tensor;
 
    fn forward_prefill_with_kv_cache(&self, token_ids: &[usize], cache: &mut dyn KvCache)
    -> Tensor;
 
    fn forward_decode_with_kv_cache(
        &self,
        token_id: usize,
        pos: usize,
        cache: &mut dyn KvCache,
    ) -> Tensor;
 
    fn kv_cache_layers_and_width(&self) -> (usize, usize);
}

forward stays, since the non-cached path is still useful and --kv off keeps using it. forward_prefill_with_kv_cache takes the whole prompt and a cache to fill. forward_decode_with_kv_cache takes a single token id, its position pos, and the cache to read and extend. kv_cache_layers_and_width returns the (num_layers, kv_width) pair the factory needs.

The cached forward paths in Qwen3

Now src/model/qwen3/forward.rs. First, the dimension query: Qwen3 reads both numbers off its config:

src/model/qwen3/forward.rsRUST
    pub fn kv_cache_layers_and_width(&self) -> (usize, usize) {
        (self.config.num_hidden_layers, self.config.kv_width())
    }

num_hidden_layers is 28; kv_width() is the width of one K (or V) row, which is the number of key/value heads times the head dimension.

The Act 1 forward is refactored so prefill can share its body. The trick is a forward_common that takes an optional cache:

src/model/qwen3/forward.rsRUST
    pub fn forward(&self, token_ids: &[usize]) -> Tensor {
        self.forward_common(token_ids, None)
    }
 
    pub fn forward_prefill_with_kv_cache(
        &self,
        token_ids: &[usize],
        cache: &mut dyn KvCache,
    ) -> Tensor {
        self.forward_common(token_ids, Some(cache))
    }
 
    fn forward_common(&self, token_ids: &[usize], mut cache: Option<&mut dyn KvCache>) -> Tensor {
        let ops = self.cpu_backend.as_ref();
        let cfg = &self.config;
        let mut x = ops.gather_rows(&self.embed, token_ids);
        for (li, layer) in self.layers.iter().enumerate() {
            let normed = rms_norm_weighted_last(ops, &x, &layer.input_layernorm, cfg.rms_norm_eps);
            let (attn_out, k_rope, v) = gqa_attention_forward_with_kv(
                ops,
                &normed,
                &layer.q_proj,
                &layer.k_proj,
                &layer.v_proj,
                &layer.o_proj,
                &layer.attn_q_norm,
                &layer.attn_k_norm,
                cfg.num_attention_heads,
                cfg.num_key_value_heads,
                cfg.head_dim,
                cfg.rms_norm_eps,
                cfg.rope_theta,
            );
            if let Some(c) = cache.as_mut() {
                c.set_prefill(li, &k_rope, &v);
            }
            x = ops.add(&x, &attn_out);
            let normed_mlp =
                rms_norm_weighted_last(ops, &x, &layer.post_attention_layernorm, cfg.rms_norm_eps);
            let mlp_out = mlp_forward(ops, &normed_mlp, layer);
            x = ops.add(&x, &mlp_out);
        }
        let x = rms_norm_weighted_last(ops, &x, &self.norm, cfg.rms_norm_eps);
        ops.matmul(&x, &self.lm_head)
    }

forward_common is the Act 1 forward pass with two changes. The layer loop now uses enumerate() so each layer knows its index li. And gqa_attention_forward_with_kv already returned the rotated keys k_rope and values v (Act 1 just discarded them with let (attn_out, _k_rope, _v) = ...). Now, if a cache was passed, those tensors are installed into it via c.set_prefill(li, &k_rope, &v). When cache is None (a plain forward), the if let is skipped and the function behaves exactly as it did in Act 1. Same code, two callers, one of which fills a cache on its way through.

Now the decode path, the one that does the real work-saving:

src/model/qwen3/forward.rsRUST
    pub fn forward_decode_with_kv_cache(
        &self,
        token_id: usize,
        pos: usize,
        cache: &mut dyn KvCache,
    ) -> Tensor {
        let ops = self.cpu_backend.as_ref();
        let cfg = &self.config;
        assert_eq!(
            cache.seq_len(0),
            pos,
            "KV row count must equal RoPE position before push"
        );
        let mut x = ops.gather_rows(&self.embed, &[token_id]);

It takes one token_id, not a slice. The assertion is a correctness guard worth pausing on: before we process this token, the cache must already hold exactly pos rows. A token at position pos has pos tokens before it, all of which should already be cached. If those two numbers disagree, the RoPE rotation would use the wrong angle and attention would be silently wrong. The embedding lookup gathers a single row: the one new token.

Then the per-layer loop:

src/model/qwen3/forward.rsRUST
        for (li, layer) in self.layers.iter().enumerate() {
            let normed = rms_norm_weighted_last(ops, &x, &layer.input_layernorm, cfg.rms_norm_eps);
            let q = ops.matmul(&normed, &layer.q_proj);
            let k = ops.matmul(&normed, &layer.k_proj);
            let v = ops.matmul(&normed, &layer.v_proj);
 
            let q = headwise_rms_norm_weighted(
                ops,
                &q,
                cfg.num_attention_heads,
                cfg.head_dim,
                &layer.attn_q_norm,
                cfg.rms_norm_eps,
            );
            let k = headwise_rms_norm_weighted(
                ops,
                &k,
                cfg.num_key_value_heads,
                cfg.head_dim,
                &layer.attn_k_norm,
                cfg.rms_norm_eps,
            );
            let q = ops.apply_rope_single_row(&q, pos, cfg.head_dim, cfg.rope_theta);
            let k_rope = ops.apply_rope_single_row(&k, pos, cfg.head_dim, cfg.rope_theta);
 
            cache.push_row(li, &k_rope, &v);
            let (k_full, v_full) = cache.materialize(li);
 
            let attn_out = gqa_attention_decode_one_query(
                ops,
                &q,
                &k_full,
                &v_full,
                &layer.o_proj,
                cfg.num_attention_heads,
                cfg.num_key_value_heads,
                cfg.head_dim,
            );
            x = ops.add(&x, &attn_out);
            let normed_mlp =
                rms_norm_weighted_last(ops, &x, &layer.post_attention_layernorm, cfg.rms_norm_eps);
            let mlp_out = mlp_forward(ops, &normed_mlp, layer);
            x = ops.add(&x, &mlp_out);
        }
 
        let x = rms_norm_weighted_last(ops, &x, &self.norm, cfg.rms_norm_eps);
        ops.matmul(&x, &self.lm_head)
    }

This is the chapter's whole idea, in one loop. For each layer:

  1. Project the one new token into its query, key, and value: matmul against q_proj, k_proj, v_proj. These are matrix-by-vector products now, because there's one token, not a sequence.
  2. Apply head-wise RMS norm and then RoPE, using apply_rope_single_row with the explicit pos, since this single token's position is known but isn't row 0.
  3. cache.push_row(li, &k_rope, &v): append this token's K and V to the layer's cache.
  4. cache.materialize(li): read back the full cached K and V, every token so far.
  5. Run gqa_attention_decode_one_query: the new token's query against the full cached K/V.
  6. The rest of the layer (residual add, MLP, second residual) is identical to the prefill path.

Compare this to what the non-cached forward does to generate the same token: it runs steps 1–6 for every token in the sequence, every time. The decode path runs them for one token and reads the rest out of the cache. That is the -to-N collapse, made concrete.

The impl Model for Qwen3Model block gets the three new methods, each just forwarding to the inherent method of the same name:

src/model/qwen3/forward.rsRUST
impl Model for Qwen3Model {
    fn forward(&self, token_ids: &[usize]) -> Tensor {
        Qwen3Model::forward(self, token_ids)
    }
 
    fn forward_prefill_with_kv_cache(
        &self,
        token_ids: &[usize],
        cache: &mut dyn KvCache,
    ) -> Tensor {
        Qwen3Model::forward_prefill_with_kv_cache(self, token_ids, cache)
    }
 
    fn forward_decode_with_kv_cache(
        &self,
        token_id: usize,
        pos: usize,
        cache: &mut dyn KvCache,
    ) -> Tensor {
        Qwen3Model::forward_decode_with_kv_cache(self, token_id, pos, cache)
    }
 
    fn kv_cache_layers_and_width(&self) -> (usize, usize) {
        Qwen3Model::kv_cache_layers_and_width(self)
    }
}

The --kv flag

src/cli/args.rs learns to parse --kv [basic]. First the ArgCursor gets a peek (look at the next argument without consuming it, needed because --kv's mode argument is optional):

src/cli/args.rsRUST
    fn peek(&self) -> Option<&str> {
        self.args.get(self.i).map(|s| s.as_str())
    }

Then a helper that reads the optional mode word after --kv:

src/cli/args.rsRUST
fn parse_kv_mode(cursor: &mut ArgCursor<'_>) -> &'static str {
    match cursor.peek() {
        Some("basic") => {
            cursor.advance();
            "basic"
        }
        Some(s) if !s.starts_with('-') => {
            panic!("invalid --kv mode '{s}' (expected basic)");
        }
        Some(_) | None => "basic",
    }
}

--kv can be followed by an explicit mode (--kv basic) or nothing (--kv, which defaults to basic). The logic: if the next word is basic, consume it; if it's a non-flag word that isn't a valid mode, that's a user error so we panic; otherwise (the next word is a flag, or there is no next word) leave it alone and default to basic.

CliArgs gains a kv_mode field and the --kv arm in the parse loop:

src/cli/args.rsRUST
pub struct CliArgs {
    kv_mode: Option<&'static str>,
    positionals: Vec<String>,
}
 
impl CliArgs {
    pub fn from_env() -> Self {
        Self::parse(std::env::args().collect())
    }
 
    pub fn parse(args: Vec<String>) -> Self {
        let mut kv_mode = None;
        let mut positionals = Vec::new();
 
        let mut cur = ArgCursor::new(&args);
        while cur.has_more() {
            match cur.peek() {
                Some("--kv") => {
                    cur.advance();
                    kv_mode = Some(parse_kv_mode(&mut cur));
                }
                _ => positionals.push(cur.take()),
            }
        }
 
        Self {
            kv_mode,
            positionals,
        }
    }
 
    pub fn positionals(&self) -> &[String] {
        &self.positionals
    }
 
    pub fn kv_cache_mode(&self) -> Option<&'static str> {
        self.kv_mode
    }
}

kv_mode is None when --kv is absent (cache off) and Some("basic") when present. The parse loop now peeks: if the next argument is --kv, consume it and parse the mode; anything else is a positional, exactly as in I.1. This is the additive-flag pattern that chapter set up. kv_cache_mode() exposes the result.

The crate root re-exports the factory and trait:

src/lib.rsRUST
pub use cache::{create_kv_cache, KvCache};

Wiring through the binary and the loop

greedy_generate takes the cache and picks a forward path per step:

src/decode/greedy.rsRUST
use tracing::info;
 
use crate::backend::Backend;
use crate::cache::KvCache;
use crate::decode::{Metrics, next_token_id_from_logits};
use crate::model::Model;
use crate::tokenizer::Tokenizer;
src/decode/greedy.rsRUST
pub fn greedy_generate(
    model: &dyn Model,
    ops: &dyn Backend,
    tokenizer: &dyn Tokenizer,
    prompt_ids: &[usize],
    max_new_tokens: usize,
    eos_token_id: usize,
    metrics: &mut Metrics,
    cache: &mut Option<Box<dyn KvCache>>,
) -> Vec<usize> {
    let mut ids = prompt_ids.to_vec();
 
    for step in 0..max_new_tokens {
        let logits = metrics.record_timed(|| match cache.as_mut() {
            Some(c) if step == 0 => model.forward_prefill_with_kv_cache(prompt_ids, c.as_mut()),
            Some(c) => {
                model.forward_decode_with_kv_cache(*ids.last().unwrap(), ids.len() - 1, c.as_mut())
            }
            None => model.forward(&ids),
        });
 
        let (next_id, prob) = next_token_id_from_logits(ops, &logits);
        ids.push(next_id);
 
        info!(
            step = step + 1,
            token_id = next_id,
            prob = prob * 100.0,
            token = %tokenizer.decode(&[next_id]),
            phase = if step == 0 { "prefill" } else { "decode" },
        );
 
        if next_id == eos_token_id {
            info!("[EOS reached]");
            break;
        }
    }
 
    ids
}

The new parameter is cache: &mut Option<Box<dyn KvCache>>, None if --kv was not passed. The match inside record_timed is the dispatch:

  • Some(c) if step == 0: cache on, first step: run prefill, filling the cache from the whole prompt.
  • Some(c) (step > 0): cache on, a decode step: run the decode path with the last token id and its position ids.len() - 1. After a 4-token prompt, the first decoded token is at position 4, which is ids.len() - 1 when ids holds the prompt plus the newly accepted previous token.
  • None: cache off: the Act 1 forward over the whole growing sequence. This is the path II.1 measured.

The phase field on the log event records "prefill" or "decode" so the streamed events show which path produced each token.

Finally model-generate. It reads the --kv flag, builds the cache, and threads it in:

src/bin/model-generate.rsRUST
use std::path::Path;
 
use inferno::{
    CliArgs, Metrics, create_backend, create_kv_cache, greedy_generate, load_from_gguf_path,
};
 
fn usage() -> ! {
    eprintln!("usage: model-generate [--kv [basic]] <gguf_path> [prompt] [max_new_tokens]");
    std::process::exit(2);
}
src/bin/model-generate.rsRUST
    let args = CliArgs::from_env();
    let kv_mode = args.kv_cache_mode();
src/bin/model-generate.rsRUST
    println!("kv cache: {}", kv_mode.unwrap_or("off"));
    println!();
 
    let mut metrics = Metrics::default();
 
    let mut cache = kv_mode
        .map(|mode| create_kv_cache(mode, model.clone(), backend.clone()).unwrap_or_else(|e| {
            eprintln!("error: {e}");
            std::process::exit(2);
        }));
 
    let full_ids = greedy_generate(
        model.as_ref(),
        &*backend,
        tokenizer.as_ref(),
        &prompt_ids,
        max_new_tokens,
        eos_token_id,
        &mut metrics,
        &mut cache,
    );

kv_mode is None when --kv is absent. kv_mode.map(...) builds a cache only if a mode was given, so cache is Some(Box<dyn KvCache>) with --kv, and None without. create_kv_cache needs the model (for the dimensions) and the backend (for concat_dim0). The rest of main is unchanged from II.1.

Running it

Run the same prompt twice, once with the cache off and once with it on, and let the harness compare them:

BASH
cargo run --release --bin model-generate -- path/to/qwen3-0.6b.gguf "Once upon a time" 32
cargo run --release --bin model-generate -- --kv path/to/qwen3-0.6b.gguf "Once upon a time" 32

Without --kv, the metrics block looks like II.1 and the per-forward dump climbs:

PLAINTEXT
kv cache: off
 
metrics:
  time_to_first_token_ms: 421.880
  decode_tokens_per_second: 1.984
  per_forward_ms: min 421.880  max 1224.402  mean 504.114  (n=32)
    forward 1: 421.880 ms
    forward 2: 447.221 ms
    ...
    forward 32: 1224.402 ms

With --kv, the per-forward times go flat and decode throughput jumps:

PLAINTEXT
kv cache: basic
 
metrics:
  time_to_first_token_ms: 423.104
  decode_tokens_per_second: 47.602
  per_forward_ms: min 19.882  max 423.104  mean 32.690  (n=32)
    forward 1: 423.104 ms
    forward 2: 21.114 ms
    forward 3: 20.402 ms
    ...
    forward 32: 21.985 ms

Two things to read here. Time to first token is unchanged (~420 ms either way). That's prefill, and the cache doesn't make prefill faster; it makes prefill save its work. The payoff is decode: throughput went from about 2 tokens/sec to roughly 48, better than 20×. And the per-forward dump tells the real story. Without the cache, forward 32 takes nearly 3× as long as forward 2; with it, every decode forward pass costs the same flat ~21 ms whether it's the 2nd token or the 32nd. The is gone. Longer responses no longer get slower per token.

Where this leaves us

The KV cache is the single largest speedup in Act 2, and it came from an algorithmic fix rather than a faster kernel: stop recomputing what doesn't change. Decode is now constant work per token instead of growing work, and the harness from II.1 proved it: flat per-forward times, a 20×-plus throughput jump.

But look again at that decode number: ~21 ms per forward pass, still. A pass is now a fixed amount of arithmetic, and that arithmetic is running on a single CPU core, one float at a time, in a plain scalar loop. The algorithm is right; the kernel is slow. The next chapter keeps the algorithm exactly as it is and attacks the kernel, rewriting matmul to use the CPU's SIMD vector instructions, doing 4 and then 16 multiply-adds per instruction instead of one.