Batch Normalization

Batch Normalization allows you to train neural networks very reliably through normalization of the layers’ inputs.

It was proposed by Sergey Ioffe and Christian Szegedy in 2015.

Other resources

Formalization of batchnorm

Input: Values of over a mini-batch: Parameters to be learned:

  • is the scaling factor to make the distribution wider or skinnier
  • is the bias term to shift the distribution Output:

What is here?

  • is a constant added to the mini-batch variance for numerical stability”

How is this done in practice?

  1. Takes the outputs of the linear/conv layer before the activation
  2. Computes mean and variance across the batch for each feature/channel
  3. Normalizes those outputs so they have zero mean and unit variance

2 Parameters per FEATURE, not layer

Each feature’s mean and variance are computed independently. You are NOT computing a single mean and variance value for the entire layer, rather one for each feature.

Another view

Pipeline

Linear/Conv layer → BatchNorm → Activation (like ReLU)

Mathematically: Before After

The bias term is moved to batch norm logic

Note that since we normalize Wu + b and apply a shift, we don’t really need + b anymore.

![danger] Is batch norm becoming unpopular?

Seems like it is expensive. https://www.reddit.com/r/MachineLearning/comments/nnivo6/d_why_is_batch_norm_becoming_so_unpopular/

Batch normalization requires different processing at training and inference times. At inference time, we use the running (moving) averages of the mean and variance collected during training.

This is not the case with LayerNorm: As layer normalization is done along the length of input to a specific layer, the same set of operations can be used at both training and inference times.

Motivation guided by Andrej Karpathy

Think about the Activation Functions. If the weights aren’t initialized properly, when once these values reach the activation nodes, they will be either very negative or very positive. If you have a tanh function, then the gradient is essentially 0 at those values, so during backpropagation, it essentially acts as a dead neuron. No training/learning can be done.

ideally, you want each of the node to be approximately Gaussian, but how can you achieve this?

The most common way to initialize is Kaiming Initialization. However, as the layer gets deeper and deeper, it becomes hard to determine these values.

Batch Normalization solves this by just applying normalization on the layers. However, you also want the activations to adapt, you don’t want it to be always unit gaussian. It should be able to get scaled and shifted. So there are two extra parameters that are learned.

So part2 of the notebook.

hpreact = embcat @ W1
# Batchnorm layer
bnmeani = hpreact.mean(0, keepdim=True)
bnstdi = hpreact.std(0, keepdim=True)
hpreact = bngain * (hpreact - bnmeani) / bnstdi + bnbias
 
with torch.no_grad(): # Running value so we can use at inference time
	bnmean_running = 0.999 * bnmean_running + 0.001 * bnmeani
	bnstd_running = 0.999 * bnstd_running + 0.001 * bnstdi

At inference time:

embcat =  emb.view(emb.shape[0], -1)
hpreact = embcat @ W1 # + b1
hpreact = bngain * (hpreact - bnmean_running) / bnstd_running + bnbias
h = torch.tanh(hpreact) # (32, 100)
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, y)

We use these running sums of the batch norm mean and batch norm standard deviation for each of the hidden nodes.

  • When you do batch normalization, make sure to not use the bias terms, since the batch norm bias will take care of that
Problem

One problem with batch norm is that it actual couples different examples. Before, each of these examples are trained independently. But now, with batch norm, the examples are dependent on each other. It practice, it seems that this tends to be a good thing, since it essentially acts like a form of regularization.

However, you can easily shoot yourself in the foot.

Other Intuition

“It has been long known that the network training converges faster if its inputs are whitened – i.e., linearly transformed to have zero means and unit variances, and decorrelated”.

In Practice

You do this, such as in a ResNet

  • Linear layer
  • Batchnorm layer
  • Activation layer
  • Repeat

PyTorch

PyTorch does this scaling, using a uniform distribution as initialization.