以下是使用PyTorch读取FashionMNIST数据集并构建自定义Dataset的示例代码:

import torch
import torchvision
from torch.utils.data import Dataset

# 加载FashionMNIST数据集
train_data = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True)
test_data = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True)

# 自定义Dataset类
class FashionMNISTDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x, y = self.data[index]
        if self.transform:
            x = self.transform(x)
        return x, y

# 数据预处理
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,), (0.5,))
])

# 构建训练集和测试集
train_dataset = FashionMNISTDataset(train_data, transform=transform)
test_dataset = FashionMNISTDataset(test_data, transform=transform)

# 测试代码
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

for i, (images, labels) in enumerate(train_loader):
    print(images.shape, labels.shape)
    break

在上述代码中,我们首先使用torchvision.datasets.FashionMNIST加载FashionMNIST数据集,然后定义了一个自定义的FashionMNISTDataset类,该类继承自torch.utils.data.Dataset,并重写了__len____getitem__方法,用于返回数据集的长度和每个样本的数据和标签。接着,我们定义了一个数据预处理的转换器,用于将数据转换为张量并进行标准化。最后,我们使用FashionMNISTDataset类构建了训练集和测试集,并使用torch.utils.data.DataLoader将其转换为可迭代的数据加载器,以便于模型训练

pytorch读取FashionMNIST构建Dataset

原文地址: https://www.cveoy.top/t/topic/fTnC 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录