前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >干货 | TensorFlow 2.0 模型:Keras 训练流程及自定义组件

干货 | TensorFlow 2.0 模型:Keras 训练流程及自定义组件

作者头像
AI研习社
发布于 2019-10-22 12:06:26
发布于 2019-10-22 12:06:26
3.4K02
代码可运行
举报
文章被收录于专栏:AI研习社AI研习社
运行总次数:2
代码可运行

文 / 李锡涵,Google Developers Expert

本文节选自《简单粗暴 TensorFlow 2.0》

在上一篇文章中,我们介绍了循环神经网络的建立方式。本来接下来应该介绍 TensorFlow 中的深度强化学习的,奈何笔者有点咕,到现在还没写完,所以就让我们先来了解一下 Keras 内置的模型训练 API 和自定义组件的方法吧!本文介绍以下内容:

  • 使用 Keras 内置的 API 快速建立和训练模型,几行代码创建和训练一个模型不是梦;
  • 自定义 Keras 中的层、损失函数和评估指标,创建更加个性化的模型。

Keras Pipeline *

在之前的文章中,我们均使用了 Keras 的 Subclassing API 建立模型,即对 tf.keras.Model 类进行扩展以定义自己的新模型,同时手工编写了训练和评估模型的流程。

这种方式灵活度高,且与其他流行的深度学习框架(如 PyTorch、Chainer)共通,是本手册所推荐的方法。不过在很多时候,我们只需要建立一个结构相对简单和典型的神经网络(比如上文中的 MLP 和 CNN),并使用常规的手段进行训练。这时,Keras 也给我们提供了另一套更为简单高效的内置方法来建立、训练和评估模型。

Keras Sequential/Functional API 模式建立模型

最典型和常用的神经网络结构是将一堆层按特定顺序叠加起来,那么,我们是不是只需要提供一个层的列表,就能由 Keras 将它们自动首尾相连,形成模型呢?Keras 的 Sequential API 正是如此。通过向 tf.keras.models.Sequential() 提供一个层的列表,就能快速地建立一个 tf.keras.Model 模型并返回:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1      model = tf.keras.models.Sequential([
2            tf.keras.layers.Flatten(),
3            tf.keras.layers.Dense(100, activation=tf.nn.relu),
4            tf.keras.layers.Dense(10),
5            tf.keras.layers.Softmax()
6        ])

不过,这种层叠结构并不能表示任意的神经网络结构。为此,Keras 提供了 Functional API,帮助我们建立更为复杂的模型,例如多输入 / 输出或存在参数共享的模型。其使用方法是将层作为可调用的对象并返回张量(这点与之前章节的使用方法一致),并将输入向量和输出向量提供给 tf.keras.Modelinputsoutputs 参数,示例如下:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1       inputs = tf.keras.Input(shape=(28, 28, 1))
2        x = tf.keras.layers.Flatten()(inputs)
3        x = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)(x)
4        x = tf.keras.layers.Dense(units=10)(x)
5        outputs = tf.keras.layers.Softmax()(x)
6        model = tf.keras.Model(inputs=inputs, outputs=outputs)

使用 Keras 的内置 API 训练和评估模型

当模型建立完成后,通过 tf.keras.Modelcompile 方法配置训练过程:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1   model.compile(
2        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
3        loss=tf.keras.losses.sparse_categorical_crossentropy,
4        metrics=[tf.keras.metrics.sparse_categorical_accuracy]
5    )

tf.keras.Model.compile 接受 3 个重要的参数:

  • oplimizer :优化器,可从 tf.keras.optimizers 中选择;
  • loss :损失函数,可从 tf.keras.losses 中选择;
  • metrics :评估指标,可从 tf.keras.metrics 中选择。

接下来,可以使用 tf.keras.Modelfit 方法训练模型:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1   model.fit(data_loader.train_data, data_loader.train_label, epochs=num_epochs, batch_size=batch_size)

tf.keras.Model.fit 接受 5 个重要的参数:

  • x :训练数据;
  • y :目标数据(数据标签);
  • epochs :将训练数据迭代多少遍;
  • batch_size :批次的大小;
  • validation_data :验证数据,可用于在训练过程中监控模型的性能。

Keras 支持使用 tf.data.Dataset 进行训练,详见 tf.data

注:tf.data 链接

https://tf.wiki/zh/basic/tools.html#tfdata

最后,使用 tf.keras.Model.evaluate 评估训练效果,提供测试数据及标签即可:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1   print(model.evaluate(data_loader.test_data, data_loader.test_label))

自定义层、损失函数和评估指标 *

可能你还会问,如果现有的这些层无法满足我的要求,我需要定义自己的层怎么办?事实上,我们不仅可以如 前文的介绍 一样继承 tf.keras.Model 编写自己的模型类,也可以继承 tf.keras.layers.Layer 编写自己的层。

自定义层

自定义层需要继承 tf.keras.layers.Layer 类,并重写 __init__buildcall 三个方法,如下所示:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 1class MyLayer(tf.keras.layers.Layer):
 2    def __init__(self):
 3        super().__init__()
 4        # 初始化代码
 5
 6    def build(self, input_shape):     # input_shape 是一个 TensorShape 类型对象,提供输入的形状
 7        # 在第一次使用该层的时候调用该部分代码,在这里创建变量可以使得变量的形状自适应输入的形状
 8        # 而不需要使用者额外指定变量形状。
 9        # 如果已经可以完全确定变量的形状,也可以在__init__部分创建变量
10        self.variable_0 = self.add_weight(...)
11        self.variable_1 = self.add_weight(...)
12
13    def call(self, inputs):
14        # 模型调用的代码(处理输入并返回输出)
15        return output

例如,如果我们要自己实现一个 前文 中的全连接层( tf.keras.layers.Dense ),可以按如下方式编写。此代码在 build 方法中创建两个变量,并在 call 方法中使用创建的变量进行运算:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 1class LinearLayer(tf.keras.layers.Layer):
 2    def __init__(self, units):
 3        super().__init__()
 4        self.units = units
 5
 6    def build(self, input_shape):     # 这里 input_shape 是第一次运行call()时参数inputs的形状
 7        self.w = self.add_variable(name='w',
 8            shape=[input_shape[-1], self.units], initializer=tf.zeros_initializer())
 9        self.b = self.add_variable(name='b',
10            shape=[self.units], initializer=tf.zeros_initializer())
11
12    def call(self, inputs):
13        y_pred = tf.matmul(inputs, self.w) + self.b
14        return y_pred

在定义模型的时候,我们便可以如同 Keras 中的其他层一样,调用我们自定义的层 LinearLayer

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1class LinearModel(tf.keras.Model):
2    def __init__(self):
3        super().__init__()
4        self.layer = LinearLayer(units=1)
5
6    def call(self, inputs):
7        output = self.layer(inputs)
8        return output

自定义损失函数和评估指标

自定义损失函数需要继承 tf.keras.losses.Loss 类,重写 call 方法即可,输入真实值 y_true 和模型预测值 y_pred ,输出模型预测值和真实值之间通过自定义的损失函数计算出的损失值。下面的示例为均方差损失函数:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
1class MeanSquaredError(tf.keras.losses.Loss):
2    def call(self, y_true, y_pred):
3        return tf.reduce_mean(tf.square(y_pred - y_true))

自定义评估指标需要继承 tf.keras.metrics.Metric 类,并重写 __init__update_stateresult 三个方法。下面的示例对前面用到的 SparseCategoricalAccuracy 评估指标类做了一个简单的重实现:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 1class SparseCategoricalAccuracy(tf.keras.metrics.Metric):
 2    def __init__(self):
 3        super().__init__()
 4        self.total = self.add_weight(name='total', dtype=tf.int32, initializer=tf.zeros_initializer())
 5        self.count = self.add_weight(name='count', dtype=tf.int32, initializer=tf.zeros_initializer())
 6
 7    def update_state(self, y_true, y_pred, sample_weight=None):
 8        values = tf.cast(tf.equal(y_true, tf.argmax(y_pred, axis=-1, output_type=tf.int32)), tf.int32)
 9        self.total.assign_add(tf.shape(y_true)[0])
10        self.count.assign_add(tf.reduce_sum(values))
11
12    def result(self):
13        return self.count / self.total

福利 | 问答环节

我们知道在入门一项新的技术时有许多挑战与困难需要克服。如果您有关于 TensorFlow 的相关问题,可在本文后留言,我们的工程师和 GDE 将挑选其中具有代表性的问题在下一期进行回答~

在上一篇文章《TensorFlow 2.0 模型:循环神经网络》中,我们对于部分具有代表性的问题回答如下:

Q1:mirrorstrategy 在 1.13.1 这个版本里几乎没有任何加速效果。是在 2.0 做了修复吗? A:建议使用 2.0 的新版本试试看。在我们的测试中效果是非常显著的,可以参考下面文章进行尝试。

  • https://tf.wiki/zh/appendix/distributed.html#mirroredstrategy

Q2:能不能支持一下 mac a 卡 gpu? A:目前,AMD 的显卡也开始对 TensorFlow 提供支持,可访问博客文章查看详情。

  • https://medium.com/tensorflow/amd-rocm-gpu-support-for-tensorflow-33c78cc6a6cf

Q3:可以展示一下使用 TF2.0 建立 LSTM 回归预测模型吗?

A:可以参考示例,该示例使用了 Keras 和 LSTM 在天气数据集上进行了时间序列预测。

  • https://tensorflow.google.cn/tutorials/structured_data/time_series

Q4:应该给个例子,dataset 怎么处理大数据集。现在数据集过小。还有 keras 怎么用 subclass 的方式。这种小 demo 没啥意义。还有导出模型,这个很难弄。这些应该多写。

A:我们会在后面的连载系列中介绍高效处理大数据集的 tf.data ,以及导出模型到 SavedModel,敬请期待!

Q5:我想用现成的网络但是又想更改结构怎么弄?比如我要用现成的inception解决回归问题而不是分类,需要修改输入层和输出层。

A:TensorFlow Hub 提供了不包含最顶端全连接层的预训练模型(Headless Model),您可以使用该类型的预训练模型并添加自己的输出层,具体请参考:

  • https://tensorflow.google.cn/tutorials/images/transfer_learning_with_hub

Q6.请问正式版支持avx2吗?

A:pip 版本为了更好的通用性,默认是不支持 avx2,但是可以自己编译。

Q7.tf 团队可以支持下微软的 python-language-server 团队吗,动态导入的包特性导致 vs code 的用户无法自动补全,tf2.0 让我可望不可即

A:请参考 https://github.com/microsoft/python-language-server/issues/818。

《简单粗暴 TensorFlow 2.0 》目录

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-10-16,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI研习社 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
mysql图形化工具使用教程_mysql图形化管理工具介绍
MySQL 有许多图形化的管理工具,我们在此介绍二个官方的工具「MySQL Administrator」及「MySQL Query Browser」。MySQL Administrator 是用来管理 MySQL Server 用的,您可以查看目前系统状态、新增使用者等。而 MySQL Query Browser 可以用来查看数据库内容。
全栈程序员站长
2022/09/12
3.2K0
mysql图形化工具使用教程_mysql图形化管理工具介绍
关于FTP搭建问题
1. 匿名服务器的连接(独立的服务器) 在/etc/vsftpd.conf(或在/etc/vsftpd/vsftpd.conf)配置文件中添加如下几项: Anonymous_enable=yes (允许匿名登陆) Dirmessage_enable=yes (切换目录时,显示目录下.message的内容) Local_umask=022 (FTP上本地的文件权限,默认是077) Connect_form_port_20=yes (启用FTP数据端口的数据连接)* Xferlog_enable=yes (激活上传和下载的日志) Xferlog_std_format=yes (使用标准的日志格式) Ftpd_banner=XXXXX (欢迎信息) Pam_service_name=vsftpd (验证方式)* Listen=yes (独立的VSFTPD服务器)* 注释:以上配置只能连接FTP服务器,不能上传和下载 注:其中所有和日志欢迎信息相关连的都是可选项,打了星号的无论什么帐户都要添加,是属于FTP的基本选项
云知识Online
2018/05/03
2.6K0
vsftpd如何用PAM去认证用户
vsftpd可能是世界上最好的ftpd。它在linux世界非常流行,安全,性能高。 本文的目的是让PgSQL存储你的vsftp的虚拟用户和密码,通过一个叫做pam的来认证。 零、简述PAM原理。 如果你已经对pam有所了解,请跳过,我知道的可能还不如你多。你不感兴趣也请跳过,因为不看这个也可配置。 代码: 用户 vsftpd PAM模块 用户和密码数据库 vsftpd用了一种很聪明同时也是unix/linux规范的方法来认证用户,就是PAM。大家对于PAM,也许有些陌生,但是一直在用。所谓PAM,英文是:Pluggable Authentication Modules,可拔插认证模块(不知道这样翻译对不对)。看见plug这个关键字,就知道是很灵活的。 现在几乎所有daemon程序一般都是用PAM来进行认证的,包括telnet/sshd/imapd,甚至你的login,都是用PAM。在 fbsd 4上的朋友,你可以打ps -ax|grep pam,你会发现login了多少个控制台,就会有多少个写着pam的进程。 PAM的最大好处是灵活。它不管你的用户和密码用什么数据格式存储(数据库也好,通常用的密码文件也好),只要有相应的PAM模块就可以存储。比如说,大家不仅可以用vsftpd + PgSQL做用户登陆验证,只要你喜欢你还可以用MySQL,Oracle,LDAP数据库存储用户数据,只要有相应的PAM就可以。所有的daemon 都可以用一个后台数据库来做用户验证登陆,包括telnet/sshd等等。 pam的配置机制在不同版本的freebsd上有差异。 freebsd-4放在/etc/pam.conf,一个文件记录所有pam服务。 freebsd-5放在/etc/pam.d,/usr/local/etc/pam.d。每个pam服务由一个独立的文件记录。 本文不打算详细叙述PAM的配置。PAM的配置不是很难,毕竟,只是要你配置一些参数,不是叫你开发一个pam模块出来。而且本文的篇幅所限,偶刚刚知道的一点东西希望能够起到抛砖引玉的作用。等偶对pam再玩得深入和熟一点的时候,再写一篇关于深入一点关于pam的东东? 准备开始:提要 简单讲讲要用到的配置文件的作用。 引用: /etc/pam.conf #pam服务的配置 /etc/pam_pgsql.conf #pam_pgsql.so的配置 /usr/local/etc/vsftpd.conf #vsftpd的配置 一、安装vsftpd,PostgreSQL,pam_pgsql。 我都是使用port来安装的,请大家用port/package来安装,不要自己下载源码来编译,否则可能根据本文的方法可能无法正常使用。其中vsftpd和pam-pgsql一定要用port/package来安装。 以下是他们的port目录: 引用: /usr/ports/ftp/vsftpd /usr/ports/databases/postgresql7 /usr/ports/security/pam-pgsql 安装:只要cd进去,然后make install就OK了。 二、PostgreSQL安装(如果你已经有了PostgreSQL,不需要看这一节) 简单提提用port来装PostgreSQL的过程,因为BSD版上的装PgSQL的方法都是自己下载源码编译的。我是用port来编译安装,因为这是fbsd推荐的安装方法,而且安装的软件会根据bsd的hier(目录结构)来安装,比较便于管理。 当用port来安装好PostgreSQL,默认的数据库管理用户是pgsql(port里头的安装程序自动添加的),其他系统默认的是postgres。初始化PostgreSQL的程序如下: 1、初始数据库。请先用root登陆或者su到root。然后,打命令: 代码: # su pgsql # initdb 正常初始化的应该有以下提示: 引用:
会长君
2023/04/25
1.3K0
SELinux的基本使用
从进入了 CentOS 5.x 之后的 CentOS 版本中 (当然包括 CentOS 7),SELinux 已经是个非常完备的核心模块了!尤其 CentOS 提供了很多管理 SELinux 的指令与机制,因此在整体架构上面是单纯且容易操作管理的!所以,在没有自行开发网络服务软件以及使用其他第三方协力软件的情况下,也就是全部使用 CentOS 官方提供的软件来使用我们服务器的情况下,建议大家不要关闭 SELinux ! 让我们来仔细的玩玩这家伙吧!
小柒吃地瓜
2020/04/23
3K0
vsftpd简介及搭建配置
FTP(文件传输协议)全称是:Very Secure FTP Server。  Vsftpd是linux类操作系统上运行的ftp服务器软件。
星哥玩云
2022/07/14
5.5K0
centos部署ftp服务_文件服务器搭建
vsftpd配置文件的默认路径是 /etc/vsftpd/vsftpd.conf。
全栈程序员站长
2022/10/01
1.7K0
使用hyper backup与rsync将数据备份到unraid
整理自: https://forums.unraid.net/topic/2100-rsync/
超级大猪
2021/11/22
3.9K1
使用hyper backup与rsync将数据备份到unraid
Linux上安装配置Nginx与ftp服务
首先在Nginx官网下载稳定版本的Nginx安装包,并将安装包上传到Linux。 使用 tar -zxvf nginx-1.16.0.tar.gz 将压缩包解压。
Java阿呆
2020/11/04
4.7K0
Linux上安装配置Nginx与ftp服务
FTP文件服务器
FTP (File transfer protocol) 是TCP/IP 协议组中的协议之一。他最主要的功能是在服务器与客户端之间进行文件的传输。FTP就是实现两台计算机之间的拷贝,从远程计算机拷贝文件至自己的计算机上,称之为“下载 (download)”文件。将文件从自己计算机中拷贝至远程计算机上,则称之为“上传(upload)”文件。这个古老的协议使用的是明码传输方式,且过去有相当多的安全危机历史。为了更安全的使用 FTP 协议,我们主要介绍较为安全但功能较少的 vsftpd(very secure File transfer protocol ) 这个软件。FTP是一个C/S类型的软件,FTP监听TCP端口号为21,数据端口为20。
星哥玩云
2022/09/15
22.6K0
vsftp服务器常规参数配置大全(二)
4. IP监听与连接控制      vsftpd工作在独立模式(standalone)下的启动参数有两项:
会长君
2023/04/25
2.6K0
Apache的httpd.conf文件配置详解
Apache的基本设置主要交由httpd.conf来设定管理,我们要修改Apache的相关设定,主要还是通过修改httpd.cong来实现。下面让我们来看看httpd.conf的内容,它主要分成3大部分: Section 1:Global Environment Section 2:'Main' server configurationphpma.com Section 3:Virtual Hosts 【第一部分】 ·ServerType standalone 这表示Apache是以standalone启动
wangxl
2018/03/07
2.7K0
RHEL6.4 搭建FTP服务器
[root@ftp-server ~]# vim /etc/vsftpd/vsftpd
星哥玩云
2022/07/01
3720
如何增强Linux和Unix服务器的安全性
操作系统内部的记录文件是检测是否有网络入侵的重要线索。如果您的系统是直接连到internet,您发现有很多人对您的系统做telnet/ftp登录尝试,可以运行"#more /var/log/secure grep refused"来检查系统所受到的攻击,以便采取相应的对策,如使用ssh来替换telnet/rlogin等。
会长君
2023/04/26
9400
vsftp配置文件详解
-vsftpd.user_list文件需要与vsftpd.conf文件中的配置项结合来实现对于vsftpd.user_list文件中指定用户账号的访问控制:
孤鸿
2022/09/23
4.2K0
sshd_config详解
# 1. 关于 SSH Server 的整体设定,包含使用的 port 啦,以及使用的密码演算方式
呆呆
2021/05/26
2K0
Linux下使用Nginx+vsftpd搭建图片服务器
传统项目中,可以在web项目中添加一个文件夹,来存放上传的图片。例如在工程的根目录WebRoot下创建一个images文件夹。把图片存放在此文件夹中就可以直接使用在工程中引用。
星哥玩云
2022/07/27
1.2K0
Linux下使用Nginx+vsftpd搭建图片服务器
ubuntu 16.04 搭建ftp服务器
如果登录后出现如下错误,则在/etc/vsftpd.conf文件内添加allow_writeable_chroot=YES
全栈程序员站长
2022/09/14
1.6K0
【Linux】《how linux work》第十章 网络应用和服务(1)
This chapter explores basic network applications—the clients and servers running in user space that reside at the application layer. Because this layer is at the top of the stack, close to end users, you may find this material more accessible than the material in Chapter 9. Indeed, you interact with network client applications such as web browsers and email readers every day.
阿东
2024/04/27
1750
【Linux】《how linux work》第十章 网络应用和服务(1)
解决Centos下vsftp无法上传文件的问题,附vsftp配置详解
重量网络最近买了一个腾讯云的 VPS,一直在折腾着,偶然请我帮忙敲几行命令解决一些小问题。 这不,今天他通过 yum 在线安装了一个 vsftp 后,发现不太会用,就按照网上的教程东搞西搞。最后发现无法上传文件了,就给我操作了一把。 用 SecureCRT 远程登录后,我做了如下检查: ①、打开了 vsftp 配置文件(/etc/vsftpd/vsftpd.conf)检查 write_enable 状态,发现正常: [root@VM_72_108_centos /]# cat /etc/vsftpd/vsf
张戈
2018/03/23
6.3K0
红帽子linux 架设ftp,RedHatLinux9架设FTP服务器方法[通俗易懂]
vsftpd是目前Linux最好的FTP服务器工具之一,其中的vs就是“VerySecure”(很安全)的缩写,可见它的最大优点就是安全,除此之外,它还具有体积小,可定制强,效率高的优点。
全栈程序员站长
2022/09/14
2K0
相关推荐
mysql图形化工具使用教程_mysql图形化管理工具介绍
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档