Gradient Checkpointing

Gradient checkpointing trades compute for memory by recomputing activations on the backward pass instead of storing them from the forward pass.

Why is this needed?

The backward pass needs every layer’s activation to compute its gradient. Storing them all is what makes activation memory blow up as models and sequence lengths grow, often more than the weights themselves.

Mechanics

  • Mark certain layers as checkpoints, save their activations
  • Discard activations between checkpoints
  • On backward, re-run the forward pass between checkpoints to reconstruct what was discarded, then compute the gradient

Net effect: roughly halves activation memory, costs ~20% more time (one extra forward pass over the non-checkpointed segments).

When to reach for it

  • You’re OOMing on a model that almost fits
  • You’d rather spend the time than buy more VRAM
  • Common pairing: with gradient accumulation to squeeze a larger effective batch out of a fixed GPU

Not a silver bullet

Even with checkpointing plus batch=1, bert-large-uncased still doesn’t fit on ecetesla0 (7.43 GiB). Checkpointing buys you a factor, not an order of magnitude.