Skip to content

Commit

Permalink
Merge pull request #312 from stan-dev/ppc-loo-psis_object
Browse files Browse the repository at this point in the history
Allow `psis_object` argument for all ppc-loo plots
  • Loading branch information
jgabry authored Jan 22, 2024
2 parents 09b813a + 4ec0743 commit 1fe6b67
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 53 deletions.
69 changes: 45 additions & 24 deletions R/ppc-loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
#' @param ... Currently unused.
#' @param lw A matrix of (smoothed) log weights with the same dimensions as
#' `yrep`. See [loo::psis()] and the associated `weights()` method as well as
#' the **Examples** section, below.
#' the **Examples** section, below. If `lw` is not specified then
#' `psis_object` can be provided and log weights will be extracted.
#' @param psis_object If using **loo** version `2.0.0` or greater, an
#' object returned by the `psis()` function (or by the `loo()` function
#' with argument `save_psis` set to `TRUE`).
#' @param alpha,size,fatten,linewidth Arguments passed to code geoms to control plot
#' aesthetics. For `ppc_loo_pit_qq()` and `ppc_loo_pit_overlay()`, `size` and
#' `alpha` are passed to [ggplot2::geom_point()] and
Expand Down Expand Up @@ -71,7 +75,7 @@
#' log_radon ~ floor + log_uranium + floor:log_uranium
#' + (1 + floor | county),
#' data = radon,
#' iter = 1000,
#' iter = 100,
#' chains = 2,
#' cores = 2
#' )
Expand All @@ -89,6 +93,8 @@
#' ppc_loo_pit_qq(y, yrep, lw = lw)
#' ppc_loo_pit_qq(y, yrep, lw = lw, compare = "normal")
#'
#' # can use the psis object instead of lw
#' ppc_loo_pit_qq(y, yrep, psis_object = psis1)
#'
#' # loo predictive intervals vs observations
#' keep_obs <- 1:50
Expand Down Expand Up @@ -138,8 +144,9 @@ NULL
#'
ppc_loo_pit_overlay <- function(y,
yrep,
lw,
lw = NULL,
...,
psis_object = NULL,
pit = NULL,
samples = 100,
size = 0.25,
Expand All @@ -158,6 +165,7 @@ ppc_loo_pit_overlay <- function(y,
y = y,
yrep = yrep,
lw = lw,
psis_object = psis_object,
pit = pit,
samples = samples,
bw = bw,
Expand Down Expand Up @@ -253,8 +261,9 @@ ppc_loo_pit_overlay <- function(y,
ppc_loo_pit_data <-
function(y,
yrep,
lw,
lw = NULL,
...,
psis_object = NULL,
pit = NULL,
samples = 100,
bw = "nrd0",
Expand All @@ -267,6 +276,7 @@ ppc_loo_pit_data <-
suggested_package("rstantools")
y <- validate_y(y)
yrep <- validate_predictions(yrep, length(y))
lw <- .get_lw(lw, psis_object)
stopifnot(identical(dim(yrep), dim(lw)))
pit <- rstantools::loo_pit(object = yrep, y = y, lw = lw)
}
Expand Down Expand Up @@ -295,22 +305,24 @@ ppc_loo_pit_data <-
#' @export
ppc_loo_pit_qq <- function(y,
yrep,
lw,
pit,
compare = c("uniform", "normal"),
lw = NULL,
...,
psis_object = NULL,
pit = NULL,
compare = c("uniform", "normal"),
size = 2,
alpha = 1) {
check_ignored_arguments(...)

compare <- match.arg(compare)
if (!missing(pit)) {
if (!is.null(pit)) {
stopifnot(is.numeric(pit), is_vector_or_1Darray(pit))
inform("'pit' specified so ignoring 'y','yrep','lw' if specified.")
} else {
suggested_package("rstantools")
y <- validate_y(y)
yrep <- validate_predictions(yrep, length(y))
lw <- .get_lw(lw, psis_object)
stopifnot(identical(dim(yrep), dim(lw)))
pit <- rstantools::loo_pit(object = yrep, y = y, lw = lw)
}
Expand Down Expand Up @@ -352,7 +364,7 @@ ppc_loo_pit <-
function(y,
yrep,
lw,
pit,
pit = NULL,
compare = c("uniform", "normal"),
...,
size = 2,
Expand All @@ -374,18 +386,14 @@ ppc_loo_pit <-
#' @rdname PPC-loo
#' @export
#' @template args-prob-prob_outer
#' @param psis_object If using **loo** version `2.0.0` or greater, an
#' object returned by the `psis()` function (or by the `loo()` function
#' with argument `save_psis` set to `TRUE`).
#' @param intervals For `ppc_loo_intervals()` and `ppc_loo_ribbon()`,
#' optionally a matrix of precomputed LOO predictive intervals
#' that can be specified instead of `yrep` and `lw` (these are both
#' ignored if `intervals` is specified). If not specified the intervals
#' are computed internally before plotting. If specified, `intervals`
#' must be a matrix with number of rows equal to the number of data points and
#' five columns in the following order: lower outer interval, lower inner
#' interval, median (50%), upper inner interval and upper outer interval
#' (column names are ignored).
#' @param intervals For `ppc_loo_intervals()` and `ppc_loo_ribbon()`, optionally
#' a matrix of pre-computed LOO predictive intervals that can be specified
#' instead of `yrep` (ignored if `intervals` is specified). If not specified
#' the intervals are computed internally before plotting. If specified,
#' `intervals` must be a matrix with number of rows equal to the number of
#' data points and five columns in the following order: lower outer interval,
#' lower inner interval, median (50%), upper inner interval and upper outer
#' interval (column names are ignored).
#' @param order For `ppc_loo_intervals()`, a string indicating how to arrange
#' the plotted intervals. The default (`"index"`) is to plot them in the
#' order of the observations. The alternative (`"median"`) arranges them
Expand All @@ -403,9 +411,9 @@ ppc_loo_intervals <-
function(y,
yrep,
psis_object,
...,
subset = NULL,
intervals = NULL,
...,
prob = 0.5,
prob_outer = 0.9,
alpha = 0.33,
Expand Down Expand Up @@ -498,11 +506,10 @@ ppc_loo_intervals <-
ppc_loo_ribbon <-
function(y,
yrep,
lw,
psis_object,
...,
subset = NULL,
intervals = NULL,
...,
prob = 0.5,
prob_outer = 0.9,
alpha = 0.33,
Expand Down Expand Up @@ -720,3 +727,17 @@ ppc_loo_ribbon <-

list(xs = xs, unifs = bc_mat)
}

# Extract log weights from psis_object if provided
.get_lw <- function(lw = NULL, psis_object = NULL) {
if (is.null(lw) && is.null(psis_object)) {
abort("One of 'lw' and 'psis_object' must be specified.")
} else if (is.null(lw)) {
suggested_package("loo", min_version = "2.0.0")
if (!loo::is.psis(psis_object)) {
abort("If specified, 'psis_object' must be a PSIS object from the loo package.")
}
lw <- loo::weights.importance_sampling(psis_object)
}
lw
}
52 changes: 28 additions & 24 deletions man/PPC-loo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 12 additions & 5 deletions tests/testthat/test-ppc-loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ test_that("ppc_loo_pit_overlay returns ggplot object", {
skip_if_not_installed("rstanarm")
skip_if_not_installed("loo")
expect_gg(ppc_loo_pit_overlay(y, yrep, lw, samples = 25))
expect_gg(ppc_loo_pit_overlay(y, yrep, psis_object = psis1, samples = 25))
})

test_that("ppc_loo_pit_overlay warns about binary data", {
Expand Down Expand Up @@ -65,29 +66,35 @@ test_that("ppc_loo_pit_qq returns ggplot object", {
skip_if_not_installed("rstanarm")
skip_if_not_installed("loo")
expect_gg(p1 <- ppc_loo_pit_qq(y, yrep, lw))
expect_gg(p2 <- ppc_loo_pit_qq(y, yrep, psis_object = psis1))
expect_equal(p1$labels$x, "Uniform")
expect_gg(p2 <- ppc_loo_pit_qq(y, yrep, lw, compare = "normal"))
expect_equal(p2$labels$x, "Normal")
expect_equal(p1$data, p2$data)
expect_gg(p3 <- ppc_loo_pit_qq(y, yrep, lw, compare = "normal"))
expect_equal(p3$labels$x, "Normal")
})

test_that("ppc_loo_pit functions work when pit specified instead of y,yrep,lw", {
skip_if_not_installed("rstanarm")
skip_if_not_installed("loo")
expect_gg(ppc_loo_pit_qq(pit = pits))
expect_message(
ppc_loo_pit_qq(y = y, yrep = yrep, lw = lw, pit = pits),
p1 <- ppc_loo_pit_qq(y = y, yrep = yrep, lw = lw, pit = pits),
"'pit' specified so ignoring 'y','yrep','lw' if specified"
)
expect_message(
p2 <- ppc_loo_pit_qq(pit = pits)
)
expect_equal(p1$data, p2$data)

expect_gg(ppc_loo_pit_overlay(pit = pits))

expect_gg(p1 <- ppc_loo_pit_overlay(pit = pits))
expect_message(
ppc_loo_pit_overlay(y = y, yrep = yrep, lw = lw, pit = pits),
"'pit' specified so ignoring 'y','yrep','lw' if specified"
)
})



test_that("ppc_loo_intervals returns ggplot object", {
skip_if_not_installed("rstanarm")
skip_if_not_installed("loo")
Expand Down

0 comments on commit 1fe6b67

Please sign in to comment.