前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch之autograd

pytorch之autograd

作者头像
Tom2Code
发布2024-05-27 17:31:16
940
发布2024-05-27 17:31:16
举报
文章被收录于专栏:TomTom

在医院闲来无事,记录一个小参数,叫做retain_graph

先来学习两段代码,来比较其异同

代码语言:javascript
复制
import torch
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)  

y.backward(retain_graph=True)
print(w.grad)
y.backward()
print(w.grad)

输出:

代码语言:javascript
复制
tensor([5.])
tensor([10.])

第二段代码:

代码语言:javascript
复制
import torch
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward()
print(w.grad)
y.backward()
print(w.grad)

但是就会报错:

代码语言:javascript
复制
tensor([5.])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 12
     10 y.backward()
     11 print(w.grad)
---> 12 y.backward()
     13 print(w.grad)

File ~\.conda\envs\torchgpu\lib\site-packages\torch\_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    482 if has_torch_function_unary(self):
    483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
   (...)
    490         inputs=inputs,
    491     )
--> 492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )

File ~\.conda\envs\torchgpu\lib\site-packages\torch\autograd\__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    246     retain_graph = create_graph
    248 # The reason we repeat the same comment below is that
    249 # some Python versions print out the first line of a multi-line function
    250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252     tensors,
    253     grad_tensors_,
    254     retain_graph,
    255     create_graph,
    256     inputs,
    257     allow_unreachable=True,
    258     accumulate_grad=True,
    259 )

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

这是为什么呢,这里就要介绍一下本次要学习的参数了:

首先看一个函数的原型:

代码语言:javascript
复制
torch.autograd.backward(
        tensors, 
        grad_tensors=None, 
        retain_graph=None, 
        create_graph=False, 
        grad_variables=None, 
        inputs=None)

这次我们来介绍 retain_graph.

我们都知道pytorch是经典的动态图,所以这个参数retain_graph是一个布尔类型的值,它的true or false直接说明了在计算过程中是否保留图

代码语言:javascript
复制
retain_graph (bool, optional) – 
是否需要保留计算图。pytorch的机制是在方向传播结束时,
计算图释放以节省内存。大家可以尝试连续使用loss.backward(),
就会报错。如果需要多次求导,则在执行backward()时,retain_graph=True。

上面我们第二段代码,恰恰是计算了两次w的倒数,所以就会出现报错,所以,如果我们要计算多次导数,就要设置这个参数为true。

因为会累加梯度,所以我们在训练模型的时候经常需要设计zero_grad()这也是为了防止梯度爆炸

下面是一个手动结算的示意图,很简单,大佬勿喷。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-05-24,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Tom的小院 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档