Published on

# Generalization Gaps of VAE

### 1. Generalization of Probabilistic Models

The goal of probabilistic modelling is to learn a model $p_\theta(x)$ to fit the training dataset $\mathcal{X}_{train}=\{x_1,\cdots,x_N\}\sim p_d(x)$, where $p_d(x)$ is the unknown data distribution. A common training criterion is the maximum likelihood learning:

$\max_\theta \frac{1}{N}\sum_{n=1}^N \log p_\theta(x_n).$

The generalization of the $p_\theta(x)$ can be evaluated using the test log-likelihood $\frac{1}{M}\sum_{m=1}^M \log p_\theta(x'_m)$ with the test dataset $\mathcal{X}_{test}=\{x'_1,…,x'_M\}\sim p_d(x)$. This definition of generalization has an important practical implicaition: by using the mdoel $p_\theta(x)$, we can design a compression algorithm to compress a data $x'$ losslessly into a binary string whose length is approximately $-\log_2 p_\theta(x')$. Therefore, a better generalization, as measured by the test log-likelihood, translates to enhanced performance in practical lossless compression. A detailed discussion on the this connection can be found in David MacKay's textbook.

Furthermore, it's noteworthy that this perspective on lossless compression has been instrumental in recent efforts to understand the generalization properties of large language models, see the video by Ilya Sutskever for a discussion.

### 2. Introduction of VAE

Variational Auto-Encoder (VAE) is a latent variable model with the form

$p_\theta(x)=\int p_\theta(x|z)p(z)dz,$

where the $p_\theta(x|z)$ is the deocder and $p(z)$ is the prior. When $p_\theta(x|z)$ is parameterized by a neural network, the integration over $z$ is is usually not tractable. A lower-bound of the log-likelihood can be used, which is refered to as the ELBO

$\frac{1}{N}\sum_{n=1}^N \log p_\theta(x_n)\geq \frac{1}{N}\sum_{n=1}^N \left(\int q_\phi(z|x)\log p_\theta(x_n|z)dz-\mathrm{KL}(q_\phi(z|x_n)||p(z))\right) \equiv \frac{1}{N}\sum_{n=1}^N \mathrm{ELBO}(x_n,\theta,\phi),$

where $p_\theta(z|x_n)$ is the true posterior $p_\theta(z|x_n)\propto p_\theta(x_n|z)p(z)$. The VAE has been successfully used in lossless compression and the test compression length is approximately equal to the negative ($\text{log}_2$) ELBO, see the BB-ANS paper for a practical guidence. Therefore, we will then focus on the generalization of the ELBO.

### 1. Generalization Gaps of VAE

This empirical ELBO objective can be further represented as a combination of a model empirical approximation and an amortized inference empirical approximation:

$\frac{1}{N}\sum_{n=1}^N\mathrm{ELBO}(x_n,\theta,\phi)= \underbrace{\frac{1}{N}\sum_{n=1}^N \log p_\theta(x_n)}_{\text{Model learning}}- \underbrace{\frac{1}{N}\sum_{n=1}^N \mathrm{KL}(q_\phi(z|x_n)||p_\theta(z|x_n))}_{\text{Amortized inference}},$

If either the decoder $p_\theta(x|z)$ or the encoder $q_\phi(z|x)$ is overly flexible, it can cause the VAE to overfit to the training data. We can define the ELBO generalization gap (EGG) as the difference between the training and test ELBO

$\text{EGG}\equiv \underbrace{\frac{1}{N}\sum_{N=1}^N \mathrm{ELBO}(x_n,\theta^*_N,\phi^*_N)}_{\text{Training ELBO}}-\underbrace{\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x_m,\theta^*_N,\phi^*_N)}_{\text{Test ELBO}},$

where $\theta^*_N,\phi^*_N$ are defined as the optimal parameters for training the empirical ELBO

$\theta^*_N,\phi^*_N=\arg\max_{\theta,\phi} \frac{1}{N}\sum_{n=1}^N \mathrm{ELBO}(x_n,\theta,\phi).$

We notice that $\phi^*_N$ is also the optimal parameter of the empirical amortized inference

$\phi^*_N=\arg\min_\phi \frac{1}{N}\sum_{n=1}^N \mathrm{KL}\left(q_\phi(z|x_n)||p_{\theta^*_N}(z|x_n)\right).$

For simplicity, considering a flexible amortized inference network, we can assume that for any training data point $x_n\in\mathcal{X}_{train}$, the $q_{\phi^*_N}(z|x_m')$ can generate the optimal posterior within the variational family $\mathcal{Q}$:

$q_{\phi^*_N}(z|x_n)=\arg\min_{q\in \mathcal{Q}}\mathrm{KL} \left(q_{\phi} (z|x_n)||p_{\theta^*_N}(z|x_n)\right).$

However, when $q_{\phi^*}(z|x_n)$ overfits to $\mathcal{X}_{train}$, $q_{\phi^*}(z|x'_m)$ may not be a good approximation to the true posterior $p_{\theta^*_N}(z|x_m')$ for test data $x'_m\in \mathcal{X}_{test}$. This discrepancy can lead to a suboptimal test ELBO.

To illustrate the generalization of the amortized inference, we denote the $\phi^*_M$ as the optimal realizable parameter of the amortized inference for the test data $x'_m\in \mathcal{X}_{test}$:

$\phi^*_M=\arg\min_\phi \frac{1}{M}\sum_{m=1}^M\mathrm{KL}\left(q_\phi(z|x'_m)||p_{\theta^*_N}(z|x'_m)\right).$

Similarly, we assume that $q_{\phi^*_M}(z|x_m')$ can generate the optimal posterior within the variational family $\mathcal{Q}$ for a flexible $q_\phi(z|x)$,

$q_{\phi^*_M}(z|x'_m)=\arg\min_{q\in \mathcal{Q}}\mathrm{KL}\left( q_{\phi} (z|x'_m)||p_{\theta^*_N}(z|x'_m)\right).$

We then define the amortized inference generalization gap (AIGG) as the difference between two averaged KL divergences

$\text{AIGG}\equiv\frac{1}{M}\sum_{m=1}^M \mathrm{KL}\left(q_{\phi^*_N}(z|x_m')||p_{\theta^*_N}(z|x_m')\right) -\frac{1}{M}\sum_{m=1}^M\mathrm{KL}\left(q_{\phi^*_M}(z|x_m')||p_{\theta^*_N}(z|x_m')\right).$

Intuitively, AIGG assesses the proximity of the posterior $q_{\phi^*_N}(z|x'_m)$, — which is derived from the amortized network trained on $\mathcal{X}_{train}$ — to the optimal realizable posterior for $x'_m\in\mathcal{X}_{test}$. Equivalently, the AIGG can also be written as the difference between two ELBOs with $\phi^*_M$ and $\phi^*_N$ respectively:

$\text{AIGG}\equiv \frac{1}{M}\sum_{M=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_M)-\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_N).$

It is important to emphasize that this gap cannot be reduced by simply using a more flexible variational family $\mathcal{Q}$. While this might reduce the $\mathrm{KL} (q_{\phi^*_N}(z|x_n)||p_{\theta^*_N}(z|x_n))$ for the training data $x_n\in \mathcal{X}_{train}$, it would not explicitly encourage a smaller AIGG.

The AIGG is caused by the overfitting of the amortized inference (encoder). To understand the generalization property of the decoder, we can further rewrite the EGG by subtracting and adding the term $\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_M)$:

$\text{EGG}= \frac{1}{N}\sum_{n=1}^N \mathrm{ELBO}(x_n,\theta^*_N,\phi^*_N)-\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_M)\nonumber + \underbrace{\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_M)-\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_N)}_{\text{AIGG}},$

where the second term is just the AIGG. Meanwhile, we define the expression in the first term, which is the difference between the training and test ELBO using the optimal amortized inference parameters $\phi^*_N,\phi^*_M$ respectively, as the Generative Model Generalization Gap (GMGG), which can isolate the degree of which part affect the overall generalization.

$\text{GMGG}= \underbrace{\frac{1}{N}\sum_{n=1}^N \mathrm{ELBO}(x_n,\theta^*_N,\phi^*_N)}_{\text{Training ELBO with optimal inference}}-\underbrace{\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_M)}_{\text{Test ELBO with optimal inference}}$

With this in perspective, we can express the EGG as a combination of both gaps:

$\text{EGG}=\text{GMGG}+\text{AIGG},$

This decomposition highlights that the VAE's generalization performance is influenced by the generalization capabilities of both the generative model and the amortized inference.

### 3. Visualization of the Generalizatio Gaps

To visualize the generalization gaps, we trained a VAE on the Binary MNIST training dataset for 1,000 epochs. For every 100 epochs, we fix the decoder $p_\theta(x|z)$ and only train $q_\phi(z|x)$ for 1k epochs on the test data to obtain the estimation of the optimal amortized posterior $q_{\phi_M^*}$. We can see the VAE overfits the training dataset easily. We also notice that utilizing the test ELBO with the test-time optimal inference strategy, the test BPD (green) appears largely stable, showing just a minor increase during training. This pattern hints at the overfitting of the generative model (decoder) being less pronounced compared to that of the amortized inference network (encoder). Consequently, the primary source of significant overfitting is dominated by the overfitting of the amortized inference network.

Using this observation, we can design algorithm to alleviate the overfitting of the amortized inference. Please see our NeurIPS paper Generalization Gaps in Amortized Inference for the furthur discussion how to reduce the generalization gaps of VAE.

I would to thank Prof. Yee Whye Teh for the constructive feedbacks on this topic during my PhD Viva.