Skip to content

Commit 8854d7c

Browse files
authored
Merge pull request #411 from tidymodels/xgb-objective
allow 'objective' for xgboost to be passed as an engine argument
2 parents 5a10ff6 + 2203c21 commit 8854d7c

File tree

4 files changed

+45
-16
lines changed

4 files changed

+45
-16
lines changed

NEWS.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# parsnip (development version)
22

3-
* An RStudio add-in is availble that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`.
4-
3+
* An RStudio add-in is availble that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`.
4+
5+
* For `xgboost` models, users can now pass `objective` to `set_engine("xgboost")`.
6+
57
# parsnip 0.1.4
68

79
* `show_engines()` will provide information on the current set for a model.

R/boost_tree.R

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ check_args.boost_tree <- function(object) {
302302
#' training iterations without improvement before stopping. If `validation` is
303303
#' used, performance is base on the validation set; otherwise, the training set
304304
#' is used.
305+
#' @param objective A single string (or NULL) that defines the loss function that
306+
#' `xgboost` uses to create trees. See [xgboost::xgb.train()] for options. If left
307+
#' NULL, an appropriate loss function is chosen.
305308
#' @param ... Other options to pass to `xgb.train`.
306309
#' @return A fitted `xgboost` object.
307310
#' @keywords internal
@@ -310,7 +313,9 @@ xgb_train <- function(
310313
x, y,
311314
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bytree = 1,
312315
min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
313-
early_stop = NULL, ...) {
316+
early_stop = NULL, objective = NULL, ...) {
317+
318+
others <- list(...)
314319

315320
num_class <- length(levels(y))
316321

@@ -327,13 +332,15 @@ xgb_train <- function(
327332
}
328333

329334

330-
if (is.numeric(y)) {
331-
loss <- "reg:squarederror"
332-
} else {
333-
if (num_class == 2) {
334-
loss <- "binary:logistic"
335+
if (is.null(objective)) {
336+
if (is.numeric(y)) {
337+
objective <- "reg:squarederror"
335338
} else {
336-
loss <- "multi:softprob"
339+
if (num_class == 2) {
340+
objective <- "binary:logistic"
341+
} else {
342+
objective <- "multi:softprob"
343+
}
337344
}
338345
}
339346

@@ -370,15 +377,15 @@ xgb_train <- function(
370377
gamma = gamma,
371378
colsample_bytree = colsample_bytree,
372379
min_child_weight = min(min_child_weight, n),
373-
subsample = subsample
380+
subsample = subsample,
381+
objective = objective
374382
)
375383

376384
main_args <- list(
377385
data = quote(x$data),
378386
watchlist = quote(x$watchlist),
379387
params = arg_list,
380388
nrounds = nrounds,
381-
objective = loss,
382389
early_stopping_rounds = early_stop
383390
)
384391
if (!is.null(num_class) && num_class > 2) {
@@ -388,7 +395,7 @@ xgb_train <- function(
388395
call <- make_call(fun = "xgb.train", ns = "xgboost", main_args)
389396

390397
# override or add some other args
391-
others <- list(...)
398+
392399
others <-
393400
others[!(names(others) %in% c("data", "weights", "nrounds", "num_class", names(arg_list)))]
394401
if (!(any(names(others) == "verbose"))) {
@@ -410,13 +417,12 @@ xgb_pred <- function(object, newdata, ...) {
410417

411418
res <- predict(object, newdata, ...)
412419

413-
x = switch(
420+
x <- switch(
414421
object$params$objective,
415-
"reg:squarederror" = , "reg:logistic" = , "binary:logistic" = res,
416422
"binary:logitraw" = stats::binomial()$linkinv(res),
417423
"multi:softprob" = matrix(res, ncol = object$params$num_class, byrow = TRUE),
418-
res
419-
)
424+
res)
425+
420426
x
421427
}
422428

man/xgb_train.Rd

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,26 @@ test_that('xgboost regression prediction', {
159159

160160
form_pred <- predict(form_fit$fit, newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1])))
161161
expect_equal(form_pred, predict(form_fit, new_data = mtcars[1:8, -1])$.pred)
162+
163+
expect_equal(form_fit$fit$params$objective, "reg:squarederror")
164+
162165
})
163166

164167

165168

169+
test_that('xgboost alternate objective', {
170+
skip_if_not_installed("xgboost")
171+
172+
spec <-
173+
boost_tree() %>%
174+
set_engine("xgboost", objective = "reg:pseudohubererror") %>%
175+
set_mode("regression")
176+
177+
xgb_fit <- spec %>% fit(mpg ~ ., data = mtcars)
178+
expect_equal(xgb_fit$fit$params$objective, "reg:pseudohubererror")
179+
})
180+
181+
166182
test_that('submodel prediction', {
167183

168184
skip_if_not_installed("xgboost")

0 commit comments

Comments
 (0)