神经网络-手写数字识别学习笔记
之前弄机器学习一直跟有监督学习打交道 ,没怎么看重无监督学习,最近想要往nlp自然语言处理走一走,发现一篇文章《A Neural Conversational Model》
所以自然绕不开seq2seq与LSTM,所以准备打下深度学习的基础,这两天正在看tensorflow,遂mark一下MNIST的手写识别代码
准确率到了0.98,代码还很不完善,比如隐藏层过多,迭代次数略微少点等,大家一起学习下吧
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
batch_size = 100
n_batch = mnist.train.num_examples // batch_size
x = tf.placeholder(tf.float32, )
y = tf.placeholder(tf.float32, )
keep_prob = tf.placeholder(tf.float32) #参加计算神经元占比,eg. 1.0表示100%
#有效抑制过拟合:测试数据与训练数据差别不大
Weight_L1 = tf.Variable(tf.truncated_normal(, stddev = 0.1))
biases_L1 = tf.Variable(tf.zeros() + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, Weight_L1) + biases_L1)
L1_drop = tf.nn.dropout(L1, keep_prob)
lr = tf.Variable(0.001)
Weight_L2 = tf.Variable(tf.truncated_normal(, stddev = 0.1))
biases_L2 = tf.Variable(tf.zeros() + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, Weight_L2) + biases_L2)
L2_drop = tf.nn.dropout(L2, keep_prob)
Weight_L3 = tf.Variable(tf.truncated_normal(, stddev = 0.1))
biases_L3 = tf.Variable(tf.zeros() + 0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop, Weight_L3) + biases_L3)
L3_drop = tf.nn.dropout(L3, keep_prob)
Weight_L4 = tf.Variable(tf.truncated_normal(, stddev = 0.1))
biases_L4 = tf.Variable(tf.zeros() + 0.1)
prediction = tf.nn.softmax(tf.matmul(L3_drop, Weight_L4) + biases_L4) #多分类
#损失函数
# loss = tf.reduce_mean(tf.square(y - prediction)) #二次代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels = y, logits = prediction)) #交叉熵
#优化器
# train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #梯度下降
train_step = tf.train.AdamOptimizer(lr).minimize(loss)
init = tf.global_variables_initializer()
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) #cast布尔转数值
with tf.Session() as sess:
sess.run(init)
for epoch in range(31):
sess.run(tf.assign(lr, 0.001 * (0.95 ** epoch))) #迭代降低学习率
for batch in range(n_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step, feed_dict = {x:batch_xs,
y:batch_ys,
keep_prob:1.0})
test_acc = sess.run(accuracy, feed_dict = {x:mnist.test.images,
y:mnist.test.labels,
keep_prob:1.0})
train_acc = sess.run(accuracy, feed_dict = {x:mnist.train.images,
y:mnist.train.labels,
keep_prob:1.0})
learning_rate = sess.run(lr)
print("Iter " + str(epoch) +
", Testing Accuracy " + str(test_acc) +
", Training Accuracy " + str(train_acc) +
", Learning Rate " + str(learning_rate))
以下是迭代结果
Iter 0, Testing Accuracy 0.9502, Training Accuracy 0.95552725, Learning Rate 0.001
Iter 1, Testing Accuracy 0.96, Training Accuracy 0.9675636, Learning Rate 0.00095
Iter 2, Testing Accuracy 0.9679, Training Accuracy 0.97810906, Learning Rate 0.0009025
Iter 3, Testing Accuracy 0.9707, Training Accuracy 0.98196363, Learning Rate 0.000857375
Iter 4, Testing Accuracy 0.9716, Training Accuracy 0.9841273, Learning Rate 0.00081450626
Iter 5, Testing Accuracy 0.9749, Training Accuracy 0.98634547, Learning Rate 0.0007737809
Iter 6, Testing Accuracy 0.9747, Training Accuracy 0.9884909, Learning Rate 0.0007350919
Iter 7, Testing Accuracy 0.9767, Training Accuracy 0.98965454, Learning Rate 0.0006983373
Iter 8, Testing Accuracy 0.9756, Training Accuracy 0.99114543, Learning Rate 0.0006634204
Iter 9, Testing Accuracy 0.9792, Training Accuracy 0.99325454, Learning Rate 0.0006302494
Iter 10, Testing Accuracy 0.9774, Training Accuracy 0.99332726, Learning Rate 0.0005987369
Iter 11, Testing Accuracy 0.9765, Training Accuracy 0.99332726, Learning Rate 0.0005688001
Iter 12, Testing Accuracy 0.9802, Training Accuracy 0.9935273, Learning Rate 0.0005403601
Iter 13, Testing Accuracy 0.9815, Training Accuracy 0.9952, Learning Rate 0.0005133421
Iter 14, Testing Accuracy 0.9777, Training Accuracy 0.99514544, Learning Rate 0.000487675
Iter 15, Testing Accuracy 0.9804, Training Accuracy 0.9954, Learning Rate 0.00046329122
Iter 16, Testing Accuracy 0.9813, Training Accuracy 0.9958909, Learning Rate 0.00044012666
Iter 17, Testing Accuracy 0.9802, Training Accuracy 0.9961636, Learning Rate 0.00041812033
Iter 18, Testing Accuracy 0.9766, Training Accuracy 0.99463636, Learning Rate 0.00039721432
Iter 19, Testing Accuracy 0.9809, Training Accuracy 0.99652725, Learning Rate 0.0003773536
Iter 20, Testing Accuracy 0.9758, Training Accuracy 0.99563634, Learning Rate 0.00035848594
Iter 21, Testing Accuracy 0.9807, Training Accuracy 0.99667275, Learning Rate 0.00034056162
Iter 22, Testing Accuracy 0.979, Training Accuracy 0.9961636, Learning Rate 0.00032353355
Iter 23, Testing Accuracy 0.9813, Training Accuracy 0.99692726, Learning Rate 0.00030735688
Iter 24, Testing Accuracy 0.9806, Training Accuracy 0.997, Learning Rate 0.000291989
Iter 25, Testing Accuracy 0.9799, Training Accuracy 0.9967818, Learning Rate 0.00027738957
Iter 26, Testing Accuracy 0.981, Training Accuracy 0.99725455, Learning Rate 0.0002635201
Iter 27, Testing Accuracy 0.9813, Training Accuracy 0.9972909, Learning Rate 0.00025034408
Iter 28, Testing Accuracy 0.9806, Training Accuracy 0.9973636, Learning Rate 0.00023782688
Iter 29, Testing Accuracy 0.981, Training Accuracy 0.9973818, Learning Rate 0.00022593554
Iter 30, Testing Accuracy 0.9812, Training Accuracy 0.99743634, Learning Rate 0.00021463877 有没有比较基础的东西推荐一下啊 hjsm 发表于 2019-6-30 16:06
有没有比较基础的东西推荐一下啊
手写数字识别算是hello world了 好好学习,天天向上………… 有点长,学习了
页:
[1]