Wasserstein GANs and the unexpected mathematics of a new field
This is a slightly edited version of a report I wrote for a summer research internship I did in 2021/2022 (remember, seasons are swapped in Australia). It captures why I'm skeptical of specific mathematical explanations for specific behaviors of deep learning models. It is why I am motivated to work on a general framework for models that describes them as they appear in practice, and which carefully lists the assumptions about structure and behavior that have been made.
In 2014, Ian Goodfellow argued with friends about a method that he – and seemingly only him – believed could get an algorithm to learn and imitate any dataset. That night, he tested the idea. And it worked. Well. That evening was the beginning of Generative Adversarial Networks (GANs), algorithms that massively outperformed state-of-the-art image generation techniques.
To understand how GANs can do something as remarkable as get a rock to generate pictures of faces, imagine an art student and teacher, who are both as dumb as a rock. Ok, rocks that can communicate with strings of numbers and perform countless operations per second. Nonetheless, getting them to do anything imitating creativity is quite the feat.
Initially, the student’s paintings are erratic. The teachers’ comments aren’t much better. However, the teacher learns to at least distinguish between the students’ horrible attempts and actual pictures of, say, faces. Even a rock can do that. In return, the student uses this feedback to improve their images.
Over time, this remarkably works. In GANs, the student is the generator, and the teacher is the discriminator.
Training GANs, however, was never easy. At the time we had no metric for which algorithms outperformed others, other than subjectively judging their outputs. Algorithms would be run for hours on end only for researchers to return and see meaningless outputs.
Other strange phenomena would arise. The generator would learn to make slight variations to a single image, failing to capture the diversity of the target data set.
It was around this time we needed fundamental theory, based in exotic mathematics, to save us. In 2017, a critical flaw in the approach used so far was identified by Martin Arjovsky.
The initial discriminators used a logistic loss, a yes-or-no judgement, to evaluate outputs. Problem is, the discriminator can very easily be trained to perfectly differentiate generated and real data. The technical reason is due to manifolds of different dimensions having no overlap, like the volume of a square being zero.
How this manifests can be thought of as the student’s faces being rejected by the teacher for having brushstrokes, rather than for not being faces. The dimensionality of the tool used to create images will always be lower than actual images, and hence perfect discrimination can always be found.
In a sense, the teacher is too harsh, never offering a student the chance to learn.
Traditionally, this was overcome by blurring images. Getting the teacher to wear out-of-focus glasses in the hope that belligerent attention to detail becomes impossible. Unsurprisingly, this just gets the student to paint blurry pictures.
Arjovsky proposed something else, a Wasserstein distance. The real and fake images are compared by their distance, meaning, the overall change required to go from one to the other. All of a sudden, specific details wouldn’t lead to a flat rejection from the teacher, without the need for blurry glasses. However, an algorithm must be trained to find this distance.
The discriminator trains to find this distance, while the generator learns to reduce it. This is not a yes-or-no distance, and therefore subtle brushstrokes can’t be used to perfectly recommend a harsh rejection.
In early 2018, a team from NVIDIA used this “state-of-the-art” Wasserstein loss to create cutting-edge GANs with high-resolution outputs. The Wasserstein loss seemed to be one of those elegant, clean solutions that capture the essence of an applied problem.
However, enforcing the required restriction on the discriminator proved rather difficult. To achieve these remarkable properties, the discriminator needs to train more frequently than the generator, slowing down training.
The gradient penalty, which penalizes similar images for having different discriminator scores, proved expensive to enforce. In mid-2018, a paper by Mescheder et al. showed that a Wasserstein loss discriminator does not converge. When the generator is close to mimicking the real data, the discriminator loses its incentive to refine its assessment, like a teacher getting bored of a student as they approach mastery.
NVIDIA’s stylegan1, released in late 2018, was the last iteration of their GANs to use the Wasserstein loss. Instead, a logistic loss with a gradient penalty around the real data started being used, meaning the discriminator is punished for being harsh on paintings resembling real images, while still attempting to give a rough yes-or-no probability. This “R1 regularized logistic loss” is used in the latest iteration of NVIDIA’s StyleGAN.
The tale of Wasserstein losses, their promise as a mathematically elegant solution to a problem plaguing researchers, only to then be met with the limitations of engineering, should be a warning to researchers. We do not understand deep learning well, at all. It is not obvious which fields of mathematics will be the most relevant to understanding these systems, and many of our elegant attempts will fail to map onto practice.