PyTorch MNIST 手写数字识别模型训练实战
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import os
import torch
import numpy as np
BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000
# 准备数据集
def get_dataloader(train, batch_size=BATCH_SIZE):
transform_fn = Compose([
ToTensor(),
Normalize(mean=(0.1307,), std=(0.3081,))
]) # mean和std的形状与通道数相同
dataset = MNIST(root='./data', train=train, transform=transform_fn)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return data_loader
class MnistModel(nn.Module):
def __init__(self):
super(MnistModel, self).__init__() # 继承
self.fc1 = nn.Linear(1 * 28 * 28, 28) # 参数是input和output的feature
self.fc2 = nn.Linear(28, 10)
def forward(self, input):
# 1.进行形状的修改
x = input.view([-1, 1 * 28 * 28]) # -1表示根据形状自动调整,也可以改为input.size(0)
# 2.进行全连接的操作
x = self.fc1(x)
# 3.激活函数的处理
x = F.relu(x) # 形状没有变化
# 4.输出层
out = self.fc2(x)
return F.log_softmax(out, dim=-1)
model = MnistModel()
optimizer = Adam(model.parameters(), lr=0.001)
def train(epoch): # epoch表示几轮
data_loader = get_dataloader(True) # 获取数据加载器
for idx, (input, target) in enumerate(data_loader): # idx表示data_loader中的第几个数据,元组是data_loader的数据
optimizer.zero_grad() # 将梯度置0
output = model(input) # 调用模型,得到预测值
loss = F.nll_loss(output, target) # 调用损失函数,得到损失,是一个tensor
loss.backward() # 反向传播
optimizer.step() # 梯度的更新
if idx % 10 == 0:
print(epoch, idx, loss.item())
# for是每一轮中的数据进行遍历
def test():
loss_list = []
acc_list = []
test_dataloader = get_dataloader(train=False, batch_size=TEST_BATCH_SIZE) # 获取测试集
for idx, (input, target) in enumerate(test_dataloader):
with torch.no_grad(): # 不计算梯度
output = model(input)
cur_loss = F.nll_loss(output, target)
loss_list.append(cur_loss)
# 计算准确率,output大小[batch_size,10] target[batch_size] batch_size是多少组数据,10列是每个数字概率
pred = output.max(dim=-1)[-1] # 获取最大值位置
cur_acc = pred.eq(target).float().mean()
acc_list.append(cur_acc)
print('平均准确率:', np.mean(acc_list), '平均损失:', np.mean(loss_list))
if __name__ == '__main__':
test()
for i in range(3): # 训练三轮
train(i)
test()
代码示例
此代码使用 PyTorch 构建了一个简单的 MNIST 手写数字识别模型,包含以下步骤:
- 数据集准备: 使用
torchvision.datasets.MNIST加载 MNIST 数据集,并使用torchvision.transforms.Compose对数据进行预处理,包括转换为 Tensor 和标准化。 - 模型定义: 定义了一个简单的两层全连接神经网络
MnistModel,包含一个隐藏层和一个输出层,使用torch.nn.Linear和torch.nn.functional.relu实现。 - 训练: 使用
torch.optim.Adam优化器进行模型训练,并使用torch.nn.functional.nll_loss计算损失函数。 - 测试: 使用测试集对训练好的模型进行评估,计算模型的准确率和损失值。
优化建议
为了提升模型性能,可以考虑以下优化方法:
- 使用更复杂的模型结构,例如卷积神经网络 (CNN)。
- 尝试不同的优化器,例如
torch.optim.SGD或torch.optim.Adagrad。 - 使用学习率调度器,例如
torch.optim.lr_scheduler.StepLR或torch.optim.lr_scheduler.ReduceLROnPlateau,在训练过程中动态调整学习率。 - 使用正则化方法,例如
torch.nn.Dropout或 L1/L2 正则化,防止模型过拟合。 - 使用可视化工具,例如 TensorBoard,监控模型训练过程和评估结果。
更多学习资源
- PyTorch 官方文档: https://pytorch.org/
- MNIST 数据集介绍: http://yann.lecun.com/exdb/mnist/
- 深度学习入门教程: https://www.deeplearningbook.org/
原文地址: https://www.cveoy.top/t/topic/kns0 著作权归作者所有。请勿转载和采集!