Skip to content

Commit

Permalink
Revert changes that are unrelated to subsampled LOO-CV (to find out
Browse files Browse the repository at this point in the history
why some test snapshots changed unexpectedly).
  • Loading branch information
fweber144 committed Oct 16, 2024
1 parent 0dcfcf2 commit 6151b0c
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 196 deletions.
3 changes: 1 addition & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ Suggests:
doParallel,
future,
future.callr,
doFuture,
progressr
doFuture
LinkingTo: Rcpp, RcppArmadillo
Additional_repositories:
https://mc-stan.org/r-packages/
Expand Down
226 changes: 96 additions & 130 deletions R/cv_varsel.R

Large diffs are not rendered by default.

18 changes: 2 additions & 16 deletions R/divergence_minimizers.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,25 +91,18 @@ divmin <- function(
if (!requireNamespace("iterators", quietly = TRUE)) {
stop("Please install the 'iterators' package.")
}
if (verbose_divmin && use_progressr()) {
progressor_obj <- progressr::progressor(length(formulas))
} else {
progressor_obj <- NULL
}
dot_args <- list(...)
`%do_projpred%` <- foreach::`%dopar%`
outdmin <- foreach::foreach(
formula_s = formulas,
projpred_var_s = iterators::iter(projpred_var, by = "column"),
projpred_formula_no_random_s = projpred_formulas_no_random,
.packages = c("projpred"),
.export = c("sdivmin", "projpred_random", "dot_args", "progressor_obj"),
.export = c("sdivmin", "projpred_random", "dot_args"),
.noexport = c(
"object", "p_sel", "search_path", "p_ref", "refmodel", "formulas",
"projpred_var", "projpred_ws_aug", "projpred_formulas_no_random"
)
) %do_projpred% {
if (!is.null(progressor_obj)) progressor_obj()
mssgs_warns_capt <- capt_mssgs_warns(
soutdmin <- do.call(
sdivmin,
Expand Down Expand Up @@ -656,26 +649,19 @@ divmin_augdat <- function(
if (!requireNamespace("iterators", quietly = TRUE)) {
stop("Please install the 'iterators' package.")
}
if (verbose_divmin && use_progressr()) {
progressor_obj <- progressr::progressor(ncol(projpred_ws_aug))
} else {
progressor_obj <- NULL
}
dot_args <- list(...)
`%do_projpred%` <- foreach::`%dopar%`
outdmin <- foreach::foreach(
projpred_w_aug_s = iterators::iter(projpred_ws_aug, by = "column"),
.packages = c("projpred"),
.export = c(
"sdivmin", "formula", "data", "family", "projpred_formula_no_random",
"projpred_random", "dot_args", "progressor_obj"
"projpred_random", "dot_args"
),
.noexport = c(
"object", "p_sel", "search_path", "p_ref", "refmodel", "projpred_var",
"projpred_ws_aug", "linkobjs"
)
) %do_projpred% {
if (!is.null(progressor_obj)) progressor_obj()
mssgs_warns_capt <- capt_mssgs_warns(
soutdmin <- do.call(
sdivmin,
Expand Down
2 changes: 1 addition & 1 deletion R/glmfun.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ standardization <- function(x, center = TRUE, scale = TRUE, weights = NULL) {
mx <- rep(0, ncol(x))
}
if (scale) {
sx <- apply(x, 2, .weighted_sd, w)
sx <- apply(x, 2, weighted.sd, w)
} else {
sx <- rep(1, ncol(x))
}
Expand Down
13 changes: 4 additions & 9 deletions R/misc.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
.onAttach <- function(...) {
ver <- utils::packageVersion("projpred")
msg <- paste0("This is projpred version ", ver, ".")
msg <- paste0(msg, "\n", "NOTE: In projpred 2.7.0, the default search ",
"method was set to \"forward\" (for all kinds of models).")
msg <- paste0(msg, " ", "NOTE: In projpred 2.7.0, the default search method ",
"was set to \"forward\" (for all kinds of models).")
packageStartupMessage(msg)
}

Expand All @@ -14,7 +14,7 @@ nms_y_wobs_test <- function(wobs_nm = "wobs") {
c("y", "y_oscale", wobs_nm)
}

.weighted_sd <- function(x, w, na.rm = FALSE) {
weighted.sd <- function(x, w, na.rm = FALSE) {
if (na.rm) {
ind <- !is.na(w) & !is.na(x)
n <- sum(ind)
Expand Down Expand Up @@ -63,7 +63,7 @@ ilinkfun_raw <- function(x, link_nm) {
return(basic_ilink(x))
}

.auc <- function(x) {
auc <- function(x) {
resp <- x[, 1]
pred <- x[, 2]
wobs <- x[, 3]
Expand Down Expand Up @@ -710,8 +710,3 @@ element_unq <- function(list_obj, nm) {
}
return(el_unq)
}

use_progressr <- function() {
getOption("projpred.use_progressr",
requireNamespace("progressr", quietly = TRUE) && interactive())
}
18 changes: 9 additions & 9 deletions R/summary_funs.R
Original file line number Diff line number Diff line change
Expand Up @@ -334,14 +334,14 @@ get_stat <- function(summaries, summaries_baseline = NULL,
} else {
# full LOO estimator
value <- mean(wobs * (mu - y)^2)
value_se <- .weighted_sd((mu - y)^2, wobs) / sqrt(n_full)
value_se <- weighted.sd((mu - y)^2, wobs) / sqrt(n_full)
}
# store for later calculations
mse_e <- value
if (!is.null(summaries_baseline)) {
# delta=TRUE, variance of difference of two normally distributed
mse_b <- mean(wobs * (mu_baseline - y)^2)
var_mse_b <- .weighted_sd((mu_baseline - y)^2, wobs)^2 / n_full
var_mse_b <- weighted.sd((mu_baseline - y)^2, wobs)^2 / n_full
if (n_loo < n_full) {
mse_e_fast <- mean(wobs * (summaries_fast$mu - y)^2)
srs_diffe <-
Expand Down Expand Up @@ -370,7 +370,7 @@ get_stat <- function(summaries, summaries_baseline = NULL,
mse_y <- mean(wobs * (mean(y) - y)^2)
value <- 1 - mse_e / mse_y - ifelse(is.null(summaries_baseline), 0, 1 - mse_b / mse_y)
# the first-order Taylor approximation of the variance
var_mse_y <- .weighted_sd((mean(y) - y)^2, wobs)^2 / n_full
var_mse_y <- weighted.sd((mean(y) - y)^2, wobs)^2 / n_full
if (n_loo < n_full) {
mse_e_fast <- mean(wobs * (summaries_fast$mu - y)^2)
if (is.null(summaries_baseline)) {
Expand Down Expand Up @@ -495,7 +495,7 @@ get_stat <- function(summaries, summaries_baseline = NULL,
} else {
# full LOO estimator
value <- mean(wobs * correct) - mean(wobs * correct_baseline)
value_se <- .weighted_sd(correct - correct_baseline, wobs) / sqrt(n_full)
value_se <- weighted.sd(correct - correct_baseline, wobs) / sqrt(n_full)
}
} else if (stat == "auc") {
if (n_loo < n_full) {
Expand All @@ -505,15 +505,15 @@ get_stat <- function(summaries, summaries_baseline = NULL,
if (!is.null(mu_baseline)) {
auc_data <- cbind(y, mu, wobs)
auc_data_baseline <- cbind(y, mu_baseline, wobs)
value <- .auc(auc_data) - .auc(auc_data_baseline)
value <- auc(auc_data) - auc(auc_data_baseline)
idxs_cols <- seq_len(ncol(auc_data))
idxs_cols_bs <- setdiff(seq_len(ncol(auc_data) + ncol(auc_data_baseline)),
idxs_cols)
diffvalue.bootstrap <- bootstrap(
cbind(auc_data, auc_data_baseline),
function(x) {
.auc(x[, idxs_cols, drop = FALSE]) -
.auc(x[, idxs_cols_bs, drop = FALSE])
auc(x[, idxs_cols, drop = FALSE]) -
auc(x[, idxs_cols_bs, drop = FALSE])
},
...
)
Expand All @@ -523,8 +523,8 @@ get_stat <- function(summaries, summaries_baseline = NULL,
names = FALSE, na.rm = TRUE)
} else {
auc_data <- cbind(y, mu, wobs)
value <- .auc(auc_data)
value.bootstrap <- bootstrap(auc_data, .auc, ...)
value <- auc(auc_data)
value.bootstrap <- bootstrap(auc_data, auc, ...)
value_se <- sd(value.bootstrap, na.rm = TRUE)
lq_uq <- quantile(value.bootstrap,
probs = c(alpha_half, one_minus_alpha_half),
Expand Down
36 changes: 12 additions & 24 deletions R/varsel.R
Original file line number Diff line number Diff line change
Expand Up @@ -248,27 +248,15 @@ varsel.vsel <- function(object, ...) {

#' @rdname varsel
#' @export
varsel.refmodel <- function(
object,
d_test = NULL,
method = "forward",
ndraws = NULL,
nclusters = 20,
ndraws_pred = 400,
nclusters_pred = NULL,
refit_prj = !inherits(object, "datafit"),
nterms_max = NULL,
verbose = getOption("projpred.verbose", interactive()),
search_control = NULL,
lambda_min_ratio = 1e-5,
nlambda = 150,
thresh = 1e-6,
penalty = NULL,
search_terms = NULL,
search_out = NULL,
seed = NA,
...
) {
varsel.refmodel <- function(object, d_test = NULL, method = "forward",
ndraws = NULL, nclusters = 20, ndraws_pred = 400,
nclusters_pred = NULL,
refit_prj = !inherits(object, "datafit"),
nterms_max = NULL, verbose = TRUE,
search_control = NULL, lambda_min_ratio = 1e-5,
nlambda = 150, thresh = 1e-6, penalty = NULL,
search_terms = NULL, search_out = NULL, seed = NA,
...) {
if (!missing(lambda_min_ratio)) {
warning("Argument `lambda_min_ratio` is deprecated. Please specify ",
"control arguments for the search via argument `search_control`. ",
Expand Down Expand Up @@ -388,7 +376,7 @@ varsel.refmodel <- function(
search_path <- search_out[["search_path"]]
} else {
verb_out("-----\nRunning the search ...", verbose = verbose)
search_path <- .select(
search_path <- select(
refmodel = refmodel, ndraws = ndraws, nclusters = nclusters,
method = method, nterms_max = nterms_max, penalty = penalty,
verbose = verbose, search_control = search_control,
Expand Down Expand Up @@ -523,8 +511,8 @@ varsel.refmodel <- function(
# `outdmins` (the submodel fits along the predictor ranking, with the number
# of fits per model size being equal to the number of projected draws), and
# `p_sel` (the output from get_refdist() for the search).
.select <- function(refmodel, ndraws, nclusters, reweighting_args = NULL,
method, nterms_max, penalty, verbose, search_control, ...) {
select <- function(refmodel, ndraws, nclusters, reweighting_args = NULL, method,
nterms_max, penalty, verbose, search_control, ...) {
if (is.null(reweighting_args)) {
p_sel <- get_refdist(refmodel, ndraws = ndraws, nclusters = nclusters)
} else {
Expand Down
2 changes: 1 addition & 1 deletion man/cv_varsel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/varsel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions tests/testthat/helpers/testers.R
Original file line number Diff line number Diff line change
Expand Up @@ -1976,8 +1976,7 @@ vsel_tester <- function(
search_control_expected = NULL,
extra_tol = 1.1,
info_str = ""
) {

) {
# Preparations:
if (with_cv) {
if (is.null(cv_method_expected)) {
Expand Down Expand Up @@ -2305,7 +2304,6 @@ vsel_tester <- function(
}
return(invisible(TRUE))
}

for (j in seq_along(vs$summaries$sub)) {
smmrs_sub_j_tester(vs$summaries$sub[[j]])
if (vs$refmodel$family$for_latent) {
Expand Down

0 comments on commit 6151b0c

Please sign in to comment.