MNIST手写数字识别:PyTorch实战指南
MNIST手写数字识别:PyTorch实战指南
在本文中,我们将使用PyTorch构建一个简单的卷积神经网络,并使用MNIST数据集训练它识别手写数字。MNIST数据集包含70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试。在MNIST数据集上训练分类器可以看作是图像识别的'hello world'。
MNIST 数据集(手写数字数据集)是一个公开的公共数据集,任何人都可以免费获取它。深度学习的基础就是反向传播算法,手写数字数据集是一个经典的多分类问题,通过神经网络可以很好地解决它。
1. MINIST数据集介绍
MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员。测试集 (test set) 也是同样比例构成的。
2. 代码实现
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# 定义网络结构
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.ReLU()(x)
x = self.conv2(x)
x = nn.ReLU()(x)
x = nn.MaxPool2d(2)(x)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.ReLU()(x)
x = self.dropout2(x)
x = self.fc2(x)
return nn.Softmax(dim=1)(x)
# 加载数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000, shuffle=True)
# 定义训练函数
def train(model, optimizer, criterion, train_loader):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 定义测试函数
def test(model, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += nn.CrossEntropyLoss()(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = correct / len(test_loader.dataset)
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
test_loss, correct, len(test_loader.dataset), accuracy * 100))
# 初始化模型和优化器
model = Net()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 开始训练
for epoch in range(10):
train(model, optimizer, criterion, train_loader)
test(model, test_loader)
3. 代码解析
- 网络结构定义: 代码首先定义了一个名为
Net的类,该类继承自nn.Module,并包含卷积层、池化层、全连接层和激活函数等。 - 数据集加载: 代码使用
torchvision.datasets.MNIST加载MNIST数据集,并使用torch.utils.data.DataLoader将数据集封装成数据加载器。 - 训练函数: 代码定义了一个名为
train的函数,用于训练模型。函数中使用optimizer.zero_grad()清除梯度,使用model(data)计算输出,使用criterion(output, target)计算损失,使用loss.backward()反向传播梯度,使用optimizer.step()更新模型参数。 - 测试函数: 代码定义了一个名为
test的函数,用于测试模型。函数中使用model.eval()设置模型为评估模式,使用nn.CrossEntropyLoss()计算损失,并统计模型的准确率。 - 模型初始化和训练: 代码初始化了模型、优化器和损失函数,并使用循环进行训练。每个循环结束后,都会进行一次测试并输出结果。
4. 运行代码
- 确保你已经安装了PyTorch和torchvision库。
- 将代码保存为
.py文件,例如mnist_classifier.py。 - 在命令行中运行
python mnist_classifier.py,启动训练过程。
5. 结果分析
训练结束后,你将看到每个训练周期的损失和准确率,以及最终的测试准确率。一般来说,模型的准确率会随着训练周期的增加而逐渐提高,最终能够达到较高的识别准确率。
6. 总结
本文使用PyTorch实现了一个简单的MNIST手写数字识别模型,并详细介绍了代码实现过程和结果分析。通过本篇文章,你将对PyTorch的使用以及深度学习模型训练过程有一个更深入的了解。
注意: 为了运行代码,你需要将代码中的文件路径 './data' 更改为你希望保存数据集的路径。
原文地址: https://www.cveoy.top/t/topic/b0lS 著作权归作者所有。请勿转载和采集!