简要介绍
本文主要介绍,用tensorflow训练的mnist数据库,如何保存模型和使用保存的模型,以及如何对一张测试图片进行识别
模型保存
模型训练与保存代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
| from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) import tensorflow as tf import os x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b) y_ = tf.placeholder(tf.float32, [None, 10]) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) saver = tf.train.Saver() for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys}) print("训练完成!") model_dir = "mnist" model_name = "ckp" if not os.path.exists(model_dir): os.mkdir(model_dir) saver.save(sess, os.path.join(model_dir, model_name)) print("保存模型成功!")
|
模型使用
模型恢复与使用代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| from tensorflow.examples.tutorials.mnist import input_data mnist=input_data.read_data_sets("MNIST_data",one_hot=True) import tensorflow as tf sess = tf.Session() x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b) saver = tf.train.Saver([W, b]) saver.restore(sess, "mnist/ckp") print("恢复模型成功!") idx=0 img = mnist.test.images[idx] ret = sess.run(y, feed_dict = {x : img.reshape(1, 784)}) print("计算模型结果成功!") print("预测结果:%d"%(ret.argmax())) print("实际结果:%d"%(mnist.test.labels[idx].argmax()))
|
附件:
训练脚本:mnist_train.py
测试脚本:mnist_test.py