Chapter 19 Clustering

Tidymodels provides a framework for clustering with the tidyclust package. It currently supports the following clustering algorithms:

  • k-means clustering (k_means())
  • hierarchical clustering (hier_clust())

Clustering methods are defined similarly to predictive models in tidymodels (parsnip). This means each of the methods can use different engines and we can combine we can define clustering with a preprocessing recipe in a workflow.

Load the packages we need for this chapter.

Code
library(tidymodels)
library(tidyclust)
library(kableExtra)
library(patchwork)
library(GGally)
library(doParallel)

Because tuning requires training many models, we also enable parallel computing.

Code
library(doParallel)
cl <- makePSOCKcluster(parallel::detectCores(logical = FALSE))
registerDoParallel(cl)

19.1 k-means clustering

The k_means() function is a wrapper around four different packages. Here is an example using the default stats::kmeans engine to cluster the penguins dataset into three clusters.

Code
# k-means clustering uses a random starting point,
# so we set a seed for reproducibility
set.seed(123)
penguins <- modeldata::penguins %>% drop_na()

formula <- ~ bill_length_mm + bill_depth_mm + flipper_length_mm + body_mass_g
rec_penguins <- recipe(formula, data=penguins) %>%
    step_normalize(all_predictors())
kmeans_penguins <- k_means(num_clusters=3) %>%
    set_engine("stats") %>%
    set_mode("partition")
kmeans_wf <- workflow() %>%
    add_recipe(rec_penguins) %>%
    add_model(kmeans_penguins)

Note that the preprocessing includes a normalization step. This is recommended so that all predictors have the same scale. The kmeans_wf object can be used to fit the model.

Code
kmeans_model <- kmeans_wf %>% fit(data=penguins)

The tidy function gives us a concise overview of the results.

Code
tidy(kmeans_model) %>% knitr::kable() %>% scroll_box(width = "100%")
bill_length_mm bill_depth_mm flipper_length_mm body_mass_g size withinss cluster
-1.0452359 0.4858944 -0.8803701 -0.7616078 129 120.7030 1
0.6710153 0.8040534 -0.2889118 -0.3835267 85 109.4813 2
0.6537742 -1.1010497 1.1607163 1.0995561 119 139.4684 3

The resulting table has a row for each cluster and contains columns for the cluster centers, the number of observations in each cluster, and the within-cluster sum of squares. The cluster center coordinates are based on normalized data and therefore not directly comparable to the original data.

Useful to know:

It is important to set a random seed for reproducibility. The numbering of clusters as well as cluster assignments of data points in rougly equal distance to multiple cluster centers can be different each time you run the code

We can also look at the cluster center coordinates using a parallel coordinate plot (see Figure 19.1).

Code
tidy(kmeans_model) %>%
    pivot_longer(cols=c("bill_length_mm", "bill_depth_mm",
                        "flipper_length_mm", "body_mass_g")) %>%
ggplot(aes(x=name, y=value, group=cluster, color=cluster)) +
    geom_point() +
    geom_line() +
    labs(x="", y="Value at cluster center")
Cluster center values for each variable

Figure 19.1: Cluster center values for each variable

Clusters 1 and 2 have similar characteristics and differ only with respect to bill length. Cluster 1 represents penguins with smaller bill lengths compared to the penguins in clusters 2 and 3. Cluster 3 is clearly different to the other two clusters and represents penguins with a larger body mass, longer flippers, and smaller bill depth.

We can use the augment() function to add the cluster assignments to the original or new data.

Code
cl_penguins <- augment(kmeans_model, new_data=penguins)
cl_penguins %>% head() %>% knitr::kable() %>% scroll_box(width = "100%")
.pred_cluster species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
Cluster_1 Adelie Torgersen 39.1 18.7 181 3750 male
Cluster_1 Adelie Torgersen 39.5 17.4 186 3800 female
Cluster_1 Adelie Torgersen 40.3 18.0 195 3250 female
Cluster_1 Adelie Torgersen 36.7 19.3 193 3450 female
Cluster_1 Adelie Torgersen 39.3 20.6 190 3650 male
Cluster_1 Adelie Torgersen 38.9 17.8 181 3625 female

Figure 19.2 shows a scatterplot matrix (ggpairs) of the penguin data with the cluster assignments indicated by color. The clusters are clearly separated in the scatterplot matrix. The GGally package provides a ggpairs() function that can be used to create such a plot.

Code
cl_penguins %>%
    select(-c(species, island, sex)) %>%
    ggpairs(aes(color=.pred_cluster))
k-means clustering of penguins

Figure 19.2: k-means clustering of penguins

It is interesting to compare the distribution of the other variables in the clusters.

Code
plot_distribution <- function(variable) {
    ggplot(cl_penguins, aes(fill=.data[[variable]], x=.pred_cluster)) +
        geom_bar() +
        theme(legend.position="inside", legend.position.inside=c(0.74, 0.83)) +
        scale_y_continuous(limits=c(0, 200)) +
        labs(y="", x="")
}
g1 <- plot_distribution("species")
g2 <- plot_distribution("island")
g3 <- plot_distribution("sex")

g1 + g2 + g3
Distribution of other variables in the clusters

Figure 19.3: Distribution of other variables in the clusters

Figure @(fig:penguin-kmean-distribution) shows that the clusters separate the species well. The clusters also show some discrimination of islands. Cluster 3 contains only penguins from Biscoe and cluster 2 mostly penguins from Dream. Sex is not well separated by the clusters.

Useful to know:

\(k\)-Means clustering requires numerical data. If your dataset contains factors, tidyclust will automatically convert these to indicator variables.13

19.2 Hierarchical clustering

The tidyclust package also provides hierarchical clustering. Using the same penguin dataset and the recipe from the previous section we can fit a hierarchical clustering as follows. In hierarchical clustering, points are combined based on distances. It is therefore recommended to normalize the data. At the time of writing, hierarchical clustering did not work as part of a workflow.14 We therefore first preprocess the data and then perform the hierarchical clustering.

Code
formula <- ~ bill_length_mm + bill_depth_mm + flipper_length_mm + body_mass_g
rec_penguins <- recipe(formula, data=penguins) %>%
    step_normalize(all_predictors())

norm_penguins <- rec_penguins %>%
    prep() %>%
    bake(new_data=penguins)

hier_penguins <- hier_clust(linkage_method="complete", num_clusters=3) %>%
    set_engine("stats") %>%
    set_mode("partition")

hier_model <- hier_penguins %>% fit(formula, data=norm_penguins)

We specify the number of clusters (num_clusters) in the hier_clust function. An alternative would be to define the height at which to cut the dendrogram using cut_height.

Currently, tidyclust only provides a wrapper around the stats::hclust function. The resulting clustering can be visualized by accessing the underlying hier_model$fit object.

Code
hier_model$fit %>% plot()
Cluster dendrogram for a hierarchical clustering of the penguins dataset using complete linkage

Figure 19.4: Cluster dendrogram for a hierarchical clustering of the penguins dataset using complete linkage

Figure 19.4 shows the dendrogram of the complete linkage clustering.

As we specified number of clusters, we can extract information about the resulting cluster.

Code
cluster_assignment <- hier_model %>% extract_cluster_assignment()
centroids <- hier_model %>% extract_centroids()
Code
centroids %>%
    pivot_longer(cols=c("bill_length_mm", "bill_depth_mm",
                        "flipper_length_mm", "body_mass_g")) %>%
ggplot(aes(x=name, y=value, group=.cluster, color=.cluster)) +
    geom_point() +
    geom_line() +
    labs(x="", y="Value at cluster center")
Cluster center values for each variable from hierarchical clustering

Figure 19.5: Cluster center values for each variable from hierarchical clustering

Figure 19.5 visualizes the co-ordinates of the cluster centroids. The results are comparable to the k-means clustering shown in Figure 19.1.

The resulting hiearchical clustering can also predict new data. It classifies a data point by distance to the nearest cluster centroid. In contrast to k-means clustering, you can specify the number of clusters or the cut height in the predict function.

Code
pred_class <- hier_model %>% predict(new_data=norm_penguins, num_clusters=4)

We can visualize the resulting cluster assignments in a pairs plot.

Code
penguins %>%
    cbind(hier_model %>% predict(new_data=norm_penguins, num_clusters=4)) %>%
    select(-c(species, island, sex)) %>%
    ggpairs(aes(color=.pred_cluster))
Hierarchical clustering of penguins

Figure 19.6: Hierarchical clustering of penguins

As can be seen in Figure 19.6 the additional split leads to the formation of the blue and purple clusters (compare to k-means Figure 19.2 for three clusters).

19.3 Determine the number of clusters

Code
set.seed(123)
penguins <- modeldata::penguins %>% drop_na()

formula <- ~ bill_length_mm + bill_depth_mm + flipper_length_mm + body_mass_g
rec_penguins <- recipe(formula, data=penguins) %>%
    step_normalize(all_predictors())
kmeans_penguins <- k_means(num_clusters=tune()) %>%
    set_engine("stats") %>%
    set_mode("partition")
kmeans_wf <- workflow() %>%
    add_recipe(rec_penguins) %>%
    add_model(kmeans_penguins)

In order to tune the number of clusters, we set num_clusters=tune(). We can now run the tune_cluster function using a grid search over different values with cross-validation.

Code
registerDoSEQ()  # currently, tune_cluster fails when running in parallel; temporarily disable it

set.seed(4400)
folds <- vfold_cv(penguins, v = 2)
grid <- tibble(num_clusters=1:10)
result <- tune_cluster(kmeans_wf, resamples=folds, grid=grid,
    metrics=cluster_metric_set(sse_within_total, silhouette_avg))

registerDoParallel(cl)

We can use the collect_metrics function to retrieve the cluster metrics for different cluster numbers. By default, tidyclust computes the total sum of squares (sse_total) and the within-cluster SSE (sse_within_total).

Code
collect_metrics(result) %>% head()
## # A tibble: 6 × 7
##   num_clusters .metric          .estimator    mean     n std_err .config        
##          <int> <chr>            <chr>        <dbl> <int>   <dbl> <chr>          
## 1            1 silhouette_avg   standard   NaN         0 NA      Preprocessor1_…
## 2            1 sse_within_total standard   662         2  2.00   Preprocessor1_…
## 3            2 silhouette_avg   standard     0.729     2  0.0123 Preprocessor1_…
## 4            2 sse_within_total standard   278.        2 18.4    Preprocessor1_…
## 5            3 silhouette_avg   standard     0.661     2  0.0207 Preprocessor1_…
## 6            3 sse_within_total standard   182.        2  1.65   Preprocessor1_…

tune_cluster also supports the autoplot function to visualize the variation of cluster metrics as a function of the tuning parameter.

Code
autoplot(result)
Cluster metrics as a function of number of clusters

Figure 19.7: Cluster metrics as a function of number of clusters

The metrics curves are interpreted as follows to determine the optimal number of clusters. For the sse_within_total curve, we look for an ellbow, a point where the slope of the curve changes visibly. In our case, the ellbow is either at 2 or at 3. In the silhouette_avg curve, the optimal cluster number corresponds to the maximum of the curve, here 2.

Useful to know:

While these metrics can act as a guideline, the decision is somewhat subjective and often depends more on the use case.

Once we have made a decision on the number of clusters, we can train a finalized model using the finalize_model_tidyclust or finalize_workflow_tidyclust methods.

Code
best_params <- data.frame(num_clusters=3)
model <- finalize_workflow_tidyclust(kmeans_wf, best_params)
model
## ══ Workflow ═════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: k_means()
## 
## ── Preprocessor ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
## 1 Recipe Step
## 
## • step_normalize()
## 
## ── Model ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
## K Means Cluster Specification (partition)
## 
## Main Arguments:
##   num_clusters = 3
## 
## Computational engine: stats

https://stackoverflow.com/a/75756506

Further information:

Code
stopCluster(cl)
registerDoSEQ()

Code

The code of this chapter is summarized here.

Code
knitr::opts_chunk$set(echo=TRUE, cache=TRUE, autodep=TRUE, fig.align="center")
library(tidymodels)
library(tidyclust)
library(kableExtra)
library(patchwork)
library(GGally)
library(doParallel)
library(doParallel)
cl <- makePSOCKcluster(parallel::detectCores(logical = FALSE))
registerDoParallel(cl)
# k-means clustering uses a random starting point,
# so we set a seed for reproducibility
set.seed(123)
penguins <- modeldata::penguins %>% drop_na()

formula <- ~ bill_length_mm + bill_depth_mm + flipper_length_mm + body_mass_g
rec_penguins <- recipe(formula, data=penguins) %>%
    step_normalize(all_predictors())
kmeans_penguins <- k_means(num_clusters=3) %>%
    set_engine("stats") %>%
    set_mode("partition")
kmeans_wf <- workflow() %>%
    add_recipe(rec_penguins) %>%
    add_model(kmeans_penguins)
kmeans_model <- kmeans_wf %>% fit(data=penguins)
tidy(kmeans_model) %>% knitr::kable() %>% scroll_box(width = "100%")
tidy(kmeans_model) %>%
    pivot_longer(cols=c("bill_length_mm", "bill_depth_mm",
                        "flipper_length_mm", "body_mass_g")) %>%
ggplot(aes(x=name, y=value, group=cluster, color=cluster)) +
    geom_point() +
    geom_line() +
    labs(x="", y="Value at cluster center")
cl_penguins <- augment(kmeans_model, new_data=penguins)
cl_penguins %>% head() %>% knitr::kable() %>% scroll_box(width = "100%")
cl_penguins %>%
    select(-c(species, island, sex)) %>%
    ggpairs(aes(color=.pred_cluster))
plot_distribution <- function(variable) {
    ggplot(cl_penguins, aes(fill=.data[[variable]], x=.pred_cluster)) +
        geom_bar() +
        theme(legend.position="inside", legend.position.inside=c(0.74, 0.83)) +
        scale_y_continuous(limits=c(0, 200)) +
        labs(y="", x="")
}
g1 <- plot_distribution("species")
g2 <- plot_distribution("island")
g3 <- plot_distribution("sex")

g1 + g2 + g3
formula <- ~ bill_length_mm + bill_depth_mm + flipper_length_mm + body_mass_g
rec_penguins <- recipe(formula, data=penguins) %>%
    step_normalize(all_predictors())

norm_penguins <- rec_penguins %>%
    prep() %>%
    bake(new_data=penguins)

hier_penguins <- hier_clust(linkage_method="complete", num_clusters=3) %>%
    set_engine("stats") %>%
    set_mode("partition")

hier_model <- hier_penguins %>% fit(formula, data=norm_penguins)
hier_model$fit %>% plot()
cluster_assignment <- hier_model %>% extract_cluster_assignment()
centroids <- hier_model %>% extract_centroids()
centroids %>%
    pivot_longer(cols=c("bill_length_mm", "bill_depth_mm",
                        "flipper_length_mm", "body_mass_g")) %>%
ggplot(aes(x=name, y=value, group=.cluster, color=.cluster)) +
    geom_point() +
    geom_line() +
    labs(x="", y="Value at cluster center")
pred_class <- hier_model %>% predict(new_data=norm_penguins, num_clusters=4)
penguins %>%
    cbind(hier_model %>% predict(new_data=norm_penguins, num_clusters=4)) %>%
    select(-c(species, island, sex)) %>%
    ggpairs(aes(color=.pred_cluster))
set.seed(123)
penguins <- modeldata::penguins %>% drop_na()

formula <- ~ bill_length_mm + bill_depth_mm + flipper_length_mm + body_mass_g
rec_penguins <- recipe(formula, data=penguins) %>%
    step_normalize(all_predictors())
kmeans_penguins <- k_means(num_clusters=tune()) %>%
    set_engine("stats") %>%
    set_mode("partition")
kmeans_wf <- workflow() %>%
    add_recipe(rec_penguins) %>%
    add_model(kmeans_penguins)
registerDoSEQ()  # currently, tune_cluster fails when running in parallel; temporarily disable it

set.seed(4400)
folds <- vfold_cv(penguins, v = 2)
grid <- tibble(num_clusters=1:10)
result <- tune_cluster(kmeans_wf, resamples=folds, grid=grid,
    metrics=cluster_metric_set(sse_within_total, silhouette_avg))

registerDoParallel(cl)
collect_metrics(result) %>% head()
autoplot(result)
best_params <- data.frame(num_clusters=3)
model <- finalize_workflow_tidyclust(kmeans_wf, best_params)
model
stopCluster(cl)
registerDoSEQ()

  1. I couldn’t get this to work in a workflow↩︎

  2. If you get it to work, let me know↩︎