前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >干货 | TensorFlow 2.0 模型:Keras 训练流程及自定义组件

干货 | TensorFlow 2.0 模型:Keras 训练流程及自定义组件

作者头像
AI研习社
发布于 2019-10-22 12:06:26
发布于 2019-10-22 12:06:26
3.3K00
代码可运行
举报
文章被收录于专栏:AI研习社AI研习社
运行总次数:0
代码可运行

文 / 李锡涵,Google Developers Expert

本文节选自《简单粗暴 TensorFlow 2.0》

在上一篇文章中,我们介绍了循环神经网络的建立方式。本来接下来应该介绍 TensorFlow 中的深度强化学习的,奈何笔者有点咕,到现在还没写完,所以就让我们先来了解一下 Keras 内置的模型训练 API 和自定义组件的方法吧!本文介绍以下内容:

  • 使用 Keras 内置的 API 快速建立和训练模型,几行代码创建和训练一个模型不是梦;
  • 自定义 Keras 中的层、损失函数和评估指标,创建更加个性化的模型。

Keras Pipeline *

在之前的文章中,我们均使用了 Keras 的 Subclassing API 建立模型,即对 tf.keras.Model 类进行扩展以定义自己的新模型,同时手工编写了训练和评估模型的流程。

这种方式灵活度高,且与其他流行的深度学习框架(如 PyTorch、Chainer)共通,是本手册所推荐的方法。不过在很多时候,我们只需要建立一个结构相对简单和典型的神经网络(比如上文中的 MLP 和 CNN),并使用常规的手段进行训练。这时,Keras 也给我们提供了另一套更为简单高效的内置方法来建立、训练和评估模型。

Keras Sequential/Functional API 模式建立模型

最典型和常用的神经网络结构是将一堆层按特定顺序叠加起来,那么,我们是不是只需要提供一个层的列表,就能由 Keras 将它们自动首尾相连,形成模型呢?Keras 的 Sequential API 正是如此。通过向 tf.keras.models.Sequential() 提供一个层的列表,就能快速地建立一个 tf.keras.Model 模型并返回:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1      model = tf.keras.models.Sequential([
2            tf.keras.layers.Flatten(),
3            tf.keras.layers.Dense(100, activation=tf.nn.relu),
4            tf.keras.layers.Dense(10),
5            tf.keras.layers.Softmax()
6        ])

不过,这种层叠结构并不能表示任意的神经网络结构。为此,Keras 提供了 Functional API,帮助我们建立更为复杂的模型,例如多输入 / 输出或存在参数共享的模型。其使用方法是将层作为可调用的对象并返回张量(这点与之前章节的使用方法一致),并将输入向量和输出向量提供给 tf.keras.Modelinputsoutputs 参数,示例如下:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1       inputs = tf.keras.Input(shape=(28, 28, 1))
2        x = tf.keras.layers.Flatten()(inputs)
3        x = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)(x)
4        x = tf.keras.layers.Dense(units=10)(x)
5        outputs = tf.keras.layers.Softmax()(x)
6        model = tf.keras.Model(inputs=inputs, outputs=outputs)

使用 Keras 的内置 API 训练和评估模型

当模型建立完成后,通过 tf.keras.Modelcompile 方法配置训练过程:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1   model.compile(
2        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
3        loss=tf.keras.losses.sparse_categorical_crossentropy,
4        metrics=[tf.keras.metrics.sparse_categorical_accuracy]
5    )

tf.keras.Model.compile 接受 3 个重要的参数:

  • oplimizer :优化器,可从 tf.keras.optimizers 中选择;
  • loss :损失函数,可从 tf.keras.losses 中选择;
  • metrics :评估指标,可从 tf.keras.metrics 中选择。

接下来,可以使用 tf.keras.Modelfit 方法训练模型:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1   model.fit(data_loader.train_data, data_loader.train_label, epochs=num_epochs, batch_size=batch_size)

tf.keras.Model.fit 接受 5 个重要的参数:

  • x :训练数据;
  • y :目标数据(数据标签);
  • epochs :将训练数据迭代多少遍;
  • batch_size :批次的大小;
  • validation_data :验证数据,可用于在训练过程中监控模型的性能。

Keras 支持使用 tf.data.Dataset 进行训练,详见 tf.data

注:tf.data 链接

https://tf.wiki/zh/basic/tools.html#tfdata

最后,使用 tf.keras.Model.evaluate 评估训练效果,提供测试数据及标签即可:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1   print(model.evaluate(data_loader.test_data, data_loader.test_label))

自定义层、损失函数和评估指标 *

可能你还会问,如果现有的这些层无法满足我的要求,我需要定义自己的层怎么办?事实上,我们不仅可以如 前文的介绍 一样继承 tf.keras.Model 编写自己的模型类,也可以继承 tf.keras.layers.Layer 编写自己的层。

自定义层

自定义层需要继承 tf.keras.layers.Layer 类,并重写 __init__buildcall 三个方法,如下所示:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 1class MyLayer(tf.keras.layers.Layer):
 2    def __init__(self):
 3        super().__init__()
 4        # 初始化代码
 5
 6    def build(self, input_shape):     # input_shape 是一个 TensorShape 类型对象,提供输入的形状
 7        # 在第一次使用该层的时候调用该部分代码,在这里创建变量可以使得变量的形状自适应输入的形状
 8        # 而不需要使用者额外指定变量形状。
 9        # 如果已经可以完全确定变量的形状,也可以在__init__部分创建变量
10        self.variable_0 = self.add_weight(...)
11        self.variable_1 = self.add_weight(...)
12
13    def call(self, inputs):
14        # 模型调用的代码(处理输入并返回输出)
15        return output

例如,如果我们要自己实现一个 前文 中的全连接层( tf.keras.layers.Dense ),可以按如下方式编写。此代码在 build 方法中创建两个变量,并在 call 方法中使用创建的变量进行运算:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 1class LinearLayer(tf.keras.layers.Layer):
 2    def __init__(self, units):
 3        super().__init__()
 4        self.units = units
 5
 6    def build(self, input_shape):     # 这里 input_shape 是第一次运行call()时参数inputs的形状
 7        self.w = self.add_variable(name='w',
 8            shape=[input_shape[-1], self.units], initializer=tf.zeros_initializer())
 9        self.b = self.add_variable(name='b',
10            shape=[self.units], initializer=tf.zeros_initializer())
11
12    def call(self, inputs):
13        y_pred = tf.matmul(inputs, self.w) + self.b
14        return y_pred

在定义模型的时候,我们便可以如同 Keras 中的其他层一样,调用我们自定义的层 LinearLayer

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1class LinearModel(tf.keras.Model):
2    def __init__(self):
3        super().__init__()
4        self.layer = LinearLayer(units=1)
5
6    def call(self, inputs):
7        output = self.layer(inputs)
8        return output

自定义损失函数和评估指标

自定义损失函数需要继承 tf.keras.losses.Loss 类,重写 call 方法即可,输入真实值 y_true 和模型预测值 y_pred ,输出模型预测值和真实值之间通过自定义的损失函数计算出的损失值。下面的示例为均方差损失函数:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1class MeanSquaredError(tf.keras.losses.Loss):
2    def call(self, y_true, y_pred):
3        return tf.reduce_mean(tf.square(y_pred - y_true))

自定义评估指标需要继承 tf.keras.metrics.Metric 类,并重写 __init__update_stateresult 三个方法。下面的示例对前面用到的 SparseCategoricalAccuracy 评估指标类做了一个简单的重实现:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 1class SparseCategoricalAccuracy(tf.keras.metrics.Metric):
 2    def __init__(self):
 3        super().__init__()
 4        self.total = self.add_weight(name='total', dtype=tf.int32, initializer=tf.zeros_initializer())
 5        self.count = self.add_weight(name='count', dtype=tf.int32, initializer=tf.zeros_initializer())
 6
 7    def update_state(self, y_true, y_pred, sample_weight=None):
 8        values = tf.cast(tf.equal(y_true, tf.argmax(y_pred, axis=-1, output_type=tf.int32)), tf.int32)
 9        self.total.assign_add(tf.shape(y_true)[0])
10        self.count.assign_add(tf.reduce_sum(values))
11
12    def result(self):
13        return self.count / self.total

福利 | 问答环节

我们知道在入门一项新的技术时有许多挑战与困难需要克服。如果您有关于 TensorFlow 的相关问题,可在本文后留言,我们的工程师和 GDE 将挑选其中具有代表性的问题在下一期进行回答~

在上一篇文章《TensorFlow 2.0 模型:循环神经网络》中,我们对于部分具有代表性的问题回答如下:

Q1:mirrorstrategy 在 1.13.1 这个版本里几乎没有任何加速效果。是在 2.0 做了修复吗? A:建议使用 2.0 的新版本试试看。在我们的测试中效果是非常显著的,可以参考下面文章进行尝试。

  • https://tf.wiki/zh/appendix/distributed.html#mirroredstrategy

Q2:能不能支持一下 mac a 卡 gpu? A:目前,AMD 的显卡也开始对 TensorFlow 提供支持,可访问博客文章查看详情。

  • https://medium.com/tensorflow/amd-rocm-gpu-support-for-tensorflow-33c78cc6a6cf

Q3:可以展示一下使用 TF2.0 建立 LSTM 回归预测模型吗?

A:可以参考示例,该示例使用了 Keras 和 LSTM 在天气数据集上进行了时间序列预测。

  • https://tensorflow.google.cn/tutorials/structured_data/time_series

Q4:应该给个例子,dataset 怎么处理大数据集。现在数据集过小。还有 keras 怎么用 subclass 的方式。这种小 demo 没啥意义。还有导出模型,这个很难弄。这些应该多写。

A:我们会在后面的连载系列中介绍高效处理大数据集的 tf.data ,以及导出模型到 SavedModel,敬请期待!

Q5:我想用现成的网络但是又想更改结构怎么弄?比如我要用现成的inception解决回归问题而不是分类,需要修改输入层和输出层。

A:TensorFlow Hub 提供了不包含最顶端全连接层的预训练模型(Headless Model),您可以使用该类型的预训练模型并添加自己的输出层,具体请参考:

  • https://tensorflow.google.cn/tutorials/images/transfer_learning_with_hub

Q6.请问正式版支持avx2吗?

A:pip 版本为了更好的通用性,默认是不支持 avx2,但是可以自己编译。

Q7.tf 团队可以支持下微软的 python-language-server 团队吗,动态导入的包特性导致 vs code 的用户无法自动补全,tf2.0 让我可望不可即

A:请参考 https://github.com/microsoft/python-language-server/issues/818。

《简单粗暴 TensorFlow 2.0 》目录

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-10-16,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI研习社 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
哈希表基本概念介绍及哈希冲突的处理方法(附源码)
  哈希表(散列表),是根据关键码值(Key value)而直接进行访问的数据结构。也就是说,它通过把关键码值映射到表中一个位置来访问记录,以加快查找的速度。这个映射函数叫做哈希(散列)函数,存放记录的数组叫做哈希(散列)表。
嵌入式与Linux那些事
2021/05/20
9310
哈希表基本概念介绍及哈希冲突的处理方法(附源码)
数据结构之哈希表(HASH)
   当我们在编程过程中,往往需要对线性表进行查找操作。在顺序表中查找时,需要从表头开始,依次遍历比较a[i]与key的值是否相等,直到相等才返回索引i;在有序表中查找时,我们经常使用的是二分查找,通过比较key与a[i]的大小来折半查找,直到相等时才返回索引i。最终通过索引找到我们要找的元素。    但是,这两种方法的效率都依赖于查找中比较的次数。我们有一种想法,能不能不经过比较,而是直接通过关键字key一次得到所要的结果呢?这时,就有了散列表查找(哈希表)。
全栈程序员站长
2022/07/21
5920
数据结构之哈希表(HASH)
散列表
http://blog.csdn.net/yyxaf/article/details/7527878 搜索关键词:散列函数、散列表、哈希函数、哈希表、Hash函数、Hash表 散列方法不同于顺序查找、二分查找、二叉排序树及B-树上的查找。它不以关键字的比较为基本操作,采用直接寻址技术。在理想情况下,无须任何比较就可以找到待查关键字,查找的期望时间为O(1)。 散列表的概念 1、散列表 设所有可能出现的关键字集合记为U(简称全集)。实际发生(即实际存储)的关键字集合记为K(|K|比|U|小得多)。 散列方
用户1624346
2018/04/17
1K0
【408&数据结构】散列 (哈希)知识点集合复习&考点题目
散列查找是一种高效的查找方法,它通过散列函数将关键字映射到数组的一个位置,从而实现快速查找。这种方法的时间复杂度平均为(O(1)),但最坏情况下可能会退化到(O(n))。
苏泽
2024/09/09
2830
【408&数据结构】散列 (哈希)知识点集合复习&考点题目
重温数据结构:哈希 哈希函数 哈希表
该文介绍了计算机科学中的哈希表(Hash Table)及其在编程中的应用。哈希表是一种数据结构,可以高效地完成查找、插入、删除等操作。文章还介绍了哈希函数、哈希冲突、拉链法等概念。
张拭心 shixinzhang
2018/01/05
2.7K1
重温数据结构:哈希 哈希函数 哈希表
【数据结构】什么是哈希表(散列表)?
在正式开始深入了解哈希表之前呢, 我想带大家先回忆一下生活中咱们的这个"老朋友"。可能你会感到诧异, 我怎么会和它是"老朋友"呢? 别急, 其实你的生活中常常会出现哈希的身影,只是你没有细心观察罢了,不信你看下面几个场景对你来说是不是非常熟悉呢:
修修修也
2024/10/06
2550
【数据结构】什么是哈希表(散列表)?
《大话数据结构》 查找 以及一个简单的哈希表例子
第八章 查找 定义:查找就是根据给定的某个值,在查找表中确定一个其关键字等于给定值的数据元素(或记录)。 8.2 查找概论 查找表(Search table):是由同一类型的数据元素构成的集合。 关键字(key):是数据元素中某个数据项的值,又称为键值。 若此关键字可以唯一的标识一个记录,则称此关键字为主关键字(Primary key)。 对于那些可以识别多个数据元素的关键字,我们称为次关键字(Secondary key)。 查找表按照操作方式来分有两大种:静态查找表和动态查找表 静态查找表(Static
xcywt
2018/03/28
2.4K0
《大话数据结构》 查找 以及一个简单的哈希表例子
数据结构:查找
衡量标准:查找过程中对关键字的平均比较次数——平均查找长度ASL。设查找到第i个元素的概率为p,比较次数为c,则查找成功的ASL_{succ}=\sum^n_{i=1}p_ic_i
ttony0
2022/12/26
9870
数据结构:查找
【数据结构】哈希表
顺序结构以及平衡树中,元素关键码与其存储位置之间没有对应的关系,因此在查找一个元素时,必须要经过关键码的多次比较。顺序查找时间复杂度为 $O(N)$ ,平衡树中为树的高度,即 $O(logN)$ ,搜索的效率取决于搜索过程中元素的比较次数。
椰椰椰耶
2024/09/20
1120
【数据结构】哈希表
【经验分享】数据结构——哈希查找冲突处理方法(开放地址法-线性探测、平方探测、双散列探测、再散列,分离链接法)
的哈希表,插入一组关键字 [10, 22, 31, 4, 15, 28],并使用线性探测解决冲突。
命运之光
2024/08/17
2330
[算法] 开放寻址法解决哈希冲突方式
开放寻址法:又称开放定址法,当哈希冲突发生时,从发生冲突的那个单元起,按照一定的次序,从哈希表中寻找一个空闲的单元,然后把发生冲突的元素存入到该单元。这个空闲单元又称为开放单元或者空白单元。开放寻址法需要的表长度要大于等于所需要存放的元素数量,非常适用于装载因子较小(小于0.5)的散列表。
唯一Chat
2020/12/31
4K0
哈希表总结
之前给大家介绍了链表,栈和队列今天我们来说一种新的数据结构散列(哈希)表,散列是应用非常广泛的数据结构,在我们的刷题过程中,散列表的出场率特别高。所以我们快来一起把散列表的内些事给整明白吧,文章框架如下。
宿春磊Charles
2022/03/29
7390
哈希表总结
散列表(哈希表)
版权声明:本文为博主原创文章,转载请注明博客地址: https://blog.csdn.net/zy010101/article/details/83998492
zy010101
2019/05/25
7540
程序员必读:教你摸清哈希表的脾气
在哈希表中,记录的存储位置 = f (关键字),通过查找关键字的存储位置即可,不用进行比较。散列技术是在记录的存储位置和它的关键字之间建立一个明确的对应关系f 函数,使得每个关键字 key 对应一个存储位置 f(key) 且这个位置是唯一的。这里我们将这种对应关系 f 称为散列函数,又称为哈希(Hash)函数。采用散列技术将记录存储在一块连续的存储空间中,这块连续存储空间称为散列表或哈希表(Hash table)。
谭庆波
2018/08/10
3930
程序员必读:教你摸清哈希表的脾气
进阶 | 我实现了javascript 哈希表,并进行性能比较
前端爱好者的聚集地 javascript的对象就是一个哈希表,为了学习真正的数据结构,我们还是有必要自己重新实现一下。 基本概念 哈希表(hash table )是一种根据关键字直接访问内存存储位置的数据结构,通过哈希表,数据元素的存放位置和数据元素的关键字之间建立起某种对应关系,建立这种对应关系的函数称为哈希函数。 哈希表的构造方法 假设要存储的数据元素个数是n,设置一个长度为m(m > n)的连续存储单元,分别以每个数据元素的关键字Ki(0<=i<=n-1)为自变量,通过哈希函数hash(Ki),把
用户1097444
2022/06/29
6930
进阶 | 我实现了javascript 哈希表,并进行性能比较
数据结构 之 哈希表
哈希表(Hash table) 又称为散列表,是根据关键码值(Key value)而直接进行访问的数据结构。也就是说,它通过把关键码值映射到表中一个位置来访问记录,以加快查找的速度。这个映射函数叫做散列函数,存放记录的数组叫做哈希表。
AUGENSTERN_
2024/04/23
1.3K0
数据结构 之 哈希表
数据结构与算法之哈希表
哈希表也叫散列表。 散列表(Hash table,也叫哈希表),是根据关键码值(Key value)而直接进行访问的数据结构。也就是说,它通过把关键码值映射到表中一个位置来访问记录,以加快查找的速度。这个映射函数叫做散列函数,存放记录的数组叫做散列表。 给定表M,存在函数f(key),对任意给定的关键字值key,代入函数后若能得到包含该关键字的记录在表中的地址,则称表M为哈希(Hash)表,函数f(key)为哈希(Hash) 函数。
袁新栋-jeff.yuan
2020/08/26
7570
什么是散列表(哈希表)?
假设你们班级100个同学每个人的学号是由院系-年级-班级和编号组成,例如学号为01100168表示是1系,10级1班的68号。为了快速查找到68号的成绩信息,可以建立一张表,但是不能用学号作为下标,学号的数值实在太大。因此将学号除以1100100取余,即得到编号作为该表的下标,那么,要查找学号为01100168的成绩的时候,只要直接访问表下标为68的数据即可。这就能够在O(1)时间复杂度内完成成绩查找。
编程珠玑
2019/07/12
6580
海量数据处理
  针对海量数据的处理,可以使用的方法非常多,常见的方法有hash法、Bit-map法、Bloom filter法、数据库优化法、倒排索引法、外排序法、Trie树、堆、双层桶法以及MapReduce法。 1、hash法 hash法也成为散列法,它是一种映射关系,即给定一个元素,关键字是key,按照一个确定的散列函数计算出hash(key),把hash(key)作为关键字key对应的元素的存储地址,再进行数据元素的插入和检索操作。   散列表是具有固定大小的数组,表长应该是质数,散列函数是用于关键字和存储
Mister24
2018/05/14
2.2K0
数据结构 Hash表(哈希表)
参考链接:数据结构(严蔚敏) 文章发布很久了,具体细节已经不清晰了,不再回复各种问题 文章整理自严蔚敏公开课视频 可以参考 https://www.bilibili.com/video/av22258871/ 如果链接失效 可以自行搜索 数据结构严蔚敏视频 @2021/07/12
全栈程序员站长
2022/09/15
1.3K0
相关推荐
哈希表基本概念介绍及哈希冲突的处理方法(附源码)
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档