Skip to content

Commit 0290da6

Browse files
committed
make 'objective' and argument to xgb_train
1 parent cb08638 commit 0290da6

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

R/boost_tree.R

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ check_args.boost_tree <- function(object) {
302302
#' training iterations without improvement before stopping. If `validation` is
303303
#' used, performance is base on the validation set; otherwise, the training set
304304
#' is used.
305+
#' @param objective A single string (or NULL) that defines the loss function that
306+
#' `xgboost` uses to create trees. See [xgboost::xgb.train()] for options. If left
307+
#' NULL, an appropriate loss function is chosen.
305308
#' @param ... Other options to pass to `xgb.train`.
306309
#' @return A fitted `xgboost` object.
307310
#' @keywords internal
@@ -310,7 +313,7 @@ xgb_train <- function(
310313
x, y,
311314
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bytree = 1,
312315
min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
313-
early_stop = NULL, ...) {
316+
early_stop = NULL, objective = NULL, ...) {
314317

315318
others <- list(...)
316319

@@ -329,14 +332,14 @@ xgb_train <- function(
329332
}
330333

331334

332-
if (!any(names(others) == "objective")) {
335+
if (is.null(objective)) {
333336
if (is.numeric(y)) {
334-
others$objective <- "reg:squarederror"
337+
objective <- "reg:squarederror"
335338
} else {
336339
if (num_class == 2) {
337-
others$objective <- "binary:logistic"
340+
objective <- "binary:logistic"
338341
} else {
339-
others$objective <- "multi:softprob"
342+
objective <- "multi:softprob"
340343
}
341344
}
342345
}
@@ -374,7 +377,8 @@ xgb_train <- function(
374377
gamma = gamma,
375378
colsample_bytree = colsample_bytree,
376379
min_child_weight = min(min_child_weight, n),
377-
subsample = subsample
380+
subsample = subsample,
381+
objective = objective
378382
)
379383

380384
main_args <- list(
@@ -413,13 +417,12 @@ xgb_pred <- function(object, newdata, ...) {
413417

414418
res <- predict(object, newdata, ...)
415419

416-
x = switch(
420+
x <- switch(
417421
object$params$objective,
418-
"reg:squarederror" = , "reg:logistic" = , "binary:logistic" = res,
419422
"binary:logitraw" = stats::binomial()$linkinv(res),
420423
"multi:softprob" = matrix(res, ncol = object$params$num_class, byrow = TRUE),
421-
res
422-
)
424+
res)
425+
423426
x
424427
}
425428

man/xgb_train.Rd

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)