Loading [MathJax]/jax/output/CommonHTML/config.js
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >BatchNorm2d原理、作用及其pytorch中BatchNorm2d函数的参数讲解

BatchNorm2d原理、作用及其pytorch中BatchNorm2d函数的参数讲解

作者头像
全栈程序员站长
发布于 2022-06-28 06:26:12
发布于 2022-06-28 06:26:12
3.3K00
代码可运行
举报
运行总次数:0
代码可运行

大家好,又见面了,我是你们的朋友全栈君。

BN原理、作用:

函数参数讲解:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

1.num_features:一般输入参数为batch_sizenum_featuresheight*width,即为其中特征的数量,即为输入BN层的通道数; 2.eps:分母中添加的一个值,目的是为了计算的稳定性,默认为:1e-5,避免分母为0; 3.momentum:一个用于运行过程中均值和方差的一个估计参数(我的理解是一个稳定系数,类似于SGD中的momentum的系数); 4.affine:当设为true时,会给定可以学习的系数矩阵gamma和beta 一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。 同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。 其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True。 trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。 如果BatchNorm2d的参数track_running_stats设置False,那么加载预训练后每次模型测试测试集的结果时都不一样;track_running_stats设置为True时,每次得到的结果都一样。 running_mean和running_var参数是根据输入的batch的统计特性计算的,严格来说不算是“学习”到的参数,不过对于整个计算是很重要的。BN层中的running_mean和running_var的更新是在forward操作中进行的,而不是在optimizer.step()中进行的,因此如果处于训练中泰,就算不进行手动step(),BN的统计特性也会变化。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
model.train() #处于训练状态
for data , label in self.dataloader:
    pred =model(data)  #在这里会更新model中的BN统计特性参数,running_mean,running_var
    loss=self.loss(pred,label)
    #就算不进行下列三行,BN的统计特性参数也会变化
    opt.zero_grad()
    loss.backward()
    opt.step()

这个时候,要用model.eval()转到测试阶段,才能固定住running_mean和running_var,有时候如果是先预训练模型然后加载模型,重新跑测试数据的时候,结果不同,有一点性能上的损失,这个时候基本上是training和track_running_stats设置的不对。 如果使用两个模型进行联合训练,为了收敛更容易控制,先预训练好模型model_A,并且model_A内还有若干BN层,后续需要将model_A作为一个inference推理模型和model_B联合训练,此时希望model_A中的BN的统计特性量running_mean和running_var不会乱变化,因此就需要将model_A.eval()设置到测试模型,否则在trainning模式下,就算是不去更新模型的参数,其BN都会变化,这将导致和预期不同的结果。

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/132951.html原文链接:https://javaforall.cn

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022年6月1,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
batchnorm pytorch_Pytorch中的BatchNorm
本文主要介绍在pytorch中的Batch Normalization的使用以及在其中容易出现的各种小问题,本来此文应该归属于[1]中的,但是考虑到此文的篇幅可能会比较大,因此独立成篇,希望能够帮助到各位读者。如有谬误,请联系指出,如需转载,请注明出处,谢谢。
全栈程序员站长
2022/09/05
7690
pytorch BatchNorm参数详解,计算过程
网络训练时和网络评估时,BatchNorm模块的计算方式不同。如果一个网络里包含了BatchNorm,则在训练时需要先调用train(),使网络里的BatchNorm模块的training=True(默认是True),在网络评估时,需要先调用eval(),使网络里的BatchNorm模块的training=False。
全栈程序员站长
2022/09/01
1.5K0
PyTorch 源码解读之 BN & SyncBN:BN 与 多卡同步 BN 详解
BatchNorm 最早在全连接网络中被提出,对每个神经元的输入做归一化。扩展到 CNN 中,就是对每个卷积核的输入做归一化,或者说在 channel 之外的所有维度做归一化。 BN 带来的好处有很多,这里简单列举几个:
OpenMMLab 官方账号
2022/02/21
2K0
【pytorch】bn
bn接口定义: torch.nn.BatchNorm2d: def init(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True) args: momentum: 默认为 0.1 。 要freeze的时候就设置为0.0(和 tf 里面是反着来的,tf是设置为1.0才能freeze)。 rack_running_stats: 计算running_mean和running
JNingWei
2021/12/06
5680
【PyTorch】详解pytorch中nn模块的BatchNorm2d()函数
在卷积神经网络的卷积层之后总会添加BatchNorm2d进行数据的归一化处理,这使得数据在进行Relu之前不会因为数据过大而导致网络性能的不稳定,BatchNorm2d()函数数学原理如下:
全栈程序员站长
2022/07/02
1.7K0
【PyTorch】详解pytorch中nn模块的BatchNorm2d()函数
PyTorch模型微调实例
from torch.utils.data import DataLoader, Dataset
代码医生工作室
2020/02/19
1.8K0
Pytorch-BN层详细解读
机器学习领域有个很重要的假设:独立同分布假设,即假设训练数据和测试数据是满足相同分布的。我们知道:神经网络的训练实际上就是在拟合训练数据的分布。如果不满足独立同分布假设,那么训练得到的模型的泛化能力肯定不好。
全栈程序员站长
2022/11/04
9540
Pytorch搭建DenseNet[通俗易懂]
首先回顾一下DenseNet的结构,DenseNet的每一层都都与前面层相连,实现了特征重用。
全栈程序员站长
2022/11/10
5520
Pytorch搭建DenseNet[通俗易懂]
[深度应用]·实战掌握PyTorch图片分类简明教程
深度学习的比赛中,图片分类是很常见的比赛,同时也是很难取得特别高名次的比赛,因为图片分类已经被大家研究的很透彻,一些开源的网络很容易取得高分。如果大家还掌握不了使用开源的网络进行训练,再慢慢去模型调优,很难取得较好的成绩。
小宋是呢
2019/06/27
5400
[深度应用]·实战掌握PyTorch图片分类简明教程
BN、LN、IN、GN、SN归一化
内容包含:BatchNormalization、LayerNormalization、InstanceNorm、GroupNorm、SwitchableNorm
zenRRan
2019/08/19
2.2K0
BN、LN、IN、GN、SN归一化
实战:掌握PyTorch图片分类的简明教程 | 附完整代码
深度学习的比赛中,图片分类是很常见的比赛,同时也是很难取得特别高名次的比赛,因为图片分类已经被大家研究的很透彻,一些开源的网络很容易取得高分。如果大家还掌握不了使用开源的网络进行训练,再慢慢去模型调优,很难取得较好的成绩。
AI科技大本营
2019/06/18
1K0
【猫狗数据集】使用预训练的resnet18模型
链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw 提取码:2xq4
西西嘛呦
2020/08/26
3.1K0
【猫狗数据集】使用预训练的resnet18模型
深度学习算法优化系列二 | 基于Pytorch的模型剪枝代码实战
昨天讲了一篇ICLR 2017《Pruning Filters for Efficient ConvNets》 ,相信大家对模型剪枝有一定的了解了。今天我就剪一个简单的网络,体会一下模型剪枝的魅力。本文的代码均放在我的github工程,我是克隆了一个原始的pytorch模型压缩工程,然后我最近会公开一些在这个基础上新增的自测结果,一些经典的网络压缩benchmark,一些有趣的实验。欢迎关注,github地址见文后。最后申明一下,本人处于初学阶段,肯定了解的知识很浅并且会犯很多错误,有错误之处欢迎大家指出并和我交流讨论。
BBuf
2019/12/24
3.7K0
【DL】规范化:你确定了解我吗?
Batch Normalization(以下简称 BN)出自 2015 年的一篇论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》,是最近几年来 DL 领域中非常重要的成功,并且已经被广泛证明其有效性和重要性。本篇文章将对此进行详细介绍,帮助大家更加深入的理解 BN。
阿泽 Crz
2020/07/21
1.1K0
【DL】规范化:你确定了解我吗?
Pytorch转Msnhnet模型思路分享
注意上面出现了一行if "num_batches_tracked" not in name:,这一行是Pytorch的一个坑点,在pytorch 0.4.1及后面的版本里,BatchNorm层新增了num_batches_tracked参数,用来统计训练时的forward过的batch数目,源码如下(pytorch0.4.1):
BBuf
2020/09/27
6430
实践教程|Grad-CAM的详细介绍和Pytorch代码实现
Grad-CAM (Gradient-weighted Class Activation Mapping) 是一种可视化深度神经网络中哪些部分对于预测结果贡献最大的技术。它能够定位到特定的图像区域,从而使得神经网络的决策过程更加可解释和可视化。
一点人工一点智能
2023/08/25
1.7K0
实践教程|Grad-CAM的详细介绍和Pytorch代码实现
【猫狗数据集】定义模型并进行训练模型
发现数据集没有完整的上传到谷歌的colab上去,我说怎么计算出来的step不对劲。
西西嘛呦
2020/08/26
7580
【猫狗数据集】定义模型并进行训练模型
逃不过呀!不论是训练还是部署都会让你踩坑的Batch Normalization
BN是2015年论文Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift提出的一种**数据归一化方法**。现在也是大多数神经网络结构的**标配**,我们可能已经**熟悉的不能再熟悉了**。
老潘
2021/08/07
3.4K0
PyTorch学习之归一化层(BatchNorm、LayerNorm、InstanceNorm、GroupNorm)[通俗易懂]
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
全栈程序员站长
2022/06/29
2.8K0
PyTorch学习之归一化层(BatchNorm、LayerNorm、InstanceNorm、GroupNorm)[通俗易懂]
基于Pytorch构建ResNet网络对cifar-10进行分类
何凯明等人在2015年提出的ResNet,在ImageNet比赛classification任务上获得第一名,获评CVPR2016最佳论文。
python与大数据分析
2023/09/03
7380
基于Pytorch构建ResNet网络对cifar-10进行分类
推荐阅读
相关推荐
batchnorm pytorch_Pytorch中的BatchNorm
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验