Variational Inference (theory)

TagsCS 228InferenceLearning

Why and what is variational inference?

The term comes from the calculus of variations, in which you optimize a function (and not numbers) to fit something.

Our main problem comes from marginal inference. Given a joint distribtuion, can we compute something like p(x1e)p(x_1 | e)? This requires marginalization, which is an exponential thing. We can approximate this posterior with a simpler distribution. Therefore, we reframe inference as optimization.

The key insight

The prior distribution is typically complicated because it encompasses so much about the world. However, once you condition on some evidence p(xe)p(x | e), the posterior is typically simpler.

If we are looking at MRFs, the posterior is proportional to the joint distribution but you just clamp the observations in the factors, which of course makes it simpler. But it’s unnormalized, and the normalization constant p(e)p(e) is hard to compute.

Methods of optimization

The goal is to find some qq that is close to p(xe)p(x | e). We compare this through the DKL and we compute argminqD(pq)\arg \min_q D(p || q) or argminqD(qp)\arg \min_q D(q || p).

The first one is called the M-projection of PP, and the second one is called the I-projection. These two objectives actually yield quite different results.

M-projection

Here, we are essentially optimizing over an expected value over pp. This means that if qq is very small over some xpx \sim p, then the DKL is very large. As such, the qq^* is forced to have a very wide support. Philosophically speaking, we are finding the qq such that everything that is correct is encapsulated by it.

As we learned from last time with exponential families, M-projection is equivalent to setting the moment of the sufficient statistic equal to each other.

Below is an example of M-projection of gaussians, where qq is constrained to have a diagonal covariance. You can clearly see that the means are matching

I-projection

In this case, we are taking the expectation across qq, which means that we want to encapsulate the highest density within the support of qq. Essentially, this method is “everything I see is correct”. This has the result of excluding things, as can be seen below. When given a multimodal distribtuion, I-projection will settle on one of the modes, while M-projection will take the center

Mean Field inference

What do we count as “simple” for qq? Well, we can make one key assumption: qq is fully factorizable. In other words, each variable under qq is independent, and the MRF graph is just a bunch of vertices without any edges.

We define QQ as the family of all fully-factored distributions, and we can try solving M projection.

M-projection derivation

We can set up the DKL and use a little multiply-by-one trick, using qM=ip(xi)q_M = \prod_i p(x_i). We have a key question: what qq minimizes the difference?

We get this upshot: to minimize divergence between the simple qq and the complicated pp, just let q(xi)=p(xi)q(x_i) = p(x_i).

From our previous discussion on M-projection, we also know that if QQ is an exponential family, then we have

Eq[f(x)]=Ep[f(x)]E_{q^*}[f(x)] = E_p[f(x)]

If we were to model this, then we would probably use something like an exponential family. However, this defeats the purpose, because marginalization (and sampling) is generally intractable from pp (or else why would you need this simpler distribution?)

Variational Models through I-projections

Now, we saw that the M-projection approach was simple because you just needed to get qq to match with some distribtuion, or if qq were an exponential family, we just needed to match the moments. However, with this simplicity comes a difficulty of expectation. The distributions are just intractable to sample from / match to.

Here, we look at a different approach that is the pinnacle of variational inference, in which you turn an inference problem into an optimization problem.

The exact inference is finding the partition function, but this is a general framework and we will talk about why we care, after the derivation.

I-projection

Let’s say that we had a generic exponential family distribution pp (includes any MRF). Moreover, we know the factors, but we don’t know the normalizing constant because it is intractable to marginalize.

You can derive the I-projection starting from the DKL and you get

If we use the property that DKL is at least zero, we get that the log partition function is bounded by

let’s unpack this

which means that you can approximate the partition function by picking the best qq

again, this is variational because we are optimizing over a function qq.

First, what are we doing by optimizing qq? Well, we are pushing qq closer to pp, which is our big goal. This is an optimization problem. It’s worth noting that if pp were an exponential model, you don’t consider the logZ(θ)\log Z(\theta) when optimizing over qq because it doesn’t depend on qq.

But we are also pushing closer to logZ(θ)\log Z(\theta), which is an inference problem. For example, if your target distribution were P(xe)P(x | e) and P(x,e)P(x, e) were a directed model, then the Z(θ)Z(\theta) would be P(e)P(e), which can be otherwise intractable to compute. By optimizing qq, you kill two birds with one stone. You get an approximate distribution and an inference!

We call the right hand side as the variational lower bound. It can also be written as

logZ(θ)Eq(x)[logp(x)logq(x)]\log Z(\theta)\geq E_{q(x)}[\log p(x) - \log q(x)]

Optimizing the I-projection: initial idea

We can rearrange the objective through the summations

One immediate thought is that we can just find the xx^* that maximizes the cθc(xc))\sum_c \theta_c(x_c)), and then let q(x)=1q(x^*) = 1. This basically looks at the distribtuion and makes a dirac delta function where the mode occurs.

This is an intrusive thought, but we see that it doesn’t actually work because we have the H(q(x))H(q(x)) which encourages a “fat” distribution.

It is, however, a good approximation if pp peaks sharply.

Variational algorithms

We now focus on a more rigorous treatment of how we might optimize the I-projection. There are two approaches: mean-field and relaxation.

Mean-field relies on the factorization of qq into individual components, and relaxation removes some assumptions to make the optimization easier. First, we will look at mean field

Mean field variational

We start with the naive mean field, in which you take some connected network and turn it into an independent set of variables

Naive Mean-field: the setup

Let’s go back to our original assumption, which is that qq is fully factorizable. If this is the case, then our objective can be simplified. We start with

where q(xc)=icqi(xi)q(x_c) = \prod_{i\in c} q_i(x_i).

The entropy term

Right now, the entropy term looks a little nasty. Can we factor it out, given that we have our mean field? Well, let’s start from the definitions

Now, we realize that because q(x)q(x) is factorizable, we can push the summation into the product. we actually discover that in doing so, everything else marginalizes to 1. The notation is a little confusing, but think about pushing the summation inside.

The objective

And our objective becomes

subject to the standard constraints of qq as a distribution.

This is ultimately a lower bound on the partition function

Block coordinate ascent

This optimization problem isn’t really fun to deal with, but we can use block coordinate ascent . The procedure is pretty straightforward

  1. initialize distributions randomly
  1. iterate through each of the variables ii
  1. maximize the objective WRT qiq_i
  1. repeat

This is feasible because you remove consideration of all variables that are not immediate neighbors of variable xix_i. You essentially form a markov blanket with distributions and then you optimize the current distribution based on the surrounding distributions. One step looks like this (derived from the lagrangian)

This is guarenteed to converge, though we might get trapped in a local minimum

Coordinate ascent and gibbs sampling

In both gibbs sampling and coordinate ascent we are doing some sort of markov blanket thing

In gibbs sampling, we would “select” the elements in the markov blanket by using their previous estimations

In this coordinate ascent, we are not selecting, but rather doing a soft selection based on a previous estimation of the distribution qq

Accuracy of approximation

Mean field approimation can yield very wrong answers if the variables have a strong relationship between them. A classic example is the XOR distribution, where p(a,b)=0.5ϵp(a, b) = 0.5 - \epsilon if aba\neq b and p(a,b)=ϵp(a, b) = \epsilon if a=ba = b. Mean field approximation will fail here for obvious reasons.

Beyond mean field

We can make a mean field that has some sort of connections, even if it’s less connected than the true distribution. This can help with better variational approximations. We will talk about more of this in the next section

Relaxation algorithms

Instead of using the mean field assumption, the next variational algorithm we look at involves relaxation. This means approximating something to achieve tractability. Honestly, this makes sense because the mean field algorithm makes a rather strong assumption, don’t you think?

For the problems below, we are assuming that everything is in the exponential family. MRFs are also in the exponential family; the f(x)f(x) is just an indicator vector, and θ\theta are the weights. So

Reframing

Because we have seen that moments are helpful, let’s try to reframe our efforts in terms of moments. So we have previously established that

Here, we define

μq=Eq[f(x)]\mu_q = E_q[f(x)]

We call this the mean parameters or the marginals of q(x)q(x). Because the f(x)f(x) is an indicator variable, it is indeed a marginal. So μq[i]\mu_q[i] might represent qj(xj=ai)q_j(x_j = a_i), or something like that (depends on how you encode ff).

Note that this is a little different than our above examples of E[θ(x)]E[\theta(x)] because we generalized it to exponential families, which have a linear θ\theta that can be pulled outside.

Now, what are we trying to do here? Well, maybe something a bit weird. Instead of optimizing over qq, because we don’t know how to do this, maybe we can optimize over this μq\mu_q instead. Hmm.

Marginal polytope

This MM is pretty special. It’s called the marginal polytope , for reasons that will become apparent in a second.

It’s the set of all possible μ\mu that can arise from Eq[f(x)]E_q[f(x)]. Now, you can think of the expectation as a convex combination of f(x)f(x) vectors, and the choice of qq determines the degree of convex combination. Therefore, if you see Eq[f(x)]E_q[f(x)] as a geometric shape, you can imagine it as a convex hull between points representing each f(x)f(x).

Think about why this is the case for a second. The points represent different f(x)f(x), and the qq is just deciding how to combine them into the final μ\mu. And again, this μ\mu is literally a marginal, so you can imagine that we are making a marginal distribution from the original indicator distributions.

Reframining in terms of the polytope

This is just a continuation from before. We needed to discuss the MM as a geometric interpretation, but algebraically, this is not very stimulating. We just let

Intuitively, we are trying to optimize some marginal distribution, in the hopes that we optimize the original distribtuion.

But this is kinda dumb because the marginal polytope is complex to describe, and H(μ)H(\mu) seems hopeless, because there is literally an inference in the max\max.

To continue, we must relax some things!

The key approximations

  1. MM is a relaxation of the marginal polytope such that it includes MM but is not exactly MM
  1. We replace H(μ)H(\mu) with an approximation.

First: the MM

Essentially, our task is to generate something in MM. This is very hard, so we try something different: local consistency.

In other words, we try to restrict some local properties of the μ\mu and hope that these local properties lead to some μM\mu \in M. To help solify this idea, we look at a pairwise MRF.

In this MRF example, the μ\mu is a function that takes in a value of xx and outputs a vector rerpresentation. The original f(x)f(x) outputs an indicator. The expectation over ff yields a convex combination, which means that it’s no longer a vector of indicators. However, some key conditions still hold

So these are some easy checks you can enforce on the μ\mu. But we also know that the assignments must be locally consistent

Now, these conditions are NOT sufficient for a μ\mu to be in MM. We lack a global consistency which is necessary for loops. For example, if we had ABCAA→B→C→A, it is not enough to check if each neighbor is consistent; we need to check if the loop is consistent with itself.

However, if uMu\in M, then these conditions are still satisfied, because they are a looser constraint. We can define these constraints as forming MLM_L. Geometrically, it might look like this:

So with these constraints, we can (decently) easily generate a μ\mu. And interestingly, if the graph were a tree, this ML=MM_L = M because the local assumption is all you need.

The entropy

Let’s look at a tree assumption first. If the MRF were a tree, can we compute H(μ)H(\mu)? Well, recall that the exponential family maximizes entropy. Let’s not worry about where that comes from, but keep this in midn. So H(μ)H(\mu) would be an exponential family.

We find the largest possible μ\mu in the domain, compute the reverse mapping through an MLE optimization problem to get

and then you can compute the entropy easily.

For trees, the entropy decomposes

where we have

Again, this is valid for trees only. But we can use the Bethe-free energy approximation in which we just use

for all graphs.

Putting this together

With these two relaxations, we have

Now, this is very interesting. For non-trees, it is not concave, but for trees, it is.

But this is the same thing with loopy belief propagation! In fact, we cay claim that if Loopy BP converges, the estimated marginals are stationary points in the variational objective, i.e. the μ\mu.