Chapter 13 Validating models

Model validation

Figure 13.1: Model validation

In Chapter 12, we covered various ways of splitting the data into subsets. In this chapter, we will use these subsets to assess model performance for model validation using:

  • a holdout set,
  • cross-validation, and
  • bootstrapping.

Load required packages:

Code
library(tidyverse)
library(tidymodels)
library(patchwork)

and setup the parallel backend for faster processing (see Appendix 27.1 for details):

Code
library(doParallel)
cl <- makePSOCKcluster(parallel::detectCores(logical = FALSE))
registerDoParallel(cl)

13.1 Model validation using holdout set

In the following we demonstrate how to use a holdout set to assess the performance of a regression model to predict mileage for cars.

Code
# Load an preprocess the data
data <- datasets::mtcars %>% 
    as_tibble(rownames="car") %>%
    mutate(
        vs = factor(vs, labels=c("V-shaped", "straight")),
        am = factor(am, labels=c("automatic", "manual")),
    )

# Split the data into training and test/holdout sets
set.seed(1353)
car_split <- initial_split(mtcars)
train_data <- training(car_split)
holdout_data <- testing(car_split)

# Train a model using the training set
formula <- mpg ~ cyl + disp + hp + drat + wt + qsec + vs + am + gear + carb
model <- linear_reg() %>% 
    set_engine("lm") %>% 
    fit(formula, data=train_data)

We now have a trained model and can use it to assess model performance for the training and holdout set using the default regression metrics from yardstick.

Code
perf_train <- augment(model, new_data=train_data) %>% 
    metrics(truth=mpg, estimate=.pred) %>% 
    mutate(data='Training')
perf_holdout <- augment(model, new_data=holdout_data) %>%
    metrics(truth=mpg, estimate=.pred) %>% 
    mutate(data='Holdout')

We can combine the performance metrics for the training and holdout set into a single table for comparison.

Code
bind_rows(perf_train, perf_holdout) %>%  # combine the two results
    select(data, .estimate, .metric) %>%  # select the columns of interest
    pivot_wider(names_from=.metric, values_from=.estimate) %>%  # convert to wide format
    knitr::kable(digits = 2) %>%
    kableExtra::kable_styling(full_width=FALSE)
data rmse rsq mae
Training 2.06 0.90 1.62
Holdout 3.10 0.74 2.78

The performance metrics on the training set indicate better performance compared to the holdout set. This is expected since the model was trained on the training set.

13.2 Model validation using cross-validation

We will now use cross-validation to assess model performance. This time, we train a logistic regression model for the Universal Bank dataset using the entire dataset and use cross-validation to assess model performance. First download the dataset and preprocess it.

Code
# Load and preprocess the data
data <- read_csv("https://gedeck.github.io/DS-6030/datasets/UniversalBank.csv.gz")
data <- data %>% 
    select(-c(ID, `ZIP Code`)) %>%
    rename(
        Personal.Loan = `Personal Loan`,
        Securities.Account = `Securities Account`,
        CD.Account = `CD Account`
    ) %>%
    mutate(
        Personal.Loan = factor(Personal.Loan, labels=c("Yes", "No"), levels=c(1, 0)),
        Education = factor(Education, labels=c("Undergrad", "Graduate", "Advanced")),
    )

Now we setup and execute the cross-validation.

Code
# Use 10-fold cross-validation to assess model performance
set.seed(1353)
folds <- vfold_cv(data, strata=Personal.Loan)

# define the model
formula <- Personal.Loan ~ Age + Experience + Income + Family + CCAvg + Education + 
                           Mortgage + Securities.Account + CD.Account + Online + 
                           CreditCard
logreg_model <- logistic_reg() %>% 
    set_engine("glm")
  
# define and execute the cross-validation workflow
logreg_wf <- workflow() %>%
    add_model(logreg_model) %>%
    add_formula(formula)

logreg_fit_cv <- logreg_wf %>% 
    fit_resamples(folds)

This is all we need to do. We first define our resampling approach using the function vfold_cv passing in the dataset and information about the column we want to use for stratified sampling, the outcome variable Personal.Loan. The default v=10 is used to define a 10-fold cross validation. Next we setup our model and define the formulat to use for training. Finaly, we combine the model and formula into a workflow and use the fit_resamples function to execute the cross-validation. The results are stored in the logreg_fit_cv object. Let’s have a look at it:

Code
logreg_fit_cv
## # Resampling results
## # 10-fold cross-validation using stratification 
## # A tibble: 10 × 4
##    splits             id     .metrics         .notes          
##    <list>             <chr>  <list>           <list>          
##  1 <split [4500/500]> Fold01 <tibble [3 × 4]> <tibble [0 × 3]>
##  2 <split [4500/500]> Fold02 <tibble [3 × 4]> <tibble [0 × 3]>
##  3 <split [4500/500]> Fold03 <tibble [3 × 4]> <tibble [0 × 3]>
##  4 <split [4500/500]> Fold04 <tibble [3 × 4]> <tibble [0 × 3]>
##  5 <split [4500/500]> Fold05 <tibble [3 × 4]> <tibble [0 × 3]>
##  6 <split [4500/500]> Fold06 <tibble [3 × 4]> <tibble [0 × 3]>
##  7 <split [4500/500]> Fold07 <tibble [3 × 4]> <tibble [0 × 3]>
##  8 <split [4500/500]> Fold08 <tibble [3 × 4]> <tibble [0 × 3]>
##  9 <split [4500/500]> Fold09 <tibble [3 × 4]> <tibble [0 × 3]>
## 10 <split [4500/500]> Fold10 <tibble [3 × 4]> <tibble [0 × 3]>

It’s not very informative. logreg_fit_cv is a tibble where each row corresponds to models trained for each fold (column id).8 Information about what was used at each iteration is in the splits column. The performance on the out-of-fold validation set is in the .metrics column.

We can now use the collect_metrics function to extract information about the performance metrics for each fold. By default it will summarize the information for each metric into a mean and standard error.

Code
cv_metrics <- collect_metrics(logreg_fit_cv)
cv_metrics
## # A tibble: 3 × 6
##   .metric     .estimator   mean     n std_err .config             
##   <chr>       <chr>       <dbl> <int>   <dbl> <chr>               
## 1 accuracy    binary     0.958     10 0.00184 Preprocessor1_Model1
## 2 brier_class binary     0.0322    10 0.00123 Preprocessor1_Model1
## 3 roc_auc     binary     0.961     10 0.00502 Preprocessor1_Model1

We can see that during cross-validation, the model performance is evaluated using accuracy and the AUC under the ROC curve. The metrics are combined into a mean and an associated standard deviation.

By default, individual predictions on the out-of-fold dataset for the performance metrics are not returned. If we want to keep these for further analysis, we need to add a control statement to the fit_resamples call.

Code
logreg_fit_cv <- logreg_wf %>% 
    fit_resamples(folds, control=control_resamples(save_pred=TRUE))

The control_resamples function returns the is used to pass additional arguments to the fit_resamples function. Here, we override the default behavior by setting save_pred=TRUE which instructs the function to preserve the out-of-fold predictions for each fold. The collect_predictions function returns a tibble with all predictions for each fold.

Code
cv_predictions <- collect_predictions(logreg_fit_cv)
cv_predictions
## # A tibble: 5,000 × 7
##    .pred_class .pred_Yes .pred_No id      .row Personal.Loan .config            
##    <fct>           <dbl>    <dbl> <chr>  <int> <fct>         <chr>              
##  1 No           0.0175      0.982 Fold01     7 No            Preprocessor1_Mode…
##  2 No           0.0490      0.951 Fold01    25 No            Preprocessor1_Mode…
##  3 No           0.0123      0.988 Fold01    28 No            Preprocessor1_Mode…
##  4 Yes          0.785       0.215 Fold01    30 Yes           Preprocessor1_Mode…
##  5 No           0.00122     0.999 Fold01    31 No            Preprocessor1_Mode…
##  6 No           0.00517     0.995 Fold01    47 No            Preprocessor1_Mode…
##  7 No           0.00142     0.999 Fold01    59 No            Preprocessor1_Mode…
##  8 Yes          0.654       0.346 Fold01    60 No            Preprocessor1_Mode…
##  9 No           0.0113      0.989 Fold01    74 No            Preprocessor1_Mode…
## 10 No           0.000650    0.999 Fold01    88 No            Preprocessor1_Mode…
## # ℹ 4,990 more rows

We can use this information to calculate a ROC curves on the out-of-fold predictions.

Code
cv_predictions %>% 
    group_by(id) %>% 
    roc_curve(truth=Personal.Loan, .pred_Yes, event_level="first") %>% 
    autoplot() +
        geom_abline(lty=2) +
        theme(legend.position="none")
Individual ROC curves for cross-validation folds

Figure 13.2: Individual ROC curves for cross-validation folds

Instead of showing individual ROC curves for each fold, we can also combine them into a single plot. Figure 13.3 compares the ROC curves for the cross-validation predictions and the predictions on the training set.

Code
# Train a model on the full dataset
full_model <- logistic_reg() %>% 
    set_engine("glm") %>% 
    fit(formula, data=data)

cv_ROC <- cv_predictions %>% 
    roc_curve(truth=Personal.Loan, .pred_Yes, event_level="first")

ontrain_ROC <- augment(full_model, new_data=data) %>% 
    roc_curve(Personal.Loan, .pred_Yes, event_level="first")

ggplot() +
    geom_path(data=cv_ROC, aes(x=1-specificity, y=sensitivity)) +
    geom_path(data=ontrain_ROC, aes(x=1-specificity, y=sensitivity), color="red") +
    geom_abline(lty=2)
Comparison of ROC curves for cross-validation predictions (black) and on-training set predictions (red)

Figure 13.3: Comparison of ROC curves for cross-validation predictions (black) and on-training set predictions (red)

The ROC curves for the cross-validation predictions are very similar to the ROC curves for the predictions on the training set. This indicates that the model is not overfitting the training data.

13.3 Model validation using bootstrapping

The bootstraps function from the rsample package can be used to generate bootstrap samples.9 We can use these to assess model performance using bootstrap resampling. This time, we train a nearest neighbor model (see appendix 25.5) for the Universal Bank dataset. Because the kknn package supports both classification and regression, we need to specify the type of model we want to train.

Code
# Use bootstrap to assess model performance
set.seed(1353)
resamples <- rsample::bootstraps(data)

# define the model
formula <- Personal.Loan ~ Age + Experience + Income + Family + CCAvg + Education + 
                           Mortgage + Securities.Account + CD.Account + Online + 
                           CreditCard
nn_model <- nearest_neighbor(neighbors=5) %>%
    set_mode("classification") %>%
    set_engine("kknn")
  
# define and execute the cross-validation workflow
nn_wf <- workflow() %>%
    add_model(nn_model) %>%
    add_formula(formula)

nn_fit_boot <- nn_wf %>% 
    fit_resamples(resamples, control=control_resamples(save_pred=TRUE))

If you compare the code for bootstrap sampling with the code for cross-validation, you will notice that the only difference is the call to bootstraps instead of vfold_cv and the use of a different model. The rest of the code is identical.

Let’s have a look at the results.

Code
nn_fit_boot
## # Resampling results
## # Bootstrap sampling 
## # A tibble: 25 × 5
##    splits              id          .metrics         .notes   .predictions
##    <list>              <chr>       <list>           <list>   <list>      
##  1 <split [5000/1846]> Bootstrap01 <tibble [3 × 4]> <tibble> <tibble>    
##  2 <split [5000/1809]> Bootstrap02 <tibble [3 × 4]> <tibble> <tibble>    
##  3 <split [5000/1874]> Bootstrap03 <tibble [3 × 4]> <tibble> <tibble>    
##  4 <split [5000/1835]> Bootstrap04 <tibble [3 × 4]> <tibble> <tibble>    
##  5 <split [5000/1845]> Bootstrap05 <tibble [3 × 4]> <tibble> <tibble>    
##  6 <split [5000/1856]> Bootstrap06 <tibble [3 × 4]> <tibble> <tibble>    
##  7 <split [5000/1839]> Bootstrap07 <tibble [3 × 4]> <tibble> <tibble>    
##  8 <split [5000/1841]> Bootstrap08 <tibble [3 × 4]> <tibble> <tibble>    
##  9 <split [5000/1859]> Bootstrap09 <tibble [3 × 4]> <tibble> <tibble>    
## 10 <split [5000/1834]> Bootstrap10 <tibble [3 × 4]> <tibble> <tibble>    
## # ℹ 15 more rows

Again, this is not very informative. We have 25 bootstrap samples and various columns that contain information about each sample. Because we used save_pred=TRUE in the fit_resamples call, we also have the .predictions column with the individual predictions.

We can use collect_metrics to extract the performance metrics for each fold and compare the result to the cross-validation results for the logistic regression model.10

Code
boot_metrics <- collect_metrics(nn_fit_boot)

bind_rows(cv_metrics %>% mutate(model='Logistic regression'), 
          boot_metrics %>% mutate(model='Nearest neighbor')) %>%
    select(model, mean, .metric) %>%
    pivot_wider(names_from=.metric, values_from=mean) %>% 
    knitr::kable(digits = 3) %>%
    kableExtra::kable_styling(full_width=FALSE)
model accuracy brier_class roc_auc
Logistic regression 0.958 0.032 0.961
Nearest neighbor 0.961 0.032 0.918

Base on accuracy, we would conclude that the nearest neighbor model is better than the logistic regression model. However, the AUC for the nearest neighbor model is significantly lower.

Let’s see if this is reflected in the ROC curves of the two models. Figure 13.4 compares the ROC curves for the bootstrap predictions and the predictions on the training set.

Code
# Train a model on the full dataset
full_model <- nn_wf %>% fit(data)

boot_ROC <- collect_predictions(nn_fit_boot) %>% 
    roc_curve(truth=Personal.Loan, .pred_Yes, event_level="first")

ontrain_ROC <- augment(full_model, new_data=data) %>% 
    roc_curve(Personal.Loan, .pred_Yes, event_level="first")

ggplot() +
    geom_path(data=boot_ROC, aes(x=1-specificity, y=sensitivity)) +
    geom_path(data=ontrain_ROC, aes(x=1-specificity, y=sensitivity), color="red") +
    geom_path(data=cv_ROC, aes(x=1-specificity, y=sensitivity), color="grey") +
    geom_abline(lty=2)
Comparison of ROC curves for bootstrap predictions (black) and on-training set predictions (red) for a nearest-neighbor model. For comparison, the ROC curve for the logistic regression model is overlaid as well in grey.

Figure 13.4: Comparison of ROC curves for bootstrap predictions (black) and on-training set predictions (red) for a nearest-neighbor model. For comparison, the ROC curve for the logistic regression model is overlaid as well in grey.

Let’s first compare the ROC curves for the bootstrap predictions and the predictions on the training set. The ROC curves for prediction of the full model on the training set (red curve) represents an ideal model, i.e. every data point is correctly predicted. This is expected for a nearest neighbor model. This observations emphasizes the importance of using a holdout set or cross-validation to assess model performance.

The ROC curve for the bootstrap predictions (black curve) is more realistic. It is similar to the ROC curve for the cross-validation predictions (grey curve). However, we can also see that at the beginning, the ROC curve for the bootstrap predictions is below the ROC curve for the cross-validation predictions. This confirms what we’ve seen from the AUC values. However, we have not explored different numbers of neighbors. In fact, as we will see in Section 15.4, increasing the number of neighbors will improve the performance of the nearest neighbor model and ultimately result in a model that has a better ROC curve compared to the logistic regression model.

13.3.1 Distribution of metrics for bootstrap samples

We can also use the bootstrap samples to assess the distribution of the performance metrics. In this case, it is however better to increase the number of resamples. Here, we use 1000 bootstrap samples.

Code
set.seed(123)
nn_fit_boot <- nn_wf %>% 
    fit_resamples(rsample::bootstraps(data, times=1000), control=control_resamples(save_pred=TRUE))

We repeat the bootstrap validation more times. By default, if we call collect_metrics we get mean and standard error for each metric.

Code
nn_fit_boot %>% 
    collect_metrics() 
## # A tibble: 3 × 6
##   .metric     .estimator   mean     n  std_err .config             
##   <chr>       <chr>       <dbl> <int>    <dbl> <chr>               
## 1 accuracy    binary     0.961   1000 0.000131 Preprocessor1_Model1
## 2 brier_class binary     0.0329  1000 0.000102 Preprocessor1_Model1
## 3 roc_auc     binary     0.913   1000 0.000428 Preprocessor1_Model1

However, if we use the summarize=FALSE argument, we get the performance metrics calculated for each bootstrap sample.

Code
nn_fit_boot %>% 
    collect_metrics(summarize=FALSE) %>%
    head()
## # A tibble: 6 × 5
##   id            .metric     .estimator .estimate .config             
##   <chr>         <chr>       <chr>          <dbl> <chr>               
## 1 Bootstrap0001 accuracy    binary        0.956  Preprocessor1_Model1
## 2 Bootstrap0001 roc_auc     binary        0.917  Preprocessor1_Model1
## 3 Bootstrap0001 brier_class binary        0.0350 Preprocessor1_Model1
## 4 Bootstrap0002 accuracy    binary        0.953  Preprocessor1_Model1
## 5 Bootstrap0002 roc_auc     binary        0.906  Preprocessor1_Model1
## 6 Bootstrap0002 brier_class binary        0.0364 Preprocessor1_Model1

This allows us to calculate the distribution of the performance metrics. Figure 13.5 shows the distribution of the two metrics for the bootstrap samples.

Code
quantiles <- nn_fit_boot %>% 
    collect_metrics(summarize=FALSE) %>%
    group_by(.metric) %>%
    summarize(
        q0.025 = quantile(.estimate, 0.025),
        median = quantile(.estimate, 0.5),
        q0.975 = quantile(.estimate, 0.975)
    )
nn_fit_boot %>% 
    collect_metrics(summarize=FALSE) %>%
ggplot(aes(x=.estimate)) +
    geom_histogram(bins=50) +
    facet_wrap(~.metric, scales="free") +
    geom_vline(data=quantiles, aes(xintercept=median), color="blue") +
    geom_vline(data=quantiles, aes(xintercept=q0.025), color="blue", linetype="dashed") +
    geom_vline(data=quantiles, aes(xintercept=q0.975), color="blue", linetype="dashed")
Distribution of performance metrics for bootstrap samples; the blue lines show the median and the 95% confidence interval

Figure 13.5: Distribution of performance metrics for bootstrap samples; the blue lines show the median and the 95% confidence interval

Further information:

Code
stopCluster(cl)
registerDoSEQ()

Code

The code of this chapter is summarized here.

Code
knitr::opts_chunk$set(echo=TRUE, cache=TRUE, autodep=TRUE, fig.align="center")
knitr::include_graphics("images/model_workflow_validate.png")
library(tidyverse)
library(tidymodels)
library(patchwork)
library(doParallel)
cl <- makePSOCKcluster(parallel::detectCores(logical = FALSE))
registerDoParallel(cl)
# Load an preprocess the data
data <- datasets::mtcars %>% 
    as_tibble(rownames="car") %>%
    mutate(
        vs = factor(vs, labels=c("V-shaped", "straight")),
        am = factor(am, labels=c("automatic", "manual")),
    )

# Split the data into training and test/holdout sets
set.seed(1353)
car_split <- initial_split(mtcars)
train_data <- training(car_split)
holdout_data <- testing(car_split)

# Train a model using the training set
formula <- mpg ~ cyl + disp + hp + drat + wt + qsec + vs + am + gear + carb
model <- linear_reg() %>% 
    set_engine("lm") %>% 
    fit(formula, data=train_data)
perf_train <- augment(model, new_data=train_data) %>% 
    metrics(truth=mpg, estimate=.pred) %>% 
    mutate(data='Training')
perf_holdout <- augment(model, new_data=holdout_data) %>%
    metrics(truth=mpg, estimate=.pred) %>% 
    mutate(data='Holdout')
bind_rows(perf_train, perf_holdout) %>%  # combine the two results
    select(data, .estimate, .metric) %>%  # select the columns of interest
    pivot_wider(names_from=.metric, values_from=.estimate) %>%  # convert to wide format
    knitr::kable(digits = 2) %>%
    kableExtra::kable_styling(full_width=FALSE)
# Load and preprocess the data
data <- read_csv("https://gedeck.github.io/DS-6030/datasets/UniversalBank.csv.gz")
data <- data %>% 
    select(-c(ID, `ZIP Code`)) %>%
    rename(
        Personal.Loan = `Personal Loan`,
        Securities.Account = `Securities Account`,
        CD.Account = `CD Account`
    ) %>%
    mutate(
        Personal.Loan = factor(Personal.Loan, labels=c("Yes", "No"), levels=c(1, 0)),
        Education = factor(Education, labels=c("Undergrad", "Graduate", "Advanced")),
    )
# Use 10-fold cross-validation to assess model performance
set.seed(1353)
folds <- vfold_cv(data, strata=Personal.Loan)

# define the model
formula <- Personal.Loan ~ Age + Experience + Income + Family + CCAvg + Education + 
                           Mortgage + Securities.Account + CD.Account + Online + 
                           CreditCard
logreg_model <- logistic_reg() %>% 
    set_engine("glm")
  
# define and execute the cross-validation workflow
logreg_wf <- workflow() %>%
    add_model(logreg_model) %>%
    add_formula(formula)

logreg_fit_cv <- logreg_wf %>% 
    fit_resamples(folds)
logreg_fit_cv
cv_metrics <- collect_metrics(logreg_fit_cv)
cv_metrics
logreg_fit_cv <- logreg_wf %>% 
    fit_resamples(folds, control=control_resamples(save_pred=TRUE))
cv_predictions <- collect_predictions(logreg_fit_cv)
cv_predictions
cv_predictions %>% 
    group_by(id) %>% 
    roc_curve(truth=Personal.Loan, .pred_Yes, event_level="first") %>% 
    autoplot() +
        geom_abline(lty=2) +
        theme(legend.position="none")
# Train a model on the full dataset
full_model <- logistic_reg() %>% 
    set_engine("glm") %>% 
    fit(formula, data=data)

cv_ROC <- cv_predictions %>% 
    roc_curve(truth=Personal.Loan, .pred_Yes, event_level="first")

ontrain_ROC <- augment(full_model, new_data=data) %>% 
    roc_curve(Personal.Loan, .pred_Yes, event_level="first")

ggplot() +
    geom_path(data=cv_ROC, aes(x=1-specificity, y=sensitivity)) +
    geom_path(data=ontrain_ROC, aes(x=1-specificity, y=sensitivity), color="red") +
    geom_abline(lty=2)
# Use bootstrap to assess model performance
set.seed(1353)
resamples <- rsample::bootstraps(data)

# define the model
formula <- Personal.Loan ~ Age + Experience + Income + Family + CCAvg + Education + 
                           Mortgage + Securities.Account + CD.Account + Online + 
                           CreditCard
nn_model <- nearest_neighbor(neighbors=5) %>%
    set_mode("classification") %>%
    set_engine("kknn")
  
# define and execute the cross-validation workflow
nn_wf <- workflow() %>%
    add_model(nn_model) %>%
    add_formula(formula)

nn_fit_boot <- nn_wf %>% 
    fit_resamples(resamples, control=control_resamples(save_pred=TRUE))
nn_fit_boot
boot_metrics <- collect_metrics(nn_fit_boot)

bind_rows(cv_metrics %>% mutate(model='Logistic regression'), 
          boot_metrics %>% mutate(model='Nearest neighbor')) %>%
    select(model, mean, .metric) %>%
    pivot_wider(names_from=.metric, values_from=mean) %>% 
    knitr::kable(digits = 3) %>%
    kableExtra::kable_styling(full_width=FALSE)
# Train a model on the full dataset
full_model <- nn_wf %>% fit(data)

boot_ROC <- collect_predictions(nn_fit_boot) %>% 
    roc_curve(truth=Personal.Loan, .pred_Yes, event_level="first")

ontrain_ROC <- augment(full_model, new_data=data) %>% 
    roc_curve(Personal.Loan, .pred_Yes, event_level="first")

ggplot() +
    geom_path(data=boot_ROC, aes(x=1-specificity, y=sensitivity)) +
    geom_path(data=ontrain_ROC, aes(x=1-specificity, y=sensitivity), color="red") +
    geom_path(data=cv_ROC, aes(x=1-specificity, y=sensitivity), color="grey") +
    geom_abline(lty=2)
set.seed(123)
nn_fit_boot <- nn_wf %>% 
    fit_resamples(rsample::bootstraps(data, times=1000), control=control_resamples(save_pred=TRUE))
nn_fit_boot %>% 
    collect_metrics() 
nn_fit_boot %>% 
    collect_metrics(summarize=FALSE) %>%
    head()
quantiles <- nn_fit_boot %>% 
    collect_metrics(summarize=FALSE) %>%
    group_by(.metric) %>%
    summarize(
        q0.025 = quantile(.estimate, 0.025),
        median = quantile(.estimate, 0.5),
        q0.975 = quantile(.estimate, 0.975)
    )
nn_fit_boot %>% 
    collect_metrics(summarize=FALSE) %>%
ggplot(aes(x=.estimate)) +
    geom_histogram(bins=50) +
    facet_wrap(~.metric, scales="free") +
    geom_vline(data=quantiles, aes(xintercept=median), color="blue") +
    geom_vline(data=quantiles, aes(xintercept=q0.025), color="blue", linetype="dashed") +
    geom_vline(data=quantiles, aes(xintercept=q0.975), color="blue", linetype="dashed")
stopCluster(cl)
registerDoSEQ()

  1. This means, the fold was used as the validation set and the remaining folds were used for training.↩︎

  2. Be careful for typos here. There is also the broom::bootstrap function, which will give you a missing argument warning.↩︎

  3. This is for demonstration only! In practice, you would use the same validation approach for both models.↩︎