Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >神经网络使用梯度下降的原因(摘自我写的书)

神经网络使用梯度下降的原因(摘自我写的书)

作者头像
黄鸿波
发布于 2020-04-14 09:34:37
发布于 2020-04-14 09:34:37
1.9K0
举报
文章被收录于专栏:AI的那些事儿AI的那些事儿
自上次我把我2018年出版的《TensorFlow进阶指南 基础、算法与应用》一书的部分内容放上来之后,收到了很多后台留言,说能不能再放一些,后来想了想,今天又放出来了其中的一节内容。

目前深度神经网络模型的优化方法主要是梯度下降。我们使用梯度下降的方法来进行误差的反向传播,不断地调整模型参数,以降低模型所产生的误差,使模型更好实现从输入到输出的映射。目前因为各种因素,神经网络可以做的层数更深,神经元更多。相较于以前得到了性能上较大的提升。

由于许多非线性层的作用,模型容量得到了较大的提高,使模型可以完成更加复杂的任务,模型很庞大,参数空间也非常复杂,我们使用的梯度下降算法是目前最有效的优化算法,但是这样深层的神经网络在误差反向传播过程中,很容易遭遇梯度消散和梯度爆炸的问题。

我们进行反向传播过程会把误差由输出层一层一层地往前面传播,这与神经网络的层数有一定的关系(如果是循环神经网络还会时间步有关)。我们的误差是由链式法则一层一层地传播的,假设神经网络模型中的参数为W,则在链式法则中,需要多次乘以W,可以理解为Wn次方,假设W有特征值分解,则

V是权重参数W矩阵特征向量构成的矩阵,是由权重参数W矩阵的特征值λ构成的对角矩阵。

λ>1,W容易产生一个极大的数值,导致梯度爆炸。

λ<1,W容易接近于0,导致梯度消失。

梯度消失

每一次梯度更新的公式如下:

其中w为模型参数,α表示学习率,

就是目标函数对参数W的导数。

如果产生梯度消散的问题,每一次的梯度更新,

就会等于零,那么这样的梯度更新是没有意义的,这样意味着已经无法进行学习了。

由链式法则可以知道,这样的问题经常出现在深层神经网络模型的较浅的层中,出现这个问题时,较浅的层往往还没有掌握最好的学习技巧和提取特征的能力,对于后续神经层以及整一个模型的效果都会产生较大的影响。如果只能对后面的神经层进行训练,前面较浅的层不再能继续训练了,则不利于模型在参数空间中寻找最优点。

在循环神经网络中出现这样的问题时,可以理解为模型失去了对较早时间步的记忆,无法做到长期依赖,更加注重当前几步的信息,造成了时序信息的丢失。

梯度爆炸

深度神经网络模型很大,有许多非线性神经元,模型会呈现出高度非线性,所以参数空间也会很复杂。我们对这么复杂的参数空间可以理解成一个地形非常复杂的地方,往往伴随着“悬崖”地形,在“悬崖”处的梯度是极大的,正因为这个极大的梯度容易导致梯度爆炸的问题,如图所示。

在我们进行梯度更新时,根据公式(梯度消散部分说到的公式)可以知道学习速率乘以一个极大的梯度会导致参数更新时更新的幅度非常的大,离开了当前的区域,进入了另外一个较远的区域,使之前更新的步骤都成了“无用功”,极大地影响了优化的性能。

循环神经网络中出现梯度爆炸的情况少一些,它更多的问题是梯度消散,梯度爆炸会更多的出现在深度前馈神经网络中。

解决梯度消散和梯度爆炸问题的方法

选择合适的激活函数

在误差反向传播过程中,需要对激活函数进行多次求导,此时,激活函数的导数大小可以直接影响梯度下降的效果,过小容易产生梯度消散,过大容易产生梯度爆炸,如果激活函数的导数是1,则这是最理想的情况,所以我们更多地建议使用relu系列的激活函数,如Relu、elu、leakyrelu,Relu函数图像如图6-8左图所示,其导数图像如图。

不建议大家使用sigmoid和tanh等激活函数,因为它们的导数在大部分区域都是非常小的,容易导致梯度消散的问题。如图所示。

选择合适的参数初始化方法

使用权重参数正则化

使用权重参数正则化可以减少梯度爆炸发生的概率,常用的正则化方式就是L1或者L2正则化。对模型参数进行L1正则化时,参数会倾向于0和1的稀疏结构(假设参数为Laplace先验分布),对模型参数进行L2正则化时,参数会倾向于被“压缩”到一个很小的接近于0的数字(假设参数为标准高斯先验分布)。我们通过在目标函数中添加惩罚项来达到这样的效果,减小了模型复杂度的同时,也减小了发生梯度爆炸的概率,但是却增加了梯度消散的概率。

正则化对模型的梯度影响有限,这不是最主要的。

使用BatchNormalization

BatchNormalization目前已经在深度神经网络模型中得到了广泛的应用,主要通过规范化操作将输出信号x规范化到均值为0、方差为1保证网络的稳定性,可以加大神经网络训练的速度,提高训练的稳定性,也可以缓解梯度爆炸和梯度消散的问题。

在误差反向传播过程中,经过每一层的梯度都会乘以该层的权重参数,举个简单的例子:

正向传播中:

那么反向传播中:

反向传播式子中有w的存在,所以 w 的大小影响了梯度的消失和爆炸,batchnormalization就是通过对每一层的输出做规模和偏移的方法,通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布控制在接近均值为0、方差为1的分布,把偏离的分布强制拉回到一个比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域。这样输入的小变化就会导致损失函数较大的变化,使得让梯度变大,避免梯度消失,而且梯度变大意味着学习收敛速度快,能大大加快训练速度,同时也能在一定程度防止梯度爆炸的问题。

使用残差结构

说起残差结构,我们一定要提起resdiual network,这个结构的提出极大地提高了神经网络的深度,因为这种结构在很大程度上解决了梯度消散的问题,允许我们可以训练很深层的神经网络,充分利用了神经网络里面的每一个神经元,模型更大更复杂,也可以通过梯度下降的方式进行训练。自从提出残差结构后,现在的模型基本都离不开resdiual network的身影。

相比之前的几层、几十层的深度网络,在残差网络面前都不值一提,残差可以很轻松地构建几百层,即便一千多层的网络也不用担心梯度消散过快的问题,原因就在于残差的捷径(high way)部分,其中残差单元如图。

使用梯度裁剪

之前在讲解梯度爆炸产生的原因时,提到了参数空间有很多“悬崖”地形,导致了梯度下降的困难,如图所示,“悬崖”处的参数梯度是极大的,梯度下降时可以把参数抛出很远,使之前的努力都荒废了。我们解决这个问题的方法是进行梯度裁剪,梯度裁剪就是用来限制梯度大小的,若梯度大小超出了梯度范数的上界,则强制令梯度大小为梯度范数的上界的大小,来避免梯度过大的情况,在使用这样的方法进行梯度裁剪时,只是改变了这个梯度的大小,仍然保持了梯度的方向。

公式如下:

其中v是梯度范数的上界,g用来更新参数的梯度。

我们要控制“悬崖”处梯度的大小,使用一个尽量小一点的梯度,避免穿越向上的曲面,使参数保持在一个合适的区域内。使用了梯度截断的梯度下降对“悬崖”处的反应更加温和,当参数更新到了“悬崖”截面处时,由于梯度大小收到了控制,不会那么容易被“抛出”到比较远的参数空间中去,导致“前功尽弃”。如图所示。

既然使用梯度裁剪的方式来处理梯度爆炸,同样的,梯度消散可不可以使用梯度扩张的方式来解决呢?其实这个问题并没有那么简单,梯度过小的时候,有两种可能:一种是梯度消散,一种是到达局部最优或者鞍点。如果不能准确区分这两类情况,单纯扩张梯度有可能导致系统不收敛。而且梯度太小时,方向其实也是很难确定的,或者说很有可能是不准确的,我们不可以随意地在这个方向上放大梯度,进行梯度下降。

END

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-03-31,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI的那些事儿 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
2019-05-31 ojdbc6 安装到本地maven仓库
https://www.oracle.com/technetwork/cn/database/enterprise-edition/jdbc-112010-094555-zhs.html
Albert陈凯
2019/06/02
8330
Mybatis-Plus实践学习(二十五)
由于版权原因,我们不能直接通过maven的中央仓库下载oracle数据库的jdbc驱动包,所以我们需要将驱动包安装到本地仓库。
用户1289394
2024/02/26
1420
Mybatis-Plus实践学习(二十五)
Maven 菜鸟教程 4 常用dos命令
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
全栈程序员站长
2022/10/04
4130
maven上传就私库
mvn deploy:deploy-file -DgroupId=com.cmos -DartifactId=itframe-boot-base -Dversion=1.0.1-SNAPSHOT -Dpackaging=jar -Dfile=D:\s\itframe-boot-base-1.0.1-SNAPSHOT.jar -Durl=http://10.97.85.11:38081/repository/maven-snapshots -DrepositoryId=nexus
全栈程序员站长
2022/09/09
3300
JAR包安装报错requires a project to execute but there is no POM的解决
  本文介绍在Windows中,通过Maven的mvn install:install-file命令安装JAR包时,提示The goal you specified requires a project to execute but there is no POM in this directory错误的解决方法。
疯狂学习GIS
2025/03/10
4270
JAR包安装报错requires a project to execute but there is no POM的解决
Hudi数据湖技术引领大数据新风口(二)编译安装
(1)上传apache-maven-3.6.1-bin.tar.gz到/opt/software目录,并解压更名
Maynor
2023/07/28
5840
Hudi数据湖技术引领大数据新风口(二)编译安装
Mavan 引入本地Jar
通常情况下,我们都是通过 Maven 从中央仓库或者阿里仓库直接拉取依赖的 JAR 包来构建我们的项目。然而,在实际工作中,有时候会遇到一些特殊情况,比如对接三方平台时,对方提供的是一个直接下载链接的 JAR 包,而不是通过 Maven 仓库管理的方式提供依赖。
刺槐儿
2024/01/17
3640
maven向本土仓库导入jar包(处理官网没有的jar包)
1、将pinyin4j-2.5.0.jar文件放在“D:\JAR_LIB”目录下(该目录任意)
EltonZheng
2021/01/26
1.2K0
maven 安装alipay-sdk包到本地及远程仓库
一、安装到本地: mvn install:install-file -DgroupId=com.alipay -DartifactId=sdk-Java -Dversion=*** -Dpackaging=jar -Dfile=alipay-sdk-java*.jar 二、安装到远程仓库: maven配置: <!-- Another sample, using keys to authenticate. <server> <id>siteServer</id> <privateKey>/path/t
WindWant
2020/09/11
1K0
本地私服仓库nexus3.3.1使用手册
私服架构 私服是指私有服务器,是架设在局域网的一种特殊的远程仓库,目的是代理远程仓库及部署第三方构建。有了私服之后,当 Maven 需要下载构件时,直接请求私服,私服上存在则下载到本地仓库;否则,私服
小柒2012
2018/04/16
8.5K0
本地私服仓库nexus3.3.1使用手册
SpringBoot项目Oracle报AbstractMethodError
java.lang.AbstractMethodError: oracle.jdbc.driver.T4CConnection.isValid(I)Z
I Teach You 我教你
2023/07/18
3790
SpringBoot项目Oracle报AbstractMethodError
maven命令大全
Maven常用命令: 创建Maven的普通Java项目: mvn archetype:create -DgroupId=packageName -DartifactId=projectName 创建Maven的Web项目: mvn archetype:create -DgroupId=packageName -DartifactId=webappName-DarchetypeArtifactId=maven-archetype-webapp 编译源代码: mvn compile 编译测试代码:mvn t
小柒2012
2018/04/16
1.6K0
如何使用 Java 生成二维码?
 QRCode生成二维码网址:http://swetake.com/qrcode/index-e.html
芋道源码
2019/10/24
2.2K0
jar包导入到项目中、本地maven仓库、私库
配置Jar包的dependency,包括groupId,artifactId,version三个属性,同时还要包含scope和systemPath属性;
ha_lydms
2023/08/09
2.5K0
jar包导入到项目中、本地maven仓库、私库
数据库文档生成工具- screw
在企业级开发中、我们经常会有编写数据库表结构文档的时间付出,从业以来,待过几家企业,关于数据库表结构文档状态:要么没有、要么有、但都是手写、后期运维开发,需要手动进行维护到文档中,很是繁琐、如果忘记一次维护、就会给以后工作造成很多困扰、无形中制造了很多坑留给自己和后人,于是需要一个插件工具 screw 来维护。
Remember_Ray
2020/09/15
1.4K0
数据库文档生成工具- screw
Maven 常用命令
mvn deploy:deploy-file -DgroupId=com.sun.pdfview -DartifactId=pdf -Dversion=1.0 -Dpackaging=jar -Dfile=/home/homer/Desktop/pdf.jar -Durl=http://172.27.9.104:8081/nexus/content/repositories/thirdparty/ -DrepositoryId=thirdparty
阳光岛主
2019/02/19
7550
maven打包本地jar到本地仓库
该命令会使用E:\dev\maven\xinao\apache-maven-3.6.0\bin\mvn路径下的settings.xml文件
九转成圣
2024/04/10
1980
将下载到本地的JAR包手动添加到Maven仓库(转)
常用Maven仓库网址: http://mvnrepository.com/ http://search.maven.org/ http://repository.sonatype.org/content/groups/public/ http://people.apache.org/repo/m2-snapshot-repository/ http://people.apache.org/repo/m2-incubating-repository/
HUC思梦
2020/09/03
2.1K0
将下载到本地的JAR包手动添加到Maven仓库(转)
Maven安装本地jar
以Oracle数据库的驱动为例 oracle驱动安装 下载驱动这里 安装在本地maven库 mvn install:install-file -Dfile=ojdbc8路径 -DgroupId=com.oracle -DartifactId=ojdbc8 -Dversion=版本号 -Dpackaging=jar
DH镔
2019/12/20
9440
快速学习Maven-Nexus把第三方jar包放入本地仓库或私服
需要在 maven 软件的核心配置文件 settings.xml 中配置第三方仓库的 server 信息
cwl_java
2019/12/25
1.8K0
相关推荐
2019-05-31 ojdbc6 安装到本地maven仓库
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档