这段代码定义了一个名为Dataset_M4的类,它继承自torch.utils.data.Dataset类。它用于加载M4时间序列数据集,包括训练集和测试集。具体来说,它包含以下参数:

  • root_path:数据集文件的根目录。
  • flag:数据集类型,可以是“train”或“pred”。
  • size:一个包含三个整数值的列表,分别表示序列长度、标签长度和预测长度。
  • features:用于表示数据集中包含哪些特征。
  • data_path:数据集的文件名。
  • target:要预测的目标特征。
  • scale:是否对数据进行标准化。
  • inverse:是否对数据进行反标准化。
  • timeenc:时间编码的类型,目前只支持0。
  • freq:时间序列数据的时间间隔。
  • seasonal_patterns:时间序列的季节性类型。

该类的主要方法是__read_data__()和__getitem__()。read_data()方法用于加载M4数据集。它首先通过调用M4Dataset.load()方法加载数据集,然后将数据集按照季节性类型分组,将每个时间序列放入一个列表中。getitem()方法用于获取数据集中的一个样本。它首先从时间序列中随机选择一个时间点,并将该时间点前面的数据作为输入序列,将该时间点后面的数据作为标签序列。如果输入序列长度不足,则用0进行填充。如果标签序列长度不足,则将它们放在序列的开头,并用0进行填充。最后,该方法返回输入序列、标签序列、输入序列掩码和标签序列掩码。


原文地址: http://www.cveoy.top/t/topic/bE8R 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录