pytorch读取FashionMNIST构建Dataset
以下是使用PyTorch读取FashionMNIST数据集并构建自定义Dataset的示例代码:
import torch
import torchvision
from torch.utils.data import Dataset
# 加载FashionMNIST数据集
train_data = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True)
test_data = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True)
# 自定义Dataset类
class FashionMNISTDataset(Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x, y = self.data[index]
if self.transform:
x = self.transform(x)
return x, y
# 数据预处理
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,))
])
# 构建训练集和测试集
train_dataset = FashionMNISTDataset(train_data, transform=transform)
test_dataset = FashionMNISTDataset(test_data, transform=transform)
# 测试代码
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for i, (images, labels) in enumerate(train_loader):
print(images.shape, labels.shape)
break
在上述代码中,我们首先使用torchvision.datasets.FashionMNIST加载FashionMNIST数据集,然后定义了一个自定义的FashionMNISTDataset类,该类继承自torch.utils.data.Dataset,并重写了__len__和__getitem__方法,用于返回数据集的长度和每个样本的数据和标签。接着,我们定义了一个数据预处理的转换器,用于将数据转换为张量并进行标准化。最后,我们使用FashionMNISTDataset类构建了训练集和测试集,并使用torch.utils.data.DataLoader将其转换为可迭代的数据加载器,以便于模型训练
原文地址: https://www.cveoy.top/t/topic/fTnC 著作权归作者所有。请勿转载和采集!