用tensorflow写self attention的代码
下面是使用TensorFlow实现self-attention的代码示例:
import tensorflow as tf
class SelfAttention(tf.keras.layers.Layer):
def __init__(self, num_heads, d_model):
super(SelfAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
q = self.wq(inputs)
k = self.wk(inputs)
v = self.wv(inputs)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
scaled_attention_logits = tf.matmul(q, k, transpose_b=True)
scaled_attention_logits = scaled_attention_logits / tf.math.sqrt(tf.cast(self.depth, tf.float32))
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
output = tf.matmul(attention_weights, v)
output = tf.transpose(output, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(output, (batch_size, -1, self.d_model))
output = self.dense(concat_attention)
return output, attention_weights
在上面的代码中,我们定义了一个SelfAttention类,它继承自tf.keras.layers.Layer。在初始化函数中,我们指定了注意力头数和模型维度。在call函数中,我们首先使用三个全连接层来计算查询(q)、键(k)和值(v)。然后,我们将它们分别拆分成多个头,计算注意力权重,并将值乘以权重得到输出。最后,我们将输出拼接在一起并通过一个全连接层进行处理,得到最终的输出。
使用该SelfAttention层的示例代码如下:
import numpy as np
num_heads = 8
d_model = 512
self_attention = SelfAttention(num_heads, d_model)
batch_size = 2
seq_len = 10
embedding_dim = d_model
inputs = np.random.rand(batch_size, seq_len, embedding_dim)
outputs, attention_weights = self_attention(inputs)
print(outputs.shape)
print(attention_weights.shape)
在这个示例中,我们首先创建了一个SelfAttention层,并指定了头数和模型维度。然后,我们生成了一个随机的输入张量,并将其传递给SelfAttention层。最后,我们打印输出张量和注意力权重张量的形状。
原文地址: https://www.cveoy.top/t/topic/bYtl 著作权归作者所有。请勿转载和采集!