前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch-ResNet(残差网络)-下

Pytorch-ResNet(残差网络)-下

作者头像
用户6719124
发布2019-12-04 11:48:10
1.1K0
发布2019-12-04 11:48:10
举报
文章被收录于专栏:python pytorch AI机器学习实践

ResNet具有诸多优异性能,如下所示

在左图(准确率)的比较中,从AlexNet到GoogleNet再到ResNet,准确率逐渐提高。20层结构是很多网络结构性能提升的分水岭,在20层之前,模型性能提升较容易。但在20层之后,继续添加层数对性能的提升不是很明显。但ResNet很好地解决了高层数带来的误差叠加问题,因此性能也随着层数的增加而提升。

而在右图计算量对比图中,性能最完美的是ResNet-101、Inception-v4等,计算量不大且性能很好。而VGG的运算量较大、AlexNet虽然计算量较小,但性能不佳。

那么在具体代码中,卷积层是如何实现的?

如图我们想构建一个如下图所示得神经网络

首先要明确ResNet本质上是由多个基本单元堆叠实现的,写法与之前所讲的类似。

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F

# 先明确ResNet是由conv1+bn+ReLU+conv2+bn+ReLU构成

class ResBlk(nn.Module):
    def __init__(self, ch_in, ch_out):
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        # 依次对kernel_size、stride、padding进行定义
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()

        if ch_out != ch_in:
            # 该处即为short cut结构,若input_channel与该单元输出channel不一致
            # 即将ch_in作为输入、ch_out作为输出
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
                nn.BatchNorm2d(ch_out)
        )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.extra(x) + out
        return out

由此我们看出ResNet本质上是在每一层结构上都加了一个short cut。

若将该思路扩展,在中间的每一层均让其可能与之前层接触。这样就成了连接很密集的DenseNet。

如下所示

Densenet是各个channel上的累加,有时会使后面的计算量contact的很大。因此在DenseNet上的channel选择必须要非常的精妙。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-11-29,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
灰盒安全测试
腾讯知识图谱(Tencent Knowledge Graph,TKG)是一个集成图数据库、图计算引擎和图可视化分析的一站式平台。支持抽取和融合异构数据,支持千亿级节点关系的存储和计算,支持规则匹配、机器学习、图嵌入等图数据挖掘算法,拥有丰富的图数据渲染和展现的可视化方案。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档