MNIST 手写数字识别:神经网络模型训练与性能评估
本项目使用 PyTorch 训练一个神经网络模型,对 MNIST 手写数字数据集进行分类。模型包含 3 个全连接层,通过训练集进行训练,并使用测试集评估其性能。代码中包含以下步骤:
- 加载数据集:加载 MNIST 训练集和测试集,并创建数据加载器。
- 定义模型:定义包含 3 个全连接层的神经网络模型。
- 定义损失函数和优化器:使用交叉熵损失函数和 SGD 优化器。
- 训练模型:迭代 10 个 epoch,每个 epoch 使用训练集进行训练,并记录损失和准确率。
- 评估模型:使用测试集评估模型,计算混淆矩阵和每个类别的识别准确率。
- 可视化结果:绘制损失曲线、混淆矩阵和每个类别的识别准确率。
代码如下:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
# 设置随机种子
torch.manual_seed(123)
# 加载数据集
train_dataset = MNIST(root='./data', train=True, transform=ToTensor(), download=True)
test_dataset = MNIST(root='./data', train=False, transform=ToTensor(), download=True)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 定义神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = Net()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
losses = []
accuracies = []
for epoch in range(10):
running_loss = 0.0
correct = 0
total = 0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 计算准确率
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
running_loss += loss.item()
if i % 100 == 99:
losses.append(running_loss / 100)
accuracies.append(correct / total)
running_loss = 0.0
correct = 0
total = 0
# 绘制 Loss-acc 图
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(losses)
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.subplot(1, 2, 2)
plt.plot(accuracies)
plt.xlabel('Iterations')
plt.ylabel('Accuracy')
# 在测试集上评估模型
model.eval()
all_labels = []
all_predictions = []
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
all_labels.extend(labels.numpy())
all_predictions.extend(predicted.numpy())
# 计算混淆矩阵
cm = confusion_matrix(all_labels, all_predictions)
# 绘制混淆矩阵
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.colorbar()
plt.xticks(np.arange(10))
plt.yticks(np.arange(10))
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
# 分析每一类的识别准确率
class_acc = cm.diagonal() / cm.sum(axis=1)
plt.figure(figsize=(10, 5))
plt.bar(np.arange(10), class_acc)
plt.xlabel('Class')
plt.ylabel('Accuracy')
plt.show()
代码实现了 MNIST 手写数字识别任务,并通过可视化结果分析模型性能。
原文地址: https://www.cveoy.top/t/topic/b5pl 著作权归作者所有。请勿转载和采集!