1.下载tensorflow源代码
1
| git clone https://github.com/tensorflow/tensorflow.git
|
假设tensorflow的根目录为 TENSORFLOW_ROOT
2.训练模型
训练的模型采用是最简单的模型,训练脚本如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| import tensorflow as tf import os.path from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("data/", one_hot=True) g = tf.Graph() with g.as_default(): x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10]), name="vaiable_W") b = tf.Variable(tf.zeros([10]), name="variable_b") y = tf.nn.softmax(tf.matmul(x, W) + b) y_ = tf.placeholder(tf.float32, [None, 10]) cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) train_step.run({x: batch_xs, y_: batch_ys}, sess) correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}, sess)) _W = W.eval(sess) _b = b.eval(sess) sess.close() g_2 = tf.Graph() with g_2.as_default(): x_2 = tf.placeholder(tf.float32, [None, 784], name="input") W_2 = tf.constant(_W, name="constant_W") b_2 = tf.constant(_b, name="constant_b") y_2 = tf.nn.softmax(tf.matmul(x_2, W_2) + b_2, name="output") sess_2 = tf.Session() init_2 = tf.global_variables_initializer() sess_2.run(init_2) graph_def = g_2.as_graph_def() tf.train.write_graph(graph_def, './model/beginner-export', 'beginner-graph.pb', as_text=False)
|
训练完成之后,得到了模型文件 beginner-graph.pb
3.替换模型
android例子的位置在TENSORFLOW_ROOT/tensorflow/examples/android目录下面
拷贝上一步生成的模型文件到 assets目录下面
新建一个标签文件 mnist_labels.txt 放入 assets目录下面
文件内容如下:
4.修改源码
首先修改 org.tensorflow.demo 包下面的 ClassifierActivity.java 文件
修改前:
1 2 3 4 5 6 7 8 9
| private static final int NUM_CLASSES = 1001; private static final int INPUT_SIZE = 224; private static final int IMAGE_MEAN = 117; private static final float IMAGE_STD = 1; private static final String INPUT_NAME = "input:0"; private static final String OUTPUT_NAME = "output:0"; private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb"; private static final String LABEL_FILE = "file:///android_asset/imagenet_comp_graph_label_strings.txt";
|
修改后:
1 2 3 4 5 6 7 8 9
| private static final int NUM_CLASSES = 10; private static final int INPUT_SIZE = 28; private static final int IMAGE_MEAN = 117; private static final float IMAGE_STD = 1; private static final String INPUT_NAME = "input"; private static final String OUTPUT_NAME = "output"; private static final String MODEL_FILE = "file:///android_asset/beginner-graph.pb"; private static final String LABEL_FILE = "file:///android_asset/mnist_labels.txt";
|
接着修改 TensorFlowImageClassifier.java 文件中的
函数1:
1
| public int initializeTensorFlow(){}
|
修改前:
1 2 3 4 5
| outputNames = new String[] {outputName}; intValues = new int[inputSize * inputSize]; floatValues = new float[inputSize * inputSize * 3]; outputs = new float[numClasses];
|
修改后:
1 2 3 4 5
| outputNames = new String[] {outputName}; intValues = new int[inputSize * inputSize]; floatValues = new float[inputSize * inputSize]; outputs = new float[numClasses];
|
函数2:
1
| public List<Recognition> recognizeImage(final Bitmap bitmap) {}
|
修改前:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| Trace.beginSection("preprocessBitmap"); bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); for (int i = 0; i < intValues.length; ++i) { final int val = intValues[i]; floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd; floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd; floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd; } Trace.endSection(); Trace.beginSection("fillNodeFloat"); inferenceInterface.fillNodeFloat(inputName, new int[] {1, inputSize, inputSize, 3}, floatValues); Trace.endSection();
|
修改后:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| Trace.beginSection("preprocessBitmap"); bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); for (int i = 0; i < intValues.length; ++i) { final int val = intValues[i]; int R = (val >> 16) & 0xFF; int G = (val >> 8) & 0xFF; int B = val & 0xFF; float Y = (float)(1-(0.3*R + 0.59*G + 0.11*B)/255); floatValues[i] = Y>0.2?Y:0; } Trace.endSection(); Trace.beginSection("fillNodeFloat"); inferenceInterface.fillNodeFloat(inputName, new int[] {1, inputSize*inputSize}, floatValues); Trace.endSection();
|
5.编译运行
编译apk
1
| bazel build //tensorflow/examples/android:tensorflow_demo
|
安装apk
1
| adb install -r bazel-bin/tensorflow/examples/android/tensorflow_demo.apk
|
运行TF Classify程序