前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于 ggplot2 的 confusion matrix 可视化

基于 ggplot2 的 confusion matrix 可视化

作者头像
一只羊
发布2022-11-30 14:27:21
9430
发布2022-11-30 14:27:21
举报
文章被收录于专栏:生信了生信了

本文介绍如何用 ggplot2 实现 confusion matrix 的可视化

confusion matrix 经常被用来表示两个类别重合的程度,比如在分类任务中,我们可以用 confusion matrix 来展示预测的 label 与真实的 label 在多大程度上是一致的。

在 R 中有很多现成的包可以画 confusion matrix,不过本文介绍的是一种基于 ggplot2 的实现,它的好处是灵活,可以给用户更多的自由去改进某些功能。我们后续会介绍如何基于 heatmap 来画 confusion matrix,它的好处是使用简单,但是相对地会减少一些更改功能的自由。

我们已经将所需 ggplot2 的功能包装到一个名为 plot_conf_mtx 的函数中(详细信息见文末)。

代码语言:javascript
复制
#' @param dat A dataframe containing 3 columns: `x`, `y`, and `n`. See
#'   details in \code{Details}.
#' @param xlab A string that is the label of x axis of the table.
#' @param ylab A string that is the label of y axis of the table.
#' @param x_order A character vector of names of the x-axis objects that
#'   defines their orders.
#' @param y_order A character vector of names of the y-axis objects that
#'   defines their orders.
#' @param normalize A bool whether the `n` values should be scaled to make
#'   the sum of rows or columns to be 1, when setting background colors.
#' @param normalize_by_row A bool whether the normalization should be
#'   performed by row or by column.
#' @return A ggplot2 object designed for confusion matrix visualization.

plot_conf_mtx <- function(
  dat, 
  xlab, 
  ylab, 
  x_order = NULL,
  y_order = NULL,
  normalize = TRUE,
  normalize_by_row = TRUE
)

用户所需准备的就是一个包含了 confusion matrix 全部信息的 dataframe。这个 dataframe 应有 3 列,分别为 x, y, n,分别为 confusion matrix 的列名,行名以及二者重合元素的数量。一个例子如下:

代码语言:javascript
复制
d <- data.frame(
  x = c("A1", "A1", "A1", "A2", "A2", "A2", "A3", "A3", "A3"),
  y = c("B1", "B2", "B3", "B1", "B2", "B3", "B1", "B2", "B3"),
  n = c(88, 3, 1, 22, 112, 5, 2, 2, 108)
)

> d
   x  y   n
1 A1 B1  88
2 A1 B2   3
3 A1 B3   1
4 A2 B1  22
5 A2 B2 112
6 A2 B3   5
7 A3 B1   2
8 A3 B2   2
9 A3 B3 108

默认情况下,背景色会根据每一行 normalized 后的值进行设置。当然,用户可以按每一列进行 normalization(在 plot_conf_mtx 函数中指定 normalize = TRUE, normalize_by_row = FALSE),或者干脆不进行 normalization(在 plot_conf_mtx 函数中指定 normalize = FALSE)。

我们首先看一下默认设置下(按行进行 normalization)的 confusion matrix:

代码语言:javascript
复制
p1 <- plot_conf_mtx(
  d, 
  xlab = "Class A", 
  ylab = "Class B", 
  y_order = c("B3", "B2", "B1")
)
p1

然后看一下按列进行 normalization 的 confusion matrix:

代码语言:javascript
复制
p2 <- plot_conf_mtx(
  d, 
  xlab = "Class A", 
  ylab = "Class B", 
  y_order = c("B3", "B2", "B1"),
  normalize_by_row = FALSE
)
p2

接着看一下不进行 normalization 的 confusion matrix:

代码语言:javascript
复制
p3 <- plot_conf_mtx(
  d, 
  xlab = "Class A", 
  ylab = "Class B", 
  y_order = c("B3", "B2", "B1"),
  normalize = FALSE
)
p3

最后,我们还是提一下由于 plot_conf_mtx 函数返回的是一个 ggplot2 object,所以用户可以基于它添加一些特征或者调整一些细节。比如,调整字体大小,更改背景色等等。下面这个例子是给画好的 confusion matrix 添加一个标题:

代码语言:javascript
复制
p4 <- p1 +
  labs(title = "Default Confusion Matrix") +
  theme(plot.title = element_text(hjust = 0.5))
p4

完整代码

代码语言:javascript
复制
# reference: https://stackoverflow.com/questions/37897252/plot-confusion-matrix-in-r-using-ggplot

library(dplyr)
library(ggplot2)

#' Plot Confusion Matrix
#' 
#' @param dat A dataframe containing 3 columns: `x`, `y`, and `n`. See
#'   details in \code{Details}.
#' @param xlab A string that is the label of x axis of the table.
#' @param ylab A string that is the label of y axis of the table.
#' @param x_order A character vector of names of the x-axis objects that
#'   defines their orders.
#' @param y_order A character vector of names of the y-axis objects that
#'   defines their orders.
#' @param normalize A bool whether the `n` values should be scaled to make
#'   the sum of rows or columns to be 1, when setting background colors.
#' @param normalize_by_row A bool whether the normalization should be
#'   performed by row or by column.
#' @return A ggplot2 object designed for confusion matrix visualization.
#'   
#' @section Details
#' The `x` column is the names of the x-axis objects; The `y` column is the 
#' names of the y-axis objects; The `n` is the number of matched elements 
#' between `x` and `y`.
#' 
#' @examples
#' d <- data.frame(
#'   x = c("A1", "A1", "A1", "A2", "A2", "A2", "A3", "A3", "A3"),
#'   y = c("B1", "B2", "B3", "B1", "B2", "B3", "B1", "B2", "B3"),
#'   n = c(88, 3, 1, 22, 112, 5, 2, 2, 108)
#' )
#' p <- plot_conf_mtx(d, "Class A", "Class B")
#' 
#' #library(ggplot2)
#' #ggsave("out_conf_mtx.jpg", p, width = 8, height = 6, units = "cm",
#'         dpi = 300)
plot_conf_mtx <- function(
  dat, 
  xlab, 
  ylab, 
  x_order = NULL,
  y_order = NULL,
  normalize = TRUE,
  normalize_by_row = TRUE
) {
  if (normalize) {
    if (normalize_by_row) {
      dstat <- dat %>%
        group_by(y) %>%
        summarise(m = sum(n)) %>%
        ungroup()
      dat <- dat %>%
        left_join(dstat, by = "y")
    } else {
      dstat <- dat %>%
        group_by(x) %>%
        summarise(m = sum(n)) %>%
        ungroup()
      dat <- dat %>%
        left_join(dstat, by = "x")      
    }
    dat <- dat %>%
      mutate(freq = n / m)
  } else {
    dat <- dat %>%
      mutate(freq = n)
  }

  freq_cutoff <- (max(dat$freq) + min(dat$freq)) / 2
  dat <- dat %>%
    mutate(font_color = ifelse(freq > freq_cutoff, "white", "black"))
    
  if (is.null(x_order))
    x_order <- sort(unique(dat$x))
  
  if (is.null(y_order))
    y_order <- sort(unique(dat$y))

  p <- 
    # main settings
    ggplot(dat, aes(x = x, y = y, fill = freq)) +
    geom_tile() + 
    geom_text(aes(label = n, color = font_color), size = 4) +
    scale_color_manual(values = c("white" = "white", "black" = "black")) +
    scale_fill_distiller(palette = "Blues", direction = 1) +
    labs(x = xlab, y = ylab) +
    scale_x_discrete(limits = x_order) +
    scale_y_discrete(limits = y_order) +
    guides(color = "none") +
    guides(fill = guide_colourbar(title = NULL, barwidth = .5,
                                  barheight = 8)) +
    
    # use minimum background & theme
    theme_bw() +
    theme(panel.grid.major = element_blank(),
          panel.grid.minor = element_blank()) +
    theme(panel.background = element_blank()) +
    theme(panel.border = element_blank()) +
    theme(axis.ticks = element_blank()) +
    theme(axis.text.x = element_text(vjust = 4, size = 9)) +
    theme(axis.text.y = element_text(angle = 90, hjust = .5, vjust = -3, 
                                     size = 9))
  
  return(p)
}

d <- data.frame(
  x = c("A1", "A1", "A1", "A2", "A2", "A2", "A3", "A3", "A3"),
  y = c("B1", "B2", "B3", "B1", "B2", "B3", "B1", "B2", "B3"),
  n = c(88, 3, 1, 22, 112, 5, 2, 2, 108)
)

p1 <- plot_conf_mtx(
  d, 
  xlab = "Class A", 
  ylab = "Class B", 
  y_order = c("B3", "B2", "B1")
)
p1

p2 <- plot_conf_mtx(
  d, 
  xlab = "Class A", 
  ylab = "Class B", 
  y_order = c("B3", "B2", "B1"),
  normalize_by_row = FALSE
)
p2

p3 <- plot_conf_mtx(
  d, 
  xlab = "Class A", 
  ylab = "Class B", 
  y_order = c("B3", "B2", "B1"),
  normalize = FALSE
)
p3

p4 <- p1 +
  labs(title = "Default Confusion Matrix") +
  theme(plot.title = element_text(hjust = 0.5))
p4

《生信了》 2022年11月

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-11-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 生信了 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 完整代码
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档