Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subsampling LOO estimates with diff-est-srs-wor start #496

Open
wants to merge 106 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 91 commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
bba6adf
subsampling LOO estimates with diff-est-srs-wor start
avehtari Apr 10, 2024
781e331
put back unnecessarily removed weights back in kfold
avehtari Apr 10, 2024
d50f4bd
Revert "put back unnecessarily removed weights back in kfold"
avehtari Apr 10, 2024
4de50b7
subsampling LOO for acc and pctcor
avehtari Apr 10, 2024
c50bdbf
ignore nloo if validate_search=FALSE
avehtari Apr 10, 2024
649f0ad
fix tests
avehtari Apr 11, 2024
aba8670
fix mse interval for delta=TRUE
avehtari Apr 12, 2024
81b5dc3
don't stop due to repeated arguments
avehtari Apr 12, 2024
0ed8391
normal approximation for mse, rmse, and R2
avehtari Apr 15, 2024
a36a3d8
rename internal function select -> .select
avehtari Apr 16, 2024
2c846a3
with delta and mse/rmse/R2/acc/pctcorr/auc, plot values in orig scale
avehtari Apr 16, 2024
fbda70e
don't warn about subsampling
avehtari Apr 16, 2024
1fa7fcd
improve messages
avehtari Apr 17, 2024
7e7fc7a
if available, use progressr for parallel progress bar
avehtari Apr 17, 2024
09223c6
verbosity improvements
avehtari Apr 25, 2024
45e22a6
fix
avehtari Apr 25, 2024
1cb78fc
use do_call instead of do.call
avehtari Apr 25, 2024
98757a9
add progress and progressr to Suggests
avehtari Apr 25, 2024
bf5af29
Merge branch 'master' into fix-subsampling
avehtari Jun 11, 2024
01f3bed
remove unneeded code
avehtari Jun 27, 2024
a5c5103
remove unnecessary sum
avehtari Jun 27, 2024
a223725
revert the addition of correct_baseline
avehtari Jun 28, 2024
24fc370
remove unneeded code
avehtari Jun 28, 2024
36f3543
document deltas=TRUE change
avehtari Jun 28, 2024
eeef49a
wcv -> wobs in summary_funs
avehtari Jun 28, 2024
1fc4669
newline in startup message to make it more readable
avehtari Jun 28, 2024
9139ce0
Merge remote-tracking branch 'upstream/master' into fix-subsampling
fweber144 Jun 30, 2024
10ab731
rename remaining occurrences of `wcv` to `wobs`
fweber144 Jun 30, 2024
c23fa78
re-add a comment
fweber144 Jun 30, 2024
b81e401
Merge remote-tracking branch 'upstream/master' into fix-subsampling
fweber144 Jul 4, 2024
71a2198
progressr: remove code that is part of the end-user's API (see
fweber144 Jul 11, 2024
b41ad1e
use `use_progressr` consistently
fweber144 Jul 11, 2024
81b4bf0
package `progress` is no longer needed in the "Suggests" dependencies
fweber144 Jul 11, 2024
497a245
add function `get_use_progressr()` to avoid redundancies;
fweber144 Jul 11, 2024
5ac4041
rename `p` to `progressor_obj` to identify it more clearly and
fweber144 Jul 11, 2024
825fc3d
use argument `steps` of `progressr::progressor()` explicitly and
fweber144 Jul 11, 2024
1655c84
remove unnecessary `""` in the `progressor_obj()` call
fweber144 Jul 11, 2024
1633137
use a simpler solution for identifying whether `progressr` should be …
fweber144 Jul 11, 2024
0996e7c
add the possibility to use `progressr` at the remaining occurrences o…
fweber144 Jul 11, 2024
8488e3e
fix a bug (`could not find function "do_call"`) when using the `doFut…
fweber144 Jul 12, 2024
118838c
remove `.select <- .select` (the issue does not occur when installing
fweber144 Jul 15, 2024
d663df8
fix a bug when checking arguments in `cv_varsel.vsel()`
fweber144 Jul 17, 2024
b6fe949
minor cleaning for consistency
fweber144 Jul 17, 2024
d9fdbaf
don't include the argument content in the message as the argument con…
fweber144 Jul 17, 2024
5b2f283
add functionality for option deltas='mixed'
avehtari Aug 11, 2024
17be148
remove option `baseline = "best"`
avehtari Aug 11, 2024
6b6520a
attempt to fix vsel.summary
avehtari Aug 11, 2024
fc8c665
move code for deltas='mixed' to plot.vsel
avehtari Aug 12, 2024
99e9c79
fixes
avehtari Aug 13, 2024
cd4b248
docs: fix minor typos
fweber144 Jul 18, 2024
7176dfe
avoid `object` within `cv_varsel.refmodel()` (for consistency; I don'…
fweber144 Jul 18, 2024
4109aae
fix a verbose message (at `?projpred::cv_varsel`, the documentation for
fweber144 Jul 18, 2024
7616128
mention thinning in the verbose message which gives information about
fweber144 Jul 18, 2024
48a6af1
minor cleaning
fweber144 Aug 18, 2024
c0593c6
fix usage of argument `summaries_fast` (at that place, `sel_cv$summar…
fweber144 Aug 18, 2024
0c85eb6
use argument `summaries_fast` as it was probably intended to
fweber144 Aug 18, 2024
4f893f9
fixup! use argument `summaries_fast` as it was probably intended to
fweber144 Aug 18, 2024
d40acd8
fixup! fixup! use argument `summaries_fast` as it was probably intend…
fweber144 Aug 18, 2024
d1d37bd
fix input for argument `search_path_fulldata` when running fast LOO-C…
fweber144 Aug 18, 2024
b9f8368
for argument `verbose`, default to a new global option:
fweber144 Aug 19, 2024
a9ce55f
argument `summaries_fast` should not change either (when calling `cv_…
fweber144 Aug 20, 2024
23b4a2a
remove unused object `n_arg_nms_internal_used`
fweber144 Aug 20, 2024
cc26cf7
define `arg_nms_internal_used` more straightforwardly
fweber144 Aug 20, 2024
66f2e34
minor enhancements
fweber144 Aug 20, 2024
e61f916
fix a verbose message (at `?projpred::cv_varsel`, the documentation for
fweber144 Aug 20, 2024
750c81a
fix a message when using standard importance sampling (SIS)
fweber144 Aug 20, 2024
f84fe55
remove fragment `verb_txt_start <-`
fweber144 Aug 22, 2024
8dde0bc
fix verbose message
fweber144 Aug 22, 2024
9c4d1a4
docs: abbreviate the performance statistics appropriately
fweber144 Aug 23, 2024
ea18e9a
UNFINISHED: move out the new "mixed deltas" variant of `plot.vsel()`,…
fweber144 Aug 23, 2024
aea7f08
Revert "UNFINISHED: move out the new "mixed deltas" variant of `plot.…
fweber144 Aug 23, 2024
867f29f
in `.onAttach()`, keep the temporary "NOTE" in separate lines (to
fweber144 Aug 23, 2024
2d9a652
add comments in `summary_funs.R`
fweber144 Aug 23, 2024
35cc542
simplify `summaries_fast_sub <- varsel$summaries_fast$sub` and `summa…
fweber144 Aug 23, 2024
19948e3
`loo_inds` as stored in `vsel` objects was unused so far
fweber144 Aug 23, 2024
f26e4fb
in `get_stat()`, the `is.null(summaries_fast)` checks are not necessary
fweber144 Aug 23, 2024
319c2b7
avoid object name `n` at more places
fweber144 Aug 23, 2024
dfcc58c
the definition of `loo_ref_oscale` does not make sense to be placed a…
fweber144 Aug 23, 2024
f53bc48
simplify an SRS-WOR `value` computation (if `mu_baseline` is `NULL`, …
fweber144 Aug 23, 2024
59afc97
simplify initialization of `est_list`
fweber144 Aug 23, 2024
2727cff
avoid redundant computations by moving `sqrt(srs_diffe$v_y_hat + srs_…
fweber144 Aug 23, 2024
52fd9c3
add an early error for `!validate_search && nloo < refmodel[["nobs"]]`
fweber144 Aug 23, 2024
81ec495
add a comment and a check in `loo_varsel()` for `!validate_search && …
fweber144 Aug 23, 2024
348827c
Revert "avoid redundant computations by moving `sqrt(srs_diffe$v_y_ha…
fweber144 Aug 26, 2024
f4f9760
simplify definitions of `mu_baseline` (possible because
fweber144 Aug 26, 2024
519dac2
fixup! `loo_inds` as stored in `vsel` objects was unused so far
fweber144 Aug 26, 2024
a7458b9
move out the new "mixed deltas" variant of `plot.vsel()`, the
fweber144 Aug 23, 2024
43878e1
use a consistent order of the `if` cases differentiating between
fweber144 Aug 29, 2024
925f2cd
remove unused `var_mse_e` definition
fweber144 Aug 29, 2024
040d05e
there was only one use of `var_mse_e` and since `value_se`
fweber144 Aug 29, 2024
0d73c8e
remove unused `mu_baseline` (only unused in case of
fweber144 Aug 29, 2024
288d948
Merge remote-tracking branch 'upstream/master' into fix-subsampling
fweber144 Sep 15, 2024
6a85b41
re-document
fweber144 Sep 18, 2024
ef33da6
add a placeholder for the documentation of argument `summaries_fast` …
fweber144 Sep 18, 2024
cdaf384
avoid partial argument matching of 'w' to 'wobs' in `.srs_diff_est_w(…
fweber144 Sep 18, 2024
8aacd5d
fixup! remove unused `mu_baseline` (only unused in case of
fweber144 Sep 18, 2024
9e79883
Merge remote-tracking branch 'upstream/master' into fix-subsampling
fweber144 Sep 25, 2024
1127412
`vsel_obj$nloo` can be `NULL` (for `vsel_obj` created by `varsel()`), so
fweber144 Sep 25, 2024
c331cb4
fix the `get_stat()` call for the reference model statistics (`loo_inds`
fweber144 Sep 25, 2024
8389d39
Tests: Subsampled PSIS-LOO-CV is not supported for `validate_search =…
fweber144 Sep 25, 2024
a246dd3
fix the `stat %in% c("acc", "pctcorr", "auc")` case in `get_stat()`
fweber144 Sep 26, 2024
2360630
fix `.tabulate_stats()` (`catmaxprb()` also needs to be
fweber144 Sep 26, 2024
7469aea
fix `.tabulate_stats()` (several steps in case of the latent projection
fweber144 Sep 26, 2024
0dcfcf2
Since `summaries_fast` is created by a call to `loo_varsel()` with
fweber144 Oct 9, 2024
6151b0c
Revert changes that are unrelated to subsampled LOO-CV (to find out
fweber144 Oct 16, 2024
360354a
Adapt the existing tests to work with the new implementation of subsa…
fweber144 Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ Suggests:
doParallel,
future,
future.callr,
doFuture
doFuture,
progressr
LinkingTo: Rcpp, RcppArmadillo
Additional_repositories:
https://mc-stan.org/r-packages/
Expand Down
380 changes: 188 additions & 192 deletions R/cv_varsel.R

Large diffs are not rendered by default.

18 changes: 16 additions & 2 deletions R/divergence_minimizers.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,25 @@ divmin <- function(
if (!requireNamespace("iterators", quietly = TRUE)) {
stop("Please install the 'iterators' package.")
}
if (verbose_divmin && use_progressr()) {
progressor_obj <- progressr::progressor(length(formulas))
} else {
progressor_obj <- NULL
}
dot_args <- list(...)
`%do_projpred%` <- foreach::`%dopar%`
outdmin <- foreach::foreach(
formula_s = formulas,
projpred_var_s = iterators::iter(projpred_var, by = "column"),
projpred_formula_no_random_s = projpred_formulas_no_random,
.export = c("sdivmin", "projpred_random", "dot_args"),
.packages = c("projpred"),
.export = c("sdivmin", "projpred_random", "dot_args", "progressor_obj"),
.noexport = c(
"object", "p_sel", "search_path", "p_ref", "refmodel", "formulas",
"projpred_var", "projpred_ws_aug", "projpred_formulas_no_random"
)
) %do_projpred% {
if (!is.null(progressor_obj)) progressor_obj()
mssgs_warns_capt <- capt_mssgs_warns(
soutdmin <- do.call(
sdivmin,
Expand Down Expand Up @@ -649,19 +656,26 @@ divmin_augdat <- function(
if (!requireNamespace("iterators", quietly = TRUE)) {
stop("Please install the 'iterators' package.")
}
if (verbose_divmin && use_progressr()) {
progressor_obj <- progressr::progressor(ncol(projpred_ws_aug))
} else {
progressor_obj <- NULL
}
dot_args <- list(...)
`%do_projpred%` <- foreach::`%dopar%`
outdmin <- foreach::foreach(
projpred_w_aug_s = iterators::iter(projpred_ws_aug, by = "column"),
.packages = c("projpred"),
.export = c(
"sdivmin", "formula", "data", "family", "projpred_formula_no_random",
"projpred_random", "dot_args"
"projpred_random", "dot_args", "progressor_obj"
),
.noexport = c(
"object", "p_sel", "search_path", "p_ref", "refmodel", "projpred_var",
"projpred_ws_aug", "linkobjs"
)
) %do_projpred% {
if (!is.null(progressor_obj)) progressor_obj()
mssgs_warns_capt <- capt_mssgs_warns(
soutdmin <- do.call(
sdivmin,
Expand Down
2 changes: 1 addition & 1 deletion R/glmfun.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ standardization <- function(x, center = TRUE, scale = TRUE, weights = NULL) {
mx <- rep(0, ncol(x))
}
if (scale) {
sx <- apply(x, 2, weighted.sd, w)
sx <- apply(x, 2, .weighted_sd, w)
fweber144 marked this conversation as resolved.
Show resolved Hide resolved
} else {
sx <- rep(1, ncol(x))
}
Expand Down
24 changes: 12 additions & 12 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ plot.vsel <- function(
# Parse input:
object <- x
validate_vsel_object_stats(object, stats, resp_oscale = resp_oscale)
baseline <- validate_baseline(object$refmodel, baseline, deltas)
baseline <- validate_baseline(object, baseline, deltas)
if (!is.null(ranking_repel) && !requireNamespace("ggrepel", quietly = TRUE)) {
warning("Package 'ggrepel' is needed for a non-`NULL` argument ",
"`ranking_repel`, but could not be found. Setting `ranking_repel` ",
Expand Down Expand Up @@ -1065,11 +1065,11 @@ plot.vsel <- function(
# direction = 1)
###
}
if (all(stats %in% c("rmse", "auc"))) {
if (all(stats %in% c("auc"))) {
ci_type <- "bootstrap "
} else if (all(stats %in% c("gmpd"))) {
ci_type <- "exponentiated normal-approximation "
} else if (all(!stats %in% c("rmse", "auc", "gmpd"))) {
} else if (all(!stats %in% c("auc", "gmpd"))) {
ci_type <- "normal-approximation "
} else {
ci_type <- ""
Expand Down Expand Up @@ -1158,23 +1158,23 @@ plot.vsel <- function(
#' are again all observations because the test set is the same as the training
#' set). Available statistics are:
#' * `"elpd"`: expected log (pointwise) predictive density (for a new
#' dataset). Estimated by the sum of the observation-specific log predictive
#' density values (with each of these predictive density values being
#' a---possibly weighted---average across the parameter draws).
#' * `"mlpd"`: mean log predictive density, that is, `"elpd"` divided by the
#' number of observations.
#' dataset) (ELPD). Estimated by the sum of the observation-specific log
#' predictive density values (with each of these predictive density values
#' being a---possibly weighted---average across the parameter draws).
#' * `"mlpd"`: mean log predictive density (MLPD), that is, the ELPD divided
#' by the number of observations.
#' * `"gmpd"`: geometric mean predictive density (GMPD), that is, [exp()] of
#' `"mlpd"`. The GMPD is especially helpful for discrete response families
#' the MLPD. The GMPD is especially helpful for discrete response families
#' (because there, the GMPD is bounded by zero and one). For the corresponding
#' standard error, the delta method is used. The corresponding confidence
#' interval type is "exponentiated normal approximation" because the
#' confidence interval bounds are the exponentiated confidence interval bounds
#' of the `"mlpd"`.
#' of the MLPD.
#' * `"mse"`: mean squared error (only available in the situations mentioned
#' in section "Details" below).
#' * `"rmse"`: root mean squared error (only available in the situations
#' mentioned in section "Details" below). For the corresponding standard error
#' and lower and upper confidence interval bounds, bootstrapping is used.
#' and lower and upper confidence interval bounds, the delta method is used.
#' * `"acc"` (or its alias, `"pctcorr"`): classification accuracy (only
#' available in the situations mentioned in section "Details" below). By
#' "classification accuracy", we mean the proportion of correctly classified
Expand Down Expand Up @@ -1283,7 +1283,7 @@ summary.vsel <- function(
...
) {
validate_vsel_object_stats(object, stats, resp_oscale = resp_oscale)
baseline <- validate_baseline(object$refmodel, baseline, deltas)
baseline <- validate_baseline(object, baseline, deltas)

# Initialize output:
out <- c(
Expand Down
32 changes: 21 additions & 11 deletions R/misc.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
.onAttach <- function(...) {
ver <- utils::packageVersion("projpred")
msg <- paste0("This is projpred version ", ver, ".")
msg <- paste0(msg, " ", "NOTE: In projpred 2.7.0, the default search method ",
"was set to \"forward\" (for all kinds of models).")
msg <- paste0(msg, "\n", "NOTE: In projpred 2.7.0, the default search ",
"method was set to \"forward\" (for all kinds of models).")
packageStartupMessage(msg)
}

Expand All @@ -14,7 +14,7 @@ nms_y_wobs_test <- function(wobs_nm = "wobs") {
c("y", "y_oscale", wobs_nm)
}

weighted.sd <- function(x, w, na.rm = FALSE) {
.weighted_sd <- function(x, w, na.rm = FALSE) {
if (na.rm) {
ind <- !is.na(w) & !is.na(x)
n <- sum(ind)
Expand Down Expand Up @@ -63,10 +63,10 @@ ilinkfun_raw <- function(x, link_nm) {
return(basic_ilink(x))
}

auc <- function(x) {
.auc <- function(x) {
fweber144 marked this conversation as resolved.
Show resolved Hide resolved
resp <- x[, 1]
pred <- x[, 2]
wcv <- x[, 3]
wobs <- x[, 3]

# Make it explicit that `x` should not be used anymore (due to the possibility
# of `NA`s, but also due to the re-ordering):
Expand All @@ -77,9 +77,9 @@ auc <- function(x) {

resp <- resp[ord]
pred <- pred[ord]
wcv <- wcv[ord]
wobs <- wobs[ord]

w0 <- w1 <- wcv
w0 <- w1 <- wobs
# CAUTION: The following check also ensures that `resp` does not have `NA`s:
stopifnot(all(resp %in% c(0, 1)))
w0[resp == 1] <- 0 # for calculating the false positive rate (fpr)
Expand Down Expand Up @@ -152,8 +152,8 @@ validate_vsel_object_stats <- function(object, stats, resp_oscale = TRUE) {
}
resp_oscale <- object$refmodel$family$for_latent && resp_oscale

trad_stats <- c("elpd", "mlpd", "gmpd", "mse", "rmse", "acc", "pctcorr",
"auc")
trad_stats <- c("elpd", "mlpd", "gmpd", "mse", "rmse", "R2",
"acc", "pctcorr", "auc")
trad_stats_binom_only <- c("acc", "pctcorr", "auc")
augdat_stats <- c("elpd", "mlpd", "gmpd", "acc", "pctcorr")
resp_oscale_stats_fac <- augdat_stats
Expand Down Expand Up @@ -196,17 +196,22 @@ validate_vsel_object_stats <- function(object, stats, resp_oscale = TRUE) {
return(invisible(TRUE))
}

validate_baseline <- function(refmodel, baseline, deltas) {
validate_baseline <- function(vsel_obj, baseline, deltas) {
stopifnot(!is.null(baseline))
if (!(baseline %in% c("ref", "best"))) {
stop("Argument 'baseline' must be either 'ref' or 'best'.")
}
if (baseline == "ref" && deltas == TRUE && inherits(refmodel, "datafit")) {
if (baseline == "ref" && deltas == TRUE &&
inherits(vsel_obj$refmodel, "datafit")) {
# no reference model (or the results missing for some other reason),
# so cannot compute differences (or ratios) vs. the reference model
stop("Cannot use deltas = TRUE and baseline = 'ref' when there is no ",
"reference model.")
}
if (baseline == "best" && vsel_obj$cv_method == "LOO" &&
vsel_obj$nloo < vsel_obj$refmodel$nobs) {
stop("Cannot use `baseline = \"best\"` in case of subsampled LOO-CV.")
}
return(baseline)
}

Expand Down Expand Up @@ -705,3 +710,8 @@ element_unq <- function(list_obj, nm) {
}
return(el_unq)
}

use_progressr <- function() {
getOption("projpred.use_progressr",
requireNamespace("progressr", quietly = TRUE) && interactive())
}
Loading