We have an intractable posterior distribution \(p(x|D)\) that we wish to approximate with \(q(x)\), chosen from a given family of tractable distribution (eg, a gaussian).

We define \(\widetilde{p}(x) = p(x|D)p(D) = p(x,D)\) which is easier to compute pointwise, since we don’t need to compute the expensive \(p(D)\).

The goal is to approximate \(q\) with a cost function defined using the KL-divergence from Information Theory:

\[J(q) = KL(q\|\widetilde{p}) = \sum_x q(x) \log \frac{q(x)}{\widetilde{p}(x)}\]

or using an integral for continuous distributions.

We can implement KL divergence in R:

# KL for continuous functions
KL <- function(q, p_tilde, lower, upper, ...) {
  f <- function(x) q(x, ...) * log(q(x, ...)/p_tilde(x))
  integrate(f, lower, upper)$value
}

# an eg:
p_tilde <- function(x) dgamma(x, shape=3.0, scale=0.25)
q1 <- function(x) dlnorm(x, 0, 1)
q2 <- function(x) dlnorm(x, 0, .45)

curve(p_tilde, 0, 6, lwd=2, ylab="")
curve(q1,      0, 6, lwd=2, col="red",   add=T)
curve(q2,      0, 6, lwd=2, col="green", add=T)  # this one 'seems' closer...


KL(q1,p_tilde, lower=1e-3, upper=100)
## [1] 1.709245
KL(q2,p_tilde, lower=1e-3, upper=100) # ...and in fact, KL gives a smaller number
## [1] 0.3400462

Notice that this is not exactly a DL divergence, since \(\widetilde{p}\) is a non-normalized ‘distribution’.

To see that cost function \(J\) works as desired, let’s develop the equation

\[ \begin{array}{lcll} J(q) & = & \sum_x q(x) \log \frac{q(x)}{\widetilde{p}(x)} & \\ & = & \sum_x q(x) \log \frac{q(x)}{p(x|D)p(D)} & \\ & = & \sum_x q(x) \left( \log \frac{q(x)}{p(x|D)} - \log p(D) \right) & \\ & = & \sum_x q(x) \log \frac{q(x)}{p(x|D)} - \log p(D) & \color{blue}{ \sum_x q(x) = 1} \\ & = & KL(q\|p(x|D)) - \log p(D) & \end{array} \]

Since \(p(D)\) is a constant, it means that minimizing \(J(q)\) is minimizing \(KL(q\|p(x|D))\), and so \(q(x)\) will approach \(p(x|D)\).

Aproximation via optim

Here’s an eg where we want to approximate a gamma using log-normals:

variational_lnorm <- function(p_tilde, lower, upper) {
  q <- dlnorm # in this eg, q is a log-normal
  
  J <- function(params) {
    KL(q, p_tilde, lower=lower, upper=upper, meanlog=params[1], sdlog=params[2])
  }
  
  optim(par=c(0, 1), fn=J)$par
}

# an eg:
p_tilde <- function(x) dgamma(x, shape=3.0, scale=0.25)

approximation_params <- variational_lnorm(p_tilde, lower=1e-3, upper=100)
# get the resulting approximation:
q <- function(x) dlnorm(x, approximation_params[1], approximation_params[2])

KL(q,p_tilde,1e-3,10) # compute their distance
## [1] 0.02765858

curve(p_tilde, 0, 6, lwd=2, ylab="", ylim=c(0,1.25))
curve(q,       0, 6, lwd=2, col="red",   add=T)

We can simulate Laplace approximation (ie, gaussian approximation) of, say, a given beta:

variational_norm <- function(p_tilde, lower, upper) {
  q <- dnorm  # in this eg, q is a normal
  
  J <- function(params) {
    KL(q, p_tilde, lower=lower, upper=upper, mean=params[1], sd=params[2])
  }
  
  optim(par=c(0.5, 0.2), fn=J)$par  # initial values are tricky, not very stable
}

p_tilde <- function(x) dbeta(x,11,9)

approximation_params <- variational_norm(p_tilde, lower=1e-3, upper=1-1e-3)
# get the resulting approximation:
q <- function(x) dnorm(x, mean=approximation_params[1], sd=approximation_params[2])

KL(q,p_tilde,0,1) # compute their distance
## [1] 0.004112233

curve(p_tilde, 0, 1, lwd=2, ylab="", ylim=c(0,4))
curve(q,       0, 1, lwd=2, col="red",   add=T)

The next eg approximates beta distributions with Kumaraswamy distributions, which has simpler PDF function

\[f_\text{kumar}(x|a,b) = abx^{a-1}(1-x^a)^{b-1}\]

and has a closed CDF expression:

\[F_\text{kumar}(x|a,b) = 1-(1-x^a)^b\]

In this post, John Cook refers that, since the CDF is easy to invert, it’s simple to generate random samples from \(K(a,b)\) by generating \(u \sim U(0,1)\) and return

\[F^{-1} = (1-(1-u)^{1/b})^{1/a}\]

library(extraDistr)

n <- 1e4
u <- runif(n)

a <- 1/2
b <- 1/2

ku <- (1 - (1-u)^(1/b))^(1/a)

hist(ku, breaks=50, prob=T)
q <- function(x) dkumar(x,a,b) 
curve(q, 0, 1, lwd=2, col="red", add=T)

Let’s do the approximation:

variational_kumaraswamy <- function(p_tilde, lower, upper) {
  q <- dkumar
  
  J <- function(params) {
    # TODO: somehow the KL-divergence function outputs negative values (?)
    # included abs() in order to work
    abs(KL(q, p_tilde, lower=lower, upper=upper, a=params[1], b=params[2]))
  }
  
  optim(par=c(0.5, 0.5), fn=J, method="BFGS")$par  
}

p_tilde <- function(x) dbeta(x,1/2,1/2)

approximation_params <- variational_kumaraswamy(p_tilde, lower=0.01, upper=0.99)
# get the resulting approximation:
q <- function(x) dkumar(x, approximation_params[1], approximation_params[2])

KL(q,p_tilde,0.01,0.99) # compute their distance
## [1] 1.240327e-16

curve(p_tilde, 0, 1, lwd=2, ylab="", ylim=c(0,4))
curve(q,       0, 1, lwd=2, col="red",   add=T)

Another eg:

p_tilde <- function(x) dbeta(x,3,3)

approximation_params <- variational_kumaraswamy(p_tilde, lower=0.01, upper=0.99)
# get the resulting approximation:
q <- function(x) dkumar(x, approximation_params[1], approximation_params[2])

#q <- function(x) dkumar(x, 5, 251/40)

KL(q,p_tilde,0.01,0.99) # compute their distance
## [1] 0.001386222
curve(p_tilde, 0, 1, lwd=2, ylab="", ylim=c(0,4))
curve(q,       0, 1, lwd=2, col="red",   add=T)

But this method, for these two distributions, is very unstable. It diverges with many parameter values…