本文介绍如何用 ggplot2 实现 confusion matrix 的可视化
confusion matrix 经常被用来表示两个类别重合的程度,比如在分类任务中,我们可以用 confusion matrix 来展示预测的 label 与真实的 label 在多大程度上是一致的。
在 R 中有很多现成的包可以画 confusion matrix,不过本文介绍的是一种基于 ggplot2 的实现,它的好处是灵活,可以给用户更多的自由去改进某些功能。我们后续会介绍如何基于 heatmap 来画 confusion matrix,它的好处是使用简单,但是相对地会减少一些更改功能的自由。
我们已经将所需 ggplot2 的功能包装到一个名为 plot_conf_mtx
的函数中(详细信息见文末)。
#' @param dat A dataframe containing 3 columns: `x`, `y`, and `n`. See
#' details in \code{Details}.
#' @param xlab A string that is the label of x axis of the table.
#' @param ylab A string that is the label of y axis of the table.
#' @param x_order A character vector of names of the x-axis objects that
#' defines their orders.
#' @param y_order A character vector of names of the y-axis objects that
#' defines their orders.
#' @param normalize A bool whether the `n` values should be scaled to make
#' the sum of rows or columns to be 1, when setting background colors.
#' @param normalize_by_row A bool whether the normalization should be
#' performed by row or by column.
#' @return A ggplot2 object designed for confusion matrix visualization.
plot_conf_mtx <- function(
dat,
xlab,
ylab,
x_order = NULL,
y_order = NULL,
normalize = TRUE,
normalize_by_row = TRUE
)
用户所需准备的就是一个包含了 confusion matrix 全部信息的 dataframe
。这个 dataframe 应有 3 列,分别为 x
, y
, n
,分别为 confusion matrix 的列名,行名以及二者重合元素的数量。一个例子如下:
d <- data.frame(
x = c("A1", "A1", "A1", "A2", "A2", "A2", "A3", "A3", "A3"),
y = c("B1", "B2", "B3", "B1", "B2", "B3", "B1", "B2", "B3"),
n = c(88, 3, 1, 22, 112, 5, 2, 2, 108)
)
> d
x y n
1 A1 B1 88
2 A1 B2 3
3 A1 B3 1
4 A2 B1 22
5 A2 B2 112
6 A2 B3 5
7 A3 B1 2
8 A3 B2 2
9 A3 B3 108
默认情况下,背景色会根据每一行 normalized 后的值进行设置。当然,用户可以按每一列进行 normalization(在 plot_conf_mtx 函数中指定 normalize = TRUE, normalize_by_row = FALSE
),或者干脆不进行 normalization(在 plot_conf_mtx 函数中指定 normalize = FALSE
)。
我们首先看一下默认设置下(按行进行 normalization)的 confusion matrix:
p1 <- plot_conf_mtx(
d,
xlab = "Class A",
ylab = "Class B",
y_order = c("B3", "B2", "B1")
)
p1
然后看一下按列进行 normalization 的 confusion matrix:
p2 <- plot_conf_mtx(
d,
xlab = "Class A",
ylab = "Class B",
y_order = c("B3", "B2", "B1"),
normalize_by_row = FALSE
)
p2
接着看一下不进行 normalization 的 confusion matrix:
p3 <- plot_conf_mtx(
d,
xlab = "Class A",
ylab = "Class B",
y_order = c("B3", "B2", "B1"),
normalize = FALSE
)
p3
最后,我们还是提一下由于 plot_conf_mtx 函数返回的是一个 ggplot2 object,所以用户可以基于它添加一些特征或者调整一些细节。比如,调整字体大小,更改背景色等等。下面这个例子是给画好的 confusion matrix 添加一个标题:
p4 <- p1 +
labs(title = "Default Confusion Matrix") +
theme(plot.title = element_text(hjust = 0.5))
p4
# reference: https://stackoverflow.com/questions/37897252/plot-confusion-matrix-in-r-using-ggplot
library(dplyr)
library(ggplot2)
#' Plot Confusion Matrix
#'
#' @param dat A dataframe containing 3 columns: `x`, `y`, and `n`. See
#' details in \code{Details}.
#' @param xlab A string that is the label of x axis of the table.
#' @param ylab A string that is the label of y axis of the table.
#' @param x_order A character vector of names of the x-axis objects that
#' defines their orders.
#' @param y_order A character vector of names of the y-axis objects that
#' defines their orders.
#' @param normalize A bool whether the `n` values should be scaled to make
#' the sum of rows or columns to be 1, when setting background colors.
#' @param normalize_by_row A bool whether the normalization should be
#' performed by row or by column.
#' @return A ggplot2 object designed for confusion matrix visualization.
#'
#' @section Details
#' The `x` column is the names of the x-axis objects; The `y` column is the
#' names of the y-axis objects; The `n` is the number of matched elements
#' between `x` and `y`.
#'
#' @examples
#' d <- data.frame(
#' x = c("A1", "A1", "A1", "A2", "A2", "A2", "A3", "A3", "A3"),
#' y = c("B1", "B2", "B3", "B1", "B2", "B3", "B1", "B2", "B3"),
#' n = c(88, 3, 1, 22, 112, 5, 2, 2, 108)
#' )
#' p <- plot_conf_mtx(d, "Class A", "Class B")
#'
#' #library(ggplot2)
#' #ggsave("out_conf_mtx.jpg", p, width = 8, height = 6, units = "cm",
#' dpi = 300)
plot_conf_mtx <- function(
dat,
xlab,
ylab,
x_order = NULL,
y_order = NULL,
normalize = TRUE,
normalize_by_row = TRUE
) {
if (normalize) {
if (normalize_by_row) {
dstat <- dat %>%
group_by(y) %>%
summarise(m = sum(n)) %>%
ungroup()
dat <- dat %>%
left_join(dstat, by = "y")
} else {
dstat <- dat %>%
group_by(x) %>%
summarise(m = sum(n)) %>%
ungroup()
dat <- dat %>%
left_join(dstat, by = "x")
}
dat <- dat %>%
mutate(freq = n / m)
} else {
dat <- dat %>%
mutate(freq = n)
}
freq_cutoff <- (max(dat$freq) + min(dat$freq)) / 2
dat <- dat %>%
mutate(font_color = ifelse(freq > freq_cutoff, "white", "black"))
if (is.null(x_order))
x_order <- sort(unique(dat$x))
if (is.null(y_order))
y_order <- sort(unique(dat$y))
p <-
# main settings
ggplot(dat, aes(x = x, y = y, fill = freq)) +
geom_tile() +
geom_text(aes(label = n, color = font_color), size = 4) +
scale_color_manual(values = c("white" = "white", "black" = "black")) +
scale_fill_distiller(palette = "Blues", direction = 1) +
labs(x = xlab, y = ylab) +
scale_x_discrete(limits = x_order) +
scale_y_discrete(limits = y_order) +
guides(color = "none") +
guides(fill = guide_colourbar(title = NULL, barwidth = .5,
barheight = 8)) +
# use minimum background & theme
theme_bw() +
theme(panel.grid.major = element_blank(),
panel.grid.minor = element_blank()) +
theme(panel.background = element_blank()) +
theme(panel.border = element_blank()) +
theme(axis.ticks = element_blank()) +
theme(axis.text.x = element_text(vjust = 4, size = 9)) +
theme(axis.text.y = element_text(angle = 90, hjust = .5, vjust = -3,
size = 9))
return(p)
}
d <- data.frame(
x = c("A1", "A1", "A1", "A2", "A2", "A2", "A3", "A3", "A3"),
y = c("B1", "B2", "B3", "B1", "B2", "B3", "B1", "B2", "B3"),
n = c(88, 3, 1, 22, 112, 5, 2, 2, 108)
)
p1 <- plot_conf_mtx(
d,
xlab = "Class A",
ylab = "Class B",
y_order = c("B3", "B2", "B1")
)
p1
p2 <- plot_conf_mtx(
d,
xlab = "Class A",
ylab = "Class B",
y_order = c("B3", "B2", "B1"),
normalize_by_row = FALSE
)
p2
p3 <- plot_conf_mtx(
d,
xlab = "Class A",
ylab = "Class B",
y_order = c("B3", "B2", "B1"),
normalize = FALSE
)
p3
p4 <- p1 +
labs(title = "Default Confusion Matrix") +
theme(plot.title = element_text(hjust = 0.5))
p4
《生信了》 2022年11月