LeNet网络结构优化:提高FashionMNIST数据集分类精度

本代码使用PyTorch框架对经典的LeNet网络结构进行优化,通过添加BatchNorm层、Dropout层、调整卷积核大小和步长等方法,在FashionMNIST数据集上取得了更高的分类精度。

代码实现

import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):  
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)  # 修改卷积核大小
        self.bn1 = nn.BatchNorm2d(6)  # 添加BatchNorm层
        self.pool1 = nn.MaxPool2d(2, 2) #将LeNet中的平均池化改为最大池化
        self.conv2 = nn.Conv2d(6, 16, 7) # 将其中一个5*5的卷积核修改为7*7
        #self.conv2 = nn.Conv2d(6, 12, 3)  # 修改卷积核大小,增加卷积层
        self.bn2 = nn.BatchNorm2d(16)  # 添加BatchNorm层
        self.conv3 = nn.Conv2d(16, 24, 3, stride=2)  # 修改卷积核大小,增加卷积层,调整步长
        self.bn3 = nn.BatchNorm2d(24)
        self.conv4 = nn.Conv2d(24, 48, 3, stride=2)  # 修改卷积核大小,增加卷积层,调整步长
        self.bn4 = nn.BatchNorm2d(48)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(48 * 3 * 3, 256)  # 增加维度
        self.dropout = nn.Dropout(0.5)  # 添加Dropout层
        self.fc2 = nn.Linear(256, 128)  # 增加维度
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU() #将LeNet中的激活函数替换为ReLU

    def forward(self, x):
        x = self.pool1(self.relu(self.bn1(self.conv1(x))))  # 添加BatchNorm层
        x = self.pool2(self.relu(self.bn2(self.conv2(x))))  # 添加BatchNorm层
        out = self.bn3(self.conv3(x))  # 添加BatchNorm层
        out = self.bn4(self.conv4(out))  # 添加BatchNorm层
        out = out.view(-1, 48 * 3 * 3)  # 修改维度
        out = self.dropout(self.relu(self.fc1(out)))  # 添加Dropout层
        out = self.dropout(self.relu(self.fc2(out)))  # 添加Dropout层
        out = self.fc3(out)
        return out

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 超参数
batch_size = 128
learning_rate = 0.01
num_epochs = 20
# 读取训练集和测试集
train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 定义模型
model = LeNet().to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# 训练模型
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
for epoch in range(num_epochs):
    train_loss = 0.0
    train_total = 0
    train_correct = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()
    train_loss /= len(train_loader.dataset)
    train_accuracy = 100.0 * train_correct / train_total
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    test_loss = 0.0
    test_total = 0
    test_correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
    test_loss /= len(test_loader.dataset)
    test_accuracy = 100.0 * test_correct / test_total
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)
    print('Epoch [{}/{}], Train Loss: {:.4f}, Train Accuracy: {:.2f}%, Test Loss: {:.4f}, Test Accuracy: {:.2f}%'
          .format(epoch + 1, num_epochs, train_loss, train_accuracy, test_loss, test_accuracy))
# 绘制损失函数曲线和分类正确率曲线
#plt.rcParams['font.sans-serif']=['Microsoft YaHei']
#plt.rcParams['axes.unicode_minus']=False
plt.figure()
plt.plot(np.arange(num_epochs), train_losses, label='Train Loss')
plt.plot(np.arange(num_epochs), test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Testing Loss')
plt.legend()
plt.figure()
plt.plot(np.arange(num_epochs), train_accuracies, label='Train Accuracy')
plt.plot(np.arange(num_epochs), test_accuracies, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Testing Accuracy')
plt.legend()
# 保存模型
torch.save(model.state_dict(), 'lenet.pth')
# 测试模型
model.load_state_dict(torch.load('lenet.pth'))
model.eval()
with torch.no_grad():
    test_total = 0
    test_correct = 0
    confusion_matrix = np.zeros((10, 10))
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()
        for i in range(labels.size(0)):
            confusion_matrix[labels[i]][predicted[i]] += 1
    test_accuracy = 100.0 * test_correct / test_total
    print('准确率: {:.2f}%'.format(test_accuracy))
    print('混淆矩阵:')
    print(confusion_matrix)

代码解释

  1. 网络结构优化

    • 添加BatchNorm层:BatchNorm层可以加速训练过程,并防止梯度消失和爆炸问题,提高模型的泛化能力。
    • 添加Dropout层:Dropout层可以防止过拟合,提高模型的泛化能力。
    • 调整卷积核大小和步长:根据数据特征调整卷积核大小和步长,可以提取更有效的特征,提高模型精度。
    • 修改池化层:将LeNet中的平均池化改为最大池化,可以更好地保留图像的特征信息。
  2. 训练过程

    • 使用FashionMNIST数据集进行训练和测试。
    • 使用SGD优化器进行模型训练。
    • 使用交叉熵损失函数进行模型评估。
    • 绘制训练过程中的损失函数曲线和分类正确率曲线,观察模型训练情况。
  3. 模型测试

    • 使用训练好的模型对测试集进行测试,评估模型的泛化能力。
    • 计算测试集的分类精度,并输出混淆矩阵,分析模型的错误分类情况。

结论

通过对LeNet网络结构的优化,在FashionMNIST数据集上取得了更高的分类精度,说明了BatchNorm层、Dropout层、调整卷积核大小和步长等方法的有效性。

LeNet网络结构优化:提高FashionMNIST数据集分类精度

原文地址: https://www.cveoy.top/t/topic/n9AS 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录