Machine Learning and Pattern Recognition - Variational - Details
Machine Learning and Pattern Recognition - Variational - Details
Variational methods used to be complicated. After writing down the standard KL-divergence
objective (previous note), researchers would try to derive clever fixed-point update equations
to optimize it. For some models, including simple ones like logistic regression, this strategy
didn’t work out. Special-case variational objectives would be crafted for particular models.
As a result, most text-book treatments of the applications of variational methods are fairly
complicated, and beyond what’s required for this course.
Fortunately the stochastic variational inference (SVI) methods developed in the last few
years are simpler to understand, more general, and scale to enormous datasets. This note
will outline enough of the idea to explain the demonstration code on the website. The
demonstration is applied to logistic regression, but in principle its log-likelihood and
gradients could be replaced with any other model with real-valued parameters.
As a reminder, we wish to minimize
with respect to our variational parameters {m, V }, the mean and covariance of our Gaussian
approximate posterior. For any {m, V } we have a lower bound on the log-marginal likelihood,
or the model evidence1 :
log p(D) ≥ − J (m, V ). (2)
We would like to maximize the model likelihood p(D) with respect to any hyperparameters,
such as the prior variance on the weights, σw2 . We can’t do that exactly, but we can instead
minimize J with respect to these parameters. So, we will jointly minimize J with respect to
the variational distribution and model hyperparameters, aiming for a tight bound2 and a
large model likelihood.
1 Unconstrained optimization
Stochastic gradient descent on parameters V and σw2 will sometimes set negative variances
and covariances that aren’t positive definite. Instead we should optimize unconstrained
quantities, such as log σw .
To optimize a covariance matrix, we can first write it in terms of its Cholesky decomposition:
V = LL> . The diagonal elements of L are positive3 , the other elements are unconstrained.
We take the log of the diagonal elements of the Cholesky decomposition, leaving the other
elements equal to the values in the Cholesky decomposition, and optimize that unconstrained
matrix.
[The website version of this note has a question here.]
cannot be computed in closed form. We could convert the integral of each term into a 1D
integral and compute it numerically. However, that is expensive.
We only need unbiased estimates of the gradients to perform stochastic gradient descent.
We can get a get a simple “Monte Carlo” unbiased estimate by sampling a random weight
from the variational posterior:
N
− EN (w; m,V ) [log p(D | w)] ≈ − ∑ log p(y(n) | x(n) , w), w ∼ N (m, V ). (7)
n =1
We can also replace the sum by scaling up the contribution of a random example, and still
get an unbiased estimate:
The expectation is now under a constant distribution, so it’s easy to write down derivatives
with respect to the variational parameters:
To estimate both gradients, we just need to be able to evaluate gradients of the log-likelihood
function with respect to the weights, which we already know how to do if we can do
maximum likelihood fitting.
[The website version of this note has a question here.]
7 References
Shakir Mohamed has recent tutorial slides on variational inference. The final slide has a
reading list of both the classic and modern variational inference papers that discovered the
theory in this note.
Black-box stochastic variational inference in five lines of Python, David Duvenaud, Ryan
P. Adams, NeurIPS Workshop on Black-box Learning and Inference, 2015. The associated
Python code and neural net demo require autograd.
4. As an example, here is how to find the following expectation over a D-dimensional vector z:
EN (z;0,V ) [log N (z; 0, V )]. (14)
Using standard manipulations, including the “trace trick”:
1 1
EN (z;0,V ) [log N (z; 0, V )] + log |2πV | = EN (z;0,V ) − Tr z> V −1 z
(15)
2 2
1
= EN (z;0,V ) − Tr zz> V −1
(16)
2
1 h i
= − Tr EN (z;0,V ) zz> V −1 (17)
2
1 1 D
= − Tr(VV −1 ) = − Tr(ID ) = − . (18)
2 2 2