吾爱破解 - 52pojie.cn

 找回密码
 注册[Register]

QQ登录

只需一步,快速开始

查看: 4694|回复: 6
收起左侧

[Python 转载] 神经网络-手写数字识别学习笔记

[复制链接]
ConstantinChiae 发表于 2019-3-18 22:27
之前弄机器学习一直跟有监督学习打交道 ,没怎么看重无监督学习,最近想要往nlp自然语言处理走一走,发现一篇文章《A Neural Conversational Model

所以自然绕不开seq2seq与LSTM,所以准备打下深度学习的基础,这两天正在看tensorflow,遂mark一下MNIST的手写识别代码
准确率到了0.98,代码还很不完善,比如隐藏层过多,迭代次数略微少点等,大家一起学习下吧
[Python] 纯文本查看 复制代码
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)

batch_size = 100
n_batch = mnist.train.num_examples // batch_size

x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32) #参加计算神经元占比,eg. 1.0表示100%
                                       #有效抑制过拟合:测试数据与训练数据差别不大             
Weight_L1 = tf.Variable(tf.truncated_normal([784, 500], stddev = 0.1))
biases_L1 = tf.Variable(tf.zeros([500]) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, Weight_L1) + biases_L1)
L1_drop = tf.nn.dropout(L1, keep_prob)
lr = tf.Variable(0.001)

Weight_L2 = tf.Variable(tf.truncated_normal([500, 300], stddev = 0.1))
biases_L2 = tf.Variable(tf.zeros([300]) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, Weight_L2) + biases_L2)
L2_drop = tf.nn.dropout(L2, keep_prob)

Weight_L3 = tf.Variable(tf.truncated_normal([300, 100], stddev = 0.1))
biases_L3 = tf.Variable(tf.zeros([100]) + 0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop, Weight_L3) + biases_L3)
L3_drop = tf.nn.dropout(L3, keep_prob)

Weight_L4 = tf.Variable(tf.truncated_normal([100, 10], stddev = 0.1))
biases_L4 = tf.Variable(tf.zeros([10]) + 0.1)
prediction = tf.nn.softmax(tf.matmul(L3_drop, Weight_L4) + biases_L4) #多分类

#损失函数
# loss = tf.reduce_mean(tf.square(y - prediction)) #二次代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels = y, logits = prediction)) #交叉熵
#优化器
# train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #梯度下降
train_step = tf.train.AdamOptimizer(lr).minimize(loss)

init = tf.global_variables_initializer()
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) #cast布尔转数值

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(31):
        sess.run(tf.assign(lr, 0.001 * (0.95 ** epoch))) #迭代降低学习率
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict = {x:batch_xs, 
                                              y:batch_ys,
                                              keep_prob:1.0})
        test_acc = sess.run(accuracy, feed_dict = {x:mnist.test.images, 
                                              y:mnist.test.labels,
                                              keep_prob:1.0})
        train_acc = sess.run(accuracy, feed_dict = {x:mnist.train.images, 
                                              y:mnist.train.labels,
                                              keep_prob:1.0})
        learning_rate = sess.run(lr)
        print("Iter " + str(epoch) + 
              ", Testing Accuracy " + str(test_acc) + 
              ", Training Accuracy " + str(train_acc) + 
              ", Learning Rate " + str(learning_rate))


以下是迭代结果
[Python] 纯文本查看 复制代码
Iter 0, Testing Accuracy 0.9502, Training Accuracy 0.95552725, Learning Rate 0.001
Iter 1, Testing Accuracy 0.96, Training Accuracy 0.9675636, Learning Rate 0.00095
Iter 2, Testing Accuracy 0.9679, Training Accuracy 0.97810906, Learning Rate 0.0009025
Iter 3, Testing Accuracy 0.9707, Training Accuracy 0.98196363, Learning Rate 0.000857375
Iter 4, Testing Accuracy 0.9716, Training Accuracy 0.9841273, Learning Rate 0.00081450626
Iter 5, Testing Accuracy 0.9749, Training Accuracy 0.98634547, Learning Rate 0.0007737809
Iter 6, Testing Accuracy 0.9747, Training Accuracy 0.9884909, Learning Rate 0.0007350919
Iter 7, Testing Accuracy 0.9767, Training Accuracy 0.98965454, Learning Rate 0.0006983373
Iter 8, Testing Accuracy 0.9756, Training Accuracy 0.99114543, Learning Rate 0.0006634204
Iter 9, Testing Accuracy 0.9792, Training Accuracy 0.99325454, Learning Rate 0.0006302494
Iter 10, Testing Accuracy 0.9774, Training Accuracy 0.99332726, Learning Rate 0.0005987369
Iter 11, Testing Accuracy 0.9765, Training Accuracy 0.99332726, Learning Rate 0.0005688001
Iter 12, Testing Accuracy 0.9802, Training Accuracy 0.9935273, Learning Rate 0.0005403601
Iter 13, Testing Accuracy 0.9815, Training Accuracy 0.9952, Learning Rate 0.0005133421
Iter 14, Testing Accuracy 0.9777, Training Accuracy 0.99514544, Learning Rate 0.000487675
Iter 15, Testing Accuracy 0.9804, Training Accuracy 0.9954, Learning Rate 0.00046329122
Iter 16, Testing Accuracy 0.9813, Training Accuracy 0.9958909, Learning Rate 0.00044012666
Iter 17, Testing Accuracy 0.9802, Training Accuracy 0.9961636, Learning Rate 0.00041812033
Iter 18, Testing Accuracy 0.9766, Training Accuracy 0.99463636, Learning Rate 0.00039721432
Iter 19, Testing Accuracy 0.9809, Training Accuracy 0.99652725, Learning Rate 0.0003773536
Iter 20, Testing Accuracy 0.9758, Training Accuracy 0.99563634, Learning Rate 0.00035848594
Iter 21, Testing Accuracy 0.9807, Training Accuracy 0.99667275, Learning Rate 0.00034056162
Iter 22, Testing Accuracy 0.979, Training Accuracy 0.9961636, Learning Rate 0.00032353355
Iter 23, Testing Accuracy 0.9813, Training Accuracy 0.99692726, Learning Rate 0.00030735688
Iter 24, Testing Accuracy 0.9806, Training Accuracy 0.997, Learning Rate 0.000291989
Iter 25, Testing Accuracy 0.9799, Training Accuracy 0.9967818, Learning Rate 0.00027738957
Iter 26, Testing Accuracy 0.981, Training Accuracy 0.99725455, Learning Rate 0.0002635201
Iter 27, Testing Accuracy 0.9813, Training Accuracy 0.9972909, Learning Rate 0.00025034408
Iter 28, Testing Accuracy 0.9806, Training Accuracy 0.9973636, Learning Rate 0.00023782688
Iter 29, Testing Accuracy 0.981, Training Accuracy 0.9973818, Learning Rate 0.00022593554
Iter 30, Testing Accuracy 0.9812, Training Accuracy 0.99743634, Learning Rate 0.00021463877

发帖前要善用论坛搜索功能,那里可能会有你要找的答案或者已经有人发布过相同内容了,请勿重复发帖。

hjsm 发表于 2019-6-30 16:06
有没有比较基础的东西推荐一下啊
 楼主| ConstantinChiae 发表于 2019-7-2 08:23
hjsm 发表于 2019-6-30 16:06
有没有比较基础的东西推荐一下啊

手写数字识别算是hello world了
du198683 发表于 2019-7-2 23:27
yctx999 发表于 2019-7-8 20:51
有点长,学习了
您需要登录后才可以回帖 登录 | 注册[Register]

本版积分规则

返回列表

RSS订阅|小黑屋|处罚记录|联系我们|吾爱破解 - LCG - LSG ( 京ICP备16042023号 | 京公网安备 11010502030087号 )

GMT+8, 2024-11-26 01:50

Powered by Discuz!

Copyright © 2001-2020, Tencent Cloud.

快速回复 返回顶部 返回列表