Eight Schools is a study of coaching effects from eight schools; it comes from section 5.5 of Gelman et al. (2003) as covered in 2.1. Schools data of ‘R2WinBUGS: A Package for Running WinBUGS from R’:
The Scholastic Aptitude Test (SAT) measures the aptitude of high-schoolers in order to help colleges to make admissions decisions. It is divided into two parts, verbal (SAT-V) and mathematical (SAT-M). Our data comes from the SAT-V (Scholastic Aptitude Test-Verbal) on eight different high schools, from an experiment conducted in the late 1970s. SAT-V is a standard multiple choice test administered by the Educational Testing Service. This Service was interested in the effects of coaching programs for each of the selected schools. The study included coached and uncoached pupils, about sixty in each of the eight different schools; see Rubin (1981). All of them had already taken the PSAT (Preliminary SAT) which results were used as covariates. For each school, the estimated treatment effect and the standard error of the effect estimate are given. These are calculated by an analysis of covariance adjustment appropriate for a completely randomized experiment (Rubin 1981). This example was analysed using a hierarchical normal model in Rubin (1981) and Gelman, Carlin, Stern, and Rubin (2003, Section 5.5).
The corresponding TensorFlow Probability Jupyter notebook can be found here.
library(greta)
library(tidyverse)
library(bayesplot)
color_scheme_set("purple")
# data
N <- letters[1:8]
treatment_effects <- c(28.39, 7.94, -2.75 , 6.82, -0.64, 0.63, 18.01, 12.16)
treatment_stddevs <- c(14.9, 10.2, 16.3, 11.0, 9.4, 11.4, 10.4, 17.6)
<- data.frame(N = N,
schools treatment_effects = treatment_effects,
treatment_stddevs = treatment_stddevs) %>%
mutate(treatment_effects_p_stddevs = treatment_effects + treatment_stddevs,
treatment_effects_m_stddevs = treatment_effects - treatment_stddevs)
For each the eight schools N
we have the estimated treatment effect (treatment_effects
) plus standard error (treatment_stddevs
). Below, we are replicating the barplot from the TensorFlow Probability example that shows the estimated treatment effects +/- standard error per school:
ggplot(schools, aes(x = N, y = treatment_effects)) +
geom_bar(stat = "identity", fill = "purple", alpha = 0.5) +
geom_errorbar(aes(ymin = treatment_effects_m_stddevs, ymax = treatment_effects_p_stddevs), width = 0.3) +
labs(x = "school", y = "treatment effect",
title = "Barplot of treatment effects for eight schools",
subtitle = "Error bars represent standard error")
A different way to plot the estimated effects and their standard errors is to plot the density distribution over the eight schools we have:
%>%
schools gather(x, y, treatment_effects, treatment_effects_p_stddevs, treatment_effects_m_stddevs) %>%
ggplot(aes(x = y, color = x)) +
geom_density(fill = "purple", alpha = 0.5) +
scale_color_brewer(palette = "Set1") +
labs(x = "treatment effect (+/- standard error)",
color = "density curve of",
title = "Density plot of treatment effects +/- standard error for eight schools")
greta
To model the data, we use the same hierarchical normal model as in the TensorFlow Probability example.
First, we create greta arrays that represent the variables and prior distributions in our model and create a greta array for school effect from them. We define the following (random) variables and priors:
avg_effect
: normal density function (dnorm
) with a mean of 0
and standard deviation of 10
; represents the prior average treatment effect.<- normal(mean = 0, sd = 10)
avg_effect avg_effect
## greta array (variable following a normal distribution)
##
## [,1]
## [1,] ?
avg_stddev
: normal density function (dnorm
) with a mean of 5
and standard deviation of 1
; controls the amount of variance between schools.<- normal(5, 1)
avg_stddev avg_stddev
## greta array (variable following a normal distribution)
##
## [,1]
## [1,] ?
school_effects_standard
: normal density function (dnorm
) with a mean of 0
, standard deviation of 1
and dimension of 8
<- normal(0, 1, dim = length(N))
school_effects_standard school_effects_standard
## greta array (variable following a normal distribution)
##
## [,1]
## [1,] ?
## [2,] ?
## [3,] ?
## [4,] ?
## [5,] ?
## [6,] ?
## [7,] ?
## [8,] ?
school_effects
: here we multiply the exponential of avg_stddev
with school_effects_standard
and add avg_effect
<- avg_effect + exp(avg_stddev) * school_effects_standard
school_effects school_effects
## greta array (operation)
##
## [,1]
## [1,] ?
## [2,] ?
## [3,] ?
## [4,] ?
## [5,] ?
## [6,] ?
## [7,] ?
## [8,] ?
An alternative would be to directly use the lognormal()
density function for avg_stddev
and use that to calculate school_effect
:
<- lognormal(5, 1)
avg_stddev <- avg_effect + avg_stddev * school_effects_standard school_effects
Next, we want to link the variables and priors with the observed dependent data - in this case the school estimate treatment_effects
. We define the likelihood over our observed estimates treatment_effects
given a random sample from the normal probability distribution with mean school_effects
and standard deviation treatment_stddevs
. From this, we would now like to calculate the parameter of that probability distribution by using the distribution()
function:
distribution(treatment_effects) <- normal(school_effects, treatment_stddevs)
Now we have all the prerequisites for building a Hamiltonian Monte Carlo (HMC) to calculate the posterior distribution over the model’s parameters.
We first define the model by combining the calculated avg_effect
, avg_stddev
and school_effects_standard
variables so that we can sample from them during modelling. The model m
we define below contains all our prior distributions and thus represent the combined density of the model.
It is recommended that you check your model at this step by plotting the model graph. More information about these plots can be found here.
# defining the hierarchical model
<- model(avg_effect, avg_stddev, school_effects_standard)
m m
## greta model
plot(m)
The actual sampling from the model happens with the mcmc()
function. By default 1000 MCMC samples are drawn after warm-up. What we obtain is a probability measure that describes the likelihood of a set of randomly sampled values for the model variables.
# sampling
<- greta::mcmc(m, n_samples = 1000, warmup = 1000, chains = 4) draws
summary(draws)
##
## Iterations = 1:1000
## Thinning interval = 1
## Number of chains = 4
## Sample size per chain = 1000
##
## 1. Empirical mean and standard deviation for each variable,
## plus standard error of the mean:
##
## Mean SD Naive SE Time-series SE
## avg_effect 5.66639 5.4058 0.08547 0.11406
## avg_stddev 13.23355 7.0904 0.11211 0.24471
## school_effects_standard[1,1] 0.65643 0.7890 0.01248 0.01573
## school_effects_standard[2,1] 0.10173 0.7145 0.01130 0.01418
## school_effects_standard[3,1] -0.21187 0.8164 0.01291 0.01618
## school_effects_standard[4,1] 0.04093 0.7438 0.01176 0.01369
## school_effects_standard[5,1] -0.29814 0.6720 0.01063 0.01319
## school_effects_standard[6,1] -0.19047 0.7308 0.01155 0.01433
## school_effects_standard[7,1] 0.52154 0.6945 0.01098 0.01326
## school_effects_standard[8,1] 0.13518 0.8489 0.01342 0.01640
##
## 2. Quantiles for each variable:
##
## 2.5% 25% 50% 75% 97.5%
## avg_effect -5.3445 2.28764 5.77921 9.2812 15.6988
## avg_stddev 4.1594 8.32819 11.71733 16.5100 33.1951
## school_effects_standard[1,1] -0.9430 0.15193 0.66124 1.1683 2.2007
## school_effects_standard[2,1] -1.3107 -0.34421 0.10414 0.5532 1.5196
## school_effects_standard[3,1] -1.8459 -0.73635 -0.20636 0.3352 1.3145
## school_effects_standard[4,1] -1.4767 -0.43877 0.05423 0.5166 1.5597
## school_effects_standard[5,1] -1.6740 -0.72760 -0.28705 0.1321 0.9997
## school_effects_standard[6,1] -1.6760 -0.65693 -0.17677 0.2787 1.2420
## school_effects_standard[7,1] -0.8987 0.08377 0.53535 0.9619 1.9026
## school_effects_standard[8,1] -1.5778 -0.42636 0.14659 0.6951 1.8485
mcmc_trace(draws, facet_args = list(ncol = 3))
mcmc_intervals(draws)
mcmc_acf_bar(draws)
mcmc_hist(draws, facet_args = list(ncol = 3))
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
calculate()
for transforming estimates to natural scaleThe calculate()
function can be used with the transformation function used in building the model to get the school-specific posteriors chains. This function is also how you would get posterior predictive values.
# Calculate school effects on original scale
<- avg_effect + avg_stddev * school_effects_standard
school_effects <- calculate(school_effects, values = draws) posterior_school_effects
As a sanity check that we parameterized our model correctly, we can compare the back-transformed school-specific estimates to the results from the Edward2 approach in the TensorFlow Probability documentation. The results are very similar.
# Posterior means via Edward2
<-
edward2_school_means data.frame(tool = "Edward2",
school = N,
#mean_school_effects_standard = c(0.61157268, 0.06430732, -0.25459746,
# 0.04828103, -0.36940941, -0.23154463,
# 0.49402338, 0.13042814),
mean = c(14.93237686, 7.50939941, 3.07602358, 7.21652555,
2.0329783, 3.41213799, 12.92509365, 8.36702347),
sd = 0)
<- data.frame(tool = "Edward2", 'mean' = 6.48866844177, 'sd' = 0)
edward2_pop_mean # hmc_mean_avg_stddev <- 2.46163249016
<- as.data.frame(as.matrix(posterior_school_effects))
posterior_school_effects
# Relabel school measures
colnames(posterior_school_effects) <- N
# Summarise and combine all chains of interest for plotting
<-
posterior_summaries %>%
posterior_school_effects gather(key = school, value = value) %>%
group_by(school) %>%
summarise_all(funs(mean, sd))
## Warning: `funs()` was deprecated in dplyr 0.8.0.
## Please use a list of either functions or lambdas:
##
## # Simple named list:
## list(mean = mean, median = median)
##
## # Auto named with `tibble::lst()`:
## tibble::lst(mean, median)
##
## # Using lambdas
## list(~ mean(., trim = .2), ~ median(., na.rm = TRUE))
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
<-
school_summaries %>%
posterior_summaries mutate(tool = "greta") %>%
rbind(edward2_school_means)
<-
population_parameters as.data.frame(as.matrix(draws)) %>%
select(avg_effect) %>%
summarise_all(funs(mean, sd)) %>%
mutate(tool = "greta") %>%
rbind(edward2_pop_mean)
ggplot(school_summaries, aes(x = school, y = mean, color = tool, shape = tool)) +
geom_errorbar(aes(ymin = mean - sd, ymax = mean + sd), width = 0.2) +
geom_point() +
geom_hline(data = population_parameters,
aes(yintercept = mean, linetype = 'Population mean', color = tool)) +
scale_linetype_manual(name = "", values = c(2, 2))
sessionInfo()
## R version 4.1.2 (2021-11-01)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_AU.UTF-8/en_AU.UTF-8/en_AU.UTF-8/C/en_AU.UTF-8/en_AU.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] bayesplot_1.8.1 forcats_0.5.1 stringr_1.4.0 dplyr_1.0.7
## [5] purrr_0.3.4 readr_2.1.1 tidyr_1.1.4 tibble_3.1.6
## [9] ggplot2_3.3.5 tidyverse_1.3.1 greta_0.4.0
##
## loaded via a namespace (and not attached):
## [1] fs_1.5.1 lubridate_1.8.0 progress_1.2.2 httr_1.4.2
## [5] rprojroot_2.0.2 tools_4.1.2 backports_1.4.0 bslib_0.3.1
## [9] utf8_1.2.2 R6_2.5.1 DBI_1.1.1 colorspace_2.0-2
## [13] withr_2.4.3 tidyselect_1.1.1 prettyunits_1.1.1 processx_3.5.2
## [17] compiler_4.1.2 cli_3.1.0 rvest_1.0.2 xml2_1.3.3
## [21] desc_1.4.0 labeling_0.4.2 sass_0.4.0 scales_1.1.1
## [25] ggridges_0.5.3 callr_3.7.0 tfruns_1.5.0 pkgdown_1.1.0.9000
## [29] digest_0.6.29 rmarkdown_2.11 base64enc_0.1-3 pkgconfig_2.0.3
## [33] htmltools_0.5.2 parallelly_1.29.0 highr_0.9 dbplyr_2.1.1
## [37] fastmap_1.1.0 rlang_0.4.12 readxl_1.3.1 rstudioapi_0.13
## [41] farver_2.1.0 jquerylib_0.1.4 generics_0.1.1 jsonlite_1.7.2
## [45] tensorflow_2.7.0 magrittr_2.0.1 Matrix_1.3-4 Rcpp_1.0.7
## [49] munsell_0.5.0 fansi_0.5.0 abind_1.4-5 reticulate_1.22
## [53] lifecycle_1.0.1 stringi_1.7.6 whisker_0.4 yaml_2.2.1
## [57] MASS_7.3-54 plyr_1.8.6 grid_4.1.2 parallel_4.1.2
## [61] listenv_0.8.0 crayon_1.4.2 lattice_0.20-45 haven_2.4.3
## [65] hms_1.1.1 knitr_1.36 ps_1.6.0 pillar_1.6.4
## [69] codetools_0.2-18 reprex_2.0.1 glue_1.5.1 evaluate_0.14
## [73] modelr_0.1.8 png_0.1-7 vctrs_0.3.8 tzdb_0.2.0
## [77] cellranger_1.1.0 gtable_0.3.0 future_1.23.0 assertthat_0.2.1
## [81] cachem_1.0.6 xfun_0.28 broom_0.7.10 coda_0.19-4
## [85] roxygen2_7.1.2 memoise_2.0.1 globals_0.14.0 ellipsis_0.3.2
## [89] here_1.0.1