吾爱破解 - 52pojie.cn

 找回密码
 注册[Register]

QQ登录

只需一步,快速开始

查看: 581|回复: 9
收起左侧

[Python 原创] BP实战minist数据集

[复制链接]
2045976511 发表于 2024-10-16 19:26

BP实战minist数据集

前言

在当今人工智能与机器学习飞速发展的时代,神经网络作为一种强大的工具,在图像识别、自然语言处理等众多领域都展现出了卓越的性能。其中,BP(Back Propagation)神经网络作为一种经典的前馈神经网络,以其简单的结构和高效的学习能力,一直备受研究者和开发者的青睐。

本次实战,我们将目光聚焦于著名的 MNIST 数据集。MNIST 数据集由手写数字的图像组成,它具有规模适中、问题清晰等特点,非常适合作为神经网络的入门实战案例。

通过使用 BP 神经网络对 MNIST 数据集中的手写数字进行识别,我们将深入了解神经网络的工作原理、训练过程以及在实际问题中的应用。 在这个过程中,我们将逐步探索如何构建 BP 神经网络模型、如何加载和预处理数据集、如何进行模型的训练和优化,以及如何评估模型的性能。


一、MNIST数据集介绍和加载

1.MNIST数据集介绍

MNIST 数据集是机器学习领域中广泛使用的一个基准数据集,主要用于图像识别和数字分类任务。

MNIST 数据集由手写数字的图像组成,这些数字是从 0 到 9 的整数。它包含了 70,000 张灰度图像,其中 60,000 张用于训练,10,000 张用于测试。每一张图像都是 28×28 像素的,呈现出不同人书写的数字形态,具有一定的多样性和复杂性。

该数据集的图像是灰度的且数字居中,这在一定程度上减少了预处理的工作量并加快了模型的运行速度。其简洁明了的特点使得 MNIST 成为初学者进入机器学习和深度学习领域的理想选择,许多经典的算法和模型都首先在这个数据集上进行验证和优化。 MNIST 数据集的广泛应用推动了图像识别技术的发展,研究人员通过在这个数据集上不断尝试新的算法和改进模型结构,为更复杂的图像识别任务奠定了基础。

2.加载数据集MNIST数据集

# MNIST 包含 70,000 张手写数字图像: 60,000 张用于训练,10,000 张用于测试。
# 图像是灰度的,28×28 像素的,并且居中的,以减少预处理和加快运行。
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# 使用 torchvision 读取数据
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 使用 DataLoader 加载数据
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

首先定义了一个数据转换 transform,包括将图像转换为张量并进行归一化处理。然后使用 torchvision.datasets.MNIST 加载 MNIST 数据集,分别设置 train=True 和 train=False 来获取训练集和测试集。最后使用 torch.utils.data.DataLoader 将数据集包装成数据加载器,设置了批量大小为 64,训练集进行随机打乱,测试集不打乱。


二、构建 BP 网络模型

# 第 1 步:构建 BP 网络模型
class BPNetwork(torch.nn.Module):

    def __init__(self):
        super(BPNetwork, self).__init__()

        """
        定义第一个线性层,
        输入为图片(28x28),
        输出为第一个隐层的输入,大小为 128。
        """
        self.linear1 = torch.nn.Linear(28 * 28, 128)
        # 在第一个隐层使用 ReLU 激活函数
        self.relu1 = torch.nn.ReLU()
        """
        定义第二个线性层,
        输入是第一个隐层的输出,
        输出为第二个隐层的输入,大小为 64。
        """
        self.linear2 = torch.nn.Linear(128, 64)
        # 在第二个隐层使用 ReLU 激活函数
        self.relu2 = torch.nn.ReLU()
        """
        定义第三个线性层,
        输入是第二个隐层的输出,
        输出为输出层,大小为 10
        """
        self.linear3 = torch.nn.Linear(64, 10)
        # 最终的输出经过 softmax 进行归一化
        self.softmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, x):
        """
        定义神经网络的前向传播
        x: 图片数据, shape 为(64, 1, 28, 28)
        """
        # 首先将 x 的 shape 转为(64, 784)
        x = x.view(x.shape[0], -1)

        # 接下来进行前向传播
        x = self.linear1(x)
        x = self.relu1(x)
        x = self.linear2(x)
        x = self.relu2(x)
        x = self.linear3(x)
        x = self.softmax(x)

        # 上述一串,可以直接使用 x = self.model(x) 代替。

        return x

1.神经网络结构图示

层名 输入大小 输出大小
输入层(展平后的图片) 784 -
第一个隐藏层 784 128
第二个隐藏层 128 64
输出层 64 10

2.BP 网络模型代码解释

定义了一个名为 BPNetwork 的类,继承自 torch.nn.Module,用于构建一个三层的神经网络模型。在 init 方法中定义了三个线性层和两个 ReLU 激活函数以及一个对数 softmax 函数用于输出层的归一化。在 forward 方法中定义了神经网络的前向传播过程,首先将输入的图片数据展平为一维向量,然后依次通过三个线性层和激活函数,最后经过 softmax 归一化得到输出。


三、定义和训练BP 网络模型

model = BPNetwork()
# criterion = torch.nn.MSELoss()
criterion = torch.nn.NLLLoss()                                            # 定义 loss 函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)   # 定义优化器

epochs = 15                               # 一共训练 15 轮
for i in range(epochs):
    running_loss = 0                     # 本轮的损失值
    for images, labels in trainloader:
        # 前向传播获取预测值
        output = model(images)
        # 计算损失
        loss = criterion(output, labels)
        # 进行反向传播
        loss.backward()
        # 更新权重
        optimizer.step()
        # 清空梯度
        optimizer.zero_grad()
        # 累加损失
        running_loss += loss.item()

    # 一轮循环结束后打印本轮的损失函数
    print("Epoch {} - Training loss: {}".format(i, running_loss / len(trainloader)))

这里首先创建了一个 BPNetwork 模型实例,然后定义了损失函数为负对数似然损失(torch.nn.NLLLoss),优化器为随机梯度下降(torch.optim.SGD),设置了学习率为 0.003 和动量为 0.9。接着设置了训练轮数为 15。在训练循环中,遍历训练数据加载器,进行前向传播得到预测值,计算损失,然后进行反向传播、更新权重和清空梯度。最后打印每一轮的训练损失。


四、测试模型

examples = enumerate(testloader)
batch_idx, (imgs, labels) = next(examples)

fig = plt.figure()
for i in range(64):

    logps = model(imgs[i])                    # 通过模型进行预测

    probab = list(logps.detach().numpy()[0])  # 将预测结果转为概率列表。[0]是取第一张照片的 10 个数字的概率列表(因为一次只预测一张照片)
    pred_label = probab.index(max(probab))    # 取最大的 index 作为预测结果

    img = torch.squeeze(imgs[i])
    img = img.numpy()

    plt.subplot(8, 8, i + 1)
    plt.tight_layout()
    plt.imshow(img, cmap='gray', interpolation='none')
    plt.title("预测值: {}".format(pred_label))
    plt.xticks([])
    plt.yticks([])

plt.show()

首先从测试数据加载器中获取一批数据,然后创建一个 matplotlib 的图形对象。接着在一个循环中,对这批数据中的前 64 张图像进行预测,将预测结果转换为概率列表,取最大概率的索引作为预测标签。同时将图像数据转换为 numpy 数组并进行展示,在图像上标注预测值。最后显示绘制的图形,展示测试结果。


五、训练结果

见附件

总结

用 BP 神经网络对 MNIST 数据集进行实战。首先构建了一个包含两个隐藏层的三层神经网络模型,使用全连接层和 ReLU 激活函数,输出层经 softmax 归一化。接着加载 MNIST 数据集并预处理,用数据加载器进行高效加载。然后定义损失函数和优化器进行模型训练,通过前向传播、计算损失、反向传播等步骤更新权重。最后在测试集上进行预测,展示图像及预测结果。此实战有助于理解神经网络原理和训练过程,为深入学习提供基础和经验。

联想截图_20241008083937.png
联想截图_20241008080448.png

免费评分

参与人数 2威望 +1 吾爱币 +21 热心值 +2 收起 理由
苏紫方璇 + 1 + 20 + 1 感谢发布原创作品,吾爱破解论坛因你更精彩!
restart19 + 1 + 1 谢谢@Thanks!

查看全部评分

本帖被以下淘专辑推荐:

发帖前要善用论坛搜索功能,那里可能会有你要找的答案或者已经有人发布过相同内容了,请勿重复发帖。

侃遍天下无二人 发表于 2024-10-16 19:44
这篇也是你写的吗,是的话在文中补个csdn链接和登录截图

BP实战minist数据集_bp神经网络实战
https://blog.csdn.net/shiguang521314/article/details/137123457
 楼主| 2045976511 发表于 2024-10-16 21:05
侃遍天下无二人 发表于 2024-10-16 19:44
这篇也是你写的吗,是的话在文中补个csdn链接和登录截图

BP实战minist数据集_bp神经网络实战

BP实战minist数据集_bp神经网络实战-CSDN博客
联想截图_20241016203318.png
WWQ052911 发表于 2024-10-16 22:22
wudavid33 发表于 2024-10-16 23:18
学习了,谢谢楼主分享
sifan785622020 发表于 2024-10-16 23:19
可以写的挺清楚的
三滑稽甲苯 发表于 2024-10-17 08:40
很好的入门教程
tomjin 发表于 2024-10-18 20:00
666,虽然没看懂
yoga2joker 发表于 2024-10-21 08:43
谢谢分享
knife5719 发表于 2024-11-21 10:13
学习一下
您需要登录后才可以回帖 登录 | 注册[Register]

本版积分规则

返回列表

RSS订阅|小黑屋|处罚记录|联系我们|吾爱破解 - LCG - LSG ( 京ICP备16042023号 | 京公网安备 11010502030087号 )

GMT+8, 2024-11-24 09:30

Powered by Discuz!

Copyright © 2001-2020, Tencent Cloud.

快速回复 返回顶部 返回列表