Generative Adversarial Networks for dummies

I've found it surprisingly hard to find a simple explanation of GANs online. Most either left out the key details or else pointed me at the original research paper to make sense of them. Here I hope to give a rough overview that covers the essential details.

What is a GAN?

I'm assuming you already have some idea if you landed here, and anyway other sites do a fine job of explaining this. But basically, a Generative Adversarial Network is one where one net (the Generator, G) is trying to generate realistic-looking images (or whatever data is in your domain) and a second network (the Discriminator), D is trying to identify those generated images as fakes. You have a pool of real images that D should learn to recognize as real and that G will eventually learn to imitate.

Architecture details

The Discriminator is any network that takes in an image (say) and produces a binary label: 0 for "image is fake," 1 for "image is real." A convolutional network might be good here, for example.

The Generator takes as input a random variable chosen from some distribution (often Normal(0, 1)) and generates an image from it. This way, once the Generator learns to generate realistic images, you can generate other realistic images by drawing from the same distribution and feeding it into G.

The output of G is fed as the input to D, so that (a) D can be given a fake image to train on, and (b) we can use backpropagation through both D and G so they simultaneously learn.

Training

This piece is often left out of descriptions. The process goes like this:
  1. Pick our random input (z) to G by drawing from Normal(0, 1) (or whatever distribution we picked).
  2. Feed it to G, which outputs a fake (f).
  3. Feed f to D and get y_f. Also feed a real training image (x) to D and get y_x.
  4. Calculate discriminator loss (loss_d):  -log(y_x) - log(1 - y_f). This is minimized when y_x is 1 and y_f is 0. In other words, when D categorizes image x as real (1) and f as fake (0).
  5. Calculate generator loss (loss_g): -log(y_f). This is minimized when y_f is 1. In other words, when D classifies the fake as real (1).
  6. Backpropagate, ensuring that only D learns from loss_d and only G learns from loss_g. This is important. If D were to learn from loss_g, then it would move toward classifying fakes as real. If G were to learn from loss_d, it would generate worse fakes.
In Tensorflow, Optimizers' minimize() function take a var_list  to let you specify which variables should be learning. You create one Optimizer to minimize loss_d (with D's variables) and one for loss_g (with G's).

To summarize: 
  1. Generate a fake image (by feeding G a random input).
  2. Feed the fake and one real image to D (you can think of this as two "copies" of D with the same weights).
  3. Calculate loss_d and backpropagate through D.
  4. Calculate loss_g and backpropagate all the way back through G (without updating D).

Sometimes people get confused by the "two copies" idea, but it's quite straightforward. Imagine that you have a network with fixed weights. You run some data (x) through it and get a result. You run some other data (f) and get a second result. You calculate some loss that's a function of both results. Jittering any parameter of D (weights, biases, etc.) will affect that loss function in some precise way, so you can calculate which way to jitter each parameter to decrease that loss in exactly the same way you would with a single "copy."

Resources:

[1] http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/

No comments:

Post a Comment

Maximum Likelihood Estimation for dummies

What is Maximum Likelihood Estimation (MLE)? It's simple, but there are some gotchas. First, let's recall what likelihood  is. ...