PyTorch 自定义数据集加载:深度解析 MyDataset 类
PyTorch 自定义数据集加载:深度解析 MyDataset 类
这篇文章将逐行解析一段 Python 代码,该代码定义了一个名为 MyDataset 的自定义数据集类,用于加载和预处理机器学习任务的数据。pythonimport numpy as npimport osimport scipy.io as sciofrom torch.utils.data import Dataset
def normalize(data): # 归一化到0-255 rawdata_max = max(map(max, data)) rawdata_min = min(map(min, data)) for i in range(data.shape[0]): for j in range(data.shape[1]): data[i][j] = round(((255 - 0) * (data[i][j] - rawdata_min) / (rawdata_max - rawdata_min)) + 0) return data
class MyDataset(Dataset):
def __init__(self, root_dir, names_file, transform=None): self.root_dir = root_dir self.names_file = names_file self.transform = transform self.size = 0 self.names_list = [] if not os.path.isfile(self.names_file): print(self.names_file + 'does not exist!') file = open(self.names_file) for f in file: self.names_list.append(f) self.size += 1
def __len__(self): return self.size
def __getitem__(self, idx): data_path = self.root_dir + self.names_list[idx].split(' ')[0] if not os.path.isfile(data_path): print(data_path + 'does not exist!') return None rawdata = scio.loadmat(data_path)['data'] # 10000,12 uint16 rawdata = rawdata.astype(int) # int32 data = normalize(rawdata) label = int(self.names_list[idx].split(' ')[1]) sample = {'data': data, 'label': label} if self.transform: sample = self.transform(sample) return sample
代码解析:
-
导入必要库: 代码首先导入了必要的库,包括
numpy,os,scipy.io和torch.utils.data.Dataset。 -
normalize(data)函数: - 该函数接收一个数据数组data作为输入。 - 它通过找到数据的最大值和最小值,将数据线性归一化到 0 到 255 的范围内。 - 循环遍历数据数组,并将每个元素替换为其归一化值。 - 最后返回归一化后的数据。 -
MyDataset类: - 继承自torch.utils.data.Dataset类,这是 PyTorch 中用于创建自定义数据集的基类。 -__init__(self, root_dir, names_file, transform=None): - 构造函数,用于初始化MyDataset类的实例。 -root_dir: 数据存储的根目录。 -names_file: 包含数据文件名及其对应标签的文件路径。 -transform: 可选参数,用于对数据进行转换的函数。 - 初始化实例变量,包括root_dir,names_file,transform,size和names_list。 - 读取names_file文件,并将文件名和标签存储在names_list中,同时更新数据集大小size。 -__len__(self): - 返回数据集的大小,即names_file中的样本数量。 -__getitem__(self, idx): - 根据给定的索引idx获取对应的数据样本。 - 使用root_dir和names_list构建数据文件的完整路径。 - 检查文件是否存在,如果不存在则打印错误信息并返回None。 - 使用scipy.io.loadmat函数加载数据文件。 - 将数据类型转换为int。 - 使用normalize函数对数据进行归一化。 - 从names_list中获取对应的标签。 - 将数据和标签存储在字典sample中。 - 如果定义了transform函数,则对sample进行转换。 - 最后返回处理后的数据样本sample。
总结
这段代码定义了一个名为 MyDataset 的自定义数据集类,用于加载和预处理数据。它使用 names_file 文件来管理数据文件及其标签,并提供了一种使用 transform 参数对数据进行自定义转换的方法。通过继承 torch.utils.data.Dataset 类,MyDataset 可以轻松地与 PyTorch 的数据加载器一起使用,从而简化机器学习模型的训练过程。
原文地址: http://www.cveoy.top/t/topic/fPIz 著作权归作者所有。请勿转载和采集!