前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >用Tensorflow识别手写体

用Tensorflow识别手写体

作者头像
用户3577892
发布2020-06-12 09:11:41
4.2K0
发布2020-06-12 09:11:41
举报
文章被收录于专栏:数据科学CLUB数据科学CLUB

数据准备

代码语言:javascript
复制
import tensorflow as tfimport tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
代码语言:javascript
复制
WARNING:tensorflow:From <ipython-input-1-6bfbaa60ed82>:3: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use alternatives such as official/mnist/dataset.py from tensorflow/models.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.Instructions for updating:Please write your own downloading logic.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.Instructions for updating:Please use urllib or similar directly.Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.data to implement this functionality.Extracting MNIST_data/train-images-idx3-ubyte.gzSuccessfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.data to implement this functionality.Extracting MNIST_data/train-labels-idx1-ubyte.gzWARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.one_hot on tensors.Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.Extracting MNIST_data/t10k-images-idx3-ubyte.gzSuccessfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.Extracting MNIST_data/t10k-labels-idx1-ubyte.gzWARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use alternatives such as official/mnist/dataset.py from tensorflow/models.

建立共享函数

定义weight函数
代码语言:javascript
复制
def weight(shape):    return tf.Variable(tf.truncated_normal(shape, stddev=0.1),                       name ='W')
定义bias函数
代码语言:javascript
复制
def bias(shape):    return tf.Variable(tf.constant(0.1, shape=shape)                       , name = 'b')
定义conv2d函数
代码语言:javascript
复制
def conv2d(x, W):    return tf.nn.conv2d(x, W, strides=[1,1,1,1], #filter每次移动时从左到右,从上到下各一步                        padding='SAME')
建立池化函数
代码语言:javascript
复制
def max_pool_2x2(x):    return tf.nn.max_pool(x, ksize=[1,2,2,1], #设置采样窗口的大小,height=2,width=2                          strides=[1,2,2,1], #设置步长,从左到右,从上到下个两步                          padding='SAME')

建立模型

输入层 Input Layer

x_image的参数说明

  • 第一维是-1:因为后续通过placeholder输入的参数的个数不一定,所以设置为-1
  • 第二维和第三维是28,28:输入的数字大小是28*28
  • 第四维是1,因为是单色,所以设置为1,如果是彩色设置为3
代码语言:javascript
复制
with tf.name_scope('Input_Layer'):#设置计算图的输入名称    x = tf.placeholder("float",shape=[None, 784]                       ,name="x")        x_image = tf.reshape(x, [-1, 28, 28, 1])
Convolutional Layer 1

W1参数的解释

  • 第一维和第二维均是5:代表filter的大小是5*5
  • 第三维是1:单色设置为1,彩色设置为3
  • 第四维是16:要产生16个图像
代码语言:javascript
复制
with tf.name_scope('C1_Conv'):    W1 = weight([5,5,1,16])    b1 = bias([16])    Conv1=conv2d(x_image, W1)+ b1    C1_Conv = tf.nn.relu(Conv1 )
代码语言:javascript
复制
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.Instructions for updating:Colocations handled automatically by placer.

建立池化层函数的好处

  • 减少所需要处理的数据点
  • 让图像位置的差异变小
  • 参数的数量和计算量下降
代码语言:javascript
复制
with tf.name_scope('C1_Pool'):    C1_Pool = max_pool_2x2(C1_Conv)
Convolutional Layer 2
代码语言:javascript
复制
with tf.name_scope('C2_Conv'):    W2 = weight([5,5,16,36])#将原来的16个图像转换为36个    b2 = bias([36])    Conv2=conv2d(C1_Pool, W2)+ b2    C2_Conv = tf.nn.relu(Conv2)
代码语言:javascript
复制
with tf.name_scope('C2_Pool'):    C2_Pool = max_pool_2x2(C2_Conv) 
Fully Connected Layer

D_Flat参数的解释

  • C2_Pool:此参数为要进行的reshape张量
  • 列表第一维-1:因为传入的是不限定项数的训练数据
  • 列表第二维1764:因为传入的张量是36个7*7的图像
代码语言:javascript
复制
with tf.name_scope('D_Flat'):    D_Flat = tf.reshape(C2_Pool, [-1, 1764])
代码语言:javascript
复制
with tf.name_scope('D_Hidden_Layer'):    W3= weight([1764, 128])#隐藏层的神经元个数为128    b3= bias([128])    D_Hidden = tf.nn.relu(                  tf.matmul(D_Flat, W3)+b3)    D_Hidden_Dropout= tf.nn.dropout(D_Hidden,                                 keep_prob=0.8)#要保留的神经元的比例
代码语言:javascript
复制
WARNING:tensorflow:From <ipython-input-12-b635345e166c>:7: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.Instructions for updating:Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
输出层
代码语言:javascript
复制
with tf.name_scope('Output_Layer'):    W4 = weight([128,10])    b4 = bias([10])    y_predict= tf.nn.softmax(                 tf.matmul(D_Hidden_Dropout,                           W4)+b4)

设置训练模型最优化步骤

代码语言:javascript
复制
with tf.name_scope("optimizer"):    y_label = tf.placeholder("float", shape=[None, 10],                               name="y_label")    loss_function = tf.reduce_mean(                      tf.nn.softmax_cross_entropy_with_logits_v2                         (logits=y_predict ,                           labels=y_label))    optimizer = tf.train.AdamOptimizer(learning_rate=0.0001) \                    .minimize(loss_function)

设置评估模型

代码语言:javascript
复制
with tf.name_scope("evaluate_model"):    correct_prediction = tf.equal(tf.argmax(y_predict, 1),                                  tf.argmax(y_label, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

训练模型

代码语言:javascript
复制
trainEpochs = 10batchSize = 100totalBatchs = int(mnist.train.num_examples/batchSize)epoch_list=[];accuracy_list=[];loss_list=[];from time import timestartTime=time()sess = tf.Session()sess.run(tf.global_variables_initializer())
代码语言:javascript
复制
for epoch in range(trainEpochs):    for i in range(totalBatchs):        batch_x, batch_y = mnist.train.next_batch(batchSize)        sess.run(optimizer,feed_dict={x: batch_x,                                      y_label: batch_y})    loss,acc = sess.run([loss_function,accuracy],                        feed_dict={x: mnist.validation.images,                                    y_label: mnist.validation.labels})    epoch_list.append(epoch)    loss_list.append(loss);accuracy_list.append(acc)        print("Train Epoch:", '%02d' % (epoch+1), \          "Loss=","{:.9f}".format(loss)," Accuracy=",acc)duration =time()-startTimeprint("Train Finished takes:",duration)         
代码语言:javascript
复制
Train Epoch: 01 Loss= 1.656932473  Accuracy= 0.827Train Epoch: 02 Loss= 1.613922596  Accuracy= 0.8558Train Epoch: 03 Loss= 1.598174453  Accuracy= 0.8692Train Epoch: 04 Loss= 1.510785699  Accuracy= 0.9574Train Epoch: 05 Loss= 1.500687838  Accuracy= 0.9658Train Epoch: 06 Loss= 1.495839953  Accuracy= 0.9684Train Epoch: 07 Loss= 1.491830468  Accuracy= 0.9726Train Epoch: 08 Loss= 1.489337087  Accuracy= 0.9742Train Epoch: 09 Loss= 1.486868739  Accuracy= 0.9774Train Epoch: 10 Loss= 1.484916449  Accuracy= 0.9792Train Finished takes: 720.7906568050385
代码语言:javascript
复制
%matplotlib inlineimport matplotlib.pyplot as pltfig = plt.gcf()fig.set_size_inches(4,2)plt.plot(epoch_list, loss_list, label = 'loss')plt.ylabel('loss')plt.xlabel('epoch')plt.legend(['loss'], loc='upper left')
代码语言:javascript
复制
<matplotlib.legend.Legend at 0x7fd9d2ab8b38>
代码语言:javascript
复制
plt.plot(epoch_list, accuracy_list,label="accuracy" )fig = plt.gcf()fig.set_size_inches(4,2)plt.ylim(0.8,1)plt.ylabel('accuracy')plt.xlabel('epoch')plt.legend()plt.show()

评估模型的准确率

代码语言:javascript
复制
len(mnist.test.images)
代码语言:javascript
复制
10000
代码语言:javascript
复制
print("Accuracy:",       sess.run(accuracy,feed_dict={x: mnist.test.images,                                   y_label: mnist.test.labels}))
代码语言:javascript
复制
Accuracy: 0.9792
代码语言:javascript
复制
print("Accuracy:",       sess.run(accuracy,feed_dict={x: mnist.test.images[:5000],                                   y_label: mnist.test.labels[:5000]}))
代码语言:javascript
复制
Accuracy: 0.968
代码语言:javascript
复制
print("Accuracy:",       sess.run(accuracy,feed_dict={x: mnist.test.images[5000:],                                   y_label: mnist.test.labels[5000:]}))
代码语言:javascript
复制
Accuracy: 0.9886

预测概率

代码语言:javascript
复制
y_predict=sess.run(y_predict,                    feed_dict={x: mnist.test.images[:5000]})
代码语言:javascript
复制
y_predict[:5]
代码语言:javascript
复制
array([[4.05578522e-12, 6.15486123e-14, 5.71559293e-12, 1.74847949e-11,        2.71332728e-17, 8.90746643e-11, 5.53451119e-21, 1.00000000e+00,        1.10875556e-13, 8.30471913e-10],       [9.93732328e-07, 4.50552989e-06, 9.99993682e-01, 8.04418278e-07,        2.64564185e-14, 1.46194583e-14, 1.56614929e-10, 6.01911912e-14,        3.10939221e-08, 2.34203085e-15],       [1.31195605e-08, 9.99897718e-01, 4.33765905e-07, 2.02467453e-11,        9.89620676e-05, 6.53400056e-10, 6.65772149e-08, 2.67320161e-06,        5.65030227e-08, 4.94121100e-09],       [9.99993682e-01, 6.49839280e-11, 3.86714616e-09, 2.97008674e-13,        8.59991689e-10, 1.07891083e-11, 6.35852575e-06, 1.65313943e-10,        2.73128520e-10, 5.31917976e-08],       [1.19434844e-06, 9.74953984e-09, 2.05678519e-09, 2.03244167e-14,        9.99985814e-01, 1.02013356e-10, 7.86321621e-08, 3.65643515e-08,        6.86227242e-10, 1.28641732e-05]], dtype=float32)

预测结果

代码语言:javascript
复制
prediction_result=sess.run(tf.argmax(y_predict,1),                           feed_dict={x: mnist.test.images ,                                      y_label: mnist.test.labels})
代码语言:javascript
复制
prediction_result[:10]
代码语言:javascript
复制
array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
代码语言:javascript
复制
import numpy as npdef show_images_labels_predict(images,labels,prediction_result):    fig = plt.gcf()    fig.set_size_inches(8, 10)    for i in range(0, 10):        ax=plt.subplot(5,5, 1+i)        ax.imshow(np.reshape(images[i],(28, 28)),                   cmap='binary')        ax.set_title("label=" +str(np.argmax(labels[i]))+                     ",predict="+str(prediction_result[i])                     ,fontsize=9)     plt.show()
代码语言:javascript
复制
show_images_labels_predict(mnist.test.images,mnist.test.labels,prediction_result)

找出预测错误

代码语言:javascript
复制
for i in range(500):    if prediction_result[i]!=np.argmax(mnist.test.labels[i]):        print("i="+str(i)+              "   label=",np.argmax(mnist.test.labels[i]),              "predict=",prediction_result[i])
代码语言:javascript
复制
i=247   label= 4 predict= 2i=259   label= 6 predict= 0i=290   label= 8 predict= 4i=320   label= 9 predict= 1i=321   label= 2 predict= 7i=340   label= 5 predict= 3i=445   label= 6 predict= 0i=495   label= 8 predict= 0
代码语言:javascript
复制
def show_images_labels_predict_error(images,labels,prediction_result):    fig = plt.gcf()    fig.set_size_inches(8, 10)    i=0;j=0    while i<10:        if prediction_result[j]!=np.argmax(labels[j]):            ax=plt.subplot(5,5, 1+i)            ax.imshow(np.reshape(images[j],(28, 28)),                       cmap='binary')            ax.set_title("j="+str(j)+                         ",l=" +str(np.argmax(labels[j]))+                         ",p="+str(prediction_result[j])                         ,fontsize=9)             i=i+1          j=j+1    plt.show()
代码语言:javascript
复制
show_images_labels_predict_error(mnist.test.images,mnist.test.labels,prediction_result)

保存模型

代码语言:javascript
复制
saver = tf.train.Saver()
代码语言:javascript
复制
save_path = saver.save(sess, "saveModel/CNN_model1")
代码语言:javascript
复制
print("Model saved in file: %s" % save_path)
代码语言:javascript
复制
Model saved in file: saveModel/CNN_model1

启动TensorBoard

  • tensorboard --logdir=c:\pyhonwork\log\CNN
  • 在浏览器中打开https://localhost:6006/
代码语言:javascript
复制
merged = tf.summary.merge_all()train_writer = tf.summary.FileWriter('log/CNN',sess.graph)
代码语言:javascript
复制
sess.close()
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-05-29,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 数据科学CLUB 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 数据准备
  • 建立共享函数
  • 建立模型
  • 设置训练模型最优化步骤
  • 设置评估模型
  • 训练模型
  • 评估模型的准确率
  • 预测概率
  • 预测结果
  • 找出预测错误
  • 保存模型
  • 启动TensorBoard
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档