## Data

Eight Schools is a study of coaching effects from eight schools; it comes from section 5.5 of Gelman et al. (2003) as covered in 2.1. Schools data of ‘R2WinBUGS: A Package for Running WinBUGS from R’:

The Scholastic Aptitude Test (SAT) measures the aptitude of high-schoolers in order to help colleges to make admissions decisions. It is divided into two parts, verbal (SAT-V) and mathematical (SAT-M). Our data comes from the SAT-V (Scholastic Aptitude Test-Verbal) on eight different high schools, from an experiment conducted in the late 1970s. SAT-V is a standard multiple choice test administered by the Educational Testing Service. This Service was interested in the effects of coaching programs for each of the selected schools. The study included coached and uncoached pupils, about sixty in each of the eight different schools; see Rubin (1981). All of them had already taken the PSAT (Preliminary SAT) which results were used as covariates. For each school, the estimated treatment effect and the standard error of the effect estimate are given. These are calculated by an analysis of covariance adjustment appropriate for a completely randomized experiment (Rubin 1981). This example was analysed using a hierarchical normal model in Rubin (1981) and Gelman, Carlin, Stern, and Rubin (2003, Section 5.5).

The corresponding TensorFlow Probability Jupyter notebook can be found here.

library(greta)
library(tidyverse)
library(bayesplot)
color_scheme_set("purple")
# data
N <- letters[1:8]
treatment_effects <- c(28.39, 7.94, -2.75 , 6.82, -0.64, 0.63, 18.01, 12.16)
treatment_stddevs <- c(14.9, 10.2, 16.3, 11.0, 9.4, 11.4, 10.4, 17.6)
schools <- data.frame(N = N,
treatment_effects = treatment_effects,
treatment_stddevs = treatment_stddevs) %>%
mutate(treatment_effects_p_stddevs = treatment_effects + treatment_stddevs,
treatment_effects_m_stddevs = treatment_effects - treatment_stddevs)

For each the eight schools N we have the estimated treatment effect (treatment_effects) plus standard error (treatment_stddevs). Below, we are replicating the barplot from the TensorFlow Probability example that shows the estimated treatment effects +/- standard error per school:

ggplot(schools, aes(x = N, y = treatment_effects)) +
geom_bar(stat = "identity", fill = "purple", alpha = 0.5) +
geom_errorbar(aes(ymin = treatment_effects_m_stddevs, ymax = treatment_effects_p_stddevs), width = 0.3) +
labs(x = "school", y = "treatment effect",
title = "Barplot of treatment effects for eight schools",
subtitle = "Error bars represent standard error")

A different way to plot the estimated effects and their standard errors is to plot the density distribution over the eight schools we have:

schools %>%
gather(x, y, treatment_effects, treatment_effects_p_stddevs, treatment_effects_m_stddevs) %>%
ggplot(aes(x = y, color = x)) +
geom_density(fill = "purple", alpha = 0.5) +
scale_color_brewer(palette = "Set1") +
labs(x = "treatment effect (+/- standard error)",
color = "density curve of",
title = "Density plot of treatment effects +/- standard error for eight schools")

## Modelling with greta

To model the data, we use the same hierarchical normal model as in the TensorFlow Probability example.

### Variables and priors

First, we create greta arrays that represent the variables and prior distributions in our model and create a greta array for school effect from them. We define the following (random) variables and priors:

• avg_effect: normal density function (dnorm) with a mean of 0 and standard deviation of 10; represents the prior average treatment effect.
avg_effect <- normal(mean = 0, sd = 10)
avg_effect
## greta array (variable following a normal distribution)
##
##      [,1]
## [1,]  ?
• avg_stddev: normal density function (dnorm) with a mean of 5 and standard deviation of 1; controls the amount of variance between schools.
avg_stddev <- normal(5, 1)
avg_stddev
## greta array (variable following a normal distribution)
##
##      [,1]
## [1,]  ?
• school_effects_standard: normal density function (dnorm) with a mean of 0, standard deviation of 1 and dimension of 8
school_effects_standard <- normal(0, 1, dim = length(N))
school_effects_standard
## greta array (variable following a normal distribution)
##
##      [,1]
## [1,]  ?
## [2,]  ?
## [3,]  ?
## [4,]  ?
## [5,]  ?
## [6,]  ?
## [7,]  ?
## [8,]  ?
• school_effects: here we multiply the exponential of avg_stddev with school_effects_standard and add avg_effect
school_effects <- avg_effect + exp(avg_stddev) * school_effects_standard
school_effects
## greta array (operation)
##
##      [,1]
## [1,]  ?
## [2,]  ?
## [3,]  ?
## [4,]  ?
## [5,]  ?
## [6,]  ?
## [7,]  ?
## [8,]  ?

An alternative would be to directly use the lognormal() density function for avg_stddev and use that to calculate school_effect:

avg_stddev <- lognormal(5, 1)
school_effects <- avg_effect + avg_stddev * school_effects_standard

### Likelihood

Next, we want to link the variables and priors with the observed dependent data - in this case the school estimate treatment_effects. We define the likelihood over our observed estimates treatment_effects given a random sample from the normal probability distribution with mean school_effects and standard deviation treatment_stddevs. From this, we would now like to calculate the parameter of that probability distribution by using the distribution() function:

distribution(treatment_effects) <- normal(school_effects, treatment_stddevs)

### Bayesian inference model

Now we have all the prerequisites for building a Hamiltonian Monte Carlo (HMC) to calculate the posterior distribution over the model’s parameters.

We first define the model by combining the calculated avg_effect, avg_stddev and school_effects_standard variables so that we can sample from them during modelling. The model m we define below contains all our prior distributions and thus represent the combined density of the model.

It is recommended that you check your model at this step by plotting the model graph. More information about these plots can be found here.

# defining the hierarchical model
m <- model(avg_effect, avg_stddev, school_effects_standard)
m
## greta model
plot(m)

The actual sampling from the model happens with the mcmc() function. By default 1000 MCMC samples are drawn after warm-up. What we obtain is a probability measure that describes the likelihood of a set of randomly sampled values for the model variables.

# sampling
draws <- greta::mcmc(m, n_samples = 1000, warmup = 1000, chains = 4)
summary(draws)
##
## Iterations = 1:1000
## Thinning interval = 1
## Number of chains = 4
## Sample size per chain = 1000
##
## 1. Empirical mean and standard deviation for each variable,
##    plus standard error of the mean:
##
##                                  Mean     SD Naive SE Time-series SE
## avg_effect                    5.66639 5.4058  0.08547        0.11406
## avg_stddev                   13.23355 7.0904  0.11211        0.24471
## school_effects_standard[1,1]  0.65643 0.7890  0.01248        0.01573
## school_effects_standard[2,1]  0.10173 0.7145  0.01130        0.01418
## school_effects_standard[3,1] -0.21187 0.8164  0.01291        0.01618
## school_effects_standard[4,1]  0.04093 0.7438  0.01176        0.01369
## school_effects_standard[5,1] -0.29814 0.6720  0.01063        0.01319
## school_effects_standard[6,1] -0.19047 0.7308  0.01155        0.01433
## school_effects_standard[7,1]  0.52154 0.6945  0.01098        0.01326
## school_effects_standard[8,1]  0.13518 0.8489  0.01342        0.01640
##
## 2. Quantiles for each variable:
##
##                                 2.5%      25%      50%     75%   97.5%
## avg_effect                   -5.3445  2.28764  5.77921  9.2812 15.6988
## avg_stddev                    4.1594  8.32819 11.71733 16.5100 33.1951
## school_effects_standard[1,1] -0.9430  0.15193  0.66124  1.1683  2.2007
## school_effects_standard[2,1] -1.3107 -0.34421  0.10414  0.5532  1.5196
## school_effects_standard[3,1] -1.8459 -0.73635 -0.20636  0.3352  1.3145
## school_effects_standard[4,1] -1.4767 -0.43877  0.05423  0.5166  1.5597
## school_effects_standard[5,1] -1.6740 -0.72760 -0.28705  0.1321  0.9997
## school_effects_standard[6,1] -1.6760 -0.65693 -0.17677  0.2787  1.2420
## school_effects_standard[7,1] -0.8987  0.08377  0.53535  0.9619  1.9026
## school_effects_standard[8,1] -1.5778 -0.42636  0.14659  0.6951  1.8485
mcmc_trace(draws, facet_args = list(ncol = 3))

mcmc_intervals(draws)

mcmc_acf_bar(draws)

mcmc_hist(draws, facet_args = list(ncol = 3))
## stat_bin() using bins = 30. Pick better value with binwidth.

#### Use of calculate() for transforming estimates to natural scale

The calculate() function can be used with the transformation function used in building the model to get the school-specific posteriors chains. This function is also how you would get posterior predictive values.

# Calculate school effects on original scale
school_effects           <- avg_effect + avg_stddev * school_effects_standard
posterior_school_effects <- calculate(school_effects, values = draws) 

#### Comparison with Edward2 HMC

As a sanity check that we parameterized our model correctly, we can compare the back-transformed school-specific estimates to the results from the Edward2 approach in the TensorFlow Probability documentation. The results are very similar.

# Posterior means via Edward2
edward2_school_means <-
data.frame(tool = "Edward2",
school = N,
#mean_school_effects_standard = c(0.61157268, 0.06430732, -0.25459746,
#                                 0.04828103, -0.36940941, -0.23154463,
#                                 0.49402338,  0.13042814),
mean = c(14.93237686, 7.50939941, 3.07602358, 7.21652555,
2.0329783, 3.41213799, 12.92509365, 8.36702347),
sd = 0)

edward2_pop_mean <- data.frame(tool = "Edward2", 'mean' = 6.48866844177, 'sd' = 0)
# hmc_mean_avg_stddev <- 2.46163249016

posterior_school_effects <- as.data.frame(as.matrix(posterior_school_effects))

# Relabel school measures
colnames(posterior_school_effects) <- N

# Summarise and combine all chains of interest for plotting
posterior_summaries <-
posterior_school_effects %>%
gather(key = school, value = value) %>%
group_by(school) %>%
summarise_all(funs(mean, sd)) 
## Warning: funs() was deprecated in dplyr 0.8.0.
## Please use a list of either functions or lambdas:
##
##   # Simple named list:
##   list(mean = mean, median = median)
##
##   # Auto named with tibble::lst():
##   tibble::lst(mean, median)
##
##   # Using lambdas
##   list(~ mean(., trim = .2), ~ median(., na.rm = TRUE))
## This warning is displayed once every 8 hours.
## Call lifecycle::last_lifecycle_warnings() to see where this warning was generated.
school_summaries <-
posterior_summaries %>%
mutate(tool = "greta") %>%
rbind(edward2_school_means)

population_parameters <-
as.data.frame(as.matrix(draws)) %>%
select(avg_effect) %>%
summarise_all(funs(mean, sd)) %>%
mutate(tool = "greta") %>%
rbind(edward2_pop_mean)

ggplot(school_summaries, aes(x = school, y = mean, color = tool, shape = tool)) +
geom_errorbar(aes(ymin = mean - sd, ymax = mean + sd), width = 0.2) +
geom_point() +
geom_hline(data = population_parameters,
aes(yintercept = mean, linetype = 'Population mean', color = tool)) +
scale_linetype_manual(name = "", values = c(2, 2)) 

## Session information

sessionInfo()
## R version 4.1.2 (2021-11-01)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
##
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_AU.UTF-8/en_AU.UTF-8/en_AU.UTF-8/C/en_AU.UTF-8/en_AU.UTF-8
##
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base
##
## other attached packages:
##  [1] bayesplot_1.8.1 forcats_0.5.1   stringr_1.4.0   dplyr_1.0.7
##  [5] purrr_0.3.4     readr_2.1.1     tidyr_1.1.4     tibble_3.1.6
##  [9] ggplot2_3.3.5   tidyverse_1.3.1 greta_0.4.0
##
## loaded via a namespace (and not attached):
##  [1] fs_1.5.1           lubridate_1.8.0    progress_1.2.2     httr_1.4.2
##  [5] rprojroot_2.0.2    tools_4.1.2        backports_1.4.0    bslib_0.3.1
##  [9] utf8_1.2.2         R6_2.5.1           DBI_1.1.1          colorspace_2.0-2
## [13] withr_2.4.3        tidyselect_1.1.1   prettyunits_1.1.1  processx_3.5.2
## [17] compiler_4.1.2     cli_3.1.0          rvest_1.0.2        xml2_1.3.3
## [21] desc_1.4.0         labeling_0.4.2     sass_0.4.0         scales_1.1.1
## [25] ggridges_0.5.3     callr_3.7.0        tfruns_1.5.0       pkgdown_1.1.0.9000
## [29] digest_0.6.29      rmarkdown_2.11     base64enc_0.1-3    pkgconfig_2.0.3
## [33] htmltools_0.5.2    parallelly_1.29.0  highr_0.9          dbplyr_2.1.1
## [37] fastmap_1.1.0      rlang_0.4.12       readxl_1.3.1       rstudioapi_0.13
## [41] farver_2.1.0       jquerylib_0.1.4    generics_0.1.1     jsonlite_1.7.2
## [45] tensorflow_2.7.0   magrittr_2.0.1     Matrix_1.3-4       Rcpp_1.0.7
## [49] munsell_0.5.0      fansi_0.5.0        abind_1.4-5        reticulate_1.22
## [53] lifecycle_1.0.1    stringi_1.7.6      whisker_0.4        yaml_2.2.1
## [57] MASS_7.3-54        plyr_1.8.6         grid_4.1.2         parallel_4.1.2
## [61] listenv_0.8.0      crayon_1.4.2       lattice_0.20-45    haven_2.4.3
## [65] hms_1.1.1          knitr_1.36         ps_1.6.0           pillar_1.6.4
## [69] codetools_0.2-18   reprex_2.0.1       glue_1.5.1         evaluate_0.14
## [73] modelr_0.1.8       png_0.1-7          vctrs_0.3.8        tzdb_0.2.0
## [77] cellranger_1.1.0   gtable_0.3.0       future_1.23.0      assertthat_0.2.1
## [81] cachem_1.0.6       xfun_0.28          broom_0.7.10       coda_0.19-4
## [85] roxygen2_7.1.2     memoise_2.0.1      globals_0.14.0     ellipsis_0.3.2
## [89] here_1.0.1