Cross-Entropy Loss
Cross-entropy loss measures the difference between two probability distributions, and is the standard loss for Classification problems. LLMs use this loss to train on predicting next tokens.
Resources:
TheΒ cross-entropyΒ between a βtrueβ distributionΒ Β and an estimated distributionΒ Β is defined as:
Intuition
Cross-entropy penalizes confident wrong predictions catastrophically. As the predicted probability of the true class approaches zero, . A mildly-wrong hedge () costs , a confidently-wrong prediction () costs , and being sure the true class has probability zero is infinitely bad. This is what forces models to be calibrated: they cannot cheaply bluff.
Where does this log come from?
This comes from Shannon Entropy.
He proved that under certain very reasonable assumptions, log is the only possible choice.
For our ML classification problems, let be the correct class. We can simplify the cross-entropy equation
- vanishes into a One-Hot distribution (Kronecker Delta), since we have So we have
PyTorch Cross-Entropy Loss
https://docs.pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
- Β spans the minibatch dimension (i.e. there are samples in a batch). If this is already confusing to you, see a more basic example in L1 Loss.
There are 2 different formulations depending on how the classes are predicted.
where:
- is the -th example
- is the number dimensions
- is the correct class for the -th example,
- is the loss for the -th example
- is an optional weight matrix for the class
- is the value of the logit at index for the -th example
If you use indices as the target: where
- The Negative Log Likelihood of the correct output
Summary
Model outputs logits β softmax(logits) β log(softmax(logits))
The cross entropy loss is calculated over the softmax.
Exercise:
- Help explain how we derive the Behavior Cloning loss, and how we end up using the Cross-Entropy Loss, which is where the log comes from