Python中自定义数据集类MyDataset详解:使用PyTorch加载和预处理数据

在PyTorch中,我们可以通过继承torch.utils.data.Dataset类来创建自定义的数据集类,以便于我们加载和预处理自己的数据。本文将详细介绍如何创建一个名为MyDataset的自定义数据集类,并展示如何使用它来加载和预处理数据。

代码实现pythonimport osimport sciofrom torch.utils.data import Dataset

class MyDataset(Dataset): def init(self, root_dir, names_file, transform=None): self.root_dir = root_dir self.names_file = names_file self.transform = transform self.size = 0 self.names_list = [] if not os.path.isfile(self.names_file): print(self.names_file + 'does not exist!') file = open(self.names_file) for f in file: self.names_list.append(f) self.size += 1

def __len__(self):        return self.size

def __getitem__(self, idx):        data_path = self.root_dir + self.names_list[idx].split(' ')[0]        if not os.path.isfile(data_path):            print(data_path + 'does not exist!')            return None        rawdata = scio.loadmat(data_path)['data']  # 10000,12 uint16        rawdata = rawdata.astype(int)  # int32        data = normalize(rawdata)  # 假设normalize是一个预处理函数        label = int(self.names_list[idx].split(' ')[1])        sample = {'data': data, 'label': label}        if self.transform:            sample = self.transform(sample)        return sample

代码解析

  1. __init__(self, root_dir, names_file, transform=None): 初始化函数,接收三个参数: * root_dir: 数据集根目录 * names_file: 包含数据文件名和标签的文本文件路径 * transform: 可选参数,用于对数据进行预处理的函数2. __len__(self): 返回数据集的大小3. __getitem__(self, idx): 根据索引idx获取对应的样本,主要步骤如下: * 根据文件名和root_dir拼接出数据文件的完整路径 * 使用scio.loadmat函数加载数据文件 * 对原始数据进行类型转换和归一化处理 * 从names_file中读取对应的标签 * 将数据和标签封装成字典sample * 如果定义了transform函数,则对sample进行预处理

使用示例python# 实例化MyDataset类dataset = MyDataset(root_dir='./data/', names_file='./data/names.txt', transform=transform)

获取数据集中第10个样本sample = dataset[9]

打印样本信息print(sample)

总结

通过自定义数据集类MyDataset,我们可以方便地加载和预处理自己的数据,为后续的模型训练做好准备。需要注意的是,transform参数可以是一个自定义的函数或组合函数,用于对数据进行增强等操作,例如数据归一化、随机裁剪、水平翻转等。

Python中自定义数据集类MyDataset详解:使用PyTorch加载和预处理数据

原文地址: https://www.cveoy.top/t/topic/fPII 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录