II.5: Metal GPU backend

The CPU backends are about as fast as a CPU gets for this: SIMD packs 16 floats into one instruction (II.3), threads spread the work over every core (II.4). A ten-core laptop CPU has, very roughly, ten powerful arithmetic units running at once.

A GPU has thousands of small ones. Matmul, where every output element is an independent dot product, is the canonical workload they were built for. This chapter writes a fourth Backend, Metal, that runs matmul on the GPU through Apple's Metal API: a compute kernel written in Metal Shading Language, command buffers to launch it, and Apple Silicon's unified memory so the CPU and GPU share the same bytes without copying.

Like every backend before it, Metal changes only where matmul runs, not what it computes. Like Parallel, it uses a size threshold to decide when the GPU is worth the trip and when to fall back to the CPU.

What "offload to the GPU" actually means

For a generalist, here's the model. The GPU is a separate processor. It does not share the CPU's instruction stream or its program. To make it do work you must, every time:

  1. Make sure the data the GPU needs lives in memory the GPU can read. On a discrete graphics card this means copying data across the PCIe bus into the card's own memory, which is slow. On Apple Silicon the CPU and GPU share one physical pool of RAM (unified memory), so "making the data available" can be free, which matters a lot here.
  2. Tell the GPU which program to run. A GPU program is called a kernel (or shader). You write it in a small C-like language (for Metal, Metal Shading Language (MSL)) and the driver compiles it for the GPU.
  3. Launch the kernel over a grid of threads. The GPU runs the same kernel on thousands of threads simultaneously, each with a different (x, y) coordinate. For matmul, you launch one thread per output element; each thread computes its own C[row][col].
  4. Wait for it to finish, then read the results back.

Steps 1 and 4 (getting data to and from the GPU, and the round-trip of asking it to do something) are latency. They cost a fixed amount of time no matter how small the actual computation is. The GPU only pays off when step 3, the kernel, is doing enough work to dwarf that fixed overhead. That trade-off drives every design decision in this chapter, and it's the same trade-off as II.4's MIN_ROWS_FOR_PARALLEL threshold, just with a bigger fixed cost.

Three crates for talking to Metal

We don't reimplement Metal; we bind to Apple's system framework through three crates:

Cargo.tomlTOML
[dependencies]
regex = "1"
rayon = "1"
mtl-gpu = "1.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
mtl-foundation = "1.0.1"
mtl-sys = "1.0.1"

mtl-sys is the raw, unsafe layer: the Objective-C messaging machinery (Metal's API is Objective-C under the hood). mtl-foundation and mtl-gpu are safer Rust wrappers over the device, command queue, buffers, and pipelines. We mostly use mtl-gpu, dropping to mtl-sys once for one operation the wrapper doesn't expose.

The Metal code lives in a new submodule, src/backend/metal/, split three ways: shaders.rs (the kernel source), context.rs (device and dispatch plumbing), and backend.rs (the Backend impl). We'll take them in that order.

The matmul kernel

src/backend/metal/shaders.rs holds the GPU program as a Rust string constant, MSL source that the Metal driver compiles at startup:

src/backend/metal/shaders.rsRUST
pub const SHADERS: &str = r#"
#include <metal_stdlib>
using namespace metal;
 
constant uint TS = 16;
 
kernel void matmul_fp32_fp32(
    device const float* a [[buffer(0)]],
    device const float* b [[buffer(1)]],
    device       float* c [[buffer(2)]],
    constant     uint&  m [[buffer(3)]],
    constant     uint&  n [[buffer(4)]],
    constant     uint&  p [[buffer(5)]],
    uint2 gid   [[thread_position_in_grid]],
    uint2 lid   [[thread_position_in_threadgroup]]
) {
    threadgroup float tA[16][16];
    threadgroup float tB[16][16];
 
    uint row = gid.x;
    uint col = gid.y;
    float sum = 0.0f;
 
    uint num_tiles = (n + TS - 1) / TS;
 
    for (uint t = 0; t < num_tiles; t++) {
        uint a_col = t * TS + lid.y;
        uint b_row = t * TS + lid.x;
 
        tA[lid.x][lid.y] = (row < m && a_col < n) ? a[row * n + a_col] : 0.0f;
        tB[lid.x][lid.y] = (b_row < n && col < p) ? b[b_row * p + col] : 0.0f;
 
        threadgroup_barrier(mem_flags::mem_threadgroup);
 
        for (uint j = 0; j < TS; j++) {
            sum += tA[lid.x][j] * tB[j][lid.y];
        }
 
        threadgroup_barrier(mem_flags::mem_threadgroup);
    }
 
    if (row < m && col < p) {
        c[row * p + col] = sum;
    }
}
"#;

This computes C = A · B with A being m × n and B being n × p. The kernel keyword marks matmul_fp32_fp32 as a GPU entry point. Its parameters are bound to numbered slots: [[buffer(0)]] through [[buffer(5)]] are how Rust passes the three matrices and the three dimensions in. [[thread_position_in_grid]] and [[thread_position_in_threadgroup]] are filled in by the hardware. Every thread sees a different gid, its coordinate in the launch grid.

The basic idea: one GPU thread per output element. Thread at grid position (row, col) computes C[row][col]. With one thread per element, the GPU can have thousands of output elements in flight at once.

The kernel is a tiled matmul rather than the naive version, and the reason is the same memory story as everywhere else in Act 2. The naive approach has each thread stream a full row of A and a full column of B straight from main memory, and adjacent threads re-read overlapping data, hammering the slow path. So the kernel works in 16×16 tiles. The threads in a 16 × 16 threadgroup cooperate: each thread loads one element of A and one of B into the small, fast threadgroup-shared arrays tA and tB, all 256 threads doing it in parallel. A threadgroup_barrier makes every thread wait until the tile is fully loaded. Then each thread does 16 multiply-adds reading only from the fast shared tiles. Slide to the next tile and repeat. Each value of A and B is read from main memory once per tile instead of once per thread, the classic GPU matmul optimization. The (row < m && ...) guards handle matrices whose dimensions aren't exact multiples of 16, zero-padding the edge tiles.

The Metal context

src/backend/metal/context.rs owns the GPU connection (the device, the command queue, the compiled kernel) and the function that dispatches a matmul. It opens with kernel registration and two helpers:

src/backend/metal/context.rsRUST
use std::collections::HashMap;
use std::ffi::c_void;
 
use mtl_foundation::object::Referencing;
use mtl_gpu::device as mtl_device;
use mtl_gpu::{
    Buffer, CommandQueue, ComputeCommandEncoder, ComputePipelineState, Device, ResourceOptions,
    Size,
};
use mtl_sys::{msg_send_4, sel};
 
use super::shaders::SHADERS;
 
pub(crate) const KERNEL_MATMUL_FP32_FP32: &str = "matmul_fp32_fp32";
 
pub(crate) const KERNELS: &[&str] = &[KERNEL_MATMUL_FP32_FP32];

KERNELS is the list of kernel names to compile: one for now, with II.6 adding a second.

src/backend/metal/context.rsRUST
unsafe fn wrap_shared_storage<T: Copy>(device: &Device, data: &[T]) -> Option<Buffer> {
    unsafe {
        let ptr: *mut c_void = msg_send_4(
            device.as_ptr(),
            sel!(newBufferWithBytesNoCopy: length: options: deallocator:),
            data.as_ptr() as *mut c_void,
            std::mem::size_of_val(data) as usize,
            ResourceOptions::STORAGE_MODE_SHARED,
            std::ptr::null_mut::<c_void>(),
        );
        Buffer::from_raw(ptr)
    }
}

This is the unified-memory trick, and it is the one place we drop to raw mtl-sys. newBufferWithBytesNoCopy tells Metal: "make a GPU buffer that points directly at this CPU memory; do not allocate, do not copy." Because Apple Silicon's CPU and GPU share one physical RAM pool (STORAGE_MODE_SHARED), the GPU can read the exact bytes the Rust Vec already holds. On a discrete GPU, getting data across to the card is a real copy over PCIe; here it is a pointer wrap. That is why "step 1" (getting data to the GPU) is nearly free on this hardware, and it's the reason a GPU backend is worth writing for a model this small.

src/backend/metal/context.rsRUST
#[inline]
fn set_u32(enc: &ComputeCommandEncoder, val: u32, index: usize) {
    let raw = val.to_ne_bytes();
    enc.set_bytes(&raw, index);
}

set_u32 passes a scalar (a matrix dimension) into a kernel buffer slot, specifically those constant uint& m parameters.

The context struct and its constructor:

src/backend/metal/context.rsRUST
pub struct MetalContext {
    device: Device,
    queue: CommandQueue,
    pipelines: HashMap<&'static str, ComputePipelineState>,
}
 
impl MetalContext {
    pub fn new() -> Self {
        let device = mtl_device::system_default().unwrap();
 
        let library = device.new_library_with_source(SHADERS, None).unwrap();
 
        let pipelines: HashMap<&'static str, _> = KERNELS
            .iter()
            .copied()
            .map(|name| {
                let func = library.new_function_with_name(name).unwrap();
                let pipeline = device
                    .new_compute_pipeline_state_with_function(&func)
                    .unwrap();
                (name, pipeline)
            })
            .collect();
 
        let queue = device.new_command_queue().unwrap();
 
        eprintln!(
            "MetalBackend: device={} unified_memory={}",
            device.name(),
            device.has_unified_memory()
        );
 
        Self {
            device,
            queue,
            pipelines,
        }
    }

new does the one-time GPU setup, so it runs once at startup and never again. system_default() grabs the GPU. new_library_with_source(SHADERS, ...) hands our MSL string to the driver, which compiles it. For each kernel name, new_compute_pipeline_state_with_function builds a compute pipeline, the GPU-ready, launchable form of that kernel, and we stash them in a HashMap keyed by name. new_command_queue() creates the channel for submitting work. The eprintln! reports the GPU's name and confirms unified memory is on.

The dispatch function, the generic "run a kernel" routine:

src/backend/metal/context.rsRUST
    fn dispatch(&self, kernel: &str, bufs: &[&Buffer], scalars: &[u32], grid: Size, threads: Size) {
        let cmd = self.queue.command_buffer().unwrap();
        let enc =
            unsafe { ComputeCommandEncoder::from_raw(cmd.compute_command_encoder()) }.unwrap();
        let pipeline = self.pipelines.get(kernel).unwrap();
        enc.set_compute_pipeline_state(pipeline);
        for (i, buf) in bufs.iter().enumerate() {
            enc.set_buffer(buf, 0, i);
        }
        let buf_count = bufs.len();
        for (i, &val) in scalars.iter().enumerate() {
            set_u32(&enc, val, buf_count + i);
        }
        enc.dispatch_threadgroups(grid, threads);
        enc.end_encoding();
        cmd.commit();
        cmd.wait_until_completed();
    }

This is the universal shape of "make the GPU do something":

  • command_buffer(): a command buffer is a batch of GPU instructions you build up and then submit.
  • compute_command_encoder(): the encoder writes compute commands into that buffer.
  • set_compute_pipeline_state(pipeline): select which kernel to run.
  • The two loops bind the arguments: each tensor Buffer goes into slot i (matching [[buffer(0..2)]]), each scalar goes into the slots after them ([[buffer(3..5)]]).
  • dispatch_threadgroups(grid, threads): launch. threads is the threadgroup size; grid is how many threadgroups. Together they define the thread grid the kernel runs over.
  • end_encoding() / commit(): finish and submit to the GPU.
  • wait_until_completed(): block until the GPU is done.

That final wait_until_completed is the synchronous round-trip ("step 4"), and it is the latency the size threshold exists to amortize.

wrap_f32 is the small convenience that turns a float slice into a shared-storage buffer:

src/backend/metal/context.rsRUST
    fn wrap_f32(&self, data: &[f32]) -> Buffer {
        unsafe { wrap_shared_storage(&self.device, data) }.unwrap()
    }

And the matmul entry point:

src/backend/metal/context.rsRUST
    pub fn matmul_fp32_fp32(&self, a: &[f32], b: &[f32], m: usize, n: usize, p: usize) -> Vec<f32> {
        const TILE: usize = 16;
 
        let an = m.checked_mul(n).unwrap();
        let np = n.checked_mul(p).unwrap();
        let mp = m.checked_mul(p).unwrap();
        assert_eq!(a.len(), an);
        assert_eq!(b.len(), np);
 
        let c_data = vec![0.0f32; mp];
 
        let a_buf = self.wrap_f32(a);
        let b_buf = self.wrap_f32(b);
        let c_buf = self.wrap_f32(&c_data);
 
        self.dispatch(
            KERNEL_MATMUL_FP32_FP32,
            &[&a_buf, &b_buf, &c_buf],
            &[m as u32, n as u32, p as u32],
            Size::new((m + TILE - 1) / TILE, (p + TILE - 1) / TILE, 1),
            Size::new(TILE, TILE, 1),
        );
 
        c_data
    }
}

It allocates the output c_data, wraps all three matrices as shared buffers (no copy), and dispatches. The grid sizing is the key arithmetic: Size::new(TILE, TILE, 1) is a 16×16 threadgroup (256 threads, matching the kernel's tA[16][16] tiles) and the grid is ((m + 15) / 16, (p + 15) / 16, 1) threadgroups, i.e. enough 16×16 tiles to cover the whole m × p output, rounding up. Then the crucial detail of unified memory: the kernel wrote its results straight into the memory c_data already owns, so once dispatch returns, c_data is the answer, with no read-back step. We just return it.

The Metal backend

src/backend/metal/backend.rs is the Backend impl. Same delegation pattern as SimdCpu and Parallel: implement matmul, forward the rest:

src/backend/metal/backend.rsRUST
use crate::tensor::{Tensor, TensorData};
 
use crate::backend::Backend;
 
use super::context::MetalContext;
 
const MIN_M_FOR_GPU_MATMUL: usize = 8;
 
pub struct Metal<B: Backend> {
    ctx: MetalContext,
    fallback: B,
}
 
impl<B: Backend> Metal<B> {
    pub fn new(fallback: B) -> Self {
        Self {
            ctx: MetalContext::new(),
            fallback,
        }
    }

Metal<B> holds the MetalContext and a fallback backend. MIN_M_FOR_GPU_MATMUL = 8 is this backend's size threshold, the GPU equivalent of Parallel's MIN_ROWS_FOR_PARALLEL, set higher because the GPU's fixed cost (the command-buffer round-trip, wait_until_completed) is larger than rayon's thread-wakeup cost.

src/backend/metal/backend.rsRUST
    fn matmul_fp32_fp32(
        &self,
        a_data: &[f32],
        b_data: &[f32],
        m: usize,
        n: usize,
        p: usize,
    ) -> Tensor {
        assert_eq!(a_data.len(), m * n, "matmul_fp32_fp32: len(a) must be m*n");
        assert_eq!(b_data.len(), n * p, "matmul_fp32_fp32: len(b) must be n*p");
        let data = self.ctx.matmul_fp32_fp32(a_data, b_data, m, n, p);
        Tensor::new(data, vec![m, p])
    }
}

A thin adapter from the Backend world to the MetalContext world: validate the shapes, dispatch to the GPU, wrap the result back into a Tensor.

The Backend impl, starting with matmul:

src/backend/metal/backend.rsRUST
impl<B: Backend> Backend for Metal<B> {
    fn name(&self) -> String {
        "metal".to_string()
    }
 
    fn matmul(&self, a: &Tensor, b: &Tensor) -> Tensor {
        assert_eq!(a.shape().len(), 2);
        assert_eq!(b.shape().len(), 2);
        let a_shape = a.shape();
        let b_shape = b.shape();
        let m = a_shape[0];
        let n = a_shape[1];
        match (a.as_data(), b.as_data()) {
            (TensorData::Fp32(a_data), TensorData::Fp32(b_data)) => {
                let p = b_shape[1];
                assert_eq!(n, b_shape[0], "tensor shape mismatch");
                if m == 0 || p == 0 {
                    return Tensor::new(vec![], vec![m, p]);
                }
                if m < MIN_M_FOR_GPU_MATMUL {
                    return self.fallback.matmul(a, b);
                }
                self.matmul_fp32_fp32(a_data, b_data, m, n, p)
            }
        }
    }

Three guards before the GPU is touched. An empty result (m == 0 || p == 0) short-circuits. m < MIN_M_FOR_GPU_MATMUL, fewer than 8 output rows, falls back to the CPU. This is the same prefill/decode logic as II.4: a prefill matmul has one row per prompt token (dozens or hundreds, straight to the GPU), but a decode matmul has a single output row, far below 8, so it runs on the fallback. The GPU round-trip would dwarf the work of a one-row matmul. Decode, once again, stays on the CPU; the GPU is a prefill accelerator here.

The rest of the trait delegates to fallback. The first few:

src/backend/metal/backend.rsRUST
    fn sum_squares_axis(&self, x: &Tensor, axis: usize) -> Tensor {
        self.fallback.sum_squares_axis(x, axis)
    }
    fn add(&self, a: &Tensor, b: &Tensor) -> Tensor {
        self.fallback.add(a, b)
    }
    fn hadamard(&self, a: &Tensor, b: &Tensor) -> Tensor {
        self.fallback.hadamard(a, b)
    }
    fn scale(&self, x: &Tensor, s: f32) -> Tensor {
        self.fallback.scale(x, s)
    }
    fn silu(&self, x: &Tensor) -> Tensor {
        self.fallback.silu(x)
    }

…and identically for add_scalar, rsqrt_elem, broadcast_row_scalars, transpose_2d, softmax_rows, gather_rows, reshape_data, fill_strict_upper_tri, copy_2d_from_cols, copy_2d_into_cols, repeat_row_as_matrix, apply_rope, apply_rope_single_row, concat_dim0, copy_row_2d, copy_contiguous_into, and argmax_with_prob. Only the big matmuls go to the GPU; everything else (including, by the threshold, every decode matmul) stays on the CPU fallback.

The module file exports Metal:

src/backend/metal/mod.rsRUST
mod backend;
mod context;
mod shaders;
 
pub use backend::Metal;

Wiring it into the factory

src/backend/factory.rs gets the "metal" arm:

src/backend/factory.rsRUST
use std::sync::Arc;
 
use super::Backend;
use super::{CpuBackend, Metal, Parallel, SimdCpu, TracingBackend};
 
pub fn create_backend(name: &str, enable_tracing: bool) -> Result<Arc<dyn Backend>, String> {
    let name = name.trim();
    match name {
        "scalar" => Ok(wrap_scalar(enable_tracing)),
        "simd" => Ok(wrap_simd(enable_tracing)),
        "parallel" => Ok(wrap_parallel(enable_tracing)),
        "metal" => Ok(wrap_metal(enable_tracing)),
        other => Err(format!(
            "unknown backend {other:?} (supported: scalar, simd, parallel, metal)"
        )),
    }
}
src/backend/factory.rsRUST
fn wrap_metal(enable_tracing: bool) -> Arc<dyn Backend> {
    let metal = Metal::new(SimdCpu::new(CpuBackend));
    if enable_tracing {
        Arc::new(TracingBackend::new(metal))
    } else {
        Arc::new(metal)
    }
}

Metal::new(SimdCpu::new(CpuBackend)) gives the GPU backend a SIMD fallback. So a small (sub-8-row) matmul, which is every decode matmul, runs on the vectorized CPU kernel from II.3; a large one goes to the GPU. We pick SimdCpu rather than Parallel for the fallback because the fallback only ever handles tiny matmuls (single-row decode matmuls), and those are below Parallel's threshold too, so the extra layer would do nothing.

The module exports it:

src/backend/mod.rsRUST
mod backend_trait;
pub(crate) mod cpu;
mod factory;
pub(crate) mod metal;
pub(crate) mod parallel_cpu;
pub(crate) mod simd_cpu;
pub(crate) mod tracing;
 
pub use backend_trait::Backend;
pub use factory::create_backend;
 
pub(crate) use cpu::CpuBackend;
pub(crate) use metal::Metal;
pub(crate) use parallel_cpu::Parallel;
pub(crate) use simd_cpu::SimdCpu;
pub(crate) use tracing::TracingBackend;

And model-generate's usage string adds metal:

src/bin/model-generate.rsRUST
fn usage() -> ! {
    eprintln!(
        "usage: model-generate [--kv [basic]] [--backend scalar|simd|parallel|metal] <gguf_path> [prompt] [max_new_tokens]"
    );
    std::process::exit(2);
}

Running it

The GPU's win is prefill, so use a long prompt and compare parallel against metal:

BASH
cargo run --release --bin model-generate -- --kv --backend parallel path/to/qwen3-0.6b.gguf "$(cat long-prompt.txt)" 32
cargo run --release --bin model-generate -- --kv --backend metal    path/to/qwen3-0.6b.gguf "$(cat long-prompt.txt)" 32

The Metal run first prints the device line from MetalContext::new, then the usual metrics:

PLAINTEXT
MetalBackend: device=Apple M-series GPU unified_memory=true
backend: metal
kv cache: basic
 
metrics:
  time_to_first_token_ms: 96.214
  decode_tokens_per_second: 279.881
  per_forward_ms: min 3.511  max 96.214  mean 6.402  (n=32)

Against the parallel baseline on the same ~512-token prompt (~313 ms time to first token from II.4), prefill drops to roughly 96 ms: the thousands of GPU threads chew through the big prefill matmuls far faster than ten CPU cores. Decode throughput is essentially unchanged (~280 tokens/sec), since decode's single-row matmuls are below MIN_M_FOR_GPU_MATMUL, so they fall back to the SIMD CPU kernel and never reach the GPU. The split is exactly what the threshold guarantees: the GPU accelerates prefill; decode stays on the CPU.

Where this leaves us

Metal is a fourth Backend, --backend metal, that compiles an MSL matmul kernel, dispatches big matmuls to the GPU through command buffers, and uses Apple Silicon's unified memory to share data with no copy. Prefill, the compute-bound half of inference, now runs on hardware built for exactly this shape of work.

But notice what hasn't moved through the last three chapters: decode throughput. SIMD lifted it, and then parallel and metal left it essentially flat; every one of those backends sends decode's single-row matmuls down the same CPU path. That is not a backend failing to try. It is the act intro's central point asserting itself: decode is memory-bandwidth-bound. The bottleneck is not arithmetic throughput. It is the time spent reading the model's weights out of memory, the same gigabytes loaded for every single token. No amount of faster arithmetic touches that. The only lever is reading fewer bytes. The final chapter of Act 2 pulls it: Q8_0 quantization, which stores the weights in 8 bits instead of 32 and roughly halves decode latency by halving the bytes.