Skip to content

Commit 374dbf9

Browse files
committed
changes for #461
1 parent adc9d77 commit 374dbf9

File tree

5 files changed

+109
-49
lines changed

5 files changed

+109
-49
lines changed

NEWS.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
* The `liquidSVM` engine for `svm_rbf()` was deprecated due to that package's removal from CRAN. (#425)
1313

14-
* The xgboost engine for boosted trees was translating `mtry` to xgboost's `colsample_bytree`. We now map `mtry` to `colsample_bynode` since that is more consistent with how random forest works. `colsample_bytree` can still be optimized by passing it in as an engine argument. (#495)
14+
* The xgboost engine for boosted trees was translating `mtry` to xgboost's `colsample_bytree`. We now map `mtry` to `colsample_bynode` since that is more consistent with how random forest works. `colsample_bytree` can still be optimized by passing it in as an engine argument. `colsample_bynode` was added to xgboost after the `parsnip` package code was written. (#495)
15+
16+
* For xgboost boosting, `mtry` and `colsample_bytree` can be passed as integer counts or proportions while `subsample` and `validation` should be proportions. `xgb_train()` now has a new option `counts` for state what scale `mtry` and `colsample_bytree` are being used. (#461)
1517

1618
## Other Changes
1719

@@ -21,12 +23,8 @@
2123

2224
* Re-organized model documentation for `update` methods (#479).
2325

24-
25-
2626
* `generics::required_pkgs()` was extended for `parsnip` objects.
2727

28-
29-
3028
# parsnip 0.1.5
3129

3230
* An RStudio add-in is available that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`.

R/boost_tree.R

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,25 @@ check_args.boost_tree <- function(object) {
265265
#' @param nrounds An integer for the number of boosting iterations.
266266
#' @param eta A numeric value between zero and one to control the learning rate.
267267
#' @param colsample_bytree Subsampling proportion of columns for each tree.
268+
#' See the `counts` argument below. The default uses all columns.
268269
#' @param colsample_bynode Subsampling proportion of columns for each node
269-
#' within each tree.
270+
#' within each tree. See the `counts` argument below. The default uses all
271+
#' columns.
270272
#' @param min_child_weight A numeric value for the minimum sum of instance
271273
#' weights needed in a child to continue to split.
272274
#' @param gamma A number for the minimum loss reduction required to make a
273275
#' further partition on a leaf node of the tree
274-
#' @param subsample Subsampling proportion of rows.
275-
#' @param validation A positive number. If on `[0, 1)` the value, `validation`
276-
#' is a random proportion of data in `x` and `y` that are used for performance
277-
#' assessment and potential early stopping. If 1 or greater, it is the _number_
278-
#' of training set samples use for these purposes.
276+
#' @param subsample Subsampling proportion of rows. By default, all of the
277+
#' training data are used.
278+
#' @param validation The _proportion_ of the data that are used for performance
279+
#' assessment and potential early stopping.
279280
#' @param early_stop An integer or `NULL`. If not `NULL`, it is the number of
280281
#' training iterations without improvement before stopping. If `validation` is
281282
#' used, performance is base on the validation set; otherwise, the training set
282283
#' is used.
284+
#' @param counts A logical. If `FALSE`, `colsample_bynode` and
285+
#' `colsample_bytree` are both assumed to be _proportions_ of the proportion of
286+
#' columns affects (instead of counts).
283287
#' @param objective A single string (or NULL) that defines the loss function that
284288
#' `xgboost` uses to create trees. See [xgboost::xgb.train()] for options. If left
285289
#' NULL, an appropriate loss function is chosen.
@@ -292,11 +296,10 @@ check_args.boost_tree <- function(object) {
292296
#' @export
293297
xgb_train <- function(
294298
x, y,
295-
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = 1,
296-
colsample_bytree = 1, min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
297-
early_stop = NULL, objective = NULL,
298-
event_level = c("first", "second"),
299-
...) {
299+
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
300+
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
301+
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
302+
event_level = c("first", "second"), ...) {
300303

301304
event_level <- rlang::arg_match(event_level, c("first", "second"))
302305
others <- list(...)
@@ -306,6 +309,7 @@ xgb_train <- function(
306309
if (!is.numeric(validation) || validation < 0 || validation >= 1) {
307310
rlang::abort("`validation` should be on [0, 1).")
308311
}
312+
309313
if (!is.null(early_stop)) {
310314
if (early_stop <= 1) {
311315
rlang::abort(paste0("`early_stop` should be on [2, ", nrounds, ")."))
@@ -315,7 +319,6 @@ xgb_train <- function(
315319
}
316320
}
317321

318-
319322
if (is.null(objective)) {
320323
if (is.numeric(y)) {
321324
objective <- "reg:squarederror"
@@ -333,26 +336,21 @@ xgb_train <- function(
333336

334337
x <- as_xgb_data(x, y, validation, event_level)
335338

336-
# translate `subsample` and `colsample_bytree` to be on (0, 1] if not
337-
if (subsample > 1) {
338-
subsample <- subsample/n
339-
}
340-
if (subsample > 1) {
341-
subsample <- 1
342-
}
343339

344-
if (colsample_bytree > 1) {
345-
colsample_bytree <- colsample_bytree/p
346-
}
347-
if (colsample_bytree > 1) {
348-
colsample_bytree <- 1
340+
if (!is.numeric(subsample) || subsample < 0 || subsample > 1) {
341+
rlang::abort("`subsample` should be on [0, 1].")
349342
}
350343

351-
if (colsample_bynode > 1) {
352-
colsample_bynode <- colsample_bynode/p
344+
# initialize
345+
if (is.null(colsample_bytree)) {
346+
colsample_bytree <- 1
347+
} else {
348+
colsample_bytree <- recalc_param(colsample_bytree, counts, p)
353349
}
354-
if (colsample_bynode > 1) {
350+
if (is.null(colsample_bynode)) {
355351
colsample_bynode <- 1
352+
} else {
353+
colsample_bynode <- recalc_param(colsample_bynode, counts, p)
356354
}
357355

358356
if (min_child_weight > n) {
@@ -400,6 +398,30 @@ xgb_train <- function(
400398
eval_tidy(call, env = current_env())
401399
}
402400

401+
recalc_param <- function(x, counts, denom) {
402+
nm <- as.character(match.call()$x)
403+
if (is.null(x)) {
404+
x <- 1
405+
} else {
406+
if (counts) {
407+
maybe_proportion(x, nm)
408+
x <- min(denom, x)/denom
409+
}
410+
}
411+
x
412+
}
413+
414+
maybe_proportion <- function(x, nm) {
415+
if (x < 1) {
416+
msg <- paste0(
417+
"The option `counts = TRUE` was used but parameter `", nm,
418+
"` was given as ", signif(x, 3), ". Please use a value >= 1 or use ",
419+
"`counts = FALSE`."
420+
)
421+
rlang::abort(msg)
422+
}
423+
}
424+
403425
#' @importFrom stats binomial
404426
xgb_pred <- function(object, newdata, ...) {
405427
if (!inherits(newdata, "xgb.DMatrix")) {
@@ -442,7 +464,8 @@ as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
442464

443465
if (!inherits(x, "xgb.DMatrix")) {
444466
if (validation > 0) {
445-
trn_index <- sample(1:n, size = floor(n * (1 - validation)) + 1)
467+
m <- floor(n * (1 - validation)) + 1
468+
trn_index <- sample(1:n, size = max(m, 2))
446469
wlist <-
447470
list(validation = xgboost::xgb.DMatrix(x[-trn_index, ], label = y[-trn_index], missing = NA))
448471
dat <- xgboost::xgb.DMatrix(x[trn_index, ], label = y[trn_index], missing = NA)

man/boost_tree.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/xgb_train.Rd

Lines changed: 15 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,15 +483,48 @@ test_that("fit and prediction with `event_level`", {
483483

484484
})
485485

486-
test_that("mtry parameters", {
486+
test_that("count/proportion parameters", {
487487
skip_if_not_installed("xgboost")
488-
fit <-
489-
boost_tree(mtry = .7, trees = 4) %>%
488+
fit1 <-
489+
boost_tree(mtry = 7, trees = 4) %>%
490490
set_engine("xgboost") %>%
491491
set_mode("regression") %>%
492492
fit(mpg ~ ., data = mtcars)
493-
expect_equal(fit$fit$params$colsample_bytree, 1)
494-
expect_equal(fit$fit$params$colsample_bynode, 0.7)
493+
expect_equal(fit1$fit$params$colsample_bytree, 1)
494+
expect_equal(fit1$fit$params$colsample_bynode, 7/(ncol(mtcars) - 1))
495+
496+
fit2 <-
497+
boost_tree(mtry = 7, trees = 4) %>%
498+
set_engine("xgboost", colsample_bytree = 4) %>%
499+
set_mode("regression") %>%
500+
fit(mpg ~ ., data = mtcars)
501+
expect_equal(fit2$fit$params$colsample_bytree, 4/(ncol(mtcars) - 1))
502+
expect_equal(fit2$fit$params$colsample_bynode, 7/(ncol(mtcars) - 1))
503+
504+
fit3 <-
505+
boost_tree(trees = 4) %>%
506+
set_engine("xgboost") %>%
507+
set_mode("regression") %>%
508+
fit(mpg ~ ., data = mtcars)
509+
expect_equal(fit3$fit$params$colsample_bytree, 1)
510+
expect_equal(fit3$fit$params$colsample_bynode, 1)
511+
512+
fit4 <-
513+
boost_tree(mtry = .9, trees = 4) %>%
514+
set_engine("xgboost", colsample_bytree = .1, counts = FALSE) %>%
515+
set_mode("regression") %>%
516+
fit(mpg ~ ., data = mtcars)
517+
expect_equal(fit4$fit$params$colsample_bytree, .1)
518+
expect_equal(fit4$fit$params$colsample_bynode, .9)
519+
520+
expect_error(
521+
boost_tree(mtry = .9, trees = 4) %>%
522+
set_engine("xgboost") %>%
523+
set_mode("regression") %>%
524+
fit(mpg ~ ., data = mtcars),
525+
"was given as 0.9"
526+
)
527+
495528
})
496529

497530

0 commit comments

Comments
 (0)