We have an intractable posterior distribution \(p(x|D)\) that we wish to approximate with \(q(x)\), chosen from a given family of tractable distribution (eg, a gaussian).

We define \(\widetilde{p}(x) = p(x|D)p(D) = p(x,D)\) which is easier to compute pointwise, since we don’t need to compute the expensive \(p(D)\).

The goal is to approximate \(q\) with a cost function defined using the KL-divergence from Information Theory:

\[J(q) = KL(q\|\widetilde{p}) = \sum_x q(x) \log \frac{q(x)}{\widetilde{p}(x)}\]

or using an integral for continuous distributions.

We can implement KL divergence in R:

# KL for continuous functions
KL <- function(q, p_tilde, lower, upper, ...) {
  f <- function(x) q(x, ...) * log(q(x, ...)/p_tilde(x))
  integrate(f, lower, upper)$value
}

# an eg:
p_tilde <- function(x) dgamma(x, shape=3.0, scale=0.25)
q1 <- function(x) dlnorm(x, 0, 1)
q2 <- function(x) dlnorm(x, 0, .45)

curve(p_tilde, 0, 6, lwd=2, ylab="")
curve(q1,      0, 6, lwd=2, col="red",   add=T)
curve(q2,      0, 6, lwd=2, col="green", add=T)  # this one 'seems' closer...

KL(q1,p_tilde, lower=1e-3, upper=100)
## [1] 1.709245
KL(q2,p_tilde, lower=1e-3, upper=100) # ...and in fact, KL gives a smaller number
## [1] 0.3400462

Notice that this is not exactly a DL divergence, since \(\widetilde{p}\) is a non-normalized ‘distribution’.

To see that cost function \(J\) works as desired, let’s develop the equation

\[\begin{array}{lcll} J(q) & = & \sum_x q(x) \log \frac{q(x)}{\widetilde{p}(x)} & \\ & = & \sum_x q(x) \log \frac{q(x)}{p(x|D)p(D)} & \\ & = & \sum_x q(x) \left( \log \frac{q(x)}{p(x|D)} - \log p(D) \right) & \\ & = & \sum_x q(x) \log \frac{q(x)}{p(x|D)} - \log p(D) & \color{blue}{ \sum_x q(x) = 1} \\ & = & KL(q\|p(x|D)) - \log p(D) & \end{array} \]

Since \(p(D)\) is a constant, it means that minimizing \(J(q)\) is minimizing \(KL(q\|p(x|D))\), and so \(q(x)\) will approach \(p(x|D)\).

Aproximation via optim

Here’s an eg where we want to approximate a gamma using log-normals:

variational_lnorm <- function(p_tilde, lower, upper) {
  q <- dlnorm # in this eg, q is a log-normal
  
  J <- function(params) {
    KL(q, p_tilde, lower=lower, upper=upper, meanlog=params[1], sdlog=params[2])
  }
  
  optim(par=c(0, 1), fn=J)$par
}

# an eg:
p_tilde <- function(x) dgamma(x, shape=3.0, scale=0.25)

approximation_params <- variational_lnorm(p_tilde, lower=1e-3, upper=100)
# get the resulting approximation:
q <- function(x) dlnorm(x, approximation_params[1], approximation_params[2])

KL(q,p_tilde,1e-3,10) # compute their distance
## [1] 0.02765858
curve(p_tilde, 0, 6, lwd=2, ylab="", ylim=c(0,1.25))
curve(q,       0, 6, lwd=2, col="red",   add=T)

We can simulate Laplace approximation (ie, gaussian approximation) of, say, a given beta:

variational_norm <- function(p_tilde, lower, upper) {
  q <- dnorm  # in this eg, q is a normal
  
  J <- function(params) {
    KL(q, p_tilde, lower=lower, upper=upper, mean=params[1], sd=params[2])
  }
  
  optim(par=c(0.5, 0.2), fn=J)$par  # initial values are tricky, not very stable
}

p_tilde <- function(x) dbeta(x,11,9)

approximation_params <- variational_norm(p_tilde, lower=1e-3, upper=1-1e-3)
# get the resulting approximation:
q <- function(x) dnorm(x, mean=approximation_params[1], sd=approximation_params[2])

KL(q,p_tilde,0,1) # compute their distance
## [1] 0.004112233
curve(p_tilde, 0, 1, lwd=2, ylab="", ylim=c(0,4))
curve(q,       0, 1, lwd=2, col="red",   add=T)