首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >Pytorch预训练模型以及修改

Pytorch预训练模型以及修改

作者头像
狼啸风云
修改于 2022-09-02 14:30:58
修改于 2022-09-02 14:30:58
20.9K0
举报

pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnetdensenetinceptionresnetsqueezenetvgg等常用网络结构,并且提供了预训练模型,可通过调用来读取网络结构和预训练模型(模型参数)。往往为了加快学习进度,训练的初期直接加载pretrain模型中预先训练好的参数。加载model如下所示:

代码语言:python
AI代码解释
复制
import torchvision.models as models
  1. 加载网络结构和预训练参数:resnet34 = models.resnet34(pretrained=True)
  2. 只加载网络结构,不加载预训练参数,即不需要用预训练模型的参数来初始化:
代码语言:python
AI代码解释
复制
resnet18 = models.resnet18(pretrained=False) #pretrained参数默认是False,为了代码清晰,最好还是加上参数赋值.
print resnet18 #打印网络结构
resnet18.load_state_dict(torch.load(path_params.pkl)) #其中,path_params.pkl为预训练模型参数的保存路径。

加载预先下载好的预训练参数到resnet18,用预训练模型的参数初始化resnet18的层,此时resnet18发生了改变。

调用modelload_state_dict方法用预训练的模型参数来初始化自己定义的新网络结构,这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。

load_state_dict方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度)。

当新定义的网络(model_dict)和预训练网络(pretrained_dict)的层名不严格相等时,需要先将pretrained_dict里不属于model_dict的键剔除掉 :

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict},再用预训练模型参数更新model_dict,最后用load_state_dict方法初始化自己定义的新网络结构。

代码语言:python
AI代码解释
复制
print resnet18 #打印的还是网络结构

# 注意: cnn = resnet18.load_state_dict(torch.load( path_params.pkl )) #是错误的,这样cnn将是nonetype

pre_dict = resnet18.state_dict() #按键值对将模型参数加载到pre_dict

print for k, v in pre_dict.items(): # 打印模型参数

for k, v in pre_dict.items():
  print k  #打印模型每层命名

# model是自己定义好的新网络模型,将pretrained_dict和model_dict中命名一致的层加入

# pretrained_dict(包括参数)。

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 

 预训练模型的修改(具体要求不同,则用到的修改方式不同。)

1、参数修改  对于简单的参数修改,这里以resnet预训练模型举例,resnet源代码在Github。 resnet网络最后一层分类层fc是对1000种类型进行划分,对于自己的数据集,如果只有9类,修改的代码如下:

代码语言:python
AI代码解释
复制
# coding=UTF-8 
import torchvision.models as models 
#调用模型 
model = models.resnet50(pretrained=True) 
#提取fc层中固定的参数 
fc_features = model.fc.in_features 
#修改类别为9 
model.fc = nn.Linear(fc_features, 9)

2、增减卷积层

前一种方法只适用于简单的参数修改,有时候往往要修改网络中的层次结构,这时只能用参数覆盖的方法,即自己先定义一个类似的网络,再将预训练中的参数提取到自己的网络中来。这里以resnet预训练模型举例。

3、训练特定层,冻结其它层

另一种使用预训练模型的方法是对它进行部分训练。具体做法是,将模型起始的一些层的权重保持不变,重新训练后面的层,得到新的权重。在这个过程中,可多次进行尝试,从而能够依据结果找到 frozen layers 和 retrain layers 之间的最佳搭配。 如何使用预训练模型,是由数据集大小和新旧数据集(预训练的数据集和自己要解决的数据集)之间数据的相似度来决定的。 下图表展示了在各种情况下应该如何使用预训练模型:

一、是保存整个神经网络的的结构信息和模型参数信息,save的对象是网络net:

代码语言:python
AI代码解释
复制
torch.save(model_object, 'model.pkl')  # 保存整个神经网络的结构和模型参数 

重载:

代码语言:python
AI代码解释
复制
model = torch.load('model.pkl') # 重载并初始化新的神经网络对象。 

二、是只保存神经网络的训练模型参数,save的对象是net.state_dict()

代码语言:python
AI代码解释
复制
torch.save(model_object.state_dict(), 'params.pkl')  # 只保存神经网络的模型参数

需要首先导入对应的网络,通过model_object.load_state_dict(torch.load('params.pkl'))完成模型参数的重载和初始化新定义的网络。

PyTorch中使用预训练的模型初始化网络的一部分参数:

代码语言:python
AI代码解释
复制
#首先自己新定义一个网络
class CNN(nn.Module):
  def __init__(self, block, layers, num_classes=9): 
    #自己新定义的CNN与继承的ResNet网络结构大体相同,即除了新增层,其他层的层名与ResNet的相同。
 
   self.inplanes = 64 
    super(ResNet, self).__init__() #继承ResNet网络结构
    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 
    self.bn1 = nn.BatchNorm2d(64) 
    self.relu = nn.ReLU(inplace=True) 
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 
    self.layer1 = self._make_layer(block, 64, layers[0]) 
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 
    self.avgpool = nn.AvgPool2d(7, stride=1)
   
      #新增一个反卷积层 
    self.convtranspose1 = nn.ConvTranspose2d(2048, 2048, kernel_size=3, stride=1, padding=1, output_padding=0, groups=1, bias=False, dilation=1) 
   
      #新增一个最大池化层 
    self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 
    
       #将原来的fc层改成fclass层 
    self.fclass = nn.Linear(2048, num_classes) #原来的fc层:self.fc = nn.Linear(512 * block.expansion, num_classes)
    for m in self.modules(): #
      if isinstance(m, nn.Conv2d): 
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels #
        m.weight.data.normal_(0, math.sqrt(2. / n)) 
      elif isinstance(m, nn.BatchNorm2d): 
        m.weight.data.fill_(1) 
        m.bias.data.zero_() 
  def _make_layer(self, block, planes, blocks, stride=1): 
    downsample = None 
    if stride != 1 or self.inplanes != planes * block.expansion: 
      downsample = nn.Sequential( 
      nn.Conv2d(self.inplanes, planes * block.expansion, 
          kernel_size=1, stride=stride, bias=False), 
          nn.BatchNorm2d(planes * block.expansion), 
            ) 
    layers = [ ] 
    layers.append(block(self.inplanes, planes, stride, downsample)) 
    self.inplanes = planes * block.expansion 
    for i in range(1, blocks): 
      layers.append(block(self.inplanes, planes)) 
    return nn.Sequential(*layers) 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.bn1(x) 
    x = self.relu(x) 
    x = self.maxpool(x) 
    x = self.layer1(x) 
    x = self.layer2(x) 
    x = self.layer3(x) 
    x = self.layer4(x) 
    x = self.avgpool(x) 
    #3个新加层的forward 
    x = x.view(x.size(0), -1) 
       
       #因为接下来的self.convtranspose1层的输入通道是2048
    x = self.convtranspose1(x) 
    x = self.maxpool2(x) 
    x = x.view(x.size(0), -1)  
       
       #因为接下来的self.fclass层的输入通道是2048 
    x = self.fclass(x) 
    return x
#加载model 
resnet50 = models.resnet50(pretrained=True) 
cnn = CNN(Bottleneck, [3, 4, 6, 3]) #创建一个自己新定义的网络对象cnn。

pretrained_dict = resnet50.state_dict() 用来记录预训练模型的参数:resnet50.state_dict()。若已存在 resnet50.state_dict()对应的模型参数文件 'params.pkl',则此句代码等价于:pretrained_dict =torch.load(path_params.pkl) ?其中,path_params.pkl为' params.pkl '的保存路径

代码语言:python
AI代码解释
复制
model_dict = cnn.state_dict()  #自己新定义网络的参数

pretrained_dict里不属于model_dict的键剔除掉 ,因为后面的cnn.load_state_dict()方法有个重要参数是strict,默认是True,表示预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度)。

代码语言:python
AI代码解释
复制
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} #只能对层名一致的层进行“层名:参数”键值对赋值。
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的模型参数state_dict
cnn.load_state_dict(model_dict) # cnn.load_state_dict()方法对cnn初始化,其一个重要参数strict,默认为True,表示预训练模型(model_dict)的层和自己定义的网络结构(cnn)的层严格对应相等(比如层名和维度)。
print(cnn)

后续在此基础上继续重新进行训练,如下面即将介绍的: 选择特定的层进行finetune。

1、选择特定的层进行finetune 先使用Module.children()方法查看网络的直接子模块,将不需要调整的模块中的参数设置为param.requires_grad = False,同时用一个list收集需要调整的模块中的参数。具体代码为:

代码语言:python
AI代码解释
复制
count = 0
para_optim = []
for k in model.children():
  count += 1
  if count > 6:   # 6 should be changed properly
    for param in k.parameters():
      para_optim.append(param)
  else:
    for param in k.parameters():
      param.requires_grad = False
optimizer = optim.RMSprop(para_optim, lr)#只对特定的层的参数进行优化更新,即选择特定的层进行finetune。
到此我们实现了PyTorch中使用预训练的模型初始化网络的一部分参数。

此部分主要参考PyTorch教程的Autograd machnics部分  1.在PyTorch中,每个Variable数据含有两个flag(requires_gradvolatile)用于指示是否计算此Variable的梯度。设置requires_grad = False,或者设置volatile=True,即可指示不计算此Variable的梯度:

代码语言:python
AI代码解释
复制
for param in model.parameters():
   param.requires_grad = False

注意,在模型测试时,对input_data设置volatile=True,可以节省测试时的显存 。

2、PyTorch的Module.modules()Module.children()

参考PyTorch document和discuss 在PyTorch中,所有的neural network module都是class torch.nn.Module的子类,在Modules中可以包含其它的Modules,以一种树状结构进行嵌套。当需要返回神经网络中的各个模块时,Module.modules()方法返回网络中所有模块的一个iterator,而Module.children()方法返回所有直接子模块的一个iterator。具体而言:

代码语言:javascript
AI代码解释
复制
 list ( nn.Sequential(nn.Linear(10, 20), nn.ReLU()).modules() )
 Out[9]:
 [Sequential (
 (0): Linear (10 -> 20)
 (1): ReLU ()
 ), Linear (10 -> 20), ReLU ()]
 In [10]: list( nn.Sequential(nn.Linear(10, 20), nn.ReLU()) .children() )
 Out[10]: [Linear (10 -> 20), ReLU ()]

 举例:Faster-RCNN基于vgg19提取features,但是只使用了vgg19一部分模型提取features。

步骤:

下载vgg19的pth文件,在anaconda中直接设置pretrained=True下载一般都比较慢,在model_zoo里面有各种预训练模型的下载链接:

代码语言:javascript
AI代码解释
复制
 model_urls = {
 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth'  }

下载好的模型,可以用下面这段代码看一下模型参数,并且改一下模型。在vgg19.pth同级目录建立一个test.py。

代码语言:javascript
AI代码解释
复制
import torch
 import torch.nn as nn
 import torchvision.models as models
vgg16 = models.vgg16(pretrained=False)
#打印出预训练模型的参数
 vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
 print('vgg16:\n', vgg16) 
modified_features = nn.Sequential(*list(vgg16.features.children())[:-1])
 # to relu5_3
 print('modified_features:\n', modified_features )#打印修改后的模型参数

修改好之后features就可以拿去做Faster-RCNN提取特征用了。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020/05/03 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
PyTorch—torchvision.models导入预训练模型—残差网络代码讲解
PyTorch框架中torchvision模块下有:torchvision.datasets、torchvision.models、torchvision.transforms这3个子包。 关于详情请参考官网: http://pytorch.org/docs/master/torchvision/index.html。 具体代码可以参考github: https://github.com/pytorch/vision/tree/master/torchvision。
全栈程序员站长
2022/09/12
1.9K0
PyTorch—torchvision.models导入预训练模型—残差网络代码讲解
PyTorch源码解读之torchvision.models「建议收藏」
PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms。这3个子包的具体介绍可以参考官网:http://pytorch.org/docs/master/torchvision/index.html。具体代码可以参考github:https://github.com/pytorch/vision/tree/master/torchvision。
全栈程序员站长
2022/09/07
1.1K0
【猫狗数据集】使用预训练的resnet18模型
链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw 提取码:2xq4
西西嘛呦
2020/08/26
3.2K0
【猫狗数据集】使用预训练的resnet18模型
PyTorch 的预训练,是时候学习一下了
前言 最近使用 PyTorch 感觉妙不可言,有种当初使用 Keras 的快感,而且速度还不慢。各种设计直接简洁,方便研究,比 tensorflow 的臃肿好多了。今天让我们来谈谈 PyTorch 的预训练,主要是自己写代码的经验以及论坛 PyTorch Forums 上的一些回答的总结整理。 直接加载预训练模型 如果我们使用的模型和原模型完全一样,那么我们可以直接加载别人训练好的模型: my_resnet = MyResNet(*args, **kwargs) my_resnet.load_st
AI研习社
2018/03/28
1.3K0
通过和resnet18和resnet50理解PyTorch的ResNet模块
resnet和resnext的框架基本相同的,这里先学习下resnet的构建,感觉高度模块化,很方便。本文算是对 PyTorch源码解读之torchvision.modelsResNet代码的详细理解,另外,强烈推荐这位大神的PyTorch的教程!
全栈程序员站长
2022/09/01
1.9K0
通过和resnet18和resnet50理解PyTorch的ResNet模块
[Pytorch][转载]resnet模型实现
本文源自Pytorch官方:https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py import torch
云未归来
2025/07/18
1760
你知道Deeplab那些事儿吗?
DeepLab系列论文一共有四篇,分别对应DeepLab V1,DeepLab V2,DeepLab V3,DeepLab V3+。
灿视学长
2021/05/28
9190
【最强ResNet改进系列】Res2Net:一种新的多尺度网络结构,性能提升显著
【导读】2020年,在各大CV顶会上又出现了许多基于ResNet改进的工作,比如:Res2Net,ResNeSt,IResNet,SCNet等等。为了更好的了解ResNet整个体系脉络的发展,我们开设了一个最强ResNet改进系列专题,主要为大家介绍2020年最新发表在顶会顶刊上基于ResNet改进的论文,这些论文的创新点很值得参考借鉴!本文是【最强ResNet改进系列】第一篇文章,本文我们将着重讲解Res2Net,该论文已被TPAMI2020录用,另外ResNeSt的论文解读见:【CV中的注意力机制】史上最强"ResNet"变体--ResNeSt,下一篇我们将直接来讲解IResNet
深度学习技术前沿公众号博主
2020/08/11
10.9K0
resnet34 pytorch_pytorch环境搭建
导师的课题需要用到图片分类;入门萌新啥也不会,只需要实现这个功能,给出初步效果,不需要花太多时间了解内部逻辑。经过一周的摸索,建好环境、pytorch,终于找到整套的代码和数据集,实现了一个小小的分类。记录一下使用方法,避免后续使用时遗忘。感谢各位大佬的开源代码和注释!
全栈程序员站长
2022/09/27
9330
resnet34 pytorch_pytorch环境搭建
ResNet详细解读
这篇文章是Deep Residual Learning for Image Recognition 的翻译,精简部分内容的同时补充了相关的概念,如有错误,敬请指正。
全栈程序员站长
2022/09/01
2.2K0
ResNet详细解读
DenseNet:比ResNet更优的CNN模型
本篇文章首先介绍DenseNet的原理以及网路架构,然后讲解DenseNet在Pytorch上的实现。
机器学习算法工程师
2018/07/27
1.8K0
DenseNet:比ResNet更优的CNN模型
深度学习算法优化系列八 | VGG,ResNet,DenseNe模型剪枝代码实战
具体原理已经讲过了,见上回的推文。深度学习算法优化系列七 | ICCV 2017的一篇模型剪枝论文,也是2019年众多开源剪枝项目的理论基础 。这篇文章是从源码实战的角度来解释模型剪枝,源码来自:https://github.com/Eric-mingjie/network-slimming 。我这里主要是结合源码来分析每个模型的具体剪枝过程,希望能给你剪枝自己的模型一些启发。
BBuf
2020/02/12
2.5K0
深度学习算法优化系列八 | VGG,ResNet,DenseNe模型剪枝代码实战
[Pytorch][转载]VGG模型实现
本文源自Pytoch官方:https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py import torch
云未归来
2025/07/18
1590
pytorch笔记:04)resnet网络&解决输入图像大小问题「建议收藏」
因为torchvision对resnet18-resnet152进行了封装实现,因而想跟踪下源码
全栈程序员站长
2022/09/01
5.1K0
pytorch笔记:04)resnet网络&解决输入图像大小问题「建议收藏」
VGG16 训练猫狗数据集
准备数据应该是一件比较麻烦的过程,所以一般都去找那种公开的数据集。在网上找到的可以用于猫狗分类的数据集有 Kaggle 的 “Dogs vs. Cats”数据集,还有牛津大学提供的 Oxford-IIIT Pet 数据集,包含猫和狗的图片,都是非常适合做猫狗分类任务的公开数据集。
繁依Fanyi
2025/03/24
3220
卷积神经网络及经典模型
虽然图片识别对于人来说是一件轻松的事情,但是对于计算机来说,由于接受的是一串数字,对于同一个物体,表示这个物体的数字可能会有很大的不同,所以使用算法来实现这一任务还是有很多挑战的,具体来说:
Here_SDUT
2022/09/19
4.7K0
卷积神经网络及经典模型
一个小改动,CNN输入固定尺寸图像改为任意尺寸图像
本文小白将和大家一起学习如何在不使用计算量很大的滑动窗口的情况下对任意尺寸的图像进行图像分类。通过修改,将ResNet-18CNN框架需要224×224尺寸的图像输入改为任意尺寸的图像输入。
小白学视觉
2020/06/01
9K0
PyTorch ImageNet 基于预训练六大常用图片分类模型的实战
在本教程中,我们将深入探讨如何对 torchvision 模型进行微调和特征提取,所有这些模型都已经预先在1000类的Imagenet数据集上训练完成。 本教程将深入介绍如何使用几个现代的CNN架构,并将直观展示如何微调任意的PyTorch模型。由于每个模型架构是有差异的,因此没有 可以在所有场景中使用的微调代码样板。然而,研究人员必须查看现有架构并对每个模型进行自定义调整。
磐创AI
2019/09/19
5.4K0
PyTorch ImageNet 基于预训练六大常用图片分类模型的实战
搞懂Vision Transformer 原理和代码,看这篇技术综述就够了(三)
本文为详细解读Vision Transformer的第三篇,主要解读了两篇关于Transformer在识别任务上的演进的文章:DeiT与VT。它们的共同特点是避免使用巨大的非公开数据集,只使用ImageNet训练Transformer。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
godweiyang
2021/04/08
6.3K0
搞懂Vision Transformer 原理和代码,看这篇技术综述就够了(三)
Pytorch模型训练实用教程学习笔记:二、模型的构建
最近在重温Pytorch基础,然而Pytorch官方文档的各种API是根据字母排列的,并不适合学习阅读。 于是在gayhub上找到了这样一份教程《Pytorch模型训练实用教程》,写得不错,特此根据它来再学习一下Pytorch。 仓库地址:https://github.com/TingsongYu/PyTorch_Tutorial
zstar
2022/09/20
6700
推荐阅读
相关推荐
PyTorch—torchvision.models导入预训练模型—残差网络代码讲解
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档