首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >在Keras中,使用SGD,为什么model.fit()训练得很顺利,但分步训练方法给出了爆炸性的梯度和损失

在Keras中,使用SGD,为什么model.fit()训练得很顺利,但分步训练方法给出了爆炸性的梯度和损失
EN

Stack Overflow用户
提问于 2021-08-06 12:36:10
回答 1查看 46关注 0票数 0

因为这种爆炸性的梯度和爆炸性的损失发生在网络巨大的时候,所以我不在这里张贴整个网络。但我已经尽了最大的努力,在过去的两周里,我深入研究了源代码的每个细节来监控一些权重,手动编写更新步骤来监控损失、权重、更新、梯度和超参数,以便与内部状态进行比较。我想在我问之前,我已经做了一些功课。

问题是,使用Keras API有两种训练方法,一种是model.fit(),第二种是更多的定制方法,用于更复杂的训练和网络,但是虽然我几乎所有的东西都保持不变,model.fit()没有爆炸性的损失,但自定义方法给出了爆炸性的损失。有趣的是,当我在一个小得多的网络中监控许多细节时,两种方法看起来都是一样的。

环境:

代码语言:javascript
运行
AI代码解释
复制
# tensorflow 1.14
import tensorflow as tf
from tensorflow.keras import backend as K

对于model.fit()方法:

代码语言:javascript
运行
AI代码解释
复制
# I skipped the details of the below two lines as I couldn't share the very details. but x is [10000, 32, 32, 3] image data, y is [10000, 10, 1] label. model is regular Keras model.

x_train, y_train, x_test, y_test = get_data()
model = get_keras_model()

loss_fn = tf.keras.losses.CategoricalCrossentropy()
sgd = tf.keras.optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)

model.compile(loss=loss_fn, optimizer=sgd, metrics=['accuracy'])
history = model.fit(x_train, y_train, batch_size=128, epochs=100, validation_data=(x_test, y_test))

自定义方法:

代码语言:javascript
运行
AI代码解释
复制
x_train, y_train, x_test, y_test = get_data()
model = get_keras_model()

input = model.inputs[0]
y_true = tf.placeholder(dtype = tf.int32, shape = [None, 10])
y_pred = model.outputs[0]

loss_fn = tf.keras.losses.CategoricalCrossentropy()
loss = loss_fn(y_true, y_pred)
weights = model.trainable_weights
sgd = tf.keras.optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)

training_updates = sgd.get_updates(loss, weights)
training_fn = K.function([y_true, input], [loss], training_updates)

num_train = 10000
steps_per_epoch = int(num_train / 128) # batch size 128
total_steps = steps_per_epoch * 100 # epoch 100

for step in total_steps:
    idx = np.random.randint(0, 10000, 128)
    input_img = x_train[idx]
    ground_true = y_train[idx]

    cur_loss = training_fn([ground_true, input_img])

简而言之,相同的模型,相同的损失函数,相同的优化器SGD,相同的图像馈送(我确实控制图像馈送顺序,尽管这里的代码是从训练数据中随机选择的)。在model.fit()的内部过程中,有什么可以防止损失或梯度爆炸的东西吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-08-29 15:09:11

在深入研究源代码后,我找到了梯度爆炸的原因,正确的代码(最小的变化如下所示):

代码语言:javascript
运行
AI代码解释
复制
x_train, y_train, x_test, y_test = get_data()
model = get_keras_model()

input = model.inputs[0]
y_true = tf.placeholder(dtype = tf.int32, shape = [None, 10])
y_pred = model.outputs[0]

loss_fn = tf.keras.losses.CategoricalCrossentropy()
loss = loss_fn(y_true, y_pred)
weights = model.trainable_weights
sgd = tf.keras.optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)

training_updates = sgd.get_updates(loss, weights)

# Correct:
training_fn = K.function([y_true, input, K.symbolic_learning_phase()], [loss], training_updates)

# Before:
# training_fn = K.function([y_true, input], [loss], training_updates)

num_train = 10000
steps_per_epoch = int(num_train / 128) # batch size 128
total_steps = steps_per_epoch * 100 # epoch 100

for step in total_steps:
    idx = np.random.randint(0, 10000, 128)
    input_img = x_train[idx]
    ground_true = y_train[idx]

    # Correct:
    cur_loss = training_fn([ground_true, input_img, True])

    # Before:
    # cur_loss = training_fn([ground_true, input_img])

我对这个特殊的张量K.symbolic_learning_phase()的理解是,它有默认值设置为False (如果你在初始化时检查源代码),BatchNormalizationDropout层等在训练阶段和测试阶段表现不同。在这种情况下,BatchNormalization层是导致梯度爆炸的原因(现在有一些帖子提到他们使用BatchNormalization层进行梯度爆炸),这是因为它的两个可训练权重batch_normalization_1/gamma:0batch_normalization_1/beta:0依赖于这个张量,并且使用默认值False,它们没有学习,它们的权重在训练过程中很快就变成了nan

我注意到,使用这种training_updates方法的Keras代码并不是真正将K.symbolic_learning_phase()放在代码中,然而,这是Keras的API在幕后做的事情。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68687237

复制
相关文章
Scrapy 使用代理IP并将输出保存到 jsonline
1、使用 scrapy 中间件,您需要在 settings.py 中启用 HttpProxyMiddleware,例如:
jackcode
2023/05/30
3240
Scrapy 使用代理IP并将输出保存到 jsonline
使用 PHP 函数对变量进行比较
使用 PHP 函数对变量 $x 进行比较 表达式 gettype() empty() is_null() isset() boolean : if($x) $x = ""; string TRUE FALSE TRUE FALSE $x = null; NULL TRUE TRUE FALSE FALSE var $x; NULL TRUE TRUE FALSE FALSE $x is undefined NULL TRUE TRUE FALSE FALSE $x = array(); array TRUE
Denis
2023/04/15
1.4K0
在Scrapy中如何使用aiohttp?
当我们从一些代理IP供应商购买代理IP时,他们可能是提供一个网址供我们查询当前可用的代理IP。我们周期性访问这个网址,拿到最新的IP,再分给爬虫使用。
青南
2020/07/16
6.7K0
Scrapy输出中文保存中文
scrapy在保存json文件时容易乱码 settings.py文件改动: ITEM_PIPELINES = { 'tutorial.pipelines.TutorialPipeline': 300, } pipeline.py文件改动: import json import codecs class TutorialPipeline(object): def __init__(self, spider): self.file = codecs.open('data_cn
林清猫耳
2019/03/04
2.8K0
混合线性模型如何进行多重比较
这里,得到的LSD = 6.708889, 多重比较中,用水平的平均值的差值,与LSD比较,如果大于LSD,则认为两水平达到显著性差异。
邓飞
2019/11/04
3.7K0
如何使用Scrapy框架抓取电影数据
随着互联网的普及和电影市场的繁荣,越来越多的人开始关注电影排行榜和评分,了解电影的排行榜和评分可以帮助我们更好地了解观众的喜好和市场趋势.豆瓣电影是一个广受欢迎的电影评分和评论网站,它提供了丰富的电影信息和用户评价。因此,爬取豆瓣电影排行榜的数据对于电影从业者和电影爱好者来说都具有重要意义。
小白学大数据
2023/09/25
3930
如何对矩阵中的所有值进行比较?
需求相对比较明确,就是在矩阵中显示的值,需要进行整体比较,而不是单个字段值直接进行的比较。如图1所示,确认矩阵中最大值或者最小值。
逍遥之
2020/05/14
8.2K0
【说站】java如何进行数据的比较
首先,Java中的数据存储在JVM中,而基本类型的数据存储在JVM的局部变量表中,也可以理解为所谓的“栈”。
很酷的站长
2022/11/23
8460
【说站】java如何进行数据的比较
【说站】python比较运算如何使用
1、除数值操作外,整数型和浮点型还可以进行比较操作,即比较两个数值的大小。比较结果是布尔值。
很酷的站长
2022/11/23
5960
【说站】python比较运算如何使用
Scrapy ---- 使用步骤
python、scrapy和pycharm已经安装好,并且python和scrapy环境已经配置好。scrapy安装比较简单的方法是通过pycharm IDE进行安装。 一、创建工程 命令行输入:sc
SuperHeroes
2018/05/30
7800
Scrapy框架的使用之Scrapy入门
接下来介绍一个简单的项目,完成一遍Scrapy抓取流程。通过这个过程,我们可以对Scrapy的基本用法和原理有大体了解。 一、准备工作 本节要完成的任务如下。 创建一个Scrapy项目。 创建一个Spider来抓取站点和处理数据。 通过命令行将抓取的内容导出。 将抓取的内容保存的到MongoDB数据库。 二、准备工作 我们需要安装好Scrapy框架、MongoDB和PyMongo库。 三、创建项目 创建一个Scrapy项目,项目文件可以直接用scrapy命令生成,命令如下所示: scrapy st
崔庆才
2018/06/25
1.4K0
scrapy 进阶使用
乐百川
2018/01/09
2K0
scrapy 进阶使用
学习爬虫之Scrapy框架学习(六)–1.直接使用scrapy;使用scrapy管道;使用scrapy的媒体管道类进行猫咪图片存储。媒体管道类学习。自建媒体管道类存储图片
大家好,又见面了,我是你们的朋友全栈君。 1.引入: 先来看个小案例:使用scrapy爬取百度图片。( 目标百度图片URL: https://image.baidu.com/search/
全栈程序员站长
2022/09/13
4160
scrapy爬虫笔记(1):scrapy基本使用
之前在写爬虫时,都是自己写整个爬取过程,例如向目标网站发起请求、解析网站、提取数据、下载数据等,需要自己定义这些实现方法等
冰霜
2022/03/15
3760
scrapy爬虫笔记(1):scrapy基本使用
如何使用tsharkVM分析tshark的输出
tsharkVM这个项目旨在构建一台虚拟机,以帮助广大研究人员分析tshark的输出结果。虚拟设备是使用vagrant构建的,它可以使用预安装和预配置的ELK堆栈构建Debian 10。
FB客服
2022/11/14
1.5K0
如何使用tsharkVM分析tshark的输出
PyCharm下进行Scrapy项目的调试
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_35512245/article/details/72835653
大黄大黄大黄
2018/09/14
1.3K0
PyCharm下进行Scrapy项目的调试
【说站】python变量如何进行格式化输出
以上就是python变量进行格式化输出的方法,希望对大家有所帮助。更多Python学习指路:python基础教程
很酷的站长
2022/11/23
6950
【说站】python变量如何进行格式化输出
在windows下如何新建爬虫虚拟环境和进行Scrapy安装
Scrapy是Python开发的一个快速、高层次的屏幕抓取和web抓取框架,用于抓取web站点并从页面中提取结构化的数据。Scrapy吸引人的地方在于它是一个框架,任何人都可以根据需求方便的修改。Scrapy用途广泛,可以用于数据挖掘、监测和自动化测试。
Python进阶者
2019/03/04
4730
在windows下如何新建爬虫虚拟环境和进行Scrapy安装
如何使用 scrapy.Request.from_curl() 方法将 cURL 命令转换为 Scrapy 请求
Scrapy 是一个用 Python 编写的开源框架,用于快速、高效地抓取网页数据。Scrapy 提供了许多强大的功能,如选择器、中间件、管道、信号等,让开发者可以轻松地定制自己的爬虫程序。
jackcode
2023/08/08
4510
如何使用 scrapy.Request.from_curl() 方法将 cURL 命令转换为 Scrapy 请求
在windows下如何新建爬虫虚拟环境和进行scrapy安装
Scrapy是Python开发的一个快速、高层次的屏幕抓取和web抓取框架,用于抓取web站点并从页面中提取结构化的数据。Scrapy吸引人的地方在于它是一个框架,任何人都可以根据需求方便的修改。Scrapy用途广泛,可以用于数据挖掘、监测和自动化测试。
Python进阶者
2019/02/11
7150
在windows下如何新建爬虫虚拟环境和进行scrapy安装

相似问题

使用jq计算给定JSON结构中键值对的数目

15

Java:将键值对附加到嵌套的json对象

13

将嵌套的Json转换为单个键值对

23

将Json嵌套映射转换为键值对

336

将MySQL层次结构数据转换为JSON字符串

26
添加站长 进交流群

领取专属 10元无门槛券

AI混元助手 在线答疑

扫码加入开发者社群
关注 腾讯云开发者公众号

洞察 腾讯核心技术

剖析业界实践案例

扫码关注腾讯云开发者公众号
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档