The TFT architecture defined in the paper Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting 其中的GRN模块的具体输入、输出可否结合代码详细说明一下呢。
GRN模块的输入是来自TFT模型的输出,即每个时间步长的表示和历史目标值。具体来说,输入包括以下内容:
-
encoder_outputs: 编码器的输出表示,包括历史时间步的表示和当前时间步的表示。shape为(batch_size, num_encoder_steps, hidden_size)。
-
target: 历史目标值,即前面时间步的真实值。shape为(batch_size, num_decoder_steps, output_size)。
-
static_input: 静态输入,包括时序特征和静态特征。shape为(batch_size, static_input_size)。
-
decoder_length: 解码器的长度,即预测的时间步数。
GRN模块的输出是门控调整后的表示,即对输入的编码器输出和历史目标值进行加权,生成一个新的表示。具体来说,输出包括以下内容:
-
gated_encoder_outputs: 经过门控调整后的编码器输出表示,shape为(batch_size, num_encoder_steps, hidden_size)。
-
gated_target: 经过门控调整后的历史目标值表示,shape为(batch_size, num_decoder_steps, output_size)。
下面是GRN模块的代码实现:
class GatedResidualNetwork(tf.keras.Model):
def __init__(self, hidden_size, dropout_rate, use_time_distributed=False):
super(GatedResidualNetwork, self).__init__()
self.hidden_size = hidden_size
self.dropout_rate = dropout_rate
self.use_time_distributed = use_time_distributed
# 定义门控网络
self.gate_network = tf.keras.Sequential([
tf.keras.layers.Dense(hidden_size, activation="sigmoid"),
tf.keras.layers.Dropout(dropout_rate)
])
# 定义调整网络
self.adjustment_network = tf.keras.Sequential([
tf.keras.layers.Dense(hidden_size, activation="relu"),
tf.keras.layers.Dropout(dropout_rate),
tf.keras.layers.Dense(hidden_size)
])
def call(self, inputs):
encoder_outputs, target, static_input, decoder_length = inputs
# 计算门控分数
gate_scores = self.gate_network(tf.concat([encoder_outputs, target], axis=-1))
# 计算调整分数
adjustment_scores = self.adjustment_network(tf.concat([encoder_outputs, target, static_input], axis=-1))
# 应用门控调整
if self.use_time_distributed:
gate_scores = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.hidden_size, activation="sigmoid"))(gate_scores)
adjustment_scores = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.hidden_size))(adjustment_scores)
gated_encoder_outputs = gate_scores * encoder_outputs
gated_target = gate_scores * target
adjustment = tf.expand_dims(tf.matmul(gated_encoder_outputs, adjustment_scores, transpose_b=True), axis=2)
output = gated_encoder_outputs + adjustment
return output, gated_encoder_outputs, gated_target
原文地址: http://www.cveoy.top/t/topic/CEU 著作权归作者所有。请勿转载和采集!