前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >Pytorch 拷贝数据

Pytorch 拷贝数据

作者头像
为为为什么
发布2022-08-06 10:14:06
发布2022-08-06 10:14:06
1K00
代码可运行
举报
文章被收录于专栏:又见苍岚又见苍岚
运行总次数:0
代码可运行

本文记录Pytorch拷贝数据的几种方法,clone(), detach(), new_tensor(), copy_()。

1、clone()

clone()函数返回一个和源张量同shape、dtype和device的张量,与源张量不共享数据内存,但提供梯度的回溯。

代码语言:javascript
代码运行次数:0
复制
import torch

a = torch.tensor(1.0, requires_grad=True)
y = a ** 2
a_ = a.clone()
z = a_ * 3
y.backward()
print(a.grad)  # 2
z.backward()
print(a_.grad)  # None, 无grad
print(a.grad)  # 5. a_的梯度会传递回给a,因此2+3=5


a = a + 1
print(a_) # 1

梯度回溯:进行的运算梯度会加在a(叶子节点)的梯度上。

代码语言:javascript
代码运行次数:0
复制
import torch

a = torch.tensor(1.0, requires_grad=True)
a_ = a.clone()
a_.add_(torch.tensor(1.0))
print(a)    # tensor(1., requires_grad=True)
print(a_)   # tensor(2., grad_fn=<AddBackward0>)
a_.backward()
print(a_.grad)  # None
print(a.grad)   # tensor(1.)
可以发现不共享内存空间。

2、detach()

detach()函数返回一个和源张量同shape、dtype和device的张量,并且与源张量共享数据内存,但不提供梯度的回溯。

代码语言:javascript
代码运行次数:0
复制
import torch

a = torch.tensor(1.0, requires_grad=True)
y = a ** 2
a_ = a.detach()
print(a_.grad)  # None,requires_grad=False
a_.requires_grad_()  # set a_.requires_grad = True
z = a_ * 3
y.backward()
z.backward()
print(a_.grad)  # tensor(3.)
print(a.grad)  # tensor(2.)

可见,a_即使重新定义requires_grad=True,也与a的梯度没有关系。

代码语言:javascript
代码运行次数:0
复制
import torch

a = torch.tensor(1.0, requires_grad=True)
a_ = a.detach()
a_.add_(torch.tensor(1.0))
print(a)    # tensor(2., requires_grad=True)
print(a_)   # tensor(2.)
# a_.backward()
# print(a.grad) # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

容易验证a与a_共享内存空间

3、torch.clone.detach() (建议赋值方法)

clone()提供了非数据内存共享的梯度追溯功能,而detach又“舍弃”了梯度回溯功能,因此clone.detach()只做简单的数据复制,既不数据共享,也不梯度共享,从此两个张量无关联。 置于是先clone还是先detach,其返回值一样,一般采用sourcetensor.clone().detach()。

4、new_tensor()

new_tensor()可以将源张量中的数据复制到目标张量(数据不共享),同时提供了更细致的属性控制:

代码语言:javascript
代码运行次数:0
复制
import torch

a = torch.tensor(1.0, device="cuda", dtype=torch.float32, requires_grad=True)
b = a.clone()
c = a.detach()
d = a.clone().detach()
e = a.new_tensor(a)  # more attributions could be setted
f = a.new_tensor(a, device="cpu", dtype=torch.float64, requires_grad=False)
print(a)  # tensor(1., device='cuda:0', requires_grad=True)
print(b)  # tensor(1., device='cuda:0', grad_fn=<CloneBackward>)
print(c)  # tensor(1., device='cuda:0')
print(d)  # tensor(1., device='cuda:0')
print(e)  # tensor(1., device='cuda:0')
print(f)  # tensor(1., dtype=torch.float64)

5、copy_()

代码语言:javascript
代码运行次数:0
复制
import torch

a = torch.tensor(1.0, device="cpu", requires_grad=False)
b = torch.tensor(2.0, device="cuda", requires_grad=True)
print(a)  # tensor(1.)
print(b)  # tensor(2., device='cuda:0', requires_grad=True)
a.copy_(b)
print(a)  # tensor(2., grad_fn=<CopyBackwards>)
print(a.device)  # cpu
print(a.requires_grad)  # True

copy_()会将b复制给a,同时改变 a 的 requires_grad 属性,但不改变 device 属性。(当a.requires_grad=False是copy_()方法会报错)

参考资料

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021年8月18日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1、clone()
  • 2、detach()
  • 3、torch.clone.detach() (建议赋值方法)
  • 4、new_tensor()
  • 5、copy_()
  • 参考资料
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档