Scatter-Gather Pattern
https://www.enterpriseintegrationpatterns.com/patterns/messaging/BroadcastAggregate.html
Use a Scatter-Gather that broadcasts a message to multiple recipients and re-aggregates the responses back into a single message.
Scatter-Gather in transformer architecture
- In the forward pass, when you have retrieve the embedding for a particular token from the [[embedding, you are doing a gather operation.
- In the backwards pass, to propagate the gradients back to the correct embedding in the embedding matrix, that is a scatter operation.
- Forward pass: It’s just a gather operation.
- We “gather” rows from the embedding matrix corresponding to the input token indices.
- This is essentially
embedding[token_ids]
under the hood.
- Backward pass: This is effectively a scatter-add operation.
- We have gradients (
dL/d_embeddings
) for each gathered row. - We “scatter-add” those gradients back into the corresponding rows of the embedding matrix.
- If the same token appears multiple times, gradients for that row are summed together.
- We have gradients (