这段代码定义了一个自定义的数据集类MyDataset,用于加载数据和标签。

  • normalize函数用于将数据归一化到0-255的范围。
  • MyDataset类继承自torch.utils.data.Dataset类,重写了__init____len____getitem__方法。
    • __init__方法初始化数据集的根目录、文件名列表和数据集的大小。
    • __len__方法返回数据集的大小。
    • __getitem__方法根据索引idx加载对应的数据和标签,并返回一个样本。
  • 数据集的根目录和文件名列表通过构造函数的参数传入。
  • __getitem__方法中,首先根据索引idx获取对应的数据文件路径,然后使用scio.loadmat函数加载.mat文件中的数据,将数据转换为int类型,并调用normalize函数进行归一化处理。最后将数据和标签封装成一个字典样本,并应用transform函数(如果有的话)对样本进行变换。
  • normalize函数将数据归一化到0-255的范围,通过计算最大值和最小值,将原始数据映射到0-255的范围内,并返回归一化后的数据。

代码逐行解释:

import numpy as np
import os
import scipy.io as scio
from torch.utils.data import Dataset


def normalize(data):  # 归一化到0-255
    rawdata_max = max(map(max, data))
    rawdata_min = min(map(min, data))
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            data[i][j] = round(((255 - 0) * (data[i][j] - rawdata_min) / (rawdata_max - rawdata_min)) + 0)
    return data


class MyDataset(Dataset):

    def __init__(self, root_dir, names_file, transform=None):
        self.root_dir = root_dir
        self.names_file = names_file
        self.transform = transform
        self.size = 0
        self.names_list = []
        if not os.path.isfile(self.names_file):
            print(self.names_file + 'does not exist!')
        file = open(self.names_file)
        for f in file:
            self.names_list.append(f)
            self.size += 1

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        data_path = self.root_dir + self.names_list[idx].split(' ')[0]
        if not os.path.isfile(data_path):
            print(data_path + 'does not exist!')
            return None
        rawdata = scio.loadmat(data_path)['data']  # 10000,12 uint16
        rawdata = rawdata.astype(int)  # int32
        data = normalize(rawdata)
        label = int(self.names_list[idx].split(' ')[1])
        sample = {'data': data, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        return sample

总结:

这段代码定义了一个用于加载和预处理 .mat 文件的自定义数据集类 MyDataset,并使用 normalize 函数对数据进行归一化处理。MyDataset 类继承自 torch.utils.data.Dataset 类,并重写了必要的方法,以支持 PyTorch 数据加载器。该代码可用于构建基于 .mat 文件的机器学习或深度学习模型。

PyTorch 自定义数据集类:加载和预处理 .mat 文件

原文地址: https://www.cveoy.top/t/topic/fPIC 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录