diff --git a/.github/recipe/recipe.yaml b/.github/recipe/recipe.yaml index 7a678ef8..211c37ab 100644 --- a/.github/recipe/recipe.yaml +++ b/.github/recipe/recipe.yaml @@ -56,9 +56,11 @@ requirements: - r-mr.mashr - r-mvsusier - r-ncvreg + - r-nnls - r-pgenlibr - r-purrr - r-qgg + - r-quadprog - r-rcppdpr - r-readr - r-rfast @@ -98,6 +100,7 @@ requirements: - r-mr.mashr - r-mvsusier - r-ncvreg + - r-nnls - r-pgenlibr - r-purrr - r-qgg diff --git a/DESCRIPTION b/DESCRIPTION index 5e858b8c..78c47760 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -56,7 +56,9 @@ Suggests: mr.mashr, mvsusieR, ncvreg, + nnls, pgenlibr, + quadprog, qgg, qvalue, rmarkdown, diff --git a/NAMESPACE b/NAMESPACE index 18fb675b..ced6d96a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -23,9 +23,13 @@ export(compute_qtl_enrichment) export(ctwas_bimfile_loader) export(dentist) export(dentist_single_window) +export(dpr_adaptive_gibbs_weights) +export(dpr_gibbs_weights) +export(dpr_vb_weights) export(dpr_weights) export(enet_weights) export(enforce_design_full_rank) +export(ensemble_weights) export(extract_cs_info) export(extract_flatten_sumstats_from_nested) export(extract_top_pip_info) @@ -185,10 +189,12 @@ importFrom(readr,read_lines) importFrom(rlang,"!!!") importFrom(stats,as.dist) importFrom(stats,coef) +importFrom(stats,complete.cases) importFrom(stats,cor) importFrom(stats,cutree) importFrom(stats,hclust) importFrom(stats,lm.fit) +importFrom(stats,optim) importFrom(stats,pchisq) importFrom(stats,pnorm) importFrom(stats,predict) diff --git a/R/regularized_regression.R b/R/regularized_regression.R index a003d4e6..0d86c658 100644 --- a/R/regularized_regression.R +++ b/R/regularized_regression.R @@ -1208,3 +1208,15 @@ dpr_weights <- function(X, y, fitting_method = "VB", ...) { eff.wgt[keep] <- as.numeric(fit$beta + fit$alpha) return(eff.wgt) } + +#' @rdname dpr_weights +#' @export +dpr_vb_weights <- function(X, y, ...) dpr_weights(X, y, fitting_method = "VB", ...) + +#' @rdname dpr_weights +#' @export +dpr_gibbs_weights <- function(X, y, ...) dpr_weights(X, y, fitting_method = "Gibbs", ...) + +#' @rdname dpr_weights +#' @export +dpr_adaptive_gibbs_weights <- function(X, y, ...) dpr_weights(X, y, fitting_method = "Adaptive_Gibbs", ...) diff --git a/R/twas_weights.R b/R/twas_weights.R index c2bead97..336d6aef 100644 --- a/R/twas_weights.R +++ b/R/twas_weights.R @@ -1,3 +1,68 @@ +# Map short method names and presets to weight_methods lists. +# @param methods A character vector of short method names, or a preset string +# ("default" or "fast_default"). +# @return A named list suitable for the weight_methods parameter. +# @noRd +.twas_method_lookup <- function(methods) { + method_map <- list( + susie = list(fn = "susie_weights", args = list(refine = FALSE, init_L = 5, max_L = 20)), + susie_ash = list(fn = "susie_ash_weights", args = list()), + susie_inf = list(fn = "susie_inf_weights", args = list()), + mrash = list(fn = "mrash_weights", args = list(init_prior_sd = TRUE, max.iter = 100)), + enet = list(fn = "enet_weights", args = list()), + lasso = list(fn = "lasso_weights", args = list()), + bayes_r = list(fn = "bayes_r_weights", args = list()), + bayes_l = list(fn = "bayes_l_weights", args = list()), + bayes_a = list(fn = "bayes_a_weights", args = list()), + bayes_b = list(fn = "bayes_b_weights", args = list()), + bayes_c = list(fn = "bayes_c_weights", args = list()), + bayes_n = list(fn = "bayes_n_weights", args = list()), + b_lasso = list(fn = "b_lasso_weights", args = list()), + dpr_vb = list(fn = "dpr_vb_weights", args = list()), + dpr_gibbs = list(fn = "dpr_gibbs_weights", args = list()), + dpr_adaptive_gibbs = list(fn = "dpr_adaptive_gibbs_weights", args = list()), + scad = list(fn = "scad_weights", args = list()), + mcp = list(fn = "mcp_weights", args = list()), + l0learn = list(fn = "l0learn_weights", args = list()), + mvsusie = list(fn = "mvsusie_weights", args = list()), + mrmash = list(fn = "mrmash_weights", args = list()) + ) + + # Handle presets + if (length(methods) == 1) { + if (methods == "fast_default") { + methods <- c("susie", "mrash", "enet", "lasso") + } else if (methods == "default") { + methods <- c("susie", "mrash", "enet", "lasso", "bayes_r", "dpr_gibbs") + } + } + + # Build reverse map: function name -> short name, so full names are accepted too + fn_to_short <- setNames( + names(method_map), + vapply(method_map, function(x) x$fn, character(1)) + ) + # Normalize any full function names to short names + methods <- vapply(methods, function(m) { + if (m %in% names(fn_to_short)) fn_to_short[[m]] else m + }, character(1), USE.NAMES = FALSE) + + unknown <- setdiff(methods, names(method_map)) + if (length(unknown) > 0) { + stop( + "Unknown TWAS method(s): ", paste(unknown, collapse = ", "), + ". Available methods: ", paste(names(method_map), collapse = ", ") + ) + } + + result <- list() + for (m in methods) { + entry <- method_map[[m]] + result[[entry$fn]] <- entry$args + } + result +} + # Identify non-zero-variance columns of X. Returns a logical vector. #' @importFrom matrixStats colSds #' @noRd @@ -131,7 +196,7 @@ twas_weights_cv <- function(X, Y, fold = NULL, sample_partitions = NULL, weight_ } if (is.character(weight_methods)) { - weight_methods <- lapply(setNames(nm = weight_methods), function(x) list()) + weight_methods <- .twas_method_lookup(weight_methods) } if (!exists(".Random.seed")) { @@ -355,7 +420,7 @@ twas_weights <- function(X, Y, weight_methods, num_threads = 1) { } if (is.character(weight_methods)) { - weight_methods <- lapply(setNames(nm = weight_methods), function(x) list()) + weight_methods <- .twas_method_lookup(weight_methods) } # Determine number of cores to use @@ -450,6 +515,18 @@ twas_predict <- function(X, weights_list) { #' @param max_cv_variants The maximum number of variants to be included in cross-validation. Defaults to -1 which means no limit. #' @param cv_threads The number of threads to use for parallel computation in cross-validation. Defaults to 1. #' @param cv_weight_methods List of methods to use for cross-validation. If NULL, uses the same methods as weight_methods. +#' @param ensemble Logical. If TRUE and cv_folds > 1, learn ensemble combination +#' weights via stacked regression (SR-TWAS). Requires at least two individual +#' methods to have been run and to pass the R-squared cutoff. Defaults to FALSE. +#' @param ensemble_r2_threshold Minimum cross-validated R-squared for an individual method +#' to be included in the ensemble. Methods below this threshold are excluded. +#' Defaults to 0.01. +#' @param ensemble_solver Character string specifying the optimization backend +#' for ensemble learning. One of \code{"quadprog"}, \code{"nnls"}, +#' \code{"lbfgsb"}, or \code{"glmnet"}. Passed to +#' \code{\link{ensemble_weights}}. Defaults to \code{"quadprog"}. +#' @param ensemble_alpha Elastic net mixing parameter, used only when +#' \code{ensemble_solver = "glmnet"}. Defaults to 1 (lasso). #' #' @return A list containing results from the TWAS pipeline, including TWAS weights, predictions, and optionally cross-validation results. #' @export @@ -462,17 +539,18 @@ twas_weights_pipeline <- function(X, susie_fit = NULL, cv_folds = 5, sample_partition = NULL, - weight_methods = list( - enet_weights = list(), - lasso_weights = list(), - bayes_r_weights = list(), - bayes_l_weights = list(), - mrash_weights = list(init_prior_sd = TRUE, max.iter = 100), - susie_weights = list(refine = FALSE, init_L = 5, max_L = 20) - ), + weight_methods = "default", max_cv_variants = -1, cv_threads = 1, - cv_weight_methods = NULL) { + cv_weight_methods = NULL, + ensemble = FALSE, + ensemble_r2_threshold = 0.01, + ensemble_solver = "quadprog", + ensemble_alpha = 1) { + if (is.character(weight_methods)) { + weight_methods <- .twas_method_lookup(weight_methods) + } + res <- list() st <- proc.time() message("Performing TWAS weights computation for univariate analysis methods ...") @@ -521,6 +599,65 @@ twas_weights_pipeline <- function(X, num_threads = cv_threads, variants_to_keep = if (length(variants_for_cv) > 0) variants_for_cv else NULL ) + + # Ensemble learning: learn optimal method combination via stacked regression + if (ensemble) { + n_methods <- length(cv_weight_methods) + if (n_methods < 2) { + message("Ensemble TWAS requires at least 2 weight methods to be used. ", + "Only ", n_methods, " method was provided. Skipping ensemble.") + } else if (!is.null(res$twas_cv_result$performance)) { + # Extract R² for each method from CV performance table + method_rsq <- vapply(res$twas_cv_result$performance, function(perf) { + perf[1, "rsq"] + }, numeric(1)) + names(method_rsq) <- gsub("_performance$", "", names(method_rsq)) + + passing <- !is.na(method_rsq) & method_rsq >= ensemble_r2_threshold + n_passing <- sum(passing) + + if (n_passing < 2) { + passed_info <- paste0(" ", names(method_rsq), ": R² = ", + round(method_rsq, 4), + ifelse(passing, " (passed)", " (failed)")) + message("Ensemble TWAS could not be run because fewer than 2 methods ", + "passed the R² cutoff of ", ensemble_r2_threshold, ".\n", + "Method R² values:\n", + paste(passed_info, collapse = "\n")) + } else { + passing_base <- names(method_rsq)[passing] + passing_pred_names <- paste0(passing_base, "_predicted") + passing_weight_names <- paste0(passing_base, "_weights") + + # Subset cv_results predictions to passing methods + filtered_cv <- res$twas_cv_result + filtered_cv$prediction <- filtered_cv$prediction[passing_pred_names] + + # Subset twas_weights to passing methods + filtered_weights <- res$twas_weights[passing_weight_names] + + message("Computing ensemble TWAS weights via stacked regression ", + "using ", n_passing, " methods: ", + paste(passing_base, collapse = ", "), " ...") + ens_result <- ensemble_weights( + cv_results = filtered_cv, + Y = y, + twas_weight_list = filtered_weights, + solver = ensemble_solver, + alpha = ensemble_alpha + ) + + # Add ensemble weights alongside individual method weights + if (!is.null(ens_result$ensemble_twas_weights)) { + res$twas_weights$ensemble_weights <- ens_result$ensemble_twas_weights + ens_wt <- ens_result$ensemble_twas_weights + if (!is.matrix(ens_wt)) ens_wt <- matrix(ens_wt, ncol = 1) + res$twas_predictions$ensemble_predicted <- X %*% ens_wt + } + res$ensemble <- ens_result + } + } + } } res$total_time_elapsed <- proc.time() - st @@ -668,3 +805,516 @@ twas_multivariate_weights_pipeline <- function( } return(res) } + + +# Solve ensemble stacking via quadprog (constrained QP with sum-to-1 and non-negativity). +# @param P_valid Matrix of CV predictions for valid methods (n x K_valid). +# @param y_obs Observed outcome vector (n). +# @param K_valid Number of valid methods. +# @return Normalized coefficient vector of length K_valid. +# @noRd +.solve_ensemble_quadprog <- function(P_valid, y_obs, K_valid) { + if (!requireNamespace("quadprog", quietly = TRUE)) { + stop("Package 'quadprog' is required for solver='quadprog'. ", + "Install with: install.packages('quadprog')") + } + + Dmat <- crossprod(P_valid) + dvec <- as.vector(crossprod(P_valid, y_obs)) + # Ridge term for numerical stability (small relative to trace) + Dmat <- Dmat + 1e-8 * mean(diag(Dmat)) * diag(K_valid) + + # Constraint matrix: first constraint is equality (sum = 1), then K_valid + # non-negativity constraints. + Amat <- cbind(rep(1, K_valid), diag(K_valid)) + bvec <- c(1, rep(0, K_valid)) + + qp_sol <- tryCatch( + quadprog::solve.QP(Dmat = Dmat, dvec = dvec, Amat = Amat, bvec = bvec, meq = 1), + error = function(e) { + warning("QP solver failed: ", conditionMessage(e), + ". Falling back to equal weights among valid methods.") + NULL + } + ) + + if (is.null(qp_sol)) { + return(rep(1 / K_valid, K_valid)) + } + + # Numerical cleanup: clamp to non-negative and renormalize + zeta_valid <- pmax(qp_sol$solution, 0) + zeta_sum <- sum(zeta_valid) + if (zeta_sum <= 0) { + warning("QP returned all-zero solution. Falling back to equal weights.") + return(rep(1 / K_valid, K_valid)) + } + zeta_valid / zeta_sum +} + +# Solve ensemble stacking via NNLS (non-negative least squares, then normalize). +# This is the approach used by SuperLearner (Lawson-Hanson algorithm). +# @param P_valid Matrix of CV predictions for valid methods (n x K_valid). +# @param y_obs Observed outcome vector (n). +# @param K_valid Number of valid methods. +# @return Normalized coefficient vector of length K_valid. +# @noRd +.solve_ensemble_nnls <- function(P_valid, y_obs, K_valid) { + if (!requireNamespace("nnls", quietly = TRUE)) { + stop("Package 'nnls' is required for solver='nnls'. ", + "Install with: install.packages('nnls')") + } + + fit <- tryCatch( + nnls::nnls(P_valid, y_obs), + error = function(e) { + warning("NNLS solver failed: ", conditionMessage(e), + ". Falling back to equal weights.") + NULL + } + ) + + if (is.null(fit)) { + return(rep(1 / K_valid, K_valid)) + } + + zeta_valid <- fit$x + zeta_sum <- sum(zeta_valid) + if (zeta_sum <= 0) { + warning("NNLS returned all-zero solution. Falling back to equal weights.") + return(rep(1 / K_valid, K_valid)) + } + zeta_valid / zeta_sum +} + +# Solve ensemble stacking via L-BFGS-B (box-constrained optimization, then normalize). +# Uses base R optim() with analytical gradient. No extra dependencies. +# @param P_valid Matrix of CV predictions for valid methods (n x K_valid). +# @param y_obs Observed outcome vector (n). +# @param K_valid Number of valid methods. +# @return Normalized coefficient vector of length K_valid. +# @noRd +.solve_ensemble_lbfgsb <- function(P_valid, y_obs, K_valid) { + PtP <- crossprod(P_valid) + Pty <- as.vector(crossprod(P_valid, y_obs)) + + fn <- function(z) sum((y_obs - P_valid %*% z)^2) + gr <- function(z) as.vector(2 * (PtP %*% z - Pty)) + + fit <- tryCatch( + optim( + par = rep(1 / K_valid, K_valid), + fn = fn, gr = gr, + method = "L-BFGS-B", + lower = rep(0, K_valid) + ), + error = function(e) { + warning("L-BFGS-B solver failed: ", conditionMessage(e), + ". Falling back to equal weights.") + NULL + } + ) + + if (is.null(fit)) { + return(rep(1 / K_valid, K_valid)) + } + + zeta_valid <- pmax(fit$par, 0) + zeta_sum <- sum(zeta_valid) + if (zeta_sum <= 0) { + warning("L-BFGS-B returned all-zero solution. Falling back to equal weights.") + return(rep(1 / K_valid, K_valid)) + } + zeta_valid / zeta_sum +} + +# Solve ensemble stacking via glmnet (penalized regression with non-negativity). +# Uses cv.glmnet for automatic lambda selection. The alpha parameter controls +# the elastic net mixing: alpha=1 is lasso (sparse), alpha=0 is ridge. +# @param P_valid Matrix of CV predictions for valid methods (n x K_valid). +# @param y_obs Observed outcome vector (n). +# @param K_valid Number of valid methods. +# @param alpha Elastic net mixing parameter (default 1 = lasso). +# @return Normalized coefficient vector of length K_valid. +# @noRd +.solve_ensemble_glmnet <- function(P_valid, y_obs, K_valid, alpha = 1) { + if (!requireNamespace("glmnet", quietly = TRUE)) { + stop("Package 'glmnet' is required for solver='glmnet'. ", + "Install with: install.packages('glmnet')") + } + + fit <- tryCatch( + glmnet::cv.glmnet( + x = P_valid, y = y_obs, + lower.limits = 0, + alpha = alpha, + intercept = FALSE + ), + error = function(e) { + warning("glmnet solver failed: ", conditionMessage(e), + ". Falling back to equal weights.") + NULL + } + ) + + if (is.null(fit)) { + return(rep(1 / K_valid, K_valid)) + } + + zeta_valid <- as.numeric(coef(fit, s = "lambda.min"))[-1] # drop intercept + zeta_valid <- pmax(zeta_valid, 0) + zeta_sum <- sum(zeta_valid) + if (zeta_sum <= 0) { + warning("glmnet returned all-zero solution. Falling back to equal weights.") + return(rep(1 / K_valid, K_valid)) + } + zeta_valid / zeta_sum +} + + +#' Ensemble TWAS Weights via Stacked Regression +#' +#' Given cross-validated predictions from multiple TWAS weight methods, learns +#' non-negative combination coefficients (summing to 1) via constrained least +#' squares. Returns ensemble weights and per-method performance metrics. +#' +#' This implements the stacked regression approach of SR-TWAS (Dai et al., +#' Nature Communications, 2024, \doi{10.1038/s41467-024-50983-w}). The ensemble +#' provides a principled way to combine predictions from many TWAS weight +#' methods without requiring the user to pick one method a priori or pay a +#' multiple-testing penalty for running several. +#' +#' For single-dataset usage, pass one \code{twas_weights_cv()} result directly. +#' For multi-dataset ensemble (e.g., combining cell types or reference panels +#' such as CUMC1 + MIT), pass a list of \code{twas_weights_cv()} results along +#' with a list of observed Y vectors — this learns a single joint set of +#' coefficients. +#' +#' @param cv_results Output of \code{\link{twas_weights_cv}}, with \code{$prediction} +#' (named list of method -> out-of-fold prediction matrix, keys like +#' \code{"susie_predicted"}). For multi-dataset: a list of such objects. +#' @param Y Observed outcome vector or matrix (samples x contexts). For +#' multi-dataset: a list of vectors/matrices, one per dataset. +#' @param twas_weight_list Optional named list of weight matrices from +#' \code{\link{twas_weights}}, with keys like \code{"susie_weights"}. Used to +#' construct the final combined TWAS weight vector. For multi-dataset: a list +#' of such lists (the first is used as the weight template). +#' @param context_index Integer indicating which column of Y to use when Y is a +#' matrix. Default is 1 (univariate). +#' @param solver Character string specifying the optimization backend. +#' One of \code{"quadprog"} (default), \code{"nnls"}, \code{"lbfgsb"}, or +#' \code{"glmnet"}. +#' \code{"quadprog"} solves a constrained QP with sum-to-1 and non-negativity +#' constraints. \code{"nnls"} uses non-negative least squares (Lawson-Hanson +#' algorithm, as in SuperLearner) and normalizes post-hoc. \code{"lbfgsb"} +#' uses \code{optim(method = "L-BFGS-B")} with non-negativity bounds and +#' normalizes post-hoc. \code{"glmnet"} uses \code{cv.glmnet} with +#' \code{lower.limits = 0} for penalized non-negative regression, providing +#' automatic method selection via regularization. All solvers fall back to +#' equal weights on failure. +#' @param alpha Elastic net mixing parameter, used only when +#' \code{solver = "glmnet"}. \code{alpha = 1} (default) is lasso (sparse +#' method selection), \code{alpha = 0} is ridge, and intermediate values +#' give elastic net. +#' +#' @return A list with components: +#' \describe{ +#' \item{method_coef}{Named numeric vector of combination coefficients +#' (\eqn{\zeta_k}), non-negative and summing to 1. Names are method +#' base names (e.g., \code{"susie"}, \code{"enet"}).} +#' \item{ensemble_twas_weights}{Final combined weight vector +#' \eqn{w = \sum_k \zeta_k w_k}, or NULL if \code{twas_weight_list} +#' is not provided. Returned as a vector for univariate Y, matrix otherwise.} +#' \item{method_performance}{Named numeric vector of per-method R-squared +#' computed from out-of-fold CV predictions. Preserved so users can still +#' report individual method performance.} +#' } +#' +#' @details +#' The stacked regression solves: +#' \deqn{\min_{\zeta} \|y - P\zeta\|^2 \quad \text{s.t.} \quad \zeta_k \geq 0,\ \sum_k \zeta_k = 1} +#' where P is the \eqn{n \times K} matrix of out-of-fold predictions from K +#' methods. Four solver backends are available: \code{"quadprog"} enforces +#' both constraints during optimization; \code{"nnls"}, \code{"lbfgsb"}, and +#' \code{"glmnet"} enforce non-negativity only, then normalize coefficients +#' to sum to 1. The \code{"glmnet"} solver additionally applies +#' regularization, which can produce sparse solutions (method selection). +#' If any solver fails, the function falls back to equal weights with a +#' warning. +#' +#' Methods whose CV predictions have zero variance (e.g., when all weights are +#' zero) are excluded from the optimization and assigned \eqn{\zeta_k = 0}. +#' +#' Predictions and Y are aligned by sample names (rownames) when available, +#' rather than assuming positional order. +#' +#' @seealso \code{\link{twas_weights_cv}}, \code{\link{twas_weights}}, +#' \code{\link{twas_weights_pipeline}} +#' +#' @examples +#' \dontrun{ +#' # After running twas_weights_pipeline with CV: +#' res <- twas_weights_pipeline(X, y, cv_folds = 5, weight_methods = methods) +#' +#' ens <- ensemble_weights( +#' cv_results = res$twas_cv_result, +#' Y = y, +#' twas_weight_list = res$twas_weights +#' ) +#' ens$method_coef # combination weights, sum to 1 +#' +#' # Multi-dataset ensemble (e.g., CUMC1 + MIT cell types): +#' ens_multi <- ensemble_weights( +#' cv_results = list(res_cumc$twas_cv_result, res_mit$twas_cv_result), +#' Y = list(y_cumc, y_mit), +#' twas_weight_list = list(res_cumc$twas_weights, res_mit$twas_weights) +#' ) +#' } +#' +#' @importFrom stats optim coef complete.cases sd cor +#' @export +ensemble_weights <- function(cv_results, Y, twas_weight_list = NULL, + context_index = 1, + solver = c("quadprog", "nnls", "lbfgsb", "glmnet"), + alpha = 1) { + # --- Input validation --- + solver <- match.arg(solver) + if (is.null(cv_results)) { + stop("'cv_results' is required.") + } + if (is.null(Y)) { + stop("'Y' is required.") + } + if (!is.numeric(context_index) || length(context_index) != 1 || context_index < 1) { + stop("'context_index' must be a positive integer scalar.") + } + + # --- Normalize single vs multi-dataset input --- + # Single dataset: cv_results has $prediction directly (is a twas_weights_cv() output). + # Multi-dataset: cv_results is a list of such outputs. + is_single <- !is.null(cv_results$prediction) + if (is_single) { + cv_results <- list(cv_results) + Y <- list(Y) + if (!is.null(twas_weight_list)) twas_weight_list <- list(twas_weight_list) + } else { + # Multi-dataset: validate list consistency + if (!is.list(cv_results) || length(cv_results) == 0) { + stop("For multi-dataset ensemble, 'cv_results' must be a non-empty list of ", + "twas_weights_cv() outputs.") + } + if (!is.list(Y) || length(Y) != length(cv_results)) { + stop("'Y' must be a list of the same length as 'cv_results' for ", + "multi-dataset ensemble.") + } + if (!is.null(twas_weight_list)) { + if (!is.list(twas_weight_list) || length(twas_weight_list) != length(cv_results)) { + stop("'twas_weight_list' must be a list of the same length as 'cv_results'.") + } + } + for (d in seq_along(cv_results)) { + if (is.null(cv_results[[d]]$prediction)) { + stop("cv_results[[", d, "]] does not contain '$prediction'. ", + "Expected a twas_weights_cv() output.") + } + } + } + + # --- Extract and validate method names --- + pred_names <- names(cv_results[[1]]$prediction) + if (is.null(pred_names) || any(pred_names == "")) { + stop("cv_results$prediction must be a named list (output of twas_weights_cv).") + } + base_names <- gsub("_predicted$", "", pred_names) + K <- length(base_names) + + if (K < 2) { + stop("Ensemble learning requires at least 2 methods. Found: ", K, ".") + } + + # Consistency: all datasets must report the same methods in the same order + for (d in seq_along(cv_results)) { + if (!identical(names(cv_results[[d]]$prediction), pred_names)) { + stop("All cv_results must have the same method names (in $prediction) ", + "in the same order. Dataset 1 has: ", paste(pred_names, collapse = ", "), + "; dataset ", d, " has: ", + paste(names(cv_results[[d]]$prediction), collapse = ", ")) + } + } + + # --- Build stacked prediction matrix P and observed y vector --- + pred_list <- list() + y_list <- list() + + for (d in seq_along(cv_results)) { + preds_d <- cv_results[[d]]$prediction + y_raw <- Y[[d]] + + # Get sample names from predictions and Y for alignment + pred_samples <- rownames(preds_d[[pred_names[1]]]) + y_names <- if (is.matrix(y_raw) || is.data.frame(y_raw)) { + rownames(y_raw) + } else { + names(y_raw) + } + + # Determine sample alignment + if (!is.null(pred_samples) && !is.null(y_names)) { + common <- intersect(pred_samples, y_names) + if (length(common) == 0) { + stop("No common sample names between predictions and Y in dataset ", d, ".") + } + if (length(common) < length(pred_samples) || length(common) < length(y_names)) { + message("Dataset ", d, ": using ", length(common), " common samples ", + "(predictions: ", length(pred_samples), ", Y: ", length(y_names), ").") + } + # Extract y aligned to common samples + y_d <- if (is.matrix(y_raw) || is.data.frame(y_raw)) { + if (context_index > ncol(y_raw)) { + stop("context_index (", context_index, ") exceeds number of columns in Y[[", + d, "]] (", ncol(y_raw), ").") + } + as.numeric(as.matrix(y_raw)[match(common, y_names), context_index]) + } else { + as.numeric(y_raw[match(common, y_names)]) + } + pred_order <- match(common, pred_samples) + n_d <- length(common) + } else { + # No sample names available: fall back to positional alignment + y_d <- if (is.matrix(y_raw) || is.data.frame(y_raw)) { + if (context_index > ncol(y_raw)) { + stop("context_index (", context_index, ") exceeds number of columns in Y[[", + d, "]] (", ncol(y_raw), ").") + } + as.numeric(as.matrix(y_raw)[, context_index]) + } else { + as.numeric(y_raw) + } + n_d <- length(y_d) + pred_order <- seq_len(n_d) + } + + P_d <- matrix(NA_real_, nrow = n_d, ncol = K) + colnames(P_d) <- base_names + for (k in seq_along(pred_names)) { + pred_mat <- preds_d[[pred_names[k]]] + p_col <- if (is.matrix(pred_mat)) pred_mat[pred_order, context_index] else as.numeric(pred_mat)[pred_order] + if (length(p_col) != n_d) { + stop("Prediction length for method '", pred_names[k], "' in dataset ", d, + " (", length(p_col), ") does not match number of aligned samples (", n_d, ").") + } + P_d[, k] <- p_col + } + pred_list[[d]] <- P_d + y_list[[d]] <- y_d + } + + P <- do.call(rbind, pred_list) # (n_total x K) + y_obs <- unlist(y_list) # (n_total) + + # Remove rows with any NA (in P or y) + complete <- complete.cases(P, y_obs) + n_dropped <- sum(!complete) + if (n_dropped > 0) { + message("Dropping ", n_dropped, " observation(s) with NA predictions or outcomes.") + } + if (sum(complete) < K + 1) { + stop("Too few complete observations (", sum(complete), ") for ", K, + " methods. Need at least ", K + 1, ".") + } + P <- P[complete, , drop = FALSE] + y_obs <- y_obs[complete] + + # --- Identify methods with non-zero variance predictions --- + method_sds <- apply(P, 2, sd) + valid_methods <- method_sds > .Machine$double.eps + n_valid <- sum(valid_methods) + + if (n_valid < 1) { + stop("All methods have zero-variance predictions. Cannot compute ensemble. ", + "This typically means all methods returned zero weights — check that ", + "the input data has sufficient signal.") + } + + # --- Solve for combination coefficients --- + if (n_valid == 1) { + # Only one method has signal: assign it full weight + zeta <- rep(0, K) + zeta[valid_methods] <- 1 + names(zeta) <- base_names + message("Only one method ('", base_names[valid_methods], + "') has non-zero variance predictions. Assigning it full weight.") + } else { + P_valid <- P[, valid_methods, drop = FALSE] + K_valid <- ncol(P_valid) + + zeta_valid <- switch(solver, + quadprog = .solve_ensemble_quadprog(P_valid, y_obs, K_valid), + nnls = .solve_ensemble_nnls(P_valid, y_obs, K_valid), + lbfgsb = .solve_ensemble_lbfgsb(P_valid, y_obs, K_valid), + glmnet = .solve_ensemble_glmnet(P_valid, y_obs, K_valid, alpha = alpha) + ) + + zeta <- rep(0, K) + zeta[valid_methods] <- zeta_valid + names(zeta) <- base_names + } + + # --- Performance metrics --- + method_rsq <- setNames(vapply(seq_len(K), function(k) { + if (method_sds[k] > 0) cor(y_obs, P[, k])^2 else NA_real_ + }, numeric(1)), base_names) + + # --- Build ensemble TWAS weight vector (uses first dataset's weights) --- + ensemble_twas_wt <- NULL + if (!is.null(twas_weight_list)) { + wt_list <- twas_weight_list[[1]] + if (!is.list(wt_list) || length(wt_list) == 0) { + warning("twas_weight_list[[1]] is empty or not a list; skipping weight combination.") + } else { + wt_keys <- paste0(base_names, "_weights") + matched <- wt_keys %in% names(wt_list) + + if (any(matched)) { + first_wt <- wt_list[[wt_keys[which(matched)[1]]]] + if (!is.matrix(first_wt)) first_wt <- matrix(first_wt, ncol = 1) + p <- nrow(first_wt) + n_contexts <- ncol(first_wt) + + ensemble_twas_wt <- matrix(0, nrow = p, ncol = n_contexts) + rownames(ensemble_twas_wt) <- rownames(first_wt) + colnames(ensemble_twas_wt) <- colnames(first_wt) + + for (i in which(matched)) { + w_mat <- wt_list[[wt_keys[i]]] + if (!is.matrix(w_mat)) w_mat <- matrix(w_mat, ncol = 1) + if (!identical(dim(w_mat), dim(ensemble_twas_wt))) { + warning("Weight matrix for '", wt_keys[i], + "' has inconsistent dimensions; skipping.") + next + } + ensemble_twas_wt <- ensemble_twas_wt + zeta[i] * w_mat + } + + # For univariate case, return as vector + if (n_contexts == 1) { + ensemble_twas_wt <- setNames( + as.numeric(ensemble_twas_wt), + rownames(ensemble_twas_wt) + ) + } + } else { + warning("No matching weight keys found in twas_weight_list. ", + "Expected keys like: ", + paste(wt_keys[seq_len(min(3, K))], collapse = ", ")) + } + } + } + + list( + method_coef = zeta, + ensemble_twas_weights = ensemble_twas_wt, + method_performance = method_rsq + ) +} diff --git a/man/dpr_weights.Rd b/man/dpr_weights.Rd index dcac89ef..0802aee5 100644 --- a/man/dpr_weights.Rd +++ b/man/dpr_weights.Rd @@ -2,9 +2,18 @@ % Please edit documentation in R/regularized_regression.R \name{dpr_weights} \alias{dpr_weights} +\alias{dpr_vb_weights} +\alias{dpr_gibbs_weights} +\alias{dpr_adaptive_gibbs_weights} \title{Compute Weights Using Dirichlet Process Regression (RcppDPR)} \usage{ dpr_weights(X, y, fitting_method = "VB", ...) + +dpr_vb_weights(X, y, ...) + +dpr_gibbs_weights(X, y, ...) + +dpr_adaptive_gibbs_weights(X, y, ...) } \arguments{ \item{X}{A numeric matrix of predictors.} diff --git a/man/ensemble_weights.Rd b/man/ensemble_weights.Rd new file mode 100644 index 00000000..e86eb86d --- /dev/null +++ b/man/ensemble_weights.Rd @@ -0,0 +1,123 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/twas_weights.R +\name{ensemble_weights} +\alias{ensemble_weights} +\title{Ensemble TWAS Weights via Stacked Regression} +\usage{ +ensemble_weights( + cv_results, + Y, + twas_weight_list = NULL, + context_index = 1, + solver = c("quadprog", "nnls", "lbfgsb", "glmnet"), + alpha = 1 +) +} +\arguments{ +\item{cv_results}{Output of \code{\link{twas_weights_cv}}, with \code{$prediction} +(named list of method -> out-of-fold prediction matrix, keys like +\code{"susie_predicted"}). For multi-dataset: a list of such objects.} + +\item{Y}{Observed outcome vector or matrix (samples x contexts). For +multi-dataset: a list of vectors/matrices, one per dataset.} + +\item{twas_weight_list}{Optional named list of weight matrices from +\code{\link{twas_weights}}, with keys like \code{"susie_weights"}. Used to +construct the final combined TWAS weight vector. For multi-dataset: a list +of such lists (the first is used as the weight template).} + +\item{context_index}{Integer indicating which column of Y to use when Y is a +matrix. Default is 1 (univariate).} + +\item{solver}{Character string specifying the optimization backend. +One of \code{"quadprog"} (default), \code{"nnls"}, \code{"lbfgsb"}, or +\code{"glmnet"}. +\code{"quadprog"} solves a constrained QP with sum-to-1 and non-negativity +constraints. \code{"nnls"} uses non-negative least squares (Lawson-Hanson +algorithm, as in SuperLearner) and normalizes post-hoc. \code{"lbfgsb"} +uses \code{optim(method = "L-BFGS-B")} with non-negativity bounds and +normalizes post-hoc. \code{"glmnet"} uses \code{cv.glmnet} with +\code{lower.limits = 0} for penalized non-negative regression, providing +automatic method selection via regularization. All solvers fall back to +equal weights on failure.} + +\item{alpha}{Elastic net mixing parameter, used only when +\code{solver = "glmnet"}. \code{alpha = 1} (default) is lasso (sparse +method selection), \code{alpha = 0} is ridge, and intermediate values +give elastic net.} +} +\value{ +A list with components: +\describe{ + \item{method_coef}{Named numeric vector of combination coefficients + (\eqn{\zeta_k}), non-negative and summing to 1. Names are method + base names (e.g., \code{"susie"}, \code{"enet"}).} + \item{ensemble_twas_weights}{Final combined weight vector + \eqn{w = \sum_k \zeta_k w_k}, or NULL if \code{twas_weight_list} + is not provided. Returned as a vector for univariate Y, matrix otherwise.} + \item{method_performance}{Named numeric vector of per-method R-squared + computed from out-of-fold CV predictions. Preserved so users can still + report individual method performance.} +} +} +\description{ +Given cross-validated predictions from multiple TWAS weight methods, learns +non-negative combination coefficients (summing to 1) via constrained least +squares. Returns ensemble weights and per-method performance metrics. +} +\details{ +This implements the stacked regression approach of SR-TWAS (Dai et al., +Nature Communications, 2024, \doi{10.1038/s41467-024-50983-w}). The ensemble +provides a principled way to combine predictions from many TWAS weight +methods without requiring the user to pick one method a priori or pay a +multiple-testing penalty for running several. + +For single-dataset usage, pass one \code{twas_weights_cv()} result directly. +For multi-dataset ensemble (e.g., combining cell types or reference panels +such as CUMC1 + MIT), pass a list of \code{twas_weights_cv()} results along +with a list of observed Y vectors — this learns a single joint set of +coefficients. + + +The stacked regression solves: +\deqn{\min_{\zeta} \|y - P\zeta\|^2 \quad \text{s.t.} \quad \zeta_k \geq 0,\ \sum_k \zeta_k = 1} +where P is the \eqn{n \times K} matrix of out-of-fold predictions from K +methods. Four solver backends are available: \code{"quadprog"} enforces +both constraints during optimization; \code{"nnls"}, \code{"lbfgsb"}, and +\code{"glmnet"} enforce non-negativity only, then normalize coefficients +to sum to 1. The \code{"glmnet"} solver additionally applies +regularization, which can produce sparse solutions (method selection). +If any solver fails, the function falls back to equal weights with a +warning. + +Methods whose CV predictions have zero variance (e.g., when all weights are +zero) are excluded from the optimization and assigned \eqn{\zeta_k = 0}. + +Predictions and Y are aligned by sample names (rownames) when available, +rather than assuming positional order. +} +\examples{ +\dontrun{ +# After running twas_weights_pipeline with CV: +res <- twas_weights_pipeline(X, y, cv_folds = 5, weight_methods = methods) + +ens <- ensemble_weights( + cv_results = res$twas_cv_result, + Y = y, + twas_weight_list = res$twas_weights +) +ens$method_coef # combination weights, sum to 1 + +# Multi-dataset ensemble (e.g., CUMC1 + MIT cell types): +ens_multi <- ensemble_weights( + cv_results = list(res_cumc$twas_cv_result, res_mit$twas_cv_result), + Y = list(y_cumc, y_mit), + twas_weight_list = list(res_cumc$twas_weights, res_mit$twas_weights) +) +} + +} +\seealso{ +\code{\link{twas_weights_cv}}, \code{\link{twas_weights}}, + \code{\link{twas_weights_pipeline}} +} diff --git a/man/twas_weights_pipeline.Rd b/man/twas_weights_pipeline.Rd index f3e6e3ed..82ad949c 100644 --- a/man/twas_weights_pipeline.Rd +++ b/man/twas_weights_pipeline.Rd @@ -10,12 +10,14 @@ twas_weights_pipeline( susie_fit = NULL, cv_folds = 5, sample_partition = NULL, - weight_methods = list(enet_weights = list(), lasso_weights = list(), bayes_r_weights = - list(), bayes_l_weights = list(), mrash_weights = list(init_prior_sd = TRUE, max.iter - = 100), susie_weights = list(refine = FALSE, init_L = 5, max_L = 20)), + weight_methods = "default", max_cv_variants = -1, cv_threads = 1, - cv_weight_methods = NULL + cv_weight_methods = NULL, + ensemble = FALSE, + ensemble_r2_threshold = 0.01, + ensemble_solver = "quadprog", + ensemble_alpha = 1 ) } \arguments{ @@ -34,6 +36,22 @@ twas_weights_pipeline( \item{cv_threads}{The number of threads to use for parallel computation in cross-validation. Defaults to 1.} \item{cv_weight_methods}{List of methods to use for cross-validation. If NULL, uses the same methods as weight_methods.} + +\item{ensemble}{Logical. If TRUE and cv_folds > 1, learn ensemble combination +weights via stacked regression (SR-TWAS). Requires at least two individual +methods to have been run and to pass the R-squared cutoff. Defaults to FALSE.} + +\item{ensemble_r2_threshold}{Minimum cross-validated R-squared for an individual method +to be included in the ensemble. Methods below this threshold are excluded. +Defaults to 0.01.} + +\item{ensemble_solver}{Character string specifying the optimization backend +for ensemble learning. One of \code{"quadprog"}, \code{"nnls"}, +\code{"lbfgsb"}, or \code{"glmnet"}. Passed to +\code{\link{ensemble_weights}}. Defaults to \code{"quadprog"}.} + +\item{ensemble_alpha}{Elastic net mixing parameter, used only when +\code{ensemble_solver = "glmnet"}. Defaults to 1 (lasso).} } \value{ A list containing results from the TWAS pipeline, including TWAS weights, predictions, and optionally cross-validation results. diff --git a/pixi.toml b/pixi.toml index c1b1961a..d5a0fb20 100644 --- a/pixi.toml +++ b/pixi.toml @@ -74,9 +74,11 @@ r45 = {features = ["r45"]} "r-mr.mashr" = "*" "r-mvsusier" = "*" "r-ncvreg" = "*" +"r-nnls" = "*" "r-pgenlibr" = "*" "r-purrr" = "*" "r-qgg" = "*" +"r-quadprog" = "*" "r-rcppdpr" = "*" "r-readr" = "*" "r-rfast" = "*" diff --git a/tests/testthat/test_ensemble_weights.R b/tests/testthat/test_ensemble_weights.R new file mode 100644 index 00000000..2e45f0ad --- /dev/null +++ b/tests/testthat/test_ensemble_weights.R @@ -0,0 +1,655 @@ +context("ensemble_weights") + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +# Build a synthetic twas_weights_cv() output with K methods. Each method's +# prediction is a convex combination of the truth + noise, letting us control +# per-method accuracy. Returns a list shaped exactly like twas_weights_cv()'s +# output (with $prediction, $performance, $sample_partition). +make_cv_result <- function(n = 100, K = 4, seed = 1, method_quality = NULL) { + set.seed(seed) + y <- rnorm(n) + sample_names <- paste0("sample_", seq_len(n)) + + if (is.null(method_quality)) { + # Methods with decreasing quality (noise amounts) + method_quality <- seq(0.1, 0.9, length.out = K) + } + stopifnot(length(method_quality) == K) + + method_names <- paste0("method", seq_len(K)) + pred_names <- paste0(method_names, "_predicted") + + prediction <- setNames(lapply(seq_len(K), function(k) { + noise_sd <- method_quality[k] + pred <- y + rnorm(n, sd = noise_sd) + mat <- matrix(pred, ncol = 1) + rownames(mat) <- sample_names + colnames(mat) <- "outcome_1" + mat + }), pred_names) + + # Dummy performance (not used by ensemble_weights) + performance <- setNames(lapply(seq_len(K), function(k) { + m <- matrix(NA, nrow = 1, ncol = 6) + colnames(m) <- c("corr", "rsq", "adj_rsq", "pval", "RMSE", "MAE") + m + }), paste0(method_names, "_performance")) + + list( + sample_partition = data.frame(Sample = sample_names, + Fold = rep(1:5, length.out = n), + stringsAsFactors = FALSE), + prediction = prediction, + performance = performance, + time_elapsed = 0, + .y = y, + .method_names = method_names + ) +} + +# Build synthetic twas_weights() output +make_weight_list <- function(p = 20, method_names, seed = 2) { + set.seed(seed) + setNames(lapply(method_names, function(m) { + w <- matrix(rnorm(p), ncol = 1) + rownames(w) <- paste0("var_", seq_len(p)) + colnames(w) <- "outcome_1" + w + }), paste0(method_names, "_weights")) +} + +# =========================================================================== +# Input validation +# =========================================================================== + +test_that("ensemble_weights: NULL cv_results errors", { + expect_error(ensemble_weights(NULL, Y = rnorm(10)), "cv_results") +}) + +test_that("ensemble_weights: NULL Y errors", { + cv <- make_cv_result(n = 20, K = 3) + expect_error(ensemble_weights(cv, Y = NULL), "'Y' is required") +}) + +test_that("ensemble_weights: single method errors (need >= 2 for ensemble)", { + cv <- make_cv_result(n = 20, K = 1) + expect_error(ensemble_weights(cv, Y = cv$.y), + "at least 2 methods") +}) + +test_that("ensemble_weights: invalid context_index errors", { + cv <- make_cv_result(n = 20, K = 3) + expect_error(ensemble_weights(cv, Y = cv$.y, context_index = 0), + "context_index") + expect_error(ensemble_weights(cv, Y = cv$.y, context_index = "a"), + "context_index") +}) + +test_that("ensemble_weights: context_index beyond Y columns errors", { + cv <- make_cv_result(n = 20, K = 3) + Y_mat <- matrix(cv$.y, ncol = 1) + expect_error(ensemble_weights(cv, Y = Y_mat, context_index = 5), + "context_index") +}) + +test_that("ensemble_weights: multi-dataset with mismatched lengths errors", { + cv1 <- make_cv_result(n = 20, K = 3, seed = 1) + cv2 <- make_cv_result(n = 20, K = 3, seed = 2) + expect_error(ensemble_weights(list(cv1, cv2), Y = list(cv1$.y)), + "same length") +}) + +test_that("ensemble_weights: multi-dataset with different methods errors", { + cv1 <- make_cv_result(n = 20, K = 3, seed = 1) + cv2 <- make_cv_result(n = 20, K = 4, seed = 2) + expect_error( + ensemble_weights(list(cv1, cv2), Y = list(cv1$.y, cv2$.y)), + "same method names" + ) +}) + +# =========================================================================== +# Core algorithm correctness +# =========================================================================== + +test_that("ensemble_weights: coefficients are non-negative and sum to 1", { + cv <- make_cv_result(n = 100, K = 4, seed = 42) + res <- ensemble_weights(cv, Y = cv$.y) + + expect_true(all(res$method_coef >= 0)) + expect_equal(sum(res$method_coef), 1, tolerance = 1e-6) +}) + +test_that("ensemble_weights: best method receives the largest coefficient", { + # Method 1 is best (lowest noise), method K is worst + cv <- make_cv_result(n = 200, K = 4, seed = 7, + method_quality = c(0.1, 0.5, 0.8, 1.2)) + res <- ensemble_weights(cv, Y = cv$.y) + + expect_equal(names(which.max(res$method_coef)), "method1") +}) + +test_that("ensemble_weights: does not return ensemble_performance (in-sample R^2 omitted)", { + cv <- make_cv_result(n = 300, K = 5, seed = 13) + res <- ensemble_weights(cv, Y = cv$.y) + + expect_null(res$ensemble_performance) + expect_false("ensemble_performance" %in% names(res)) +}) + +test_that("ensemble_weights: per-method R^2 values are sensible (between 0 and 1)", { + cv <- make_cv_result(n = 200, K = 4, seed = 21) + res <- ensemble_weights(cv, Y = cv$.y) + + expect_true(all(res$method_performance >= 0, na.rm = TRUE)) + expect_true(all(res$method_performance <= 1, na.rm = TRUE)) + expect_equal(length(res$method_performance), 4) +}) + +test_that("ensemble_weights: method names are stripped of _predicted suffix", { + cv <- make_cv_result(n = 50, K = 3, seed = 1) + res <- ensemble_weights(cv, Y = cv$.y) + + expect_equal(names(res$method_coef), + c("method1", "method2", "method3")) + expect_equal(names(res$method_performance), + c("method1", "method2", "method3")) +}) + +# =========================================================================== +# Sample name alignment +# =========================================================================== + +test_that("ensemble_weights: aligns Y and predictions by sample name", { + cv <- make_cv_result(n = 50, K = 3, seed = 10) + + # Shuffle Y order relative to predictions + shuffled_order <- sample(50) + y_shuffled <- cv$.y[shuffled_order] + names(y_shuffled) <- paste0("sample_", shuffled_order) + + res_aligned <- ensemble_weights(cv, Y = y_shuffled) + res_original <- ensemble_weights(cv, Y = cv$.y) + + # Results should be identical regardless of Y order + expect_equal(res_aligned$method_coef, res_original$method_coef, tolerance = 1e-10) +}) + +test_that("ensemble_weights: aligns Y matrix and predictions by sample name", { + cv <- make_cv_result(n = 50, K = 3, seed = 10) + + # Create Y as a matrix with shuffled row order + shuffled_order <- sample(50) + Y_mat <- matrix(cv$.y[shuffled_order], ncol = 1) + rownames(Y_mat) <- paste0("sample_", shuffled_order) + + res_aligned <- ensemble_weights(cv, Y = Y_mat) + res_original <- ensemble_weights(cv, Y = cv$.y) + + expect_equal(res_aligned$method_coef, res_original$method_coef, tolerance = 1e-10) +}) + +test_that("ensemble_weights: errors when no common sample names", { + cv <- make_cv_result(n = 20, K = 3, seed = 1) + y_bad <- setNames(rnorm(20), paste0("other_", seq_len(20))) + + expect_error(ensemble_weights(cv, Y = y_bad), "No common sample names") +}) + +# =========================================================================== +# Zero-variance / edge cases +# =========================================================================== + +test_that("ensemble_weights: zero-variance method gets coefficient 0", { + cv <- make_cv_result(n = 100, K = 3, seed = 5) + # Force method 2 to have constant predictions + cv$prediction$method2_predicted[, 1] <- 0.5 + res <- ensemble_weights(cv, Y = cv$.y) + + expect_equal(res$method_coef["method2"], c(method2 = 0)) + expect_equal(sum(res$method_coef), 1, tolerance = 1e-6) +}) + +test_that("ensemble_weights: NA predictions in some samples are dropped", { + cv <- make_cv_result(n = 100, K = 3, seed = 5) + cv$prediction$method1_predicted[1:5, 1] <- NA + expect_message( + res <- ensemble_weights(cv, Y = cv$.y), + "Dropping" + ) + expect_equal(sum(res$method_coef), 1, tolerance = 1e-6) +}) + +test_that("ensemble_weights: all zero-variance methods errors", { + cv <- make_cv_result(n = 50, K = 2, seed = 5) + cv$prediction$method1_predicted[, 1] <- 0 + cv$prediction$method2_predicted[, 1] <- 0 + expect_error(ensemble_weights(cv, Y = cv$.y), + "zero-variance predictions") +}) + +# =========================================================================== +# Weight combination +# =========================================================================== + +test_that("ensemble_weights: ensemble_twas_weights is sum of zeta_k * w_k", { + cv <- make_cv_result(n = 100, K = 3, seed = 42) + wt <- make_weight_list(p = 10, method_names = cv$.method_names) + + res <- ensemble_weights(cv, Y = cv$.y, twas_weight_list = wt) + + expect_false(is.null(res$ensemble_twas_weights)) + + # Verify the combination is correct + expected <- matrix(0, nrow = 10, ncol = 1) + for (k in seq_along(cv$.method_names)) { + m <- cv$.method_names[k] + expected <- expected + res$method_coef[m] * wt[[paste0(m, "_weights")]] + } + expect_equal(as.numeric(res$ensemble_twas_weights), + as.numeric(expected), + tolerance = 1e-10) +}) + +test_that("ensemble_weights: NULL twas_weight_list returns NULL ensemble_twas_weights", { + cv <- make_cv_result(n = 50, K = 3, seed = 1) + res <- ensemble_weights(cv, Y = cv$.y, twas_weight_list = NULL) + expect_null(res$ensemble_twas_weights) +}) + +test_that("ensemble_weights: weights with no matching keys warns and skips", { + cv <- make_cv_result(n = 50, K = 2, seed = 1) + wt <- list(unknown_weights = matrix(1, nrow = 10, ncol = 1)) + + expect_warning( + res <- ensemble_weights(cv, Y = cv$.y, twas_weight_list = wt), + "No matching weight keys" + ) + expect_null(res$ensemble_twas_weights) +}) + +# =========================================================================== +# Multi-dataset ensemble +# =========================================================================== + +test_that("ensemble_weights: multi-dataset combines predictions correctly", { + cv1 <- make_cv_result(n = 80, K = 3, seed = 1) + cv2 <- make_cv_result(n = 80, K = 3, seed = 2) + + res <- ensemble_weights( + cv_results = list(cv1, cv2), + Y = list(cv1$.y, cv2$.y) + ) + + expect_true(all(res$method_coef >= 0)) + expect_equal(sum(res$method_coef), 1, tolerance = 1e-6) + expect_equal(length(res$method_performance), 3) +}) + +test_that("ensemble_weights: Y as matrix with context_index works", { + cv <- make_cv_result(n = 50, K = 3, seed = 1) + Y_mat <- matrix(cv$.y, ncol = 1) + colnames(Y_mat) <- "ctx1" + + res <- ensemble_weights(cv, Y = Y_mat, context_index = 1) + expect_equal(sum(res$method_coef), 1, tolerance = 1e-6) +}) + +# =========================================================================== +# End-to-end with twas_weights_cv (integration) +# =========================================================================== + +test_that("ensemble_weights: end-to-end with twas_weights_cv output", { + skip_if_not_installed("glmnet") + + set.seed(42) + n <- 100 + p <- 20 + X <- matrix(rnorm(n * p), nrow = n, ncol = p) + colnames(X) <- paste0("var_", seq_len(p)) + rownames(X) <- paste0("sample_", seq_len(n)) + + beta <- c(1.5, -1.0, 0.8, rep(0, p - 3)) + y <- as.numeric(X %*% beta + rnorm(n, sd = 0.5)) + + cv <- suppressMessages(twas_weights_cv( + X, y, fold = 3, + weight_methods = list( + lasso_weights = list(), + enet_weights = list() + ) + )) + + res <- ensemble_weights(cv, Y = y) + + expect_equal(sum(res$method_coef), 1, tolerance = 1e-6) + expect_true(all(res$method_coef >= 0)) + expect_equal(names(res$method_coef), c("lasso", "enet")) + expect_null(res$ensemble_performance) +}) + +# =========================================================================== +# twas_weights_pipeline ensemble integration +# =========================================================================== + +test_that("pipeline: ensemble=TRUE with only 1 method prints skip message", { + skip_if_not_installed("glmnet") + + set.seed(42) + n <- 100 + p <- 20 + X <- matrix(rnorm(n * p), nrow = n, ncol = p) + colnames(X) <- paste0("var_", seq_len(p)) + rownames(X) <- paste0("sample_", seq_len(n)) + + beta <- c(1.5, -1.0, 0.8, rep(0, p - 3)) + y <- as.numeric(X %*% beta + rnorm(n, sd = 0.5)) + + msgs <- testthat::capture_messages( + res <- twas_weights_pipeline( + X, y, cv_folds = 3, + weight_methods = list(lasso_weights = list()), + ensemble = TRUE + ) + ) + + expect_true(any(grepl("at least 2 weight methods", msgs))) + + # No ensemble result should be present + expect_null(res$ensemble) + expect_null(res$twas_weights$ensemble_weights) +}) + +test_that("pipeline: ensemble=TRUE skips when methods fail R^2 cutoff", { + skip_if_not_installed("glmnet") + + set.seed(42) + n <- 100 + p <- 20 + X <- matrix(rnorm(n * p), nrow = n, ncol = p) + colnames(X) <- paste0("var_", seq_len(p)) + rownames(X) <- paste0("sample_", seq_len(n)) + + # Use signal so methods produce non-zero weights, but set threshold very high + beta <- c(1.5, -1.0, 0.8, rep(0, p - 3)) + y <- as.numeric(X %*% beta + rnorm(n, sd = 0.5)) + + msgs <- testthat::capture_messages( + res <- twas_weights_pipeline( + X, y, cv_folds = 3, + weight_methods = list(lasso_weights = list(), enet_weights = list()), + ensemble = TRUE, + ensemble_r2_threshold = 0.99 # impossibly high threshold + ) + ) + + expect_true(any(grepl("fewer than 2 methods passed the R.*cutoff", msgs))) + expect_null(res$ensemble) + expect_null(res$twas_weights$ensemble_weights) +}) + +test_that("pipeline: ensemble=TRUE succeeds and adds ensemble_weights", { + skip_if_not_installed("glmnet") + + set.seed(42) + n <- 100 + p <- 20 + X <- matrix(rnorm(n * p), nrow = n, ncol = p) + colnames(X) <- paste0("var_", seq_len(p)) + rownames(X) <- paste0("sample_", seq_len(n)) + + beta <- c(1.5, -1.0, 0.8, rep(0, p - 3)) + y <- as.numeric(X %*% beta + rnorm(n, sd = 0.5)) + + msgs <- testthat::capture_messages( + res <- twas_weights_pipeline( + X, y, cv_folds = 3, + weight_methods = list(lasso_weights = list(), enet_weights = list()), + ensemble = TRUE + ) + ) + + expect_true(any(grepl("Computing ensemble TWAS weights", msgs))) + + # Ensemble weights added alongside individual methods + expect_true("ensemble_weights" %in% names(res$twas_weights)) + expect_true("lasso_weights" %in% names(res$twas_weights)) + expect_true("enet_weights" %in% names(res$twas_weights)) + + # Ensemble predictions added + expect_true("ensemble_predicted" %in% names(res$twas_predictions)) + + # Ensemble result metadata present + expect_false(is.null(res$ensemble)) + expect_true(all(res$ensemble$method_coef >= 0)) + expect_equal(sum(res$ensemble$method_coef), 1, tolerance = 1e-6) + + # Ensemble weights should have same length as individual weights + expect_equal(length(res$twas_weights$ensemble_weights), + length(res$twas_weights$lasso_weights)) +}) + +test_that("pipeline: ensemble=FALSE does not run ensemble", { + skip_if_not_installed("glmnet") + + set.seed(42) + n <- 100 + p <- 20 + X <- matrix(rnorm(n * p), nrow = n, ncol = p) + colnames(X) <- paste0("var_", seq_len(p)) + rownames(X) <- paste0("sample_", seq_len(n)) + + beta <- c(1.5, -1.0, 0.8, rep(0, p - 3)) + y <- as.numeric(X %*% beta + rnorm(n, sd = 0.5)) + + res <- suppressMessages(twas_weights_pipeline( + X, y, cv_folds = 3, + weight_methods = list(lasso_weights = list(), enet_weights = list()), + ensemble = FALSE + )) + + expect_null(res$ensemble) + expect_null(res$twas_weights$ensemble_weights) +}) + +test_that("pipeline: ensemble_r2_threshold filters methods for ensemble", { + skip_if_not_installed("glmnet") + + set.seed(42) + n <- 100 + p <- 20 + X <- matrix(rnorm(n * p), nrow = n, ncol = p) + colnames(X) <- paste0("var_", seq_len(p)) + rownames(X) <- paste0("sample_", seq_len(n)) + + beta <- c(1.5, -1.0, 0.8, rep(0, p - 3)) + y <- as.numeric(X %*% beta + rnorm(n, sd = 0.5)) + + # Run with very low threshold — both methods should pass + msgs_low <- testthat::capture_messages( + res_low <- twas_weights_pipeline( + X, y, cv_folds = 3, + weight_methods = list(lasso_weights = list(), enet_weights = list()), + ensemble = TRUE, + ensemble_r2_threshold = 0.001 + ) + ) + expect_false(is.null(res_low$ensemble)) + + # Run with very high threshold — neither should pass + msgs_high <- testthat::capture_messages( + res_high <- twas_weights_pipeline( + X, y, cv_folds = 3, + weight_methods = list(lasso_weights = list(), enet_weights = list()), + ensemble = TRUE, + ensemble_r2_threshold = 0.99 + ) + ) + expect_true(any(grepl("fewer than 2 methods passed", msgs_high))) + expect_null(res_high$ensemble) +}) + +# =========================================================================== +# Solver alternatives +# =========================================================================== + +for (slv in c("quadprog", "nnls", "lbfgsb", "glmnet")) { + test_that(paste0("ensemble_weights: solver='", slv, "' produces valid coefficients"), { + if (slv == "quadprog") skip_if_not_installed("quadprog") + if (slv == "nnls") skip_if_not_installed("nnls") + if (slv == "glmnet") skip_if_not_installed("glmnet") + + cv <- make_cv_result(n = 100, K = 4, seed = 42) + res <- ensemble_weights(cv, Y = cv$.y, solver = slv) + + expect_true(all(res$method_coef >= 0)) + expect_equal(sum(res$method_coef), 1, tolerance = 1e-6) + expect_equal(length(res$method_coef), 4) + }) + + test_that(paste0("ensemble_weights: solver='", slv, "' assigns best method largest coef"), { + if (slv == "quadprog") skip_if_not_installed("quadprog") + if (slv == "nnls") skip_if_not_installed("nnls") + if (slv == "glmnet") skip_if_not_installed("glmnet") + + cv <- make_cv_result(n = 200, K = 4, seed = 7, + method_quality = c(0.1, 0.5, 0.8, 1.2)) + res <- ensemble_weights(cv, Y = cv$.y, solver = slv) + + expect_equal(names(which.max(res$method_coef)), "method1") + }) + + test_that(paste0("ensemble_weights: solver='", slv, "' combines weights correctly"), { + if (slv == "quadprog") skip_if_not_installed("quadprog") + if (slv == "nnls") skip_if_not_installed("nnls") + if (slv == "glmnet") skip_if_not_installed("glmnet") + + cv <- make_cv_result(n = 100, K = 3, seed = 42) + wt <- make_weight_list(p = 10, method_names = cv$.method_names) + res <- ensemble_weights(cv, Y = cv$.y, twas_weight_list = wt, solver = slv) + + expect_false(is.null(res$ensemble_twas_weights)) + + expected <- matrix(0, nrow = 10, ncol = 1) + for (k in seq_along(cv$.method_names)) { + m <- cv$.method_names[k] + expected <- expected + res$method_coef[m] * wt[[paste0(m, "_weights")]] + } + expect_equal(as.numeric(res$ensemble_twas_weights), + as.numeric(expected), + tolerance = 1e-10) + }) +} + +test_that("ensemble_weights: invalid solver errors", { + cv <- make_cv_result(n = 50, K = 3, seed = 1) + expect_error(ensemble_weights(cv, Y = cv$.y, solver = "bogus"), + "arg") +}) + +test_that("pipeline: ensemble_solver='nnls' works end-to-end", { + skip_if_not_installed("glmnet") + skip_if_not_installed("nnls") + + set.seed(42) + n <- 100 + p <- 20 + X <- matrix(rnorm(n * p), nrow = n, ncol = p) + colnames(X) <- paste0("var_", seq_len(p)) + rownames(X) <- paste0("sample_", seq_len(n)) + + beta <- c(1.5, -1.0, 0.8, rep(0, p - 3)) + y <- as.numeric(X %*% beta + rnorm(n, sd = 0.5)) + + msgs <- testthat::capture_messages( + res <- twas_weights_pipeline( + X, y, cv_folds = 3, + weight_methods = list(lasso_weights = list(), enet_weights = list()), + ensemble = TRUE, + ensemble_solver = "nnls" + ) + ) + + expect_true(any(grepl("Computing ensemble TWAS weights", msgs))) + expect_true("ensemble_weights" %in% names(res$twas_weights)) + expect_true(all(res$ensemble$method_coef >= 0)) + expect_equal(sum(res$ensemble$method_coef), 1, tolerance = 1e-6) +}) + +test_that("pipeline: ensemble_solver='lbfgsb' works end-to-end", { + skip_if_not_installed("glmnet") + + set.seed(42) + n <- 100 + p <- 20 + X <- matrix(rnorm(n * p), nrow = n, ncol = p) + colnames(X) <- paste0("var_", seq_len(p)) + rownames(X) <- paste0("sample_", seq_len(n)) + + beta <- c(1.5, -1.0, 0.8, rep(0, p - 3)) + y <- as.numeric(X %*% beta + rnorm(n, sd = 0.5)) + + msgs <- testthat::capture_messages( + res <- twas_weights_pipeline( + X, y, cv_folds = 3, + weight_methods = list(lasso_weights = list(), enet_weights = list()), + ensemble = TRUE, + ensemble_solver = "lbfgsb" + ) + ) + + expect_true(any(grepl("Computing ensemble TWAS weights", msgs))) + expect_true("ensemble_weights" %in% names(res$twas_weights)) + expect_true(all(res$ensemble$method_coef >= 0)) + expect_equal(sum(res$ensemble$method_coef), 1, tolerance = 1e-6) +}) + +test_that("pipeline: ensemble_solver='glmnet' works end-to-end", { + skip_if_not_installed("glmnet") + + set.seed(42) + n <- 100 + p <- 20 + X <- matrix(rnorm(n * p), nrow = n, ncol = p) + colnames(X) <- paste0("var_", seq_len(p)) + rownames(X) <- paste0("sample_", seq_len(n)) + + beta <- c(1.5, -1.0, 0.8, rep(0, p - 3)) + y <- as.numeric(X %*% beta + rnorm(n, sd = 0.5)) + + msgs <- testthat::capture_messages( + res <- twas_weights_pipeline( + X, y, cv_folds = 3, + weight_methods = list(lasso_weights = list(), enet_weights = list()), + ensemble = TRUE, + ensemble_solver = "glmnet" + ) + ) + + expect_true(any(grepl("Computing ensemble TWAS weights", msgs))) + expect_true("ensemble_weights" %in% names(res$twas_weights)) + expect_true(all(res$ensemble$method_coef >= 0)) + expect_equal(sum(res$ensemble$method_coef), 1, tolerance = 1e-6) +}) + +test_that("ensemble_weights: solver='glmnet' respects alpha parameter", { + skip_if_not_installed("glmnet") + + cv <- make_cv_result(n = 200, K = 4, seed = 42) + + res_lasso <- ensemble_weights(cv, Y = cv$.y, solver = "glmnet", alpha = 1) + res_ridge <- ensemble_weights(cv, Y = cv$.y, solver = "glmnet", alpha = 0) + + # Both should be valid + expect_true(all(res_lasso$method_coef >= 0)) + expect_equal(sum(res_lasso$method_coef), 1, tolerance = 1e-6) + expect_true(all(res_ridge$method_coef >= 0)) + expect_equal(sum(res_ridge$method_coef), 1, tolerance = 1e-6) + + # Lasso should be at least as sparse as ridge (fewer or equal non-zero coefs) + n_nonzero_lasso <- sum(res_lasso$method_coef > 1e-8) + n_nonzero_ridge <- sum(res_ridge$method_coef > 1e-8) + expect_true(n_nonzero_lasso <= n_nonzero_ridge) +}) diff --git a/tests/testthat/test_twas_weights.R b/tests/testthat/test_twas_weights.R index 9f4a1f2f..af38a80c 100644 --- a/tests/testthat/test_twas_weights.R +++ b/tests/testthat/test_twas_weights.R @@ -26,6 +26,64 @@ make_data <- function(n = 50, p = 10, seed = 42, add_zero_var_col = FALSE) { list(X = X, Y = Y, beta = beta) } +# =========================================================================== +# +# .twas_method_lookup +# +# =========================================================================== + +test_that(".twas_method_lookup: 'default' preset returns 6 methods", { + result <- pecotmr:::.twas_method_lookup("default") + expected_names <- c( + "susie_weights", "mrash_weights", "enet_weights", + "lasso_weights", "bayes_r_weights", "dpr_gibbs_weights" + ) + expect_equal(sort(names(result)), sort(expected_names)) +}) + +test_that(".twas_method_lookup: 'fast_default' preset returns 4 methods", { + result <- pecotmr:::.twas_method_lookup("fast_default") + expected_names <- c( + "susie_weights", "mrash_weights", "enet_weights", "lasso_weights" + ) + expect_equal(sort(names(result)), sort(expected_names)) +}) + +test_that(".twas_method_lookup: custom vector of short names", { + result <- pecotmr:::.twas_method_lookup(c("susie", "enet", "dpr_vb")) + expect_equal(sort(names(result)), sort(c("susie_weights", "enet_weights", "dpr_vb_weights"))) +}) + +test_that(".twas_method_lookup: unknown method produces error", { + expect_error( + pecotmr:::.twas_method_lookup(c("susie", "nonexistent_method")), + "Unknown TWAS method" + ) +}) + +test_that(".twas_method_lookup: default args are set for susie and mrash", { + result <- pecotmr:::.twas_method_lookup("fast_default") + expect_equal(result$susie_weights$refine, FALSE) + expect_equal(result$susie_weights$init_L, 5) + expect_equal(result$susie_weights$max_L, 20) + expect_equal(result$mrash_weights$init_prior_sd, TRUE) + expect_equal(result$mrash_weights$max.iter, 100) +}) + +test_that(".twas_method_lookup: methods with no special args get empty list", { + result <- pecotmr:::.twas_method_lookup(c("enet", "lasso")) + expect_equal(result$enet_weights, list()) + expect_equal(result$lasso_weights, list()) +}) + +test_that(".twas_method_lookup: all DPR variants can coexist", { + result <- pecotmr:::.twas_method_lookup(c("dpr_vb", "dpr_gibbs", "dpr_adaptive_gibbs")) + expect_equal( + sort(names(result)), + sort(c("dpr_vb_weights", "dpr_gibbs_weights", "dpr_adaptive_gibbs_weights")) + ) +}) + # =========================================================================== # # twas_predict @@ -168,9 +226,8 @@ test_that("twas_weights: character weight_methods input is accepted", { local_mocked_bindings( lasso_weights = function(X, y, ...) rep(0, ncol(X)) ) - # Character vector should be converted internally to named list - - result <- twas_weights(d$X, d$Y, weight_methods = c("lasso_weights")) + # Short name should be resolved via .twas_method_lookup + result <- twas_weights(d$X, d$Y, weight_methods = c("lasso")) expect_true(is.list(result)) expect_equal(names(result), "lasso_weights") }) @@ -359,7 +416,7 @@ test_that("twas_weights_cv: character weight_methods are accepted", { set.seed(42) result <- twas_weights_cv( d$X, d$Y, fold = 2, - weight_methods = c("lasso_weights") + weight_methods = c("lasso") ) expect_true(is.list(result)) expect_true("prediction" %in% names(result)) @@ -673,7 +730,7 @@ test_that("twas_weights_pipeline: returns list with expected structure (mocked)" enet_weights = function(X, y, ...) rep(0.1, ncol(X)), lasso_weights = function(X, y, ...) rep(0.2, ncol(X)), bayes_r_weights = function(X, y, ...) rep(0, ncol(X)), - bayes_l_weights = function(X, y, ...) rep(0, ncol(X)), + dpr_gibbs_weights = function(X, y, ...) rep(0, ncol(X)), mrash_weights = function(X, y, ...) rep(0, ncol(X)), susie_weights = function(X, y, ...) rep(0, ncol(X)) ) @@ -701,7 +758,7 @@ test_that("twas_weights_pipeline: twas_weights contains all default methods", { enet_weights = function(X, y, ...) rep(0.1, ncol(X)), lasso_weights = function(X, y, ...) rep(0.2, ncol(X)), bayes_r_weights = function(X, y, ...) rep(0, ncol(X)), - bayes_l_weights = function(X, y, ...) rep(0, ncol(X)), + dpr_gibbs_weights = function(X, y, ...) rep(0, ncol(X)), mrash_weights = function(X, y, ...) rep(0, ncol(X)), susie_weights = function(X, y, ...) rep(0, ncol(X)) ) @@ -710,7 +767,7 @@ test_that("twas_weights_pipeline: twas_weights contains all default methods", { expected_methods <- c( "enet_weights", "lasso_weights", "bayes_r_weights", - "bayes_l_weights", "mrash_weights", "susie_weights" + "dpr_gibbs_weights", "mrash_weights", "susie_weights" ) expect_true(all(expected_methods %in% names(result$twas_weights))) }) @@ -723,7 +780,7 @@ test_that("twas_weights_pipeline: predictions have _predicted suffix", { enet_weights = function(X, y, ...) rep(0, ncol(X)), lasso_weights = function(X, y, ...) rep(0, ncol(X)), bayes_r_weights = function(X, y, ...) rep(0, ncol(X)), - bayes_l_weights = function(X, y, ...) rep(0, ncol(X)), + dpr_gibbs_weights = function(X, y, ...) rep(0, ncol(X)), mrash_weights = function(X, y, ...) rep(0, ncol(X)), susie_weights = function(X, y, ...) rep(0, ncol(X)) ) @@ -732,7 +789,7 @@ test_that("twas_weights_pipeline: predictions have _predicted suffix", { expected_pred_names <- c( "enet_predicted", "lasso_predicted", "bayes_r_predicted", - "bayes_l_predicted", "mrash_predicted", "susie_predicted" + "dpr_gibbs_predicted", "mrash_predicted", "susie_predicted" ) expect_true(all(expected_pred_names %in% names(result$twas_predictions))) }) @@ -745,7 +802,7 @@ test_that("twas_weights_pipeline: cv_folds=0 skips cross-validation", { enet_weights = function(X, y, ...) rep(0, ncol(X)), lasso_weights = function(X, y, ...) rep(0, ncol(X)), bayes_r_weights = function(X, y, ...) rep(0, ncol(X)), - bayes_l_weights = function(X, y, ...) rep(0, ncol(X)), + dpr_gibbs_weights = function(X, y, ...) rep(0, ncol(X)), mrash_weights = function(X, y, ...) rep(0, ncol(X)), susie_weights = function(X, y, ...) rep(0, ncol(X)) ) @@ -782,6 +839,43 @@ test_that("twas_weights_pipeline: custom weight_methods are respected", { expect_equal(sort(names(result$twas_weights)), sort(c("lasso_weights", "enet_weights"))) }) +test_that("twas_weights_pipeline: accepts 'fast_default' preset string", { + d <- make_data(n = 50, p = 10) + y_vec <- as.numeric(d$Y) + + local_mocked_bindings( + enet_weights = function(X, y, ...) rep(0, ncol(X)), + lasso_weights = function(X, y, ...) rep(0, ncol(X)), + mrash_weights = function(X, y, ...) rep(0, ncol(X)), + susie_weights = function(X, y, ...) rep(0, ncol(X)) + ) + + result <- twas_weights_pipeline( + d$X, y_vec, susie_fit = NULL, cv_folds = 0, + weight_methods = "fast_default" + ) + + expected_methods <- c("susie_weights", "mrash_weights", "enet_weights", "lasso_weights") + expect_equal(sort(names(result$twas_weights)), sort(expected_methods)) +}) + +test_that("twas_weights_pipeline: accepts custom short-name vector", { + d <- make_data(n = 50, p = 10) + y_vec <- as.numeric(d$Y) + + local_mocked_bindings( + lasso_weights = function(X, y, ...) rep(1, ncol(X)), + enet_weights = function(X, y, ...) rep(2, ncol(X)) + ) + + result <- twas_weights_pipeline( + d$X, y_vec, susie_fit = NULL, cv_folds = 0, + weight_methods = c("lasso", "enet") + ) + + expect_equal(sort(names(result$twas_weights)), sort(c("lasso_weights", "enet_weights"))) +}) + test_that("twas_weights_pipeline: with susie_fit stores intermediate results", { d <- make_data(n = 50, p = 10) y_vec <- as.numeric(d$Y) @@ -800,7 +894,7 @@ test_that("twas_weights_pipeline: with susie_fit stores intermediate results", { enet_weights = function(X, y, ...) rep(0, ncol(X)), lasso_weights = function(X, y, ...) rep(0, ncol(X)), bayes_r_weights = function(X, y, ...) rep(0, ncol(X)), - bayes_l_weights = function(X, y, ...) rep(0, ncol(X)), + dpr_gibbs_weights = function(X, y, ...) rep(0, ncol(X)), mrash_weights = function(X, y, ...) rep(0, ncol(X)), susie_weights = function(X, y, ...) rep(0, ncol(X)) ) @@ -832,7 +926,7 @@ test_that("twas_weights_pipeline: with susie_fit, susie_weights gets susie_fit a enet_weights = function(X, y, ...) rep(0, ncol(X)), lasso_weights = function(X, y, ...) rep(0, ncol(X)), bayes_r_weights = function(X, y, ...) rep(0, ncol(X)), - bayes_l_weights = function(X, y, ...) rep(0, ncol(X)), + dpr_gibbs_weights = function(X, y, ...) rep(0, ncol(X)), mrash_weights = function(X, y, ...) rep(0, ncol(X)), susie_weights = function(X, y, ...) { args <- list(...)