前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用tensorflow构建一个卷积神经网络

使用tensorflow构建一个卷积神经网络

作者头像
生信修炼手册
发布2021-07-06 16:38:00
7640
发布2021-07-06 16:38:00
举报
文章被收录于专栏:生信修炼手册

欢迎关注”生信修炼手册”!

本文是对tensforflow官方入门教程的学习和翻译,展示了创建一个基础的卷积神经网络模型来解决图像分类问题的过程。具体步骤如下

1. 加载数据集

tensorflow集成了keras这个框架,提供了CIFAR10数据集,该数据集包含了10个类别共6万张彩色图片,加载方式如下

代码语言:javascript
复制
>>> import tensorflow as tf
>>> from tensorflow.keras import datasets,layers, models
>>> import matplotlib.pyplot as plt
>>> (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 33s 0us/step
>>> train_images, test_images = train_images / 255.0, test_images / 255.0

可以通过如下代码来查看部分图片

代码语言:javascript
复制
>>> for i in range(25):
...     plt.subplot(5, 5, i + 1)
...     plt.xticks([])
...     plt.yticks([])
...     plt.grid(False)
...     plt.imshow(train_images[i], cmap = plt.cm.binary)
...     plt.xlabel(class_names[train_labels[i][0]])
...
>>> plt.show()

可视化效果如下

2. 构建卷积神经网络

通过keras的Sequential API来构建卷积神经网络,依次添加卷积层,池化层,全连接层,代码如下

代码语言:javascript
复制
>>> model = models.Sequential()
>>> model.add(layers.Conv2D(32, (3, 3), activation = "relu", input_shape = (32, 32, 3)))
>>> model.add(layers.MaxPooling2D((2, 2)))
>>> model.add(layers.Conv2D(64, (3,3), activation = "relu"))
>>> model.add(layers.MaxPooling2D((2, 2)))
>>> model.add(layers.Conv2D(64, (3, 3), activation = "relu"))
>>> model.add(layers.Flatten())
>>> model.add(layers.Dense(64, activation = "relu"))
>>> model.add(layers.Dense(10))
>>> model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_1 (Conv2D)            (None, 30, 30, 32)        896
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 32)        0
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 13, 13, 64)        18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64)          0
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 4, 4, 64)          36928
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0
_________________________________________________________________
dense_1 (Dense)              (None, 64)                65600
_________________________________________________________________
dense_2 (Dense)              (None, 10)                650
=================================================================
Total params: 122,570
Trainable params: 122,570
Non-trainable params: 0
_________________________________________________________________

3. 编译模型

模型在训练之前,必须对其进行编译,主要是确定损失函数,优化器以及评估分类效果好坏的指标,代码如下

代码语言:javascript
复制
>>> model.compile(optimizer = 'adam', loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics = ['accuracy'])

4. 训练模型

使用训练集训练模型,代码如下

代码语言:javascript
复制
>>> history = model.fit(train_images, train_labels, epochs = 10, validation_data = (test_images, test_labels))
2021-06-23 10:59:43.386592: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/10
1563/1563 [==============================] - 412s 203ms/step - loss: 1.5396 - accuracy: 0.4380 - val_loss: 1.2760 - val_accuracy: 0.5413
Epoch 2/10
1563/1563 [==============================] - 94s 60ms/step - loss: 1.1637 - accuracy: 0.5850 - val_loss: 1.1193 - val_accuracy: 0.6084
Epoch 3/10
1563/1563 [==============================] - 95s 61ms/step - loss: 1.0210 - accuracy: 0.6398 - val_loss: 0.9900 - val_accuracy: 0.6556
Epoch 4/10
1563/1563 [==============================] - 88s 56ms/step - loss: 0.9186 - accuracy: 0.6781 - val_loss: 0.9399 - val_accuracy: 0.6687
Epoch 5/10
1563/1563 [==============================] - 95s 61ms/step - loss: 0.8472 - accuracy: 0.7023 - val_loss: 0.8984 - val_accuracy: 0.6868
Epoch 6/10
1563/1563 [==============================] - 85s 55ms/step - loss: 0.7917 - accuracy: 0.7220 - val_loss: 0.8896 - val_accuracy: 0.6888
Epoch 7/10
1563/1563 [==============================] - 88s 56ms/step - loss: 0.7450 - accuracy: 0.7381 - val_loss: 0.8843 - val_accuracy: 0.6974
Epoch 8/10
1563/1563 [==============================] - 87s 55ms/step - loss: 0.7024 - accuracy: 0.7530 - val_loss: 0.8403 - val_accuracy: 0.7089
Epoch 9/10
1563/1563 [==============================] - 92s 59ms/step - loss: 0.6600 - accuracy: 0.7676 - val_loss: 0.8512 - val_accuracy: 0.7095
Epoch 10/10
1563/1563 [==============================] - 91s 58ms/step - loss: 0.6240 - accuracy: 0.7790 - val_loss: 0.8483 - val_accuracy: 0.7119

通过比较训练集和验证集的准确率曲线,可以判断模型训练是否有过拟合等问题,代码如下

代码语言:javascript
复制
>>> plt.plot(history.history['accuracy'], label='accuracy')
[<matplotlib.lines.Line2D object at 0x000001AAC62A7B08>]
>>> plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
[<matplotlib.lines.Line2D object at 0x000001AAC28F8988>]
>>> plt.xlabel('Epoch')
Text(0.5, 0, 'Epoch')
>>> plt.ylabel('Accuracy')
Text(0, 0.5, 'Accuracy')
>>> plt.ylim([0.5, 1])
(0.5, 1.0)
>>> plt.legend(loc='lower right')
<matplotlib.legend.Legend object at 0x000001AAC62A7688>
>>> plt.show()

结果如下

当模型过拟合时,会看到accuracy非常高,而val_accuracy较低,两条线明显偏离。从上图中看到,两个准确率比较接近,没有明显的分离现象,而且值都比较低,模型存在欠拟合的问题。

5. 评估模型

用测试集评估模型效果,结果如下

代码语言:javascript
复制
>>> test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
313/313 - 7s - loss: 0.8483 - accuracy: 0.7119

>>> print(test_acc)
0.711899995803833

准确率达到了70%,对于一个由几行代码快速构建的初步卷积神经网络模型而言,这个效果还可以接受。后续可以考虑数据增强,模型改进,调整学习率等方式,来提高模型的准确率。

·end·

—如果喜欢,快分享给你的朋友们吧—

原创不易,欢迎收藏,点赞,转发!生信知识浩瀚如海,在生信学习的道路上,让我们一起并肩作战!

本公众号深耕耘生信领域多年,具有丰富的数据分析经验,致力于提供真正有价值的数据分析服务,擅长个性化分析,欢迎有需要的老师和同学前来咨询。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-06-23,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 生信修炼手册 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档