Chapter 23 Visualizing decision tree models

Decision tree models are popular because they are easy to interpret and to understand. This lead to the development of packages to facilitate the analysis of decision trees. Here, we will focus on the ggparty package and demonstrate some of its features using the ISLR2::Carseats data set.

Load required libraries

Code
library(tidyverse)
library(tidymodels)
library(ggparty)

Prepare the ISLR2::Carseats data set for classification. The Sales variable is converted to a factor with two levels Yes and No based on the median value of Sales (7.49). The Sales variable is then removed from the data set.

Code
carseats <- tibble(ISLR2::Carseats) %>%
    mutate(
        High=factor(ifelse(Sales <= median(Sales), "No", "Yes"))
    ) %>%
    dplyr::select(-c(Sales))

23.1 Classification Trees

We first train a model using the default settings.

Code
model <- decision_tree(mode="classification", engine="rpart") %>%
    fit(High ~ ., data=carseats)

I’ve not been able to use the ggparty functionality with a decision tree model created using a workflow. If you use a workflow to tune a decision tree, you will need to create a separate decision tree model with the settings from the tuning.

23.1.1 Visualizing the tree (graph)

To use the ggparty visualization, we need to extract the rpart model and convert it into a format suitable for this package using the function partykit::as.party(model$fit).

Figure 23.1 shows the tree visualization created by the autoplot implementation of the party object.

Code
autoplot(partykit::as.party(model$fit))
Decision tree visualization of default model (autoplot)

Figure 23.1: Decision tree visualization of default model (autoplot)

There are many options to customize the visualization. Figure 23.2 shows the tree visualization created by the ggparty implementation of the party object. The pie charts show the distribution of the two classes and the number of training data points in the terminal nodes.

Code
ggparty::ggparty(partykit::as.party(model$fit)) +
    ggparty::geom_edge() +
    ggparty::geom_edge_label() +
    ggparty::geom_node_label(aes(label=splitvar), ids="inner") +
    ggparty::geom_node_plot(gglist=list(geom_bar(aes(x="", fill=High)),
                                        coord_polar("y"),
                                        theme_void()))
Decision tree visualization using pie charts

Figure 23.2: Decision tree visualization using pie charts

23.1.2 Visualizing the tree (text)

The conversion of the rpart model into a party object also allows us to print the tree as text in a format that is clearer than the default rpart output.

Code
tree_carseats_party <- partykit::as.party(model$fit)
tree_carseats_party
## 
## Model formula:
## High ~ CompPrice + Income + Advertising + Population + Price + 
##     ShelveLoc + Age + Education + Urban + US
## 
## Fitted party:
## [1] root
## |   [2] ShelveLoc in Bad, Medium
## |   |   [3] Price >= 105.5
## |   |   |   [4] Advertising < 10.5
## |   |   |   |   [5] CompPrice < 143.5: No (n = 121, err = 11.6%)
## |   |   |   |   [6] CompPrice >= 143.5
## |   |   |   |   |   [7] Price >= 145: No (n = 7, err = 0.0%)
## |   |   |   |   |   [8] Price < 145: Yes (n = 16, err = 25.0%)
## |   |   |   [9] Advertising >= 10.5
## |   |   |   |   [10] Price >= 126.5: No (n = 32, err = 18.8%)
## |   |   |   |   [11] Price < 126.5
## |   |   |   |   |   [12] CompPrice < 121.5: No (n = 10, err = 20.0%)
## |   |   |   |   |   [13] CompPrice >= 121.5: Yes (n = 21, err = 0.0%)
## |   |   [14] Price < 105.5
## |   |   |   [15] Age >= 68.5
## |   |   |   |   [16] Price >= 91.5: No (n = 15, err = 13.3%)
## |   |   |   |   [17] Price < 91.5: Yes (n = 8, err = 37.5%)
## |   |   |   [18] Age < 68.5
## |   |   |   |   [19] Advertising < 8.5
## |   |   |   |   |   [20] CompPrice < 125
## |   |   |   |   |   |   [21] Price >= 86.5
## |   |   |   |   |   |   |   [22] Urban in Yes: No (n = 18, err = 16.7%)
## |   |   |   |   |   |   |   [23] Urban in No: Yes (n = 7, err = 28.6%)
## |   |   |   |   |   |   [24] Price < 86.5: Yes (n = 14, err = 14.3%)
## |   |   |   |   |   [25] CompPrice >= 125: Yes (n = 12, err = 0.0%)
## |   |   |   |   [26] Advertising >= 8.5: Yes (n = 34, err = 2.9%)
## |   [27] ShelveLoc in Good
## |   |   [28] Price >= 135: No (n = 17, err = 41.2%)
## |   |   [29] Price < 135: Yes (n = 68, err = 4.4%)
## 
## Number of inner nodes:    14
## Number of terminal nodes: 15

Let’s look at node [7]. The path is

  • ShelveLoc in Bad, Medium
  • Price >= 105.5
  • Advertising < 10.5
  • CompPrice >= 143.5
  • Price >= 145

The distribution at that node is:

No (n = 7, err = 0.0%)

We predict that High is No with 0% error.

23.1.3 Visualizing the tree (rules)

The example above shows that Price occurred twice in the decision path. We can combine the two rules into one rule.

It can be difficult to extract the decision path for complex trees. The rpart.plot package has a function to convert the tree into a set of rules:

  • ShelveLoc in Bad, Medium
  • Advertising < 10.5
  • CompPrice >= 143.5
  • Price >= 145

The function rpart.plot::rpart.rules can be used to extract the rules from the tree.

rpart.plot::rpart.rules(model$fit, style="tallw")

Abbreviated output, the first rule corresponds to node [7].

## High is 0.00 when
##              ShelveLoc is Bad or Medium
##              Price >= 145
##              Advertising < 10.5
##              CompPrice >= 144
##
## High is 0.12 when
##              ShelveLoc is Bad or Medium
##              Price >= 106
##              Advertising < 10.5
##              CompPrice < 144
##
## High is 0.13 when
##              ShelveLoc is Bad or Medium
##              Price is 92 to 106
##              Age >= 69
....

23.2 Regression Trees

We now train a regression model to predict Sales in the ISLR2::Carseats dataset using the default settings.

Code
model <- decision_tree(mode="regression", engine="rpart") %>%
    fit(Sales ~ ., data=ISLR2::Carseats)

23.2.1 Visualizing the tree (graph)

Figure 23.3 shows the tree visualization created by the autoplot implementation of the party object.

Code
ggparty::ggparty(partykit::as.party(model$fit), horizontal=TRUE) +
    ggparty::geom_edge() +
    ggparty::geom_edge_label() +
    ggparty::geom_node_label(aes(label=splitvar), ids="inner") +
    ggparty::geom_node_plot(gglist=list(geom_histogram(aes(x=Sales), binwidth=3),
                                        theme(axis.title.x = element_blank(),
                                              axis.title.y = element_blank())))
Decision tree visualization of regression model

Figure 23.3: Decision tree visualization of regression model

Further information:

There are many ways of customizing the visualization and you can find plenty of examples and resources on the internet. See the ggparty wiki.

Code

The code of this chapter is summarized here.

Code
knitr::opts_chunk$set(echo=TRUE, cache=TRUE, autodep=TRUE, fig.align="center")
library(tidyverse)
library(tidymodels)
library(ggparty)
carseats <- tibble(ISLR2::Carseats) %>%
    mutate(
        High=factor(ifelse(Sales <= median(Sales), "No", "Yes"))
    ) %>%
    dplyr::select(-c(Sales))
model <- decision_tree(mode="classification", engine="rpart") %>%
    fit(High ~ ., data=carseats)
autoplot(partykit::as.party(model$fit))
ggparty::ggparty(partykit::as.party(model$fit)) +
    ggparty::geom_edge() +
    ggparty::geom_edge_label() +
    ggparty::geom_node_label(aes(label=splitvar), ids="inner") +
    ggparty::geom_node_plot(gglist=list(geom_bar(aes(x="", fill=High)),
                                        coord_polar("y"),
                                        theme_void()))
tree_carseats_party <- partykit::as.party(model$fit)
tree_carseats_party
model <- decision_tree(mode="regression", engine="rpart") %>%
    fit(Sales ~ ., data=ISLR2::Carseats)
ggparty::ggparty(partykit::as.party(model$fit), horizontal=TRUE) +
    ggparty::geom_edge() +
    ggparty::geom_edge_label() +
    ggparty::geom_node_label(aes(label=splitvar), ids="inner") +
    ggparty::geom_node_plot(gglist=list(geom_histogram(aes(x=Sales), binwidth=3),
                                        theme(axis.title.x = element_blank(),
                                              axis.title.y = element_blank())))