PyTorch 自定义数据集类:加载和预处理 .mat 文件
这段代码定义了一个自定义的数据集类MyDataset,用于加载数据和标签。
normalize函数用于将数据归一化到0-255的范围。MyDataset类继承自torch.utils.data.Dataset类,重写了__init__、__len__和__getitem__方法。__init__方法初始化数据集的根目录、文件名列表和数据集的大小。__len__方法返回数据集的大小。__getitem__方法根据索引idx加载对应的数据和标签,并返回一个样本。
- 数据集的根目录和文件名列表通过构造函数的参数传入。
__getitem__方法中,首先根据索引idx获取对应的数据文件路径,然后使用scio.loadmat函数加载.mat文件中的数据,将数据转换为int类型,并调用normalize函数进行归一化处理。最后将数据和标签封装成一个字典样本,并应用transform函数(如果有的话)对样本进行变换。normalize函数将数据归一化到0-255的范围,通过计算最大值和最小值,将原始数据映射到0-255的范围内,并返回归一化后的数据。
代码逐行解释:
import numpy as np
import os
import scipy.io as scio
from torch.utils.data import Dataset
def normalize(data): # 归一化到0-255
rawdata_max = max(map(max, data))
rawdata_min = min(map(min, data))
for i in range(data.shape[0]):
for j in range(data.shape[1]):
data[i][j] = round(((255 - 0) * (data[i][j] - rawdata_min) / (rawdata_max - rawdata_min)) + 0)
return data
class MyDataset(Dataset):
def __init__(self, root_dir, names_file, transform=None):
self.root_dir = root_dir
self.names_file = names_file
self.transform = transform
self.size = 0
self.names_list = []
if not os.path.isfile(self.names_file):
print(self.names_file + 'does not exist!')
file = open(self.names_file)
for f in file:
self.names_list.append(f)
self.size += 1
def __len__(self):
return self.size
def __getitem__(self, idx):
data_path = self.root_dir + self.names_list[idx].split(' ')[0]
if not os.path.isfile(data_path):
print(data_path + 'does not exist!')
return None
rawdata = scio.loadmat(data_path)['data'] # 10000,12 uint16
rawdata = rawdata.astype(int) # int32
data = normalize(rawdata)
label = int(self.names_list[idx].split(' ')[1])
sample = {'data': data, 'label': label}
if self.transform:
sample = self.transform(sample)
return sample
总结:
这段代码定义了一个用于加载和预处理 .mat 文件的自定义数据集类 MyDataset,并使用 normalize 函数对数据进行归一化处理。MyDataset 类继承自 torch.utils.data.Dataset 类,并重写了必要的方法,以支持 PyTorch 数据加载器。该代码可用于构建基于 .mat 文件的机器学习或深度学习模型。
原文地址: https://www.cveoy.top/t/topic/fPIC 著作权归作者所有。请勿转载和采集!