FP8 Training

Going to spend a bit more time to understand how this training works.

Libraries

  • TorchAO
  • NVIDIA Transformer Engine

1. Keep the “real” training state in higher precision

People usually do not make the whole model FP8.

Instead they keep:

  • parameters in bf16 or fp32
  • optimizer states usually in fp32
  • accumulations/output in bf16/fp32

and use FP8 mainly for the expensive GEMMs / linear layers. NVIDIA’s Transformer Engine and PyTorch’s float8 stack both frame FP8 training this way: low-precision compute with scaling metadata around supported ops, not “everything is FP8 now.”