A few weeks ago, I introduced the generative model called generative adversarial networks (GAN), and stated the difficulties of training it. Not long after the post, a group of scientists from Facebook and Courant introduced Wasserstein GAN, which uses Wasserstein distance, or the Earth Mover (EM) distance, instead of Jensen-Shannon (JS) divergence as the final cost function.
In their paper (arXiv:1701.07875), they proposed the following to improve GAN:
- do not use sigmoid function as the activation function;
- the cost functions of the discriminator and generator must not have logarithms;
- cap the parameters at each step of training; and
- do not use momentum-based optimizer such as momentum or Adam; instead, RMSprop or SGD are recommended.
These do not have a theoretical reason, but rather empirical. However, the most important change is to use the Wasserstein distance as the cost function, which the authors explained in detail. There are many metrics to measure the differences between probability distributions, as summarized by Gibbs and Su in the paper, (arXiv:math/0209021) the authors of Wasserstein GAN discussed four of them, namely, the total variation (TV) distance, the Kullback-Leibler (KL) divergence, the Jensen-Shannon (JS) divergence, and the Earth-Mover (EM, Wasserstein) distance. They used an example of two parallel uniform distributions in a line to illustrate that only the EM (or Wasserstein) distance captured the continuity of the distribution distances, solving the problem of zero-change when the derivative of the generator becoming too small. The EM distance indicates, intuitively, how much “mass” must be transported from one distribution to another.
Formally, the EM distance is
and the training involves finding the optimal transport path .
Unfortunately, the EM distance cannot be computed directly from the definition. However, as an optimization problem, there is a dual problem corresponding to this. While the author did not explain too much, Vincent Herrmann explained about the dual problem in detail in his blog.
The algorithm is described as follow: