Batch Normalization

Batch Normalization allows you to train neural networks very reliably through normalization of the layers’ inputs.

It was proposed by Sergey Ioffe and Christian Szegedy in 2015.

Input: Values of over a mini-batch: Parameters to be learned:

  • where the is the scaling factor to make the distribution wider or skinnier
  • is the bias term to shift the distribution Output:

Andrej Karpathy

Think about the Activation Functions. If the weights aren’t initialized properly, when once these values reach the activation nodes, they will be either very negative or very positive. If you have a tanh function, then the gradient is essentially 0 at those values, so during backpropagation, it essentially acts as a dead neuron. No training/learning can be done.

ideally, you want each of the node to be approximately Gaussian Distribution, but how can you achieve this?

The most common way to initialize is kaiming initialization. However, as the layer gets deeper and deeper, it becomes hard to determine these values.

Batch Normalization solves this by just applying normalization on the layers. However, you also want the activations to adapt, you don’t want it to be always perfectly gaussian. It should be able to get scaled and shifted. So there are two extra parameters that are learned.

So part2 of the notebook.

hpreact = embcat @ W1
# Batchnorm layer
bnmeani = hpreact.mean(0, keepdim=True)
bnstdi = hpreact.std(0, keepdim=True)
hpreact = bngain * (hpreact - bnmeani) / bnstdi + bnbias
 
with torch.no_grad(): # Running value so we can use at inference time
	bnmean_running = 0.999 * bnmean_running + 0.001 * bnmeani
	bnstd_running = 0.999 * bnstd_running + 0.001 * bnstdi

At inference time:

embcat =  emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 # + b1
hpreact = bngain * (hpreact - bnmean_running) / bnstd_running + bnbias
h = torch.tanh(hpreact) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, y)

We use these running sums of the batch norm mean and batch norm standard deviation for each of the hidden nodes.

  • When you do batch normalization, make sure to not use the bias terms, since the batch norm bias will take care of that
Problem

One problem with batch norm is that it actual couples different examples. Before, each of these examples are trained independently. But now, with batch norm, the examples are dependent on each other. It practice, it seems that this tends to be a good thing, since it essentially acts like a form of regularization.

However, you can easily shoot yourself in the foot.

In Practice

You do this, such as in a ResNet

  • Linear layer
  • Batchnorm layer
  • Activation layer
  • Repeat

PyTorch

PyTorch does this scaling, using a uniform distribution as initialization.