Python中自定义数据集类MyDataset详解:使用PyTorch加载和预处理数据
Python中自定义数据集类MyDataset详解:使用PyTorch加载和预处理数据
在PyTorch中,我们可以通过继承torch.utils.data.Dataset类来创建自定义的数据集类,以便于我们加载和预处理自己的数据。本文将详细介绍如何创建一个名为MyDataset的自定义数据集类,并展示如何使用它来加载和预处理数据。
代码实现pythonimport osimport sciofrom torch.utils.data import Dataset
class MyDataset(Dataset): def init(self, root_dir, names_file, transform=None): self.root_dir = root_dir self.names_file = names_file self.transform = transform self.size = 0 self.names_list = [] if not os.path.isfile(self.names_file): print(self.names_file + 'does not exist!') file = open(self.names_file) for f in file: self.names_list.append(f) self.size += 1
def __len__(self): return self.size
def __getitem__(self, idx): data_path = self.root_dir + self.names_list[idx].split(' ')[0] if not os.path.isfile(data_path): print(data_path + 'does not exist!') return None rawdata = scio.loadmat(data_path)['data'] # 10000,12 uint16 rawdata = rawdata.astype(int) # int32 data = normalize(rawdata) # 假设normalize是一个预处理函数 label = int(self.names_list[idx].split(' ')[1]) sample = {'data': data, 'label': label} if self.transform: sample = self.transform(sample) return sample
代码解析
__init__(self, root_dir, names_file, transform=None): 初始化函数,接收三个参数: *root_dir: 数据集根目录 *names_file: 包含数据文件名和标签的文本文件路径 *transform: 可选参数,用于对数据进行预处理的函数2.__len__(self): 返回数据集的大小3.__getitem__(self, idx): 根据索引idx获取对应的样本,主要步骤如下: * 根据文件名和root_dir拼接出数据文件的完整路径 * 使用scio.loadmat函数加载数据文件 * 对原始数据进行类型转换和归一化处理 * 从names_file中读取对应的标签 * 将数据和标签封装成字典sample* 如果定义了transform函数,则对sample进行预处理
使用示例python# 实例化MyDataset类dataset = MyDataset(root_dir='./data/', names_file='./data/names.txt', transform=transform)
获取数据集中第10个样本sample = dataset[9]
打印样本信息print(sample)
总结
通过自定义数据集类MyDataset,我们可以方便地加载和预处理自己的数据,为后续的模型训练做好准备。需要注意的是,transform参数可以是一个自定义的函数或组合函数,用于对数据进行增强等操作,例如数据归一化、随机裁剪、水平翻转等。
原文地址: https://www.cveoy.top/t/topic/fPII 著作权归作者所有。请勿转载和采集!