Activation Checkpointing
Training trades memory for compute by recomputing intermediate activations during the backward pass instead of storing them from the forward pass.
Why?
A standard forward pass caches every activation (each layer’s output) so backward can use them to compute gradients. For a deep model the activation tape dominates GPU memory — a Llama3-405B FFN with bf16, batch 1, seq 4096 eats 63GB just for one layer’s activations: bytes. Activation checkpointing lets you keep only a few activations, recomputing the rest when needed, trading a larger backward-pass compute budget for smaller peak memory.
The compute/memory Pareto (CS231n 2025 Lec 11)
For a linear chain of layers, three schemes sit on the tradeoff curve:
| Scheme | Forward compute | Backward compute | Peak act memory |
|---|---|---|---|
| Standard (cache everything) | |||
| Full recompute (no cache) | |||
| checkpoints | |||
| checkpoints (optimal) |
Full recompute keeps only the input. To backprop through layer you first re-run the forward from input through layer , which is flops, summed over → backward compute.
checkpoints divides the layers into chunks of size . Only chunk boundaries are cached. To backprop through any chunk, re-run forward from the chunk’s starting checkpoint (cost ) then backprop it (cost ). Summed over chunks → backward compute, memory.
Setting minimizes the product: compute and memory. This is the classical result — the curve bottoms out at .
In practice
- Checkpoints live at Transformer block granularity — each block’s input is cached, the 6 matmuls + attention matrix inside are recomputed in backward.
- Hurts MFU somewhat because backward now does extra matmuls that don’t count as “useful” FLOPs, but lets you fit a much larger batch or model — usually a net throughput win.
- Essential combined with FSDP at scale: FSDP shards params, checkpointing shards activations.
- PyTorch exposes
torch.utils.checkpoint.checkpointas a wrapper.
Source
CS231n 2025 Lec 11 slides ~71–102 (forward/backward cache diagram, full-recompute analysis, -checkpoint derivation, optimum, why it matters next to FSDP). 2026 PDF not published — using 2025 fallback.