Eight Schools is a study of coaching effects from eight schools; it comes from section 5.5 of Gelman et al. (2003) as covered in 2.1. Schools data of ‘R2WinBUGS: A Package for Running WinBUGS from R’:
The Scholastic Aptitude Test (SAT) measures the aptitude of high-schoolers in order to help colleges to make admissions decisions. It is divided into two parts, verbal (SAT-V) and mathematical (SAT-M). Our data comes from the SAT-V (Scholastic Aptitude Test-Verbal) on eight different high schools, from an experiment conducted in the late 1970s. SAT-V is a standard multiple choice test administered by the Educational Testing Service. This Service was interested in the effects of coaching programs for each of the selected schools. The study included coached and uncoached pupils, about sixty in each of the eight different schools; see Rubin (1981). All of them had already taken the PSAT (Preliminary SAT) which results were used as covariates. For each school, the estimated treatment effect and the standard error of the effect estimate are given. These are calculated by an analysis of covariance adjustment appropriate for a completely randomized experiment (Rubin 1981). This example was analyzed using a hierarchical normal model in Rubin (1981) and Gelman, Carlin, Stern, and Rubin (2003, Section 5.5).
The corresponding TensorFlow Probability Jupyter notebook can be found here.
library(greta)
library(tidyverse)
library(bayesplot)
color_scheme_set("purple")
# data
N <- letters[1:8]
treatment_effects <- c(28.39, 7.94, -2.75 , 6.82, -0.64, 0.63, 18.01, 12.16)
treatment_stddevs <- c(14.9, 10.2, 16.3, 11.0, 9.4, 11.4, 10.4, 17.6)
schools <- data.frame(N = N,
treatment_effects = treatment_effects,
treatment_stddevs = treatment_stddevs) %>%
mutate(treatment_effects_p_stddevs = treatment_effects + treatment_stddevs,
treatment_effects_m_stddevs = treatment_effects - treatment_stddevs)
For each the eight schools N
we have the estimated treatment effect (treatment_effects
) plus standard error (treatment_stddevs
). Below, we are replicating the barplot from the TensorFlow Probability example that shows the estimated treatment effects +/- standard error per school:
ggplot(schools, aes(x = N, y = treatment_effects)) +
geom_bar(stat = "identity", fill = "purple", alpha = 0.5) +
geom_errorbar(aes(ymin = treatment_effects_m_stddevs, ymax = treatment_effects_p_stddevs), width = 0.3) +
labs(x = "school", y = "treatment effect",
title = "Barplot of treatment effects for eight schools",
subtitle = "Error bars represent standard error")
A different way to plot the estimated effects and their standard errors is to plot the density distribution over the eight schools we have:
schools %>%
gather(x, y, treatment_effects, treatment_effects_p_stddevs, treatment_effects_m_stddevs) %>%
ggplot(aes(x = y, color = x)) +
geom_density(fill = "purple", alpha = 0.5) +
scale_color_brewer(palette = "Set1") +
labs(x = "treatment effect (+/- standard error)",
color = "density curve of",
title = "Density plot of treatment effects +/- standard error for eight schools")
greta
To model the data, we use the same hierarchical normal model as in the TensorFlow Probability example.
First, we create greta arrays that represent the variables and prior distributions in our model and create a greta array for school effect from them. We define the following (random) variables and priors:
avg_effect
: normal density function (dnorm
) with a mean of 0
and standard deviation of 10
; represents the prior average treatment effect.avg_effect <- normal(mean = 0, sd = 10)
avg_effect
## greta array (variable following a normal distribution)
##
## [,1]
## [1,] ?
avg_stddev
: normal density function (dnorm
) with a mean of 5
and standard deviation of 1
; controls the amount of variance between schools.avg_stddev <- normal(5, 1)
avg_stddev
## greta array (variable following a normal distribution)
##
## [,1]
## [1,] ?
school_effects_standard
: normal density function (dnorm
) with a mean of 0
, standard deviation of 1
and dimension of 8
school_effects_standard <- normal(0, 1, dim = length(N))
school_effects_standard
## greta array (variable following a normal distribution)
##
## [,1]
## [1,] ?
## [2,] ?
## [3,] ?
## [4,] ?
## [5,] ?
## [6,] ?
## [7,] ?
## [8,] ?
school_effects
: here we multiply the exponential of avg_stddev
with school_effects_standard
and add avg_effect
school_effects <- avg_effect + exp(avg_stddev) * school_effects_standard
school_effects
## greta array (operation)
##
## [,1]
## [1,] ?
## [2,] ?
## [3,] ?
## [4,] ?
## [5,] ?
## [6,] ?
## [7,] ?
## [8,] ?
An alternative would be to directly use the lognormal()
density function for avg_stddev
and use that to calculate school_effect
:
avg_stddev <- lognormal(5, 1)
school_effects <- avg_effect + avg_stddev * school_effects_standard
Next, we want to link the variables and priors with the observed dependent data - in this case the school estimate treatment_effects
. We define the likelihood over our observed estimates treatment_effects
given a random sample from the normal probability distribution with mean school_effects
and standard deviation treatment_stddevs
. From this, we would now like to calculate the parameter of that probability distribution by using the distribution()
function:
distribution(treatment_effects) <- normal(school_effects, treatment_stddevs)
Now we have all the prerequisites for building a Hamiltonian Monte Carlo (HMC) to calculate the posterior distribution over the model’s parameters.
We first define the model by combining the calculated avg_effect
, avg_stddev
and school_effects_standard
variables so that we can sample from them during modeling. The model m
we define below contains all our prior distributions and thus represent the combined density of the model.
It is recommended that you check your model at this step by plotting the model graph. More information about these plots can be found here.
# defining the hierarchical model
m <- model(avg_effect, avg_stddev, school_effects_standard)
m
## greta model
plot(m)
The actual sampling from the model happens with the mcmc()
function. By default 1000 MCMC samples are drawn after warm-up. What we obtain is a probability measure that describes the likelihood of a set of randomly sampled values for the model variables.
# sampling
draws <- greta::mcmc(m, n_samples = 1000, warmup = 1000, chains = 4)
summary(draws)
##
## Iterations = 1:1000
## Thinning interval = 1
## Number of chains = 4
## Sample size per chain = 1000
##
## 1. Empirical mean and standard deviation for each variable,
## plus standard error of the mean:
##
## Mean SD Naive SE Time-series SE
## avg_stddev 13.34410 7.4121 0.11720 0.30686
## school_effects_standard[1,1] 0.67942 0.8090 0.01279 0.01550
## school_effects_standard[2,1] 0.07902 0.6941 0.01097 0.01349
## school_effects_standard[3,1] -0.23530 0.7938 0.01255 0.01442
## school_effects_standard[4,1] 0.02425 0.7215 0.01141 0.01321
## school_effects_standard[5,1] -0.33197 0.6835 0.01081 0.01237
## school_effects_standard[6,1] -0.20530 0.7410 0.01172 0.01409
## school_effects_standard[7,1] 0.52226 0.7098 0.01122 0.01274
## school_effects_standard[8,1] 0.16803 0.8303 0.01313 0.01615
## avg_effect 5.96913 5.5232 0.08733 0.11305
##
## 2. Quantiles for each variable:
##
## 2.5% 25% 50% 75% 97.5%
## avg_stddev 3.8937 8.32668 11.85203 16.6972 30.9743
## school_effects_standard[1,1] -0.9832 0.15694 0.69172 1.2197 2.2386
## school_effects_standard[2,1] -1.2758 -0.36769 0.07515 0.5149 1.4459
## school_effects_standard[3,1] -1.8624 -0.74132 -0.21911 0.2866 1.2835
## school_effects_standard[4,1] -1.3963 -0.42649 0.02580 0.4916 1.4576
## school_effects_standard[5,1] -1.6972 -0.75841 -0.31379 0.1215 0.9957
## school_effects_standard[6,1] -1.6735 -0.68559 -0.21268 0.2821 1.2588
## school_effects_standard[7,1] -0.8777 0.07559 0.52240 0.9760 1.9762
## school_effects_standard[8,1] -1.4581 -0.38927 0.17397 0.7153 1.7875
## avg_effect -5.1900 2.38110 5.92440 9.6904 16.8460
mcmc_trace(draws, facet_args = list(ncol = 3))
mcmc_intervals(draws)
mcmc_acf_bar(draws)
mcmc_hist(draws, facet_args = list(ncol = 3))
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
calculate()
for transforming estimates to natural scaleThe calculate()
function can be used with the transformation function used in building the model to get the school-specific posteriors chains. This function is also how you would get posterior predictive values.
# Calculate school effects on original scale
school_effects <- avg_effect + avg_stddev * school_effects_standard
posterior_school_effects <- calculate(school_effects, draws)
As a sanity check that we parameterized our model correctly, we can compare the back-transformed school-specific estimates to the results from the Edward2 approach in the TensorFlow Probability documentation. The results are very similar.
# Posterior means via Edward2
edward2_school_means <-
data.frame(tool = "Edward2",
school = N,
#mean_school_effects_standard = c(0.61157268, 0.06430732, -0.25459746,
# 0.04828103, -0.36940941, -0.23154463,
# 0.49402338, 0.13042814),
mean = c(14.93237686, 7.50939941, 3.07602358, 7.21652555,
2.0329783, 3.41213799, 12.92509365, 8.36702347),
sd = 0)
edward2_pop_mean <- data.frame(tool = "Edward2", 'mean' = 6.48866844177, 'sd' = 0)
# hmc_mean_avg_stddev <- 2.46163249016
posterior_school_effects <- as.data.frame(as.matrix(posterior_school_effects))
# Relabel school measures
colnames(posterior_school_effects) <- N
# Summarise and combine all chains of interest for plotting
posterior_summaries <-
posterior_school_effects %>%
gather(key = school, value = value) %>%
group_by(school) %>%
summarise_all(funs(mean, sd))
school_summaries <-
posterior_summaries %>%
mutate(tool = "greta") %>%
rbind(edward2_school_means)
population_parameters <-
as.data.frame(as.matrix(draws)) %>%
select(avg_effect) %>%
summarise_all(funs(mean, sd)) %>%
mutate(tool = "greta") %>%
rbind(edward2_pop_mean)
ggplot(school_summaries, aes(x = school, y = mean, color = tool, shape = tool)) +
geom_errorbar(aes(ymin = mean - sd, ymax = mean + sd), width = 0.2) +
geom_point() +
geom_hline(data = population_parameters,
aes(yintercept = mean, linetype = 'Population mean', color = tool)) +
scale_linetype_manual(name = "", values = c(2, 2))
sessionInfo()
## R version 3.5.1 (2018-07-02)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS Sierra 10.12.6
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_GB.UTF-8/en_GB.UTF-8/en_GB.UTF-8/C/en_GB.UTF-8/en_GB.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] bindrcpp_0.2.2 bayesplot_1.6.0 forcats_0.3.0 stringr_1.3.1
## [5] dplyr_0.7.6 purrr_0.2.5 readr_1.1.1 tidyr_0.8.1
## [9] tibble_1.4.2 ggplot2_3.0.0 tidyverse_1.2.1 greta_0.2.5
##
## loaded via a namespace (and not attached):
## [1] nlme_3.1-137 fs_1.2.5 lubridate_1.7.4
## [4] RColorBrewer_1.1-2 progress_1.2.0 httr_1.3.1
## [7] rprojroot_1.3-2 tools_3.5.1 backports_1.1.2
## [10] R6_2.2.2 lazyeval_0.2.1 colorspace_1.3-2
## [13] withr_2.1.2 tidyselect_0.2.4 gridExtra_2.3
## [16] prettyunits_1.0.2 compiler_3.5.1 cli_1.0.0
## [19] rvest_0.3.2 xml2_1.2.0 influenceR_0.1.0
## [22] desc_1.2.0 labeling_0.3 scales_1.0.0
## [25] ggridges_0.5.1 tfruns_1.4 pkgdown_1.1.0.9000
## [28] commonmark_1.5 digest_0.6.17 rmarkdown_1.10.13
## [31] base64enc_0.1-3 pkgconfig_2.0.1 htmltools_0.3.6
## [34] htmlwidgets_1.2 rlang_0.2.2 readxl_1.1.0
## [37] htmldeps_0.1.1 rstudioapi_0.8 visNetwork_2.0.4
## [40] bindr_0.1.1 jsonlite_1.5 tensorflow_1.9
## [43] rgexf_0.15.3 magrittr_1.5 Matrix_1.2-14
## [46] Rcpp_0.12.18 munsell_0.5.0 reticulate_1.10
## [49] viridis_0.5.1 stringi_1.2.4 whisker_0.3-2
## [52] yaml_2.2.0 MASS_7.3-50 plyr_1.8.4
## [55] grid_3.5.1 parallel_3.5.1 listenv_0.7.0
## [58] crayon_1.3.4 lattice_0.20-35 haven_1.1.2
## [61] hms_0.4.2 knitr_1.20 pillar_1.3.0
## [64] igraph_1.2.2 reshape2_1.4.3 codetools_0.2-15
## [67] XML_3.98-1.16 glue_1.3.0 evaluate_0.11
## [70] downloader_0.4 modelr_0.1.2 cellranger_1.1.0
## [73] gtable_0.2.0 future_1.9.0 assertthat_0.2.0
## [76] broom_0.5.0 coda_0.19-1 roxygen2_6.1.0
## [79] viridisLite_0.3.0 memoise_1.1.0 Rook_1.1-1
## [82] DiagrammeR_1.0.0 globals_0.12.2 brew_1.0-6