Skip to content

Commit

Permalink
Merge pull request #230 from yannmclatchie/order-stat-warning
Browse files Browse the repository at this point in the history
Add order statistic warning
  • Loading branch information
jgabry authored Nov 15, 2023
2 parents fc0cde9 + 3526fbf commit 426c2d8
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 1 deletion.
73 changes: 72 additions & 1 deletion R/loo_compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@
#' distribution, a practice derived for Gaussian linear models or
#' asymptotically, and which only applies to nested models in any case.
#'
#' If more than \eqn{11} models are compared, we internally recompute the model
#' differences using the median model by ELPD as the baseline model. We then
#' estimate whether the differences in predictive performance are potentially
#' due to chance as described by McLatchie and Vehtari (2023). This will flag
#' a warning if it is deemed that there is a risk of over-fitting due to the
#' selection process. In that case users are recommended to avoid model
#' selection based on LOO-CV, and instead to favor model averaging/stacking or
#' projection predictive inference.
#' @seealso
#' * The [FAQ page](https://mc-stan.org/loo/articles/online-only/faq.html) on
#' the __loo__ website for answers to frequently asked questions.
#' @template loo-and-psis-references
#' @template loo-and-compare-references
#'
#' @examples
#' # very artificial example, just for demonstration!
Expand Down Expand Up @@ -108,6 +116,9 @@ loo_compare.default <- function(x, ...) {
comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp)
rownames(comp) <- rnms

# run order statistics-based checks on models
loo_order_stat_check(loos, ord)

class(comp) <- c("compare.loo", class(comp))
return(comp)
}
Expand Down Expand Up @@ -270,3 +281,63 @@ loo_compare_order <- function(loos){
ord <- order(tmp[grep("^elpd", rnms), ], decreasing = TRUE)
ord
}

#' Perform checks on `"loo"` objects __after__ comparison
#' @noRd
#' @keywords internal
#' @param loos List of `"loo"` objects.
#' @param ord List of `"loo"` object orderings.
#' @return Nothing, just possibly throws errors/warnings.
loo_order_stat_check <- function(loos, ord) {

## breaks

if (length(loos) <= 11L) {
# procedure cannot be diagnosed for fewer than ten candidate models
# (total models = worst model + ten candidates)
# break from function
return(NULL)
}

## warnings

# compute the elpd differences from the median model
baseline_idx <- middle_idx(ord)
diffs <- mapply(FUN = elpd_diffs, loos[ord[baseline_idx]], loos[ord])
elpd_diff <- apply(diffs, 2, sum)

# estimate the standard deviation of the upper-half-normal
diff_median <- stats::median(elpd_diff)
elpd_diff_trunc <- elpd_diff[elpd_diff >= diff_median]
n_models <- sum(!is.na(elpd_diff_trunc))
candidate_sd <- sqrt(1 / n_models * sum(elpd_diff_trunc^2, na.rm = TRUE))

# estimate expected best diff under null hypothesis
K <- length(loos) - 1
order_stat <- order_stat_heuristic(K, candidate_sd)

if (max(elpd_diff) <= order_stat) {
# flag warning if we suspect no model is theoretically better than the baseline
warning("Difference in performance potentially due to chance.",
"See McLatchie and Vehtari (2023) for details.",
call. = FALSE)
}
}

#' Returns the middle index of a vector
#' @noRd
#' @keywords internal
#' @param vec A vector.
#' @return Integer index value.
middle_idx <- function(vec) floor(length(vec) / 2)

#' Computes maximum order statistic from K Gaussians
#' @noRd
#' @keywords internal
#' @param K Number of Gaussians.
#' @param c Scaling of the order statistic.
#' @return Numeric expected maximum from K samples from a Gaussian with mean
#' zero and scale `"c"`
order_stat_heuristic <- function(K, c) {
qnorm(p = 1 - 1 / (K * 2), mean = 0, sd = c)
}
14 changes: 14 additions & 0 deletions man-roxygen/loo-and-compare-references.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#' @references
#' Vehtari, A., Gelman, A., and Gabry, J. (2017). Practical Bayesian model
#' evaluation using leave-one-out cross-validation and WAIC.
#' *Statistics and Computing*. 27(5), 1413--1432. doi:10.1007/s11222-016-9696-4
#' ([journal version](https://link.springer.com/article/10.1007/s11222-016-9696-4),
#' [preprint arXiv:1507.04544](https://arxiv.org/abs/1507.04544)).
#'
#' Vehtari, A., Simpson, D., Gelman, A., Yao, Y., and Gabry, J. (2019).
#' Pareto smoothed importance sampling.
#' [preprint arXiv:1507.02646](https://arxiv.org/abs/1507.02646)
#'
#' McLatchie, Y., and Vehtari, A. (2023).
#' Efficient estimation and correction of selection-induced bias with order statistics.
#' [preprint arXiv:2309.03742](https://arxiv.org/abs/2309.03742)
14 changes: 14 additions & 0 deletions man/loo_compare.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions tests/testthat/test_compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ test_that("loo_compare throws appropriate warnings", {
attr(w3, "yhash") <- "a"
attr(w4, "yhash") <- "b"
expect_warning(loo_compare(w3, w4), "Not all models have the same y variable")

set.seed(123)
w_list <- lapply(1:25, function(x) SW(waic(LLarr + rnorm(1, 0, 0.1))))
expect_warning(loo_compare(w_list),
"Difference in performance potentially due to chance")

w_list_short <- lapply(1:4, function(x) SW(waic(LLarr + rnorm(1, 0, 0.1))))
expect_no_warning(loo_compare(w_list_short))
})


Expand Down

0 comments on commit 426c2d8

Please sign in to comment.