首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >如果您在网络中间分离一个nn.module,那么在该模块之前是否所有模块都没有得到它们的梯度计算?

如果您在网络中间分离一个nn.module,那么在该模块之前是否所有模块都没有得到它们的梯度计算?
EN

Stack Overflow用户
提问于 2022-06-17 22:43:44
回答 1查看 210关注 0票数 0

假设我有一个输入和一个由A,net B和net C组成的序贯网络,如果我分离了B网,然后把X放到A->B->C中,因为B被分离了,我会失去A的梯度信息吗?我会假设没有?我假设它只是把B当作一个常数加到A的输出中,而不是一些可微性的东西。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-06-18 09:18:33

TLDR;B上防止梯度计算不会停止对上游网络A的梯度计算。

我认为你认为“脱模”是有一些混淆的。在我看来,这类事情有三件事要记住:

  • 你可以用一个张量有效地把它从计算图中分离出来,也就是说,如果这个张量被用来计算另一个需要梯度的张量,反向传播步骤就不会通过这个“分离的”张量。

  • 以您描述“分离模型”的方式,您可以通过将requires_grad切换到其参数上的False来禁用在给定的网络层上的梯度计算。这可以用nn.Module.requires_grad_在模块级别的一行中完成。因此,在您的示例中,执行B.requires_grad_(False)将冻结B的参数,使其无法更新。换句话说,不会计算B参数的梯度,但是用于传播到A will的中间梯度!这里有一个极小的example:A = nn.Linear(10,10) >>> B= nn.Linear(10,10) >>> C= nn.Linear(10,10) #nn.Linear(10,10)#禁用梯度计算的B >>> B.requires_grad_(False) #虚拟输入、推理和反向传播>>> x= torch.rand(1,10,requires_grad=True) >>> C(B(A(X).mean().backward()

现在我们可以检查CA的梯度是否确实被正确地填充了:

A.weight.grad.sum()张量(0.3281) >>> C.weight.grad.sum()张量(-1.6335)

但是,当然,None.返回B.weight.grad

最后,

  • 的另一种行为是在使用no_grad上下文管理器时。这有效地杀死了梯度。如果使用like:yA = A(x) >>>和torch.no_grad():.yB = B(yA) >>> yC = C(yB)

在这里,network.已经脱离了yC

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72665429

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档