@@ -302,6 +302,9 @@ check_args.boost_tree <- function(object) {
302
302
# ' training iterations without improvement before stopping. If `validation` is
303
303
# ' used, performance is base on the validation set; otherwise, the training set
304
304
# ' is used.
305
+ # ' @param objective A single string (or NULL) that defines the loss function that
306
+ # ' `xgboost` uses to create trees. See [xgboost::xgb.train()] for options. If left
307
+ # ' NULL, an appropriate loss function is chosen.
305
308
# ' @param ... Other options to pass to `xgb.train`.
306
309
# ' @return A fitted `xgboost` object.
307
310
# ' @keywords internal
@@ -310,7 +313,7 @@ xgb_train <- function(
310
313
x , y ,
311
314
max_depth = 6 , nrounds = 15 , eta = 0.3 , colsample_bytree = 1 ,
312
315
min_child_weight = 1 , gamma = 0 , subsample = 1 , validation = 0 ,
313
- early_stop = NULL , ... ) {
316
+ early_stop = NULL , objective = NULL , ... ) {
314
317
315
318
others <- list (... )
316
319
@@ -329,14 +332,14 @@ xgb_train <- function(
329
332
}
330
333
331
334
332
- if (! any(names( others ) == " objective" )) {
335
+ if (is.null( objective )) {
333
336
if (is.numeric(y )) {
334
- others $ objective <- " reg:squarederror"
337
+ objective <- " reg:squarederror"
335
338
} else {
336
339
if (num_class == 2 ) {
337
- others $ objective <- " binary:logistic"
340
+ objective <- " binary:logistic"
338
341
} else {
339
- others $ objective <- " multi:softprob"
342
+ objective <- " multi:softprob"
340
343
}
341
344
}
342
345
}
@@ -374,7 +377,8 @@ xgb_train <- function(
374
377
gamma = gamma ,
375
378
colsample_bytree = colsample_bytree ,
376
379
min_child_weight = min(min_child_weight , n ),
377
- subsample = subsample
380
+ subsample = subsample ,
381
+ objective = objective
378
382
)
379
383
380
384
main_args <- list (
@@ -413,13 +417,12 @@ xgb_pred <- function(object, newdata, ...) {
413
417
414
418
res <- predict(object , newdata , ... )
415
419
416
- x = switch (
420
+ x <- switch (
417
421
object $ params $ objective ,
418
- " reg:squarederror" = , " reg:logistic" = , " binary:logistic" = res ,
419
422
" binary:logitraw" = stats :: binomial()$ linkinv(res ),
420
423
" multi:softprob" = matrix (res , ncol = object $ params $ num_class , byrow = TRUE ),
421
- res
422
- )
424
+ res )
425
+
423
426
x
424
427
}
425
428
0 commit comments