Skip to content

Commit 70d0cf8

Browse files
authored
Merge pull request #530 from tidymodels/check-engine-mode-model
better engine and mode checking code
2 parents 6f63bed + 5d7c5ae commit 70d0cf8

File tree

10 files changed

+183
-51
lines changed

10 files changed

+183
-51
lines changed

NEWS.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
# parsnip (development version)
22

3+
## Model Specification Changes
4+
5+
* A model function (`gen_additive_mod()`) was added for generalized additive models.
6+
37
* Each model now has a default engine that is used when the model is defined. The default for each model is listed in the help documents. This also adds functionality to declare an engine in the model specification function. `set_engine()` is still required if engine-specific arguments need to be added. (#513)
48

9+
* parsnip now checks for a valid combination of engine and mode (#529)
10+
511
* The default engine for `multinom_reg()` was changed to `nnet`.
612

13+
## Other Changes
14+
715
* The helper functions `.convert_form_to_xy_fit()`, `.convert_form_to_xy_new()`, `.convert_xy_to_form_fit()`, and `.convert_xy_to_form_new()` for converting between formula and matrix interface are now exported for developer use (#508).
816

917
* Fix bug in `augment()` when non-predictor, non-outcome variables are included in data (#510).
1018

11-
* A model function (`gen_additive_mod()`) was added for generalized additive models.
1219

1320
# parsnip 0.1.6
1421

R/aaa_models.R

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Initialize model environments
22

3+
all_modes <- c("classification", "regression", "censored regression")
4+
35
# ------------------------------------------------------------------------------
46

57
## Rules about model-related information
@@ -23,10 +25,9 @@
2325

2426
# ------------------------------------------------------------------------------
2527

26-
2728
parsnip <- rlang::new_environment()
2829
parsnip$models <- NULL
29-
parsnip$modes <- c("regression", "classification", "unknown")
30+
parsnip$modes <- c(all_modes, "unknown")
3031

3132
# ------------------------------------------------------------------------------
3233

@@ -134,25 +135,119 @@ check_mode_val <- function(mode) {
134135
}
135136

136137

137-
stop_incompatible_mode <- function(spec_modes) {
138+
stop_incompatible_mode <- function(spec_modes, eng = NULL, cls = NULL) {
139+
if (is.null(eng) & is.null(cls)) {
140+
msg <- "Available modes are: "
141+
}
142+
if (!is.null(eng) & is.null(cls)) {
143+
msg <- glue::glue("Available modes for engine {eng} are: ")
144+
}
145+
if (is.null(eng) & !is.null(cls)) {
146+
msg <- glue::glue("Available modes for model type {cls} are: ")
147+
}
148+
if (!is.null(eng) & !is.null(cls)) {
149+
msg <- glue::glue("Available modes for model type {cls} with engine {eng} are: ")
150+
}
151+
138152
msg <- glue::glue(
139-
"Available modes are: ",
153+
msg,
140154
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
141155
)
142156
rlang::abort(msg)
143157
}
144158

145-
# check if class and mode are compatible
146-
check_spec_mode_val <- function(cls, mode) {
147-
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
159+
stop_incompatible_engine <- function(spec_engs, mode) {
160+
msg <- glue::glue(
161+
"Available engines for mode {mode} are: ",
162+
glue::glue_collapse(glue::glue("'{spec_engs}'"), sep = ", ")
163+
)
164+
rlang::abort(msg)
165+
}
166+
167+
stop_missing_engine <- function(cls) {
168+
info <-
169+
get_from_env(cls) %>%
170+
dplyr::group_by(mode) %>%
171+
dplyr::summarize(msg = paste0(unique(mode), " {",
172+
paste0(unique(engine), collapse = ", "),
173+
"}"),
174+
.groups = "drop")
175+
if (nrow(info) == 0) {
176+
rlang::abort(paste0("No known engines for `", cls, "()`."))
177+
}
178+
msg <- paste0(info$msg, collapse = ", ")
179+
msg <- paste("Missing engine. Possible mode/engine combinations are:", msg)
180+
rlang::abort(msg)
181+
}
182+
183+
184+
# check if class and mode and engine are compatible
185+
check_spec_mode_engine_val <- function(cls, eng, mode) {
186+
all_modes <- c("unknown", all_modes)
187+
if (!(mode %in% all_modes)) {
188+
rlang::abort(paste0("'", mode, "' is not a known mode."))
189+
}
190+
191+
model_info <- rlang::env_get(get_model_env(), cls)
192+
193+
# Cases where the model definition is in parsnip but all of the engines
194+
# are contained in a different package
195+
if (nrow(model_info) == 0) {
196+
check_mode_with_no_engine(cls, mode)
197+
return(invisible(NULL))
198+
}
199+
200+
# ------------------------------------------------------------------------------
201+
# First check engine against any mode for the given model class
202+
203+
spec_engs <- model_info$engine
204+
# engine is allowed to be NULL
205+
if (!is.null(eng) && !(eng %in% spec_engs)) {
206+
rlang::abort(
207+
paste0(
208+
"Engine '", eng, "' is not supported for `", cls, "()`. See ",
209+
"`show_engines('", cls, "')`."
210+
)
211+
)
212+
}
213+
214+
# ----------------------------------------------------------------------------
215+
# Check modes based on model and engine
216+
217+
spec_modes <- model_info$mode
218+
if (!is.null(eng)) {
219+
spec_modes <- spec_modes[model_info$engine == eng]
220+
}
221+
spec_modes <- unique(c("unknown", spec_modes))
222+
148223
if (is.null(mode) || length(mode) > 1) {
149-
stop_incompatible_mode(spec_modes)
224+
stop_incompatible_mode(spec_modes, eng)
150225
} else if (!(mode %in% spec_modes)) {
151-
stop_incompatible_mode(spec_modes)
226+
stop_incompatible_mode(spec_modes, eng)
152227
}
228+
229+
# ----------------------------------------------------------------------------
230+
# Check engine based on model and model
231+
232+
# How check for compatibility with the chosen mode (if any)
233+
if (!is.null(mode) && mode != "unknown") {
234+
spec_engs <- spec_engs[model_info$mode == mode]
235+
}
236+
spec_engs <- unique(spec_engs)
237+
if (!is.null(eng) && !(eng %in% spec_engs)) {
238+
stop_incompatible_engine(spec_engs, mode)
239+
}
240+
153241
invisible(NULL)
154242
}
155243

244+
check_mode_with_no_engine <- function(cls, mode) {
245+
spec_modes <- get_from_env(paste0(cls, "_modes"))
246+
if (!(mode %in% spec_modes)) {
247+
stop_incompatible_mode(spec_modes, cls = cls)
248+
}
249+
}
250+
156251
check_engine_val <- function(eng) {
157252
if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng))
158253
rlang::abort("Please supply a character string for an engine (e.g. `'lm'`).")
@@ -625,8 +720,7 @@ get_dependency <- function(model) {
625720
set_fit <- function(model, mode, eng, value) {
626721
check_model_exists(model)
627722
check_eng_val(eng)
628-
check_mode_val(mode)
629-
check_engine_val(eng)
723+
check_spec_mode_engine_val(model, eng, mode)
630724
check_fit_info(value)
631725

632726
current <- get_model_env()
@@ -692,8 +786,7 @@ get_fit <- function(model) {
692786
set_pred <- function(model, mode, eng, type, value) {
693787
check_model_exists(model)
694788
check_eng_val(eng)
695-
check_mode_val(mode)
696-
check_engine_val(eng)
789+
check_spec_mode_engine_val(model, eng, mode)
697790
check_pred_info(value, type)
698791

699792
current <- get_model_env()

R/arguments.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ set_mode <- function(object, mode) {
7979
cls <- class(object)[1]
8080
if (rlang::is_missing(mode)) {
8181
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
82-
stop_incompatible_mode(spec_modes)
82+
stop_incompatible_mode(spec_modes, cls = cls)
8383
}
84-
check_spec_mode_val(cls, mode)
84+
check_spec_mode_engine_val(cls, object$engine, mode)
8585
object$mode <- mode
8686
object
8787
}

R/engines.R

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,6 @@ possible_engines <- function(object, ...) {
1010
unique(engs$engine)
1111
}
1212

13-
stop_incompatible_engine <- function(avail_eng) {
14-
msg <- glue::glue(
15-
"Available engines are: ",
16-
glue::glue_collapse(glue::glue("'{avail_eng}'"), sep = ", ")
17-
)
18-
rlang::abort(msg)
19-
}
20-
21-
check_engine <- function(object) {
22-
avail_eng <- possible_engines(object)
23-
eng <- object$engine
24-
if (is.null(eng) || length(eng) > 1) {
25-
stop_incompatible_engine(avail_eng)
26-
} else if (!(eng %in% avail_eng)) {
27-
stop_incompatible_engine(avail_eng)
28-
}
29-
object
30-
}
31-
3213
# ------------------------------------------------------------------------------
3314

3415
shhhh <- function(x)
@@ -90,16 +71,16 @@ load_libs <- function(x, quiet, attach = FALSE) {
9071
#' translate(mod, engine = "glmnet")
9172
#' @export
9273
set_engine <- function(object, engine, ...) {
74+
mod_type <- class(object)[1]
9375
if (!inherits(object, "model_spec")) {
9476
rlang::abort("`object` should have class 'model_spec'.")
9577
}
9678

9779
if (rlang::is_missing(engine)) {
98-
avail_eng <- possible_engines(object)
99-
stop_incompatible_engine(avail_eng)
80+
stop_missing_engine(mod_type)
10081
}
10182
object$engine <- engine
102-
object <- check_engine(object)
83+
check_spec_mode_engine_val(mod_type, object$engine, object$mode)
10384

10485
if (object$engine == "liquidSVM") {
10586
lifecycle::deprecate_soft(
@@ -109,7 +90,7 @@ set_engine <- function(object, engine, ...) {
10990
}
11091

11192
new_model_spec(
112-
cls = class(object)[1],
93+
cls = mod_type,
11394
args = object$args,
11495
eng_args = enquos(...),
11596
mode = object$mode,

R/misc.R

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ check_empty_ellipse <- function (...) {
2323
terms
2424
}
2525

26-
all_modes <- c("classification", "regression", "censored regression")
27-
28-
2926
deparserizer <- function(x, limit = options()$width - 10) {
3027
x <- deparse(x, width.cutoff = limit)
3128
x <- gsub("^ ", "", x)
@@ -192,7 +189,7 @@ update_dot_check <- function(...) {
192189
#' @rdname add_on_exports
193190
new_model_spec <- function(cls, args, eng_args, mode, method, engine) {
194191

195-
check_spec_mode_val(cls, mode)
192+
check_spec_mode_engine_val(cls, engine, mode)
196193

197194
out <- list(args = args, eng_args = eng_args,
198195
mode = mode, method = method, engine = engine)

R/translate.R

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,15 @@ translate.default <- function(x, engine = x$engine, ...) {
5959
mod_name <- specific_model(x)
6060

6161
x$engine <- engine
62-
x <- check_engine(x)
63-
6462
if (x$mode == "unknown") {
6563
rlang::abort("Model code depends on the mode; please specify one.")
6664
}
6765

68-
if (is.null(x$method))
66+
check_spec_mode_engine_val(class(x)[1], x$engine, x$mode)
67+
68+
if (is.null(x$method)) {
6969
x$method <- get_model_spec(mod_name, x$mode, engine)
70+
}
7071

7172
arg_key <- get_args(mod_name, engine)
7273

@@ -174,7 +175,7 @@ deharmonize <- function(args, key) {
174175

175176
add_methods <- function(x, engine) {
176177
x$engine <- engine
177-
x <- check_engine(x)
178+
check_spec_mode_engine_val(class(x)[1], x$engine, x$mode)
178179
x$method <- get_model_spec(specific_model(x), x$mode, x$engine)
179180
x
180181
}

man/details_gen_additive_mod_mgcv.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/boost_tree_C5.0.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ defaults <-
1313
param <-
1414
boost_tree() %>%
1515
set_engine("C5.0") %>%
16-
set_mode("regression") %>%
16+
set_mode("classification") %>%
1717
tunable() %>%
1818
dplyr::select(-source, -component, -component_id, parsnip = name) %>%
1919
dplyr::mutate(

man/rmd/decision_tree_C5.0.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ defaults <-
1313
param <-
1414
decision_tree() %>%
1515
set_engine("C5.0") %>%
16-
set_mode("regression") %>%
16+
set_mode("classification") %>%
1717
tunable() %>%
1818
dplyr::select(-source, -component, -component_id, parsnip = name) %>%
1919
dplyr::mutate(

tests/testthat/test_args_and_modes.R

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,59 @@ test_that('pipe engine', {
4949
test_that("can't set a mode that isn't allowed by the model spec", {
5050
expect_error(
5151
set_mode(linear_reg(), "classification"),
52-
"Available modes are:"
52+
"Available modes"
5353
)
5454
})
55+
56+
57+
58+
test_that("unavailable modes for an engine and vice-versa", {
59+
expect_error(
60+
decision_tree() %>%
61+
set_mode("regression") %>%
62+
set_engine("C5.0"),
63+
"Available modes for engine C5"
64+
)
65+
expect_error(
66+
decision_tree() %>%
67+
set_engine("C5.0") %>%
68+
set_mode("regression"),
69+
"Available modes for engine C5"
70+
)
71+
72+
expect_error(
73+
decision_tree(engine = NULL) %>%
74+
set_engine("C5.0") %>%
75+
set_mode("regression"),
76+
"Available modes for engine C5"
77+
)
78+
79+
expect_error(
80+
decision_tree(engine = NULL)%>%
81+
set_mode("regression") %>%
82+
set_engine("C5.0"),
83+
"Available modes for engine C5"
84+
)
85+
86+
expect_error(
87+
proportional_hazards() %>% set_mode("regression"),
88+
"Available modes for model type proportional_hazards"
89+
)
90+
91+
expect_error(
92+
linear_reg() %>% set_mode(),
93+
"Available modes for model type linear_reg"
94+
)
95+
96+
expect_error(
97+
linear_reg() %>% set_engine(),
98+
"Missing engine"
99+
)
100+
101+
expect_error(
102+
proportional_hazards() %>% set_engine(),
103+
"No known engines for"
104+
)
105+
})
106+
107+

0 commit comments

Comments
 (0)