Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions R/colocboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,32 @@ colocboost_validate_input_data <- function(X = NULL, Y = NULL,
} else {
if (is.data.frame(X_ref)) X_ref <- as.matrix(X_ref)
if (is.matrix(X_ref)) X_ref <- list(X_ref)

# Trim X_ref supersets before any expensive reference processing.
# ColocBoost later matches each sumstat by variant name, so columns
# absent from every mapped sumstat cannot affect the fit.
get_needed_sumstat_variants <- function(ref_idx) {
if (length(X_ref) == 1) {
ss_idx <- seq_along(sumstat)
} else if (length(X_ref) == length(sumstat)) {
ss_idx <- ref_idx
} else if (!is.null(dict_sumstatLD) && !is.null(dim(dict_sumstatLD))) {
ss_idx <- unique(dict_sumstatLD[dict_sumstatLD[, 2] == ref_idx, 1])
} else {
return(NULL)
}
unique(unlist(lapply(sumstat[ss_idx], function(ss) ss$variant), use.names = FALSE))
}
for (idx in seq_along(X_ref)) {
xref_variants <- colnames(X_ref[[idx]])
if (is.null(xref_variants)) next
needed_variants <- get_needed_sumstat_variants(idx)
if (is.null(needed_variants) || length(needed_variants) == 0) next
keep_variants <- intersect(xref_variants, needed_variants)
if (length(keep_variants) > 0 && length(keep_variants) < ncol(X_ref[[idx]])) {
X_ref[[idx]] <- X_ref[[idx]][, keep_variants, drop = FALSE]
}
}

# When N_ref >= P, precompute LD (avoids repeated O(N_ref*P) in boosting loop)
# When N_ref < P, keep X_ref for on-the-fly computation (avoids P*P memory)
Expand Down
2 changes: 1 addition & 1 deletion R/colocboost_inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ get_between_purity <- function(pos1, pos2, X = NULL, Xcorr = NULL, miss_idx = NU
X_sub2 <- scale(X[, pos2, drop = FALSE], center = T, scale = F)
value <- abs(get_matrix_mult(X_sub1, X_sub2))
} else {
if (identical(ref_label, "No_ref") || sum(Xcorr) == 1) {
if (identical(ref_label, "No_ref") || (length(Xcorr) == 1 && Xcorr == 1)) {
value <- 0
} else {
if (length(miss_idx) != 0) {
Expand Down
19 changes: 14 additions & 5 deletions R/colocboost_init.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,16 @@ colocboost_init_model <- function(cb_data,
"ld_jk" = c(),
"jk" = c(),
"scaling_factor" = if (!is.null(cb_data$data[[i]]$N)) (cb_data$data[[i]]$N - 1) else 1,
"beta_scaling" = if (!is.null(cb_data$data[[i]]$N)) 1 else 100
"beta_scaling" = if (!is.null(cb_data$data[[i]]$N)) 1 else 100,
"XtX_beta_cache" = NULL
)

data_each <- cb_data$data[[i]]
X_dict <- cb_data$dict[i]
if (!is.null(cb_data$data[[X_dict]]$XtX)) {
tmp$XtX_beta_cache <- rep(0, P - length(data_each$variable_miss))
}

# - calculate change of loglikelihood for data
tmp$change_loglike <- estimate_change_profile(
X = cb_data$data[[X_dict]]$X, Y = data_each[["Y"]], N = data_each$N,
Expand All @@ -184,6 +189,7 @@ colocboost_init_model <- function(cb_data,
XtX = cb_data$data[[X_dict]]$XtX,
beta_k = tmp$beta,
miss_idx = data_each$variable_miss,
XtX_beta_cache = tmp$XtX_beta_cache,
ref_label = cb_data$data[[X_dict]]$ref_label
)
# - initial z-score between X and residual based on correlation
Expand Down Expand Up @@ -240,7 +246,8 @@ colocboost_init_para <- function(cb_data, cb_model, tau = 0.01,
jk_equiv_corr = 0.8,
jk_equiv_loglik = 1,
func_compare = "min_max",
coloc_thresh = 0.1) {
coloc_thresh = 0.1,
ld_mismatch = "none") {
################# initialization #######################################
# - sample size
N <- sapply(cb_data$data, function(dt) dt$N)
Expand Down Expand Up @@ -312,7 +319,8 @@ colocboost_init_para <- function(cb_data, cb_model, tau = 0.01,
"func_multi_test" = func_multi_test,
"multi_test_thresh" = multi_test_thresh,
"multi_test_max" = multi_test_max,
"model_used" = "original"
"model_used" = "original",
"ld_mismatch" = ld_mismatch
)
class(cb_model_para) <- "colocboost"

Expand Down Expand Up @@ -391,7 +399,7 @@ get_correlation <- function(X = NULL, res = NULL, XtY = NULL, N = NULL,
}
if (identical(ref_label, "No_ref")) {
var_r <- YtY - 2 * sum(beta_k * XtY) + sum(beta_k^2)
} else if (!is.null(XtX_beta_cache)) {
} else if (!is.null(XtX_beta_cache) && length(XtX_beta_cache) == length(beta_k)) {
var_r <- YtY - 2 * sum(beta_k * XtY) + sum(XtX_beta_cache * beta_k)
} else {
XtX_beta_val <- compute_XtX_product(XtX, beta_k, ref_label)
Expand Down Expand Up @@ -591,7 +599,8 @@ process_sumstat <- function(Z, N, Var_y, SeBhat,
"XtY" = NULL,
"YtY" = NULL,
"N" = N[[i]],
"variable_miss" = NULL
"variable_miss" = NULL,
"R_finite_B" = NULL
)

# Get current status
Expand Down
15 changes: 10 additions & 5 deletions R/colocboost_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,10 @@ plot_initial <- function(cb_plot_input, y = "log10p",
# - set data and x-lab and y-lab
if (y == "log10p") {
plot_data <- lapply(cb_plot_input$Zscores, function(z) {
-log10(2 * pnorm(-abs(z)))
y_val <- -log10(2 * pnorm(-abs(z)))
max_finite <- max(y_val[is.finite(y_val)], na.rm = TRUE)
y_val[!is.finite(y_val)] <- max_finite
return(y_val)
})
ylab <- "-log10(p)"
} else if (y == "z_original") {
Expand Down Expand Up @@ -765,7 +768,11 @@ plot_initial <- function(cb_plot_input, y = "log10p",
ymin <- rep(args$ylim[1], length(args$y))
} else {
ymax <- NULL
ymin <- rep(0, length(args$y))
if (y %in% c("z_original", "coef")) {
ymin <- sapply(plot_data, function(p) min(p[is.finite(p)], na.rm = TRUE) * 1.05)
} else {
ymin <- rep(0, length(args$y))
}

# Check if ylim_each is FALSE but no ylim is provided
if (!ylim_each) {
Expand All @@ -784,9 +791,7 @@ plot_initial <- function(cb_plot_input, y = "log10p",
}
return(ymax)
})
if (y == "coef") {
ymin <- sapply(plot_data, function(p) min(p) * 1.05)
}

}
args$ymax <- ymax
args$ymin <- ymin
Expand Down
24 changes: 18 additions & 6 deletions R/colocboost_update.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,27 @@ colocboost_update <- function(cb_model, cb_model_para, cb_data) {
ref_label_i <- cb_data$data[[X_dict]]$ref_label
cb_model[[i]]$res <- rep(0, cb_model_para$P)
if (length(cb_data$data[[i]]$variable_miss) != 0) {
beta <- cb_model[[i]]$beta[-cb_data$data[[i]]$variable_miss] / beta_scaling
xty <- cb_data$data[[i]]$XtY[-cb_data$data[[i]]$variable_miss]
XtX_beta <- compute_XtX_product(xtx, beta, ref_label_i)
cb_model[[i]]$res[-cb_data$data[[i]]$variable_miss] <- xty - scaling_factor * XtX_beta

obs_idx <- setdiff(seq_len(cb_model_para$P), cb_data$data[[i]]$variable_miss)
beta <- cb_model[[i]]$beta[obs_idx] / beta_scaling
xty <- cb_data$data[[i]]$XtY[obs_idx]
delta_beta <- step1 * beta_grad[obs_idx] / beta_scaling
XtX_beta_cache <- cb_model[[i]]$XtX_beta_cache
if (!is.null(XtX_beta_cache) && length(XtX_beta_cache) == length(beta)) {
XtX_beta <- XtX_beta_cache + compute_XtX_product(xtx, delta_beta, ref_label_i)
} else {
XtX_beta <- compute_XtX_product(xtx, beta, ref_label_i)
}
cb_model[[i]]$res[obs_idx] <- xty - scaling_factor * XtX_beta
} else {
beta <- cb_model[[i]]$beta / beta_scaling
xty <- cb_data$data[[i]]$XtY
XtX_beta <- compute_XtX_product(xtx, beta, ref_label_i)
delta_beta <- step1 * beta_grad / beta_scaling
XtX_beta_cache <- cb_model[[i]]$XtX_beta_cache
if (!is.null(XtX_beta_cache) && length(XtX_beta_cache) == length(beta)) {
XtX_beta <- XtX_beta_cache + compute_XtX_product(xtx, delta_beta, ref_label_i)
} else {
XtX_beta <- compute_XtX_product(xtx, beta, ref_label_i)
}
cb_model[[i]]$res <- xty - scaling_factor * XtX_beta
}
# - cache XtX %*% beta for reuse in get_correlation (avoids redundant O(P^2) computation)
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test_Xref.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,27 @@ test_that("X_ref with N_ref >= P precomputes LD and produces valid results", {
expect_equal(length(result_xref$data_info$variables), 30)
})

test_that("X_ref superset is trimmed before LD precomputation", {
test_data <- generate_xref_test_data(n_ref = 80, p = 20)
extra_ref <- matrix(rnorm(nrow(test_data$X_ref) * 35), nrow(test_data$X_ref), 35)
colnames(extra_ref) <- paste0("EXTRA", seq_len(ncol(extra_ref)))
X_ref_superset <- cbind(test_data$X_ref, extra_ref)

validated <- suppressMessages(
colocboost_validate_input_data(
sumstat = test_data$sumstat,
X_ref = X_ref_superset
)
)

expected_variants <- unique(unlist(lapply(test_data$sumstat, function(ss) ss$variant)))
expect_equal(validated$ref_label, "LD")
expect_null(validated$X_ref)
expect_equal(ncol(validated$LD[[1]]), length(expected_variants))
expect_setequal(colnames(validated$LD[[1]]), expected_variants)
expect_false(any(grepl("^EXTRA", colnames(validated$LD[[1]]))))
})


# ============================================================================
# Test 3: X_ref with N_ref < P keeps X_ref for on-the-fly computation
Expand Down