forked from xuos/xiuos
APP_Framework/Framework/:update knowing framework
1.fix some Kconfig file 2.add tensorflow-lite-for-mcu in knowing file 3.add mnist application,note the application cannot be used with RAM less than 500K. 4.the version need to separate application and OS(rtt),later by using add transform layer to solve it.
This commit is contained in:
@@ -1,7 +1,3 @@
|
||||
menu "knowing app"
|
||||
|
||||
menuconfig APPLICATION_KNOWING
|
||||
bool "Using knowing apps"
|
||||
default n
|
||||
|
||||
source "$APP_DIR/Applications/knowing_app/mnist/Kconfig"
|
||||
endmenu
|
||||
|
||||
14
APP_Framework/Applications/knowing_app/SConscript
Normal file
14
APP_Framework/Applications/knowing_app/SConscript
Normal file
@@ -0,0 +1,14 @@
|
||||
import os
|
||||
Import('RTT_ROOT')
|
||||
from building import *
|
||||
|
||||
cwd = GetCurrentDir()
|
||||
objs = []
|
||||
list = os.listdir(cwd)
|
||||
|
||||
for d in list:
|
||||
path = os.path.join(cwd, d)
|
||||
if os.path.isfile(os.path.join(path, 'SConscript')):
|
||||
objs = objs + SConscript(os.path.join(path, 'SConscript'))
|
||||
|
||||
Return('objs')
|
||||
2
APP_Framework/Applications/knowing_app/mnist/.gitignore
vendored
Normal file
2
APP_Framework/Applications/knowing_app/mnist/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*.h5
|
||||
*.tflite
|
||||
BIN
APP_Framework/Applications/knowing_app/mnist/K210 mnist .png
Normal file
BIN
APP_Framework/Applications/knowing_app/mnist/K210 mnist .png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 32 KiB |
4
APP_Framework/Applications/knowing_app/mnist/Kconfig
Normal file
4
APP_Framework/Applications/knowing_app/mnist/Kconfig
Normal file
@@ -0,0 +1,4 @@
|
||||
config APP_MNIST
|
||||
bool "enable apps/mnist"
|
||||
depends on USING_TENSORFLOWLITEMICRO
|
||||
default n
|
||||
23
APP_Framework/Applications/knowing_app/mnist/README.md
Normal file
23
APP_Framework/Applications/knowing_app/mnist/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# MNIST 说明
|
||||
|
||||
要使用本例程,MCU RAM必须至少500K左右,所以本例程目前在K210上面验证过,stm32f407 目前在rtt上原则上只能采取dlmodule加载的方式。
|
||||
|
||||

|
||||
|
||||
## 使用
|
||||
|
||||
tools/mnist-train.py 训练生成 mnist 模型。
|
||||
|
||||
tools/mnist-inference.py 使用 mnist 模型进行推理。
|
||||
|
||||
tools/mnist-c-model.py 将 mnist 模型转换成 C 的数组保存在 model.h 中。
|
||||
|
||||
tools/mnist-c-digit.py 将 mnist 数据集中的某个数字转成数组保存在 digit.h 中。
|
||||
|
||||
## 参考资料
|
||||
|
||||
https://tensorflow.google.cn/lite/performance/post_training_quantization
|
||||
|
||||
https://tensorflow.google.cn/lite/performance/post_training_integer_quant
|
||||
|
||||
https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/hello_world/train/train_hello_world_model.ipynb
|
||||
9
APP_Framework/Applications/knowing_app/mnist/SConscript
Normal file
9
APP_Framework/Applications/knowing_app/mnist/SConscript
Normal file
@@ -0,0 +1,9 @@
|
||||
from building import *
|
||||
|
||||
cwd = GetCurrentDir()
|
||||
src = Glob('*.c') + Glob('*.cpp')
|
||||
CPPPATH = [cwd]
|
||||
|
||||
group = DefineGroup('Applications', src, depend = ['APP_MNIST'], LOCAL_CPPPATH = CPPPATH)
|
||||
|
||||
Return('group')
|
||||
31
APP_Framework/Applications/knowing_app/mnist/digit.h
Normal file
31
APP_Framework/Applications/knowing_app/mnist/digit.h
Normal file
@@ -0,0 +1,31 @@
|
||||
const float mnist_digit[] = {
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.33, 0.73, 0.62, 0.59, 0.24, 0.14, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.87, 1.00, 1.00, 1.00, 1.00, 0.95, 0.78, 0.78, 0.78, 0.78, 0.78, 0.78, 0.78, 0.78, 0.67, 0.20, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.26, 0.45, 0.28, 0.45, 0.64, 0.89, 1.00, 0.88, 1.00, 1.00, 1.00, 0.98, 0.90, 1.00, 1.00, 0.55, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.07, 0.26, 0.05, 0.26, 0.26, 0.26, 0.23, 0.08, 0.93, 1.00, 0.42, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.33, 0.99, 0.82, 0.07, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.09, 0.91, 1.00, 0.33, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.51, 1.00, 0.93, 0.17, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.23, 0.98, 1.00, 0.24, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.52, 1.00, 0.73, 0.02, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.04, 0.80, 0.97, 0.23, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.49, 1.00, 0.71, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.29, 0.98, 0.94, 0.22, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.07, 0.87, 1.00, 0.65, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.01, 0.80, 1.00, 0.86, 0.14, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.15, 1.00, 1.00, 0.30, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.12, 0.88, 1.00, 0.45, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.52, 1.00, 1.00, 0.20, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.24, 0.95, 1.00, 1.00, 0.20, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.47, 1.00, 1.00, 0.86, 0.16, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.47, 1.00, 0.81, 0.07, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,
|
||||
0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00
|
||||
};
|
||||
const int mnist_label = 7;
|
||||
94
APP_Framework/Applications/knowing_app/mnist/main.cpp
Normal file
94
APP_Framework/Applications/knowing_app/mnist/main.cpp
Normal file
@@ -0,0 +1,94 @@
|
||||
#include <rtthread.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "tensorflow/lite/micro/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/micro/micro_error_reporter.h"
|
||||
#include "tensorflow/lite/micro/micro_interpreter.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
#include "digit.h"
|
||||
#include "model.h"
|
||||
|
||||
namespace {
|
||||
tflite::ErrorReporter* error_reporter = nullptr;
|
||||
const tflite::Model* model = nullptr;
|
||||
tflite::MicroInterpreter* interpreter = nullptr;
|
||||
TfLiteTensor* input = nullptr;
|
||||
TfLiteTensor* output = nullptr;
|
||||
constexpr int kTensorArenaSize = 110 * 1024;
|
||||
uint8_t *tensor_arena = nullptr;
|
||||
//uint8_t tensor_arena[kTensorArenaSize];
|
||||
}
|
||||
|
||||
extern "C" void mnist_app() {
|
||||
tflite::MicroErrorReporter micro_error_reporter;
|
||||
error_reporter = µ_error_reporter;
|
||||
|
||||
model = tflite::GetModel(mnist_model);
|
||||
if (model->version() != TFLITE_SCHEMA_VERSION) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter,
|
||||
"Model provided is schema version %d not equal "
|
||||
"to supported version %d.",
|
||||
model->version(), TFLITE_SCHEMA_VERSION);
|
||||
return;
|
||||
}
|
||||
|
||||
tensor_arena = (uint8_t *)rt_malloc(kTensorArenaSize);
|
||||
if (tensor_arena == nullptr) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "malloc for tensor_arena failed");
|
||||
return;
|
||||
}
|
||||
|
||||
tflite::AllOpsResolver resolver;
|
||||
tflite::MicroInterpreter static_interpreter(
|
||||
model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
|
||||
interpreter = &static_interpreter;
|
||||
|
||||
// Allocate memory from the tensor_arena for the model's tensors.
|
||||
TfLiteStatus allocate_status = interpreter->AllocateTensors();
|
||||
if (allocate_status != kTfLiteOk) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
|
||||
return;
|
||||
}
|
||||
|
||||
input = interpreter->input(0);
|
||||
output = interpreter->output(0);
|
||||
|
||||
printf("------- Input Digit -------\n");
|
||||
for (int i = 0; i < 28; i++) {
|
||||
for (int j = 0; j < 28; j++) {
|
||||
if (mnist_digit[i*28+j] > 0.3)
|
||||
printf("#");
|
||||
else
|
||||
printf(".");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
for (int i = 0; i < 28*28; i++) {
|
||||
input->data.f[i] = mnist_digit[i];
|
||||
}
|
||||
|
||||
TfLiteStatus invoke_status = interpreter->Invoke();
|
||||
if (invoke_status != kTfLiteOk) {
|
||||
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x_val\n");
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the predicted y value from the model's output tensor
|
||||
float max = 0.0;
|
||||
int index;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
if(output->data.f[i]>max){
|
||||
max = output->data.f[i];
|
||||
index = i;
|
||||
}
|
||||
}
|
||||
printf("------- Output Result -------\n");
|
||||
printf("result is %d\n", index);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
MSH_CMD_EXPORT(mnist_app, run mnist app);
|
||||
}
|
||||
31409
APP_Framework/Applications/knowing_app/mnist/model.h
Normal file
31409
APP_Framework/Applications/knowing_app/mnist/model.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
print("TensorFlow version %s" % (tf.__version__))
|
||||
|
||||
def show(image):
|
||||
for i in range(28):
|
||||
for j in range(28):
|
||||
if image[i][j] > 0.3:
|
||||
print('#', end = '')
|
||||
else:
|
||||
print('.', end = '')
|
||||
print()
|
||||
|
||||
digit_file_path = 'digit.h'
|
||||
digit_content = '''const float mnist_digit[] = {
|
||||
%s
|
||||
};
|
||||
const int mnist_label = %d;
|
||||
'''
|
||||
|
||||
if __name__ == '__main__':
|
||||
mnist = tf.keras.datasets.mnist
|
||||
(_, _), (test_images, test_labels) = mnist.load_data()
|
||||
index = 0
|
||||
shape = 28
|
||||
image = test_images[index].astype('float32')/255
|
||||
label = test_labels[index]
|
||||
print('label: %d' % label)
|
||||
#show(image)
|
||||
digit_data = (',\n ').join([ (', ').join([ '%.2f' % image[row][col] for col in range(shape)]) for row in range(shape)])
|
||||
digit_file = open(digit_file_path, 'w')
|
||||
digit_file.write(digit_content % (digit_data, label))
|
||||
digit_file.close()
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
#tflite_file_path = 'mnist-default-quan.tflite'
|
||||
tflite_file_path = 'mnist.tflite'
|
||||
model_file_path = 'model.h'
|
||||
|
||||
tflite_file = open(tflite_file_path, 'rb')
|
||||
tflite_data = tflite_file.read()
|
||||
tflite_file.close()
|
||||
tflite_array = [ '0x%02x' % byte for byte in tflite_data ]
|
||||
|
||||
model_content = '''unsigned char mnist_model[] = {
|
||||
%s
|
||||
};
|
||||
unsigned int mnist_model_len = %d;
|
||||
'''
|
||||
# 12 bytes in a line, the same with xxd
|
||||
bytes_of_line = 12
|
||||
model_data = (',\n ').join([ (', ').join(tflite_array[i:i+bytes_of_line]) for i in range(0, len(tflite_array), bytes_of_line) ])
|
||||
model_file = open(model_file_path, 'w')
|
||||
model_file.write(model_content % (model_data, len(tflite_array)))
|
||||
model_file.close()
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
print("TensorFlow version %s" % (tf.__version__))
|
||||
|
||||
MODEL_NAME_H5 = 'mnist.h5'
|
||||
MODEL_NAME_TFLITE = 'mnist.tflite'
|
||||
DEFAULT_QUAN_MODEL_NAME_TFLITE = 'mnist-default-quan.tflite'
|
||||
FULL_QUAN_MODEL_NAME_TFLITE = 'mnist-full-quan.tflite'
|
||||
|
||||
|
||||
def show(image):
|
||||
for i in range(28):
|
||||
for j in range(28):
|
||||
if image[i][j][0] > 0.3:
|
||||
print('#', end = '')
|
||||
else:
|
||||
print(' ', end = '')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mnist = tf.keras.datasets.mnist
|
||||
(_, _), (test_images, test_labels) = mnist.load_data()
|
||||
test_images = test_images.reshape(10000, 28, 28, 1)
|
||||
index = 0
|
||||
input_image = test_images[index].astype('float32')/255
|
||||
target_label = test_labels[index]
|
||||
|
||||
interpreter = tf.lite.Interpreter(model_path = DEFAULT_QUAN_MODEL_NAME_TFLITE)
|
||||
interpreter.allocate_tensors()
|
||||
input_details = interpreter.get_input_details()[0]
|
||||
output_details = interpreter.get_output_details()[0]
|
||||
interpreter.set_tensor(input_details['index'], [input_image])
|
||||
interpreter.invoke()
|
||||
output = interpreter.get_tensor(output_details['index'])[0]
|
||||
|
||||
show(input_image)
|
||||
print('target label: %d, predict label: %d' % (target_label, output.argmax()))
|
||||
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import tensorflow as tf
|
||||
|
||||
print("TensorFlow version %s" % (tf.__version__))
|
||||
|
||||
MODEL_NAME_H5 = 'mnist.h5'
|
||||
MODEL_NAME_TFLITE = 'mnist.tflite'
|
||||
DEFAULT_QUAN_MODEL_NAME_TFLITE = 'mnist-default-quan.tflite'
|
||||
FULL_QUAN_MODEL_NAME_TFLITE = 'mnist-full-quan.tflite'
|
||||
|
||||
def build_model(model_name):
|
||||
print('\n>>> load mnist dataset')
|
||||
mnist = tf.keras.datasets.mnist
|
||||
(train_images, train_labels),(test_images, test_labels) = mnist.load_data()
|
||||
print("train images shape: ", train_images.shape)
|
||||
print("train labels shape: ", train_labels.shape)
|
||||
print("test images shape: ", test_images.shape)
|
||||
print("test labels shape: ", test_labels.shape)
|
||||
|
||||
# transform label to categorical, like: 2 -> [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
|
||||
print('\n>>> transform label to categorical')
|
||||
train_labels = tf.keras.utils.to_categorical(train_labels)
|
||||
test_labels = tf.keras.utils.to_categorical(test_labels)
|
||||
print("train labels shape: ", train_labels.shape)
|
||||
print("test labels shape: ", test_labels.shape)
|
||||
|
||||
# transform color like: [0, 255] -> 0.xxx
|
||||
print('\n>>> transform image color into float32')
|
||||
train_images = train_images.astype('float32') / 255
|
||||
test_images = test_images.astype('float32') / 255
|
||||
|
||||
# reshape image like: (60000, 28, 28) -> (60000, 28, 28, 1)
|
||||
print('\n>>> reshape image with color channel')
|
||||
train_images = train_images.reshape((60000, 28, 28, 1))
|
||||
test_images = test_images.reshape((10000, 28, 28, 1))
|
||||
print("train images shape: ", train_images.shape)
|
||||
print("test images shape: ", test_images.shape)
|
||||
|
||||
print('\n>>> build model')
|
||||
model = tf.keras.models.Sequential([
|
||||
tf.keras.layers.Conv2D(32, (3, 3), activation=tf.nn.relu, input_shape=(28, 28, 1)),
|
||||
tf.keras.layers.MaxPooling2D((2, 2)),
|
||||
tf.keras.layers.Conv2D(64, (3, 3), activation=tf.nn.relu),
|
||||
tf.keras.layers.MaxPooling2D((2, 2)),
|
||||
tf.keras.layers.Conv2D(64, (3, 3), activation=tf.nn.relu),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(64, activation=tf.nn.relu),
|
||||
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
|
||||
])
|
||||
model.compile(optimizer='rmsprop',
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
model.summary()
|
||||
|
||||
print('\n>>> train the model')
|
||||
early_stopping = tf.keras.callbacks.EarlyStopping(
|
||||
monitor='loss', min_delta=0.0005, patience=3, verbose=1, mode='auto',
|
||||
baseline=None, restore_best_weights=True
|
||||
)
|
||||
model.fit(train_images, train_labels, epochs=100, batch_size=64, callbacks=[early_stopping])
|
||||
|
||||
print('\n>>> evaluate the model')
|
||||
test_loss, test_acc = model.evaluate(test_images, test_labels)
|
||||
print("lost: %f, accuracy: %f" % (test_loss, test_acc))
|
||||
|
||||
print('\n>>> save the keras model as %s' % model_name)
|
||||
model.save(model_name)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if not os.path.exists(MODEL_NAME_H5):
|
||||
build_model(MODEL_NAME_H5)
|
||||
|
||||
if not os.path.exists(MODEL_NAME_TFLITE):
|
||||
print('\n>>> save the tflite model as %s' % MODEL_NAME_TFLITE)
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(tf.keras.models.load_model(MODEL_NAME_H5))
|
||||
tflite_model = converter.convert()
|
||||
with open(MODEL_NAME_TFLITE, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
|
||||
if not os.path.exists(DEFAULT_QUAN_MODEL_NAME_TFLITE):
|
||||
print('\n>>> save the default quantized model as %s' % DEFAULT_QUAN_MODEL_NAME_TFLITE)
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(tf.keras.models.load_model(MODEL_NAME_H5))
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
tflite_model = converter.convert()
|
||||
with open(DEFAULT_QUAN_MODEL_NAME_TFLITE, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
|
||||
if not os.path.exists(FULL_QUAN_MODEL_NAME_TFLITE):
|
||||
mnist = tf.keras.datasets.mnist
|
||||
(train_images, _), (_, _) = mnist.load_data()
|
||||
train_images = train_images.astype('float32') / 255
|
||||
train_images = train_images.reshape((60000, 28, 28, 1))
|
||||
def representative_data_gen():
|
||||
for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
|
||||
yield [input_value]
|
||||
print('\n>>> save the full quantized model as %s' % DEFAULT_QUAN_MODEL_NAME_TFLITE)
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(tf.keras.models.load_model(MODEL_NAME_H5))
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.representative_dataset = representative_data_gen
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
converter.inference_input_type = tf.uint8
|
||||
converter.inference_output_type = tf.uint8
|
||||
tflite_model = converter.convert()
|
||||
with open(FULL_QUAN_MODEL_NAME_TFLITE, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
Reference in New Issue
Block a user