Loss Function

Cross-Entropy Loss

Cross-entropy loss measures the difference between two probability distributions, and is the standard loss for Classification problems. LLMs use this loss to train on predicting next tokens.

Resources:

TheΒ cross-entropyΒ  between a β€œtrue” distributionΒ Β and an estimated distributionΒ Β is defined as:

Intuition

Cross-entropy penalizes confident wrong predictions catastrophically. As the predicted probability of the true class approaches zero, . A mildly-wrong hedge () costs , a confidently-wrong prediction () costs , and being sure the true class has probability zero is infinitely bad. This is what forces models to be calibrated: they cannot cheaply bluff.

Where does this log come from?

This comes from Shannon Entropy.

He proved that under certain very reasonable assumptions, log is the only possible choice.

For our ML classification problems, let be the correct class. We can simplify the cross-entropy equation

  • vanishes into a One-Hot distribution (Kronecker Delta), since we have So we have

PyTorch Cross-Entropy Loss

https://docs.pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

  • Β spans the minibatch dimension (i.e. there are samples in a batch). If this is already confusing to you, see a more basic example in L1 Loss.

There are 2 different formulations depending on how the classes are predicted.

where:

  • is the -th example
  • is the number dimensions
  • is the correct class for the -th example,
  • is the loss for the -th example
  • is an optional weight matrix for the class
  • is the value of the logit at index for the -th example

If you use indices as the target: where

Summary

Model outputs logits β†’ softmax(logits) β†’ log(softmax(logits))

The cross entropy loss is calculated over the softmax.

Exercise: