fashion_mnist 和 mnist 一样,都是深度学习入门用的简单数据集,两者的图片尺寸一样,都是28x28。fashion_mnist的训练集有6万张图片,测试集有1万张图片,全是衣服、鞋、包包之类的图片,共10个类别:
Label Class:
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot
下图显示的是训练集中的前25张图片:
下面的代码用于训练CNN:
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 21 18:12:16 2019
@author: Administrator
"""
import tensorflow as tf
print(tf.__version__)
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels),(test_images,test_labels) = fashion_mnist.load_data()
train_images, test_images = train_images/255.0, test_images/255.0
'''
Label
Class
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot
'''
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
print(train_images.shape)
print(len(train_labels))
'''
plt.figure(figsize=(10,10))
for i in range(50):
plt.subplot(5,10,i+1)
plt.xticks([])
plt.yticks([])
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i]])
plt.show()
'''
#加入新的维度,Conv2D需要颜色chanels维度 #彩色图片数据集就不需要
train_images = train_images[..., tf.newaxis]
test_images = test_images[..., tf.newaxis]
model = keras.Sequential()
# an `input_shape` passed to the first layer
model.add(keras.layers.Conv2D(input_shape=(28,28,1),
filters=32, kernel_size=(3,3),activation='relu'))
model.add(keras.layers.Conv2D(filters=64, kernel_size=(3,3),activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=(3,3),strides=1,padding='same'))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(250, activation='relu'))
model.add(keras.layers.Dropout(0.3))
model.add(keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer ='adam', loss ='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images, train_labels, epochs =10, batch_size =200)
tess_loss, test_acc = model.evaluate(test_images,test_labels, verbose=2)
print('\nTest accuracy: ', test_acc)
model.save('my fashion_mnist mode.h5')
10个Epoch后,测试集上的准确度已达93.21%:
下面的代码用于预测一组图片(测试集):
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 21 18:12:16 2019
@author: Administrator
"""
import tensorflow as tf
print(tf.__version__)
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
fashion_mnist = keras.datasets.fashion_mnist
test_images,test_labels = fashion_mnist.load_data()[1]
test_images = test_images/255.0
test_images = test_images[..., tf.newaxis]
'''
Label
Class
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot
'''
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
model = keras.models.load_model('my fashion_mnist mode.h5')
#预测一组图片,
predictions = model.predict(test_images)
plt.figure(figsize=(10,10))
for i in range(25):
index = tf.argmax(predictions[i])
index = int(index)
print(index)
name = class_names[index]
print(name)
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
#plt.imshow(test_images[i], cmap=plt.cm.binary)
plt.imshow(np.reshape(test_images[i],(28,28)),cmap=plt.cm.binary)
plt.xlabel("True : %s"%class_names[test_labels[i]])
plt.title("prediction: %s"%name)
plt.tight_layout()
plt.show()
可以看出,测试集的前25张图片全部都能正确识别:
本文分享自 Python可视化编程机器学习OpenCV 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有