A central task when working with probabilistic models is the evaluation of the posterior distribution. It is often the case the posterior is intractable (integrals with no closed form analytical solutions; exponentially many discrete states), so approximation methods need to be employed. Broadly, these methods fall into two categories, stochastic or deterministic. In this post I discuss variational inference, a deterministic posterior approximation method. Historically, variational methods can be dated back to the 18th century with the work of Euler, Lagrange, and others on the calculus of variations (exploring possible functions that optimize a functional).

The general idea is to approximate the intractable posterior distribution $p(z \mid x)$ with a (simpler) variational distribution $q(z)$. This is treated as an optimization problem with the goal of finding the variational distribution $q(z)$ such that the difference (measured using the Kullback–Leibler (KL) divergence) to the posterior distribution is minimum. Let’s spell that out:

$$$\begin{split} \text{KL} ( q(z) \parallel p(z | x) ) & = \text{E}_{q} \left[ \log \frac{ q(z) }{ p(z | x) } \right] \\ & = \text{E}_{q} \left[ \log q(z) \right] - \text{E}_{q} \left[ \log p(z | x) \right] \\ &= \text{E}_{q} \left[ \log q(z) \right] - \text{E}_{q} \left[ \log p(z, x) \right] + \log p(x) \\ \end{split}$$$

Noting that the log evidence $\log p(x)$ is constant w.r.t. the variational distribution $q(z)$, minimizing the $\text{KL} ( q(z) \parallel p(z \mid x) )$ is equivalent to maximizing the evidence lowerbound $\mathcal{L} = \text{E}_{q} \left[ \log p(z, x) \right] - \text{E}_{q} \left[ \log q(z) \right]$. $\mathcal{L}$ can be seen to lowerbound the evidence by remembering that the KL divergence between any two distributions is non-negative.

Before moving forward, it’s worth pointing out that the $KL$ divergence is an asymmetric measure. The direction we chose above is appealing from a tractability standpoint (the expectation is taken w.r.t. the variational distribution). The alternative direction - optimizing $KL(p \parallel q)$ - is considered in expectation propagation, a method I might cover in a future session. The direction affects also the outcome of the optimization; this can be observed by taking a look at the expectation term from the first line; concretely, in places where $p$ has (very) small mass $q$ will avoid placing any mass since taking a wrong step there (over-estimating) can lead to a large penalty (a large $KL$).

Let’s remind ourselves the objective function - the evidence lowerbound - we need to maximize:

$\mathcal{L} = \text{E}_{q} \left[ \log p(z, x) \right] - \text{E}_{q} \left[ \log q(z) \right]$

The approximation in variational methods comes from restricting the family of distributions $q(z)$ - note that up to this point we made no assumption about the variational distribution. A common choice, referred to as the mean field assumption, is to use a fully factorized distribution $q(z) = \prod_{i} q(z_i)$ - note we’re not imposing any constrains on the functional form of the factors. The optimization problem involves then maximizing the lowerbound w.r.t. all of the distributions $q(z_{j})$ . Focusing on the $j$-th factor, we have:

$$$\begin{split} \mathcal{L}_{q_j} & = \text{E}_{q} \left[ \log p(z, x) \right] - \text{E}_{q_{j}} \left[ \log q(z_{j}) \right] \\ & = \int_{z} q(z) \log p(z, x) - \int_{z_{j}} q(z_{j}) \log q(z_{j}) \\ &= \int_{z_{j}} q(z_{j}) \int_{z_{-j}} q(z_{-j}) \log p(x, z) - \int_{z_{j}} q(z_{j}) \log q(z_{j}) \\ &= \int_{z_{j}} q(z_{j}) \text{E}_{q_{-j}} \left[ \log p(x, z) \right] - \int_{z_{j}} q(z_{j}) \log q(z_{j}) \\ &= \int_{z_{j}} q(z_{j}) \log \frac{ \exp\{\text{E}_{q_{-j}} \left[ \log p(x, z) \right] \} }{ q(z_{j}) } \end{split}$$$

We can recognize above the negative $KL$ divergence between $q(z_{j})$ and an unnormalized distribution, proportional to $\exp\{\text{E}_{q_{-j}} \left[ \log p(x, z) \right] \}$. The maximum of the negative $KL$ divergence is achieved when the two distributions are the same, leading to:

$$$\begin{split} q(z_{j}) & \propto \exp\{\text{E}_{q_{-j}} \left[ \log p(x, z) \right] \} \\ & \propto \exp\{\text{E}_{q_{-j}} \left[ \log p(z_{j} | x, z_{-j}) \right] \} \end{split}$$$

We’ve just derived the optimal solution for the $q(z_{j})$ factor. The solution also determines its form (recall that the only assumption we made was the factorization). This is a general result in mean field variational inference - you can underline it :-)

Since each factor depends on an expectation taken w.r.t. the other factors, the final estimates are obtained by updating them in turn while keeping the others fixed. This is done iteratively, until the evidence lowerbound converges (the bound is guaranteed to increase with each iteration).

We could stop now and enjoy what we’ve achieved so far, but there is one more useful result we can derive. A large class of models have their complete conditionals in the exponential family:

$$$p(z_{j} \mid z_{-j}, x) = h(z_{j}) \exp \{ \eta(z_{-j}, x)^T t(z_{j}) - a(\eta(z_{-j}, x)) \}$$$

Plugging in the conditional in its exponential family form into the general result we derived before we get:

$$$\begin{split} q(z_{j}) & \propto \exp\{\text{E}_{q_{-j}} \left[ \log p(z_{j} | x, z_{-j}) \right] \} \\ & \propto h(z_{j}) \exp\{ \text{E}_{q_{-j}} \left[ \eta(z_{-j}, x) \right]^T t(z_j) \} \end{split}$$$

We can see the optimal $q(z_{j})$ has the same exponential family form as the conditional, where the natural parameter $v_{j}$ is expressed as:

$$$v_{j} = \text{E}_{q_{-j}} \left[ \eta(z_{-j}, x) \right]$$$

Countless conditionally conjugate models that use mean field variational inference have been published over the years. I’ve recently applied this theory to a model for aggregating crowdsourced anaphoric annotations. Check out the appendix and code for the derivations and implementation.

A Probabilistic Annotation Model for Crowdsourcing Coreference
Silviu Paun, Jon Chamberlain, Udo Kruschwitz, Juntao Yu, Massimo Poesio
In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing (EMNLP) , 2018
[pdf] [code] [bib]