MNIST Dataset : 손으로 쓴 글씨 데이터이다. 우체국에서 손으로 쓴 글씨를 읽게 하기 위해 사용한다.
X 는 전체 픽셀 수로 지정(28 * 28), Y는 0~9까지 10개의 숫자 갯수 10으로 지정한다.
from tensorflow.examples.tutorials.mnist import input_data
# Check out https://www.tensorflow.org/get_started/mnist/beginners for
# more information about the mnist dataset
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
nb_classes = 10
X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, nb_classes])
W = tf.Variable(tf.random_normal([784, nb_classes])) # 입력 출력
b = tf.Variable(tf.random_normal([nb_classes]) # Y
hypothesis = tf.nn.softmax(tf.matmul(X, W) + b)
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(hypothesis), axis=1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)
#Test model
is_correct = tf.equal(tf.arg_max(hypothesis, 1), tf.arg_max(Y, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
#Training epoch / batch ... 한번에 몇개씩만 Training을 돌린다(데이터가 많으니까)
# epoch : 전체 데이터셋을 한 번 돌릴 거를 one epoch 으로 본다. ex. 전체 1000개 데이터면 500개씩
# 2 epoch 돌리면 된다.
# Parameters
training_epochs = 15
batch_size = 100
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Training cycle
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples / batch_size) # 전체 데이터의 개수를 epoch 사이즈로 나눠준다.
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
c, _ = sess.run([cost, optimizer], feed_dict={X: batch_xs, Y: batch_ys})
avg_cost += c / total_batch
print('Epoch:', '%04d' % (epoch + 1), 'cost= ', '{:.9f'.format(avg_cost))
# Report results on test dataset
# Session.run() 으로 할수도 있고 아래 소스처럼 .eval을 써서 할 수도 있다.
print("Accuracy: ", accuracy.eval(session=sess, feed_dict={X: mnist.test.images, Y: mnist.test.labels}))
랜덤한 글자 이미지를 하나 찾아서 Test 해보는 소스
import matplotlib.pyplot as plt
import random
r = random.randint(0, mnist.test.num_examples - 1)
print("Label:", sess.run(tf.argmax(mnist.test.labels[r:r+1], 1)))
print("Prediction: ", sess.run(tf.argmax(hypothesis, 1), feed_dict={X: mnist.test.images[r: r+1]}))
plt.imshow(mnist.test.images[r: r+1].reshape(28, 28), cmap='Greys', interpolation='nearest')
plt.show()