Published on

A Failure Case of VAE

The Variational Auto-Encoder (VAE) has been widely used to model complex data distributions. In this post, we discuss a simple but important failure mode of VAEs as well as a solution proposed in the paper Spread Divergences to fix the broken VAE. The code to reproduce the experiments behind the figures can be found here.

1. Introduction of VAE

Given a dataset {x1,,xN}\{x_1,…,x_N\} sampled from some unknown data distribution pd(x)p_d(x), we are interested in using a model pθ(x)p_\theta(x) to approximate pd(x)p_d(x). A classic way to learn θ\theta is to minimize the KL divergence:

KL(pd(x)pθ(x))=pd(x)logpd(x)constantpd(x)logpθ(x)dx,\mathrm{KL}(p_d(x)||p_\theta(x))=\underbrace{\int p_d(x)\log p_d(x)}_{\text{constant}}-\int p_d(x)\log p_\theta(x)dx,

From the definition of "divergence", we know that:

KL(pd(x)pθ(x))=0pd(x)=pθ(x).\mathrm{KL}(p_d(x)||p_\theta(x))=0 \leftrightarrow p_d(x)=p_\theta(x).

The integration over pd(x)p_d(x) can be approximated by the Monte-Carlo method:

KL(pd(x)pθ(x))1Nn=1Nlogpθ(xn)+const.,\mathrm{KL}(p_d(x)||p_\theta(x))\approx -\frac{1}{N}\sum_{n=1}^N\log p_\theta(x_n)+\text{const.},

When NN \rightarrow \infty, minimizing the KL divergence is equivalent to Maximum Likelihood Estimation (MLE). For a latent variable model pθ(x)=pθ(xz)p(z)dzp_\theta(x)=\int p_\theta(x|z)p(z)dz with pθ(xz)p_\theta(x|z) parameterized by a neural network, the log-likelihood logpθ(x)\log p_\theta(x) is usually intractable. We can instead maximize the Evidence Lower Bound (ELBO):

logpθ(x)(logpθ(xz)+logp(z)logqϕ(zx))qϕ(zx)dz=ELBO(x,θ,ϕ),\log p_\theta(x)\geq \int \left(\log p_\theta(x|z)+\log p(z)-\log q_\phi(z|x)\right) q_\phi(z|x)dz = \mathrm{ELBO}(x,\theta,\phi),

Where qϕ(zx)q_\phi(z|x) is the amortized posterior that is introduced to approximate the true model posterior pθ(zx)p(xz)p(z)p_\theta(z|x) \propto p(x|z)p(z). The training objective becomes maximizing 1Nn=1NELBO(xn,θ,ϕ)\frac{1}{N}\sum_{n=1}^N \mathrm{ELBO}(x_n,\theta,\phi).

2. A failure case

Let’s consider the following data generation process for data xpd(x)x\sim p_d(x):

zBernoulli(0.5),aGaussian(2z1,0.1),x=(a,0).z \sim \text{Bernoulli}(0.5),\quad a \sim \text{Gaussian}(2z-1, 0.1),\quad x =(a,0).

We also plot the data samples:

summary

We attempt to learn this distribution with a VAE of the following form:

  • pθ(xz)p_\theta(x|z): 2D Gaussian with learned mean and variance;
  • p(z)p(z): 1D Bernoulli with mean equal to 0.5;
  • qϕ(zx)q_\phi(z|x): 1D Bernoulli with learned mean.

If we plot the training loss and samples from the resulting trained VAE we find that the training is unstable and the quality of the samples is bad.

summary

2. Why the failure happens?

Computational: The y-axis of the pd(x)p_d(x) has 0 variance, and so the variance of the y-axis in pθ(xz)p_\theta(x|z) converges to 0 during training and the log likelihood becomes infinity.

Theoretical: The data distribution is a 1D manifold distribution that lies in a 2D space. Its density function is not well-defined (the distribution is not absolutely continuous w.r.t. Lebesgue measure.), so the KL divergence and MLE are ill-defined or cannot provide valid gradients for training. This problem has also been previously discussed in Wasserstein GAN.

3. A simple fix: spread KL divergence

A simple trick to fix the problem is adding the same amount of convolutional noise to both data distribution pd(x)p_d(x) and model pθ(x)p_\theta(x). For example, we can choose a Gaussian noise distribution n(x~x)=N(x,σfix2I)n(\tilde{x}|x)=\mathrm{N}(x,\sigma_{fix}^2 I) and let:

p~d(x~)=pd(x)n(x~x)dx,p~θ(x~)=pθ(x)n(x~x)dx.\tilde{p}_d(\tilde{x})= \int p_d(x) n(\tilde{x}|x)dx,\quad \tilde{p}_\theta(\tilde{x})=\int p_\theta(x) n(\tilde{x}|x)dx.

We further define the spread KL divergence KL~n(pd(x)pθ(x))\widetilde{\mathrm{KL}}_n(p_d(x)||p_\theta(x)) as:

KL~n(pd(x)pθ(x))KL(p~d(x~)p~θ(x~)).\widetilde{\mathrm{KL}}_n(p_d(x)||p_\theta(x))\equiv \mathrm{KL}(\tilde{p}_d(\tilde{x})||\tilde{p}_\theta(\tilde{x})).

Our paper shows that for some certain noise distributions (e.g. Gaussian), we have:

KL~n(pd(x)pθ(x))KL(pd(x)pθ(x))\widetilde{\mathrm{KL}}_n(p_d(x)||p_\theta(x))\geq \mathrm{KL}(p_d(x)||p_\theta(x))

This provides a good estimation for KL(pd(x)pθ(x))\mathrm{KL}(p_d(x)||p_\theta(x)) and can provide valid gradients for training. When we train with spread KL divergence, the result becomes much better:

summary

3. Conclusion

Even a simple model like VAE can fail in a seemingly straightforward situation due to both computational and theoretical reasons. Recognizing these pitfalls can help us design better models. A simple but effective fix is to spread the KL divergence, which provides valid training signals and fixes the problem.

This note was written in 2019, but I still found it interesting at the end of 2021, so I made it my first blog post. I want to thank Peter Hayes for useful feedbacks. Hope you enjoy it and happy new year!