Skip to content

Commit e0a238d

Browse files
topeposimonpcouch
andauthored
Add a model for bagged neural networks (#815)
* initial function * update engine docs for new model and engine * update param docs * add pkgdown topic * add model description * Apply suggestions from code review Co-authored-by: Simon P. Couch <[email protected]> * update docs * skip on windows due to different snapshot results Running the tests in 'tests/testthat.R' failed. 111 Last 13 lines of output: 112 113 * Run `snapshot_accept('boost_tree_xgboost')` to accept the change 114 * Run `snapshot_review('boost_tree_xgboost')` to interactively review the change 115 ── Failure (test_boost_tree_xgboost.R:674:3): interface to param arguments ───── 116 Snapshot of `fit_6 <- spec_6 %>% fit(mpg ~ ., data = mtcars)` has changed: 117 old vs new 118 - "! The arguments `watchlist` and `data` are guarded by parsnip and will not be passed to `xgb.train()`." 119 + "! The arguments `watchlist` and `data` are guarded by parsnip and will not be passed to `xgb.train()`." Co-authored-by: Simon P. Couch <[email protected]>
1 parent 3c0a0e2 commit e0a238d

File tree

93 files changed

+712
-358
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+712
-358
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ S3method(type_sum,model_fit)
9292
S3method(type_sum,model_spec)
9393
S3method(update,C5_rules)
9494
S3method(update,bag_mars)
95+
S3method(update,bag_mlp)
9596
S3method(update,bag_tree)
9697
S3method(update,bart)
9798
S3method(update,boost_tree)
@@ -146,6 +147,7 @@ export(augment)
146147
export(auto_ml)
147148
export(autoplot)
148149
export(bag_mars)
150+
export(bag_mlp)
149151
export(bag_tree)
150152
export(bart)
151153
export(bartMachine_interval_calc)

R/bag_mlp.R

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#' Ensembles of neural networks
2+
#'
3+
#' @description
4+
#'
5+
#' `bag_mlp()` defines an ensemble of single layer, feed-forward neural networks.
6+
#' This function can fit classification and regression models.
7+
#'
8+
#' \Sexpr[stage=render,results=rd]{parsnip:::make_engine_list("bag_mlp")}
9+
#'
10+
#' More information on how \pkg{parsnip} is used for modeling is at
11+
#' \url{https://www.tidymodels.org/}.
12+
#'
13+
#' @inheritParams mlp
14+
#'
15+
#' @template spec-details
16+
#'
17+
#' @template spec-references
18+
#'
19+
#' @seealso \Sexpr[stage=render,results=rd]{parsnip:::make_seealso_list("bag_mlp")}
20+
#' @export
21+
bag_mlp <-
22+
function(mode = "unknown",
23+
hidden_units = NULL,
24+
penalty = NULL,
25+
epochs = NULL,
26+
engine = "nnet") {
27+
args <- list(
28+
hidden_units = enquo(hidden_units),
29+
penalty = enquo(penalty),
30+
epochs = enquo(epochs)
31+
)
32+
33+
new_model_spec(
34+
"bag_mlp",
35+
args = args,
36+
eng_args = NULL,
37+
mode = mode,
38+
user_specified_mode = !missing(mode),
39+
method = NULL,
40+
engine = engine,
41+
user_specified_engine = !missing(engine)
42+
)
43+
}
44+
45+
# ------------------------------------------------------------------------------
46+
47+
#' @method update bag_mlp
48+
#' @rdname parsnip_update
49+
#' @inheritParams mars
50+
#' @export
51+
update.bag_mlp <-
52+
function(object,
53+
parameters = NULL,
54+
hidden_units = NULL, penalty = NULL, epochs = NULL,
55+
fresh = FALSE, ...) {
56+
57+
args <- list(
58+
hidden_units = enquo(hidden_units),
59+
penalty = enquo(penalty),
60+
epochs = enquo(epochs)
61+
)
62+
63+
update_spec(
64+
object = object,
65+
parameters = parameters,
66+
args_enquo_list = args,
67+
fresh = fresh,
68+
cls = "bag_mlp",
69+
...
70+
)
71+
}
72+
73+
# ------------------------------------------------------------------------------
74+
75+
set_new_model("bag_mlp")
76+
set_model_mode("bag_mlp", "classification")
77+
set_model_mode("bag_mlp", "regression")

R/bag_mlp_nnet.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#' Bagged neural networks via nnet
2+
#'
3+
#' [baguette::bagger()] creates a collection of neural networks forming an
4+
#' ensemble. All trees in the ensemble are combined to produce a final prediction.
5+
#'
6+
#' @includeRmd man/rmd/bag_mlp_nnet.md details
7+
#'
8+
#' @name details_bag_mlp_nnet
9+
#' @keywords internal
10+
NULL
11+
12+
# See inst/README-DOCS.md for a description of how these files are processed

R/print.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ model_descs <- tibble::tribble(
4040
~cls, ~desc,
4141
"auto_ml", "Automatic Machine Learning",
4242
"bag_mars", "Bagged MARS",
43+
"bag_mlp", "Bagged Neural Network",
4344
"bag_tree", "Bagged Decision Tree",
4445
"bart", "BART",
4546
"boost_tree", "Boosted Tree",

_pkgdown.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ reference:
3535
contents:
3636
- auto_ml
3737
- bag_mars
38+
- bag_mlp
3839
- bag_tree
3940
- bart
4041
- boost_tree

inst/models.tsv

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"auto_ml" "regression" "h2o" "agua"
44
"bag_mars" "classification" "earth" "baguette"
55
"bag_mars" "regression" "earth" "baguette"
6+
"bag_mlp" "classification" "nnet" "baguette"
7+
"bag_mlp" "regression" "nnet" "baguette"
68
"bag_tree" "censored regression" "rpart" "censored"
79
"bag_tree" "classification" "C5.0" "baguette"
810
"bag_tree" "classification" "rpart" "baguette"
@@ -12,10 +14,12 @@
1214
"boost_tree" "censored regression" "mboost" "censored"
1315
"boost_tree" "classification" "C5.0" NA
1416
"boost_tree" "classification" "h2o" "agua"
17+
"boost_tree" "classification" "h2o_gbm" "agua"
1518
"boost_tree" "classification" "lightgbm" "bonsai"
1619
"boost_tree" "classification" "spark" NA
1720
"boost_tree" "classification" "xgboost" NA
1821
"boost_tree" "regression" "h2o" "agua"
22+
"boost_tree" "regression" "h2o_gbm" "agua"
1923
"boost_tree" "regression" "lightgbm" "bonsai"
2024
"boost_tree" "regression" "spark" NA
2125
"boost_tree" "regression" "xgboost" NA
@@ -43,6 +47,7 @@
4347
"linear_reg" "regression" "brulee" NA
4448
"linear_reg" "regression" "gee" "multilevelmod"
4549
"linear_reg" "regression" "glm" NA
50+
"linear_reg" "regression" "glmer" "multilevelmod"
4651
"linear_reg" "regression" "glmnet" NA
4752
"linear_reg" "regression" "gls" "multilevelmod"
4853
"linear_reg" "regression" "h2o" "agua"

man/bag_mlp.Rd

Lines changed: 53 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_C5_rules_C5.0.Rd

Lines changed: 4 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_auto_ml_h2o.Rd

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_bag_mars_earth.Rd

Lines changed: 2 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)