using CSV
using HTTP
using DataFrames
= HTTP.request("GET", "https://www.maths.ed.ac.uk/~swood34/data/urchin-vol.txt")
http_response = CSV.read(IOBuffer(http_response.body), DataFrame, header=["row", "age", "vol"], skipto=2) uv
Laplace approximation to fit models including random effects in Julia
Urchin model
This example comes from Simon Wood’s book Core Statistics, available for free on his website (click the link and scroll down). See (Wood 2015, pp 99-111) for more information.
This is a nice example of how to use the Laplace approximation to fit models with random effects in Julia because it does not use special DSLs (domain specific languages) and as much as possible sticks to Julia functionality available for general users. The model itself is coded in standard idiomatic Julia, and therefore should be applicable to a fairly broad class of statistical models.
The model is one of sea urchin growth. There is a mechanistic model to relate volume \(V\) to age \(a\) by the differential equation:
\[ \frac{dV}{da} = \begin{cases} g_{i} V, & V < p_{i}/g_{i} \\ p_{i} & \text{otherwise} \end{cases} \]
This is an “individual-based” model, because the random effects \(p\) and \(g\) are indexed by the individual urchin. Each urchin has an initial condition for the ODE (initial volume) of \(\omega\) (a fixed effect). The differential equation changes when the animal hits reproductive age, which is:
\[ a_{mi} = \frac{1}{g_{i}}\log\left( \frac{p_{i}}{g_{i}\omega}\right) \]
The ODE has the analytical solution:
\[ V(a) = \begin{cases} \omega \exp(g_{i} a_{i}), & a_{i} < a_{mi} \\ p_{i}/g_{i} + p_{i}(a_{i}-a_{mi}) & \text{otherwise} \end{cases} \]
The measured urchin volumes are \(v\), and the likelihood model is:
\[ \sqrt v_{i} \sim N(\sqrt V(a_{i}),\sigma^{2}) \]
And we assume the random effects follow:
\[ \begin{split} \log g_{i} \sim N(\mu_{g},\sigma_{g}^{2}) \\ \log p_{i} \sim N(\mu_{p},\sigma_{p}^{2}) \ \end{split} \]
The model parameters (“fixed effects”) are \(\omega, \mu_{g}, \sigma_{g}, \mu_{p}, \sigma_{p}, \sigma\), and the random effects are \(g, p\).
Urchin data
We are going to load the data from Simon Wood’s website.
We can take a look at it:
using Plots
::Vector{Float64} = uv[:,:vol]
y::Vector{Int64} = uv[:,:age]
a
scatter(a,y,label=false,xlabel="age",ylabel="volume")
If we square root transform the input data we can see that the likelihood model’s assumption of constant variance is probably not too bad, although we’d want to fix it for “real” modeling.
scatter(a,sqrt.(y),label=false,xlabel="age",ylabel="Sqrt(volume)")
Laplace approximation
These are summarized from the account given in (Wood 2015). Let a statistical model with random and fixed effects by given by the joint density function of the data \(y\) and random effects \(b\) parameterized by \(\theta\), written as \(f_{\theta}(y,b)\). We call the likelihood the marginal density of the data:
\[ L(\theta) = f_{\theta}(y) = \int f_{\theta}(y,b) db \]
We can then evaluate \(L(\theta)\) through:
- Numerical integration
- Monte Carlo methods
- Approximate the integral
- Replace the integral with another function whose maximum coincides with the maximum of \(L(\theta)\).
The Laplace approximation falls under option 3. The EM algorithm falls under option 4.
The whole point of Laplace approximation is to approximate the integral to calculate \(L(\theta)\). Then one can find the MLE \(\hat{\theta}\) via standard optimization methods.
One starts by writing a second order Taylor expansion of \(\log f_{\theta}(y,\hat{b})\) around \(\hat{b}\), the values of \(b\) which maximize the joint density with fixed \(y\). After exponentiating the expansion and simplifying one gets:
\[ \int f_{\theta}(y,b) db \approx f_{\theta}(y,\hat{b}) \frac{2\pi^{n_{b}/2}}{|H|^{1/2}} \]
Where \(H\) is the Hessian of the joint density at \(\hat{b}\) and \(n_{b}\) is the number of random effects. So we have replaced the very difficult integral with an optimization problem (to get \(\hat{b}\)) and calculation of \(H\).
Julia implementation
Now we can look at how to use the Laplace approximation to do maximum likelihood estimation (MLE) on the urchin problem.
Let’s first load some packages we will need. We use ForwardDiff.jl for forward mode automatic differentiation, Optim.jl for optimization algorithms, LineSearches.jl for choosing options for optimization (see this issue in Optim.jl), and LinearAlgebra.jl for logdet
.
using ForwardDiff
using Optim
using LineSearches
using LinearAlgebra
First, we define the mathematical model which calculates urchin volumes \(v_{i}\) from the ages \(a_{i}\), and random effects (\(\log g_{i}, \log p_{i}\), which jointly make up \(b\)), and the parameter \(\log\omega\). Note that unlike R, Julia is not “vectorized” by default, so we use the @.
operator to automatically use brodcasting for all functions in that line (see the broadcasting documentation for more information).
Please note we treat \(\log\omega\) as a vector due to an issue in ReverseDiff.jl https://github.com/JuliaDiff/ReverseDiff.jl/issues/214. While this tutorial will use ForwardDiff.jl, we keep this convention so reverse mode AD can be used if one wishes.
function urchins(log_ω, log_g, log_p, a)
= @. exp(log_ω)
ω = @. exp(log_p)
p = @. exp(log_g)
g = @. log(p/(g*ω))/g
am = map(p,g,a,am) do pi,gi,ai,ami
μ ifelse(ai < ami, ω[1] * exp(gi*ai), pi/gi + pi*(ai-ami))
end
return μ
end
Next we write the function \(f_{\theta}(y,b)\), the joint density of the data and random effects. We use the loglikelihood
function from Distributions.jl to calculate the contributions to the likelihood (density) from \(y\) and \(b\).
using Distributions
function urchins_jll(y, θ, log_g, log_p, a)
= θ
log_ω, μg, log_σg, μp, log_σp, log_σ = urchins([log_ω], log_g, log_p, a) # predicted values
μ = @. loglikelihood(Normal(sqrt(μ), exp(log_σ)), sqrt(y))
data_ll = loglikelihood(Normal(μg, exp(log_σg)), log_g)
g_ll = loglikelihood(Normal(μp, exp(log_σp)), log_p)
p_ll return sum(data_ll) + g_ll + p_ll
end
Now we can define the function that does the Laplace approximation of the marginal log density (likelihood) of the data. We use ForwardDiff.jl for automatic differentiation of gradients (for the LBFGS optimization of \(\hat{b}\)) and Hessians (for \(H\)).
We use GradientConfig
to get a better chunk size for forward AD (see docs). The function takes arguments y
for the urchin volumes, a
for the urchin ages, θ
for the fixed effects/parameters, and also cols
for a color vector and sparse
for a sparsity pattern. These last two arguments are used to make calculation of the Hessian much more efficient. They are passed to numauto_color_hessian
to compute \(H\), a function from SparseDiffTools.jl.
function urchins_laplace_ll(y, a, θ, cols, sparse)
# generic info
= length(y)
n = [fill(θ[2],n); fill(θ[4],n)]
b = length(b)
nb
# objective to optimize the random effects
f(x) = -urchins_jll(y, θ, x[1:n], x[n+1:end], a)
# set up ForwardDiff gradient for optimizing random effects
= ForwardDiff.GradientConfig(f, b)
gconfig g!(G, x) = ForwardDiff.gradient!(G, f, x, gconfig)
# optim r.e.
= optimize(f, g!, b, LBFGS(linesearch=LineSearches.BackTracking()))
re_mle
# compute Hessian at MLE of r.e.
= numauto_color_hessian(f, re_mle.minimizer, cols, sparse)
H
return -re_mle.minimum + 0.5 * (log((2π)^nb) - logdet(H))
end
The final line of the function urchins_laplace_ll
calculates the Laplace approximation to the marginal log likelihood. Take the log of the Laplace approximation of the marginal log likelihood: \(\log f_{\theta}(y,\hat{b}) \frac{2\pi^{n_{b}/2}}{|H|^{1/2}} = \log f_{\theta}(y,\hat{b}) + \log \frac{2\pi^{n_{b}/2}}{|H|^{1/2}}\). Then, \(\log \frac{2\pi^{n_{b}/2}}{|H|^{1/2}} = \log \left(\frac{2\pi^{n_{b}}}{|H|}\right)^{1/2} = 1/2 \left( \log 2\pi^{n_{b}} - \log |H|\right)\), so the final line calculates \(\log f_{\theta}(y,\hat{b}) + 1/2 \left( \log 2\pi^{n_{b}} - \log |H|\right)\).
Now we are ready to fit the random effects model. The starting values of parameters in θ0
are from Simon Wood’s book.
= -4.0
log_ω = -0.2
μg = log(0.1)
log_σg = 0.2
μp = log(0.1)
log_σp = log(0.5)
log_σ
0::Vector{Float64} = [log_ω, μg, log_σg, μp, log_σp, log_σ] θ
We now calculate the color vector and sparsity pattern to efficiently compute \(H\).
using SparseDiffTools
using SparseArrays
f(x) = -urchins_jll(y, θ0, x[1:n], x[n+1:end], a)
= length(y)
n = [fill(θ0[2],n); fill(θ0[4],n)] # should sample from the random effects, or make this more generic
b
# sparse hessian for computing the Hessian for Laplace approximation
= ForwardDiff.hessian(f, b)
A = SparseArrays.sparse(A)
A = SparseDiffTools.matrix_colors(A) A_colors
Now we can fit the model. Note we need to minimize the (approximate) negative log-likehood.
urchins_laplace_nll(θ) = -urchins_laplace_ll(y, a, θ, A_colors, A)
= optimize(urchins_laplace_nll, θ0, Optim.Options(f_tol = 1e-5, g_tol = 1e-5, show_trace = false)) fit
The value of the likelihood at its approximate maximum is about the same as the R code in the book:
fit.minimum
92.14564988143138
As are the estimates of \(\hat{\theta}\):
fit.minimizer
6-element Vector{Float64}:
-4.010108413056277
-0.20953496592362672
-1.8228234557070078
0.17312824510878386
-1.6003923376943126
-1.1715877153672012
We plot the data again, along with the estimated mean trajectory of an urchin, and also the 0.025 and 0.975 quantiles from 5000 samples from the random effect model to see the range of reasonable trajectories possible from the fitted model.
using Statistics
= collect(1:0.1:28)
a_pred = fill(fit.minimizer[1], length(a_pred))
log_ω_pred = fill(fit.minimizer[2], length(a_pred))
log_g_pred = fill(fit.minimizer[4], length(a_pred))
log_p_pred = urchins(log_ω_pred, log_g_pred, log_p_pred, a_pred)
y_pred
= 5000
samp = zeros(length(a_pred), samp)
out for i in 1:samp
= rand(Normal(fit.minimizer[2], exp(fit.minimizer[3])))
log_g_pred_i = rand(Normal(fit.minimizer[4], exp(fit.minimizer[5])))
log_p_pred_i = fill(fit.minimizer[1], length(a_pred))
log_ω_pred = fill(log_g_pred_i, length(a_pred))
log_g_pred = fill(log_p_pred_i, length(a_pred))
log_p_pred :,i] = urchins(log_ω_pred, log_g_pred, log_p_pred, a_pred)
out[end
= quantile.(eachrow(out), 0.025)
quant_lo = quantile.(eachrow(out), 0.975)
quant_hi
scatter(a,y,xlabel="age",ylabel="volume",legend=false)
plot!(a_pred,y_pred,seriescolor=2)
plot!(a_pred,quant_lo; fillrange = quant_hi,alpha=0.25,seriescolor=2)