吾爱破解 - 52pojie.cn

 找回密码
 注册[Register]

QQ登录

只需一步,快速开始

查看: 2008|回复: 14
上一主题 下一主题
收起左侧

[Python 原创] 神经网络中前向传播与反向传播意义及其参数的更新方式

  [复制链接]
跳转到指定楼层
楼主
RiiiickSandes 发表于 2023-2-24 12:54 回帖奖励
本帖最后由 RiiiickSandes 于 2023-2-24 12:59 编辑

前向传播与反向传播意义及其参数的更新方式

一、前言

因为本身非科班出身,数学又学的很差,一直都是傻瓜式地用tensorflow和pytorch搭网络。前一段时间竞赛的时候尝试着用简单神经网络做了个题,同学突然问起反向传播的具体原理,一时语塞,遂下决心把这个问题搞明白。这篇学习笔记将以我的认知顺序也就是由浅至深的顺序叙述,里面可能涉及到一些神经网络的基础知识,比如学习率、激活函数、损失函数等,详情可以看看这里,本文不再赘述

写文章的时候查阅了一些资料,感觉写得最好的是这篇文章,我的一些思路也有所参考,推荐去看看,记得给大佬点star  : )

二、前反向传播的作用

这个问题应该大部分接触过神经网络的人都有所了解,我最开始的认知也就停留在这一步

前向传播,也叫正向传播,其实就是参数在神经网络中从输入层到输出层传输过程

反向传播,其实就是根据输出层的输出实际值的差距,更新神经网络中参数的过程

而一次正向传播加上一次反向传播就是一次网络的学习

话虽如此,参数在网络中到底是如何变化的呢

三、前向传播

首先我们来看一个神经网络,这个神经网络是如此的简单,这种简单结构的网络可以使我们更好地理解神经网络的工作方式。

所谓前向传播,其实就是将神经网络的上一层作为下一层的输入,并计算下一层的输出,一直到输出层位置

如上图,假如输入层输入x,那么参数前向传播到隐藏层其实就是输入x权重矩阵相乘加上偏置项之和再通过激活函数,假设我们使用的激活函数为


此时输入层的输出就是

当参数继续向前传播,通过隐藏层的输出到输出层,其值为

上面的式子的值其实就是神经网络的输出了,这样两个算式描述了一次前向传播的全部过程

四、反向传播

由于反向传播涉及到导数运算,而我的数学能力已经退化到小学水平了,所以这里我们直接使用一个1 1 1的 “神经网络” 来做演示

这里我们的损失函数选择使用最常见的均方误差(MSE),即定义损失值为预测值与实际值的差的平方除以样本数,这个损失函数对异常值比较敏感,适用于回归问题


而更新参数的依据,就是使最后预测的结果朝着损失函数值减小的方向移动,故我们用损失函数对每一个参数求偏导,让各个参数往损失函数减小的方向变化。假设我们这里的激活函数为

损失函数对各参数求偏导的结果如下

反向传播算法建立在梯度下降法的基础上,已经算出各参数偏导的情况下,需要使用梯度下降法进行参数更新,我们以学习率为μ为例,各参数的更新如下

为什么这里要引入学习率的概念呢,有一篇博客非常形象的说明了这个问题,感兴趣的可以看看原文,省流量的可以看下面这个表格,这个表格说明了当学习率等于1的时候可能遇到的困境

轮数 当前轮参数值 梯度x学习率 更新后参数值
1 5 2x5x1=10 5-10=-5
2 -5 2x-5x1=-10 -5-(-10)=5
3 5 2x5x1=10 5-10=-5

很明显,这里参数没有更新,输出结果就像大禹治水,三过家门而不入,训练也就毫无意义

代码

自己懒得写了,在网上找了一个,出处:CSDN

import numpy as np
import matplotlib.pyplot as plt

# 激活函数
def sigmoid(z):
    return 1 / (1 + np.exp(-z))

# 向前传递
def forward(X, W1, W2, W3, b1, b2, b3):
    # 隐藏层1
    Z1 = np.dot(W1.T,X)+b1  # X=n*m ,W1.T=h1*n,b1=h1*1,Z1=h1*m
    A1 = sigmoid(Z1)  # A1=h1*m
    # 隐藏层2
    Z2 = np.dot(W2.T, A1) + b2  # W2.T=h2*h1,b2=h2*1,Z2=h2*m
    A2 = sigmoid(Z2)  # A2=h2*m
    # 输出层
    Z3=np.dot(W3.T,A2)+b3  # W3.T=(h3=1)*h2,b3=(h3=1)*1,Z3=1*m
    A3=sigmoid(Z3)  # A3=1*m

    return Z1,Z2,Z3,A1,A2,A3

# 反向传播
def backward(Y,X,A3,A2,A1,Z3,Z2,Z1,W3,W2,W1):
    n,m = np.shape(X)
    dZ3 = A3-Y # dZ3=1*m
    dW3 = 1/m *np.dot(A2,dZ3.T) # dW3=h2*1
    db3 = 1/m *np.sum(dZ3,axis=1,keepdims=True) # db3=1*1

    dZ2 = np.dot(W3,dZ3)*A2*(1-A2) # dZ2=h2*m
    dW2 = 1/m*np.dot(A1,dZ2.T) #dw2=h1*h2
    db2 = 1/m*np.sum(dZ2,axis=1,keepdims=True) #db2=h2*1

    dZ1 = np.dot(W2, dZ2) * A1 * (1 - A1) # dZ1=h1*m
    dW1 = 1 / m * np.dot(X, dZ1.T)  # dW1=n*h
    db1 = 1 / m * np.sum(dZ1,axis=1,keepdims=True)  # db1=h*m

    return dZ3,dZ2,dZ1,dW3,dW2,dW1,db3,db2,db1

def costfunction(Y,A3):
    m, n = np.shape(Y)
    J=np.sum(Y*np.log(A3)+(1-Y)*np.log(1-A3))/m
    # J = (np.dot(y, np.log(A2.T)) + np.dot((1 - y).T, np.log(1 - A2))) / m
    return -J

# Data = np.loadtxt("gua2.txt")
# X = Data[:, 0:-1]
# X = X.T
# Y = Data[:, -1]
# Y=np.reshape(1,m)
X=np.random.rand(100,200)
n,m=np.shape(X)
Y=np.random.rand(1,m)
n_x=n
n_y=1
n_h1=5
n_h2=4
W1=np.random.rand(n_x,n_h1)*0.01
W2=np.random.rand(n_h1,n_h2)*0.01
W3=np.random.rand(n_h2,n_y)*0.01
b1=np.zeros((n_h1,1))
b2=np.zeros((n_h2,1))
b3=np.zeros((n_y,1))
alpha=0.1
number=10000
for i in range(0,number):
    Z1,Z2,Z3,A1,A2,A3=forward(X,W1,W2,W3,b1,b2,b3)
    dZ3, dZ2, dZ1, dW3, dW2, dW1, db3, db2, db1=backward(Y,X,A3,A2,A1,Z3,Z2,Z1,W3,W2,W1)
    W1=W1-alpha*dW1
    W2=W2-alpha*dW2
    W3=W3-alpha*dW3
    b1=b1-alpha*db1
    b2=b2-alpha*db2
    b3=b3-alpha*db3
    J=costfunction(Y,A3)
    if (i%100==0):
        print(i)
    plt.plot(i,J,'ro')
plt.show()

免费评分

参与人数 4威望 +1 吾爱币 +19 热心值 +3 收起 理由
苏紫方璇 + 1 + 15 + 1 感谢发布原创作品,吾爱破解论坛因你更精彩!
junjia215 + 1 + 1 用心讨论,共获提升!
夫子点灯 + 1 我很赞同!
话痨司机啊 + 2 + 1 谢谢@Thanks!

查看全部评分

发帖前要善用论坛搜索功能,那里可能会有你要找的答案或者已经有人发布过相同内容了,请勿重复发帖。

推荐
starkcccc 发表于 2023-4-2 10:46
写得不错,其实入门推导一下就行。目前我对深度学习的理解,现在就是一堆矩阵+激活函数的变换,对数据最后一个维度进行变来变去,个人觉得只要明白其本质,代码就很好写,尤其结合爱因斯坦求和,各种操作都可以做出来了。
沙发
 楼主| RiiiickSandes 发表于 2023-2-24 13:01 |楼主
发完才发现吾爱的md编辑器好像没有对公式的支持,只好用截图代替,大佬们请见谅

点评

可以用简书或知乎的公式接口,返回的应该是svg图片  详情 回复 发表于 2023-2-24 17:34
3#
 楼主| RiiiickSandes 发表于 2023-2-24 13:02 |楼主
因为原来接触这个领域不多,文章肯定有写的不好或者不对的地方,敬请各位不吝赐教
4#
甜萝 发表于 2023-2-24 16:59
人工智能的话 感觉对数学要求比较高 科不科班其实也没那么重要
5#
 楼主| RiiiickSandes 发表于 2023-2-24 17:12 |楼主
paypojie 发表于 2023-2-24 16:59
人工智能的话 感觉对数学要求比较高 科不科班其实也没那么重要

所以对我这种数学不好又不科班的就很不友好
6#
甜萝 发表于 2023-2-24 17:15
RiiiickSandes 发表于 2023-2-24 17:12
所以对我这种数学不好又不科班的就很不友好

没事 楼主恶补数学就好 加油
7#
侃遍天下无二人 发表于 2023-2-24 17:34
RiiiickSandes 发表于 2023-2-24 13:01
发完才发现吾爱的md编辑器好像没有对公式的支持,只好用截图代替,大佬们请见谅

可以用简书或知乎的公式接口,返回的应该是svg图片
8#
 楼主| RiiiickSandes 发表于 2023-2-25 09:22 |楼主
侃遍天下无二人 发表于 2023-2-24 17:34
可以用简书或知乎的公式接口,返回的应该是svg图片

长知识了,这就去试试
9#
 楼主| RiiiickSandes 发表于 2023-2-25 09:23 |楼主
paypojie 发表于 2023-2-24 17:15
没事 楼主恶补数学就好 加油

马上猛学!
10#
12member 发表于 2023-2-27 07:43
学习中,谢谢分享
您需要登录后才可以回帖 登录 | 注册[Register]

本版积分规则

返回列表

RSS订阅|小黑屋|处罚记录|联系我们|吾爱破解 - LCG - LSG ( 京ICP备16042023号 | 京公网安备 11010502030087号 )

GMT+8, 2025-1-11 14:57

Powered by Discuz!

Copyright © 2001-2020, Tencent Cloud.

快速回复 返回顶部 返回列表