以下是PyTorch的多头自注意机制代码,其中包含有注释:

import torch
import torch.nn as nn

class MultiheadAttention(nn.Module):
    def __init__(self, n_heads, d_model):
        super(MultiheadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_dim = d_model // n_heads
        
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        self.fc = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask=None):
        # query, key, and value have shape (batch_size, seq_len, d_model)
        batch_size = query.shape[0]
        
        # Project query, key, and value to n_heads * head_dim dimensions
        query = self.query(query)  # (batch_size, seq_len, d_model) -> (batch_size, seq_len, n_heads * head_dim)
        key = self.key(key)        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, n_heads * head_dim)
        value = self.value(value)  # (batch_size, seq_len, d_model) -> (batch_size, seq_len, n_heads * head_dim)
        
        # Split query, key, and value into n_heads
        query = query.view(batch_size, -1, self.n_heads, self.head_dim)    # (batch_size, seq_len, n_heads * head_dim) -> (batch_size, seq_len, n_heads, head_dim)
        key = key.view(batch_size, -1, self.n_heads, self.head_dim)        # (batch_size, seq_len, n_heads * head_dim) -> (batch_size, seq_len, n_heads, head_dim)
        value = value.view(batch_size, -1, self.n_heads, self.head_dim)    # (batch_size, seq_len, n_heads * head_dim) -> (batch_size, seq_len, n_heads, head_dim)
        
        # Transpose query, key, and value to (batch_size, n_heads, seq_len, head_dim)
        query = query.transpose(1, 2)    # (batch_size, seq_len, n_heads, head_dim) -> (batch_size, n_heads, seq_len, head_dim)
        key = key.transpose(1, 2)        # (batch_size, seq_len, n_heads, head_dim) -> (batch_size, n_heads, seq_len, head_dim)
        value = value.transpose(1, 2)    # (batch_size, seq_len, n_heads, head_dim) -> (batch_size, n_heads, seq_len, head_dim)
        
        # Compute dot product of query and key for each head
        scores = torch.matmul(query, key.transpose(-2, -1))    # (batch_size, n_heads, seq_len, head_dim) * (batch_size, n_heads, head_dim, seq_len) -> (batch_size, n_heads, seq_len, seq_len)
        scores = scores / (self.head_dim ** 0.5)                # Scale by square root of head_dim to avoid large gradients
        
        # Apply mask (if provided)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax activation function to obtain attention weights
        attn_weights = torch.softmax(scores, dim=-1)    # (batch_size, n_heads, seq_len, seq_len)
        
        # Apply attention weights to value for each head and concatenate
        attn_output = torch.matmul(attn_weights, value)    # (batch_size, n_heads, seq_len, seq_len) * (batch_size, n_heads, seq_len, head_dim) -> (batch_size, n_heads, seq_len, head_dim)
        attn_output = attn_output.transpose(1, 2)          # (batch_size, n_heads, seq_len, head_dim) -> (batch_size, seq_len, n_heads, head_dim)
        attn_output = attn_output.contiguous().view(batch_size, -1, self.d_model)  # (batch_size, seq_len, n_heads, head_dim) -> (batch_size, seq_len, d_model)
        
        # Project concatenated output through fully connected layer
        attn_output = self.fc(attn_output)  # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
        
        return attn_output, attn_weights

注释说明:

  • n_heads:注意力头的数量。
  • d_model:模型的维度。
  • head_dim:每个注意力头的维度,等于 d_model // n_heads
  • querykeyvalue:输入的查询、键和值向量,形状为 (batch_size, seq_len, d_model)。其中,seq_len 是序列长度,batch_size 是批次大小,d_model 是模型的维度。
  • batch_size:批次大小。
  • self.queryself.keyself.value:将输入的查询、键和值向量投影到 n_heads * head_dim 维度的线性层。
  • querykeyvalue:将输入的查询、键和值向量通过线性层进行投影后的输出,形状为 (batch_size, seq_len, n_heads * head_dim)
  • querykeyvalue:将投影后的查询、键和值向量沿着最后一维分成 n_heads 个头,形状为 (batch_size, seq_len, n_heads, head_dim)
  • querykeyvalue:将沿着最后一维分成 n_heads 个头的查询、键和值向量进行转置,变成 (batch_size, n_heads, seq_len, head_dim)
  • scores:将每个头的查询向量和键向量进行点积,得到的注意力分数张量,形状为 (batch_size, n_heads, seq_len, seq_len)
  • scores:将注意力分数张量进行缩放,除以 $\sqrt{\text{head_dim}}$,避免梯度爆炸。
  • mask:掩码张量,形状为 (batch_size, seq_len)(batch_size, 1, seq_len, seq_len),用于在计算注意力分数时屏蔽某些位置。
  • scores:将掩码张量应用于注意力分数张量,将需要屏蔽的位置的分数设置为负无穷大。
  • attn_weights:将注意力分数张量应用 softmax 激活函数,得到的注意力权重张量,形状为 (batch_size, n_heads, seq_len, seq_len)
  • attn_output:将每个头的注意力权重张量和值向量进行加权求和,得到的注意力输出张量,形状为 (batch_size, n_heads, seq_len, head_dim)
  • attn_output:将 n_heads 个注意力头的输出张量进行转置和重塑,变成 (batch_size, seq_len, d_model)
  • self.fc:将重塑后的注意力输出张量通过一个全连接层进行投影,得到的最终输出张量,形状为 (batch_size, seq_len, d_model)

原文地址: https://www.cveoy.top/t/topic/fmxo 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录