from torchvisiondatasets import MNISTfrom torchvisiontransforms import Compose ToTensor Normalizefrom torchutilsdata import DataLoaderimport torchnn as nnimport torchnnfunctional as Ffrom torchoptim i
要识别自己提供的数据集,需要将数据集准备成与MNIST数据集相同的格式,即将数据集转换为torchvision.datasets.ImageFolder类的格式。具体步骤如下:
- 创建一个文件夹,将所有的图片放在该文件夹下,并按照类别分好子文件夹。
- 使用torchvision.transforms.Compose创建一个数据预处理的管道,包括将图片转换为tensor、归一化等操作。
- 使用torchvision.datasets.ImageFolder创建一个数据集对象,将文件夹路径和数据预处理管道传入,即可得到一个数据集对象。
- 使用torch.utils.data.DataLoader将数据集包装成一个数据加载器,可以设置batch_size等参数。
- 在训练和测试时,使用该数据加载器加载数据即可。
具体代码如下:
from torchvision.datasets import ImageFolder from torchvision.transforms import Compose, ToTensor, Normalize from torch.utils.data import DataLoader
BATCH_SIZE = 128 TEST_BATCH_SIZE = 1000
准备数据集
def get_dataloader(train, batch_size=BATCH_SIZE): transform_fn = Compose([ ToTensor(), Normalize(mean=(0.1307,), std=(0.3081,)) ]) # mean和std的形状与通道数相同
if train:
data_folder = 'train_folder'
else:
data_folder = 'test_folder'
dataset = ImageFolder(root=data_folder, transform=transform_fn)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return data_loader
训练和测试代码不变,可以直接使用
原文地址: https://www.cveoy.top/t/topic/c1NQ 著作权归作者所有。请勿转载和采集!