Variational autoencoders for dummies


Variational autoencoders are another architecture that can be surprisingly hard to get your head around given how simple they ultimately are. If you're already very familiar with Bayesian stats, maybe these descriptions make sense to you. But they don't to me.
Variational Autoencoders (VAEs) incorporate regularization by explicitly learning the joint distribution over data and a set of latent variables that is most compatible with observed datapoints and some designated prior distribution over latent space. The prior informs the model by shaping the corresponding posterior, conditioned on a given observation, into a regularized distribution over latent space (the coordinate system spanned by the hidden representation).
Let's start by recalling the general idea of autoencoders.

Autoencoder

Suppose we have some images (such as the MNIST dataset of hand-drawn digits). We can generate a network that's comprised of two parts:

  1. An encoder, which maps each image to a 2D vector. This 2D space is called the latent space, and the 2D vector for an image is its latent representation.
  2. A decoder, which maps a the latent representation into an image.
The idea is that the two together should be able to take an image (28 x 28 pixels = 784 dimensions), compress it into 2D, and then decompress it back into the same image.

Ideally, your encoding would have some nice structure. For example, all the twos from your dataset would be encoded to 2D points that are near to each other. That way, if you were to pick a vector in that cluster and feed it to the decoder, you would get back an image that looks like a two.


But actually it's possible to overtrain and not get such nice properties: with enough degrees of freedom, the network might be able to "memorize" the data. The encoder could just map the inputs to arbitrary 2D vectors and the decoder would just memorize their reconstructions, bypassing the need to cluster similar images together.

Variation to the rescue

One clever way around this is to introduce some jitter after the encoding: the encoder produces a 2D vector, then the network randomly nudges it a bit, and the decoder has to reconstruct the original from this nudged vector. This way, the network is incentivized not to have dissimilar images end up in similar places in the latent space. Otherwise you might encode a two, jiggle it, and get back out a zero, leading to a large loss.

A funny way to add variation

Suppose in our autoencoder, the encoder maps a certain image to the vector (0.6, 0.4). This is given by a dense (fully connected) layer with two neurons.

A slightly wonky way of looking at this is that the encoder is mapping the image to two Gaussian distributions. The first distribution has a mean of 0.6 and a variance of zero; the second has a distribution of 0.4 and a variance of zero. This can be represented by four neurons (two of which are forced to be zero).

The "jitter" (or "sampling") phase of the network then selects from each distribution, producing a vector to pass on. Since the variance is zero for each, the only thing it can produce is the vector (0.6, 0.4). This is a standard (non-variational) autoencoder.

Suppose we remove the constraint that the variances must be zero. Then, the sampling phase would be nudging that (0.6, 0.4) value a bit (maybe producing (0.58, 0.43), for example). But because less variance actually helps the network reconstruct perfectly, it may naturally tend toward zero anyway.

So we encourage it away from zero by introducing a component to our loss function.

The loss

Our loss function thus far has been a "reconstruction loss": how similar is the reconstructed image to the input? But there's a second thing we want to encourage in a variational autoencoder: we want the generated Gaussians to be constrained in some way.

As mentioned above, we want them to have nonzero variance (or else we just degenerate to a regular autoencoder). We also want the means to be well-controlled. Why? Because VAEs are meant to be generative networks. We want to be able to select a point in the latent space and decode it into something similar to our trained data. And therefore we want our latent space to be densely packed, leaving no "holes" in our encoding (which would probably not have interesting decodings).

For a single image, for each dimension, we have one encoded mean and variance. One thing we can do is try to make that mean and variance approach a distribution of our choice. A common choice is Normal(0, 1). And luckily, there's a well-understood way to measure how close a given distribution is from another. It's known as the Kullback-Leibler (KL) divergence, and for normal distributions it simplifies to:
$$KL(p, q) = \log \frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2 \sigma_2^2} - \frac{1}{2}$$

We can plug in $\mu_1 = 0$, $\sigma_1 = 1$ and measure how close we got. That term gets added to our loss.

Putting it all together

Nothing helps understand stuff like code. So here's sample Keras code for a VAE.

First, the encoder. We'll gloss over the initial parts (which here are just dense layers, but can also be convolution layers, or pretty much anything). The interesting part is that we create latent_dim means and as many variances:

 z_mean = Dense(latent_dim)(h)  
 z_log_var = Dense(latent_dim)(h)  

Actually we're creating the log of the variance, just to make some of the math work out easier. Next, we sample from this distribution:
 epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=1.0)  
 return z_mean + K.exp(z_log_var / 2) * epsilon  

This looks a bit intimidating, but it's really just a vectorized way of sampling from latent_dim  Gaussians. For each dimension, we pick a number from Normal(0, 1) (epsilon) and then we scale it so that it appears to have come from Normal(z_mean, z_var). Also note that
$$e^{z\_log\_var / 2} = \sqrt{e^{z\_log\_var}} = \sqrt{z\_var} = z\_stddev$$

Finally, the loss function:
 xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)  
 kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)  
 return K.mean(xent_loss + kl_loss)  

We have the cross-entropy loss, which is one way to calculate a pixel-wise distance from the image (x) to the reconstruction (x_decoded_mean). We also have the KL loss, which is just the formula we gave above. We multiply the xent_loss by a large number (original_dim) so that it doesn't get overpowered by the kl_loss. If the kl_loss were too powerful, all of our encodings would become identical, which would be unhelpful.


No comments:

Post a Comment

Maximum Likelihood Estimation for dummies

What is Maximum Likelihood Estimation (MLE)? It's simple, but there are some gotchas. First, let's recall what likelihood  is. ...