diff --git a/R/checks.R b/R/checks.R index b7111a34..76b97282 100644 --- a/R/checks.R +++ b/R/checks.R @@ -20,6 +20,26 @@ check_rset <- function(x) { if (inherits(x, "permutations")) { cli::cli_abort("Permutation samples are not suitable for tuning.") } + + # Check fold weights if present + check_fold_weights(x) + + invisible(NULL) +} + +#' Check fold weights in rset objects +#' +#' @param x An rset object. +#' @return `NULL` invisibly, or error if weights are invalid. +#' @keywords internal +check_fold_weights <- function(x) { + weights <- attr(x, ".fold_weights") + if (is.null(weights)) { + return(invisible(NULL)) + } + + .validate_fold_weights(weights, nrow(x)) + invisible(NULL) } diff --git a/R/collect.R b/R/collect.R index b61ca8e9..114cf22d 100644 --- a/R/collect.R +++ b/R/collect.R @@ -551,6 +551,8 @@ estimate_tune_results <- function(x, ..., col_name = ".metrics") { ) } + fold_weights <- .get_fold_weights(x) + # The mapping of tuning parameters and .config. config_key <- .config_key_from_metrics(x) @@ -590,15 +592,43 @@ estimate_tune_results <- function(x, ..., col_name = ".metrics") { x <- x %>% tibble::as_tibble() %>% - vctrs::vec_slice(., .$id != "Apparent") %>% - dplyr::group_by(!!!rlang::syms(param_names), .metric, .estimator, - !!!rlang::syms(group_cols)) %>% - dplyr::summarize( - mean = mean(.estimate, na.rm = TRUE), - n = sum(!is.na(.estimate)), - std_err = sd(.estimate, na.rm = TRUE) / sqrt(n), - .groups = "drop" - ) + vctrs::vec_slice(., .$id != "Apparent") + + # Join weights to the data if available + if (!is.null(fold_weights)) { + weight_data <- .create_weight_mapping(fold_weights, id_names, x) + if (!is.null(weight_data)) { + x <- dplyr::left_join(x, weight_data, by = id_names) + } else { + # If weight mapping failed, fall back to unweighted + fold_weights <- NULL + } + } + + if (!is.null(fold_weights)) { + # Use weighted aggregation + x <- x %>% + dplyr::group_by(!!!rlang::syms(param_names), .metric, .estimator, + !!!rlang::syms(group_cols)) %>% + dplyr::summarize( + mean = .weighted_mean(.estimate, .fold_weight), + n = sum(!is.na(.estimate)), + effective_n = .effective_sample_size(.fold_weight[!is.na(.estimate)]), + std_err = .weighted_sd(.estimate, .fold_weight) / sqrt(pmax(effective_n, 1)), + .groups = "drop" + ) %>% + dplyr::select(-effective_n) + } else { + x <- x %>% + dplyr::group_by(!!!rlang::syms(param_names), .metric, .estimator, + !!!rlang::syms(group_cols)) %>% + dplyr::summarize( + mean = mean(.estimate, na.rm = TRUE), + n = sum(!is.na(.estimate)), + std_err = sd(.estimate, na.rm = TRUE) / sqrt(n), + .groups = "drop" + ) + } # only join when parameters are being tuned (#600) if (length(param_names) == 0) { diff --git a/R/tune_grid.R b/R/tune_grid.R index fe67c6a6..d9e337cc 100644 --- a/R/tune_grid.R +++ b/R/tune_grid.R @@ -396,6 +396,10 @@ pull_rset_attributes <- function(x) { att$class <- setdiff(class(x), class(tibble::new_tibble(list()))) att$class <- att$class[att$class != "rset"] + if (!is.null(attr(x, ".fold_weights"))) { + att[[".fold_weights"]] <- attr(x, ".fold_weights") + } + lab <- try(pretty(x), silent = TRUE) if (inherits(lab, "try-error")) { lab <- NA_character_ diff --git a/R/utils.R b/R/utils.R index c94e70b2..79b72298 100644 --- a/R/utils.R +++ b/R/utils.R @@ -247,6 +247,190 @@ pretty.tune_results <- function(x, ...) { attr(x, "rset_info")$label } +#' Fold weights utility functions +#' +#' These are internal functions for handling variable fold weights in +#' hyperparameter tuning. +#' +#' @param x A tune_results object. +#' @param weights Numeric vector of weights. +#' @param id_names Character vector of ID column names. +#' @param metrics_data The metrics data frame. +#' @param w Numeric vector of weights. +#' @param n_folds Integer number of folds. +#' +#' @return Various return values depending on the function. +#' @keywords internal +#' @name fold_weights_utils +#' @aliases .create_weight_mapping .weighted_mean .weighted_sd .effective_sample_size .validate_fold_weights +#' @export +#' @rdname fold_weights_utils +.get_fold_weights <- function(x) { + rset_info <- attr(x, "rset_info") + if (is.null(rset_info)) { + return(NULL) + } + + # Access weights from rset_info attributes using correct path + weights <- rset_info$att[[".fold_weights"]] + + weights +} + +#' @export +#' @rdname fold_weights_utils +.create_weight_mapping <- function(weights, id_names, metrics_data) { + # Get unique combinations of ID columns from the metrics data + unique_ids <- dplyr::distinct(metrics_data, !!!rlang::syms(id_names)) + + if (nrow(unique_ids) != length(weights)) { + cli::cli_warn( + "Number of weights ({length(weights)}) does not match number of resamples ({nrow(unique_ids)}). Weights will be ignored." + ) + return(NULL) + } + + # Add weights to the unique ID combinations + unique_ids$.fold_weight <- weights + unique_ids +} + +#' @export +#' @rdname fold_weights_utils +.weighted_mean <- function(x, w) { + if (all(is.na(x))) { + return(NA_real_) + } + + # Remove NA values and corresponding weights + valid <- !is.na(x) + x_valid <- x[valid] + w_valid <- w[valid] + + if (length(x_valid) == 0) { + return(NA_real_) + } + + # Normalize weights + w_valid <- w_valid / sum(w_valid) + + sum(x_valid * w_valid) +} + +#' @export +#' @rdname fold_weights_utils +.weighted_sd <- function(x, w) { + if (all(is.na(x))) { + return(NA_real_) + } + + # Remove NA values and corresponding weights + valid <- !is.na(x) + x_valid <- x[valid] + w_valid <- w[valid] + + if (length(x_valid) <= 1) { + return(NA_real_) + } + + # Normalize weights + w_valid <- w_valid / sum(w_valid) + + # Calculate weighted mean + weighted_mean <- sum(x_valid * w_valid) + + # Calculate weighted variance + weighted_var <- sum(w_valid * (x_valid - weighted_mean)^2) + + sqrt(weighted_var) +} + +#' @export +#' @rdname fold_weights_utils +.effective_sample_size <- function(w) { + # Remove NA weights + w <- w[!is.na(w)] + + if (length(w) == 0) { + return(0) + } + + # Calculate effective sample size: (sum of weights)^2 / sum of squared weights + sum_w <- sum(w) + sum_w_sq <- sum(w^2) + + if (sum_w_sq == 0) { + return(0) + } + + sum_w^2 / sum_w_sq +} + +#' @export +#' @rdname fold_weights_utils +.validate_fold_weights <- function(weights, n_folds) { + if (is.null(weights)) { + return(NULL) + } + + if (!is.numeric(weights)) { + cli::cli_abort("{.arg weights} must be numeric.") + } + + if (length(weights) != n_folds) { + cli::cli_abort( + "Length of {.arg weights} ({length(weights)}) must equal number of folds ({n_folds})." + ) + } + + if (any(weights < 0)) { + cli::cli_abort("{.arg weights} must be non-negative.") + } + + if (all(weights == 0)) { + cli::cli_abort("At least one weight must be positive.") + } + + # Return normalized weights + weights / sum(weights) +} + +#' Add fold weights to an rset object +#' +#' @param rset An rset object. +#' @param weights A numeric vector of weights. +#' @return The rset object with weights added as an attribute. +#' @export +add_fold_weights <- function(rset, weights) { + if (!inherits(rset, "rset")) { + cli::cli_abort("{.arg rset} must be an rset object.") + } + + # Validate weights + weights <- .validate_fold_weights(weights, nrow(rset)) + + # Add weights as an attribute + attr(rset, ".fold_weights") <- weights + + rset +} + +#' Calculate fold weights from fold sizes +#' +#' @param rset An rset object. +#' @return A numeric vector of weights proportional to fold sizes. +#' @export +calculate_fold_weights <- function(rset) { + if (!inherits(rset, "rset")) { + cli::cli_abort("{.arg rset} must be an rset object.") + } + + # Calculate the size of each analysis set + fold_sizes <- purrr::map_int(rset$splits, ~ nrow(rsample::analysis(.x))) + + # Return weights proportional to fold sizes + fold_sizes / sum(fold_sizes) +} # ------------------------------------------------------------------------------ diff --git a/tests/testthat/test-checks.R b/tests/testthat/test-checks.R index 4a88ff2e..829581f5 100644 --- a/tests/testthat/test-checks.R +++ b/tests/testthat/test-checks.R @@ -509,3 +509,140 @@ test_that("check parameter finalization", { ) expect_true(inherits(p5, "parameters")) }) + +test_that("check fold weights", { + folds <- rsample::vfold_cv(mtcars, v = 3) + + # No weights should pass silently + expect_no_error(tune:::check_fold_weights(folds)) + + # Valid weights should pass + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_fold_weights(folds, weights) + expect_no_error(tune:::check_fold_weights(weighted_folds)) + + # Invalid weights should error + expect_error( + add_fold_weights(folds, c("a", "b", "c")), + "must be numeric" + ) + + expect_error( + add_fold_weights(folds, c(0.5, 0.3)), + "must equal number of folds" + ) + + expect_error( + add_fold_weights(folds, c(-0.1, 0.5, 0.6)), + "must be non-negative" + ) +}) + +test_that("fold weights integration test", { + skip_if_not_installed("rsample") + skip_if_not_installed("parsnip") + skip_if_not_installed("recipes") + skip_if_not_installed("workflows") + + # Create simple data and resamples + set.seed(1234) + data_small <- mtcars[1:20, ] + folds <- rsample::vfold_cv(data_small, v = 3) + + # Create simple model and recipe + simple_rec <- recipes::recipe(mpg ~ wt + hp, data = data_small) + simple_mod <- parsnip::linear_reg() %>% parsnip::set_engine("lm") + simple_wflow <- workflows::workflow() %>% + workflows::add_recipe(simple_rec) %>% + workflows::add_model(simple_mod) + + # Test with equal weights (should match unweighted results) + equal_weights <- c(1, 1, 1) + weighted_folds_equal <- add_fold_weights(folds, equal_weights) + + # Fit both weighted and unweighted + unweighted_results <- fit_resamples(simple_wflow, folds, + control = control_resamples(save_pred = FALSE)) + weighted_results_equal <- fit_resamples(simple_wflow, weighted_folds_equal, + control = control_resamples(save_pred = FALSE)) + + # Extract metrics + unweighted_metrics <- collect_metrics(unweighted_results) + weighted_metrics_equal <- collect_metrics(weighted_results_equal) + + # Should be nearly identical (allowing for small numerical differences) + expect_equal(unweighted_metrics$mean, weighted_metrics_equal$mean, tolerance = 1e-10) + + # Test with unequal weights + unequal_weights <- c(0.1, 0.3, 0.6) # Higher weight on last fold + weighted_folds_unequal <- add_fold_weights(folds, unequal_weights) + + weighted_results_unequal <- fit_resamples(simple_wflow, weighted_folds_unequal, + control = control_resamples(save_pred = FALSE)) + weighted_metrics_unequal <- collect_metrics(weighted_results_unequal) + + # Should be different from unweighted results + expect_false(all(abs(unweighted_metrics$mean - weighted_metrics_unequal$mean) < 1e-10)) + + # Verify that weights are properly stored and retrieved + expect_equal(attr(weighted_folds_unequal, ".fold_weights"), unequal_weights) + + # Test fold size calculation + calculated_weights <- calculate_fold_weights(folds) + expect_length(calculated_weights, nrow(folds)) + expect_true(all(calculated_weights > 0)) + expect_equal(sum(calculated_weights), 1) # Should sum to 1 now +}) + +test_that("fold weights with tune_grid", { + skip_if_not_installed("rsample") + skip_if_not_installed("parsnip") + skip_if_not_installed("recipes") + skip_if_not_installed("workflows") + skip_if_not_installed("dials") + + # Create simple tuning scenario + set.seed(5678) + data_small <- mtcars[1:15, ] + folds <- rsample::vfold_cv(data_small, v = 3) + + # Create tunable workflow + tune_rec <- recipes::recipe(mpg ~ wt + hp, data = data_small) %>% + recipes::step_normalize(recipes::all_predictors()) + tune_mod <- parsnip::linear_reg(penalty = tune()) %>% + parsnip::set_engine("glmnet") + tune_wflow <- workflows::workflow() %>% + workflows::add_recipe(tune_rec) %>% + workflows::add_model(tune_mod) + + # Create simple grid + simple_grid <- tibble::tibble(penalty = c(0.001, 0.01, 0.1)) + + # Test with unequal weights + weights <- c(0.2, 0.3, 0.5) + weighted_folds <- add_fold_weights(folds, weights) + + # Tune with weights + weighted_tune_results <- tune_grid(tune_wflow, weighted_folds, + grid = simple_grid, + control = control_grid(save_pred = FALSE)) + + # Verify results structure + expect_s3_class(weighted_tune_results, "tune_results") + + # Extract metrics and verify they're computed + weighted_metrics <- collect_metrics(weighted_tune_results) + expect_true(nrow(weighted_metrics) > 0) + expect_true(all(c("mean", "std_err") %in% names(weighted_metrics))) + + # Compare with unweighted results + unweighted_tune_results <- tune_grid(tune_wflow, folds, + grid = simple_grid, + control = control_grid(save_pred = FALSE)) + unweighted_metrics <- collect_metrics(unweighted_tune_results) + + # Results should differ due to weighting + expect_false(all(abs(weighted_metrics$mean - unweighted_metrics$mean) < 1e-10)) +}) + +# ------------------------------------------------------------------------------ diff --git a/tests/testthat/test-weights.R b/tests/testthat/test-weights.R new file mode 100644 index 00000000..c96387d5 --- /dev/null +++ b/tests/testthat/test-weights.R @@ -0,0 +1,311 @@ +# Test file for variable fold weights functionality +if (rlang::is_installed(c("rsample", "parsnip", "yardstick", "workflows", "recipes", "kknn"))) { + + # Setup test data + set.seed(42) + test_data <- data.frame( + x1 = rnorm(50), + x2 = rnorm(50), + x3 = rnorm(50) + ) + test_data$y <- 2 * test_data$x1 + 3 * test_data$x2 + rnorm(50, sd = 0.5) + + set.seed(123) + folds <- rsample::vfold_cv(mtcars, v = 3) + + # Helper function to create a simple model + create_test_model <- function() { + parsnip::linear_reg() %>% parsnip::set_engine("lm") + } + + test_that("add_fold_weights() validates inputs correctly", { + expect_error( + add_fold_weights("not_an_rset", c(0.5, 0.3, 0.2)), + "must be an rset object" + ) + + expect_error( + add_fold_weights(folds, c("a", "b", "c")), + "must be numeric" + ) + + expect_error( + add_fold_weights(folds, c(0.5, 0.3)), + "must equal number of folds" + ) + + expect_error( + add_fold_weights(folds, c(-0.1, 0.5, 0.6)), + "must be non-negative" + ) + + expect_error( + add_fold_weights(folds, c(0, 0, 0)), + "At least one weight must be positive" + ) + }) + + test_that("add_fold_weights() adds weights correctly", { + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_fold_weights(folds, weights) + + # Weights get normalized to sum to 1 + expected_weights <- weights / sum(weights) + + expect_s3_class(weighted_folds, "rset") + expect_equal(attr(weighted_folds, ".fold_weights"), expected_weights) + expect_equal(nrow(weighted_folds), nrow(folds)) + }) + + test_that("calculate_fold_weights() works correctly", { + auto_weights <- calculate_fold_weights(folds) + + expect_type(auto_weights, "double") + expect_length(auto_weights, nrow(folds)) + expect_true(all(auto_weights > 0)) + expect_true(abs(sum(auto_weights) - 1) < 1e-10) + }) + + test_that("weights are preserved through tuning pipeline", { + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_fold_weights(folds, weights) + + mod <- create_test_model() + + suppressWarnings({ + res <- tune_grid( + mod, + mpg ~ ., + resamples = weighted_folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + metrics <- collect_metrics(res) + expect_equal(nrow(metrics), 1) + expect_true("mean" %in% names(metrics)) + expect_true(is.numeric(metrics$mean)) + }) + + test_that("weights affect metric aggregation", { + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_fold_weights(folds, weights) + + mod <- create_test_model() + + suppressWarnings({ + # Unweighted results + res_unweighted <- tune_grid( + mod, + mpg ~ ., + resamples = folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + + # Weighted results + res_weighted <- tune_grid( + mod, + mpg ~ ., + resamples = weighted_folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + unweighted_rmse <- collect_metrics(res_unweighted)$mean[1] + weighted_rmse <- collect_metrics(res_weighted)$mean[1] + + expect_true(is.numeric(unweighted_rmse)) + expect_true(is.numeric(weighted_rmse)) + expect_false(is.na(unweighted_rmse)) + expect_false(is.na(weighted_rmse)) + }) + + test_that("extreme weights show larger effect", { + skip_if_not_installed("kknn") + + # Create folds for this specific test + set.seed(42) + test_folds <- rsample::vfold_cv(test_data, v = 3) + + # Regular weights + weights <- c(0.6, 0.2, 0.2) + weighted_folds <- add_fold_weights(test_folds, weights) + + # Extreme weights + extreme_weights <- c(0.95, 0.025, 0.025) + extreme_weighted_folds <- add_fold_weights(test_folds, extreme_weights) + + # Create a model with tuning parameter + knn_spec <- parsnip::nearest_neighbor(neighbors = tune()) %>% + parsnip::set_engine("kknn") %>% + parsnip::set_mode("regression") + + param_grid <- data.frame(neighbors = c(3, 5)) + + suppressWarnings({ + # Unweighted + res_unweighted <- tune_grid( + knn_spec, + y ~ ., + resamples = test_folds, + grid = param_grid, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + + # Regular weights + res_weighted <- tune_grid( + knn_spec, + y ~ ., + resamples = weighted_folds, + grid = param_grid, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + + # Extreme weights + res_extreme <- tune_grid( + knn_spec, + y ~ ., + resamples = extreme_weighted_folds, + grid = param_grid, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + unweighted_metrics <- collect_metrics(res_unweighted) + weighted_metrics <- collect_metrics(res_weighted) + extreme_metrics <- collect_metrics(res_extreme) + + # Check that results exist and are sensible + expect_equal(nrow(unweighted_metrics), 2) + expect_equal(nrow(weighted_metrics), 2) + expect_equal(nrow(extreme_metrics), 2) + + # Calculate differences + regular_diff <- max(abs(unweighted_metrics$mean - weighted_metrics$mean)) + extreme_diff <- max(abs(unweighted_metrics$mean - extreme_metrics$mean)) + + expect_true(regular_diff >= 0) + expect_true(extreme_diff >= 0) + expect_true(all(is.finite(c(regular_diff, extreme_diff)))) + }) + + test_that("weight normalization works correctly", { + expect_equal( + tune:::.validate_fold_weights(c(3, 6, 9), 3), + c(1/6, 1/3, 1/2) # normalized to sum to 1 + ) + + expect_equal( + tune:::.validate_fold_weights(c(0.2, 0.3, 0.5), 3), + c(0.2, 0.3, 0.5) # already normalized to sum to 1 + ) + }) + + test_that("weighted statistics functions work correctly", { + x <- c(1, 2, 3, 4, 5) + w <- c(0.1, 0.2, 0.3, 0.2, 0.2) + + weighted_mean <- tune:::.weighted_mean(x, w) + weighted_sd <- tune:::.weighted_sd(x, w) + + expect_true(is.numeric(weighted_mean)) + expect_true(is.numeric(weighted_sd)) + expect_false(is.na(weighted_mean)) + expect_false(is.na(weighted_sd)) + expect_true(weighted_sd >= 0) + + # Test with NA values + x_na <- c(1, 2, NA, 4, 5) + weighted_mean_na <- tune:::.weighted_mean(x_na, w) + weighted_sd_na <- tune:::.weighted_sd(x_na, w) + + expect_true(is.numeric(weighted_mean_na)) + expect_true(is.numeric(weighted_sd_na)) + expect_false(is.na(weighted_mean_na)) + + # Test edge cases + expect_true(is.na(tune:::.weighted_mean(c(NA, NA), c(0.5, 0.5)))) + expect_true(is.na(tune:::.weighted_sd(c(1), c(1)))) # single value + }) + + test_that("fold weight extraction works", { + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_fold_weights(folds, weights) + + # Weights get normalized to sum to 1 + expected_weights <- weights / sum(weights) + + mod <- create_test_model() + + suppressWarnings({ + res <- tune_grid( + mod, + mpg ~ ., + resamples = weighted_folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + extracted_weights <- tune:::.get_fold_weights(res) + expect_equal(extracted_weights, expected_weights) + }) + + test_that("individual fold metrics can be collected", { + weights <- c(0.1, 0.5, 0.4) + weighted_folds <- add_fold_weights(folds, weights) + + mod <- create_test_model() + + suppressWarnings({ + res <- tune_grid( + mod, + mpg ~ ., + resamples = weighted_folds, + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + # Collect individual fold metrics + individual_metrics <- collect_metrics(res, summarize = FALSE) + + expect_true(nrow(individual_metrics) >= 3) # At least one metric per fold + expect_true("id" %in% names(individual_metrics)) + expect_true(".estimate" %in% names(individual_metrics)) + expect_true(all(is.finite(individual_metrics$.estimate))) + }) + + test_that("backwards compatibility - no weights", { + mod <- create_test_model() + + suppressWarnings({ + res <- tune_grid( + mod, + mpg ~ ., + resamples = folds, # No weights + grid = 1, + metrics = yardstick::metric_set(yardstick::rmse), + control = control_grid(verbose = FALSE) + ) + }) + + metrics <- collect_metrics(res) + expect_equal(nrow(metrics), 1) + expect_true("mean" %in% names(metrics)) + expect_true(is.numeric(metrics$mean)) + expect_false(is.na(metrics$mean)) + }) + +} \ No newline at end of file