Published on

Generalization Gaps of VAE

1. Generalization of Probabilistic Models

The goal of probabilistic modelling is to learn a model pθ(x)p_\theta(x) to fit the training dataset Xtrain={x1,,xN}pd(x) \mathcal{X}_{train}=\{x_1,\cdots,x_N\}\sim p_d(x), where pd(x)p_d(x) is the unknown data distribution. A common training criterion is the maximum likelihood learning:

maxθ1Nn=1Nlogpθ(xn).\max_\theta \frac{1}{N}\sum_{n=1}^N \log p_\theta(x_n).

The generalization of the pθ(x)p_\theta(x) can be evaluated using the test log-likelihood 1Mm=1Mlogpθ(xm)\frac{1}{M}\sum_{m=1}^M \log p_\theta(x'_m) with the test dataset Xtest={x1,,xM}pd(x)\mathcal{X}_{test}=\{x'_1,…,x'_M\}\sim p_d(x). This definition of generalization has an important practical implicaition: by using the mdoel pθ(x)p_\theta(x), we can design a compression algorithm to compress a data xx' losslessly into a binary string whose length is approximately log2pθ(x)-\log_2 p_\theta(x'). Therefore, a better generalization, as measured by the test log-likelihood, translates to enhanced performance in practical lossless compression. A detailed discussion on the this connection can be found in David MacKay's textbook.

Furthermore, it's noteworthy that this perspective on lossless compression has been instrumental in recent efforts to understand the generalization properties of large language models, see the video by Ilya Sutskever for a discussion.

2. Introduction of VAE

Variational Auto-Encoder (VAE) is a latent variable model with the form

pθ(x)=pθ(xz)p(z)dz,p_\theta(x)=\int p_\theta(x|z)p(z)dz,

where the pθ(xz)p_\theta(x|z) is the deocder and p(z)p(z) is the prior. When pθ(xz)p_\theta(x|z) is parameterized by a neural network, the integration over zz is is usually not tractable. A lower-bound of the log-likelihood can be used, which is refered to as the ELBO

1Nn=1Nlogpθ(xn)1Nn=1N(qϕ(zx)logpθ(xnz)dzKL(qϕ(zxn)p(z)))1Nn=1NELBO(xn,θ,ϕ),\frac{1}{N}\sum_{n=1}^N \log p_\theta(x_n)\geq \frac{1}{N}\sum_{n=1}^N \left(\int q_\phi(z|x)\log p_\theta(x_n|z)dz-\mathrm{KL}(q_\phi(z|x_n)||p(z))\right) \equiv \frac{1}{N}\sum_{n=1}^N \mathrm{ELBO}(x_n,\theta,\phi),

where pθ(zxn)p_\theta(z|x_n) is the true posterior pθ(zxn)pθ(xnz)p(z)p_\theta(z|x_n)\propto p_\theta(x_n|z)p(z). The VAE has been successfully used in lossless compression and the test compression length is approximately equal to the negative (log2\text{log}_2) ELBO, see the BB-ANS paper for a practical guidence. Therefore, we will then focus on the generalization of the ELBO.

1. Generalization Gaps of VAE

This empirical ELBO objective can be further represented as a combination of a model empirical approximation and an amortized inference empirical approximation:

1Nn=1NELBO(xn,θ,ϕ)=1Nn=1Nlogpθ(xn)Model learning1Nn=1NKL(qϕ(zxn)pθ(zxn))Amortized inference, \frac{1}{N}\sum_{n=1}^N\mathrm{ELBO}(x_n,\theta,\phi)= \underbrace{\frac{1}{N}\sum_{n=1}^N \log p_\theta(x_n)}_{\text{Model learning}}- \underbrace{\frac{1}{N}\sum_{n=1}^N \mathrm{KL}(q_\phi(z|x_n)||p_\theta(z|x_n))}_{\text{Amortized inference}},

If either the decoder pθ(xz)p_\theta(x|z) or the encoder qϕ(zx)q_\phi(z|x) is overly flexible, it can cause the VAE to overfit to the training data. We can define the ELBO generalization gap (EGG) as the difference between the training and test ELBO

EGG1NN=1NELBO(xn,θN,ϕN)Training ELBO1Mm=1MELBO(xm,θN,ϕN)Test ELBO,\text{EGG}\equiv \underbrace{\frac{1}{N}\sum_{N=1}^N \mathrm{ELBO}(x_n,\theta^*_N,\phi^*_N)}_{\text{Training ELBO}}-\underbrace{\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x_m,\theta^*_N,\phi^*_N)}_{\text{Test ELBO}},

where θN,ϕN\theta^*_N,\phi^*_N are defined as the optimal parameters for training the empirical ELBO

θN,ϕN=argmaxθ,ϕ1Nn=1NELBO(xn,θ,ϕ).\theta^*_N,\phi^*_N=\arg\max_{\theta,\phi} \frac{1}{N}\sum_{n=1}^N \mathrm{ELBO}(x_n,\theta,\phi).

We notice that ϕN\phi^*_N is also the optimal parameter of the empirical amortized inference

ϕN=argminϕ1Nn=1NKL(qϕ(zxn)pθN(zxn)).\phi^*_N=\arg\min_\phi \frac{1}{N}\sum_{n=1}^N \mathrm{KL}\left(q_\phi(z|x_n)||p_{\theta^*_N}(z|x_n)\right).

For simplicity, considering a flexible amortized inference network, we can assume that for any training data point xnXtrainx_n\in\mathcal{X}_{train}, the qϕN(zxm)q_{\phi^*_N}(z|x_m') can generate the optimal posterior within the variational family Q\mathcal{Q}:

qϕN(zxn)=argminqQKL(qϕ(zxn)pθN(zxn)).q_{\phi^*_N}(z|x_n)=\arg\min_{q\in \mathcal{Q}}\mathrm{KL} \left(q_{\phi} (z|x_n)||p_{\theta^*_N}(z|x_n)\right).

However, when qϕ(zxn)q_{\phi^*}(z|x_n) overfits to Xtrain\mathcal{X}_{train}, qϕ(zxm)q_{\phi^*}(z|x'_m) may not be a good approximation to the true posterior pθN(zxm)p_{\theta^*_N}(z|x_m') for test data xmXtestx'_m\in \mathcal{X}_{test}. This discrepancy can lead to a suboptimal test ELBO.

To illustrate the generalization of the amortized inference, we denote the ϕM\phi^*_M as the optimal realizable parameter of the amortized inference for the test data xmXtestx'_m\in \mathcal{X}_{test}:

ϕM=argminϕ1Mm=1MKL(qϕ(zxm)pθN(zxm)).\phi^*_M=\arg\min_\phi \frac{1}{M}\sum_{m=1}^M\mathrm{KL}\left(q_\phi(z|x'_m)||p_{\theta^*_N}(z|x'_m)\right).

Similarly, we assume that qϕM(zxm)q_{\phi^*_M}(z|x_m') can generate the optimal posterior within the variational family Q\mathcal{Q} for a flexible qϕ(zx)q_\phi(z|x),

qϕM(zxm)=argminqQKL(qϕ(zxm)pθN(zxm)). q_{\phi^*_M}(z|x'_m)=\arg\min_{q\in \mathcal{Q}}\mathrm{KL}\left( q_{\phi} (z|x'_m)||p_{\theta^*_N}(z|x'_m)\right).

We then define the amortized inference generalization gap (AIGG) as the difference between two averaged KL divergences

AIGG1Mm=1MKL(qϕN(zxm)pθN(zxm))1Mm=1MKL(qϕM(zxm)pθN(zxm)). \text{AIGG}\equiv\frac{1}{M}\sum_{m=1}^M \mathrm{KL}\left(q_{\phi^*_N}(z|x_m')||p_{\theta^*_N}(z|x_m')\right) -\frac{1}{M}\sum_{m=1}^M\mathrm{KL}\left(q_{\phi^*_M}(z|x_m')||p_{\theta^*_N}(z|x_m')\right).

Intuitively, AIGG assesses the proximity of the posterior qϕN(zxm)q_{\phi^*_N}(z|x'_m), — which is derived from the amortized network trained on Xtrain\mathcal{X}_{train} — to the optimal realizable posterior for xmXtestx'_m\in\mathcal{X}_{test}. Equivalently, the AIGG can also be written as the difference between two ELBOs with ϕM\phi^*_M and ϕN\phi^*_N respectively:

AIGG1MM=1MELBO(xm,θN,ϕM)1Mm=1MELBO(xm,θN,ϕN). \text{AIGG}\equiv \frac{1}{M}\sum_{M=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_M)-\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_N).

It is important to emphasize that this gap cannot be reduced by simply using a more flexible variational family Q\mathcal{Q}. While this might reduce the KL(qϕN(zxn)pθN(zxn))\mathrm{KL} (q_{\phi^*_N}(z|x_n)||p_{\theta^*_N}(z|x_n)) for the training data xnXtrainx_n\in \mathcal{X}_{train}, it would not explicitly encourage a smaller AIGG.

The AIGG is caused by the overfitting of the amortized inference (encoder). To understand the generalization property of the decoder, we can further rewrite the EGG by subtracting and adding the term 1Mm=1MELBO(xm,θN,ϕM)\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_M):

EGG=1Nn=1NELBO(xn,θN,ϕN)1Mm=1MELBO(xm,θN,ϕM)+1Mm=1MELBO(xm,θN,ϕM)1Mm=1MELBO(xm,θN,ϕN)AIGG,\text{EGG}= \frac{1}{N}\sum_{n=1}^N \mathrm{ELBO}(x_n,\theta^*_N,\phi^*_N)-\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_M)\nonumber + \underbrace{\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_M)-\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_N)}_{\text{AIGG}},

where the second term is just the AIGG. Meanwhile, we define the expression in the first term, which is the difference between the training and test ELBO using the optimal amortized inference parameters ϕN,ϕM\phi^*_N,\phi^*_M respectively, as the Generative Model Generalization Gap (GMGG), which can isolate the degree of which part affect the overall generalization.

GMGG=1Nn=1NELBO(xn,θN,ϕN)Training ELBO with optimal inference1Mm=1MELBO(xm,θN,ϕM)Test ELBO with optimal inference\text{GMGG}= \underbrace{\frac{1}{N}\sum_{n=1}^N \mathrm{ELBO}(x_n,\theta^*_N,\phi^*_N)}_{\text{Training ELBO with optimal inference}}-\underbrace{\frac{1}{M}\sum_{m=1}^M \mathrm{ELBO}(x'_m,\theta^*_N,\phi^*_M)}_{\text{Test ELBO with optimal inference}}

With this in perspective, we can express the EGG as a combination of both gaps:

EGG=GMGG+AIGG, \text{EGG}=\text{GMGG}+\text{AIGG},

This decomposition highlights that the VAE's generalization performance is influenced by the generalization capabilities of both the generative model and the amortized inference.

3. Visualization of the Generalizatio Gaps

To visualize the generalization gaps, we trained a VAE on the Binary MNIST training dataset for 1,000 epochs. For every 100 epochs, we fix the decoder pθ(xz)p_\theta(x|z) and only train qϕ(zx)q_\phi(z|x) for 1k epochs on the test data to obtain the estimation of the optimal amortized posterior qϕMq_{\phi_M^*}.

summary

We can see the VAE overfits the training dataset easily. We also notice that utilizing the test ELBO with the test-time optimal inference strategy, the test BPD (green) appears largely stable, showing just a minor increase during training. This pattern hints at the overfitting of the generative model (decoder) being less pronounced compared to that of the amortized inference network (encoder). Consequently, the primary source of significant overfitting is dominated by the overfitting of the amortized inference network.

Using this observation, we can design algorithm to alleviate the overfitting of the amortized inference. Please see our NeurIPS paper Generalization Gaps in Amortized Inference for the furthur discussion how to reduce the generalization gaps of VAE.

I would to thank Prof. Yee Whye Teh for the constructive feedbacks on this topic during my PhD Viva.