Loading [MathJax]/jax/output/CommonHTML/config.js
前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >Tensorflow2.X中使用自编码器图像重构实战---文末送书

Tensorflow2.X中使用自编码器图像重构实战---文末送书

作者头像
Color Space
发布于 2022-09-26 02:19:18
发布于 2022-09-26 02:19:18
57600
代码可运行
举报
运行总次数:0
代码可运行

图像重构是计算机视觉领域里一种经典的图像处理技术,而自编码器算法便是实现该技术的核心算法之一。在了解了自编码器的基本原理之后,本节就通过实例讲解如何利用Tensorflow2.X来一步步地搭建出一个自编码器并将其应用于MNIST手写图像数据的重构当中。

01 编译器模块搭建

在本节中,使用MNIST手写数据集来进行自编码器模型的训练。首先需要搭建的是编码器网络,如前面所述,它的作用是使网络中的输入数据不断地降维变成低维度的隐变量。

首先导入相关的第三方库:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt

加载MNIST数据集并对其进行预处理:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data()   #加载MNIST数据集
x_train = x_train.reshape(-1,784).astype('float32')/255   #训练集图像打平并归一化 
x_test = x_test.reshape(-1,784).astype('float32')/255  

在这里使用3层全连接层作为编码器的网络结构,即输入维度为784的图像会不断地经过3层网络并降维变成512,256和60。其中每一层网络都使用ReLu作为激活函数并对神经元权重进行正态分布初始化:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 编码器网络
Encoder = tf.keras.models.Sequential([                       
    layers.Dense(512, kernel_initializer = 'normal', activation = 'relu'),
    layers.Dense(256, kernel_initializer = 'normal', activation = 'relu'),
    layers.Dense(60, kernel_initializer = 'normal', activation = 'relu')
])

02 解码器模块搭建

解码器网络实质上就是对编码器输出的隐变量进行一次次的上采样,最后输出再还原成和原输入数据相同维度的数据。在这里将之前得到的维度为60的数据再依次升维到256,512和784。同样地,对每一层网络的神经元权重进行正态分布初始化,并将最后一层激活函数换成Sigmoid函数以便于将输出转为像素值:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 解码器网络
Decoder = tf.keras.models.Sequential([
    layers.Dense(256, kernel_initializer = 'normal', activation = 'relu'),
    layers.Dense(512, kernel_initializer = 'normal', activation = 'relu'),
    layers.Dense(784, kernel_initializer = 'normal', activation = 'sigmoid')
])

03 自编码器模型

将上述的编码器和解码器进行结合便可得到完整的自编码器模型。整个自编码器网络结构如图1所示。

图1 自编码器网络

为了方便,可以将编码器和解码器代码封装成类,并将传播过程实现在call函数当中:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
class Autoencoder(tf.keras.Model):
    def __init__(self):
        super(Autoencoder,self).__init__()
        self.Encoder = tf.keras.models.Sequential([   #编码器网络
            layers.Dense(512, kernel_initializer = 'normal', activation = 'relu'),
            layers.Dense(256, kernel_initializer = 'normal', activation = 'relu'),
            layers.Dense(60, kernel_initializer = 'normal', activation = 'relu')
        ])
        self.Decoder = tf.keras.models.Sequential([   #解码器网络
            layers.Dense(256, kernel_initializer = 'normal', activation = 'relu'),
            layers.Dense(512, kernel_initializer = 'normal', activation = 'relu'),
            layers.Dense(784, kernel_initializer = 'normal', activation = 'sigmoid')
        ])
    def call(self,input_features,training = None):   #前向传播
        code = self.Encoder(input_features)   #数据编码
        reconstructed = self.Decoder(code)   #数据解码
        return reconstructed

搭建好自编码器的网络模型之后,下一步便是对该网络进行训练。在训练之前,首先需要配置训练过程所需的优化器及损失函数等参数。在这里选择了Adam优化器以及使用经典的二元交叉熵作为模型的损失函数:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
model = Autoencoder() 
model.compile(optimizer = 'adam', loss = 'binary_crossentropy') 

配置好所需参数之后,可以正式开始训练已搭建好的模型。这里选用测试集的前4000张作为验证集,而其余作为测试集,由于自编码器模型为无监督训练模型,因此这里输入数据的标签等于输入自身:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
model.fit(x_train,x_train, epochs = 10, batch_size = 256, shuffle = True, validation_data = (x
_test[:4000], x_test[:4000]))  

输出的结果为:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
Train on 60000 samples, validate on 4000 samples
……….
Epoch 5/10
60000/60000 [==============================] - 11s 184us/sample - loss: 0.0907 - val_loss: 0.0896
Epoch 6/10
60000/60000 [==============================] - 10s 170us/sample - loss: 0.0877 - val_loss: 0.0862
Epoch 7/10
60000/60000 [==============================] - 10s 171us/sample - loss: 0.0855 - val_loss: 0.0845
Epoch 8/10
60000/60000 [==============================] - 11s 175us/sample - loss: 0.0838 - val_loss: 0.0832
Epoch 9/10
60000/60000 [==============================] - 11s 190us/sample - loss: 0.0823 - val_loss: 0.0821
Epoch 10/10
60000/60000 [==============================] - 11s 182us/sample - loss: 0.0811 - val_loss: 0.0814

在使用训练集训练好模型之后,还需要对其进行进一步的测试。测试集的后6000张图像被应用于测试图像的重构效果,之后再构建可视化函数对其显示:

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
#模型测试
decoded_imgs = model.predict(x_test[4000:])
#原图像与重构后的图像对比
plt.figure(figsize = (20, 4))
n = 10
for i in range(n):
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(tf.reshape(x_test[4000+i], [28, 28]))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

图像重构的结果如图2所示,图中分别展示了训练1和10个epoch的模型测试效果:

图2 1,10个epoch训练模型的图像重构效果对比

本文选自水利水电出版社的《深度学习实战:基于TensorFlow2.X的计算机视觉开发应用 》一书,略有修改,经出版社授权刊登于此。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-08-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV与AI深度学习 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
Python 运算符与数据类型
运算符用于执行程序代码运算,会针对一个以上操作数项目来进行运算,在Python中运算符大致可以分为7种类型:算术运算符、比较运算符、赋值运算符、逻辑运算符、位运算等,下面的例子将依次介绍这几种运算符的使用技巧.
王瑞MVP
2022/12/28
1.9K0
条件语句/变量和基本数据类型
在32位机器上,整数的位数为32位,取值范围为-2**31~2**31-1,即-2147483648~2147483647   在64位系统上,整数的位数为64位,取值范围为-2**63~2**63-1,即-9223372036854775808~9223372036854775807
py3study
2020/01/17
2K0
Python基础之数据类型详解
数字类型与其他编程语言类似,这里不再具体讲解。作为Python中最重要的基础知识,下面主要梳理下字符串、列表、元组、字典、集合的核心知识点。
吾非同
2020/10/13
1K0
python初级:基础知识学习-变量、数据类型、运算符、选择结构
变量是程序中临时存储数据的容器。 变量的赋值:向变量中存储数据 语法:变量名称 = 数据 python代码中,出现了等号~通常情况就是向左边的变量中存储数据 变量作为一个容器,对于数据的操作一般只有四种:增加、删除、修改、查询
全栈程序员站长
2021/09/26
5900
基本数据类型(二)
  列表是 Python 最常用的数据类型,它是有序元素的集合,元素之间以逗号分隔,用中括号括起来,可以是任何数据类型。同时它也是一种序列,支持索引、切片、加、乘和成员检查等。
py3study
2020/01/16
6780
Python--4 基本数据类型
  字符串str是在Python编写程序过程中,最常见的一种基本数据类型。字符串是许多单个子串组成的序列,其主要是用来表示文本。字符串是不可变数据类型,也就是说你要改变原字符串内的元素,只能是新建另一个字符串。
py3study
2020/01/19
9540
Python--4 基本数据类型
python数据类型,格式话输出
 一.程序交互 name = input(“你的名字是:”) #用户输入,输入的任何东西都存储成str(字符串类型)的形式 二.注释的重要性   以后动辄几千行代码的时候,回过头再去看的时候,发现自己都看不懂了,在工作中还会大家一起合作完成代码,不写注释的话,更难以交流了。 单行注释直接在句首写上#就好了 多行注释可用快捷键ctrl+/,或者用三个引号括起来''' 99999999                          12345789                      
py3study
2020/01/19
1.3K0
Python知识点(史上最全)
type()不会认为子类是一种父类类型。 isinstance()会认为子类是一种父类类型
全栈程序员站长
2022/08/27
8360
Python知识点(史上最全)
python数据类型
代码注释分单行和多行注释, 单行注释用#,多行注释可以用三对双引号"""  """
py3study
2020/01/19
5690
【Python】从基础变量类型到各种容器(列表、字典、元组、集合、字符串)
反向索引:从-1开始,-1代表最后一个,-2代表倒数第二个,以此类推,第一个是-len(s)。
杨丝儿
2022/02/17
2.4K0
【Python】从基础变量类型到各种容器(列表、字典、元组、集合、字符串)
02 . Python之数据类型
变量存储在内存中的值。这就意味着在创建变量时会在内存中开辟一个空间。基于变量的数据类型,解释器会分配指定内存,并决定什么数据可以被存储在内存中。 因此,变量可以指定不同的数据类型,这些变量可以存储整数,小数或字符.
iginkgo18
2020/09/27
1.8K0
02 . Python之数据类型
python 基础 数据类型
1、变      量:变量是计算机内存中的一块儿区域,变量可以存储规定范围内的值,而且值可以改变。
py3study
2020/01/07
6670
基本数据类型、输入输出、运算符
数据类型值是变量值的类型,变量值之所以区分类型,是因为变量值是用来记录事物状态的,而事物的状态有不同的种类,对应着,也必须使用不同类型的值去记录它们。
py3study
2020/01/17
5910
python_列表_元组_字典
insert(index, object) 在指定位置index前插入元素object
以某
2023/03/07
2.4K0
python_列表_元组_字典
​Python数据类型
序列是Python中最基本的数据结构。序列中的每个元素都分配一个数字 - 它的位置,或索引,第一个索引是0,第二个索引是1,依此类推。
PayneWu
2020/12/18
7480
Python基础(三) | Python的组合数据类型
d.get(key,default) 从字典d中获取键key对应的值,如果没有这个键,则返回default
timerring
2022/09/27
2.7K0
Python基础(三) | Python的组合数据类型
python3_03.数据类型
  Python中的字符串用单引号(')或双引号(")括起来,同时使用反斜杠(\)转义特殊字符。
py3study
2020/01/03
5960
python基础--数据类型
在Python3中有六个标准的数据类型:Number(数字)、String(字符串)、List(列表)、Tuple(元组)、Set(集合)、Dictionary(字典),
ypoint
2019/08/15
1.6K1
Python入门(三):数据结构
切换list[begin:end],获取切换list内元素,从begin开始,至end结束,不包含end
披头
2019/12/26
1.1K0
第一章 python入门
                                        2.获取用户名跟密码,如果用户名是:root  密码是:root 提示正确登录,否则登录失败
py3study
2020/01/17
6270
相关推荐
Python 运算符与数据类型
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验