Attention (Transformer)
Terms:
- Self-Attention
- Masked Attention
- Sparse Attention
- Flash Attention
- Paged Attention (used for faster compute)
- Multi-Head Latent Attention
- Multi-Head Attention
- Multi-Query Attention (modern)
- Grouped Query Attention
Confusion:
Intuition
Soft dictionary lookup. Each query asks โwhoโs relevant to me?โ, each key advertises โhereโs what Iโm about,โ each value carries the actual content. The softmax over picks a weighted blend of values, so โitโ in a sentence can literally point at โthe catโ three tokens back and pull its information forward.
In matrix form
Attention mechanism
-
Query roughly speaking is what am I looking for.
-
Key is what I represent
-
Value is what I actually contain
-
Query (Q): โWhat am I looking for?โ
-
Key (K): โWhat do I have?โ
-
Value (V): โWhat do I give you if you choose me?โ

Another vid
I was really confused about QKV. https://www.reddit.com/r/MachineLearning/comments/19ewfm9/d_attention_mystery_which_is_which_q_k_or_v/ https://stats.stackexchange.com/questions/421935/what-exactly-are-keys-queries-and-values-in-attention-mechanisms
Like conceptually, there are no physical properties that allow us to distinguish Q from K. Like for all we known, Q is K and K is Q.
- Query (Q): โWhat am I looking for?โ
- Key (K): โWhat do I have?โ
- Value (V): โWhat do I give you if you choose me?โ
Think of a library:
- Query = what you want to read about
- Keys = the summary on the card catalog for each book
- Values = the full content of each book
Question
in attention, is W_Q, W_K and W_V really necessary? Just use the raw embedding as your Q, K and V. the embedding is going to change too anyways
- Inputs:
- Query
- Value
- Key
Output: an matrix
The dot product of the query and key tells you how well the key and query are aligned. Large positive dot product = โthis key matches what the query is asking aboutโ, so softmax puts most of its mass there and the output is pulled toward that value.
Then (Softmax operation is row-wise, i.e., ):
What is ? I think that is the number of dimensions
- Itโs just a scaling factor
- I asked the professor and he said empirically, it gives the best performance
Why
\sqrt{d_k}?Without it, dot products grow with dimension ( has variance proportional to for random unit-ish vectors), softmax saturates into a near one-hot, and gradients vanish. Dividing by keeps logit variance around 1 so softmax stays in the informative regime where it can be nudged during training.
Attention Layer
This is cross-attention

Building up Attention from scratch (CS231n 2025 Lec 8)
Bahdanau attention (the original motivation)
Seq2seq (Sutskever 2014) bottlenecked the full input through a single fixed-size vector . Long inputs (T=1000) canโt compress into one vector. Bahdanau et al. (ICLR 2015) let the decoder look back at all encoder states each step.
For each decoder timestep :
- Alignment scores: where is a small learned MLP
- Attention weights: so
- Context vector:
- Decoder step:
Intuition, translating โwe see the skyโ โ โvediamo il cieloโ. At the first decoder step โvediamoโ = โwe seeโ, so the model learns weights like (attend to โweโ, โseeโ), . All differentiable, no supervision on weights. It learns alignment as a byproduct of caption quality.
The fixed bottleneck vector was the whole translation, compressed. Attention lets the decoder reach back and reread the input as it writes each word, so no single vector has to carry everything.
Generalized Attention Layer (inputs decoupled from scores)
Throw away the RNN; keep the primitive. Inputs:
- Query vectors
- Data vectors
- Key projection and Value projection
Two upgrades over Bahdanau:
- Scaled dot product replaces the MLP : . The scaling keeps pre-softmax logits in the โnot saturatedโ regime; without it, large makes softmax collapse to one-hot and gradients vanish
- Separate key & value matrices, the query compares against but sums up . Decouples โwhat to match onโ from โwhat to pass throughโ. Library analogy: the catalog card (key) advertises the bookโs topic; the book itself (value) is what you actually read. You donโt want to be forced to score books using their full contents
Self-Attention = Cross-Attention where Q, K, V all come from X
Each token plays all three roles against every other token: it asks its own question, offers its own advertisement, and provides its own content. The layer is just โevery token mixes in a weighted blend of every other tokenโs value, weighted by query-key similarity.โ Thatโs it.
Everything else is the same. The QKV projections are often fused into one matmul: with shape .
Is self-attention permutation-equivariant?
Yes, . Permuting the inputs permutes , , and thus the outputs in the same way. Self-attention works on sets of vectors, not sequences. This is why Transformers need positional encoding (see Positional Encoding); the layer itself has no notion of order.
Masked self-attention
For autoregressive language modeling, token must not see tokens . Implement by setting for before the softmax, which forces . No changes to the forward code path, just a mask.
in the softmax, so those positions contribute zero weight. The same forward pass trains all prefixes in parallel without the token at position ever seeing its own future, which is what makes teacher-forced training tractable on GPUs.
Multi-head self-attention
Run independent self-attention heads in parallel, each with head dim , then concat and project:
Intuition
Ask multiple questions in parallel. One head can track syntax, another semantics, another long-range co-reference. Each head has its own so it learns its own notion of โwhat to match onโ and โwhat to pass through.โ Concatenating lets the next layer read the sequence from several perspectives at once. The total param count is the same as one big head of width since each head is wide.
| name | shape | |
|---|---|---|
| Queries | ||
| Keys | ||
| Values | ||
| Similarities | ||
| Attention | ||
| Head outputs | reshape | |
| Output |
In practice all heads run as one batched matmul.
Self-attention is four matmuls
Everything above collapses to four matrix multiplies against :
- QKV projection: , split into Q/K/V
- QK similarity:
- V-weighting:
- Output projection:
Compute and memory are both from steps 2+3. At , , the attention matrix alone is 1.19 TB, no GPU fits that. Flash Attention fuses steps 2+3 tile-by-tile, avoiding the full materialization and reducing memory to .
The cost is the price of โevery token sees every token.โ You never actually need the whole matrix in memory: each row only needs its own softmax denominator, which Flash Attention streams through SRAM.
Spatio-Temporal Self-Attention / Nonlocal Block (CS231n 2025 Lec 10)
The Nonlocal Block (Wang et al CVPR 2018) drops self-attention into a 3D CNN as a residual sub-module, letting any spatio-temporal location attend to any other. Pipeline starting from a 3D-CNN feature map :
- Three convs project to (typically )
- Reshape to , compute ; every spatio-temporal voxel attends to every other
- Apply to (reshape to ) โ
- conv back to channels
- Residual add to the input feature map
Plug it anywhere into an existing 3D CNN (e.g. I3D, SlowFast) for โglobal temporal context for freeโ; SlowFast + Nonlocal hits 79.8 on Kinetics-400, vs SlowFast alone trailing.
This is structurally identical to ViT-style self-attention, just over tokens with 3D-conv projections instead of linear ones. Foreshadows the Video Transformer wave (ViViT, MViTv2, VideoMAE).
Context Parallelism, parallelizing attention across the sequence (CS231n 2025 Lec 11)
When sequence length gets long (131k for Llama3-405B stage 2), even FSDP+TP canโt fit activations. Context Parallelism (CP) splits the tokens across GPUs along the sequence axis. Most Transformer sublayers are easy:
- RMSNorm / residual: elementwise, no comm needed
- MLP / QKV projection: matmul on a slice of tokens, same weights everywhere; treat like DP on the sliced tokens
- Self-attention: hard, because every query needs every key & value; the score matrix is and spans GPUs
Two implementations:
Option 1, Ring Attention (Liu et al., arXiv 2023). Outer loop over tiles (stays local to each GPU); inner loop rotates tiles around a ring of GPUs. Each GPU processes its against every shard by receiving from its left neighbor and sending to its right. Scales to arbitrarily long sequences (no single GPU ever holds the full ), at the cost of peer-to-peer comms per attention layer.
Option 2, DeepSpeed Ulysses (Jacobs et al., arXiv 2023). Parallelize over heads instead of sequence: each GPU gets all tokens for a subset of heads. Max parallelism = (number of heads). Simpler than Ring Attention but caps at ; fine for Llama3 (128 heads) but doesnโt scale past that.
Llama3-405B usage:
- Stage 1 (seq 8192): no CP, TP=8, PP=16, DP=64/128 โ 43% MFU
- Stage 2 (seq 131072): CP=16 (so 8192 tokens/GPU), TP=8, PP=16, DP=8 โ 38% MFU
CP is Flash Attention-friendly; both avoid materializing the full attention matrix, just at different granularities (tile-within-GPU vs GPU-across-ring).