前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >理解 PyTorch 中的 gather 函数

理解 PyTorch 中的 gather 函数

作者头像
Alan Lee
发布2021-12-07 15:49:03
1.9K0
发布2021-12-07 15:49:03
举报
文章被收录于专栏:Small Code

好久没更新博客了,最近一直在忙,既有生活上的也有工作上的。道阻且长啊。

今天来水一文,说一说最近工作上遇到的一个函数:torch.gather()

文字理解

我遇到的代码是 NLP 相关的,代码中用 torch.gather() 来将一个 tensor 的 shape 从 (batch_size, seq_length, hidden_size) 转为 (batch_size, labels_length, hidden_size) ,其中 seq_length >= labels_length

torch.gather() 的官方解释是

Gathers values along an axis specified by dim.

就是在指定维度上 gather value。那么怎么 gather、gather 哪些 value 呢?这就要看其参数了。

torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

所以一句话概括 gather 操作就是:根据 index ,在 inputdim 维度上收集 value

具体来说,input 就是源 tensor,等会我们要在这个 tensor 上执行 gather 操作。如果 input 是一个一维数组,即 flat 列表,那么我们就可以直接根据 indexinput 上取了,就像正常的列表/数组索引一样。但是由于 input 可能含有多个维度,是 N 维数组,所以我们需要知道在哪个维度上进行 gather,这就是 dim 的作用。

对于 dim 参数,一种更为具体的理解方式是替换法。假设 inputindex 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可,但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报 IndexError 。所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。如果 dim=0 ,我们就替换索引列表第 0 个值,即 [dim, j, k] ,依此类推。Pytorch 的官方文档的写法其实也是这个意思,但是看这么多个方括号可能会有点懵:

代码语言:javascript
复制
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

但是可能你还有点迷糊,没关系接着看下面的直观理解部分,然后再回来看这段话,结合着看,相信你很快能明白。

由于我们是按照 index 来取值的,所以最终得到的 tensor 的 shape 也是和 index 一样的,就像我们在列表上按索引取值,得到的输出列表长度和索引相等一样。

直观理解

为便于理解,我们以一个具体例子来说明。我们使用反推法,根据 input 和输出推参数。这应该也是我们平常自己写代码的时候遇到比较多的情况。

假设 input 和我们想要的输出 output 如下:

代码语言:javascript
复制
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
>>> output_tensor  # shape: (2, 2, 4)
tensor([[[ 0,  1,  2,  3],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [20, 21, 22, 23]]])

即,我们想让 shape 为 (2, 3, 4)input_tensor 变成 shape 为 (2, 2, 4)output_tensor ,丢弃维度 1 的第 2 个元素,即 [ 4, 5, 6, 7][16, 17, 18, 19]

我们应用替换法,重点是找出来 dimindex 的值。始终记住 indexoutput_tensor 的 shape 是一样的。

output_tensor 的第一个位置开始,由于 output_tensor[0, 0, :] = input_tensor[0, 0, :] ,所以此时 [i, j, k] 是一样的,我们看不出来 dim 应该是多少。

下一行 output_tensor[0, 1, 0] = input_tensor[0, 2, 0] ,这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2, index_tensor[0, 1, 0]=2

此时 dim 已经明确。同理,output_tensor[0, 1, 1] = input_tensor[0, 2, 1]index_tensor[0, 1, 1]=2 ,依此类推,得到 index_tensor[0, 1, :] = 2 。同时也可以明确 index_tensor[0, 0, :] = 0

所以

代码语言:javascript
复制
>>> dim = 0
>>> index_tensor
tensor([[[0, 0, 0, 0],
         [2, 2, 2, 2]],

        [[0, 0, 0, 0],
         [2, 2, 2, 2]]])

简单可描述如下图:

为描述方便,假如我们把输入看作是 6 行,从上到下依次是 0-5。那么从事后诸葛亮的角度讲,输出相当于是把第 1 和第 4 行“抽掉”。如果输出和输入一样,那么原本的 index_tensor 就是如下:

代码语言:javascript
复制
tensor([[[0, 0, 0, 0],
         [1, 1, 1, 1],
         [2, 2, 2, 2]],

        [[0, 0, 0, 0],
         [1, 1, 1, 1],
         [2, 2, 2, 2]]])

“抽掉”后, index_tensor 也相应“抽掉”,那么就得到我们想要的结果了。而且由于这个“抽掉”的操作是在维度 1 上进行的,那么 dim 自然是 1。

numpy.take()tf.gather 貌似也是同样功能,就不细说了。

Reference

END

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021/08/29 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 文字理解
  • 直观理解
  • Reference
  • END
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档