class MultiInferGCNModel(torch.nn.Module):
    '''
    基于多跳推理图卷积神经网络的多标签文本分类模型
    '''
    def __init__(self, gen_emb, domain_emb, args):
        '''
        模型初始化
        :param gen_emb: 通用词嵌入矩阵
        :param domain_emb: 领域词嵌入矩阵
        :param args: 模型参数
        '''
        super(MultiInferGCNModel, self).__init__()
        self.args = args
        # 通用词嵌入层
        self.gen_embedding = torch.nn.Embedding(gen_emb.shape[0], gen_emb.shape[1])
        self.gen_embedding.weight.data.copy_(gen_emb)
        self.gen_embedding.weight.requires_grad = False
        # 领域词嵌入层
        self.domain_embedding = torch.nn.Embedding(domain_emb.shape[0], domain_emb.shape[1])
        self.domain_embedding.weight.data.copy_(domain_emb)
        self.domain_embedding.weight.requires_grad = False

        # Dropout层
        self.dropout = torch.nn.Dropout(0.5)

        # 双向LSTM层
        self.bilstm = torch.nn.LSTM(300 + 100, args.lstm_dim,
                                    num_layers=1, batch_first=True, bidirectional=True)
        self.dropout = torch.nn.Dropout(0.5)

        # 图卷积层
        self.gconv1 = dglnn.SAGEConv(
            in_feats=100, out_feats=100, aggregator_type='lstm')
        self.gconv2 = dglnn.SAGEConv(
            in_feats=100, out_feats=100, aggregator_type='lstm')
        self.gconv3 = dglnn.SAGEConv(
            in_feats=100, out_feats=100, aggregator_type='lstm')

        # 特征线性变换层
        self.feature_linear = torch.nn.Linear(args.lstm_dim * 4 + args.class_num * 3, args.lstm_dim * 4)
        # 分类线性变换层
        self.cls_linear = torch.nn.Linear(args.lstm_dim * 4, args.class_num)

        # 激活函数
        self.activation = torch.nn.ReLU()

    def _build_batched_graph(self, embedding, src_ids, dst_ids, edge_nums):
        '''
        构建批处理图
        :param embedding: 节点嵌入
        :param src_ids: 源节点id
        :param dst_ids: 目标节点id
        :param edge_nums: 每条边的数量
        :return: 批处理图
        '''
        batched_graph = []
        batch_size = embedding.shape[0]

        # 遍历每个样本
        for i in range(batch_size):
            # 获取边的数量
            edge_num = edge_nums[i]
            # 获取源节点和目标节点的id
            src_id, dst_id = src_ids[i][:edge_num], dst_ids[i][:edge_num]
            # 创建图
            graph = dgl.graph((src_id, dst_id), num_nodes=self.args.max_sequence_len, idtype=torch.int32)
            # 设置节点特征
            graph.ndata['in_feat'] = embedding[i]
            # 添加到批处理图中
            batched_graph.append(graph)

        # 返回批处理图
        return dgl.batch(batched_graph)

    def _get_embedding(self, sentence_tokens, mask):
        '''
        获取词嵌入
        :param sentence_tokens: 句子token
        :param mask: mask矩阵
        :return: 词嵌入
        '''
        # 获取通用词嵌入
        gen_embed = self.gen_embedding(sentence_tokens)
        # 获取领域词嵌入
        domain_embed = self.domain_embedding(sentence_tokens)
        # 将通用词嵌入和领域词嵌入拼接在一起
        embedding = torch.cat([gen_embed, domain_embed], dim=2)
        # 加上dropout
        embedding = self.dropout(embedding)
        # 将padding的部分用0进行mask
        embedding = embedding * mask.unsqueeze(2).float().expand_as(embedding)
        return embedding

    def _gcn_feature(self, g, embedding):
        '''
        获取GCN特征
        :param g: 图
        :param embedding: 节点嵌入
        :return: GCN特征
        '''
        # 第一层GCN
        h = F.relu(self.gconv1(g, embedding))
        # 第二层GCN
        h = F.relu(self.gconv2(g, h))
        # 第三层GCN
        h = F.relu(self.gconv3(g, h))
        return h

    def _lstm_feature(self, embedding):
        '''
        获取LSTM特征
        :param embedding: 词嵌入
        :return: LSTM特征
        '''
        # 双向LSTM
        context, _ = self.bilstm(embedding)
        return context

    def _cls_logits(self, features):
        '''
        获取分类logits
        :param features: 特征
        :return: 分类logits
        '''
        # 线性变换
        tags = self.cls_linear(features)
        return tags

    def multi_hops(self, features, lengths, mask, k, pos):
        '''
        多跳推理
        :param features: 特征
        :param lengths: 句子长度
        :param mask: mask矩阵
        :param k: 跳数
        :param pos: 位置编码
        :return: 多跳推理结果
        '''
        # 生成mask
        max_length = features.shape[1]
        mask = mask[:, :max_length]
        mask_a = mask.unsqueeze(1).expand([-1, max_length, -1])
        mask_b = mask.unsqueeze(2).expand([-1, -1, max_length])
        mask = mask_a * mask_b
        mask = torch.triu(mask).unsqueeze(3).expand([-1, -1, -1, self.args.class_num])

        # 保存所有logits
        logits_list = []
        logits = self._cls_logits(features)
        logits_list.append(logits)

        # 多跳推理
        for i in range(k):
            # probs = torch.softmax(logits, dim=3)
            probs = logits
            # 乘上mask
            logits = probs * mask
            # 获取最大值
            logits_a = torch.max(logits, dim=1)[0]
            logits_b = torch.max(logits, dim=2)[0]
            logits = torch.cat([logits_a.unsqueeze(3), logits_b.unsqueeze(3)], dim=3)
            logits = torch.max(logits, dim=3)[0]
            # 维度变换
            logits = logits.unsqueeze(2).expand([-1, -1, max_length, -1])
            logits_T = logits.transpose(1, 2)
            logits = torch.cat([logits, logits_T], dim=3)
            # 特征拼接
            new_features = torch.cat([features, logits, probs], dim=3)
            # 线性变换
            features = self.feature_linear(new_features)
            logits = self._cls_logits(features)
            logits_list.append(logits)
        # 返回多跳推理结果
        return logits_list

    def forward(self, sentence_tokens, lengths, mask, dependencies, src_ids, dst_ids, edge_nums, pos):
        '''
        模型前向传播
        :param sentence_tokens: 句子token
        :param lengths: 句子长度
        :param mask: mask矩阵
        :param dependencies: 依存关系矩阵
        :param src_ids: 源节点id
        :param dst_ids: 目标节点id
        :param edge_nums: 每条边的数量
        :param pos: 位置编码
        :return: 分类结果
        '''
        # 获取词嵌入
        embedding = self._get_embedding(sentence_tokens, mask)
        # 获取LSTM特征
        lstm_feature = self._lstm_feature(embedding)

        # 构建DGL图
        batched_graph = self._build_batched_graph(lstm_feature, src_ids, dst_ids, edge_nums)

        # 获取GCN特征
        in_feat = batched_graph.ndata['in_feat']
        h_feat = self._gcn_feature(batched_graph, in_feat)
        batched_graph.ndata['h_feat'] = h_feat

        # 获取GCN特征
        h_feat = batched_graph.ndata['h_feat']
        gcn_feature = h_feat.view(lengths.shape[0], self.args.max_sequence_len, 100)
        gcn_feature = gcn_feature[:, :lengths[0], :]

        # 特征拼接
        gcn_feature = gcn_feature.unsqueeze(2).expand([-1, -1, lengths[0], -1])
        gcn_feature_T = gcn_feature.transpose(1, 2)
        features = torch.cat([gcn_feature, gcn_feature_T], dim=3)

        # 多次迭代
        logits = self.multi_hops(features, lengths, mask, self.args.nhops, pos)
        # 返回最后一次迭代的结果
        return [logits[-1]]
基于多跳推理图卷积神经网络的多标签文本分类模型

原文地址: https://www.cveoy.top/t/topic/jTRl 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录