Fully Sharded Data Parallel (FSDP)
https://engineering.fb.com/2021/07/15/open-source/fsdp/
They actually have a paper PyTorch FSDP Experiences on Scaling Fully Sharded Data Parallel that you should read if you really want to understand what’s going on.
FSDP = DP + ZeRO sharding (CS231n 2025 Lec 11)
Plain DP keeps a full copy of every weight on every GPU. That doesn’t fit once the model is large — a 100B-param model with Adam needs 100B × 4 states × 2 bytes = 800GB per GPU (weights + grads + two Adam momenta in bf16), and an H100 has 80GB. FSDP (Rajbhandari et al., ZeRO, arXiv 2019) solves this by sharding each weight across the DP group: GPU is the sole owner of weight slice and its corresponding grad / optimizer-state slices. Split 800GB over 80 GPUs → 10GB/GPU.
Forward (one layer at a time):
- All GPUs
all_gatherthe shards of → full materialized on every GPU. - Compute as in plain DP.
- Free the non-owned shards (each GPU only keeps its own slice).
Backward (one layer at a time, reverse order):
all_gatherthe shards of again.- Compute local gradient from the micro-batch.
reduce_scattergradients — each GPU ends up with the averaged grad for its own shard only.- Apply Adam locally to the owned shard.
Optimization — on the last layer of forward, don’t drop since backward needs it next. Saves one all_gather.
HSDP (Hybrid Sharded Data Parallel) — when you have GPUs, split into groups of . FSDP within a group, plain DP across groups. Intra-group comm happens every layer (3 collectives × layers); inter-group comm happens once per step (1 all_reduce at the end). So you get ~3× intra-group traffic but 1× inter-group — route the heavy traffic onto fast NVLink inside a node, light traffic over slower inter-node networking.
Memory arithmetic example
Llama3-405B FFN with hidden dim 16384, 126 layers, bf16 activations, batch 1, seq 4096: Activations scale with batch × seq — this is why activation checkpointing matters even with FSDP.
Source
CS231n 2025 Lec 11 slides ~45–70 (plain DP, FSDP algorithm with all_gather/reduce_scatter, last-layer optimization, HSDP grouping, memory arithmetic). 2026 PDF not published — using 2025 fallback.