Skip to content

Commit bdc2854

Browse files
mattwarkentinhfricktopepo
authored
Adds support for flexsurvspline engine for survival_reg model spec (#831)
* Adds support for flexsurvspline engine for survival_reg model spec * Add PR number to NEWS * leave deprecated `surv_reg()` as is * make `k` tunable the arg name of `flexsurvspline()` is `k`, not `num_knots` the method does not need conditional registration because it's new and was never registered in tune * update `inst/models.tsv` so that `uses_extension()` in the engine docs works * update engine docs and knit * document() * update news, add contributor * import also the generic * remove engine arg from template * notes about case weights * render docs Co-authored-by: Hannah Frick <[email protected]> Co-authored-by: topepo <[email protected]>
1 parent 624dabc commit bdc2854

13 files changed

+250
-2
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ S3method(translate,survival_reg)
8989
S3method(translate,svm_linear)
9090
S3method(translate,svm_poly)
9191
S3method(translate,svm_rbf)
92+
S3method(tunable,survival_reg)
9293
S3method(type_sum,model_fit)
9394
S3method(type_sum,model_spec)
9495
S3method(update,C5_rules)
@@ -316,6 +317,7 @@ importFrom(generics,fit_xy)
316317
importFrom(generics,glance)
317318
importFrom(generics,required_pkgs)
318319
importFrom(generics,tidy)
320+
importFrom(generics,tunable)
319321
importFrom(generics,varying_args)
320322
importFrom(ggplot2,autoplot)
321323
importFrom(glue,glue_collapse)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* Adds documentation and tuning infrastructure for the new `flexsurvspline` engine for the `survival_reg()` model specification from the `censored` package (@mattwarkentin, #831).
4+
35
* The matrix interface for fitting `fit_xy()` now works for the `"censored regression"` mode (#829).
46

57
* The `num_leaves` argument of `boost_tree()`s `lightgbm` engine (via the bonsai package) is now tunable.

R/parsnip-package.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
## usethis namespace: start
1111
#' @importFrom dplyr arrange bind_cols bind_rows collect full_join group_by
1212
#' @importFrom dplyr mutate pull rename select starts_with summarise tally
13-
#' @importFrom generics varying_args
13+
#' @importFrom generics tunable varying_args
1414
#' @importFrom glue glue_collapse
1515
#' @importFrom pillar type_sum
1616
#' @importFrom purrr as_vector imap imap_lgl map map_chr map_dbl map_df map_dfr

R/survival_reg_flexsurvspline.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#' Flexible parametric survival regression
2+
#'
3+
#' [flexsurv::flexsurvspline()] fits a flexible parametric survival model.
4+
#'
5+
#' @includeRmd man/rmd/survival_reg_flexsurvspline.md details
6+
#'
7+
#' @name details_survival_reg_flexsurvspline
8+
#' @keywords internal
9+
NULL
10+
11+
# See inst/README-DOCS.md for a description of how these files are processed

R/tunable.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,17 @@ brulee_multinomial_engine_args <-
203203
brulee_mlp_engine_args %>%
204204
dplyr::filter(name %in% c("momentum", "batch_size", "stop_iter", "class_weights"))
205205

206+
flexsurvspline_engine_args <-
207+
tibble::tibble(
208+
name = c("k"),
209+
call_info = list(
210+
list(pkg = "dials", fun = "num_knots")
211+
),
212+
source = "model_spec",
213+
component = "survival_reg",
214+
component_id = "engine"
215+
)
216+
206217
# ------------------------------------------------------------------------------
207218

208219
# Lazily registered in .onLoad()
@@ -324,5 +335,14 @@ tunable_mlp <- function(x, ...) {
324335
res
325336
}
326337

338+
#' @export
339+
tunable.survival_reg <- function(x, ...) {
340+
res <- NextMethod()
341+
if (x$engine == "flexsurvspline") {
342+
res <- add_engine_parameters(res, flexsurvspline_engine_args)
343+
}
344+
res
345+
}
346+
327347
# nocov end
328348

inst/models.tsv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
"surv_reg" "regression" "flexsurv" NA
124124
"surv_reg" "regression" "survival" NA
125125
"survival_reg" "censored regression" "flexsurv" "censored"
126+
"survival_reg" "censored regression" "flexsurvspline" "censored"
126127
"survival_reg" "censored regression" "survival" "censored"
127128
"svm_linear" "classification" "kernlab" NA
128129
"svm_linear" "classification" "LiblineaR" NA

man/details_auto_ml_h2o.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_survival_reg_flexsurv.Rd

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/details_survival_reg_flexsurvspline.Rd

Lines changed: 80 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/survival_reg_flexsurv.Rmd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ For this engine, stratification cannot be specified via [`strata()`], please see
4545
```{r child = "template-survival-mean.Rmd"}
4646
```
4747

48+
## Case weights
49+
50+
```{r child = "template-uses-case-weights.Rmd"}
51+
```
52+
4853
## Saving fitted model objects
4954

5055
```{r child = "template-butcher.Rmd"}

0 commit comments

Comments
 (0)