III.7: Batched decode

The scheduler from III.6 runs concurrent requests by interleaving them: decode_tick walks every active slot and runs each one's forward pass, one after another. Four active slots means four separate forward passes per tick.

That works, but it leaves performance on the table, and to see why we have to look at the shape of a decode-step forward pass.

A decode step processes exactly one token. Every projection in the model (the Q/K/V projections, the attention output projection, the MLP's three matrices) is a matmul of that one token's hidden vector against a weight matrix. A vector times a matrix. In matmul terms, the "batch" dimension is 1.

That is the worst case for matmul hardware. A CPU's SIMD units and a GPU's cores are built to chew through matrix times matrix, a tall left operand, many rows processed per pass. Hand them a single-row left operand and most of the machine sits idle: the weight matrix still has to be streamed from memory, but it's used for just one row instead of many. The matmul is memory-bound, and the compute units starve.

Now the key observation. When four slots all take a decode step in the same tick, that's four separate single-row matmuls against the same weight matrices. Stack those four hidden vectors into a 4-row matrix and you get one matmul with a batch dimension of 4: the weight matrix streamed once, used four times. Same arithmetic, a quarter of the weight-memory traffic. This is batched decode, and it's this chapter.

What batches and what doesn't

Not every part of the forward pass batches the same way. There are two kinds of computation in a transformer layer:

  • Position-independent projections. Q/K/V projections, the attention output projection, the MLP. Each is hidden_vector × weight_matrix, and the weight matrix is shared by every sequence. Stack the hidden vectors → one bigger matmul. These batch perfectly.
  • Attention itself. Each sequence attends over its own KV history, with different lengths, different contents, and a different KV cache per slot. There's no shared matrix to stack against. Attention stays per-sequence: a loop over the batch.

So a batched decode step looks like: gather the batch's input tokens, do the shared projections as fused matmuls, loop per-sequence for the attention (each slot reads and appends to its own paged KV cache from III.4), then do the shared MLP as a fused matmul. The expensive matmuls (projections and MLP, the bulk of the FLOPs) are batched; only attention, which is comparatively cheap at decode time, stays in a loop.

This chapter adds forward_decode_batch to the model, teaches the scheduler worker to call it when two or more slots are active, and ships a GuideLLM A/B benchmark to measure the gain.

The trait method

src/model/model_trait.rs gains forward_decode_batch, and it ships with a default implementation so any model that doesn't override it still works:

src/model/model_trait.rsRUST
    fn forward_decode_batch(
        &self,
        token_ids: &[usize],
        positions: &[usize],
        caches: &mut [&mut dyn KvCache],
    ) -> Result<Vec<Tensor>, String> {
        assert!(token_ids.len() == positions.len() && token_ids.len() == caches.len());
        let mut out = Vec::with_capacity(token_ids.len());
        for i in 0..token_ids.len() {
            out.push(self.forward_decode_with_kv_cache(
                token_ids[i],
                positions[i],
                caches[i],
            ));
        }
        Ok(out)
    }

The signature: a slice of token ids (one per batched slot), a slice of positions, and a slice of KV cache references (&mut [&mut dyn KvCache], one mutable cache borrow per slot). It returns one logits tensor per slot.

The default just loops, calling the single-sequence forward_decode_with_kv_cache once per slot, exactly what III.6's decode_tick did. So the trait method changes nothing on its own; the speedup comes from a model that overrides it with a genuinely fused pass. Qwen3 does.

The fused Qwen3 decode

src/model/qwen3/forward.rs gets the real implementation. We'll take it in pieces.

src/model/qwen3/forward.rsRUST
    pub fn forward_decode_batch(
        &self,
        token_ids: &[usize],
        positions: &[usize],
        caches: &mut [&mut dyn KvCache],
    ) -> Result<Vec<Tensor>, String> {
        assert_eq!(token_ids.len(), positions.len());
        assert_eq!(token_ids.len(), caches.len());
        let bsz = token_ids.len();
        if bsz == 0 {
            return Ok(Vec::new());
        }
 
        let ops = self.cpu_backend.as_ref();
        let cfg = &self.config;
        let hidden = cfg.hidden_size;
 
        let mut x = ops.gather_rows(&self.embed, token_ids);

bsz is the batch size, how many slots are decoding this tick. The crucial line is the last one: gather_rows looks up the embedding for every token in the batch at once, producing x as a [bsz, hidden] matrix, one row per slot. Compare the single-sequence path, which embeds one token into a [1, hidden] row. From here, x flows through the layers as a real matrix, and that's what makes the projections batch.

Now the per-layer loop. The projections first:

src/model/qwen3/forward.rsRUST
        for (li, layer) in self.layers.iter().enumerate() {
            let normed = rms_norm_weighted_last(ops, &x, &layer.input_layernorm, cfg.rms_norm_eps);
            let q_all = ops.matmul(&normed, &layer.q_proj);
            let k_all = ops.matmul(&normed, &layer.k_proj);
            let v_all = ops.matmul(&normed, &layer.v_proj);

normed is the RMSNorm of x, a [bsz, hidden] matrix. Then three matmuls for the Q, K, V projections. Because normed has bsz rows, q_all, k_all, v_all each come out [bsz, *]; one matmul produces the projections for every slot. This is the batching win in three lines: where III.6 did bsz separate Q-projection matmuls, this does one with a bsz-row left operand. The weight matrix layer.q_proj is streamed from memory once, used bsz times.

Attention can't batch the same way, since each slot has its own KV history, so it loops:

src/model/qwen3/forward.rsRUST
            let mut attn_data = vec![0.0f32; bsz * hidden];
            for (bi, cache) in caches.iter_mut().enumerate() {
                let q_row = ops.copy_row_2d(&q_all, bi);
                let k_row = ops.copy_row_2d(&k_all, bi);
                let v_row = ops.copy_row_2d(&v_all, bi);
 
                let q_norm = headwise_rms_norm_weighted(
                    ops,
                    &q_row,
                    cfg.num_attention_heads,
                    cfg.head_dim,
                    &layer.attn_q_norm,
                    cfg.rms_norm_eps,
                );
                let k_norm = headwise_rms_norm_weighted(
                    ops,
                    &k_row,
                    cfg.num_key_value_heads,
                    cfg.head_dim,
                    &layer.attn_k_norm,
                    cfg.rms_norm_eps,
                );
                let q_rope = ops.apply_rope_single_row(
                    &q_norm,
                    positions[bi],
                    cfg.head_dim,
                    cfg.rope_theta,
                );
                let k_rope = ops.apply_rope_single_row(
                    &k_norm,
                    positions[bi],
                    cfg.head_dim,
                    cfg.rope_theta,
                );
 
                cache.push_row(li, &k_rope, &v_row);
                let (k_full, v_full) = cache.materialize(li);
 
                let attn_out = gqa_attention_decode_one_query(
                    ops,
                    &q_rope,
                    &k_full,
                    &v_full,
                    &layer.o_proj,
                    cfg.num_attention_heads,
                    cfg.num_key_value_heads,
                    cfg.head_dim,
                );
                let dst = &mut attn_data[bi * hidden..(bi + 1) * hidden];
                ops.copy_contiguous_into(&attn_out, 0, dst);
            }

attn_data is a bsz * hidden scratch buffer for the attention output of every slot. The loop runs once per slot bi:

  1. copy_row_2d pulls slot bi's row out of the batched q_all/k_all/v_all matrices, back to single-row tensors.
  2. The Q/K head-wise RMSNorm and RoPE are exactly the single-sequence attention math from I.5. RoPE uses positions[bi] because each slot is at a different point in its own sequence.
  3. cache.push_row appends this slot's new K/V row to its own paged KV cache; materialize reads back its full history. This is why the loop can't be a single matmul: caches[bi] is a different cache with a different length per slot.
  4. gqa_attention_decode_one_query runs grouped-query attention for this one token against its history, then copy_contiguous_into writes the result into slot bi's slice of attn_data.

After the loop, the attention outputs are reassembled into one batched matrix and fed through the residual and MLP, both of which batch:

src/model/qwen3/forward.rsRUST
            let attn_out = Tensor::new(attn_data, vec![bsz, hidden]);
            x = ops.add(&x, &attn_out);
 
            let normed_mlp =
                rms_norm_weighted_last(ops, &x, &layer.post_attention_layernorm, cfg.rms_norm_eps);
            let mlp_out = mlp_forward(ops, &normed_mlp, layer);
            x = ops.add(&x, &mlp_out);
        }

attn_data becomes a [bsz, hidden] tensor. The residual add, the post-attention RMSNorm, and mlp_forward all operate on that batched matrix; the MLP's three matmuls each have a bsz-row left operand, so the MLP batches exactly like the projections did. The MLP is the largest matmul in the layer, so batching it is the biggest single win.

After all layers, the final norm and output head (both batched) produce per-slot logits:

src/model/qwen3/forward.rsRUST
        let x = rms_norm_weighted_last(ops, &x, &self.norm, cfg.rms_norm_eps);
        let logits = ops.matmul(&x, &self.lm_head);
        Ok((0..bsz).map(|i| ops.copy_row_2d(&logits, i)).collect())
    }

logits is [bsz, vocab_size]. The final copy_row_2d per slot splits it into a Vec<Tensor>, one logits row per slot, the same shape forward_decode_batch's caller expects.

The Model trait impl forwards to it:

src/model/qwen3/forward.rsRUST
    fn forward_decode_batch(
        &self,
        token_ids: &[usize],
        positions: &[usize],
        caches: &mut [&mut dyn KvCache],
    ) -> Result<Vec<Tensor>, String> {
        Qwen3Model::forward_decode_batch(self, token_ids, positions, caches)
    }

Slot support for batched steps

The III.6 slot had decode_one, which did the whole step (push, forward, apply) for one slot. A batched step needs that split: every slot does its push phase, then one fused forward call covers them all, then every slot applies its result. src/scheduler/slot.rs gains the two halves.

batch_push_phase does the decision-and-append part (III.1's kv_decode_push_phase) without any forward call:

src/scheduler/slot.rsRUST
    /// Push phase of a batched decode step. Returns the `(token_id, position)`
    /// that needs a forward pass, or `None` if the slot finished here.
    pub fn batch_push_phase(&mut self, tokenizer: &dyn Tokenizer) -> Option<(usize, usize)> {
        let ids_before = self.ids.len();
 
        match kv_decode_push_phase(
            &mut self.ids,
            &mut self.next_id,
            self.prompt_ids.len(),
            self.eos_token_id,
            self.max_new_tokens,
        ) {
            KvPushPhase::Finished(step) => {
                self.finish_step(tokenizer, ids_before, Ok(step));
                None
            }
            KvPushPhase::NeedForward { token_id, position } => Some((token_id, position)),
        }
    }

It appends next_id and checks the stop conditions. If the slot finished here (EOS or token cap), finish_step marks it done and it returns None; this slot won't be in the batched forward call. Otherwise it returns the (token_id, position) the forward pass needs. This is exactly why III.1 split kv_decode_push_phase out as a pure function: the batched path reuses it.

batch_apply_logits does the after-forward part for one slot:

src/scheduler/slot.rsRUST
    pub fn batch_apply_logits(
        &mut self,
        logits: &Tensor,
        elapsed: std::time::Duration,
        batch_size: usize,
        backend: &dyn Backend,
        tokenizer: &dyn Tokenizer,
    ) {
        self.metrics.record_forward_share(elapsed, batch_size);
 
        let (next_id, _) = next_token_id_from_logits(backend, logits);
        self.next_id = next_id;
 
        if self.stream.is_some() {
            self.emit_stream_delta(tokenizer);
        }
    }

It takes this slot's logits row, picks the next token, and (if streaming) emits a delta. record_forward_share is the timing subtlety: one fused forward pass took elapsed wall-clock for the whole batch, so each slot is charged elapsed / batch_size, its fair share. Otherwise each of bsz slots would record the full batch time and the metrics would be bsz× inflated.

The batched decode tick

src/scheduler/worker.rs's decode_tick learns to choose: batch when it's worth it, fall back when it isn't.

src/scheduler/worker.rsRUST
    fn decode_tick(&mut self) {
        let active: Vec<usize> = self
            .slots
            .iter()
            .enumerate()
            .filter_map(|(i, s)| s.as_ref().filter(|s| s.is_active()).map(|_| i))
            .collect();
 
        if active.len() >= 2 {
            self.decode_tick_batched(&active);
        } else {
            for slot in self.slots.iter_mut().flatten() {
                slot.decode_one(
                    self.model.as_ref(),
                    self.tokenizer.as_ref(),
                    self.backend.as_ref(),
                );
            }
        }
    }

It collects the indices of active slots. With two or more, decode_tick_batched runs the fused path; with zero or one there's nothing to batch, so it falls back to III.6's per-slot decode_one. Batching a single sequence would be a [1, hidden] matmul, the slow case batched decode exists to avoid, so the threshold is exactly right.

decode_tick_batched is the orchestration. First, the push phase for every active slot:

src/scheduler/worker.rsRUST
    fn decode_tick_batched(&mut self, active: &[usize]) {
        let tokenizer = self.tokenizer.as_ref();
 
        let mut batch: Vec<usize> = Vec::new();
        let mut token_ids: Vec<usize> = Vec::new();
        let mut positions: Vec<usize> = Vec::new();
 
        for &i in active {
            let s = self.slots[i].as_mut().unwrap();
            if let Some((token_id, position)) = s.batch_push_phase(tokenizer) {
                batch.push(i);
                token_ids.push(token_id);
                positions.push(position);
            }
        }
 
        if batch.is_empty() {
            return;
        }

For each active slot, call batch_push_phase. A slot that finished in its push phase returns None and is not added to the batch; its forward pass would be wasted. Slots that need a forward pass contribute their index to batch and their token/position to the parallel token_ids/positions arrays. If every slot finished, the batch is empty and there's nothing to do.

Next, gather one mutable cache reference per batched slot:

src/scheduler/worker.rsRUST
        // Collect one cache reference per batched slot. `iter_mut` hands out
        // disjoint mutable borrows, and both `batch` and the iteration run in
        // ascending index order, so `caches[j]` lines up with `batch[j]`.
        let mut caches: Vec<&mut dyn crate::cache::KvCache> = Vec::with_capacity(batch.len());
        for (i, slot) in self.slots.iter_mut().enumerate() {
            if batch.contains(&i) {
                caches.push(slot.as_mut().unwrap().cache.as_mut());
            }
        }

This is a borrow-checker puzzle. forward_decode_batch wants &mut [&mut dyn KvCache], many mutable cache borrows at once. You can't take those by indexing self.slots[i] in a loop; the compiler can't prove the indices are distinct. iter_mut, though, yields provably disjoint mutable borrows of each element. The comment pins down the correctness argument: batch was built in ascending index order and this iteration is also ascending, so caches[j] is the cache of batch[j], and the slices line up.

Then the one fused forward call:

src/scheduler/worker.rsRUST
        let start = Instant::now();
        let logits_rows = match self
            .model
            .forward_decode_batch(&token_ids, &positions, &mut caches)
        {
            Ok(v) => v,
            Err(e) => {
                for &i in &batch {
                    self.slots[i].as_mut().unwrap().fail(e.clone());
                }
                return;
            }
        };
        let elapsed = start.elapsed();
        let n = batch.len();

One forward_decode_batch call for the whole batch, timed with Instant. If it errors, every batched slot is failed with that error. On success, logits_rows holds one logits tensor per batched slot, and elapsed is the wall time of the whole fused pass.

Finally, apply each slot's logits:

src/scheduler/worker.rsRUST
        let backend = self.backend.as_ref();
        for (j, &i) in batch.iter().enumerate() {
            self.slots[i].as_mut().unwrap().batch_apply_logits(
                &logits_rows[j],
                elapsed,
                n,
                backend,
                tokenizer,
            );
        }
    }

For each batched slot, batch_apply_logits with its logits row (logits_rows[j]), the shared elapsed, and the batch size n; record_forward_share charges each slot elapsed / n. Every slot picks its next token and emits its stream delta. The tick is done; the worker loops back, retires anything finished, admits anything waiting, and ticks again.

The GuideLLM benchmark

The III.7 commit ships an A/B benchmark to measure the gain, under benchmarking/guidellm/. GuideLLM is a load-testing tool for inference servers: it fires synthetic requests at a concurrency you choose and reports throughput.

The A/B is simple. Run the same concurrent load against chat-server twice:

  • A: --max-concurrent 1. One slot. decode_tick never sees two active slots, so it never batches; every request decodes serially.
  • B: --max-concurrent N. N slots. Concurrent requests fill multiple slots, decode_tick sees >= 2 active, and the fused path kicks in.

Same client load, same model, same data; the only difference is whether batched decode engages. The throughput gap is the chapter's payoff.

benchmarking/guidellm/guidellm-batched-decode-ab.sh automates it. It documents its knobs up top:

benchmarking/guidellm/guidellm-batched-decode-ab.shBASH
#!/usr/bin/env bash
# A/B GuideLLM runs: same concurrent client load vs chat-server with
#   A) --max-concurrent 1  → serial decode (no cross-request batching)
#   B) --max-concurrent N  → multiple KV slots → fused batched decode when loaded
#
# Usage (from repo root):
#   ./benchmarking/guidellm/guidellm-batched-decode-ab.sh /path/to/model.gguf

It builds the binary, then runs each phase with a helper that starts a server, waits for /health, runs GuideLLM, and stops the server:

benchmarking/guidellm/guidellm-batched-decode-ab.shBASH
run_server() {
  local port=$1
  local mc=$2
  local log=$3
  RUST_LOG="${RUST_LOG:-warn}" "$BIN" --kv "$KV" --max-concurrent "$mc" \
    --bind "127.0.0.1:$port" "$GGUF" "$DEFAULT_MAX_TOK" >"$log" 2>&1 &
  echo $!
}

The two phases, phase A pinning --max-concurrent 1 and phase B using the configurable MC_BATCH:

benchmarking/guidellm/guidellm-batched-decode-ab.shBASH
run_phase "Run A: max-concurrent=1 (serial decode), GuideLLM concurrent rate=$RATE" \
  "$PORT_A" 1 "$OUT_A"
run_phase "Run B: max-concurrent=$MC_BATCH (batched decode when loaded), same GuideLLM load" \
  "$PORT_B" "$MC_BATCH" "$OUT_B"
 
python3 "$SCRIPT_DIR/summarize-batched-decode-ab.py" \
  "$OUT_A/benchmark.json" "$OUT_B/benchmark.json" "$MC_BATCH"

GuideLLM writes a benchmark.json per phase; summarize-batched-decode-ab.py reads both and prints the comparison. Its core metric is aggregate output throughput: total tokens generated divided by wall time:

benchmarking/guidellm/summarize-batched-decode-ab.pyPYTHON
def summarize_benchmark(path: Path) -> dict:
    with open(path) as f:
        d = json.load(f)
    b = d["benchmarks"][0]
    dur = b.get("duration") or 0.0
    succ = b["requests"].get("successful") or []
    inc = b["requests"].get("incomplete") or []
    tot_out = sum((r.get("output_tokens") or 0) for r in succ)
    per_req = [
        r.get("output_tokens_per_second")
        for r in succ
        if r.get("output_tokens_per_second") is not None
    ]
    agg = (tot_out / dur) if dur > 0 else None
    mean_pr = statistics.mean(per_req) if per_req else None
    rm = b["scheduler_metrics"]["requests_made"]
    return {
        "successful": rm["successful"],
        "errored": rm["errored"],
        "incomplete": rm.get("incomplete", len(inc)),
        "duration_s": dur,
        "out_tokens_ok": tot_out,
        "agg_output_tok_s": agg,
        "mean_output_tok_s_per_req": mean_pr,
    }

agg_output_tok_s, completed output tokens over phase wall time, is the number to compare. The script's own notes explain why: per-request rates can read low under concurrency because each request's window overlaps the others', so aggregate throughput is the honest A/B measure.

A small note at benchmarking/README.md points at it:

benchmarking/README.mdMARKDOWN
# Benchmarking
 
- **[`guidellm/`](guidellm/)** — GuideLLM A/B for batched vs serial decode on `chat-server` ([`guidellm/README.md`](guidellm/README.md)).

Running it

Install GuideLLM (pip install guidellm) and run the A/B from the crate root:

BASH
./benchmarking/guidellm/guidellm-batched-decode-ab.sh path/to/qwen3-0.6b.gguf

It builds chat-server, runs phase A (one slot, serial decode), then phase B (eight slots, batched decode), each under the same eight-stream concurrent load, and prints:

PLAINTEXT
Summary (GuideLLM benchmark index 0)
  single-slot (mc=1):  ok=40 incomplete=0 err=0  wall_s=86.4
      output tokens (completed): 2560  agg output tok/s (ok_wall): 29.6  mean ok-req output tok/s: 3.71
  batched (mc=8): ok=40 incomplete=0 err=0  wall_s=24.1
      output tokens (completed): 2560  agg output tok/s (ok_wall): 106.2  mean ok-req output tok/s: 13.3

Same 40 requests, same 2560 generated tokens. With one slot they decode serially: 86 seconds, ~30 tok/s aggregate. With eight slots the concurrent requests fill the slots, decode_tick batches them into fused passes, and the same work finishes in 24 seconds at ~106 tok/s, roughly 3.5× the throughput. The exact multiple depends on the backend and the batch size, but the shape is always the same: batching turns starved single-row matmuls into matmuls the hardware can actually feed.

Where this leaves us

The decode loop batches. When two or more slots are decoding, their forward passes fuse into one: the projections and the MLP (the bulk of the FLOPs) become single matmuls with a real batch dimension, while attention stays a per-slot loop over each sequence's own paged KV cache. The GuideLLM A/B turns the abstract win into a concrete throughput number.

That closes Act 3, and the project. The engine that started Act 1 as a GGUF parser is now a complete OpenAI-compatible inference server: it speaks HTTP and SSE, admits concurrent conversations through a scheduler, stores KV in fragmentation-free paged blocks, reuses shared prompt prefixes through a radix cache, and batches concurrent decode steps. The Act 3 recap takes stock of the whole journey.