多头注意机制pytorch代码并注释每个变量的维度
以下是PyTorch的多头自注意机制代码,其中包含有注释:
import torch
import torch.nn as nn
class MultiheadAttention(nn.Module):
def __init__(self, n_heads, d_model):
super(MultiheadAttention, self).__init__()
self.n_heads = n_heads
self.d_model = d_model
self.head_dim = d_model // n_heads
self.query = nn.Linear(d_model, d_model)
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
self.fc = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
# query, key, and value have shape (batch_size, seq_len, d_model)
batch_size = query.shape[0]
# Project query, key, and value to n_heads * head_dim dimensions
query = self.query(query) # (batch_size, seq_len, d_model) -> (batch_size, seq_len, n_heads * head_dim)
key = self.key(key) # (batch_size, seq_len, d_model) -> (batch_size, seq_len, n_heads * head_dim)
value = self.value(value) # (batch_size, seq_len, d_model) -> (batch_size, seq_len, n_heads * head_dim)
# Split query, key, and value into n_heads
query = query.view(batch_size, -1, self.n_heads, self.head_dim) # (batch_size, seq_len, n_heads * head_dim) -> (batch_size, seq_len, n_heads, head_dim)
key = key.view(batch_size, -1, self.n_heads, self.head_dim) # (batch_size, seq_len, n_heads * head_dim) -> (batch_size, seq_len, n_heads, head_dim)
value = value.view(batch_size, -1, self.n_heads, self.head_dim) # (batch_size, seq_len, n_heads * head_dim) -> (batch_size, seq_len, n_heads, head_dim)
# Transpose query, key, and value to (batch_size, n_heads, seq_len, head_dim)
query = query.transpose(1, 2) # (batch_size, seq_len, n_heads, head_dim) -> (batch_size, n_heads, seq_len, head_dim)
key = key.transpose(1, 2) # (batch_size, seq_len, n_heads, head_dim) -> (batch_size, n_heads, seq_len, head_dim)
value = value.transpose(1, 2) # (batch_size, seq_len, n_heads, head_dim) -> (batch_size, n_heads, seq_len, head_dim)
# Compute dot product of query and key for each head
scores = torch.matmul(query, key.transpose(-2, -1)) # (batch_size, n_heads, seq_len, head_dim) * (batch_size, n_heads, head_dim, seq_len) -> (batch_size, n_heads, seq_len, seq_len)
scores = scores / (self.head_dim ** 0.5) # Scale by square root of head_dim to avoid large gradients
# Apply mask (if provided)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Apply softmax activation function to obtain attention weights
attn_weights = torch.softmax(scores, dim=-1) # (batch_size, n_heads, seq_len, seq_len)
# Apply attention weights to value for each head and concatenate
attn_output = torch.matmul(attn_weights, value) # (batch_size, n_heads, seq_len, seq_len) * (batch_size, n_heads, seq_len, head_dim) -> (batch_size, n_heads, seq_len, head_dim)
attn_output = attn_output.transpose(1, 2) # (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
attn_output = attn_output.contiguous().view(batch_size, -1, self.d_model) # (batch_size, seq_len, n_heads, head_dim) -> (batch_size, seq_len, d_model)
# Project concatenated output through fully connected layer
attn_output = self.fc(attn_output) # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
return attn_output, attn_weights
注释说明:
n_heads:注意力头的数量。d_model:模型的维度。head_dim:每个注意力头的维度,等于d_model // n_heads。query、key、value:输入的查询、键和值向量,形状为(batch_size, seq_len, d_model)。其中,seq_len是序列长度,batch_size是批次大小,d_model是模型的维度。batch_size:批次大小。self.query、self.key、self.value:将输入的查询、键和值向量投影到n_heads * head_dim维度的线性层。query、key、value:将输入的查询、键和值向量通过线性层进行投影后的输出,形状为(batch_size, seq_len, n_heads * head_dim)。query、key、value:将投影后的查询、键和值向量沿着最后一维分成n_heads个头,形状为(batch_size, seq_len, n_heads, head_dim)。query、key、value:将沿着最后一维分成n_heads个头的查询、键和值向量进行转置,变成(batch_size, n_heads, seq_len, head_dim)。scores:将每个头的查询向量和键向量进行点积,得到的注意力分数张量,形状为(batch_size, n_heads, seq_len, seq_len)。scores:将注意力分数张量进行缩放,除以 $\sqrt{\text{head_dim}}$,避免梯度爆炸。mask:掩码张量,形状为(batch_size, seq_len)或(batch_size, 1, seq_len, seq_len),用于在计算注意力分数时屏蔽某些位置。scores:将掩码张量应用于注意力分数张量,将需要屏蔽的位置的分数设置为负无穷大。attn_weights:将注意力分数张量应用 softmax 激活函数,得到的注意力权重张量,形状为(batch_size, n_heads, seq_len, seq_len)。attn_output:将每个头的注意力权重张量和值向量进行加权求和,得到的注意力输出张量,形状为(batch_size, n_heads, seq_len, head_dim)。attn_output:将n_heads个注意力头的输出张量进行转置和重塑,变成(batch_size, seq_len, d_model)。self.fc:将重塑后的注意力输出张量通过一个全连接层进行投影,得到的最终输出张量,形状为(batch_size, seq_len, d_model)
原文地址: https://www.cveoy.top/t/topic/fmxo 著作权归作者所有。请勿转载和采集!