Reference: David M. Blei, Alp Kucukelbir & Jon D. McAuliffe (2017) Variational Inference: Review for Statisticians, Journal of the American Statistical Association, 112:518, 859-877.

Model setting

\(\displaystyle \begin{array}{{>{\displaystyle}l}} x_{i} |c_{i} ,\ \mu \ \sim \ N\left( c^{T}_{i} \mu ,\ 1\right)\\ \\ c_{i} \ \sim \ Categorical( 1/K,\ ...,\ 1/K)\\ \\ \mu _{k} \ \sim \ N\left( 0,\ \sigma ^{2}\right)\\ \\ q( c_{i} ;\ \varphi _{i}) \ \sim \ Categorical( \varphi _{i1} ,\ ...,\ \varphi _{iK})\\ \\ q( \mu _{k} ;\ m_{k} ,\ s_{k}) \ \sim \ N\left( m_{k} ,\ s^{2}_{k}\right)\\ \\ \end{array}\)

Calculate evidence lower bound (ELBO)

\[\begin{eqnarray*} ELBO = E_q[logp(x, c, \mu)] - E_q[ogq(\mu, c)] \\ \\ \\ \end{eqnarray*}\]

Compute \(\ E( log( p( x,c,\mu )))\)

\(\displaystyle \begin{array}{{>{\displaystyle}l}} p( \mu_{k}) \ =\ \frac{1}{\sqrt{2\pi } \sigma } exp\left( -\frac{\mu ^{2}_{k}}{2\sigma ^{2}}\right) \\ \\ p( c_{i}) \ =\ \frac{1}{K} \\ \\ p( x_{i} |\mu ,\ c_{i}) \ =\ \frac{1}{\sqrt{2\pi }}\prod ^{K}_{k} exp\left( -\frac{( x_{i} -\mu_{k})^{2}}{2}\right)^{c_{ik}}\\ \\ p( x,c,\mu ) \ =\ \left(\prod ^{N}_{i} p( x_{i} ,c_{i} |\mu )\right) p( \mu ) \ =\ \left(\prod ^{N}_{i} p( x_{i} |c_{i} ,\mu ) p( c_{i}|\mu)\right)\left(\prod ^{K}_{k} p( \mu_{k})\right) =\ \left(\prod ^{N}_{i} p( x_{i} |c_{i} ,\mu ) p( c_{i})\right)\left(\prod ^{K}_{k} p( \mu _{k})\right) \ \\ \\ log( p( x,c,\mu )) \ =\ \underbrace{\sum ^{N}_{i} log( p( x_{i} |c_{i} ,\mu ))} \ +\ \underbrace{\sum ^{N}_{i} log( p( c_{i}))} \ +\ \underbrace{\sum ^{K}_{k} log( p( \mu _{k}))}\\ \hspace{4.9cm} I \hspace{3.55cm} II \hspace{3.0cm} III \\ \\ I.\hspace{0.2cm} \sum ^{N}_{i} log( p( x_{i} |c_{i} ,\mu )) \ =\sum ^{N}_{i} \ log\left(\frac{1}{\sqrt{2\pi }}\right) \ +\ \sum ^{N}_{i}\sum ^{K}_{k} c_{ik}\left( -\frac{( x_{i} -\mu_{k})^{2}}{2}\right) \ =-Nlog\left(\sqrt{2\pi }\right) +\ \sum ^{N}_{i}\sum ^{K}_{k} c_{ik}\left( -\frac{x^{2}_{i} +\mu ^{2}_{k} -2x_{i} \mu _{k}}{2}\right) \ \\ \hspace{0.6cm} exp.ll\ =\ E\left( \ \sum ^{N}_{i} log( p( x_{i} |c_{i} ,\mu )) \ \right) \ =\ -Nlog\left(\sqrt{2\pi }\right) +\ \ \sum ^{N}_{i}\sum ^{K}_{k} \varphi _{ik}\left( -\frac{x^{2}_{i} +s^{2}_{k} +m^{2}_{k} -2x_{i} m_{k}}{2}\right)\\ \\ II.\hspace{0.2cm} \sum ^{N}_{i} log( p( c_{i})) \ =\ \sum ^{N}_{i} log\left(\frac{1}{K}\right) \ =\ -Nlog( K)\\ \hspace{0.8cm} exp.pc\ =\ E\left(\sum ^{N}_{i} log( p( c_{i}))\right) \ =\ -Nlog( K)\\ \\ III.\hspace{0.2cm} \sum ^{K}_{k} log( p( \mu _{k})) \ =\ \sum ^{K}_{k} log\left(\frac{1}{\sqrt{2\pi } \sigma }\right) -\ \sum ^{K}_{k}\frac{\mu ^{2}_{k}}{2\sigma ^{2}} \ \ =\ -Klog\left( \sigma \sqrt{2\pi }\right) -\ \sum ^{K}_{k}\frac{\mu ^{2}_{k}}{2\sigma ^{2}}\\ \hspace{1cm} exp.pm\ =E\left(\sum ^{K}_{k} log( p( \mu _{k}))\right) \ =\ -Klog\left( \sigma \sqrt{2\pi }\right) \ -\ \sum ^{K}_{k}\frac{s^{2}_{k} +m^{2}_{k}}{2\sigma ^{2}}\\ \\ \\ \end{array}\)

Compute \(\ E( log(q( \mu ,c)))\)

\(\displaystyle \begin{array}{{>{\displaystyle}l}} q( \mu_{k} ;\ m_{k} ,\ s_{k}) \ =\ \ \frac{1}{\sqrt{2\pi } s_{k}} exp\left( -\frac{( \mu_{k} -m_{k})^{2}}{2s^{2}_{k}}\right) \\ \\ q( c_{i} ;\ \varphi_{i}) \ =\ \prod ^{K}_{k} \varphi ^{c_{ik}}_{ik}\\ \\ q( \mu ,c) \ =\ \left(\prod ^{K}_{k} q( \mu_{k})\right)\left(\prod ^{N}_{i} q( c_{i})\right) \\ \\ log( q( \mu ,c)) \ =\ \underbrace{\sum ^{K}_{k} log( q( \mu_{k}))} \ +\ \underbrace{\sum ^{N}_{i} log( q( c_{i}))}\\ \hspace{4.1cm} I \hspace{3.2cm} II \\ \\ \\ I. \hspace{0.2cm} \sum ^{K}_{k} log( q( \mu _{k})) \ =\ \sum ^{K}_{k} log\left(\frac{1}{\sqrt{2\pi } s_{k}}\right) \ -\ \sum ^{K}_{k}\frac{( \mu _{k} -m_{k})^{2}}{2s^{2}_{k}} \ =\ -\sum ^{K}_{k} log\left( s_{k}\sqrt{2\pi }\right) \ -\ \sum ^{K}_{k}\frac{\mu ^{2}_{k} +m^{2}_{k} -2\mu _{k} m_{k}}{2s^{2}_{k}} \ \\ \hspace{0.6cm} exp.vm\ =E\left(\sum ^{K}_{k} log( q( \mu _{k}))\right) \ =\ -\sum ^{K}_{k} log\left( s_{k}\sqrt{2\pi }\right) -\ \sum ^{K}_{k}\frac{m^{2}_{k} +s^{2}_{k} +m^{2}_{k} -2m^{2}_{k}}{2s^{2}_{k}} \ =\ -\sum ^{K}_{k} log\left( s_{k}\sqrt{2\pi }\right) \ -\ \frac{K}{2}\\ \\ \\ II. \hspace{0.2cm} \sum ^{N}_{i} log( q( c_{i})) \ =\ \sum ^{N}_{i}\sum ^{K}_{k} c_{ik} log( \varphi _{ik})\\ \hspace{0.8cm} exp.vc\ =E\left(\sum ^{N}_{i} log( q( c_{i}))\right) =\ \sum ^{N}_{i}\sum ^{K}_{k} \varphi _{ik} log( \varphi _{ik})\\ \\ \end{array}\)

Compute ELBO

\(\displaystyle ELBO\ =\ exp.ll\ +\ exp.pm\ +\ exp.pc\ -\ exp.vm\ -\ exp.vc\)

calc_elbo <- function(x, N, K, sigma, m, s, phi){
  #expected likelihood function
  exp.ll <- -N*log(sqrt(2*pi))
  for(i in 1:N){
    for(k in 1:K){
      exp.ll <- exp.ll - phi[i,k]*(x[i]^2 + s[k]^2 + m[k]^2 - 2 * m[k] * x[i])/2
    }
  }
  
  #expected prior of clustering assignments
  exp.pc <- -N*log(K)
  
  #expected prior of component means
  exp.pm <- -K*log(sigma*sqrt(2*pi))
  for(k in 1:K){
    exp.pm <- exp.pm - (s[k]^2 + m[k]^2)/(2*sigma^2)
  }
  
  #expected variational factor of component means
  exp.vm <- sum(-log(s*sqrt(2*pi))) - K/2
  
  #expected variational factor of clustering assignments
  exp.vc <- 0
  for(i in 1:N){
    for(k in 1:K){
      exp.vc <- exp.vc + phi[i,k] * log(phi[i,k])
    }
  }
  exp.ll + exp.pc + exp.pm - exp.vc - exp.vm
}

Perform CAVI

\(\displaystyle \begin{array}{{>{\displaystyle}l}} p( \mu_{k}) \ =\ \frac{1}{\sqrt{2\pi } \sigma } exp\left( -\frac{\mu ^{2}_{k}}{2\sigma ^{2}}\right) \\ p( c_{i}) \ =\ \frac{1}{K} \\ p( x_{i} |c_{i} ,\ \mu ) \ =\frac{1}{\sqrt{2\pi }}\prod ^{K}_{k} exp\left( -\frac{( x_{i} \ -\ \mu_{k})^{2}}{2}\right)^{{c_{i}}_{k}}\\ \\ p( x ,\ c ,\ \mu ) \ =\ p( x,\ c\ |\ \mu ) p( \mu ) \ =\ \left(\prod ^{N}_{i} p( x_{i} |c_{i} ,\ \mu ) \ p( c_{i})\right)\left(\prod ^{K}_{k} \ p( \mu _{k}) \ \right) \\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ =\ C\left(\prod ^{N}_{i}\prod ^{K}_{k} exp\left( -\frac{( x_{i} \ -\ \mu_{k})^{2}}{2}\right)^{{c_{i}}_{k}}\right)\left(\prod ^{K}_{k} exp\left( -\frac{\mu ^{2}_{k}}{2\sigma ^{2}}\right)\right) ,\ \ where\ C\ is\ a\ constant.\\ \\ \end{array}\)

Update for \(\mu _{k}\)

\(\displaystyle \begin{array}{{>{\displaystyle}l}} \\ log( p( x,\ c ,\ \mu_{k} ,\mu_{-k})) \ =\ C\ +\ \sum ^{N}_{i}\sum ^{K}_{k}\left( -\ c_{ik}\frac{( x_{i} \ -\ \mu_{k})^{2}}{2}\right) \ +\ \sum ^{K}_{k} -\frac{\mu ^{2}_{k}}{2\sigma ^{2}} \\ \hspace{3.9cm} =\ C\ -\ \frac{\mu ^{2}_{k}}{2\sigma ^{2}} \ -\ \sum ^{N}_{i}\left( c_{ik}\frac{( x_{i} \ -\ \mu _{k})^{2}}{2}\right)\\ \\ E_{c,\ \mu_{-k}}( log( p( x,\ c ,\ \mu _{k} ,\mu _{-k}))) \ =\ C\ -\ \frac{\mu ^{2}_{k}}{2\sigma ^{2}} \ -\ \sum ^{N}_{i}\left( E_{c_{i}}( c_{ik})\frac{( x_{i} \ -\ \mu _{k})^{2}}{2}\right)\\ \hspace{5.3cm} =\ C\ -\ \ \frac{\mu ^{2}_{k}}{2\sigma ^{2}} \ -\sum ^{N}_{i}\left( E_{c_{i}}( c_{ik})\frac{x^{2}_{i} \ +\ \mu ^{2}_{k} \ -\ 2x_{i} \mu _{k}}{2}\right)\\ \hspace{5.3cm} =\ C\ -\ \ \frac{\mu ^{2}_{k}}{2\sigma ^{2}} \ -\sum ^{N}_{i}\left( E_{c_{i}}( c_{ik})\frac{\mu ^{2}_{k} \ -\ 2x_{i} \mu _{k}}{2}\right)\\ \hspace{5.3cm} =\ C\ -\ \ \frac{\mu ^{2}_{k}}{2\sigma ^{2}} \ -\sum ^{N}_{i}\frac{E_{c_{i}}( c_{ik}) \mu ^{2}_{k}}{2} \ +\ \sum ^{N}_{i} E_{c_{i}}( c_{ik}) x_{i} \mu _{k}\\ \\ E_{c_{i}}( c_{ik}) \ =\ \varphi _{ik}\\ \\ q^{*}( \mu_{k}) \ \varpropto \ exp( E_{c,\ \mu_{-k}}( log( p( x,\ c ,\ \mu_{k} ,\mu_{-k})))) \ \varpropto \ exp\left( -\ \ \frac{\mu ^{2}_{k}}{2\sigma ^{2}} \ -\sum ^{N}_{i}\frac{\varphi_{ik} \mu ^{2}_{k}}{2} \ +\ \sum ^{N}_{i} \varphi_{ik} x_{i} \mu _{k}\right)\\ \hspace{1.5cm} =\ exp\left( -\ \ \left(\frac{1}{2\sigma ^{2}} \ +\sum ^{N}_{i}\frac{\varphi_{ik}}{2}\right) \mu ^{2}_{k} \ +\ \sum ^{N}_{i} \varphi_{ik} x_{i} \mu _{k}\right)\\ \hspace{1.5cm} =\ exp\left(\frac{}{}\frac{\mu ^{2}_{k} \ -\ \frac{\sum ^{N}_{i} \varphi_{ik} x_{i} \mu_{k}}{\frac{1}{2\sigma ^{2}} \ +\sum ^{N}_{i}\frac{\varphi_{ik}}{2}}}{\frac{1}{\frac{1}{2\sigma ^{2}} \ +\sum ^{N}_{i}\frac{\varphi_{ik}}{2}}}\right) \ =\ exp\left(\frac{}{}\frac{\mu ^{2}_{k} \ -\ 2\ \cdot \frac{\sum ^{N}_{i} \varphi_{ik} x_{i}}{\frac{1}{\sigma ^{2}} \ +\sum ^{N}_{i} \varphi_{ik}} \mu_{k}}{2\cdot \ \frac{1}{\frac{1}{\sigma ^{2}} \ +\sum ^{N}_{i} \varphi _{ik}}}\right)\\ \\ \\ For\ an\ approprate\ posterior\ distribution,\ \int q^{*}( \mu_{k}) \ d\mu _{k} \ =\ 1\\ \\ \therefore \ \mu_{k} \ \sim \ N\left(\frac{\sum ^{N}_{i} \varphi_{ik} x_{i}}{\frac{1}{\sigma ^{2}} \ +\sum ^{N}_{i} \varphi_{ik}} ,\ \frac{1}{\frac{1}{\sigma ^{2}} \ +\sum ^{N}_{i} \varphi _{ik}}\right)\\ \\ that\ is:\ m_{k} \ =\ \frac{\sum ^{N}_{i} \varphi_{ik} x_{i}}{\frac{1}{\sigma ^{2}} \ +\sum ^{N}_{i} \varphi_{ik}} ,\ s^{2}_{k} \ =\ \frac{1}{\frac{1}{\sigma ^{2}} \ +\sum ^{N}_{i} \varphi _{ik}}\\ \\ \\ \end{array}\)

Update for\(\ c_{ik}\)

\(\displaystyle \begin{array}{{>{\displaystyle}l}} \\ log( p( x,\ \mu ,\ c_{i} ,\ c_{-i})) \ =\ C\ +\ \sum ^{N}_{i}\sum ^{K}_{k}\left( -\ c_{ik}\frac{( x_{i} \ -\ \mu_{k})^{2}}{2}\right) \ +\ \sum ^{K}_{k} -\frac{\mu ^{2}_{k}}{2\sigma ^{2}} \ \\ \hspace{3.8cm} =\ C\ -\sum ^{K}_{k} \ c_{ik}\frac{( x_{i} \ -\ \mu _{k})^{2}}{2}\\ \\ E_{\mu ,\ c_{-i}}( log( p( x,\ \mu ,\ c_{ik} ,\ c_{-ik}))) \ =\ \ C\ \ -\sum ^{K}_{k}\left(\frac{c_{ik} x^{2}_{i}}{2} \ +\ \frac{c_{ik} E_{\mu_{k}}\left( \mu ^{2}_{k}\right)}{2} -\ c_{ik} x_{i} E_{\mu_{k}}( \mu _{k})\right)\\ \\ \\ E_{\mu_{k}}( \mu_{k}) \ =\ m_{k}\\ \\ \\ E_{\mu_{k}}\left( \mu ^{2}_{k}\right) \ =\ Var( \mu_{k}) \ +\ E^{2}_{\mu_{k}}( \mu_{k}) \ =\ s^{2}_{k} \ +\ m^{2}_{k}\\ \\ \\ q^{*}( c_{i}) \ \varpropto \ exp( E_{\mu ,\ c_{-i}}( log( p( x,\ \mu ,\ c_{i} ,\ c_{-i})))) \ \varpropto \ exp\left(\sum ^{K}_{k}\left( -\frac{c_{ik} x^{2}_{i}}{2} \ -\ \frac{c_{ik} \ \left( s^{2}_{k} \ +\ m^{2}_{k}\right)}{2} +\ c_{ik} x_{i} m_{k}\right)\right)\\ \hspace{1.3cm} =\ exp\left( -\frac{x^{2}_{i}}{2} \ +\sum ^{K}_{k}\left( \ -\ \frac{c_{ik} \ \left( s^{2}_{k} \ +\ m^{2}_{k}\right)}{2} +\ c_{ik} x_{i} m_{k}\right)\right) \ \varpropto exp\left(\sum ^{K}_{k}\left( \ -\ \frac{c_{ik} \ \left( s^{2}_{k} \ +\ m^{2}_{k}\right)}{2} +\ c_{ik} x_{i} m_{k}\right)\right)\\ \\ \varphi_{ik} =\ C\ \cdot \ exp\left( \ -\ \frac{s^{2}_{k} \ +\ m^{2}_{k}}{2} +\ x_{i} m_{k}\right)\\ \\ For\ an\ appropraite\ posterior\ distribution,\ \sum ^{K}_{k} \varphi _{ik} \ =\ 1\\ \\ \varphi_{ik} \ =\ \frac{exp\left( \ -\ \frac{s^{2}_{k} \ +\ m^{2}_{k}}{2} +\ x_{i} m_{k}\right)}{\sum ^{K}_{k} exp\left( \ -\ \frac{s^{2}_{k} \ +\ m^{2}_{k}}{2} +\ x_{i} m_{k}\right)} \end{array}\)

update.mean <- function(x, N, sigma, phi_k){
  m <- sum(phi_k * x)/(1/sigma^2 + sum(phi_k)) 
  s <- sqrt(1/(1/sigma^2 + sum(phi_k)))
  return(list(m = m, s = s))
}


update.cluster <- function(x_i, m, s){
  phi_i <- exp(-(s^2+m^2)/2+x_i*m) 
  phi_i <- phi_i/sum(phi_i)
  phi_i
}


CAVI <- function(x,N,K,sigma,m,s,phi, tol = 1e-10){
  elbo <- calc_elbo(x,N,K,sigma,m,s,phi)
  iter <- 1
  #while elbo not converge
  while(TRUE){
    for(k in 1:K){
      new.mean.factor <- update.mean(x,N,sigma,phi[,k])
      m[k] <- new.mean.factor$m
      s[k] <- new.mean.factor$s
    }
    for(i in 1:N){
      phi[i,] <- update.cluster(x[i],m,s)
    }
    iter <- iter + 1
    elbo[iter] <- calc_elbo(x,N,K,sigma,m,s,phi)
    
    #check convergence
    if(abs(elbo[iter] - elbo[iter - 1]) < tol){
      break
    }
  }
  return(list(m = m, s = s, phi = phi, elbo = elbo))
}

Simulate data

set.seed(1995)

#set N and K
N <- 1000
r.K <- 4

#real mean
r.mu <- c(0, 5 ,10, 15)

#real variance
r.var  <- 1

#real cluster assignment
r.c <- rep(c(1,2,3, 4),each = N/r.K)
r.c <- model.matrix(~0 + as.factor(r.c))

#simulate data
x <- rnorm(N, r.c %*% r.mu, r.var)

ggplot() + 
  geom_point(aes(x = 1:N, y = x, col = as.factor(r.c %*% r.mu)), size = 2) + 
  scale_color_discrete('real mean') + 
  xlab('') + 
  ylab('')

Implementation

Initialization

#prior sigma
sigma <- 5

#model K
K <- 4

#set initial phi
phi  <- t(replicate(N, runif(K)))
phi <- phi/rowSums(phi)

#set initial variational mean and variance
m <- rnorm(K, sd = 3)
s <- runif(K, 0, 10)

Perform CAVI

ribbon.plot <- phi

for(i in 1:N){
  ribbon.plot[i,] <- cumsum(ribbon.plot[i,])
}

ggplot(data = data.frame(ribbon.plot)) + 
  geom_ribbon(aes(1:N,ymin = 0, ymax = X1,fill = 'Group1')) + 
  geom_ribbon(aes(1:N,ymin = X1, ymax = X2, fill = 'Group2')) + 
  geom_ribbon(aes(1:N,ymin = X2, ymax = X3, fill = 'Group3')) +
  geom_ribbon(aes(1:N,ymin = X3, ymax = X4, fill = 'Group4')) + xlab('samples')

result <- CAVI(x,N,K,sigma,m,s,phi)

#plot the result
ribbon.plot <- result$phi
for(i in 1:N){
  ribbon.plot[i,] <- cumsum(ribbon.plot[i,])
}

ggplot(data = data.frame(ribbon.plot)) + 
  geom_ribbon(aes(1:N,ymin = 0, ymax = X1,fill = 'Group1')) + 
  geom_ribbon(aes(1:N,ymin = X1, ymax = X2, fill = 'Group2')) + 
  geom_ribbon(aes(1:N,ymin = X2, ymax = X3, fill = 'Group3')) +
  geom_ribbon(aes(1:N,ymin = X3, ymax = X4, fill = 'Group4')) + xlab('samples')

plot(result$elbo, type = 'l', lwd = 4, col = alpha('steelblue2', 0.7), xlab = 'iteration', ylab = 'elbo')

Posterior \(m_k\) and \(s_k\)

result$m
## [1] 10.05792975 14.97314177  5.12440010  0.00259356
result$s
## [1] 0.06349192 0.06309637 0.06350073 0.06287964