Mixed Precision
Mixed precision trains (or runs inference) with some tensors in FP16 or BF16 while keeping a master copy of weights in FP32 for numerical stability Micikevicius et al. 2017.
Why mix, not just go all FP16?
Gradients in FP16 often underflow to zero, and some reductions (softmax, loss, variance) lose too much precision. Keeping the master weights and loss scaling in FP32 avoids those failure modes while still letting the bulk of matmuls run in FP16.
Mechanics
- Weights stored FP32 (master copy)
- Matmul inputs / activations cast to FP16 before the op
- Matmul output accumulation in FP32, written back as FP16
- Loss scaling: multiply loss by a large constant before backward, unscale gradients before
optimizer.step(), prevents small gradients from underflowing - Ops kept in FP32: softmax, layer norm, loss computation
Why it’s so fast
- FP16 matmul runs on Tensor Cores at 2-4Ă— the FP32 throughput
- Half the memory bandwidth per weight access
- Half the activation memory, enabling bigger batches
BF16 vs FP16
- FP16: 5-bit exponent, 10-bit mantissa, needs loss scaling, max value ~65504
- BF16: 8-bit exponent (same range as FP32), 7-bit mantissa, no loss scaling needed, but less precision per value. Default on Ampere+ for most training
Automatic Mixed Precision (AMP)
PyTorch’s torch.cuda.amp + NVIDIA’s Apex handle the casting and scaling for you. See NVIDIA’s guide.