PyTorch 自定义数据集加载:深度解析 MyDataset 类

这篇文章将逐行解析一段 Python 代码,该代码定义了一个名为 MyDataset 的自定义数据集类,用于加载和预处理机器学习任务的数据。pythonimport numpy as npimport osimport scipy.io as sciofrom torch.utils.data import Dataset

def normalize(data): # 归一化到0-255 rawdata_max = max(map(max, data)) rawdata_min = min(map(min, data)) for i in range(data.shape[0]): for j in range(data.shape[1]): data[i][j] = round(((255 - 0) * (data[i][j] - rawdata_min) / (rawdata_max - rawdata_min)) + 0) return data

class MyDataset(Dataset):

def __init__(self, root_dir, names_file, transform=None):        self.root_dir = root_dir        self.names_file = names_file        self.transform = transform        self.size = 0        self.names_list = []        if not os.path.isfile(self.names_file):            print(self.names_file + 'does not exist!')        file = open(self.names_file)        for f in file:            self.names_list.append(f)            self.size += 1

def __len__(self):        return self.size

def __getitem__(self, idx):        data_path = self.root_dir + self.names_list[idx].split(' ')[0]        if not os.path.isfile(data_path):            print(data_path + 'does not exist!')            return None        rawdata = scio.loadmat(data_path)['data']  # 10000,12 uint16        rawdata = rawdata.astype(int)  # int32        data = normalize(rawdata)        label = int(self.names_list[idx].split(' ')[1])        sample = {'data': data, 'label': label}        if self.transform:            sample = self.transform(sample)        return sample

代码解析:

  1. 导入必要库: 代码首先导入了必要的库,包括 numpy, os, scipy.iotorch.utils.data.Dataset

  2. normalize(data) 函数: - 该函数接收一个数据数组 data 作为输入。 - 它通过找到数据的最大值和最小值,将数据线性归一化到 0 到 255 的范围内。 - 循环遍历数据数组,并将每个元素替换为其归一化值。 - 最后返回归一化后的数据。

  3. MyDataset: - 继承自 torch.utils.data.Dataset 类,这是 PyTorch 中用于创建自定义数据集的基类。 - __init__(self, root_dir, names_file, transform=None): - 构造函数,用于初始化 MyDataset 类的实例。 - root_dir: 数据存储的根目录。 - names_file: 包含数据文件名及其对应标签的文件路径。 - transform: 可选参数,用于对数据进行转换的函数。 - 初始化实例变量,包括 root_dir, names_file, transform, sizenames_list。 - 读取 names_file 文件,并将文件名和标签存储在 names_list 中,同时更新数据集大小 size。 - __len__(self): - 返回数据集的大小,即 names_file 中的样本数量。 - __getitem__(self, idx): - 根据给定的索引 idx 获取对应的数据样本。 - 使用 root_dirnames_list 构建数据文件的完整路径。 - 检查文件是否存在,如果不存在则打印错误信息并返回 None。 - 使用 scipy.io.loadmat 函数加载数据文件。 - 将数据类型转换为 int。 - 使用 normalize 函数对数据进行归一化。 - 从 names_list 中获取对应的标签。 - 将数据和标签存储在字典 sample 中。 - 如果定义了 transform 函数,则对 sample 进行转换。 - 最后返回处理后的数据样本 sample

总结

这段代码定义了一个名为 MyDataset 的自定义数据集类,用于加载和预处理数据。它使用 names_file 文件来管理数据文件及其标签,并提供了一种使用 transform 参数对数据进行自定义转换的方法。通过继承 torch.utils.data.Dataset 类,MyDataset 可以轻松地与 PyTorch 的数据加载器一起使用,从而简化机器学习模型的训练过程。

PyTorch 自定义数据集加载:深度解析 MyDataset 类

原文地址: http://www.cveoy.top/t/topic/fPIz 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录