1. Generalization of Probabilistic Models
The goal of probabilistic modelling is to learn a model to fit the training dataset , where is the unknown data distribution. A common training criterion is the maximum likelihood learning:
The generalization of the can be evaluated using the test log-likelihood with the test dataset . This definition of generalization has an important practical implicaition: by using the mdoel , we can design a compression algorithm to compress a data losslessly into a binary string whose length is approximately . Therefore, a better generalization, as measured by the test log-likelihood, translates to enhanced performance in practical lossless compression. A detailed discussion on the this connection can be found in David MacKay's textbook.
Furthermore, it's noteworthy that this perspective on lossless compression has been instrumental in recent efforts to understand the generalization properties of large language models, see the video by Ilya Sutskever for a discussion.
2. Introduction of VAE
Variational Auto-Encoder (VAE) is a latent variable model with the form
where the is the deocder and is the prior. When is parameterized by a neural network, the integration over is is usually not tractable. A lower-bound of the log-likelihood can be used, which is refered to as the ELBO
where is the true posterior . The VAE has been successfully used in lossless compression and the test compression length is approximately equal to the negative () ELBO, see the BB-ANS paper for a practical guidence. Therefore, we will then focus on the generalization of the ELBO.
1. Generalization Gaps of VAE
This empirical ELBO objective can be further represented as a combination of a model empirical approximation and an amortized inference empirical approximation:
If either the decoder or the encoder is overly flexible, it can cause the VAE to overfit to the training data. We can define the ELBO generalization gap (EGG) as the difference between the training and test ELBO
where are defined as the optimal parameters for training the empirical ELBO
We notice that is also the optimal parameter of the empirical amortized inference
For simplicity, considering a flexible amortized inference network, we can assume that for any training data point , the can generate the optimal posterior within the variational family :
However, when overfits to , may not be a good approximation to the true posterior for test data . This discrepancy can lead to a suboptimal test ELBO.
To illustrate the generalization of the amortized inference, we denote the as the optimal realizable parameter of the amortized inference for the test data :
Similarly, we assume that can generate the optimal posterior within the variational family for a flexible ,
We then define the amortized inference generalization gap (AIGG) as the difference between two averaged KL divergences
Intuitively, AIGG assesses the proximity of the posterior , — which is derived from the amortized network trained on — to the optimal realizable posterior for . Equivalently, the AIGG can also be written as the difference between two ELBOs with and respectively:
It is important to emphasize that this gap cannot be reduced by simply using a more flexible variational family . While this might reduce the for the training data , it would not explicitly encourage a smaller AIGG.
The AIGG is caused by the overfitting of the amortized inference (encoder). To understand the generalization property of the decoder, we can further rewrite the EGG by subtracting and adding the term :
where the second term is just the AIGG. Meanwhile, we define the expression in the first term, which is the difference between the training and test ELBO using the optimal amortized inference parameters respectively, as the Generative Model Generalization Gap (GMGG), which can isolate the degree of which part affect the overall generalization.
With this in perspective, we can express the EGG as a combination of both gaps:
This decomposition highlights that the VAE's generalization performance is influenced by the generalization capabilities of both the generative model and the amortized inference.
3. Visualization of the Generalizatio Gaps
To visualize the generalization gaps, we trained a VAE on the Binary MNIST training dataset for 1,000 epochs. For every 100 epochs, we fix the decoder and only train for 1k epochs on the test data to obtain the estimation of the optimal amortized posterior .
We can see the VAE overfits the training dataset easily. We also notice that utilizing the test ELBO with the test-time optimal inference strategy, the test BPD (green) appears largely stable, showing just a minor increase during training. This pattern hints at the overfitting of the generative model (decoder) being less pronounced compared to that of the amortized inference network (encoder). Consequently, the primary source of significant overfitting is dominated by the overfitting of the amortized inference network.
Using this observation, we can design algorithm to alleviate the overfitting of the amortized inference. Please see our NeurIPS paper Generalization Gaps in Amortized Inference for the furthur discussion how to reduce the generalization gaps of VAE.
I would to thank Prof. Yee Whye Teh for the constructive feedbacks on this topic during my PhD Viva.