Chapter 23 Visualizing decision tree models
Decision tree models are popular because they are easy to interpret and to understand. This lead to the development of packages to facilitate the analysis of decision trees. Here, we will focus on the ggparty
package and demonstrate some of its features using the ISLR2::Carseats
data set.
Load required libraries
Prepare the ISLR2::Carseats data set for classification. The Sales
variable is converted to a factor with two levels Yes
and No
based on the median value of Sales
(7.49). The Sales
variable is then removed from the data set.
Code
23.1 Classification Trees
We first train a model using the default settings.
I’ve not been able to use the ggparty
functionality with a decision tree model created using a workflow. If you use a workflow to tune a decision tree, you will need to create a separate decision tree model with the settings from the tuning.
23.1.1 Visualizing the tree (graph)
To use the ggparty
visualization, we need to extract the rpart
model and convert it into a format suitable for this package using the function partykit::as.party(model$fit)
.
Figure 23.1 shows the tree visualization created by the autoplot
implementation of the party
object.
There are many options to customize the visualization. Figure 23.2 shows the tree visualization created by the ggparty
implementation of the party
object. The pie charts show the distribution of the two classes and the number of training data points in the terminal nodes.
Code
23.1.2 Visualizing the tree (text)
The conversion of the rpart
model into a party
object also allows us to print the tree as text in a format that is clearer than the default rpart
output.
##
## Model formula:
## High ~ CompPrice + Income + Advertising + Population + Price +
## ShelveLoc + Age + Education + Urban + US
##
## Fitted party:
## [1] root
## | [2] ShelveLoc in Bad, Medium
## | | [3] Price >= 105.5
## | | | [4] Advertising < 10.5
## | | | | [5] CompPrice < 143.5: No (n = 121, err = 11.6%)
## | | | | [6] CompPrice >= 143.5
## | | | | | [7] Price >= 145: No (n = 7, err = 0.0%)
## | | | | | [8] Price < 145: Yes (n = 16, err = 25.0%)
## | | | [9] Advertising >= 10.5
## | | | | [10] Price >= 126.5: No (n = 32, err = 18.8%)
## | | | | [11] Price < 126.5
## | | | | | [12] CompPrice < 121.5: No (n = 10, err = 20.0%)
## | | | | | [13] CompPrice >= 121.5: Yes (n = 21, err = 0.0%)
## | | [14] Price < 105.5
## | | | [15] Age >= 68.5
## | | | | [16] Price >= 91.5: No (n = 15, err = 13.3%)
## | | | | [17] Price < 91.5: Yes (n = 8, err = 37.5%)
## | | | [18] Age < 68.5
## | | | | [19] Advertising < 8.5
## | | | | | [20] CompPrice < 125
## | | | | | | [21] Price >= 86.5
## | | | | | | | [22] Urban in Yes: No (n = 18, err = 16.7%)
## | | | | | | | [23] Urban in No: Yes (n = 7, err = 28.6%)
## | | | | | | [24] Price < 86.5: Yes (n = 14, err = 14.3%)
## | | | | | [25] CompPrice >= 125: Yes (n = 12, err = 0.0%)
## | | | | [26] Advertising >= 8.5: Yes (n = 34, err = 2.9%)
## | [27] ShelveLoc in Good
## | | [28] Price >= 135: No (n = 17, err = 41.2%)
## | | [29] Price < 135: Yes (n = 68, err = 4.4%)
##
## Number of inner nodes: 14
## Number of terminal nodes: 15
Let’s look at node [7]
. The path is
- ShelveLoc in Bad, Medium
- Price >= 105.5
- Advertising < 10.5
- CompPrice >= 143.5
- Price >= 145
The distribution at that node is:
No (n = 7, err = 0.0%)
We predict that High
is No
with 0% error.
23.1.3 Visualizing the tree (rules)
The example above shows that Price
occurred twice in the decision path. We can combine the two rules into one rule.
It can be difficult to extract the decision path for complex trees. The rpart.plot
package
has a function to convert the tree into a set of rules:
- ShelveLoc in Bad, Medium
- Advertising < 10.5
- CompPrice >= 143.5
- Price >= 145
The function rpart.plot::rpart.rules
can be used to extract the rules from the tree.
rpart.plot::rpart.rules(model$fit, style="tallw")
Abbreviated output, the first rule corresponds to node [7]
.
## High is 0.00 when
## ShelveLoc is Bad or Medium
## Price >= 145
## Advertising < 10.5
## CompPrice >= 144
##
## High is 0.12 when
## ShelveLoc is Bad or Medium
## Price >= 106
## Advertising < 10.5
## CompPrice < 144
##
## High is 0.13 when
## ShelveLoc is Bad or Medium
## Price is 92 to 106
## Age >= 69
....
23.2 Regression Trees
We now train a regression model to predict Sales
in the ISLR2::Carseats
dataset using the default settings.
Code
23.2.1 Visualizing the tree (graph)
Figure 23.3 shows the tree visualization created by the autoplot
implementation of the party
object.
Code
ggparty::ggparty(partykit::as.party(model$fit), horizontal=TRUE) +
ggparty::geom_edge() +
ggparty::geom_edge_label() +
ggparty::geom_node_label(aes(label=splitvar), ids="inner") +
ggparty::geom_node_plot(gglist=list(geom_histogram(aes(x=Sales), binwidth=3),
theme(axis.title.x = element_blank(),
axis.title.y = element_blank())))
Further information:
There are many ways of customizing the visualization and you can find plenty of examples and resources on the internet. See the ggparty wiki.
Code
The code of this chapter is summarized here.
Code
knitr::opts_chunk$set(echo=TRUE, cache=TRUE, autodep=TRUE, fig.align="center")
library(tidyverse)
library(tidymodels)
library(ggparty)
carseats = tibble(ISLR2::Carseats) %>%
mutate(
High=factor(ifelse(Sales <= median(Sales), "No", "Yes"))
) %>%
dplyr::select(-c(Sales))
model <- decision_tree(mode="classification", engine="rpart") %>%
fit(High ~ ., data=carseats)
autoplot(partykit::as.party(model$fit))
ggparty::ggparty(partykit::as.party(model$fit)) +
ggparty::geom_edge() +
ggparty::geom_edge_label() +
ggparty::geom_node_label(aes(label=splitvar), ids="inner") +
ggparty::geom_node_plot(gglist=list(geom_bar(aes(x="",fill=High)),
coord_polar("y"),
theme_void()))
tree.carseats.party = partykit::as.party(model$fit)
tree.carseats.party
model <- decision_tree(mode="regression", engine="rpart") %>%
fit(Sales ~ ., data=ISLR2::Carseats)
ggparty::ggparty(partykit::as.party(model$fit), horizontal=TRUE) +
ggparty::geom_edge() +
ggparty::geom_edge_label() +
ggparty::geom_node_label(aes(label=splitvar), ids="inner") +
ggparty::geom_node_plot(gglist=list(geom_histogram(aes(x=Sales), binwidth=3),
theme(axis.title.x = element_blank(),
axis.title.y = element_blank())))