Scatter-Gather Pattern

https://www.enterpriseintegrationpatterns.com/patterns/messaging/BroadcastAggregate.html

Use a Scatter-Gather that broadcasts a message to multiple recipients and re-aggregates the responses back into a single message.

Scatter-Gather in transformer architecture

  • In the forward pass, when you have retrieve the embedding for a particular token from the [[embedding, you are doing a gather operation.
  • In the backwards pass, to propagate the gradients back to the correct embedding in the embedding matrix, that is a scatter operation.
  • Forward pass: It’s just a gather operation.
    • We “gather” rows from the embedding matrix corresponding to the input token indices.
    • This is essentially embedding[token_ids] under the hood.
  • Backward pass: This is effectively a scatter-add operation.
    • We have gradients (dL/d_embeddings) for each gathered row.
    • We “scatter-add” those gradients back into the corresponding rows of the embedding matrix.
    • If the same token appears multiple times, gradients for that row are summed together.