Chapter 23 Visualizing decision tree models
Decision tree models are popular because they are easy to interpret and to understand. This lead to the development of packages to facilitate the analysis of decision trees. Here, we will focus on the ggparty package and demonstrate some of its features using the ISLR2::Carseats data set.
Load required libraries
Prepare the ISLR2::Carseats data set for classification. The Sales variable is converted to a factor with two levels Yes and No based on the median value of Sales (7.49). The Sales variable is then removed from the data set.
Code
23.1 Classification Trees
We first train a model using the default settings.
I’ve not been able to use the ggparty functionality with a decision tree model created using a workflow. If you use a workflow to tune a decision tree, you will need to create a separate decision tree model with the settings from the tuning.
23.1.1 Visualizing the tree (graph)
To use the ggparty visualization, we need to extract the rpart model and convert it into a format suitable for this package using the function partykit::as.party(model$fit).
Figure 23.1 shows the tree visualization created by the autoplot implementation of the party object.
Figure 23.1: Decision tree visualization of default model (autoplot)
There are many options to customize the visualization. Figure 23.2 shows the tree visualization created by the ggparty implementation of the party object. The pie charts show the distribution of the two classes and the number of training data points in the terminal nodes.
Code
Figure 23.2: Decision tree visualization using pie charts
23.1.2 Visualizing the tree (text)
The conversion of the rpart model into a party object also allows us to print the tree as text in a format that is clearer than the default rpart output.
##
## Model formula:
## High ~ CompPrice + Income + Advertising + Population + Price +
## ShelveLoc + Age + Education + Urban + US
##
## Fitted party:
## [1] root
## | [2] ShelveLoc in Bad, Medium
## | | [3] Price >= 105.5
## | | | [4] Advertising < 10.5
## | | | | [5] CompPrice < 143.5: No (n = 121, err = 11.6%)
## | | | | [6] CompPrice >= 143.5
## | | | | | [7] Price >= 145: No (n = 7, err = 0.0%)
## | | | | | [8] Price < 145: Yes (n = 16, err = 25.0%)
## | | | [9] Advertising >= 10.5
## | | | | [10] Price >= 126.5: No (n = 32, err = 18.8%)
## | | | | [11] Price < 126.5
## | | | | | [12] CompPrice < 121.5: No (n = 10, err = 20.0%)
## | | | | | [13] CompPrice >= 121.5: Yes (n = 21, err = 0.0%)
## | | [14] Price < 105.5
## | | | [15] Age >= 68.5
## | | | | [16] Price >= 91.5: No (n = 15, err = 13.3%)
## | | | | [17] Price < 91.5: Yes (n = 8, err = 37.5%)
## | | | [18] Age < 68.5
## | | | | [19] Advertising < 8.5
## | | | | | [20] CompPrice < 125
## | | | | | | [21] Price >= 86.5
## | | | | | | | [22] Urban in Yes: No (n = 18, err = 16.7%)
## | | | | | | | [23] Urban in No: Yes (n = 7, err = 28.6%)
## | | | | | | [24] Price < 86.5: Yes (n = 14, err = 14.3%)
## | | | | | [25] CompPrice >= 125: Yes (n = 12, err = 0.0%)
## | | | | [26] Advertising >= 8.5: Yes (n = 34, err = 2.9%)
## | [27] ShelveLoc in Good
## | | [28] Price >= 135: No (n = 17, err = 41.2%)
## | | [29] Price < 135: Yes (n = 68, err = 4.4%)
##
## Number of inner nodes: 14
## Number of terminal nodes: 15
Let’s look at node [7]. The path is
- ShelveLoc in Bad, Medium
- Price >= 105.5
- Advertising < 10.5
- CompPrice >= 143.5
- Price >= 145
The distribution at that node is:
No (n = 7, err = 0.0%)
We predict that High is No with 0% error.
23.1.3 Visualizing the tree (rules)
The example above shows that Price occurred twice in the decision path. We can combine the two rules into one rule.
It can be difficult to extract the decision path for complex trees. The rpart.plot package
has a function to convert the tree into a set of rules:
- ShelveLoc in Bad, Medium
- Advertising < 10.5
- CompPrice >= 143.5
- Price >= 145
The function rpart.plot::rpart.rules can be used to extract the rules from the tree.
rpart.plot::rpart.rules(model$fit, style="tallw")
Abbreviated output, the first rule corresponds to node [7].
## High is 0.00 when
## ShelveLoc is Bad or Medium
## Price >= 145
## Advertising < 10.5
## CompPrice >= 144
##
## High is 0.12 when
## ShelveLoc is Bad or Medium
## Price >= 106
## Advertising < 10.5
## CompPrice < 144
##
## High is 0.13 when
## ShelveLoc is Bad or Medium
## Price is 92 to 106
## Age >= 69
....
23.2 Regression Trees
We now train a regression model to predict Sales in the ISLR2::Carseats dataset using the default settings.
Code
23.2.1 Visualizing the tree (graph)
Figure 23.3 shows the tree visualization created by the autoplot implementation of the party object.
Code
ggparty::ggparty(partykit::as.party(model$fit), horizontal=TRUE) +
ggparty::geom_edge() +
ggparty::geom_edge_label() +
ggparty::geom_node_label(aes(label=splitvar), ids="inner") +
ggparty::geom_node_plot(gglist=list(geom_histogram(aes(x=Sales), binwidth=3),
theme(axis.title.x = element_blank(),
axis.title.y = element_blank())))
Figure 23.3: Decision tree visualization of regression model
Further information:
There are many ways of customizing the visualization and you can find plenty of examples and resources on the internet. See the ggparty wiki.
Code
The code of this chapter is summarized here.
Code
knitr::opts_chunk$set(echo=TRUE, cache=TRUE, autodep=TRUE, fig.align="center")
library(tidyverse)
library(tidymodels)
library(ggparty)
carseats <- tibble(ISLR2::Carseats) %>%
mutate(
High=factor(ifelse(Sales <= median(Sales), "No", "Yes"))
) %>%
dplyr::select(-c(Sales))
model <- decision_tree(mode="classification", engine="rpart") %>%
fit(High ~ ., data=carseats)
autoplot(partykit::as.party(model$fit))
ggparty::ggparty(partykit::as.party(model$fit)) +
ggparty::geom_edge() +
ggparty::geom_edge_label() +
ggparty::geom_node_label(aes(label=splitvar), ids="inner") +
ggparty::geom_node_plot(gglist=list(geom_bar(aes(x="", fill=High)),
coord_polar("y"),
theme_void()))
tree_carseats_party <- partykit::as.party(model$fit)
tree_carseats_party
model <- decision_tree(mode="regression", engine="rpart") %>%
fit(Sales ~ ., data=ISLR2::Carseats)
ggparty::ggparty(partykit::as.party(model$fit), horizontal=TRUE) +
ggparty::geom_edge() +
ggparty::geom_edge_label() +
ggparty::geom_node_label(aes(label=splitvar), ids="inner") +
ggparty::geom_node_plot(gglist=list(geom_histogram(aes(x=Sales), binwidth=3),
theme(axis.title.x = element_blank(),
axis.title.y = element_blank())))