Context Parallelism

Context Parallelism (CP) parallelizes a Transformer across the sequence dimension. Each GPU gets a slice of the tokens instead of a full sequence.

Easy parts:

  • RMSNorm and residual ops are elementwise
  • MLPs and QKV projections run locally on each token slice

Hard part:

  • self-attention, because every query still needs access to every key and value

Two common approaches:

  1. Ring Attention: rotate shards around GPUs so each local query shard eventually attends to the full sequence
  2. DeepSpeed Ulysses: repartition by heads so each GPU gets all tokens for a subset of heads

CP is complementary to FlashAttention: both avoid materializing the full attention matrix, but at different levels of the stack.

Related: