Loading [MathJax]/jax/input/TeX/jax.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >PyTorch入门笔记-gather选择函数

PyTorch入门笔记-gather选择函数

作者头像
触摸壹缕阳光
修改于 2021-01-19 07:39:15
修改于 2021-01-19 07:39:15
3.8K0
举报

gather

torch.gather(*input,dim,index,sparse_grad=False, out=None*) 函数沿着指定的轴 dim 上的索引 index 采集输入张量 input 中的元素值,函数的参数有:

  • input (Tensor) - 输入张量
  • dim (int) - 需要进行索引的轴
  • index (LongTensor) - 要采集元素的索引
  • sparse_grad (bool, optional) - 如果为 True,输入张量 input 会变成离散张量
  • out (Tensor, optional) - 指定输出的张量。比如执行 torch.zeros(2, 2, out = tensor_a),相当于执行 tensor_a = torch.zeros(2, 2)

除了 sparse_grad 和 out 两个可选参数,其余三个参数都是必选参数。为了方便这里只考虑必选参数,即 torch.gather(input, dim, index)。

简单介绍完 gather 函数之后,来看一个简单的小例子:一次将下面 2D 张量中所有红色的元素采集出来。

2D 张量可以看成矩阵,2D 张量的第一个维度为矩阵的行 (dim = 0),2D 张量的第二个维度为矩阵的列 (dim = 1),从左向右依次看三个红色元素在矩阵中的具体位置:

  • 6: 第 2 行的第 0 列
  • 1: 第 0 行的第 1 列
  • 5: 第 1 行的第 2 列

通过红色元素的具体位置可以看出,三个红色元素的列索引号是有规律的:从 0 到 2 逐渐递增。假设此时列索引的规律是已知并且固定的,我们只需要给出这些红色元素在行上的索引号就可以将这些红色元素全部采集出来。

至此,对于这个 2D 张量的小例子,已知了输入张量和指定行上的索引号。回顾 torch.gather(input, dim, index) 函数沿着指定轴上的索引采集输入张量的元素值,貌似现在已知的条件和 gather 函数中所需要的参数有些谋和。下面我们来尝试一下使用 gather 函数来采集红色元素。

代码语言:txt
AI代码解释
复制
>>> import torch
>>> x = torch.arange(9).view(3, 3)
>>> print(x)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> index = torch.tensor([[2, 0, 1]])
>>> # dim=0: 行上的索引
>>> out = torch.gather(x, dim = 0, index = index)
>>> print(out)

tensor([[6, 1, 5]])

gather 函数的输出结果和我们在小例子中分析的结果一致。

如果按照从上到下来看三个红色元素,采集元素的顺序和从前面从左向右看的时候不同,此时采集元素的顺序为 1, 5, 6,现在看看此时这三个红色元素在矩阵中的具体位置:

  • 1: 第 0 行的第 1 列
  • 5: 第 1 行的第 2 列
  • 6: 第 2 行的第 0 列

现在行索引号是有规律的:从 0 到 2 逐渐递增。现在假设此时行索引的规律是已知并且固定的,我们只需要给出这些红色元素在列上的索引号就可以将这些红色元素全部采集出来了。

代码语言:txt
AI代码解释
复制
>>> import torch
>>> x = torch.arange(9).view(3, 3)
>>> print(x)

tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> index = torch.tensor([[1, 2, 0]]).t()
>>> # dim=1: 在列方向上索引
>>> out = torch.gather(x, dim = 1, index = index)
>>> print(out)

tensor([[1],
        [5],
        [6]])

在不同轴上 (行或列) 进行索引传入的 index 参数的张量形状不同,在 gather 函数中规定:

  • 传入 index 的张量维度数要和输入张量 input 的维度数相同;
  • 输出张量的形状和传入 index 张量的形状相同;
  • 如果沿着轴的每个维度采集 N 个元素,则 index 对应轴上的长度为 N(N1)。比如对于前面的 2D 张量,对行索引且每一行只采集一个元素,则 index 在行上的长度为 1,index 的形状即为 (1 x 3);

接下来使用一个形状为 (3 x 5) 2D 张量来详细的分析 gather 函数的原理。

2D 张量有两个轴,假定现在只采集一个元素:

  • dim = 0

dim = 0 表示在行上索引,此时假定已知且固定了在列上的索引,即 (其中 ? 为待采集元素在行上的索引号):

  • 在 ? 行的第 0 列
  • 在 ? 行的第 1 列
  • 在 ? 行的第 2 列
  • 在 ? 行的第 3 列
  • 在 ? 行的第 4 列

如果想要使用 gather 函数采集元素,需要在 index 中指定 5 个行索引号,而每列只索引一个元素且在行上索引 (dim = 0),因此最终我们需要传入 index 张量的形状为 (1, 5),其中的元素值为待采集元素的行索引号。

  • dim = 1

dim = 1 表示在列上索引,此时假定已知且固定了在行上的索引,即 (其中 ? 为待采集元素在列上的索引号):

  • 在 0 行的第 ? 列
  • 在 1 行的第 ? 列
  • 在 2 行的第 ? 列

如果想要使用 gather 函数采集元素,需要在 index 中指定 3 个列索引号,而每行只索引一个元素且在列上索引 (dim = 1),因此最终我们需要传入 index 张量的形状为 (1, 3),其中的元素值为待采集元素的列索引号。

最后来看看如何使用 gather 函数每行采集两个元素:

代码语言:txt
AI代码解释
复制
>>> import torch
>>> x = torch.arange(15).view(3, 5)
>>> index = torch.LongTensor([[0, 1], [2, 3], [1, 2]])
>>> out = torch.gather(x, dim = 1, index = index)
>>> print(out)

tensor([[ 0,  1],
        [ 7,  8],
        [11, 12]])

传入 index 的张量形状为 (3 x 2),因此最终输出张量的形状也为 (3 x 2)。dim = 1 表示在列上索引,此时假定已知且固定了在行上的索引:

  • 在 0 行的第 0 列,在 0 行的第 1 列
  • 在 1 行的第 2 列,在 1 行的第 3 列
  • 在 2 行的第 1 列,在 2 行的第 2 列
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-12-04,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
PyTorch入门笔记-index_select选择函数
torch.index_select(input,dim,index,out=None) 函数返回的是沿着输入张量的指定维度的指定索引号进行索引的张量子集,其中输入张量、指定维度和指定索引号就是 torch.index_select(input,dim,index,out=None) 函数的三个关键参数,函数参数有:
触摸壹缕阳光
2020/12/02
6.6K1
PyTorch入门笔记-index_select选择函数
PyTorch入门笔记-nonzero选择函数
前面已经介绍了 index_select 和 mask_select 两个选择函数,这两个函数通过一定的索引规则从输入张量中筛选出满足条件的元素值,只不过 index_select 函数使用索引 index 的索引规则,而 mask_select 函数使用布尔掩码 mask 的索引规则。
触摸壹缕阳光
2020/12/02
6.4K0
PyTorch入门笔记-nonzero选择函数
PyTorch入门笔记-masked_select选择函数
torch.masked_select(input,mask,out=None) 函数返回一个根据布尔掩码 (boolean mask) 索引输入张量的 1D 张量,其中布尔掩码和输入张量就是 torch.masked_select(input, mask, out = None) 函数的两个关键参数,函数的参数有:
触摸壹缕阳光
2020/12/02
4.5K0
PyTorch入门笔记-masked_select选择函数
pytorch基础知识-高阶代码
API: torch.where(condition, x, y) => Tensor
用户6719124
2019/11/17
9380
PyTorch中Tensor的操作手册
默认下,Tensor为‘torch.FloatTensor’类型,若要改为double类型的,则需要执行
孔西皮
2023/10/18
6600
PyTorch中Tensor的操作手册
PyTorch入门笔记-复制数据repeat函数
前面提到过 input.expand(*sizes) 函数能够实现 input 输入张量中单维度(singleton dimension)上数据的复制操作。「对于非单维度上的复制操作,expand 函数就无能为力了,此时就需要使用 input.repeat(*sizes)。」
触摸壹缕阳光
2021/01/28
6.1K0
PyTorch入门笔记-复制数据repeat函数
PyTorch入门笔记-分割chunk函数
torch.chunk(input, chunks, dim = 0) 函数会将输入张量(input)沿着指定维度(dim)均匀的分割成特定数量的张量块(chunks),并返回元素为张量块的元组。torch.chunk 函数有三个参数:
触摸壹缕阳光
2021/02/26
7.2K0
PyTorch入门笔记-分割chunk函数
PyTorch入门笔记-索引和切片
切片其实也是索引操作,所以切片经常被称为切片索引,为了更方便叙述,本文将切片称为切片索引。索引和切片操作可以帮助我们快速提取张量中的部分数据。
触摸壹缕阳光
2020/12/02
3.7K0
PyTorch入门笔记-索引和切片
pytorch新手需要注意的隐晦操作Tensor,max,gather
先看官方的介绍: 如果input是一个n维的tensor,size为 (x0,x1…,xi−1,xi,xi+1,…,xn−1),dim为i,然后index必须也为n维tensor,size为 (x0,x1,…,xi−1,y,xi+1,…,xn−1),其中y >= 1,最后输出的out与index的size是一样的。 意思就是按照一个指定的轴(维数)收集值 对于一个三维向量来说:
老潘
2018/06/21
4.4K0
pytorch新手需要注意的隐晦操作Tensor,max,gather
【深度学习】Pytorch 教程(十二):PyTorch数据结构:4、张量操作(3):张量修改操作(拆分、拓展、修改)
  Tensor(张量)是PyTorch中用于表示多维数据的主要数据结构,类似于多维数组,可以存储和操作数字数据。
Qomolangma
2024/07/30
2810
【深度学习】Pytorch 教程(十二):PyTorch数据结构:4、张量操作(3):张量修改操作(拆分、拓展、修改)
PyTorch从入门到放弃之张量模块
张量(Tensor)是PyTorch最基本的操作对象。在几何定义中,张量是基于标量、向量和矩阵概念的眼神。通俗理解,可以讲标量视为0维张量,向量视为1维张量,矩阵视为2维张量。在深度学习领域,可以将张量视为一个数据的水桶,当水桶中只放一滴水时就是0维张量,多滴水排成一排就是1维张量,联排成面就是2维张量,以此类推,扩展到n维向量。
愷龍
2024/09/03
2310
PyTorch从入门到放弃之张量模块
PyTorch入门笔记-创建已知分布的张量
正态分布(Normal Distribution)和均匀分布(Uniform Distribution)是最常见的分布之一,创建采样自这 2 个分布的张量非常有用,「比如在卷积神经网络中,卷积核张量
触摸壹缕阳光
2020/11/20
3.6K0
PyTorch入门笔记-创建已知分布的张量
Pytorch中张量的高级选择操作
在某些情况下,我们需要用Pytorch做一些高级的索引/选择,所以在这篇文章中,我们将介绍这类任务的三种最常见的方法:torch.index_select, torch.gather and torch.take
deephub
2024/03/11
3610
Pytorch中张量的高级选择操作
【PyTorch入门】 常用统计函数【二】
torch.prod()用于计算张量 a 中所有元素的乘积。返回一个张量,表示输入张量所有元素的累积乘积。如果输入是一个多维张量,则默认计算所有元素的乘积。
机器学习司猫白
2025/01/21
1640
「Deep Learning」PyTorch初步认识
torch.where(condition, x, y): 按照条件从x和y中选出满足条件的元素组成新的tensor。
曼亚灿
2023/07/05
6300
「Deep Learning」PyTorch初步认识
D2L学习笔记00:Pytorch操作
张量表示由一个数值组成的数组,这个数组可能有多个维度。具有一个轴的张量对应数学上的向量(vector);具有两个轴的张量对应数学上的矩阵(matrix);具有两个轴以上的张量没有特殊的数学名称。
Hsinyan
2022/08/30
1.7K0
PyTorch入门笔记-分割split函数
torch.split(input, split_size_or_sections, dim = 0) 函数会将输入张量(input)沿着指定维度(dim)分割成特定数量的张量块,并返回元素为张量块的元素。简单来说,可以将 torch.split 函数看成是 torch.chunk 函数的进阶版,因为 torch.split 不仅能够指定块数均匀分割(torch.chunk 只能指定块数均匀分割),而且能够指定分割每一块的长度。 torch.split 函数有三个参数:
触摸壹缕阳光
2021/02/26
8.4K0
PyTorch入门笔记-分割split函数
我对torch中的gather函数的一点理解
假设输入与上同;index=B;输出为C B中每个元素分别为b(0,0)=0,b(0,1)=0 b(1,0)=1,b(1,1)=0
树枝990
2020/08/20
1K0
PyTorch入门笔记-复制数据expand函数
当通过增加维度操作插入新维度后,可能希望在新维度上面复制若干份数据,满足后续算法的格式要求。考虑 Y = X@W + b 的例子,偏置 b 插入样本数的新维度后,需要在新维度上复制 Batch Size 份数据,将 shape 变为与 X@W 一致后,才能完成张量相加运算。
触摸壹缕阳光
2021/01/28
7.1K0
PyTorch入门笔记-复制数据expand函数
强的离谱,16个Pytorch核心操作!!
当然在 PyTorch 中,转换函数的主要意义主要是用于对进行数据的预处理和数据增强,使其适用于深度学习模型的训练和推理。
Python编程爱好者
2023/12/26
3840
强的离谱,16个Pytorch核心操作!!
相关推荐
PyTorch入门笔记-index_select选择函数
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档