Distributed Data Parallel (DDP)
DDP is a PyTorch module that allows you to parallelize your model across multiple machines, making it perfect for large-scale deep learning applications.
How does it actually work in terms of ordering? Is training done first, and then validation is done, and then synchronize the weights, or some other order?
1. Model Weights Synchronization
-
At the start of DDP training, all model replicas across processes are synchronized to have the same weights.
-
During training, gradients are synchronized (averaged) across processes after each backward pass, so all models remain in sync.
2. Validation Phase
- By default, each process has a full copy of the model weights.
When you enter validation, all processes have the same weights (since DDP keeps them in sync).
- Redundant Validation:
If you simply run validation in all processes, each will compute the same metrics on its own subset of the validation data (if using DistributedSampler), or on the whole validation set (if not using a sampler). This can lead to redundant computation and duplicated logging.