Attention (Transformer)

Flash Attention

FlashAttention is a memory-efficient exact-attention algorithm that fuses the whole attention computation into a single tiled CUDA kernel, avoiding ever materializing the full N×N attention matrix Dao et al. 2022.

What's wrong with standard attention?

The naive implementation computes of shape , writes it to HBM, reads it back for the softmax, writes again, reads to multiply with . For long contexts this is both memory-quadratic and bandwidth-bound, HBM ↔ SRAM traffic dominates.

The trick: tiling + online softmax

  • Load blocks of , , into SRAM one tile at a time
  • Compute the tile’s partial attention in SRAM, never writing the N×N intermediate to HBM
  • Use an online softmax that updates the running normalizer as new tiles arrive, so a global pass isn’t needed
  • Kernel fusion: softmax + value-multiply live in the same kernel as the matmul

Net effect: O(N) memory instead of O(N²), plus 2-4× wall-clock speedup on long sequences.

Variants

  • FlashAttention (2022): core tiling + online softmax
  • FlashAttention-2 (2023): better work partitioning across warps, ~2× over FA-1
  • FlashAttention-3 (2024): Hopper-specific (warp-specialization, FP8), another ~1.5-2×