Variational Inference (practicals) ⭐

TagsCS 330

What do we want?

Given a dataset D(x)D(x), can we model p(x)p(x), the distribution that DD was drawn from?

We can do this by doing maximum likelihood, i.e. maxlogpθ(x)\max \log p_\theta(x), and this is pretty easy for things like gaussians for categoricals. In fact, MSE loss and cross entropy are all approaches to do exactly this.

But what about more complicated distributions?

Latent variable models

One thing that helps us is the latent variable model, in which you have two simple models, p(z)p(z) and p(xz)p(x | z), but when you compose them, you get a complicated distribution.

GMMs

A simple example is a Gaussian Mixture Model, in which p(z)p(z) is a categorial distribution and p(xz)p(x | z) is a gaussian. In this case, zz is a discrete variable, and it allows you to get some degree of coverage with p(x)=zp(xz)p(z)p(x) = \sum_z p(x | z) p(z).

You can even program a neural network to do GMM regression, which amounts to

p(yx)=zp(yx,z)p(zx)p(y | x) = \sum_z p(y | x, z)p(z | x)

and the model outputs means, covariances, and weights.

Moving to general models

Now, GMMs are good, but they aren’t complete. What if we wanted to fit an arbitrary distribution? Well, suppose that you have p(z)Np(z) \sim N, and p(xz)N(μ(z),σ(z))p(x | z) \sim N(\mu(z), \sigma(z))?

The two distributions are simple, but the μ(),σ()\mu(), \sigma() are arbitraily complicated. You can actually show that this composition, p(x)=p(xz)p(z)dzp(x) = \int p(x | z) p(z)dz yields any arbitrary distribution.

A good intuition is that for each slice of zz, we can create a “stamp” of a gaussian in the p(x)p(x) graph at any location and any width. You repeat these “stamps” an infinitely many times, and you have infinite resolution.

Once trained, you generate a sample from p(x)p(x) by sampling from zp(z)z \sim p(z), running it through the μ(z),σ(z)\mu(z), \sigma(z) network, and then sampling from N(μ(z),σ(z))N(\mu(z), \sigma(z)).

To evaluate the likelihood of a given sample, it’s a little more difficult. You can approximate p(x)=Ep(z)[p(xz)]p(x) = E_{p(z)}[p(x | z)] and do a monte carlo estimation

Training latent variable models: what you can’t do

You might try this

but the integral is intractable. You could try the same monte carlo Ep(z)[p(xz)]E_{p(z)}[p(x | z)], but this is very sample inefficient.

So…what can we do? Well, we can propose a lower bound to the likelihood and optimize that. As it turns out, it has some really nice theoretical properties. In the next section, we will see how this comes to be!

Variational Inference

Importance sampling

We can start by settting things up as importance sampling. The key problem with estimating the integral as p(x)=Ep(z)[p(xz)]p(x) = E_{p(z)}[p(x | z)] is that p(xz)p(x | z) may be very small for a lot of values of zz, or it may have a very weird coverage that is hard to get at. It’s the classic problem of trying to hit a bullseye but you’re essentially throwing random darts.

Now, if we can sample WRT a distribtuion that models the most likely zz given xx (i.e. p(zx)p(z | x)), now we’re talking! Now, it’s like we’re using a very calibrated dart throwing method, which allows for greater sample efficiency.

What this looks like is the following. We start with a (bad) assumption that we have a variational approximation qip(zxi)q_i \approx p(z | x_i) for every data point. We will see how to improve on this later

Now, part of the variational approximation is that qi=Nq_i = \mathcal{N}, which is not necessairlly true. But again, you can think of qq as the dart thrower. You don’t need to have a professional dart thrower. You just need to have someone good enough to hit the bullseye once in a while.

Creation of ELBO

By using Jensen’s inequality, we know that logE[y]E[logy]\log E[y] \geq E[\log y], which means that we get the following lower bound on the main objective

which actually becomes

and the last term is the entropy. Now, you can more or less use this directly, but we want to understand what exactly this ELBO is doing!

The first part tries to create some p(x,z)p(x, z) whose likelihood is maximized under q(z)q(z). This actually yields a degenerate solution as you can make the qq as narrow as possible, centered around the mode of p(x,z)p(x, z). The second term makes sure that we have a sampling distribution that is as wide as possible.

So, by maximizing the ELBO objective, you can understand it as jointly optimizing the likelihood of the data and making the likelihood estimator as correct as possible.

At this point, you can also start thinking of it as an approximate EM algorithm, with L(θ,q)L(\theta, q) where θ\theta is the pθ(xz)p_\theta(x | z) and qq is the collection of variational approximators. The E step is fitting the qq to be wide and centered around the appropiate pθ(xz)p(z)p_\theta(x | z)p(z), and the M step is maximizing this combined term using qq as the variational approximator. Just like the standard EM algorithm, it is a process of coordinate ascent.

Tightness of the lower bound

Again, while we can use this loss out of the box, we want to investigate a little more into how much bang we’re getting for our buck, and also to continue drawing the EM parallel.

We started this analysis by claiming that we want qq to be as similar to p(zxi)p(z | x_i) as possible. Now, bear in mind that because it’s importance sampling, qq could be anything. We claim now that we have already encoded this restriction in the ELBO. To show this, let’s compute DKL(qi(z)p(zxi))D_{KL}(q_i(z) || p(z | x_i)).

So, we get that

which means two things

  1. If q(z)=p(zx)q(z) = p(z | x), then the bound is tight
  1. yet again, because KL divergence is positive, this constructed LL is a lower bound (same result, different stories).

This is an entirely different derivation, but it highlights the bounding. Furthermore, we can use this objective to highlight the EM-esque style of variational inference. We can rewrite the equation as

L(p,qi)=log(p(xi))DKL(qi(z)p(zxi))L(p, q_i) = \log(p(x_i)) - D_{KL}(q_i(z) ||p(z | x_i))

which means that when you’re optimizing the qiq_i of the ELBO, you’re just minimizing the KL divergence between the variational distribution and the posterior, which is the “E” step. When you’re optimizing for the pp on the ELBO, you’re optimizing for logp(xi)\log p(x_i), dragged behind by some KL divergence. Because the pp is changing, even if you had a very tight bound at the beginning, during the M step, the bound will increase in size. This is the same for the EM algorithm.

Amortized Variational Inference

So far, there is one problem. The complexity of our model grows with the number of data points, because we need to keep track of a distribution qiq_i for every point. As it turns out, the solution is very simple! Just use a network qϕ(zx)q_\phi(z | x) in place of all the individual distributions

This is very easy for θ\theta optimization ( the M step) as you just sample through qϕq_\phi and take the gradient. For the ϕ\phi step, we run into a problem

the model is in the sampler! Uh oh…that doesn’t look good.

Reparameterization trick

As it turn out, we approximated qϕq_\phi as a gaussian for a reason. For gaussians, we have

N(μ,σ)=μ+ϵσ,ϵN(0,1)N(\mu, \sigma) = \mu + \epsilon \sigma, \epsilon \sim N(0, 1)

which means that we can pull the qϕq_\phi into the expectation!

There are other methods of dealing with the sampler issue. If we don’t want to use a gaussian, we could use things like REINFORCE, which handles a similar issue for arbitrary sampler functions.

Another way of understanding the bound

So, once again, we are totally done with the derivation. But we can also look at things a different way, which helps motivate the variational autoencoder. You can massage the original ELBO loww as follows:

what this means is that the objective can be written as

The first term you can think of a reconstruction objective. You take in the input xx, encode in through a sample in qϕ(zxi)q_\phi(z | x_i), and then try to maximize the decoding through pθ(xiz)p_\theta(x_i | z).

The second term you can think of as a regularizer. The distribution q(zxi)q(z | x_i) should be as similar to the non-information-bearing prior p(z)p(z), which is a gaussian. So, put together, we have our variational autoencoder!

Don’t let this scare you! It’s literally just a neural network autoencoder with a sampling procedure added at the bottleneck and a special regularizer.