We have an intractable posterior distribution $$p(x|D)$$ that we wish to approximate with $$q(x)$$, chosen from a given family of tractable distribution (eg, a gaussian).

We define $$\widetilde{p}(x) = p(x|D)p(D) = p(x,D)$$ which is easier to compute pointwise, since we don’t need to compute the expensive $$p(D)$$.

The goal is to approximate $$q$$ with a cost function defined using the KL-divergence from Information Theory:

$J(q) = KL(q\|\widetilde{p}) = \sum_x q(x) \log \frac{q(x)}{\widetilde{p}(x)}$

or using an integral for continuous distributions.

We can implement KL divergence in R:

# KL for continuous functions
KL <- function(q, p_tilde, lower, upper, ...) {
f <- function(x) q(x, ...) * log(q(x, ...)/p_tilde(x))
integrate(f, lower, upper)$value } # an eg: p_tilde <- function(x) dgamma(x, shape=3.0, scale=0.25) q1 <- function(x) dlnorm(x, 0, 1) q2 <- function(x) dlnorm(x, 0, .45) curve(p_tilde, 0, 6, lwd=2, ylab="") curve(q1, 0, 6, lwd=2, col="red", add=T) curve(q2, 0, 6, lwd=2, col="green", add=T) # this one 'seems' closer...  KL(q1,p_tilde, lower=1e-3, upper=100) ## [1] 1.709245 KL(q2,p_tilde, lower=1e-3, upper=100) # ...and in fact, KL gives a smaller number ## [1] 0.3400462 Notice that this is not exactly a DL divergence, since $$\widetilde{p}$$ is a non-normalized ‘distribution’. To see that cost function $$J$$ works as desired, let’s develop the equation $\begin{array}{lcll} J(q) & = & \sum_x q(x) \log \frac{q(x)}{\widetilde{p}(x)} & \\ & = & \sum_x q(x) \log \frac{q(x)}{p(x|D)p(D)} & \\ & = & \sum_x q(x) \left( \log \frac{q(x)}{p(x|D)} - \log p(D) \right) & \\ & = & \sum_x q(x) \log \frac{q(x)}{p(x|D)} - \log p(D) & \color{blue}{ \sum_x q(x) = 1} \\ & = & KL(q\|p(x|D)) - \log p(D) & \end{array}$ Since $$p(D)$$ is a constant, it means that minimizing $$J(q)$$ is minimizing $$KL(q\|p(x|D))$$, and so $$q(x)$$ will approach $$p(x|D)$$. ## Aproximation via optim Here’s an eg where we want to approximate a gamma using log-normals: variational_lnorm <- function(p_tilde, lower, upper) { q <- dlnorm # in this eg, q is a log-normal J <- function(params) { KL(q, p_tilde, lower=lower, upper=upper, meanlog=params[1], sdlog=params[2]) } optim(par=c(0, 1), fn=J)$par
}

# an eg:
p_tilde <- function(x) dgamma(x, shape=3.0, scale=0.25)

approximation_params <- variational_lnorm(p_tilde, lower=1e-3, upper=100)
# get the resulting approximation:
q <- function(x) dlnorm(x, approximation_params[1], approximation_params[2])

KL(q,p_tilde,1e-3,10) # compute their distance
## [1] 0.02765858

curve(p_tilde, 0, 6, lwd=2, ylab="", ylim=c(0,1.25))
curve(q,       0, 6, lwd=2, col="red",   add=T)

We can simulate Laplace approximation (ie, gaussian approximation) of, say, a given beta:

variational_norm <- function(p_tilde, lower, upper) {
q <- dnorm  # in this eg, q is a normal

J <- function(params) {
KL(q, p_tilde, lower=lower, upper=upper, mean=params[1], sd=params[2])
}

optim(par=c(0.5, 0.2), fn=J)$par # initial values are tricky, not very stable } p_tilde <- function(x) dbeta(x,11,9) approximation_params <- variational_norm(p_tilde, lower=1e-3, upper=1-1e-3) # get the resulting approximation: q <- function(x) dnorm(x, mean=approximation_params[1], sd=approximation_params[2]) KL(q,p_tilde,0,1) # compute their distance ## [1] 0.004112233 curve(p_tilde, 0, 1, lwd=2, ylab="", ylim=c(0,4)) curve(q, 0, 1, lwd=2, col="red", add=T) The next eg approximates beta distributions with Kumaraswamy distributions, which has simpler PDF function $f_\text{kumar}(x|a,b) = abx^{a-1}(1-x^a)^{b-1}$ and has a closed CDF expression: $F_\text{kumar}(x|a,b) = 1-(1-x^a)^b$ In this post, John Cook refers that, since the CDF is easy to invert, it’s simple to generate random samples from $$K(a,b)$$ by generating $$u \sim U(0,1)$$ and return $F^{-1} = (1-(1-u)^{1/b})^{1/a}$ library(extraDistr) n <- 1e4 u <- runif(n) a <- 1/2 b <- 1/2 ku <- (1 - (1-u)^(1/b))^(1/a) hist(ku, breaks=50, prob=T) q <- function(x) dkumar(x,a,b) curve(q, 0, 1, lwd=2, col="red", add=T) Let’s do the approximation: variational_kumaraswamy <- function(p_tilde, lower, upper) { q <- dkumar J <- function(params) { # TODO: somehow the KL-divergence function outputs negative values (?) # included abs() in order to work abs(KL(q, p_tilde, lower=lower, upper=upper, a=params[1], b=params[2])) } optim(par=c(0.5, 0.5), fn=J, method="BFGS")$par
}

p_tilde <- function(x) dbeta(x,1/2,1/2)

approximation_params <- variational_kumaraswamy(p_tilde, lower=0.01, upper=0.99)
# get the resulting approximation:
q <- function(x) dkumar(x, approximation_params[1], approximation_params[2])

KL(q,p_tilde,0.01,0.99) # compute their distance
## [1] 1.240327e-16

curve(p_tilde, 0, 1, lwd=2, ylab="", ylim=c(0,4))
curve(q,       0, 1, lwd=2, col="red",   add=T)

Another eg:

p_tilde <- function(x) dbeta(x,3,3)

approximation_params <- variational_kumaraswamy(p_tilde, lower=0.01, upper=0.99)
# get the resulting approximation:
q <- function(x) dkumar(x, approximation_params[1], approximation_params[2])

#q <- function(x) dkumar(x, 5, 251/40)

KL(q,p_tilde,0.01,0.99) # compute their distance
## [1] 0.001386222
curve(p_tilde, 0, 1, lwd=2, ylab="", ylim=c(0,4))
curve(q,       0, 1, lwd=2, col="red",   add=T)

But this method, for these two distributions, is very unstable. It diverges with many parameter values…