@@ -283,7 +283,6 @@ test_that("tune model and recipe (multi-predict)", {
283
283
# ------------------------------------------------------------------------------
284
284
285
285
test_that(" tune recipe only - failure in recipe is caught elegantly" , {
286
- skip(" test is not implemented for tune_bayes()" )
287
286
skip_if_not_installed(" splines2" )
288
287
289
288
# With tune_grid() this tests for NA values in the grid.
@@ -301,36 +300,47 @@ test_that("tune recipe only - failure in recipe is caught elegantly", {
301
300
# NA values not allowed in recipe
302
301
cars_grid <- tibble(deg_free = c(3 , NA_real_ , 4 ))
303
302
304
- # ask for predictions and extractions
305
- control <- control_bayes(
306
- save_pred = TRUE ,
307
- extract = function (x ) 1L
308
- )
309
-
310
303
suppressMessages({
311
- cars_res <- tune_bayes (
304
+ cars_init_res <- tune_grid (
312
305
model ,
313
306
preprocessor = rec ,
314
307
resamples = data_folds ,
315
- control = control
308
+ grid = cars_grid
316
309
)
317
310
})
318
311
319
- notes <- cars_res $ .notes
320
- note <- notes [[1 ]]$ note
312
+ suppressMessages({
313
+ set.seed(283 ) # <- chosen to not generate faiures
314
+ cars_bayes_res <- tune_bayes(
315
+ model ,
316
+ preprocessor = rec ,
317
+ resamples = data_folds ,
318
+ initial = cars_init_res ,
319
+ iter = 2
320
+ )
321
+ })
321
322
322
- extract <- cars_res $ .extracts [[1 ]]
323
+ exp_failures <- nrow(data_folds ) * sum(! complete.cases(cars_grid ))
324
+ obs_init_failures <- collect_notes(cars_init_res ) | >
325
+ filter(type == " error" ) | >
326
+ nrow()
327
+ obs_failures <- collect_notes(cars_bayes_res ) | >
328
+ filter(type == " error" ) | >
329
+ nrow()
323
330
324
- predictions <- cars_res $ .predictions [[ 1 ]]
325
- used_deg_free <- sort(unique( predictions $ deg_free ))
331
+ exp_init_grid_res <-
332
+ cars_grid | > tidyr :: drop_na() | > distinct( deg_free ) | > nrow( )
326
333
327
- expect_length(notes , 2L )
334
+ expect_equal(obs_init_failures , obs_failures )
335
+ expect_equal(obs_failures , exp_failures )
328
336
329
- # failing rows are not in the output
330
- expect_equal(nrow(extract ), 2L )
331
- expect_equal(extract $ deg_free , c(3 , 4 ))
337
+ all_notes <- collect_notes(cars_bayes_res )
338
+ expect_equal(nrow(all_notes ), 11L )
332
339
333
- expect_equal(used_deg_free , c(3 , 4 ))
340
+ expect_equal(
341
+ collect_metrics(cars_bayes_res ) | > distinct(deg_free ) | > nrow(),
342
+ exp_init_grid_res + 2
343
+ )
334
344
})
335
345
336
346
test_that(" tune model only - failure in recipe is caught elegantly" , {
@@ -397,38 +407,51 @@ test_that("tune model and recipe - failure in recipe is caught elegantly", {
397
407
recipes :: step_spline_b(disp , deg_free = tune())
398
408
399
409
# NA values not allowed in recipe
400
- cars_grid <- tibble(deg_free = c(NA_real_ , 10L ), cost = 0.01 )
410
+ cars_grid <- tibble(
411
+ deg_free = c(3L , NA_real_ , 10L ),
412
+ cost = c(0.1 , 0.01 , 0.001 )
413
+ )
401
414
402
415
suppressMessages({
403
- cars_res <- tune_bayes (
416
+ cars_init_res <- tune_grid (
404
417
svm_mod ,
405
418
preprocessor = rec ,
406
419
resamples = data_folds ,
407
- control = control_bayes(
408
- extract = function (x ) {
409
- 1
410
- },
411
- save_pred = TRUE
412
- )
420
+ grid = cars_grid
421
+ )
422
+ })
423
+
424
+ suppressMessages({
425
+ set.seed(283 ) # <- chosen to not generate faiures
426
+ cars_bayes_res <- tune_bayes(
427
+ svm_mod ,
428
+ preprocessor = rec ,
429
+ resamples = data_folds ,
430
+ initial = cars_init_res ,
431
+ iter = 2
413
432
)
414
433
})
415
434
416
- notes <- cars_res $ .notes
417
- note <- notes [[1 ]]$ note
435
+ exp_failures <- nrow(data_folds ) * sum(! complete.cases(cars_grid ))
436
+ obs_init_failures <- collect_notes(cars_init_res ) | >
437
+ filter(type == " error" ) | >
438
+ nrow()
439
+ obs_failures <- collect_notes(cars_bayes_res ) | >
440
+ filter(type == " error" ) | >
441
+ nrow()
418
442
419
- extract <- cars_res $ .extracts [[ 1 ]]
420
- prediction <- cars_res $ .predictions [[ 1 ]]
443
+ exp_init_grid_res <-
444
+ cars_grid | > tidyr :: drop_na() | > distinct( deg_free , cost ) | > nrow()
421
445
422
- expect_length(notes , 2L )
446
+ expect_equal(obs_init_failures , obs_failures )
447
+ expect_equal(obs_failures , exp_failures )
423
448
424
- # recipe failed half of the time, only 1 model passed
425
- expect_equal(nrow(extract ), 1L )
426
- expect_equal(extract $ deg_free , 10L )
427
- expect_equal(extract $ cost , 0.01 )
449
+ all_notes <- collect_notes(cars_bayes_res )
450
+ expect_equal(nrow(all_notes ), 6L )
428
451
429
452
expect_equal(
430
- unique( prediction [, c( " deg_free" , " cost" )] ),
431
- tibble( deg_free = 10 , cost = 0.01 )
453
+ collect_metrics( cars_bayes_res ) | > distinct( deg_free , cost ) | > nrow( ),
454
+ exp_init_grid_res + 2
432
455
)
433
456
})
434
457
0 commit comments