Laplace approximation to fit models including random effects in Julia

Author

Sean L. Wu (slwood89@gmail.com)

Published

December 1, 2022

Urchin model

This example comes from Simon Wood’s book Core Statistics, available for free on his website (click the link and scroll down). See (Wood 2015, pp 99-111) for more information.

This is a nice example of how to use the Laplace approximation to fit models with random effects in Julia because it does not use special DSLs (domain specific languages) and as much as possible sticks to Julia functionality available for general users. The model itself is coded in standard idiomatic Julia, and therefore should be applicable to a fairly broad class of statistical models.

The model is one of sea urchin growth. There is a mechanistic model to relate volume \(V\) to age \(a\) by the differential equation:

\[ \frac{dV}{da} = \begin{cases} g_{i} V, & V < p_{i}/g_{i} \\ p_{i} & \text{otherwise} \end{cases} \]

This is an “individual-based” model, because the random effects \(p\) and \(g\) are indexed by the individual urchin. Each urchin has an initial condition for the ODE (initial volume) of \(\omega\) (a fixed effect). The differential equation changes when the animal hits reproductive age, which is:

\[ a_{mi} = \frac{1}{g_{i}}\log\left( \frac{p_{i}}{g_{i}\omega}\right) \]

The ODE has the analytical solution:

\[ V(a) = \begin{cases} \omega \exp(g_{i} a_{i}), & a_{i} < a_{mi} \\ p_{i}/g_{i} + p_{i}(a_{i}-a_{mi}) & \text{otherwise} \end{cases} \]

The measured urchin volumes are \(v\), and the likelihood model is:

\[ \sqrt v_{i} \sim N(\sqrt V(a_{i}),\sigma^{2}) \]

And we assume the random effects follow:

\[ \begin{split} \log g_{i} \sim N(\mu_{g},\sigma_{g}^{2}) \\ \log p_{i} \sim N(\mu_{p},\sigma_{p}^{2}) \ \end{split} \]

The model parameters (“fixed effects”) are \(\omega, \mu_{g}, \sigma_{g}, \mu_{p}, \sigma_{p}, \sigma\), and the random effects are \(g, p\).

Urchin data

We are going to load the data from Simon Wood’s website.

using CSV
using HTTP
using DataFrames

http_response = HTTP.request("GET", "https://www.maths.ed.ac.uk/~swood34/data/urchin-vol.txt")
uv = CSV.read(IOBuffer(http_response.body), DataFrame, header=["row", "age", "vol"], skipto=2)

We can take a look at it:

using Plots

y::Vector{Float64} = uv[:,:vol]
a::Vector{Int64} = uv[:,:age]

scatter(a,y,label=false,xlabel="age",ylabel="volume")

If we square root transform the input data we can see that the likelihood model’s assumption of constant variance is probably not too bad, although we’d want to fix it for “real” modeling.

scatter(a,sqrt.(y),label=false,xlabel="age",ylabel="Sqrt(volume)")

Laplace approximation

These are summarized from the account given in (Wood 2015). Let a statistical model with random and fixed effects by given by the joint density function of the data \(y\) and random effects \(b\) parameterized by \(\theta\), written as \(f_{\theta}(y,b)\). We call the likelihood the marginal density of the data:

\[ L(\theta) = f_{\theta}(y) = \int f_{\theta}(y,b) db \]

We can then evaluate \(L(\theta)\) through:

  1. Numerical integration
  2. Monte Carlo methods
  3. Approximate the integral
  4. Replace the integral with another function whose maximum coincides with the maximum of \(L(\theta)\).

The Laplace approximation falls under option 3. The EM algorithm falls under option 4.

The whole point of Laplace approximation is to approximate the integral to calculate \(L(\theta)\). Then one can find the MLE \(\hat{\theta}\) via standard optimization methods.

One starts by writing a second order Taylor expansion of \(\log f_{\theta}(y,\hat{b})\) around \(\hat{b}\), the values of \(b\) which maximize the joint density with fixed \(y\). After exponentiating the expansion and simplifying one gets:

\[ \int f_{\theta}(y,b) db \approx f_{\theta}(y,\hat{b}) \frac{2\pi^{n_{b}/2}}{|H|^{1/2}} \]

Where \(H\) is the Hessian of the joint density at \(\hat{b}\) and \(n_{b}\) is the number of random effects. So we have replaced the very difficult integral with an optimization problem (to get \(\hat{b}\)) and calculation of \(H\).

Julia implementation

Now we can look at how to use the Laplace approximation to do maximum likelihood estimation (MLE) on the urchin problem.

Let’s first load some packages we will need. We use ForwardDiff.jl for forward mode automatic differentiation, Optim.jl for optimization algorithms, LineSearches.jl for choosing options for optimization (see this issue in Optim.jl), and LinearAlgebra.jl for logdet.

using ForwardDiff
using Optim
using LineSearches 
using LinearAlgebra

First, we define the mathematical model which calculates urchin volumes \(v_{i}\) from the ages \(a_{i}\), and random effects (\(\log g_{i}, \log p_{i}\), which jointly make up \(b\)), and the parameter \(\log\omega\). Note that unlike R, Julia is not “vectorized” by default, so we use the @. operator to automatically use brodcasting for all functions in that line (see the broadcasting documentation for more information).

Please note we treat \(\log\omega\) as a vector due to an issue in ReverseDiff.jl https://github.com/JuliaDiff/ReverseDiff.jl/issues/214. While this tutorial will use ForwardDiff.jl, we keep this convention so reverse mode AD can be used if one wishes.

function urchins(log_ω, log_g, log_p, a)
    ω = @. exp(log_ω)
    p = @. exp(log_p)
    g = @. exp(log_g)
    am = @. log(p/(g*ω))/g
    μ = map(p,g,a,am) do pi,gi,ai,ami
        ifelse(ai < ami, ω[1] * exp(gi*ai), pi/gi + pi*(ai-ami))
    end
    return μ
end

Next we write the function \(f_{\theta}(y,b)\), the joint density of the data and random effects. We use the loglikelihood function from Distributions.jl to calculate the contributions to the likelihood (density) from \(y\) and \(b\).

using Distributions

function urchins_jll(y, θ, log_g, log_p, a)
    log_ω, μg, log_σg, μp, log_σp, log_σ = θ
    μ = urchins([log_ω], log_g, log_p, a) # predicted values
    data_ll = @. loglikelihood(Normal(sqrt(μ), exp(log_σ)), sqrt(y))
    g_ll = loglikelihood(Normal(μg, exp(log_σg)), log_g)
    p_ll = loglikelihood(Normal(μp, exp(log_σp)), log_p)
    return sum(data_ll) + g_ll + p_ll
end

Now we can define the function that does the Laplace approximation of the marginal log density (likelihood) of the data. We use ForwardDiff.jl for automatic differentiation of gradients (for the LBFGS optimization of \(\hat{b}\)) and Hessians (for \(H\)).

We use GradientConfig to get a better chunk size for forward AD (see docs). The function takes arguments y for the urchin volumes, a for the urchin ages, θ for the fixed effects/parameters, and also cols for a color vector and sparse for a sparsity pattern. These last two arguments are used to make calculation of the Hessian much more efficient. They are passed to numauto_color_hessian to compute \(H\), a function from SparseDiffTools.jl.

function urchins_laplace_ll(y, a, θ, cols, sparse)

    # generic info
    n = length(y)
    b = [fill(θ[2],n); fill(θ[4],n)]
    nb = length(b)

    # objective to optimize the random effects
    f(x) = -urchins_jll(y, θ, x[1:n], x[n+1:end], a)

    # set up ForwardDiff gradient for optimizing random effects
    gconfig = ForwardDiff.GradientConfig(f, b)
    g!(G, x) = ForwardDiff.gradient!(G, f, x, gconfig)

    # optim r.e.
    re_mle = optimize(f, g!, b, LBFGS(linesearch=LineSearches.BackTracking()))

    # compute Hessian at MLE of r.e.
    H = numauto_color_hessian(f, re_mle.minimizer, cols, sparse)

    return -re_mle.minimum + 0.5 * (log((2π)^nb) - logdet(H))
end

The final line of the function urchins_laplace_ll calculates the Laplace approximation to the marginal log likelihood. Take the log of the Laplace approximation of the marginal log likelihood: \(\log f_{\theta}(y,\hat{b}) \frac{2\pi^{n_{b}/2}}{|H|^{1/2}} = \log f_{\theta}(y,\hat{b}) + \log \frac{2\pi^{n_{b}/2}}{|H|^{1/2}}\). Then, \(\log \frac{2\pi^{n_{b}/2}}{|H|^{1/2}} = \log \left(\frac{2\pi^{n_{b}}}{|H|}\right)^{1/2} = 1/2 \left( \log 2\pi^{n_{b}} - \log |H|\right)\), so the final line calculates \(\log f_{\theta}(y,\hat{b}) + 1/2 \left( \log 2\pi^{n_{b}} - \log |H|\right)\).

Now we are ready to fit the random effects model. The starting values of parameters in θ0 are from Simon Wood’s book.

log_ω = -4.0
μg = -0.2
log_σg = log(0.1)
μp = 0.2
log_σp = log(0.1)
log_σ = log(0.5)

θ0::Vector{Float64} = [log_ω, μg, log_σg, μp, log_σp, log_σ]

We now calculate the color vector and sparsity pattern to efficiently compute \(H\).

using SparseDiffTools
using SparseArrays

f(x) = -urchins_jll(y, θ0, x[1:n], x[n+1:end], a)

n = length(y)
b = [fill0[2],n); fill0[4],n)] # should sample from the random effects, or make this more generic

# sparse hessian for computing the Hessian for Laplace approximation
A = ForwardDiff.hessian(f, b)
A = SparseArrays.sparse(A)
A_colors = SparseDiffTools.matrix_colors(A)

Now we can fit the model. Note we need to minimize the (approximate) negative log-likehood.

urchins_laplace_nll(θ) = -urchins_laplace_ll(y, a, θ, A_colors, A) 
fit = optimize(urchins_laplace_nll, θ0, Optim.Options(f_tol = 1e-5, g_tol = 1e-5, show_trace = false))

The value of the likelihood at its approximate maximum is about the same as the R code in the book:

fit.minimum
92.14564988143138

As are the estimates of \(\hat{\theta}\):

fit.minimizer
6-element Vector{Float64}:
 -4.010108413056277
 -0.20953496592362672
 -1.8228234557070078
  0.17312824510878386
 -1.6003923376943126
 -1.1715877153672012

We plot the data again, along with the estimated mean trajectory of an urchin, and also the 0.025 and 0.975 quantiles from 5000 samples from the random effect model to see the range of reasonable trajectories possible from the fitted model.

using Statistics

a_pred = collect(1:0.1:28)
log_ω_pred = fill(fit.minimizer[1], length(a_pred))
log_g_pred = fill(fit.minimizer[2], length(a_pred))
log_p_pred = fill(fit.minimizer[4], length(a_pred))
y_pred = urchins(log_ω_pred, log_g_pred, log_p_pred, a_pred)

samp = 5000
out = zeros(length(a_pred), samp)
for i in 1:samp
    log_g_pred_i = rand(Normal(fit.minimizer[2], exp(fit.minimizer[3])))
    log_p_pred_i = rand(Normal(fit.minimizer[4], exp(fit.minimizer[5])))
    log_ω_pred = fill(fit.minimizer[1], length(a_pred))
    log_g_pred = fill(log_g_pred_i, length(a_pred))
    log_p_pred = fill(log_p_pred_i, length(a_pred))
    out[:,i] = urchins(log_ω_pred, log_g_pred, log_p_pred, a_pred)
end

quant_lo = quantile.(eachrow(out), 0.025)
quant_hi = quantile.(eachrow(out), 0.975)

scatter(a,y,xlabel="age",ylabel="volume",legend=false)
plot!(a_pred,y_pred,seriescolor=2)
plot!(a_pred,quant_lo; fillrange = quant_hi,alpha=0.25,seriescolor=2)

References

Wood, Simon N. 2015. Core Statistics. 6. Cambridge University Press.