Skip to content

Commit 5a10ff6

Browse files
authored
Merge pull request #408 from tidymodels/add-in
RStudio addin
2 parents 7b81378 + 3d30926 commit 5a10ff6

File tree

10 files changed

+312
-0
lines changed

10 files changed

+312
-0
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ export(new_model_spec)
145145
export(null_model)
146146
export(null_value)
147147
export(nullmodel)
148+
export(parsnip_addin)
148149
export(pred_value_template)
149150
export(predict.model_fit)
150151
export(predict_class.model_fit)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* An RStudio add-in is availble that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`.
4+
35
# parsnip 0.1.4
46

57
* `show_engines()` will provide information on the current set for a model.

R/add_in.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#' Start an RStudio Addin that can write model specifications
2+
#'
3+
#' `parsnip_addin()` starts a process in the RStudio IDE Viewer window
4+
#' that allows users to write code for `parsnip` model specifications from
5+
#' various R packages. The new code is written to the current document at the
6+
#' location of the cursor.
7+
#'
8+
#' @export
9+
parsnip_addin <- function() {
10+
sys.source(
11+
system.file("add-in", "gadget.R", package = "parsnip", mustWork = TRUE),
12+
envir = rlang::new_environment(parent = rlang::global_env()),
13+
keep.source = FALSE
14+
)
15+
}

R/data.R

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#' parsnip model specification database
2+
#'
3+
#' This is used in the RStudio add-in and captures information about mode
4+
#' specifications in various R packages.
5+
#'
6+
#' @name model_db
7+
#' @aliases model_db
8+
#' @docType data
9+
#' @return \item{model_db}{a data frame}
10+
#' @keywords datasets
11+
#' @examples
12+
#' data(model_db)
13+
NULL
14+

data/model_db.rda

1.98 KB
Binary file not shown.

inst/add-in/gadget.R

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
parsnip_spec_add_in <- function() {
2+
# ------------------------------------------------------------------------------
3+
# check installs
4+
5+
libs <- c("shiny", "miniUI", "rstudioapi")
6+
is_inst <- rlang::is_installed(libs)
7+
if (any(!is_inst)) {
8+
missing_pkg <- libs[!is_inst]
9+
missing_pkg <- paste0(missing_pkg, collapse = ", ")
10+
rlang::abort(
11+
glue::glue(
12+
"The add-in requires some CRAN package installs: ",
13+
glue::glue_collapse(glue::glue("'{missing_pkg}'"), sep = ", ")
14+
)
15+
)
16+
}
17+
18+
library(shiny)
19+
library(miniUI)
20+
library(rstudioapi)
21+
22+
data(model_db, package = "parsnip")
23+
24+
# ------------------------------------------------------------------------------
25+
26+
make_spec <- function(x, tune_args) {
27+
if (tune_args) {
28+
nms <- x$parameters[[1]]$parameter
29+
args <- purrr::map(nms, ~ rlang::call2("tune"))
30+
names(args) <- nms
31+
} else {
32+
args <- NULL
33+
}
34+
35+
if (x$package != "parsnip") {
36+
pkg <- x$package
37+
} else {
38+
pkg <- NULL
39+
}
40+
41+
if (length(args) > 0) {
42+
cl_1 <- rlang::call2(.ns = pkg, .fn = x$model, !!!args)
43+
} else {
44+
cl_1 <- rlang::call2(.ns = pkg, .fn = x$model)
45+
}
46+
47+
obj_nm <- paste0(x$model,"_", x$engine, "_spec")
48+
chr_1 <- rlang::expr_text(cl_1, width = 500)
49+
chr_1 <- paste0(chr_1, collapse = " ")
50+
chr_1 <- paste(obj_nm, "<-\n ", chr_1)
51+
chr_2 <- paste0("set_engine('", x$engine, "')")
52+
53+
res <- paste0(chr_1, " %>%\n ", chr_2)
54+
55+
if (!x$single_mode) {
56+
chr_3 <- paste0("set_mode('", x$mode, "')")
57+
res <- paste0(res, " %>%\n ", chr_3)
58+
}
59+
60+
res
61+
}
62+
63+
ui <-
64+
miniPage(
65+
gadgetTitleBar("Write out model specifications"),
66+
miniContentPanel(
67+
fillRow(
68+
fillCol(
69+
radioButtons(
70+
"model_mode",
71+
label = h3("Type of Model"),
72+
choices = c("Classification", "Regression")
73+
),
74+
checkboxInput(
75+
"tune_args",
76+
label = "Tag parameters for tuning (if any)?",
77+
value = TRUE
78+
),
79+
textInput(
80+
"pattern",
81+
label = "Match on (regex)"
82+
)
83+
),
84+
fillRow(
85+
miniContentPanel(uiOutput("model_choices"))
86+
)
87+
)
88+
),
89+
miniButtonBlock(
90+
actionButton("write", "Write specification code", class = "btn-success")
91+
)
92+
)
93+
94+
95+
server <-
96+
function(input, output) {
97+
get_models <- reactive({
98+
req(input$model_mode)
99+
100+
models <- model_db[model_db$mode == tolower(input$model_mode),]
101+
if (nchar(input$pattern) > 0) {
102+
incld <- grepl(input$pattern, models$model) | grepl(input$pattern, models$engine)
103+
models <- models[incld,]
104+
105+
}
106+
models
107+
}) # get_models
108+
109+
output$model_choices <- renderUI({
110+
111+
model_list <- get_models()
112+
if (nrow(model_list) > 0) {
113+
114+
choices <- paste0(model_list$model, " (", model_list$engine, ")")
115+
choices <- unique(choices)
116+
} else {
117+
choices <- NULL
118+
}
119+
120+
checkboxGroupInput(
121+
inputId = "model_name",
122+
label = "",
123+
choices = choices
124+
)
125+
}) # model_choices
126+
127+
create_code <- reactive({
128+
129+
req(input$model_name)
130+
req(input$model_mode)
131+
132+
model_mode <- tolower(input$model_mode)
133+
selected <- model_db[model_db$label %in% input$model_name,]
134+
selected <- selected[selected$mode %in% model_mode,]
135+
136+
res <- purrr::map_chr(1:nrow(selected),
137+
~ make_spec(selected[.x,], tune_args = input$tune_args))
138+
139+
paste0(res, sep = "\n\n")
140+
141+
}) # create_code
142+
143+
observeEvent(input$write, {
144+
res <- create_code()
145+
for (txt in res) {
146+
rstudioapi::insertText(txt)
147+
}
148+
})
149+
150+
observeEvent(input$done, {
151+
stopApp()
152+
})
153+
}
154+
155+
viewer <- paneViewer(300)
156+
runGadget(ui, server, viewer = viewer)
157+
}
158+
159+
parsnip_spec_add_in()
160+

inst/add-in/parsnip_model_db.R

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# ------------------------------------------------------------------------------
2+
# code to make the parsnip model database used by the RStudio addin
3+
4+
# ------------------------------------------------------------------------------
5+
6+
library(tidymodels)
7+
library(usethis)
8+
9+
# also requires installation of:
10+
packages <- c("parsnip", "discrim", "plsmod", "rules", "baguette", "poissonreg", "modeltime", "modeltime.gluonts")
11+
12+
# ------------------------------------------------------------------------------
13+
14+
# Detects model specifications via their print methods
15+
print_methods <- function(x) {
16+
require(x, character.only = TRUE)
17+
ns <- asNamespace(ns = x)
18+
mthds <- ls(envir = ns, pattern = "^print\\.")
19+
mthds <- gsub("^print\\.", "", mthds)
20+
purrr::map_dfr(mthds, get_engines) %>% dplyr::mutate(package = x)
21+
}
22+
get_engines <- function(x) {
23+
eng <- try(parsnip::show_engines(x), silent = TRUE)
24+
if (inherits(eng, "try-error")) {
25+
eng <- tibble::tibble(engine = NA_character_, mode = NA_character_, model = x)
26+
} else {
27+
eng$model <- x
28+
}
29+
eng
30+
}
31+
get_tunable_param <- function(mode, package, model, engine) {
32+
cl <- rlang::call2(.ns = package, .fn = model)
33+
obj <- rlang::eval_tidy(cl)
34+
obj <- parsnip::set_engine(obj, engine)
35+
obj <- parsnip::set_mode(obj, mode)
36+
res <-
37+
tune::tunable(obj) %>%
38+
dplyr::select(parameter = name)
39+
40+
# ------------------------------------------------------------------------------
41+
# Edit some model parameters
42+
43+
if (model == "rand_forest") {
44+
res <- res[res$parameter != "trees",]
45+
}
46+
if (model == "mars") {
47+
res <- res[res$parameter == "prod_degree",]
48+
}
49+
if (engine %in% c("rule_fit", "xgboost")) {
50+
res <- res[res$parameter != "mtry",]
51+
}
52+
if (model %in% c("bag_tree", "bag_mars")) {
53+
res <- res[0,]
54+
}
55+
if (engine %in% c("rpart")) {
56+
res <- res[res$parameter != "tree-depth",]
57+
}
58+
res
59+
60+
}
61+
62+
# ------------------------------------------------------------------------------
63+
64+
model_db <-
65+
purrr::map_dfr(packages, print_methods) %>%
66+
dplyr::filter(!is.na(engine)) %>%
67+
dplyr::mutate(label = paste0(model, " (", engine, ")")) %>%
68+
dplyr::arrange(model, engine, mode)
69+
70+
num_modes <-
71+
model_db %>%
72+
dplyr::group_by(package, model, engine) %>%
73+
dplyr::count() %>%
74+
dplyr::ungroup() %>%
75+
dplyr::mutate(single_mode = n == 1) %>%
76+
dplyr::select(package, model, engine, single_mode)
77+
78+
model_db <-
79+
dplyr::left_join(model_db, num_modes, by = c("package", "model", "engine")) %>%
80+
dplyr::filter(engine != "spark") %>%
81+
dplyr::mutate(parameters = purrr::pmap(list(mode, package, model, engine), get_tunable_param))
82+
83+
usethis::use_data(model_db, overwrite = TRUE)
84+

inst/rstudio/addins.dcf

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
Name: Generate parsnip model specifications
3+
Description: Automatically generate code for multiple parsnip model specifications.
4+
Binding: write_parsnip_specs
5+
Interactive: true

man/model_db.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/parsnip_addin.Rd

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)