Python自定义数据集类:深度学习数据加载指南
Python自定义数据集类:深度学习数据加载指南
在深度学习中,高效地加载和预处理数据至关重要。为了方便管理和使用数据,通常会创建一个自定义的数据集类。本指南将带你逐步创建一个Python自定义数据集类,用于加载和预处理深度学习模型所需的数据。
代码示例pythonimport osimport scipy.io as sciofrom torch.utils.data import Dataset
def normalize(data): # 在此处添加你的数据归一化逻辑 return data
class MyDataset(Dataset): def init(self, root_dir, names_file, transform=None): self.root_dir = root_dir self.names_file = names_file self.transform = transform self.size = 0 self.names_list = [] if not os.path.isfile(self.names_file): print(self.names_file + ' does not exist!') file = open(self.names_file) for f in file: self.names_list.append(f) self.size += 1
def __len__(self): return self.size
def __getitem__(self, idx): data_path = self.root_dir + self.names_list[idx].split(' ')[0] if not os.path.isfile(data_path): print(data_path + ' does not exist!') return None rawdata = scio.loadmat(data_path)['data'] # 10000,12 uint16 rawdata = rawdata.astype(int) # int32 data = normalize(rawdata) label = int(self.names_list[idx].split(' ')[1]) sample = {'data': data, 'label': label} if self.transform: sample = self.transform(sample) return sample
代码解析
这段代码定义了一个名为 MyDataset 的类,继承自 torch.utils.data.Dataset。让我们逐步分析它的功能:
-
初始化函数 (
__init__): - 接收三个参数:root_dir(数据集根目录),names_file(包含数据文件名和标签的文件路径),transform(可选的数据转换函数)。 - 初始化实例变量,包括root_dir,names_file,transform,size(数据集大小),以及names_list(存储数据文件名和标签的列表)。 - 读取names_file文件,并将文件名和标签存储到names_list中,同时更新数据集大小size。 -
__len__函数: - 返回数据集的大小,即样本数量 (self.size)。 -
__getitem__函数: - 根据索引idx获取单个样本数据。 - 首先,根据索引从names_list中获取对应的数据文件路径和标签。 - 然后,使用scipy.io.loadmat函数加载数据文件。 - 接着,对加载的原始数据进行预处理,例如转换为整数类型,并使用normalize函数进行归一化。 - 最后,将预处理后的数据和标签存储在一个字典中返回。如果定义了transform函数,则在返回之前对样本进行转换。 -
数据预处理: - 代码中包含一个简单的
normalize函数示例,用于对数据进行归一化。你可以根据实际需求修改此函数,例如进行标准化、特征缩放等操作。 -
标签处理: - 在
__getitem__函数中,从names_list中读取标签信息,并将其转换为整数类型。
总结
通过创建一个自定义的数据集类,你可以轻松地加载、预处理和管理深度学习模型所需的数据。你可以根据自己的实际需求修改和扩展此代码,例如添加更多的数据预处理步骤,或支持不同类型的数据集。
原文地址: http://www.cveoy.top/t/topic/fPIG 著作权归作者所有。请勿转载和采集!