FP8 Training
Going to spend a bit more time to understand how this training works.
Libraries
- TorchAO
- NVIDIA Transformer Engine
1. Keep the “real” training state in higher precision
People usually do not make the whole model FP8.
Instead they keep:
- parameters in bf16 or fp32
- optimizer states usually in fp32
- accumulations/output in bf16/fp32
and use FP8 mainly for the expensive GEMMs / linear layers. NVIDIA’s Transformer Engine and PyTorch’s float8 stack both frame FP8 training this way: low-precision compute with scaling metadata around supported ops, not “everything is FP8 now.”