TFT模型是一种用于多步时间序列预测的Transformer模型,其中GRN模块是模型的关键组成部分。GRN模块的输入包括:

  1. 时间嵌入(Time Embedding):用于表示时间信息的向量。这个向量的维度通常与其他特征的维度相同,用于将时间信息融合到模型中。

  2. 历史数据(History Data):用于训练模型的历史时间序列数据。这个输入通常是一个三维张量,包括时间步、特征和样本数量三个维度。

  3. 静态特征(Static Features):与时间无关的特征数据。这个输入通常是一个二维张量,包括特征和样本数量两个维度。

GRN模块的代码表达如下:

class GatedResidualNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers, dropout):
        super(GatedResidualNetwork, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout = dropout

        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(nn.Sequential(
                nn.Linear(input_size, hidden_size),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_size, output_size),
                nn.Sigmoid()
            ))

    def forward(self, x):
        out = x
        for i in range(self.num_layers):
            residual = out
            out = self.layers[i](out)
            out = out * residual
        return out

在代码中,GRN模块首先定义了一些超参数,包括输入大小、隐藏层大小、输出大小、层数和dropout率。然后通过一个循环添加了多个GRU单元构成了GRN模块。

在forward函数中,GRN模块接收一个输入张量x,然后通过循环依次对每个GRU单元进行计算,最后输出一个张量。在计算每个GRU单元的输出时,模块首先保存当前输入作为残差,然后将输入传递给一个全连接层,使用ReLU激活函数并应用dropout。接着,模块将全连接层的输出传递给另一个全连接层,使用Sigmoid激活函数。最后,模块将Sigmoid激活函数的输出与残差相乘得到最终的输出。

TFT Temporal Fusion Transformers 一种针对多步预测任务的Transformer模型中对于其中GRN模块的输入分别是什么及其维度以及代码表达帮我详细分析一下。

原文地址: https://www.cveoy.top/t/topic/Cky 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录