Fitting a continuous-time model

While mcstate is primarily designed to work with discrete-time stochastic models written in odin.dust, it is often possible to interact continuous-time deterministic models written in odin (subject to a few restrictions detailed in the Porting from odin vignette).

Considering a simple SIR model, as defined in the odin documentation: $$\begin{align*} \frac{dS}{dt} &= -\beta \frac{SI}{N} \\ \frac{dI}{dt} &= \beta \frac{SI}{N} - \gamma I \\ \frac{dR}{dt} &= \gamma I \\ \end{align*}$$ Where S is the number of susceptibles, I is the number of infected and R is the number recovered; the total population size N = S + I + R is constant. β is the infection rate, γ is the recovery rate.

We will attempt to infer parameters β and γ using weekly measures of population prevalence, for example taken from household prevalence surveys. These have been generated using β = 0.3 and γ = 0.1, assuming around 100 people are sampled each week. The data contain weekly entries n_tested denoting the number of people sampled, and n_positive recording the number of those sampled who tested positive.

data <- read.csv(system.file("sir_prevalence.csv", package = "mcstate"))

par(bty = "n", mar = c(3, 3, 1, 1), mgp = c(1.5, 0.5, 0))
test_cols <- c(n_tested = "darkorange2", n_positive = "green4")
plot(data$day, data$n_tested, type = "b",
     xlab = "Day", ylab = "Number of individuals",
     ylim = c(0, max(data$n_tested)), col = test_cols[1], pch = 20)
lines(data$day, data$n_positive, type = "b", col = test_cols[2], pch = 20)
legend("left", names(test_cols), fill = test_cols, bty = "n")

If we write our continuous-time model in the odin DSL, but compile it with odin.dust, we can use the calibration functions in mcstate to fit our model to data. In addition to the model, we include a section of code defining a likelihood function which we will use to compare model simulations to the observed data. We use a binomial likelihood:

$$ \mathrm{n_{positive}(t)} \sim \mathrm{Bin}\bigg(\mathrm{n_{tested}}(t), \frac{I(t)}{N(t)}\bigg)$$ Where the probability of testing positive at time t is given by the modelled population prevalence: $\frac{I(t)}{N(t)}$.

sir <-  odin.dust::odin_dust({
  ## model
  initial(S) <- N_init - I_init
  initial(I) <- I_init
  initial(R) <- 0
  initial(N) <- N_init
  
  deriv(S) <- -beta * S * I / N
  deriv(I) <- beta * S * I / N - gamma * I
  deriv(R) <- gamma * I
  deriv(N) <- 0
  
  beta <- user(0.4)
  gamma <- user(0.2)
  N_init <- user(1e5)
  I_init <- user(1)
  
  ## likelihood
  n_positive <- data()
  n_tested <- data()
  compare(n_positive) ~ binomial(n_tested, I / N) 
})
#> Loading required namespace: pkgbuild
#> ℹ 21 functions decorated with [[cpp11::register]]
#> ✔ generated file 'cpp11.R'
#> ✔ generated file 'cpp11.cpp'
#> ℹ Re-compiling odin64b3d728
#> ── R CMD INSTALL ───────────────────────────────────────────────────────────────
#> * installing *source* package ‘odin64b3d728’ ...
#> ** using staged installation
#> ** libs
#> using C++ compiler: ‘g++ (Ubuntu 13.2.0-23ubuntu4) 13.2.0’
#> g++ -std=gnu++17 -I"/usr/share/R/include" -DNDEBUG  -I'/github/workspace/pkglib/cpp11/include'    -I"/github/workspace/pkglib/dust/include" -DHAVE_INLINE -fopenmp  -fpic  -O2 -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -ffile-prefix-map=/build/r-base-YnquTi/r-base-4.4.1=. -fstack-protector-strong -fstack-clash-protection -Wformat  -fcf-protection -fdebug-prefix-map=/build/r-base-YnquTi/r-base-4.4.1=/usr/src/r-base-4.4.1-3.2404.0 -Wdate-time -D_FORTIFY_SOURCE=3  -Wall -pedantic  -c cpp11.cpp -o cpp11.o
#> g++ -std=gnu++17 -I"/usr/share/R/include" -DNDEBUG  -I'/github/workspace/pkglib/cpp11/include'    -I"/github/workspace/pkglib/dust/include" -DHAVE_INLINE -fopenmp  -fpic  -O2 -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -ffile-prefix-map=/build/r-base-YnquTi/r-base-4.4.1=. -fstack-protector-strong -fstack-clash-protection -Wformat  -fcf-protection -fdebug-prefix-map=/build/r-base-YnquTi/r-base-4.4.1=/usr/src/r-base-4.4.1-3.2404.0 -Wdate-time -D_FORTIFY_SOURCE=3  -Wall -pedantic  -c dust.cpp -o dust.o
#> g++ -std=gnu++17 -shared -L/usr/lib/R/lib -Wl,-Bsymbolic-functions -flto=auto -ffat-lto-objects -Wl,-z,relro -o odin64b3d728.so cpp11.o dust.o -fopenmp -L/usr/lib/R/lib -lR
#> installing to /tmp/Rtmp9xLHpI/devtools_install_113e78e0992b/00LOCK-file113e6bdd4269/00new/odin64b3d728/libs
#> ** checking absolute paths in shared objects and dynamic libraries
#> * DONE (odin64b3d728)
#> ℹ Loading odin64b3d728

To calibrate the model to the data, we first need to process the data into the mcstate format. As the rate parameter is not meaningful in the context of a continuous time model, we must supply a NULL value.

sir_data <- mcstate::particle_filter_data(data = data, time = "day",
                                          initial_time = 0, rate = NULL)

rmarkdown::paged_table(sir_data)

We define a simple index function for our model, to provide meaningful parameter names:

index <- function(info) {
  idx <- unlist(info$index)
  list(run = idx, state = idx)
}

Although we are fitting a deterministic model, to use mcstate functionality we still need to define a particle filter object. We set compare = NULL as we have defined a compiled comparison function in our model code; alternatively you can supply a compare function here if you wish (for example, if you want to use distributions not yet supported in the odin DSL), by proceeding as described in the SIR models with odin, dust and mcstate vignette.


filter <- mcstate::particle_deterministic$new(data = sir_data, model = sir,
                                       index = index, compare = NULL)

Running this filter will give us the same output as running the deterministic model:

filter$run(save_history = TRUE, pars = list())
#> [1] -74.86937

times <- c(0, data$day)
mod <- sir$new(list(), time = 0, n_particles = 1)
mod_history <- vapply(times, mod$simulate, numeric(length(mod$info()$dim)))

par(mfrow = c(1, 2), mar = c(3, 3, 1, 1), mgp = c(1.5, 0.5, 0), bty = "n")
cols <- c(S = "#8c8cd9", I = "#cc0044", R = "#999966", N = "black")
matplot(times, t(mod_history), type = "l",
        xlab = "Day", ylab = "Number of individuals", col = cols, lty = 1,
        main = "Model history")
legend("left", lwd = 1, col = cols, legend = names(cols), bty = "n")
matplot(times, t(drop(filter$history())), type = "l",
        xlab = "Day", ylab = "", col = cols, lty = 1,
        main = "Filter history")

Using MCMC to infer parameters

We can now perform MCMC using the Metropolis-Hastings algorithm to infer the values of β and γ from prevalence data. Let’s first describe the parameters we wish to infer, giving a minimum, maximum. For γ we will also apply a gamma prior with shape 1 and rate 0.2: Γ(1, 0.2):

beta <- mcstate::pmcmc_parameter("beta", 0.2, min = 0)
gamma <- mcstate::pmcmc_parameter("gamma", 0.1, min = 0, prior = function(p) {
  dgamma(p, shape = 1, scale = 0.2, log = TRUE)
})

proposal_matrix <- diag(1e-2, 2)
mcmc_pars <- mcstate::pmcmc_parameters$new(list(beta = beta, gamma = gamma),
                                           proposal_matrix)

We can now run our MCMC, opting for 4 chains run in parallel. To speed up chain convergence, we take advantage of the adaptive proposal functionality that is available for deterministic models.

control <- mcstate::pmcmc_control(n_steps = 1e4,
                                  n_workers = 4,
                                  n_threads_total = 4,
                                  n_burnin = 500,
                                  n_chains = 4,
                                  adaptive_proposal = TRUE,
                                  save_state = TRUE,
                                  save_trajectories = TRUE,
                                  progress = TRUE, )
mcmc_run <- mcstate::pmcmc(mcmc_pars, filter, control = control)
#> Finished 40000 steps in 13 secs
mcmc_sample <- mcstate::pmcmc_sample(mcmc_run, n_sample = 500)

We can visually check the chains have converged and assess the accuracy of our parameter estimates using a sample from the posterior distribution.

par(mfrow = c(2, 2), mar = c(3, 3, 1, 1), mgp = c(1.5, 0.5, 0), bty = "n")
plot(mcmc_run$pars[, "beta"], type = "l", xlab = "Iteration",
     ylab = expression(beta))
plot(mcmc_run$pars[, "gamma"], type = "l", xlab = "Iteration",
     ylab = expression(gamma))
hist(mcmc_sample$pars[, "beta"], main = "", xlab = expression(beta),
     freq = FALSE)
abline(v = 0.3, lty = 2, col = "darkred", lwd = 3)
hist(mcmc_sample$pars[, "gamma"], main = "", xlab = expression(gamma),
     freq = FALSE)
abline(v = 0.1, lty = 2, col = "darkred", lwd = 3)
legend("topright", legend = "True value", col = "darkred", lty = 2, bty = "n")

We can compare our fitted trajectories to the prevalence data by sampling from the our comparison distribution to obtain an estimate of the number of positive tests under our model.


state <- mcmc_sample$trajectories$state

prevalence <-  t(state["I", , -1] / state["N", , -1])
modelled_positives <- apply(prevalence, 2, rbinom, n = nrow(data),
                            size = data$n_tested)

par(mfrow = c(1, 2), mar = c(3, 3, 1, 1), mgp = c(1.5, 0.5, 0), bty = "n")
matplot(times, t(state["S", , ]), type = "l", lty = 1, col = cols["S"],
        ylim = c(0, max(state)), xlab = "Day", ylab = "Number of individuals")
matlines(times, t(state["R", , ]), lty = 1, col = cols["R"])
matlines(times, t(state["I", , ]), lty = 1, col = cols["I"])
legend("left", lwd = 1, col = cols[-4], legend = names(cols)[-4], bty = "n")

matplot(data$day, modelled_positives, type = "l", lty = 1, col = grey(0.7, 0.2),
        xlab = "Day", ylab = "Number of positive tests")
points(data$day, data$n_positive, pch = 20)