Batch Normalization
Batch Normalization allows you to train neural networks very reliably through normalization of the layers’ inputs.
It was proposed by Sergey Ioffe and Christian Szegedy in 2015.
Other resources
Formalization of batchnorm
Input: Values of over a mini-batch: Parameters to be learned:
- is the scaling factor to make the distribution wider or skinnier
- is the bias term to shift the distribution Output:
What is here?
- ” is a constant added to the mini-batch variance for numerical stability”
How is this done in practice?
- Takes the outputs of the linear/conv layer before the activation
- Computes mean and variance across the batch for each feature/channel
- See Neural Network Dimensions for what we mean by channel
- Normalizes those outputs so they have zero mean and unit variance
2 Parameters per FEATURE, not layer
Each feature’s mean and variance are computed independently. You are NOT computing a single mean and variance value for the entire layer, rather one for each feature.
- If you did it across the entire layer, that is Layer Normalization!
Another view
Pipeline
Linear/Conv layer → BatchNorm → Activation (like ReLU)
Mathematically: Before After
The bias term is moved to batch norm logic
Note that since we normalize Wu + b and apply a shift, we don’t really need + b anymore.
![danger] Is batch norm becoming unpopular?
Seems like it is expensive. https://www.reddit.com/r/MachineLearning/comments/nnivo6/d_why_is_batch_norm_becoming_so_unpopular/
Batch normalization requires different processing at training and inference times. At inference time, we use the running (moving) averages of the mean and variance collected during training.
This is not the case with LayerNorm: As layer normalization is done along the length of input to a specific layer, the same set of operations can be used at both training and inference times.
Motivation guided by Andrej Karpathy
Think about the Activation Functions. If the weights aren’t initialized properly, when once these values reach the activation nodes, they will be either very negative or very positive. If you have a tanh function, then the gradient is essentially 0 at those values, so during backpropagation, it essentially acts as a dead neuron. No training/learning can be done.
ideally, you want each of the node to be approximately Gaussian, but how can you achieve this?
The most common way to initialize is Kaiming Initialization. However, as the layer gets deeper and deeper, it becomes hard to determine these values.
Batch Normalization solves this by just applying normalization on the layers. However, you also want the activations to adapt, you don’t want it to be always unit gaussian. It should be able to get scaled and shifted. So there are two extra parameters that are learned.
So part2 of the notebook.
hpreact = embcat @ W1
# Batchnorm layer
bnmeani = hpreact.mean(0, keepdim=True)
bnstdi = hpreact.std(0, keepdim=True)
hpreact = bngain * (hpreact - bnmeani) / bnstdi + bnbias
with torch.no_grad(): # Running value so we can use at inference time
bnmean_running = 0.999 * bnmean_running + 0.001 * bnmeani
bnstd_running = 0.999 * bnstd_running + 0.001 * bnstdi
At inference time:
embcat = emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 # + b1
hpreact = bngain * (hpreact - bnmean_running) / bnstd_running + bnbias
h = torch.tanh(hpreact) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, y)
We use these running sums of the batch norm mean and batch norm standard deviation for each of the hidden nodes.
- When you do batch normalization, make sure to not use the bias terms, since the batch norm bias will take care of that
Problem
One problem with batch norm is that it actual couples different examples. Before, each of these examples are trained independently. But now, with batch norm, the examples are dependent on each other. It practice, it seems that this tends to be a good thing, since it essentially acts like a form of regularization.
However, you can easily shoot yourself in the foot.
Other Intuition
“It has been long known that the network training converges faster if its inputs are whitened – i.e., linearly transformed to have zero means and unit variances, and decorrelated”.
In Practice
You do this, such as in a ResNet
- Linear layer
- Batchnorm layer
- Activation layer
- Repeat
PyTorch
PyTorch does this scaling, using a uniform distribution as initialization.