Differential equations offer a powerful way of describing dynamic processes. They are widely used to describe non-linear behaviours in models of many areas including biology. Parameterising these differential equation models allows us to relate them to reality via observed data. Often there is uncertainty in our observed data that we want to account for when fitting our model and making preditions about behaviour in new conditions. We will use Bayesian methods to incorporate this uncertainty. Modern probabilistic programming software such as Stan allows us to make use of powerful statistical sampling methods such as Hamiltonian Monte Carlo. We write a model description in Stan that we can then call from R, Python or other languages.

I found when starting to use Stan that it takes a bit more time to think about a problem to set it up in the right framework, but that the rewards for doing so are more than worthwhile in terms of efficiency of sampling and clarity of thinking about the problem.

Here, I will look at a case study of fitting a simple ordinary differential equation model to data. Hopefully this will be helpful for anyone (like me!) who finds it helpful to learn from examples. For a more gentle introduction to Stan itself, I’d recommend https://cran.r-project.org/web/packages/rstan/vignettes/rstan.html

Key ingredients for a Bayesian model

  1. Process model
  2. Measurement model
  3. Parameter model

Process model

This is our model for the underlying process that we are interested in. In general this could be a linear regression model. In our context here, it will be a different equation model.

Measurement model

We assume that when we measure the state of the system, there is some error associated with that measurement. In a Bayesian context, we want a model of how that error is distributed. How big will the error be? Will we be measuring count data, where counts appear at a certain rate?

Parameter model

By taking a Bayesian approach, we assume that our parameters are random variables. We must specify our beliefs about these variables in the form of a prior distribution. This prior distribution is our parameter model. Using the model and our data, we are then able to update these beliefs and obtain samples from a posterior distribution.

We consider bacterial growth, which we will model with a logistic growth model, and parameterise this based on data from the growthcurver package. This provides bacterial growth data from a 96 well plate over time. Here we convert the data to long format and plot to see the variability.

library(rstan)
library(dplyr)
library(tidyr)
library(ggplot2)
library(growthcurver) #contains the growthdata dataset
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

long_growthdata <- growthdata %>% gather(well,absorbance,-time)
glimpse(long_growthdata)
## Observations: 13,920
## Variables: 3
## $ time       <dbl> 0.0000000, 0.1666667, 0.3333333, 0.5000000, 0.66666...
## $ well       <chr> "A1", "A1", "A1", "A1", "A1", "A1", "A1", "A1", "A1...
## $ absorbance <dbl> 0.05348585, 0.04800336, 0.05587451, 0.05131749, 0.0...
ggplot(long_growthdata,aes(time,absorbance,group=well)) +
  geom_line() + 
  theme_bw()

Plotting all the data together we can see that there is a lot of heterogeneity in the data, with the main growth phase taking place across a range of times and growing to different maximum levels.

Logistic growth model

We will start by considering a logistic growth model for the bacterial growth: \[ \frac{\text{d}y}{\text{d}t} = \theta_1 y \left( 1 - \frac{y}{\theta_2}\right). \] In the early phase, this assumes growth occurrs at an approximately constant rate, \(\theta_1\), (equivalent to exponential growth), before later the growth rate is limited by the resources available via a carrying capactity, \(\theta_2\). We assume all the measurements from different wells are governed by a single set of parameters, \(\mathbf{\theta} = [\theta_1, \theta_2]\), and we attempt to infer these parameters.

Now we want to define our process, measurement and parameter models in the appropriate code blocks in Stan.

functions {
  real[] logisticgrowth(real t,
                  real[] y,
                  real[] theta,
                  real[] x_r,
                  int[] x_i
                  ) {
    real dydt[x_i[1]];
    for (i in 1:x_i[1]){
      dydt[i] = theta[1] * y[i] * (1-y[i]/theta[2]);
    }
    return dydt;
  }
}
data {
  int<lower=1> T;
  int<lower=1> n_wells;
  real y0[n_wells];
  real z[T,n_wells];
  real t0;
  real ts[T];
}
transformed data {
  real x_r[0];
  int x_i[1];
  x_i[1] = n_wells;
}
parameters {
  real<lower=0> theta[2];
  real<lower=0> sigma;
}
model {
  real y_hat[T,n_wells];
  theta ~ cauchy(0,2.5);
  sigma ~ normal(0,0.01);
  y_hat = integrate_ode_rk45(logisticgrowth, y0, t0, ts, theta, x_r, x_i);
  for (t in 1:T) {
    for (i in 1:n_wells) {
      z[t,i] ~ normal(y_hat[t,i], sigma);
    }
  }
}
generated quantities{
  real y_pred[T,n_wells];
  real z_pred[T,n_wells];
  y_pred = integrate_ode_rk45(logisticgrowth, y0, t0, ts, theta, x_r, x_i );
  for (t in 1:T) {
    for(i in 1:n_wells){
      z_pred[t,i] = y_pred[t,i] + normal_rng(0,sigma);
    }
  }
}

We have specified a model here for the data from all 96 wells in the plate, assuming that the same parameters describe all of these data. The stan code above is compiled into a stan model object. This allows sampling under the hood to be performed via efficient C++ code.

Fitting the model via MCMC

We fit the model by calling the following code from R.

nSamples = nrow(growthdata) - 1 #use time=0 as initial condition, take this as fixed
y0 = filter(growthdata,time==0) %>% select(-time) %>% unlist #initial condition
t0 = 0.0
ts = filter(growthdata,time>0) %>% select(time) %>% unlist
z = filter(growthdata,time>0) %>% select(-time)
n_wells = 9 #running on all wells can be slow
estimates <- sampling(object = logisticgrowth_stan,
                  data = list (
                    T  = nSamples,
                    n_wells = n_wells,
                    y0 = y0[1:n_wells],
                    z  = z[,1:n_wells],
                    t0 = t0,
                    ts = ts
                  ),
                  seed = 123,
                  chains = 4,
                  iter = 1000,
                  warmup = 500
)

parametersToPlot = c("theta","sigma","lp__")
print(estimates, pars = parametersToPlot)
## Inference for Stan model: 10112d5a3cfcf5e400dfa033f5f3ea6e.
## 4 chains, each with iter=1000; warmup=500; thin=1; 
## post-warmup draws per chain=500, total post-warmup draws=2000.
## 
##             mean se_mean   sd    2.5%     25%     50%     75%   97.5%
## theta[1]    0.23    0.00 0.00    0.22    0.23    0.23    0.23    0.24
## theta[2]    0.46    0.00 0.00    0.45    0.46    0.46    0.47    0.47
## sigma       0.07    0.00 0.00    0.06    0.07    0.07    0.07    0.07
## lp__     2798.74    0.04 1.22 2795.74 2798.18 2799.03 2799.65 2800.13
##          n_eff Rhat
## theta[1]   942 1.00
## theta[2]  1160 1.00
## sigma     1345 1.00
## lp__       974 1.01
## 
## Samples were drawn using NUTS(diag_e) at Sun Mar 25 13:24:21 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

Visualising results

Great, so we have managed to sample from the differential equation model. We have managed to quantify the distribution of our parameters. Let’s have a look at the posterior distribution that we have obtained and perform some checks to make sure everything is working ok.

library(bayesplot)
## This is bayesplot version 1.4.0
## - Plotting theme set to bayesplot::theme_default()
## - Online documentation at mc-stan.org/bayesplot
draws <- as.array(estimates, pars=parametersToPlot)
mcmc_trace(draws)

color_scheme_set("brightblue")
mcmc_scatter(draws,pars=c('theta[1]','theta[2]'))

These results look great, we are able to identify some parameters for our model. Lets check the posterior predictive distribution to see if predictions made from this model actually look similar to our original data.

xdata <- data.frame(absorbance = unlist(z[,1:n_wells]),well = as.vector(matrix(rep(1:n_wells,nSamples),nrow=nSamples,byrow=TRUE)),time = rep(ts,n_wells))
pred <- as.data.frame(estimates, pars = "z_pred") %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(lb = quantile(value, probs = 0.05),
            median = quantile(value, probs = 0.5),
            ub = quantile(value, probs = 0.95)) %>%
  bind_cols(xdata)

p1 <- ggplot(pred, aes(x = time, y = absorbance))
p1 <- p1 + geom_point() +
  labs(x = "time (h)", y = "absorbance") +
  theme(text = element_text(size = 12), axis.text = element_text(size = 12),
        legend.position = "none", strip.text = element_text(size = 8))
p1 + geom_line(aes(x = time, y = median)) +
  geom_ribbon(aes(ymin = lb, ymax = ub), alpha = 0.25) +
  facet_wrap(~factor(well))

Uh oh. Looks like something is not quite right here. The grey band gives a 95% posterior predictive interval from the model, while the observed data are shown as black circles (appearing like a thick black line here).

  1. The data exhibit a lag phase where not much growth happens, before then exhibiting fast growth and slowing when reaching the carrying capacity.
  2. The model is not capturing the individual variability across each well in the plate. It tries to average over all the wells in some sense to find a single set of parameters to describe all the data.

Adding a lag

But first lets try adding a lag to the model. The logistic growth differential equation we just fitted is simple enough to solve analytically as \[y(t) = \frac{A}{ (1 + B * \exp(-C * t))}\] for some constants \(A,B,C\), where … We can add an extra parameter to this model to describe the lag phase, using the Richards function: \[y(t) = \frac{A}{ (1 + B * \exp(-C * (t-D)))^{1/B}}\] We attempt to fit this in the same way as before.

data {
  int<lower=1> T;
  int<lower=1> n_wells;
  real y0[n_wells];
  real z[T,n_wells];
  real t0;
  real ts[T];
}
transformed data {
  real x_r[0];
  int x_i[1];
  x_i[1] = n_wells;
}
parameters {
  real<lower=0> theta[4];
  real<lower=0> sigma;
}
model {
  real y_hat[T,n_wells];
  theta ~ cauchy(0,2.5);
  sigma ~ normal(0,0.01);
  for (t in 1:T){
  for (i in 1:n_wells){
    y_hat[t,i] = theta[1] / (1 + (theta[2]) * exp(-theta[3] * (ts[t]-theta[4])))^(1/theta[2]);
  }
  }
  for (t in 1:T) {
    for (i in 1:n_wells) {
      z[t,i] ~ normal(y_hat[t,i], sigma);
    }
  }
}
generated quantities{
  real y_pred[T,n_wells];
  real z_pred[T,n_wells];
  for (t in 1:T){
  for (i in 1:n_wells){
  y_pred[t,i] = theta[1] / (1 + (theta[2]) * exp(-theta[3] * (ts[t]-theta[4])))^(1/theta[2]);
  }
}
  for (t in 1:T) {
    for(i in 1:n_wells){
      z_pred[t,i] = y_pred[t,i] + normal_rng(0,sigma);
    }
  }
}
## Inference for Stan model: a29173cb61a497bb771f9bdb3b5325b0.
## 4 chains, each with iter=1000; warmup=500; thin=1; 
## post-warmup draws per chain=500, total post-warmup draws=2000.
## 
##             mean se_mean   sd    2.5%     25%     50%     75%   97.5%
## theta[1]    0.41    0.00 0.00    0.41    0.41    0.41    0.41    0.41
## theta[2]    5.08    0.04 0.92    3.61    4.42    4.93    5.60    7.22
## theta[3]    1.54    0.01 0.24    1.18    1.38    1.51    1.68    2.12
## theta[4]   10.42    0.01 0.16   10.12   10.31   10.41   10.52   10.75
## sigma       0.05    0.00 0.00    0.05    0.05    0.05    0.05    0.05
## lp__     3225.34    0.06 1.64 3221.20 3224.49 3225.70 3226.56 3227.45
##          n_eff Rhat
## theta[1]  1481 1.00
## theta[2]   455 1.01
## theta[3]   473 1.01
## theta[4]   511 1.00
## sigma     1034 1.00
## lp__       698 1.00
## 
## Samples were drawn using NUTS(diag_e) at Sun Mar 25 13:25:24 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

The extra parameter here seems to give us a better fit to the shape of the growth curve capturing the lag better. However, the heterogeneity in the data is not captured by the model. Lets try a hierarchical version of this same model.

Hierarchical models

In a hierarchical model we assume that the processes generating the data in each well are the same but the parameter values may be different. These are drawn from a common distribution of parameters. This offers a powerful way to describe the heterogeneity amongst a population. We can fit the model as before, although it is more challenging since there are many more parameters to estimate.

data {
  int<lower=1> T;
  int<lower=1> n_wells;
  real y0[n_wells];
  real z[T,n_wells];
  real t0;
  real ts[T];
}
parameters {
  real<lower=0> mu_th[4];
  real<lower=0> tau[4];
  real<lower=0> mu_sig;
  real<lower=0> xi;
  real<lower=0> theta[n_wells,4];
  real<lower=0> sigma[n_wells];
}
model {
  real y_hat[T,n_wells];
  mu_th ~ cauchy(0,2.5);
  mu_sig ~ normal(0,0.01); 
  tau ~ cauchy(0,2.5);
  xi ~ cauchy(0,2.5);
  for (i in 1:n_wells){
    theta[i,1:4] ~ normal(mu_th,tau);
    sigma[i] ~ normal(mu_sig,xi);
  }
  for (t in 1:T){
  for (i in 1:n_wells){
      y_hat[t,i] = theta[i,1] / (1 + (theta[i,2]) * exp(-theta[i,3] * (ts[t]-theta[i,4])))^(1/theta[i,2]);
  }
  }
  for (t in 1:T) {
    for (i in 1:n_wells) {
      z[t,i] ~ normal(y_hat[t,i], sigma[i]);
    }
  }
}
generated quantities{
  real y_pred[T,n_wells];
  real z_pred[T,n_wells];
  for (t in 1:T){
  for (i in 1:n_wells){
        y_pred[t,i] = theta[i,1] / (1 + (theta[i,2]) * exp(-theta[i,3] * (ts[t]-theta[i,4])))^(1/theta[i,2]);
  }
}
  for (t in 1:T) {
    for(i in 1:n_wells){
      z_pred[t,i] = y_pred[t,i] + normal_rng(0,sigma[i]);
    }
  }
}
## Warning: There were 11 divergent transitions after warmup. Increasing adapt_delta above 0.99 may help. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## Warning: Examine the pairs() plot to diagnose sampling problems
## Inference for Stan model: 5715aa8b52d56fa156827b90c6988e8b.
## 4 chains, each with iter=1000; warmup=500; thin=1; 
## post-warmup draws per chain=500, total post-warmup draws=2000.
## 
##               mean se_mean   sd    2.5%     25%     50%     75%   97.5%
## theta[1,1]    0.38    0.00 0.00    0.37    0.38    0.38    0.38    0.38
## theta[1,2]    7.29    0.04 0.73    6.05    6.80    7.23    7.71    8.92
## theta[1,3]    2.45    0.01 0.23    2.06    2.30    2.43    2.58    2.99
## theta[1,4]    9.80    0.00 0.09    9.64    9.74    9.80    9.86    9.97
## theta[2,1]    0.45    0.00 0.00    0.44    0.44    0.45    0.45    0.45
## theta[2,2]    7.18    0.06 0.71    5.97    6.71    7.12    7.60    8.74
## theta[2,3]    2.41    0.01 0.22    2.02    2.25    2.39    2.53    2.90
## theta[2,4]   11.04    0.00 0.08   10.88   10.98   11.03   11.09   11.20
## theta[3,1]    0.41    0.00 0.00    0.41    0.41    0.41    0.41    0.42
## theta[3,2]    7.38    0.05 0.75    6.06    6.88    7.32    7.85    8.93
## theta[3,3]    2.31    0.01 0.22    1.91    2.17    2.30    2.45    2.77
## theta[3,4]   11.23    0.00 0.09   11.06   11.16   11.23   11.29   11.40
## theta[4,1]    0.42    0.00 0.00    0.42    0.42    0.42    0.42    0.43
## theta[4,2]    7.26    0.04 0.73    6.00    6.76    7.19    7.70    8.83
## theta[4,3]    2.41    0.01 0.22    2.02    2.26    2.39    2.54    2.91
## theta[4,4]   10.73    0.00 0.08   10.56   10.68   10.73   10.78   10.90
## theta[5,1]    0.41    0.00 0.00    0.41    0.41    0.41    0.42    0.42
## theta[5,2]    6.93    0.04 0.70    5.69    6.45    6.87    7.36    8.50
## theta[5,3]    2.59    0.02 0.24    2.16    2.42    2.57    2.73    3.14
## theta[5,4]    9.38    0.00 0.08    9.23    9.32    9.38    9.43    9.53
## theta[6,1]    0.30    0.00 0.00    0.30    0.30    0.30    0.30    0.31
## theta[6,2]    8.37    0.11 1.31    6.37    7.34    8.16    9.27   11.16
## theta[6,3]    2.00    0.02 0.30    1.55    1.76    1.96    2.21    2.62
## theta[6,4]   11.78    0.01 0.15   11.50   11.68   11.78   11.88   12.06
## theta[7,1]    0.38    0.00 0.00    0.38    0.38    0.38    0.38    0.38
## theta[7,2]    7.08    0.04 0.71    5.85    6.59    7.02    7.49    8.65
## theta[7,3]    2.57    0.01 0.24    2.16    2.40    2.55    2.70    3.09
## theta[7,4]    9.17    0.00 0.08    9.02    9.11    9.17    9.23    9.33
## theta[8,1]    0.46    0.00 0.00    0.45    0.46    0.46    0.46    0.46
## theta[8,2]    6.72    0.05 0.71    5.35    6.24    6.70    7.19    8.21
## theta[8,3]    2.64    0.02 0.26    2.16    2.46    2.63    2.80    3.19
## theta[8,4]    9.40    0.00 0.08    9.25    9.35    9.40    9.45    9.54
## theta[9,1]    0.49    0.00 0.00    0.48    0.49    0.49    0.49    0.49
## theta[9,2]    7.21    0.05 0.74    5.88    6.68    7.18    7.68    8.77
## theta[9,3]    2.27    0.01 0.22    1.88    2.12    2.26    2.41    2.72
## theta[9,4]   12.58    0.00 0.09   12.40   12.52   12.58   12.64   12.76
## sigma[1]      0.02    0.00 0.00    0.01    0.01    0.02    0.02    0.02
## sigma[2]      0.02    0.00 0.00    0.02    0.02    0.02    0.02    0.02
## sigma[3]      0.02    0.00 0.00    0.02    0.02    0.02    0.02    0.02
## sigma[4]      0.02    0.00 0.00    0.01    0.02    0.02    0.02    0.02
## sigma[5]      0.02    0.00 0.00    0.01    0.02    0.02    0.02    0.02
## sigma[6]      0.02    0.00 0.00    0.01    0.01    0.02    0.02    0.02
## sigma[7]      0.02    0.00 0.00    0.01    0.01    0.02    0.02    0.02
## sigma[8]      0.02    0.00 0.00    0.01    0.02    0.02    0.02    0.02
## sigma[9]      0.02    0.00 0.00    0.02    0.02    0.02    0.02    0.02
## lp__       4771.24    0.55 8.32 4755.63 4765.76 4770.93 4776.34 4790.33
##            n_eff Rhat
## theta[1,1]  2000 1.00
## theta[1,2]   283 1.01
## theta[1,3]   303 1.01
## theta[1,4]   590 1.00
## theta[2,1]  2000 1.00
## theta[2,2]   144 1.02
## theta[2,3]   251 1.02
## theta[2,4]   574 1.01
## theta[3,1]  2000 1.00
## theta[3,2]   230 1.03
## theta[3,3]   283 1.02
## theta[3,4]   502 1.01
## theta[4,1]  2000 1.00
## theta[4,2]   276 1.02
## theta[4,3]   281 1.02
## theta[4,4]   535 1.01
## theta[5,1]  2000 1.00
## theta[5,2]   242 1.02
## theta[5,3]   246 1.02
## theta[5,4]   526 1.01
## theta[6,1]  2000 1.00
## theta[6,2]   152 1.02
## theta[6,3]   149 1.02
## theta[6,4]   280 1.01
## theta[7,1]  2000 1.00
## theta[7,2]   287 1.02
## theta[7,3]   286 1.02
## theta[7,4]   648 1.01
## theta[8,1]  2000 1.00
## theta[8,2]   216 1.01
## theta[8,3]   216 1.01
## theta[8,4]   424 1.01
## theta[9,1]  2000 1.00
## theta[9,2]   245 1.01
## theta[9,3]   263 1.01
## theta[9,4]   479 1.01
## sigma[1]     917 1.00
## sigma[2]     754 1.01
## sigma[3]    1440 1.00
## sigma[4]    2000 1.00
## sigma[5]    1316 1.00
## sigma[6]     845 1.00
## sigma[7]    1151 1.00
## sigma[8]    1580 1.00
## sigma[9]     347 1.01
## lp__         229 1.01
## 
## Samples were drawn using NUTS(diag_e) at Sun Mar 25 13:28:40 2018.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

Lets check the posterior predictive distribution to see if predictions made from this model look similar to our original data.

We have learnt a separate set of parameters for each well giving a better description of the data from each well and capturing the variability in the population. We could now make predictions about behaviour in a new experiment and quantitatively compare different populations of bacteria.

Notice that we received a message about divergent transitions after warmup: 11 divergent transitions after warmup. This indicates problems sampling from the posterior distribution. As a result we should be cautious in our interpretation of these results and avoid using full posterior distributions as the tails in particular may be badly sampled. Reparametrising our model may help by altering the geometry of the posterior distribution. If you have suggestions of good ways to do so for this case, then get in touch!