前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch维度变换-补充知识

pytorch维度变换-补充知识

作者头像
用户6719124
发布2019-11-17 23:08:28
2K0
发布2019-11-17 23:08:28
举报
文章被收录于专栏:python pytorch AI机器学习实践

补充介绍一下转置操作

先建立矩阵a,分别输出a和a的转置矩阵

代码语言:javascript
复制
a = torch.randn(3, 4)
print(a)
print(a.t())

代码语言:javascript
复制
tensor([[-0.4018, -1.4217,  0.5778, -1.0832],
        [ 0.9451,  0.2730,  0.2420,  1.3747],
        [-1.3293,  1.5332, -1.1212,  0.8263]])
tensor([[-0.4018,  0.9451, -1.3293],
        [-1.4217,  0.2730,  1.5332],
        [ 0.5778,  0.2420, -1.1212],
        [-1.0832,  1.3747,  0.8263]])

需要注意的是转置功能只适用于2D的矩阵,而不适用于3D或4D的矩阵。

代码语言:javascript
复制
a = torch.randn([3, 4, 3, 3])
print(a.t())

此时输出会报错

代码语言:javascript
复制
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D

除了.t()方法外,还可以使用.transpose(d1, d2)函数。在使用时需要给d1

d2赋值,以给出调换的位置。

代码语言:javascript
复制
a = torch.randn([4, 3, 28, 28])
b = a.transpose(1, 3).view(4, 3*28*28).view(4, 3, 28, 28)
# 将原来的[b, c, h, w]=>[b, w, h, c]后,再将后面三个维度连在一起来理解,再展开成[b, c, w, h]
print(b)

此时输出会报错

代码语言:javascript
复制
RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at ..\aten\src\TH/generic/THTensor.cpp:203

报错原因在于view函数会破坏原来的元素顺序,展开时channel元素跑到了前面。因此在使用transpose和view函数时,要格外注意数据的维度顺序和存储顺序需保持一致。

这里可以使用.contiguous函数,将数据重新变成连续。

代码语言:javascript
复制
b = a.transpose(1, 3).contiguous().view(4, 3*28*28).view(4, 3, 28, 28)
print('b=', b.shape)

输出

代码语言:javascript
复制
b= torch.Size([4, 3, 28, 28])

但这里有一个问题,通过以上转换,矩阵经历了[b,c,h,w]=>[b,w,h,c]=>[b,c,w,h],这样虽然数据连续了,但这种转换方式会造成数据污染。

这里再介绍法2

代码语言:javascript
复制
c = a.transpose(1, 3).contiguous().view(4, 3*28*28).view(4, 28, 28, 3).transpose(1, 3)
# 以上经历了[b,c,h,w]=>[b,w,h,c]=>[b,w,h,c]=>[b,c,h,w]
print('c=', c.shape)

输出

代码语言:javascript
复制
c= torch.Size([4, 3, 28, 28])

以上两种方法虽然输出均为一致,但为验证有没有数据污染,使用torch.eq函数进行分析

代码语言:javascript
复制
print(torch.all(torch.eq(a, b)))
# 添加torch.all,确保所有数据均一致
print(torch.all(torch.eq(a, c)))

输出

代码语言:javascript
复制
tensor(0, dtype=torch.uint8)
tensor(1, dtype=torch.uint8)

返回0说明数据不一致,返回1说明数据一致。由此看出b虽然各数据维度与a相同,但已造成了数据污染,而c没有数据污染。

下面介绍一种更加方便的转置API: permute

与transpose每次只能两两交换不同的是,permute可以一次性给出四个维度上的位置。

.permute(d1, d2, d3, d4) 通过输入d1、d2、d3、d4的顺序即可完成。

如原[b,c,h,w]想要变成[b,h,w,c],只要输入.permute(0, 2, 3, 1)即可实现,而.transpose需要重复两两调换好几次。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-10-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档