使用 PyTorch 识别自定义数据集 - MNIST 模型示例

本示例展示如何使用 PyTorch 的 MNIST 模型识别自定义数据集。我们首先介绍如何将自定义数据集转换为适合 PyTorch 的格式,并提供详细步骤,最后展示如何使用训练和测试代码识别自定义数据集。

准备自定义数据集

要识别自己提供的数据集,需要将数据集准备成与 MNIST 数据集相同的格式,即将数据集转换为 torchvision.datasets.ImageFolder 类的格式。具体步骤如下:

  1. 创建一个文件夹,将所有的图片放在该文件夹下,并按照类别分好子文件夹。例如,你可以创建一个名为 'train_folder' 的文件夹,其中包含 'cat' 和 'dog' 两个子文件夹,分别存放猫和狗的图片。
  2. 使用 torchvision.transforms.Compose 创建一个数据预处理的管道,包括将图片转换为 tensor、归一化等操作。
  3. 使用 torchvision.datasets.ImageFolder 创建一个数据集对象,将文件夹路径和数据预处理管道传入,即可得到一个数据集对象。
  4. 使用 torch.utils.data.DataLoader 将数据集包装成一个数据加载器,可以设置 batch_size 等参数。
  5. 在训练和测试时,使用该数据加载器加载数据即可。

代码示例

from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader

BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000

# 准备数据集
def get_dataloader(train, batch_size=BATCH_SIZE):
    transform_fn = Compose([
        ToTensor(),
        Normalize(mean=(0.1307,), std=(0.3081,))
    ])  # mean和std的形状与通道数相同

    if train:
        data_folder = 'train_folder'
    else:
        data_folder = 'test_folder'
    dataset = ImageFolder(root=data_folder, transform=transform_fn)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return data_loader

# 训练和测试代码不变,可以直接使用。

完整代码

from torchvision.datasets import MNIST, ImageFolder
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import os
import torch
import numpy as np

BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000

# 准备数据集
def get_dataloader(train, batch_size=BATCH_SIZE):
    transform_fn = Compose([
        ToTensor(),
        Normalize(mean=(0.1307,), std=(0.3081,))
    ])  # mean和std的形状与通道数相同

    if train:
        data_folder = 'train_folder'
    else:
        data_folder = 'test_folder'
    dataset = ImageFolder(root=data_folder, transform=transform_fn)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return data_loader


class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()  # 继承
        self.fc1 = nn.Linear(1 * 28 * 28, 28)  # 参数是input和output的feature
        self.fc2 = nn.Linear(28, 10)

    def forward(self, input):
        # 1.进行形状的修改
        x = input.view([-1, 1 * 28 * 28])  # -1表示根据形状自动调整,也可以改为input.size(0)
        # 2.进行全连接的操作
        x = self.fc1(x)
        # 3.激活函数的处理
        x = F.relu(x)  # 形状没有变化
        # 4.输出层
        out = self.fc2(x)
        return F.log_softmax(out, dim=-1)


model = MnistModel()
optimizer = Adam(model.parameters(), lr=0.001)


def train(epoch):  # epoch表示几轮
    data_loader = get_dataloader(True)  # 获取数据加载器
    for idx, (input, target) in enumerate(data_loader):  # idx表示data_loader中的第几个数据,元组是data_loader的数据
        optimizer.zero_grad()  # 将梯度置0
        output = model(input)  # 调用模型,得到预测值
        loss = F.nll_loss(output, target)  # 调用损失函数,得到损失,是一个tensor
        loss.backward()  # 反向传播
        optimizer.step()  # 梯度的更新
        if idx % 10 == 0:
            print(epoch, idx, loss.item())
    # for是每一轮中的数据进行遍历


def test():
    loss_list = []
    acc_list = []
    test_dataloader = get_dataloader(train=False, batch_size=TEST_BATCH_SIZE)  # 获取测试集
    for idx, (input, target) in enumerate(test_dataloader):
        with torch.no_grad():  # 不计算梯度
            output = model(input)
            cur_loss = F.nll_loss(output, target)
            loss_list.append(cur_loss)
            # 计算准确率,output大小[batch_size,10] target[batch_size] batch_size是多少组数据,10列是每个数字概率
            pred = output.max(dim=-1)[-1]  # 获取最大值位置
            cur_acc = pred.eq(target).float().mean()
            acc_list.append(cur_acc)
    print('平均准确率:', np.mean(acc_list), '平均损失:', np.mean(loss_list))


test()
for i in range(1):  # 训练三轮
    train(i)
test()

注意事项

  • 自定义数据集中的图片尺寸应该与 MNIST 数据集的图片尺寸相同,即 28x28 像素。
  • 确保自定义数据集的类别数量与 MNIST 数据集的类别数量相同,即 10 类。
  • 自定义数据集的图片格式应该与 MNIST 数据集的图片格式相同,即灰度图像。

总结

本示例展示了如何使用 PyTorch 的 MNIST 模型识别自定义数据集。通过将自定义数据集转换为适合 PyTorch 的格式,并使用训练和测试代码,可以轻松地使用 MNIST 模型识别自定义数据集。

使用 PyTorch 识别自定义数据集 - MNIST 模型示例

原文地址: https://www.cveoy.top/t/topic/kppn 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录