Variational Inference (theory)
| Tags | CS 228InferenceLearning |
|---|
Why and what is variational inference?
The term comes from the calculus of variations, in which you optimize a function (and not numbers) to fit something.
Our main problem comes from marginal inference. Given a joint distribtuion, can we compute something like ? This requires marginalization, which is an exponential thing. We can approximate this posterior with a simpler distribution. Therefore, we reframe inference as optimization.
The key insight
The prior distribution is typically complicated because it encompasses so much about the world. However, once you condition on some evidence , the posterior is typically simpler.
If we are looking at MRFs, the posterior is proportional to the joint distribution but you just clamp the observations in the factors, which of course makes it simpler. But it’s unnormalized, and the normalization constant is hard to compute.
%20adf17bb1530a459b9f178dc01336490e/Untitled.png)
Methods of optimization
The goal is to find some that is close to . We compare this through the DKL and we compute or .
The first one is called the M-projection of , and the second one is called the I-projection. These two objectives actually yield quite different results.
M-projection
%20adf17bb1530a459b9f178dc01336490e/Untitled%201.png)
Here, we are essentially optimizing over an expected value over . This means that if is very small over some , then the DKL is very large. As such, the is forced to have a very wide support. Philosophically speaking, we are finding the such that everything that is correct is encapsulated by it.
As we learned from last time with exponential families, M-projection is equivalent to setting the moment of the sufficient statistic equal to each other.
Below is an example of M-projection of gaussians, where is constrained to have a diagonal covariance. You can clearly see that the means are matching
%20adf17bb1530a459b9f178dc01336490e/Untitled%202.png)
I-projection
%20adf17bb1530a459b9f178dc01336490e/Untitled%203.png)
In this case, we are taking the expectation across , which means that we want to encapsulate the highest density within the support of . Essentially, this method is “everything I see is correct”. This has the result of excluding things, as can be seen below. When given a multimodal distribtuion, I-projection will settle on one of the modes, while M-projection will take the center
%20adf17bb1530a459b9f178dc01336490e/Untitled%204.png)
Mean Field inference
What do we count as “simple” for ? Well, we can make one key assumption: is fully factorizable. In other words, each variable under is independent, and the MRF graph is just a bunch of vertices without any edges.
We define as the family of all fully-factored distributions, and we can try solving M projection.
M-projection derivation
We can set up the DKL and use a little multiply-by-one trick, using . We have a key question: what minimizes the difference?
%20adf17bb1530a459b9f178dc01336490e/Untitled%205.png)
We get this upshot: to minimize divergence between the simple and the complicated , just let .
From our previous discussion on M-projection, we also know that if is an exponential family, then we have
If we were to model this, then we would probably use something like an exponential family. However, this defeats the purpose, because marginalization (and sampling) is generally intractable from (or else why would you need this simpler distribution?)
Variational Models through I-projections
Now, we saw that the M-projection approach was simple because you just needed to get to match with some distribtuion, or if were an exponential family, we just needed to match the moments. However, with this simplicity comes a difficulty of expectation. The distributions are just intractable to sample from / match to.
Here, we look at a different approach that is the pinnacle of variational inference, in which you turn an inference problem into an optimization problem.
The exact inference is finding the partition function, but this is a general framework and we will talk about why we care, after the derivation.
I-projection
Let’s say that we had a generic exponential family distribution (includes any MRF). Moreover, we know the factors, but we don’t know the normalizing constant because it is intractable to marginalize.
You can derive the I-projection starting from the DKL and you get
%20adf17bb1530a459b9f178dc01336490e/Untitled%206.png)
If we use the property that DKL is at least zero, we get that the log partition function is bounded by
%20adf17bb1530a459b9f178dc01336490e/Untitled%207.png)
let’s unpack this
which means that you can approximate the partition function by picking the best
%20adf17bb1530a459b9f178dc01336490e/Untitled%208.png)
again, this is variational because we are optimizing over a function .
First, what are we doing by optimizing ? Well, we are pushing closer to , which is our big goal. This is an optimization problem. It’s worth noting that if were an exponential model, you don’t consider the when optimizing over because it doesn’t depend on .
But we are also pushing closer to , which is an inference problem. For example, if your target distribution were and were a directed model, then the would be , which can be otherwise intractable to compute. By optimizing , you kill two birds with one stone. You get an approximate distribution and an inference!
We call the right hand side as the variational lower bound. It can also be written as
Optimizing the I-projection: initial idea
We can rearrange the objective through the summations
%20adf17bb1530a459b9f178dc01336490e/Untitled%209.png)
One immediate thought is that we can just find the that maximizes the , and then let . This basically looks at the distribtuion and makes a dirac delta function where the mode occurs.
This is an intrusive thought, but we see that it doesn’t actually work because we have the which encourages a “fat” distribution.
It is, however, a good approximation if peaks sharply.
Variational algorithms
We now focus on a more rigorous treatment of how we might optimize the I-projection. There are two approaches: mean-field and relaxation.
Mean-field relies on the factorization of into individual components, and relaxation removes some assumptions to make the optimization easier. First, we will look at mean field
Mean field variational
We start with the naive mean field, in which you take some connected network and turn it into an independent set of variables
%20adf17bb1530a459b9f178dc01336490e/Untitled%2010.png)
Naive Mean-field: the setup
Let’s go back to our original assumption, which is that is fully factorizable. If this is the case, then our objective can be simplified. We start with
%20adf17bb1530a459b9f178dc01336490e/Untitled%2011.png)
where .
The entropy term
Right now, the entropy term looks a little nasty. Can we factor it out, given that we have our mean field? Well, let’s start from the definitions
%20adf17bb1530a459b9f178dc01336490e/Untitled%2012.png)
Now, we realize that because is factorizable, we can push the summation into the product. we actually discover that in doing so, everything else marginalizes to 1. The notation is a little confusing, but think about pushing the summation inside.
%20adf17bb1530a459b9f178dc01336490e/Untitled%2013.png)
The objective
And our objective becomes
%20adf17bb1530a459b9f178dc01336490e/Untitled%2014.png)
subject to the standard constraints of as a distribution.
%20adf17bb1530a459b9f178dc01336490e/Untitled%2015.png)
This is ultimately a lower bound on the partition function
Block coordinate ascent
This optimization problem isn’t really fun to deal with, but we can use block coordinate ascent . The procedure is pretty straightforward
- initialize distributions randomly
- iterate through each of the variables
- maximize the objective WRT
- repeat
This is feasible because you remove consideration of all variables that are not immediate neighbors of variable . You essentially form a markov blanket with distributions and then you optimize the current distribution based on the surrounding distributions. One step looks like this (derived from the lagrangian)
%20adf17bb1530a459b9f178dc01336490e/Untitled%2016.png)
This is guarenteed to converge, though we might get trapped in a local minimum
Coordinate ascent and gibbs sampling
In both gibbs sampling and coordinate ascent we are doing some sort of markov blanket thing
%20adf17bb1530a459b9f178dc01336490e/Untitled%2017.png)
In gibbs sampling, we would “select” the elements in the markov blanket by using their previous estimations
%20adf17bb1530a459b9f178dc01336490e/Untitled%2018.png)
In this coordinate ascent, we are not selecting, but rather doing a soft selection based on a previous estimation of the distribution
%20adf17bb1530a459b9f178dc01336490e/Untitled%2019.png)
Accuracy of approximation
Mean field approimation can yield very wrong answers if the variables have a strong relationship between them. A classic example is the XOR distribution, where if and if . Mean field approximation will fail here for obvious reasons.
Beyond mean field
We can make a mean field that has some sort of connections, even if it’s less connected than the true distribution. This can help with better variational approximations. We will talk about more of this in the next section
%20adf17bb1530a459b9f178dc01336490e/Untitled%2020.png)
Relaxation algorithms
Instead of using the mean field assumption, the next variational algorithm we look at involves relaxation. This means approximating something to achieve tractability. Honestly, this makes sense because the mean field algorithm makes a rather strong assumption, don’t you think?
For the problems below, we are assuming that everything is in the exponential family. MRFs are also in the exponential family; the is just an indicator vector, and are the weights. So
%20adf17bb1530a459b9f178dc01336490e/Untitled%2021.png)
Reframing
Because we have seen that moments are helpful, let’s try to reframe our efforts in terms of moments. So we have previously established that
%20adf17bb1530a459b9f178dc01336490e/Untitled%2022.png)
Here, we define
We call this the mean parameters or the marginals of . Because the is an indicator variable, it is indeed a marginal. So might represent , or something like that (depends on how you encode ).
Note that this is a little different than our above examples of because we generalized it to exponential families, which have a linear that can be pulled outside.
Now, what are we trying to do here? Well, maybe something a bit weird. Instead of optimizing over , because we don’t know how to do this, maybe we can optimize over this instead. Hmm.
%20adf17bb1530a459b9f178dc01336490e/Untitled%2023.png)
Marginal polytope
This is pretty special. It’s called the marginal polytope , for reasons that will become apparent in a second.
It’s the set of all possible that can arise from . Now, you can think of the expectation as a convex combination of vectors, and the choice of determines the degree of convex combination. Therefore, if you see as a geometric shape, you can imagine it as a convex hull between points representing each .
%20adf17bb1530a459b9f178dc01336490e/Untitled%2024.png)
Think about why this is the case for a second. The points represent different , and the is just deciding how to combine them into the final . And again, this is literally a marginal, so you can imagine that we are making a marginal distribution from the original indicator distributions.
Reframining in terms of the polytope
%20adf17bb1530a459b9f178dc01336490e/Untitled%2025.png)
This is just a continuation from before. We needed to discuss the as a geometric interpretation, but algebraically, this is not very stimulating. We just let
%20adf17bb1530a459b9f178dc01336490e/Untitled%2026.png)
Intuitively, we are trying to optimize some marginal distribution, in the hopes that we optimize the original distribtuion.
But this is kinda dumb because the marginal polytope is complex to describe, and seems hopeless, because there is literally an inference in the .
To continue, we must relax some things!
The key approximations
- is a relaxation of the marginal polytope such that it includes but is not exactly
- We replace with an approximation.
First: the
Essentially, our task is to generate something in . This is very hard, so we try something different: local consistency.
In other words, we try to restrict some local properties of the and hope that these local properties lead to some . To help solify this idea, we look at a pairwise MRF.
In this MRF example, the is a function that takes in a value of and outputs a vector rerpresentation. The original outputs an indicator. The expectation over yields a convex combination, which means that it’s no longer a vector of indicators. However, some key conditions still hold
%20adf17bb1530a459b9f178dc01336490e/Untitled%2027.png)
So these are some easy checks you can enforce on the . But we also know that the assignments must be locally consistent
%20adf17bb1530a459b9f178dc01336490e/Untitled%2028.png)
Now, these conditions are NOT sufficient for a to be in . We lack a global consistency which is necessary for loops. For example, if we had , it is not enough to check if each neighbor is consistent; we need to check if the loop is consistent with itself.
However, if , then these conditions are still satisfied, because they are a looser constraint. We can define these constraints as forming . Geometrically, it might look like this:
%20adf17bb1530a459b9f178dc01336490e/Untitled%2029.png)
So with these constraints, we can (decently) easily generate a . And interestingly, if the graph were a tree, this because the local assumption is all you need.
The entropy
Let’s look at a tree assumption first. If the MRF were a tree, can we compute ? Well, recall that the exponential family maximizes entropy. Let’s not worry about where that comes from, but keep this in midn. So would be an exponential family.
We find the largest possible in the domain, compute the reverse mapping through an MLE optimization problem to get
%20adf17bb1530a459b9f178dc01336490e/Untitled%2030.png)
and then you can compute the entropy easily.
For trees, the entropy decomposes
%20adf17bb1530a459b9f178dc01336490e/Untitled%2031.png)
where we have
%20adf17bb1530a459b9f178dc01336490e/Untitled%2032.png)
Again, this is valid for trees only. But we can use the Bethe-free energy approximation in which we just use
%20adf17bb1530a459b9f178dc01336490e/Untitled%2033.png)
for all graphs.
Putting this together
With these two relaxations, we have
%20adf17bb1530a459b9f178dc01336490e/Untitled%2034.png)
Now, this is very interesting. For non-trees, it is not concave, but for trees, it is.
But this is the same thing with loopy belief propagation! In fact, we cay claim that if Loopy BP converges, the estimated marginals are stationary points in the variational objective, i.e. the .
%20adf17bb1530a459b9f178dc01336490e/Untitled%2035.png)