Python 代码解析:使用 RNN 生成条件样本
这是一个用于生成样本的函数,根据给定的标签生成样本。如果没有给定随机噪声'z',则会生成一个符合正态分布的随机噪声。然后,使用随机噪声和标签来初始化模型的初始状态,并使用循环神经网络生成样本。最终,将生成的样本拼接在一起并返回输出。其中,'max_len' 参数表示生成样本的最大长度。
def sample(self, labels, z=None, max_len=125):
' generates samples conditioned on the given label '
num_examples = int(labels.shape[0])
if z is None:
z = tf.random.normal(shape=(num_examples, self.decoder.rnn_units))
last_pred = tf.zeros(shape=(num_examples, 1, self.num_feats))
preds = []
last_state = tf.reshape(self.noise2hidden(z), [3, num_examples, -1])
z_emb = self.fc_z(z)#fc_z: Dense = layers.Dense(6)
z_with_time = tf.expand_dims(z_emb, axis=1)
for _ in range(max_len):
if self.z_context:
step_input = tf.concat([z_with_time, last_pred], axis=2)
else:
step_input = last_pred
last_pred, last_state = self.decoder(step_input, last_state, labels)
preds.append(last_pred)
output = tf.concat(preds, axis=1)
return output
原文地址: https://www.cveoy.top/t/topic/nh1z 著作权归作者所有。请勿转载和采集!