forked from xuos/xiuos
37 lines
964 B
Python
37 lines
964 B
Python
#!/usr/bin/env python3
|
|
|
|
import tensorflow as tf
|
|
|
|
print("TensorFlow version %s" % (tf.__version__))
|
|
|
|
def show(image):
|
|
for i in range(28):
|
|
for j in range(28):
|
|
if image[i][j] > 0.3:
|
|
print('#', end = '')
|
|
else:
|
|
print('.', end = '')
|
|
print()
|
|
|
|
digit_file_path = 'digit.h'
|
|
digit_content = '''const float mnist_digit[] = {
|
|
%s
|
|
};
|
|
const int mnist_label = %d;
|
|
'''
|
|
|
|
if __name__ == '__main__':
|
|
mnist = tf.keras.datasets.mnist
|
|
(_, _), (test_images, test_labels) = mnist.load_data()
|
|
index = 0
|
|
shape = 28
|
|
image = test_images[index].astype('float32')/255
|
|
label = test_labels[index]
|
|
print('label: %d' % label)
|
|
#show(image)
|
|
digit_data = (',\n ').join([ (', ').join([ '%.2f' % image[row][col] for col in range(shape)]) for row in range(shape)])
|
|
digit_file = open(digit_file_path, 'w')
|
|
digit_file.write(digit_content % (digit_data, label))
|
|
digit_file.close()
|
|
|