Fully Sharded Data Parallel (FSDP)

https://engineering.fb.com/2021/07/15/open-source/fsdp/

They actually have a paper PyTorch FSDP Experiences on Scaling Fully Sharded Data Parallel that you should read if you really want to understand what’s going on.

FSDP = DP + ZeRO sharding (CS231n 2025 Lec 11)

Plain DP keeps a full copy of every weight on every GPU. That doesn’t fit once the model is large — a 100B-param model with Adam needs 100B × 4 states × 2 bytes = 800GB per GPU (weights + grads + two Adam momenta in bf16), and an H100 has 80GB. FSDP (Rajbhandari et al., ZeRO, arXiv 2019) solves this by sharding each weight across the DP group: GPU is the sole owner of weight slice and its corresponding grad / optimizer-state slices. Split 800GB over 80 GPUs → 10GB/GPU.

Forward (one layer at a time):

  1. All GPUs all_gather the shards of → full materialized on every GPU.
  2. Compute as in plain DP.
  3. Free the non-owned shards (each GPU only keeps its own slice).

Backward (one layer at a time, reverse order):

  1. all_gather the shards of again.
  2. Compute local gradient from the micro-batch.
  3. reduce_scatter gradients — each GPU ends up with the averaged grad for its own shard only.
  4. Apply Adam locally to the owned shard.

Optimization — on the last layer of forward, don’t drop since backward needs it next. Saves one all_gather.

HSDP (Hybrid Sharded Data Parallel) — when you have GPUs, split into groups of . FSDP within a group, plain DP across groups. Intra-group comm happens every layer (3 collectives × layers); inter-group comm happens once per step (1 all_reduce at the end). So you get ~3× intra-group traffic but 1× inter-group — route the heavy traffic onto fast NVLink inside a node, light traffic over slower inter-node networking.

Memory arithmetic example

Llama3-405B FFN with hidden dim 16384, 126 layers, bf16 activations, batch 1, seq 4096: Activations scale with batch × seq — this is why activation checkpointing matters even with FSDP.

Source

CS231n 2025 Lec 11 slides ~45–70 (plain DP, FSDP algorithm with all_gather/reduce_scatter, last-layer optimization, HSDP grouping, memory arithmetic). 2026 PDF not published — using 2025 fallback.