前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >CIFAR10数据集实战-ResNet网络构建(上)

CIFAR10数据集实战-ResNet网络构建(上)

作者头像
用户6719124
发布于 2020-01-14 02:51:06
发布于 2020-01-14 02:51:06
1.1K00
代码可运行
举报
运行总次数:0
代码可运行

本部分介绍如何采用ResNet解决CIFAR10分类问题。

之前讲到过,ResNet包含了短接模块(short cut)。本节主要介绍如何实现这个模块。

先建立resnet.py文件。

如图

先引入相关包

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch
import torch.nn as nn

准备构建resnet单元

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class ResBlk(nn.Module):
    # 与上节一样,同样resnet的block单元,继承nn模块
    def __init__(self):
        super(ResBlk, self).__init__()
        # 完成初始化

由ResNet特点可知,需要传入channel_in和channel_out才能进行运算,因此在定义中需要加入两个变量。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def __init__(self, ch_in, ch_out):

接下来像之前一样,写入其原先的卷积层。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
# 进行正则化处理,以使train过程更快更稳定
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)

Resnet 模块的左侧的部分写好了,

先不急着写右侧,先写左侧的forward代码

先引入工具包

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch.nn.functional as F

书写代码

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def forward(self, x):
    # 这里输入的是[b, ch, h, w]
    out = F.relu(self.bn1(self.conv1(x)))
    out = F.relu(self.bn2(self.conv2(out)))

下面开始写short cut代码

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
out = x + out
# 这便是element.wise add,实现了[b, ch_in, h, w][b, ch_out, h, w]两个的相加

同时要考虑,若两元素中的ch_in和ch_out不匹配,则运行时会报错。因此需要在前面指定添加if函数

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
if ch_out != ch_in:
    self.extra = nn.Sequential(
        nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
        nn.BatchNorm2d(ch_out),
    )

这段代码的意思即为实现[b, ch_in, h, w] => [b, ch_out, h, w]的转化

写好后,将element.wise add部分的x替换

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
out = self.extra(x) + out

这里也要考虑若ch_in和ch_out原先就相匹配的情况,则需要先进行定义。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
self.extra = nn.Sequential()

最后在定义后,返回结果out

至此resnet block模块构建完毕

现代码为

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlk(nn.Module):
    # 与上节一样,同样resnet的block单元,继承nn模块
    def __init__(self, ch_in, ch_out):
        super(ResBlk, self).__init__()
        # 完成初始化

        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        # 进行正则化处理,以使train过程更快更稳定
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()

        if ch_out != ch_in:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
                nn.BatchNorm2d(ch_out),
            )



    def forward(self, x):
        # 这里输入的是[b, ch, h, w]
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))


        out = self.extra(x) + out
        # 这便是element.wise add,实现了[b, ch_in, h, w][b, ch_out, h, w]两个的相加

        return out
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-01-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
深度学习实战之手写签名识别(100%准确率、语音播报)
在完成了上述的环境搭建后,即可进入到准备阶段了。这里准备的有数据集的准备、以及相关代码的主备。
陶陶name
2022/05/13
1.7K0
深度学习实战之垃圾分类
垃圾分类,指按一定规定或标准将垃圾分类储存、分类投放和分类搬运,从而转变成公共资源的一系列活动的总称。分类的目的是提高垃圾的资源价值和经济价值,力争物尽其用;然而我们在日常生活中认为对垃圾分类还是有些不知所措的,对干垃圾、湿垃圾……分的不是很清楚,由此我们就想到了使用深度学习的方法进行分类。简介 本篇博文主要会带领大家进行数据的预处理、网络搭建、模型训练、模型测试 1. 获取数据集 这里笔者已经为大家提供了一个比较完整的数据集,所以大家不必再自己去收集数据了 数据集链接:https://pan.baidu
陶陶name
2022/05/13
6420
CIFAR10数据集实战-ResNet网络构建(中)
再定义一个ResNet网络 我们本次准备构建ResNet-18层结构 class ResNet(nn.Module): def __init__(self): super(ResNet, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64
用户6719124
2020/02/24
6820
​​​​【动手学深度学习】残差网络(ResNet)的研究详情
启动jupyter notebook,使用新增的pytorch环境新建ipynb文件,为了检查环境配置是否合理,输入import torch以及torch.cuda.is_available() ,若返回TRUE则说明研究环境配置正确,若返回False但可以正确导入torch则说明pytorch配置成功,但研究运行是在CPU进行的,结果如下:
SarPro
2024/06/14
2990
​​​​【动手学深度学习】残差网络(ResNet)的研究详情
Pytorch-ResNet(残差网络)-下
在左图(准确率)的比较中,从AlexNet到GoogleNet再到ResNet,准确率逐渐提高。20层结构是很多网络结构性能提升的分水岭,在20层之前,模型性能提升较容易。但在20层之后,继续添加层数对性能的提升不是很明显。但ResNet很好地解决了高层数带来的误差叠加问题,因此性能也随着层数的增加而提升。
用户6719124
2019/12/04
1.1K0
CIFAR-10 数据集实战——构建ResNet18神经网络
Block中进行了正则化处理,以使train过程更快更稳定。同时要考虑,如果两元素的ch_in和ch_out不匹配,进行加法时会报错,因此需要判断一下,如果不想等,就用1×1的卷积调整一下
mathor
2020/01/22
1.7K0
基于Pytorch构建DenseNet网络对cifar-10进行分类
DenseNet是指Densely connected convolutional networks(密集卷积网络)。它的优点主要包括有效缓解梯度消失、特征传递更加有效、计算量更小、参数量更小、性能比ResNet更好。它的缺点主要是较大的内存占用。
python与大数据分析
2023/09/03
4940
基于Pytorch构建DenseNet网络对cifar-10进行分类
【深度学习实验】卷积神经网络(七):实现深度残差神经网络ResNet
输入数据通过上述序列模块self.b1、self.b2、self.b3、self.b4、self.b5和self.head进行处理,最终输出分类结果。
Qomolangma
2024/07/30
4500
【深度学习实验】卷积神经网络(七):实现深度残差神经网络ResNet
Pytorch 基于ResNet-18的物体分类(使用CIFAR-10数据集)
✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。 🍎个人主页:小嗷犬的博客 🍊个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。 🥭本文内容:Pytorch 基于ResNet-18的物体分类(使用CIFAR-10数据集) 更多内容请见👇 Pytorch 基于VGG-16的服饰识别(使用Fashion-MNIST数据集) Pytorch 基于NiN的服饰识别(使用Fashion-MNIST数据集) Pytorch 基于ResNet-18的服饰识别(使用
小嗷犬
2022/11/15
8040
Pytorch 基于ResNet-18的物体分类(使用CIFAR-10数据集)
ResNet18复现「建议收藏」
首先将网络分为四层(layers),每层有两个模块组成,除了第一层是两个普通的残差块组成,其它三层有一个普通的残差块和下采样的卷积块组成。输入图像为3x224x224格式,经过卷积池化后为64x112x112格式进入主网络架构。
全栈程序员站长
2022/09/01
4360
ResNet18复现「建议收藏」
Yolov8-pose关键点检测:模型轻量化设计 | 模型压缩率从6842降低到1018,GFLOPs从9.6降低至2.2
轻量化模型设计:模型压缩率从6842降低到1018,GFLOPs从9.6降低至2.2, mAP50从0.921变为0.92(几乎不变)
AI小怪兽
2023/12/08
1.8K1
CIFAR10数据集实战-ResNet网络构建(下)
这里注意到由[2, 64, 32, 32]到[2, 128, 32, 32],channel数量翻倍,而长和宽没有变化。这样势必会导致x的维度会越来越大。
用户6719124
2020/02/24
9880
PyTorch实现ResNet18
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/141287.html原文链接:https://javaforall.cn
全栈程序员站长
2022/09/01
9891
PyTorch实现ResNet18
LetNet、AlexNet、ResNet网络模型实现手写数字识别
本篇文章主要分享使用 LetNet、AlexNet、ResNet网络模型实现手写数字识别,文章是基于《MNIST手写数字识别》这篇文章的拓展:
不去幼儿园
2024/12/03
1810
LetNet、AlexNet、ResNet网络模型实现手写数字识别
Resnet 18网络模型[通俗易懂]
让我们聚焦于神经网络局部:如图左侧所示,假设我们的原始输入为x,而希望学出的理想映射为f(x)(作为上方激活函数的输入)。左图虚线框中的部分需要直接拟合出该映射f(x),而右图虚线框中的部分则需要拟合出残差映射f(x)−x。 残差映射在现实中往往更容易优化。 以本节开头提到的恒等映射作为我们希望学出的理想映射f(x),我们只需将右图虚线框内上方的加权运算(如仿射)的权重和偏置参数设成0,那么f(x)即为恒等映射。 实际中,当理想映射f(x)极接近于恒等映射时,残差映射也易于捕捉恒等映射的细微波动。右图是ResNet的基础架构–残差块(residual block)。 在残差块中,输入可通过跨层数据线路更快地向前传播
全栈程序员站长
2022/09/01
8.9K0
Resnet 18网络模型[通俗易懂]
ResNet+FPN实现+白嫖代码「建议收藏」
===========================================================
全栈程序员站长
2022/08/20
1.2K0
ResNet+FPN实现+白嫖代码「建议收藏」
PyTorch建立resnet34和resnet101代码[通俗易懂]
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
全栈程序员站长
2022/11/10
7610
resnet18模型
发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/141235.html原文链接:https://javaforall.cn
全栈程序员站长
2022/09/01
3690
Transfer Learning
通过网络上收集宝可梦的图片,制作图像分类数据集。我收集了5种宝可梦,分别是皮卡丘,超梦,杰尼龟,小火龙,妙蛙种子
mathor
2020/02/17
4560
pytorch笔记:04)resnet网络&解决输入图像大小问题「建议收藏」
因为torchvision对resnet18-resnet152进行了封装实现,因而想跟踪下源码
全栈程序员站长
2022/09/01
4.8K0
pytorch笔记:04)resnet网络&解决输入图像大小问题「建议收藏」
推荐阅读
相关推荐
深度学习实战之手写签名识别(100%准确率、语音播报)
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验