# 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from easydict import EasyDict
from typing import Tuple


# 定义HMCAN类,继承自nn.Module
class HMCAN(nn.Module):

    # 初始化方法
    def __init__(self, configs, alpha):
        super(HMCAN, self).__init__()
        self.word_length = configs.max_word_length # 每个词的最大长度
        self.alpha = alpha # 用于加权融合的系数

        # 实例化两个文本图像转换器
        self.contextual_transform = TextImage_Transformer(
            configs.contextual_transform, configs.contextual_transform.output_dim)

        self.contextual_transform2 = TextImage_Transformer(
            configs.contextual_transform, configs.contextual_transform.output_dim)


        # 1x1卷积层,将输入特征图的通道数从2048降到768
        self.conv = nn.Conv2d(2048, 768, 1)
        # 批归一化层,对卷积层输出进行归一化
        self.bn = nn.BatchNorm2d(768)

        # 定义分类器,包含多个全连接层、激活函数和批归一化层
        self.classifier = nn.Sequential(nn.Linear(768*6, 256),
                                        nn.ReLU(True),
                                        nn.BatchNorm1d(256),
                                        nn.Linear(256, 2)
                                        )


    # 前向传播方法
    def forward(self, e, f):
        cap_lengths = len(e) # 获取文本序列的长度

        # 创建两个掩码矩阵,用于指示有效元素的位置
        e_f_mask = torch.ones(cap_lengths, self.word_length).cuda()
        f_e_mask = torch.ones(cap_lengths, 16).cuda()

        # 将e的维度1去掉,e的维度变为[batch_size, 40, 768]
        e = torch.squeeze(e, dim=1)
        # 取e的前self.word_length维度
        e1 = e[:, :self.word_length, :]
        # 取e的[self.word_length, self.word_length*2)维度
        e2 = e[:, self.word_length: self.word_length*2, :]
        # 取e的[self.word_length*2, end)维度
        e3 = e[:, self.word_length*2:, :]
        # e = self.fc(e) # [batch_size, 40, 64]

        # 对f进行卷积操作,然后进行ReLU激活函数,最后进行批归一化操作
        f = F.relu(self.bn(self.conv(f)))
        # 将f的维度变形为[batch_size, 768, 16]
        f = f.view(f.shape[0], f.shape[1], -1)
        # 对f的维度进行转置,变为[batch_size, 16, 768]
        f = f.permute([0, 2, 1])

        # 将e1、e_f_mask和f作为输入,通过上下文转换器进行转换
        c1_e1_f = self.contextual_transform(e1, e_f_mask, f)
        # 将f、f_e_mask和e1作为输入,通过上下文转换器进行转换
        c1_f_e1 = self.contextual_transform2(f, f_e_mask, e1)
        a = self.alpha

        # 根据参数a进行加权融合
        c1 = a * c1_e1_f + (1 - a) * c1_f_e1

        # 将e2、e_f_mask和f作为输入,通过上下文转换器进行转换
        c2_e2_f = self.contextual_transform(e2, e_f_mask, f)
        # 将f、f_e_mask和e2作为输入,通过上下文转换器进行转换
        c2_f_e2 = self.contextual_transform2(f, f_e_mask, e2)

        # 根据参数a进行加权融合
        c2 = a * c2_e2_f + (1 - a) * c2_f_e2

        # 将e3、e_f_mask和f作为输入,通过上下文转换器进行转换
        c3_e3_f = self.contextual_transform(e3, e_f_mask, f)
        # 将f、f_e_mask和e3作为输入,通过上下文转换器进行转换
        c3_f_e3 = self.contextual_transform2(f, f_e_mask, e3)

        # 根据参数a进行加权融合
        c3 = a * c3_e3_f + (1 - a) * c3_f_e3

        # 将c1、c2、c3在维度1上进行拼接
        x = torch.cat((c1, c2, c3), dim=1)
        # 将x通过分类器进行分类
        x = self.classifier(x)

        # 返回分类结果
        return x


# 定义层归一化类,继承自nn.Module
class LayerNormalization(nn.Module):
    def __init__(self, features_count, epsilon=1e-6):
        super().__init__()
        # 初始化增益参数为1,并可训练
        self.gain = nn.Parameter(
            torch.ones(features_count), requires_grad=True)
        # 初始化偏差参数为0,并可训练
        self.bias = nn.Parameter(
            torch.zeros(features_count), requires_grad=True)
        self.epsilon = epsilon

    # 前向传播方法
    def forward(self, x):
        # 沿着最后一个维度计算均值
        mean = x.mean(dim=-1, keepdim=True)
        # 沿着最后一个维度计算标准差
        std = x.std(dim=-1, keepdim=True)
        # 对x进行归一化操作,然后乘以增益参数,并加上偏差参数
        return self.gain * (x - mean) / (std + self.epsilon) + self.bias


# 定义文本图像转换器类,继承自nn.Module
class TextImage_Transformer(nn.Module):
    def __init__(self, ct: EasyDict, feature_dim: int):
        super().__init__()

        # 特征归一化层
        self.input_norm = LayerNormalization(feature_dim)
        input_dim = feature_dim
        # 位置编码层
        self.embedding = PositionalEncoding(
            input_dim, ct.dropout, max_len=1000)

        # Transformer编码层
        self.tf = TransformerEncoder(
            ct.num_layers, input_dim, ct.num_heads, input_dim,
            ct.dropout)

        self.use_context = ct.use_context
        if self.use_context:
            # 上下文编码层
            self.tf_context = TransformerEncoder(
                ct.atn_ct_num_layers, input_dim, ct.atn_ct_num_heads,
                input_dim, ct.dropout)

        # 初始化网络权重
        init_network(self, 0.01)

    # 前向传播方法
    def forward(self, features, mask, hidden_state):
        # 特征归一化
        features = self.input_norm(features)
        # 位置编码
        features = self.embedding(features)
        features
        # Transformer编码
        features = self.tf(features, features, features, mask)
        # 获取上下文编码结果
        add_after_pool = None
        if self.use_context:
            # 对隐藏状态进行上下文编码
            ctx = self.tf_context(
                hidden_state, features, features, mask)
            # 获取上下文编码结果
            add_after_pool = ctx    # ctx.squeeze(1)

        # 对特征进行池化操作
        pooled = torch.mean(features, dim=1)
        # 对上下文编码结果进行池化操作
        add_after_pool = torch.mean(add_after_pool, dim=1)
        # 将特征和上下文编码结果进行拼接
        if add_after_pool is not None:
            pooled = torch.cat([pooled, add_after_pool], dim=-1)
        # 返回拼接结果
        return pooled

# 定义位置编码类,继承自nn.Module
class PositionalEncoding(nn.Module):
    def __init__(self, dim, dropout_prob=0., max_len=1000):
        super().__init__()
        # 初始化位置编码矩阵
        pe = torch.zeros(max_len, dim).float()
        # 创建位置索引
        position = torch.arange(0, max_len).unsqueeze(1).float()
        # 创建维度索引
        dimension = torch.arange(0, dim).float()
        # 计算缩放因子
        div_term = 10000 ** (2 * dimension / dim)
        # 计算位置编码矩阵的奇数列
        pe[:, 0::2] = torch.sin(position / div_term[0::2])
        # 计算位置编码矩阵的偶数列
        pe[:, 1::2] = torch.cos(position / div_term[1::2])
        # 将位置编码矩阵注册为缓冲区
        self.register_buffer('pe', pe)
        # 创建Dropout层
        self.dropout = nn.Dropout(p=dropout_prob)
        # 记录维度
        self.dim = dim

    # 前向传播方法
    def forward(self, x, step=None):
        # 如果没有指定步长,则将位置编码矩阵添加到输入中
        if step is None:
            x = x + self.pe[:x.size(1), :]
        # 否则,将指定步长对应的行添加到输入中
        else:
            x = x + self.pe[:, step]
        # 使用Dropout层
        x = self.dropout(x)
        # 返回结果
        return x


# 定义Transformer编码器类,继承自nn.Module
class TransformerEncoder(nn.Module):
    def __init__(self, layers_count, d_model, heads_count, d_ff, dropout_prob):
        super().__init__()
        self.d_model = d_model # 模型维度
        assert layers_count > 0 # 断言层数大于0
        # 创建一个包含多个Transformer编码层的列表
        self.encoder_layers = nn.ModuleList(
            [TransformerEncoderLayer(
                d_model, heads_count, d_ff, dropout_prob)
                for _ in range(layers_count)])

    # 前向传播方法
    def forward(self, query, key, value, mask):
        # 获取输入的维度信息
        batch_size, query_len, embed_dim = query.shape
        batch_size, key_len, embed_dim = key.shape
        # 创建掩码矩阵,用于指示有效元素的位置
        mask = (1 - mask.unsqueeze(1).expand(batch_size, query_len, key_len))
        mask = mask == 1
        sources = None
        # 对每个Transformer编码层进行前向传播
        for encoder_layer in self.encoder_layers:
            sources = encoder_layer(query, key, value, mask)
        # 返回编码结果
        return sources


# 定义Transformer编码器层类,继承自nn.Module
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, heads_count, d_ff, dropout_prob):
        super(TransformerEncoderLayer, self).__init__()
        # 创建自注意力层
        self.self_attention_layer = Sublayer(
            MultiHeadAttention(heads_count, d_model, dropout_prob), d_model)
        # 创建点式前馈网络层
        self.pointwise_feedforward_layer = Sublayer(
            PointwiseFeedForwardNetwork(d_ff, d_model, dropout_prob), d_model)
        # 创建Dropout层
        self.dropout = nn.Dropout(dropout_prob)

    # 前向传播方法
    def forward(self, query, key, value, sources_mask):
        # 自注意力层
        sources = self.self_attention_layer(query, key, value, sources_mask)
        # Dropout层
        sources = self.dropout(sources)
        # 点式前馈网络层
        sources = self.pointwise_feedforward_layer(sources)
        # 返回结果
        return sources


# 定义子层类,继承自nn.Module
class Sublayer(nn.Module):
    def __init__(self, sublayer, d_model):
        super(Sublayer, self).__init__()
        # 子层
        self.sublayer = sublayer
        # 层归一化
        self.layer_normalization = LayerNormalization(d_model)

    # 前向传播方法
    def forward(self, *args):
        # 获取第一个参数
        x = args[0]
        # 对子层进行前向传播
        x = self.sublayer(*args) + x
        # 层归一化
        return self.layer_normalization(x)


# 定义多头注意力类,继承自nn.Module
class MultiHeadAttention(nn.Module):
    def __init__(self, heads_count, d_model, dropout_prob):
        super().__init__()
        assert d_model % heads_count == 0,
            f'model dim {d_model} not divisible by {heads_count} heads'
        # 计算每个头的维度
        self.d_head = d_model // heads_count
        # 头的数量
        self.heads_count = heads_count
        # 创建查询投影层
        self.query_projection = nn.Linear(d_model, heads_count * self.d_head)
        # 创建键投影层
        self.key_projection = nn.Linear(d_model, heads_count * self.d_head)
        # 创建值投影层
        self.value_projection = nn.Linear(d_model, heads_count * self.d_head)
        # 创建最终投影层
        self.final_projection = nn.Linear(d_model, heads_count * self.d_head)
        # 创建Dropout层
        self.dropout = nn.Dropout(dropout_prob)
        # 创建Softmax层
        self.softmax = nn.Softmax(dim=3)
        self.attention = None

    # 前向传播方法
    def forward(self, query, key, value, mask=None):
        # 获取输入的维度信息
        batch_size, query_len, d_model = query.size()
        # 计算每个头的维度
        d_head = d_model // self.heads_count
        # 对查询进行投影
        query_projected = self.query_projection(query)
        # 对键进行投影
        key_projected = self.key_projection(key)
        # 对值进行投影
        value_projected = self.value_projection(value)
        # 获取投影后的维度信息
        batch_size, key_len, d_model = key_projected.size()
        batch_size, value_len, d_model = value_projected.size()
        # 将查询、键和值reshape为多头形式
        query_heads = query_projected.view(
            batch_size, query_len, self.heads_count, d_head).transpose(1, 2)
        key_heads = key_projected.view(
            batch_size, key_len, self.heads_count, d_head).transpose(1, 2)
        value_heads = value_projected.view(
            batch_size, value_len, self.heads_count, d_head).transpose(1, 2)
        # 计算缩放点积
        attention_weights = self.scaled_dot_product(
            query_heads, key_heads)
        # 如果存在掩码,则将掩码应用于注意力权重
        if mask is not None:
            mask_expanded = mask.unsqueeze(1).expand_as(attention_weights)
            attention_weights = attention_weights.masked_fill(
                mask_expanded, -1e18)
        # 计算注意力矩阵
        attention = self.softmax(attention_weights)
        # 对注意力矩阵进行Dropout
        attention_dropped = self.dropout(attention)
        # 计算上下文
        context_heads = torch.matmul(
            attention_dropped, value_heads)
        # 将上下文reshape为序列形式
        context_sequence = context_heads.transpose(1, 2)
        # 将上下文reshape为原始形式
        context = context_sequence.reshape(
            batch_size, query_len, d_model)
        # 对上下文进行最终投影
        final_output = self.final_projection(context)
        # 返回结果
        return final_output

    # 计算缩放点积
    def scaled_dot_product(self, query_heads, key_heads):
        # 对键进行转置
        key_heads_transposed = key_heads.transpose(2, 3)
        # 计算点积
        dot_product = torch.matmul(
            query_heads, key_heads_transposed)
        # 计算注意力权重
        attention_weights = dot_product / np.sqrt(self.d_head)
        # 返回结果
        return attention_weights


# 定义点式前馈网络类,继承自nn.Module
class PointwiseFeedForwardNetwork(nn.Module):
    def __init__(self, d_ff, d_model, dropout_prob):
        super(PointwiseFeedForwardNetwork, self).__init__()
        # 创建点式前馈网络,包含多个线性层、激活函数和Dropout层
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.Dropout(dropout_prob),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout_prob))

    # 前向传播方法
    def forward(self, x):
        # 对输入进行前馈传播
        return self.feed_forward(x)

# 定义截断正态分布填充函数
def truncated_normal_fill(
        shape: Tuple[int], mean: float = 0, std: float = 1,
        limit: float = 2) -> torch.Tensor:
    # 生成随机数
    num_examples = 8
    tmp = torch.empty(shape + (num_examples,)).normal_()
    # 判断随机数是否在范围内
    valid = (tmp < limit) & (tmp > -limit)
    # 获取最大值及其索引
    _, ind = valid.max(-1, keepdim=True)
    # 使用索引获取有效随机数并进行缩放和偏移
    return tmp.gather(-1, ind).squeeze(-1).mul_(std).add_(mean)

# 定义权重初始化函数
def init_weight_(w, init_gain=1):

    # 使用截断正态分布填充权重
    w.copy_(truncated_normal_fill(w.shape, std=init_gain))


# 定义网络初始化函数
def init_network(net: nn.Module, init_std: float):

    # 遍历网络的所有参数
    for key, val in net.named_parameters():
        # 如果参数名为'weight'或'bias',则使用截断正态分布初始化
        if 'weight' in key or 'bias' in key:
            init_weight_(val.data, init_std)

原文地址: https://www.cveoy.top/t/topic/qtAl 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录