VAE

TagsCS 236

The idea

Some distributions are better explained with a latent variable, such that p(xz)p(x | z) is easier to model than p(x)p(x).

The simple model

Mixture of gaussians is a very simple latent model where p(xz)N(μz,σz)p(x | z) \sim N(\mu_z, \sigma_z) and zCat(k)z \in Cat(k). We know from 229 that we can apply EM algorithm to iteratively cluster the data and perform maximum likelihood. We need to create a posterior p(zix)p(z_i | x), which is possible because zz is categorical.

VAE is an infinite mixture of gaussians

Because every zz defines a unique gaussian, the marginal p(x)p(x) can be interpreted as a mixture of an uncountably infinite number of gaussians, which is why p(x)p(x) is so expressive.

Moving to Variational Inference

In our derivation today, we are taking inspiration from EM but we are doing it in a more sophisticated way. Or, rather, we are forced to do it differently because we are now defining zz as a continuous random variable.

There is a big pro to this: the marginal p(x)p(x) becomes really complex and flexible because it’s a mixture of an infinite number of gaussians!

Moving towards the Lower Bound

Attempt 1: Naive Monte Carlo

The idea here is that we can approximate the marginalization by doing

by definition of expectation, which means that our Monte Carlo estimate would be

However, because we are in such a large sample space, the true pθ(x,z)p_\theta(x, z) is very low, which means that this is very high variance. You may never “hit” the right completions.

Attempt 2: Importance Sampling

If we had a surrogate distribution q(z)q(z), we could rewrite the problem as

And we could approximate with sample average

Because of the linearity of expectation, it’s trivial to show that this is an unbiased estimator of p(x).

Adding the Log (and deriving ELBO)

Because we want the log-likelihood, we take the log of the whole empirical expectation

Now this is not necessarily the best expression to follow, but Jensen’s inequality states that this is a lower bound to

Which we call the Evidence Lower Bound (ELBO)

We can split this quantity up into

And we can show that the equality holds if q=p(zx)q = p(z | x).

Tightness of ELBO bound

The tightness of the bound we can also determine! If we ditch the Jensen’s inequality and move backwards from the KL divergence between q(z)q(z) and p(zx)p(z| x), we get

From RHS we can reassemble the ELBO, but because we have an equality, we can say that the bound gets tighter with DKL(qp(zx))D_{KL}(q || p(z|x)). Or, in other words:

logp(x)=ELBO+Dkl(q(z)p(zx;θ))\log p(x) = ELBO+ D_{kl}(q(z)||p(z|x;\theta))

This is a nice diagram

How do we actually train?

So recall that

How do you actually optimize for θ,ϕ\theta, \phi?

💡
So later we’re going to do amortized learning with q(zx)q(z | x) but for now ϕ\phi is individual

Variational Learning

Two facts:

  1. We can increase L\mathcal{L} by optimizing over θ\theta. Note that this ELBO holds for any ϕ\phi, which means that if we optimize for θ\theta at any time, we’re doing something about the lower bound
  1. Note that we can also increase L\mathcal{L} by optimizing over ϕ\phi.
💡
We CAN’T optimize over the ELBO + KL objective, because the posterior can be intractable

Concretely, this means that you perform alternative optimization on ϕ\phi and θ\theta, much like in the EM algorithm!

The key difference from the EM algorithm is that we’re not fitting ϕ\phi towards the posterior explicitly. However, implicitly, by increasing the L\mathcal{L} without changing θ\theta, we’re tightening the bound. Think about this for a second. The left hand side, p(x;θ)p(x ; \theta) is only dependent on θ\theta. So, if you optimize over ϕ\phi and L\mathcal{L} increases, the only way this happens is if the KL term goes down! So you’re performing the E step implicitly.

Computing Gradients

The gradient respect to θ\theta is easy:

But what about ϕ\phi? How do we optimize through the expectation?

Idea 1: use Monte Carlo sampling:

This works because you’ve actually nullified the gradient in the expectation sample. This is OK, but we can do better. We can improve using REINFORCE (policy gradient), but for now, there’s a cheap trick called the reparameterization trick when your distribution can be split into an affine sum of means and scaled variances.

If q(z;ϕ)N(μϕ,σϕ2)q(z;\phi)\sim N(\mu_\phi, \sigma_\phi^2), then zqz\sim q is the same as z=u+σϵ,ϵN(0,1)z = u + \sigma \epsilon, \epsilon \sim N(0, 1). Note: it’s not σ2\sigma^2, it’s σ\sigma that you scale ϵ\epsilon by. This is important because you’ve booted the parameters out of the distribution! This gets you

Which is computable through Monte Carlo estimation. This is far lower variance than the unscaled version, and it preserves parameter dependence!

Amortized Inference

The idea here is to generalize q(z;ϕ)q(z;\phi) to one neural network for all datapoints, q(zx;ϕ)q(z|x; \phi). Everything else stays the same and you get ELBO

Perspective as autoencoder

If you play with the ELBO and do some rearrangement, you will get this:

The qq is your encoder. You can take some input and map it to a latent space qϕ(zx)q_\phi(z|x’).

The pp is your decoder. Sample any zz and you can get p(xz)p(x | z).

The first term in the rearrangement of the ELBO is the reconstruction objective. The second term is a regularization objective. What is p(z)p(z)? Well, it’s some prior. We can assume that this prior is a simple distribution, like a gaussian. We constrain the output distribution of q(zx)q(z|x) to be close to the prior p(z)p(z).

This prior doesn’t get drawn out of thin air; when you compute p(z,x)p(z, x) in the original setup, you need to factor as p(xz)p(z)p(x | z)p(z), and here is where your prior comes in.

💡
In general, when you code a VAE, this is the formulation that you use. Also, you usually use a one-sample Monte Carlo approximation of the first term, and a closed form expression for the second term.

We call this KL term the variational bottleneck

Regularization and Posterior Collapse

So you might be wondering: if we’re constraining qq to be close to the prior pp, wouldn’t that destroy the purpose of qq? Yes! But L\mathcal{L} can be maximized without this KL being zero, because this qq modifies the first term too. In fact, in an ideal system, you want qq to not be the same as pp.

If qq approaches pp, this is known as posterior collapse, and it means that the intput is no longer giving information to the latent space.