前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch使用------张量的类型转换,拼接操作,索引操作,形状操作

PyTorch使用------张量的类型转换,拼接操作,索引操作,形状操作

作者头像
小言从不摸鱼
发布2024-09-10 18:50:38
580
发布2024-09-10 18:50:38
举报
文章被收录于专栏:机器学习入门

前言

学习张量的拼接、索引和形状操作在深度学习和数据处理中至关重要。

拼接操作允许我们合并不同来源或不同维度的数据,以丰富模型输入或构建复杂网络结构。

索引操作则提供了精确访问和操作张量中特定元素或子张量的能力,这对于数据预处理、特征提取和错误调试尤为关键。

形状操作如重塑、转置等,能够灵活调整张量的维度,确保数据符合算法或网络层的输入要求,从而优化计算效率和性能。

在学习张量三大操作之前,我们先来简单熟悉一下张量的类型转换

1. 张量类型转换

张量的类型转换也是经常使用的一种操作,是必须掌握的知识点。在本小节,我们主要学习如何将 numpy 数组和 PyTorch Tensor 的转化方法.

1.1 张量转换为 numpy 数组

使用 Tensor.numpy 函数可以将张量转换为 ndarray 数组,但是共享内存,可以使用 copy 函数避免共享。

代码语言:javascript
复制
# 1. 将张量转换为 numpy 数组
def test01():

    data_tensor = torch.tensor([2, 3, 4])
    # 使用张量对象中的 numpy 函数进行转换
    data_numpy = data_tensor.numpy()

    print(type(data_tensor))
    print(type(data_numpy))

    # 注意: data_tensor 和 data_numpy 共享内存
    # 修改其中的一个,另外一个也会发生改变
    # data_tensor[0] = 100
    data_numpy[0] = 100

    print(data_tensor)
    print(data_numpy)


# 2. 对象拷贝避免共享内存
def test02():

    data_tensor = torch.tensor([2, 3, 4])
    # 使用张量对象中的 numpy 函数进行转换
    data_numpy = data_tensor.numpy()

    print(type(data_tensor))
    print(type(data_numpy))

    # 注意: data_tensor 和 data_numpy 共享内存
    # 修改其中的一个,另外一个也会发生改变
    # data_tensor[0] = 100
    data_numpy[0] = 100

    print(data_tensor)
    print(data_numpy)

1.2 numpy 转换为张量

  1. 使用 from_numpy 可以将 ndarray 数组转换为 Tensor,默认共享内存,使用 copy 函数避免共享。
  2. 使用 torch.tensor 可以将 ndarray 数组转换为 Tensor,默认不共享内存。
代码语言:javascript
复制
# 1. 使用 from_numpy 函数
def test01():

    data_numpy = np.array([2, 3, 4])
    # 将 numpy 数组转换为张量类型
    # 1. from_numpy
    # 2. torch.tensor(ndarray)

    # 浅拷贝
    data_tensor = torch.from_numpy(data_numpy)

    # nunpy 和 tensor 共享内存
    # data_numpy[0] = 100
    data_tensor[0] = 100

    print(data_tensor)
    print(data_numpy)


# 2. 使用 torch.tensor 函数
def test02():

    data_numpy = np.array([2, 3, 4])

    data_tensor = torch.tensor(data_numpy)

    # nunpy 和 tensor 不共享内存
    # data_numpy[0] = 100
    data_tensor[0] = 100

    print(data_tensor)
    print(data_numpy)

1.3 标量张量和数字的转换

对于只有一个元素的张量,使用 item 方法将该值从张量中提取出来。

代码语言:javascript
复制
# 3. 标量张量和数字的转换
def test03():

    # 当张量只包含一个元素时, 可以通过 item 函数提取出该值
    data = torch.tensor([30,])
    print(data.item())

    data = torch.tensor(30)
    print(data.item())


if __name__ == '__main__':
    test03()

程序输出结果:

代码语言:javascript
复制
30
30

1.4 小节

在本小节中, 我们主要学习了 numpy 和 tensor 互相转换的规则, 以及标量张量与数值之间的转换规则。

2. 张量拼接操作

张量的拼接操作在神经网络搭建过程中是非常常用的方法,例如: 在后面将要学习到的残差网络、注意力机制中都使用到了张量拼接。

2.1 torch.cat 函数的使用

torch.cat 函数可以将两个张量根据指定的维度拼接起来.

代码语言:javascript
复制
import torch


def test():

    data1 = torch.randint(0, 10, [3, 5, 4])
    data2 = torch.randint(0, 10, [3, 5, 4])

    print(data1)
    print(data2)
    print('-' * 50)

    # 1. 按0维度拼接
    new_data = torch.cat([data1, data2], dim=0)
    print(new_data.shape)
    print('-' * 50)

    # 2. 按1维度拼接
    new_data = torch.cat([data1, data2], dim=1)
    print(new_data.shape)

    # 3. 按2维度拼接
    new_data = torch.cat([data1, data2], dim=2)
    print(new_data)


if __name__ == '__main__':
    test()

程序输出结果:

代码语言:javascript
复制
tensor([[[6, 8, 3, 5],
         [1, 1, 3, 8],
         [9, 0, 4, 4],
         [1, 4, 7, 0],
         [5, 1, 4, 8]],

        [[0, 1, 4, 4],
         [4, 1, 8, 7],
         [5, 2, 6, 6],
         [2, 6, 1, 6],
         [0, 7, 8, 9]],

        [[0, 6, 8, 8],
         [5, 4, 5, 8],
         [3, 5, 5, 9],
         [3, 5, 2, 4],
         [3, 8, 1, 1]]])
tensor([[[4, 6, 8, 1],
         [0, 1, 8, 2],
         [4, 9, 9, 8],
         [5, 1, 5, 9],
         [9, 4, 3, 0]],

        [[7, 6, 3, 3],
         [4, 3, 3, 2],
         [2, 1, 1, 1],
         [3, 0, 8, 2],
         [8, 6, 6, 5]],

        [[0, 7, 2, 4],
         [4, 3, 8, 3],
         [4, 2, 1, 9],
         [4, 2, 8, 9],
         [3, 7, 0, 8]]])
--------------------------------------------------
torch.Size([6, 5, 4])
--------------------------------------------------
torch.Size([3, 10, 4])
tensor([[[6, 8, 3, 5, 4, 6, 8, 1],
         [1, 1, 3, 8, 0, 1, 8, 2],
         [9, 0, 4, 4, 4, 9, 9, 8],
         [1, 4, 7, 0, 5, 1, 5, 9],
         [5, 1, 4, 8, 9, 4, 3, 0]],

        [[0, 1, 4, 4, 7, 6, 3, 3],
         [4, 1, 8, 7, 4, 3, 3, 2],
         [5, 2, 6, 6, 2, 1, 1, 1],
         [2, 6, 1, 6, 3, 0, 8, 2],
         [0, 7, 8, 9, 8, 6, 6, 5]],

        [[0, 6, 8, 8, 0, 7, 2, 4],
         [5, 4, 5, 8, 4, 3, 8, 3],
         [3, 5, 5, 9, 4, 2, 1, 9],
         [3, 5, 2, 4, 4, 2, 8, 9],
         [3, 8, 1, 1, 3, 7, 0, 8]]])

2.2 torch.stack 函数的使用

torch.stack 函数可以将两个张量根据指定的维度叠加起来.

代码语言:javascript
复制
import torch


def test():

    data1= torch.randint(0, 10, [2, 3])
    data2= torch.randint(0, 10, [2, 3])
    print(data1)
    print(data2)

    new_data = torch.stack([data1, data2], dim=0)
    print(new_data.shape)

    new_data = torch.stack([data1, data2], dim=1)
    print(new_data.shape)

    new_data = torch.stack([data1, data2], dim=2)
    print(new_data)


if __name__ == '__main__':
    test()

程序输出结果:

代码语言:javascript
复制
tensor([[5, 8, 7],
        [6, 0, 6]])
tensor([[5, 8, 0],
        [9, 0, 1]])
torch.Size([2, 2, 3])
torch.Size([2, 2, 3])
tensor([[[5, 5],
         [8, 8],
         [7, 0]],

        [[6, 9],
         [0, 0],
         [6, 1]]])

2.3 小节

张量的拼接操作也是在后面我们经常使用一种操作。cat 函数可以将张量按照指定的维度拼接起来,stack 函数可以将张量按照指定的维度叠加起来。

3. 张量索引操作

我们在操作张量时,经常需要去进行获取或者修改操作,掌握张量的花式索引操作是必须的一项能力。

3.1 简单行、列索引

准备数据

代码语言:javascript
复制
import torch

data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)

程序输出结果:

代码语言:javascript
复制
tensor([[0, 7, 6, 5, 9],
        [6, 8, 3, 1, 0],
        [6, 3, 8, 7, 3],
        [4, 9, 5, 3, 1]])
--------------------------------------------------
代码语言:javascript
复制
# 1. 简单行、列索引
def test01():

    print(data[0])
    print(data[:, 0])
    print('-' * 50)

if __name__ == '__main__':
    test01()

程序输出结果:

代码语言:javascript
复制
tensor([0, 7, 6, 5, 9])
tensor([0, 6, 6, 4])
--------------------------------------------------

3.2 列表索引

代码语言:javascript
复制
# 2. 列表索引
def test02():

    # 返回 (0, 1)、(1, 2) 两个位置的元素
    print(data[[0, 1], [1, 2]])
    print('-' * 50)

    # 返回 0、1 行的 1、2 列共4个元素
    print(data[[[0], [1]], [1, 2]])
if __name__ == '__main__':
    test02()

程序输出结果:

代码语言:javascript
复制
tensor([7, 3])
--------------------------------------------------
tensor([[7, 6],
        [8, 3]])

3.3 范围索引

代码语言:javascript
复制
# 3. 范围索引
def test03():
    # 前3行的前2列数据
    print(data[:3, :2])
    # 第2行到最后的前2列数据
    print(data[2:, :2])
if __name__ == '__main__':
    test03()

程序输出结果:

代码语言:javascript
复制
tensor([[0, 7],
        [6, 8],
        [6, 3]])
tensor([[6, 3],
        [4, 9]])

3.4 布尔索引

代码语言:javascript
复制
# 布尔索引
def test():

    # 第三列大于5的行数据
    print(data[data[:, 2] > 5])
    # 第二行大于5的列数据
    print(data[:, data[1] > 5])
if __name__ == '__main__':
    test04()

程序输出结果:

代码语言:javascript
复制
tensor([[0, 7, 6, 5, 9],
        [6, 3, 8, 7, 3]])
tensor([[0, 7],
        [6, 8],
        [6, 3],
        [4, 9]])

3.5 多维索引

代码语言:javascript
复制
# 多维索引
def test05():

    data = torch.randint(0, 10, [3, 4, 5])
    print(data)
    print('-' * 50)

    print(data[0, :, :])
    print(data[:, 0, :])
    print(data[:, :, 0])


if __name__ == '__main__':
    test05()

程序输出结果:

代码语言:javascript
复制
tensor([[[2, 4, 1, 2, 3],
         [5, 5, 1, 5, 0],
         [1, 4, 5, 3, 8],
         [7, 1, 1, 9, 9]],

        [[9, 7, 5, 3, 1],
         [8, 8, 6, 0, 1],
         [6, 9, 0, 2, 1],
         [9, 7, 0, 4, 0]],

        [[0, 7, 3, 5, 6],
         [2, 4, 6, 4, 3],
         [2, 0, 3, 7, 9],
         [9, 6, 4, 4, 4]]])
--------------------------------------------------
tensor([[2, 4, 1, 2, 3],
        [5, 5, 1, 5, 0],
        [1, 4, 5, 3, 8],
        [7, 1, 1, 9, 9]])
tensor([[2, 4, 1, 2, 3],
        [9, 7, 5, 3, 1],
        [0, 7, 3, 5, 6]])
tensor([[2, 5, 1, 7],
        [9, 8, 6, 9],
        [0, 2, 2, 9]])

4. 张量形状操作

在我们后面搭建网络模型时,数据都是基于张量形式的表示,网络层与层之间很多都是以不同的 shape 的方式进行表现和运算,我们需要掌握对张量形状的操作,以便能够更好处理网络各层之间的数据连接。

4.1 reshape 函数的用法

reshape 函数可以在保证张量数据不变的前提下改变数据的维度,将其转换成指定的形状,在后面的神经网络学习时,会经常使用该函数来调节数据的形状,以适配不同网络层之间的数据传递。

代码语言:javascript
复制
import torch
import numpy as np


def test():

    data = torch.tensor([[10, 20, 30], [40, 50, 60]])

    # 1. 使用 shape 属性或者 size 方法都可以获得张量的形状
    print(data.shape, data.shape[0], data.shape[1])
    print(data.size(), data.size(0), data.size(1))

    # 2. 使用 reshape 函数修改张量形状
    new_data = data.reshape(1, 6)
    print(new_data.shape)


if __name__ == '__main__':
    test()

程序运行结果:

代码语言:javascript
复制
torch.Size([2, 3]) 2 3
torch.Size([2, 3]) 2 3
torch.Size([1, 6])

4.2 transpose 和 permute 函数的使用

transpose 函数可以实现交换张量形状的指定维度, 例如: 一个张量的形状为 (2, 3, 4) 可以通过 transpose 函数把 3 和 4 进行交换, 将张量的形状变为 (2, 4, 3)

permute 函数可以一次交换更多的维度。

代码语言:javascript
复制
import torch
import numpy as np


def test():

    data = torch.tensor(np.random.randint(0, 10, [3, 4, 5]))
    print('data shape:', data.size())

    # 1. 交换1和2维度
    new_data = torch.transpose(data, 1, 2)
    print('data shape:', new_data.size())

    # 2. 将 data 的形状修改为 (4, 5, 3)
    new_data = torch.transpose(data, 0, 1)
    new_data = torch.transpose(new_data, 1, 2)
    print('new_data shape:', new_data.size())

    # 3. 使用 permute 函数将形状修改为 (4, 5, 3)
    new_data = torch.permute(data, [1, 2, 0])
    print('new_data shape:', new_data.size())


if __name__ == '__main__':
    test()

程序运行结果:

代码语言:javascript
复制
data shape: torch.Size([3, 4, 5])
data shape: torch.Size([3, 5, 4])
new_data shape: torch.Size([4, 5, 3])
new_data shape: torch.Size([4, 5, 3])

4.3 view 和 contigous 函数的用法

view 函数也可以用于修改张量的形状,但是其用法比较局限,只能用于存储在整块内存中的张量。在 PyTorch 中,有些张量是由不同的数据块组成的,它们并没有存储在整块的内存中,view 函数无法对这样的张量进行变形处理,例如: 一个张量经过了 transpose 或者 permute 函数的处理之后,就无法使用 view 函数进行形状操作。

代码语言:javascript
复制
import torch
import numpy as np


def test():

    data = torch.tensor([[10, 20, 30], [40, 50, 60]])
    print('data shape:', data.size())

    # 1. 使用 view 函数修改形状
    new_data = data.view(3, 2)
    print('new_data shape:', new_data.shape)

    # 2. 判断张量是否使用整块内存
    print('data:', data.is_contiguous())  # True

    # 3. 使用 transpose 函数修改形状
    new_data = torch.transpose(data, 0, 1)
    print('new_data:', new_data.is_contiguous())  # False
    # new_data = new_data.view(2, 3)  # RuntimeError

    # 需要先使用 contiguous 函数转换为整块内存的张量,再使用 view 函数
    print(new_data.contiguous().is_contiguous())
    new_data = new_data.contiguous().view(2, 3)
    print('new_data shape:', new_data.shape)


if __name__ == '__main__':
    test()

程序运行结果:

代码语言:javascript
复制
data shape: torch.Size([2, 3])
new_data shape: torch.Size([3, 2])
data: True
new_data: False
True
new_data shape: torch.Size([2, 3])

4.4 squeeze 和 unsqueeze 函数的用法

squeeze 函数用删除 shape 为 1 的维度,unsqueeze 在每个维度添加 1, 以增加数据的形状。

代码语言:javascript
复制
import torch
import numpy as np


def test():

    data = torch.tensor(np.random.randint(0, 10, [1, 3, 1, 5]))
    print('data shape:', data.size())

    # 1. 去掉值为1的维度
    new_data = data.squeeze()
    print('new_data shape:', new_data.size())  # torch.Size([3, 5])

    # 2. 去掉指定位置为1的维度,注意: 如果指定位置不是1则不删除
    new_data = data.squeeze(2)
    print('new_data shape:', new_data.size())  # torch.Size([3, 5])

    # 3. 在2维度增加一个维度
    new_data = data.unsqueeze(-1)
    print('new_data shape:', new_data.size())  # torch.Size([3, 1, 5, 1])


if __name__ == '__main__':
    test()

程序运行结果:

代码语言:javascript
复制
data shape: torch.Size([1, 3, 1, 5])
new_data shape: torch.Size([3, 5])
new_data shape: torch.Size([1, 3, 5])
new_data shape: torch.Size([1, 3, 1, 5, 1])

4.5 小节

本小节带着同学们学习了经常使用的关于张量形状的操作,我们用到的主要函数有:

  1. reshape 函数可以在保证张量数据不变的前提下改变数据的维度.
  2. transpose 函数可以实现交换张量形状的指定维度, permute 可以一次交换更多的维度.
  3. view 函数也可以用于修改张量的形状, 但是它要求被转换的张量内存必须连续,所以一般配合 contiguous 函数使用.
  4. squeeze 和 unsqueeze 函数可以用来增加或者减少维度.
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-09-03,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 张量类型转换
    • 1.1 张量转换为 numpy 数组
      • 1.2 numpy 转换为张量
        • 1.3 标量张量和数字的转换
          • 1.4 小节
          • 2. 张量拼接操作
            • 2.1 torch.cat 函数的使用
              • 2.2 torch.stack 函数的使用
                • 2.3 小节
                • 3. 张量索引操作
                  • 3.1 简单行、列索引
                    • 3.2 列表索引
                      • 3.3 范围索引
                        • 3.4 布尔索引
                          • 3.5 多维索引
                          • 4. 张量形状操作
                            • 4.1 reshape 函数的用法
                              • 4.2 transpose 和 permute 函数的使用
                                • 4.3 view 和 contigous 函数的用法
                                  • 4.4 squeeze 和 unsqueeze 函数的用法
                                    • 4.5 小节
                                    领券
                                    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档