Intro
You've probably seen or used cross-entropy loss as a cost function for a classifier:
From Udacity
To calculate the loss for a single example, we sum over all classes, multiplying the actual value by the log of the predicted value. For example, if the three classes are (cat, dog, rabbit), then a label (1.0, 0.0, 0.0) means the example is actually a cat, and if our prediction is (0.7, 0.2, 0.1) then our loss is simply -log(0.7). Furthermore, to evaluate the cost over all examples, we just sum their losses.
The question is: what is our theoretical justification for doing so?
Background: probability and statistics
Let's start at the beginning, at the difference between probability and statistics. If I tell you I have a fair coin and ask you for the probability of flipping "HTTH," that's probability. If I tell you I flipped "HTTH" and ask you the likelihood that it's a fair coin, that's statistics.Note the distinction between the words "probability" and "likelihood." They are not interchangeable.
Let's look more closely at likelihood. If I tell you I flipped "HHHHHHH," it seems unlikely that the coin is fair -- unless, for example, we know that I selected it amongst a group of coins known to be fair. The set of possible coins is called the "prior distribution," and we won't talk about it too much here. If we don't know anything about the prior, then we might say it's likely that the coin is biased.
Moreover, there is some bias which maximizes the probability of this outcome. In this case (without more information) it seems most likely that it's biased 100% toward heads. We say that 100% is the maximum likelihood estimate for the bias of the coin. If we next flip a T, we'll have to adjust it downward.
In this example, the bias-toward-heads of the coin (b) is a parameter to the probability function. If we want to know P(H), we have to know b. Thus instead of writing P(x) we write P(x; b): the probability (or in the continuous case, probability density) of getting outcome x given bias b.
Normally instead of b we call the parameter(s) θ, and in the below we'll use f instead of P. A coin is parameterized by only one value, but in general there may be many parameters (all part of θ).
Now for some math.
This just says that the likelihood (L) of θ given a set of observations xi is equal to the probability of the observations given θ, and they are both equal to the product of the probabilities of the individual outcomes. (This assumes they are identical and independently distributed.)
Let's calculate the likelihood of θ=0.5 (a fair coin) given toss "HHHHHHH." It's equal to f(H | 0.5) multiplied by itself 7 times. That's 1/128. The likelihood of θ=1.0 is 1^7 = 1. Remember, don't interpret those as probabilities. It just tells us that the latter is far more likely.
Neural network as a model generator
We can interpret our neural network as something that generates a probability distribution f(θ, x), where θ is our network parameters (that we're trying to optimize). We treat the labels of our data as the observed outcomes (like the coin tosses), and ask: what values of θ make those outcomes most likely?Using the above formula, since we're trying to maximize L(θ | xi), we can just maximize the product on the right. That is, we maximize the product of f(xi | θ) across all training examples, where f(x | θ) is the probability of getting label x given our model. (Normally, we use 'x' as our input and 'y' as the label, so this might be a bit confusing. This should really be f(y | θ, x) since the network params and the example data are both inputs.)
Now notice that:
- The maximum of f(x) and the max of log(f(x)) happen at the same x since log is monotonic.
- log (x1 * x2 ... * xn) = log(x1) + log(x2) + ... + log(xn)
So to maximize that product, we can just maximize the sum of log(f(xi | θ)) across all examples. Or alternatively, minimize the sum of -log(f(xi | θ)).
Applying this to our first example, we find that f(cat) = 0.7 given our predicted model of (0.7, 0.2, 0.1), so the corresponding summand in our desired sum is -log(0.7) -- which matches the value of the loss function we started with. In that example, we also summed across categories for the single example, but since the other "actual" values are always zero (if it's a cat, it's never a dog or rabbit), those terms disappear.
So now we know why this formula helps us find the best θ. But why is it called "cross-entropy loss?"
Cross entropy
Entropy is a measure of information produced by a probabilistic stochastic process. If you have a stream of information and want to encode it as densely as possible, it helps to encode the more common elements with fewer bits than the less common elements. For example, Morse code uses a single dot for the very common letter 'E' and much longer sequences of dots and dashes for less common letters like 'J' and 'Q'.If you know the frequencies of the letters, the best possible encoding will require this many bits:
Where p(x) is the frequency / probability of letter x. This is called the (Shannon) entropy.
But you will often not produce the optimal encoding. In that case, the "cross" entropy is given by:
p(x) is the same as above, since the actual distribution of letters hasn't changed. q(x) represents your (incorrect) prediction of the frequencies. The value of the cross-entropy is always greater than the entropy (unless p=q, of course). Their difference gives a measure of how far off q is from p. By minimizing cross entropy, we get as close to the true distribution as possible.
Again, in our use of cross-entropy, the sum across categories can disappear since there's only one nonzero value in the label. We still sum across examples to get the total cost of our model.
Conclusion
Cross-entropy is a common loss function to use when computing cost for a classifier. We can view it as a way of comparing our predicted distribution (in our example, (0.7, 0.2., 0.1)) against the true distribution (1.0, 0.0, 0.0), and we see that summing this loss function also helps us discover the maximum likelihood estimate for the network parameters.Thanks to Rob DiPietro, whose blog explains all of this better than mine does.