A Failure Case of VAE

Dec 31, 2021

The Variational Auto-Encoder (VAE) has been widely used to model complex data distributions. In this post we discuss a simple but important failure mode of VAEs as well as s solution proposed in paper Spread Divgerences to fix the broken VAE. The code to re-produce the experiments behind the figures can be found here.

Introduction of VAE

Given a dataset $\{x_1,…,x_N\}$ sampled from some unknown data distribution $p_d(x)$, we are interested in using a model $p_\theta(x)$ to approximate $p_d(x)$. A classic way to learn $\theta$ is to minimize the KL divergence

from the definition of “divergence”, we know that

The integration over $p_d(x)$ can be approximated by the Monte-Carlo method:

when $N\rightarrow\infty$, minimizing the KL divergence is equivalent to Maximum Likeihood Estimation (MLE). For a latent variable model $p_\theta(x)=\int p_\theta(x|z)p(z)dz$ with $p_\theta(x|z)$ parameterized by a neural network, the log-likelihood $\log p_\theta(x)$ is usually intractable. We can instead maximize the Evidence Lower Bound (ELBO)

where $q_\phi(z|x)$ is the amortized posterior that is introduced to approximate true model posetrior $p_\theta(z|x) \propto p(x|z)p(z)$. The training objective becomes maximizing $\frac{1}{N}\sum_{n=1}^N \mathrm{ELBO}(x_n,\theta,\phi)$.

A failure case

Let’s consider the following data generation process for a data $x\sim p_d(x)$:

We also plot the data samples:


We attempt to learn this distribution with a VAE of the following form:

  • $p_\theta(x|z)$: 2D Gaussian with learned mean and variance;
  • $p(z)$: 1D Bernoulli with mean equal to 0.5;
  • $q_\phi(z|x)$: 1D Bernoulli with learned mean.

If we plot the training loss and samples from the resulting trained VAE we find that the training is unstable and the quality of the samples is bad.


Why the failure happens?

Computational: The y-axis of the $p_d(x)$ has 0 variance and so the variance of the y-axis in $p_\theta(x|z)$ converges to $0$ during training and the log likelihood becomes infinity.

Theoretical: The data distribution is an 1D manifold distribution that lies in a 2D space. Its density function is not well defined (the distribution is not absolutely continous w.r.t. Lebesgue measure.), so the KL divergence and MLE are ill-defined or cannot provide valid gradients for training. This problem has also been previously discussed in Wasserstein GAN.

A simple fix: spread KL divergence

A simple trick to fix the problem is adding the same amount of convolutional noise to both data distribution $p_d(x)$ and model $p_\theta(x)$. For example, we can choose a Gaussian noise distribution $n(\tilde{x}|x)=\mathrm{N}(x,\sigma_{fix}^2 I)$ and let

We further define the spread KL divergence $\widetilde{\mathrm{KL}}_n(p_d(x)||p_\theta(x))$ as

Our paper shows that for some certain noise distributions (e.g. Gaussian), we have

so the spread KL is a valid divergence and can be used for model learning. Similarly, spread MLE can also be derived

where the integration over $n(\tilde{x}|x_n)$ can be simply approximated using Monte Carlo by adding noise to data $x_n$. For a latent variable $p_\theta(x)=\int p_\theta(x|z)p(z)dz$, the ‘noisy’ model becomes

Since $p_\theta(x|z) =\mathrm{N}(\mu_\theta(z),\sigma^2_\theta(z))$ and $n(\tilde{x}|x)=\mathrm{N}(x,\sigma^2_{fix})$, $\tilde{p}_\theta(\tilde{x}|z)$ will still be a Gaussian $\tilde{p}_\theta(\tilde{x}|z)=\mathrm{N}(\mu_\theta(z),\sigma^2_\theta(z)+\sigma^2_{fix})$ with closed-form likelihood evaluation. Similar to the ELBO, a lower bound can be further derived for the spread MLE as the training objective, see our paper for details.

If we instead train our VAE for this failure case example using the spread KL divergence, we find it stablises the training and generates good quality samples.


Why the spread KL divergence helps?

Computational: The variance of $\tilde{p}_\theta(\tilde{x}|z)$ is at least $\sigma_{fix}^2$ even when $\sigma^2_\theta(z)\rightarrow 0$, so $\log \tilde{p}_\theta(\tilde{x}|z)$ is always upper-bounded and well-defined during training.

Theoretical: Any distributions convolved with a Gaussian distribution will be absolutely continous and have valid density functions, and $\tilde{p}_\theta$ and $\tilde{p}_d$ will have the same support, so the KL divergence or MLE will be well defined for training.

Discussions

The problem described in this blog post also happens in real-world datasets. For example, the MNIST dataset has a constant black background whose pixels will have 0 variance. Using a Gaussian $p_\theta(x|z)$ with a learned variance will also lead to infinite likelihood in this case.

In general, modeling manifold distributions is an active research direction, I hope this blog post can give some intuitions about the fundamental challenges in this field and where techniques like Spread Divergences may be helpful.


This note was written in 2019, but I still found it interesting at the end of 2021, so I made it my first blog post. I want to thank Peter Hayes for useful feedbacks. Hope you enjoy it and happy new year!