基于多跳推理图卷积神经网络的多标签文本分类模型
class MultiInferGCNModel(torch.nn.Module):
'''
基于多跳推理图卷积神经网络的多标签文本分类模型
'''
def __init__(self, gen_emb, domain_emb, args):
'''
模型初始化
:param gen_emb: 通用词嵌入矩阵
:param domain_emb: 领域词嵌入矩阵
:param args: 模型参数
'''
super(MultiInferGCNModel, self).__init__()
self.args = args
# 通用词嵌入层
self.gen_embedding = torch.nn.Embedding(gen_emb.shape[0], gen_emb.shape[1])
self.gen_embedding.weight.data.copy_(gen_emb)
self.gen_embedding.weight.requires_grad = False
# 领域词嵌入层
self.domain_embedding = torch.nn.Embedding(domain_emb.shape[0], domain_emb.shape[1])
self.domain_embedding.weight.data.copy_(domain_emb)
self.domain_embedding.weight.requires_grad = False
# Dropout层
self.dropout = torch.nn.Dropout(0.5)
# 双向LSTM层
self.bilstm = torch.nn.LSTM(300 + 100, args.lstm_dim,
num_layers=1, batch_first=True, bidirectional=True)
self.dropout = torch.nn.Dropout(0.5)
# 图卷积层
self.gconv1 = dglnn.SAGEConv(
in_feats=100, out_feats=100, aggregator_type='lstm')
self.gconv2 = dglnn.SAGEConv(
in_feats=100, out_feats=100, aggregator_type='lstm')
self.gconv3 = dglnn.SAGEConv(
in_feats=100, out_feats=100, aggregator_type='lstm')
# 特征线性变换层
self.feature_linear = torch.nn.Linear(args.lstm_dim * 4 + args.class_num * 3, args.lstm_dim * 4)
# 分类线性变换层
self.cls_linear = torch.nn.Linear(args.lstm_dim * 4, args.class_num)
# 激活函数
self.activation = torch.nn.ReLU()
def _build_batched_graph(self, embedding, src_ids, dst_ids, edge_nums):
'''
构建批处理图
:param embedding: 节点嵌入
:param src_ids: 源节点id
:param dst_ids: 目标节点id
:param edge_nums: 每条边的数量
:return: 批处理图
'''
batched_graph = []
batch_size = embedding.shape[0]
# 遍历每个样本
for i in range(batch_size):
# 获取边的数量
edge_num = edge_nums[i]
# 获取源节点和目标节点的id
src_id, dst_id = src_ids[i][:edge_num], dst_ids[i][:edge_num]
# 创建图
graph = dgl.graph((src_id, dst_id), num_nodes=self.args.max_sequence_len, idtype=torch.int32)
# 设置节点特征
graph.ndata['in_feat'] = embedding[i]
# 添加到批处理图中
batched_graph.append(graph)
# 返回批处理图
return dgl.batch(batched_graph)
def _get_embedding(self, sentence_tokens, mask):
'''
获取词嵌入
:param sentence_tokens: 句子token
:param mask: mask矩阵
:return: 词嵌入
'''
# 获取通用词嵌入
gen_embed = self.gen_embedding(sentence_tokens)
# 获取领域词嵌入
domain_embed = self.domain_embedding(sentence_tokens)
# 将通用词嵌入和领域词嵌入拼接在一起
embedding = torch.cat([gen_embed, domain_embed], dim=2)
# 加上dropout
embedding = self.dropout(embedding)
# 将padding的部分用0进行mask
embedding = embedding * mask.unsqueeze(2).float().expand_as(embedding)
return embedding
def _gcn_feature(self, g, embedding):
'''
获取GCN特征
:param g: 图
:param embedding: 节点嵌入
:return: GCN特征
'''
# 第一层GCN
h = F.relu(self.gconv1(g, embedding))
# 第二层GCN
h = F.relu(self.gconv2(g, h))
# 第三层GCN
h = F.relu(self.gconv3(g, h))
return h
def _lstm_feature(self, embedding):
'''
获取LSTM特征
:param embedding: 词嵌入
:return: LSTM特征
'''
# 双向LSTM
context, _ = self.bilstm(embedding)
return context
def _cls_logits(self, features):
'''
获取分类logits
:param features: 特征
:return: 分类logits
'''
# 线性变换
tags = self.cls_linear(features)
return tags
def multi_hops(self, features, lengths, mask, k, pos):
'''
多跳推理
:param features: 特征
:param lengths: 句子长度
:param mask: mask矩阵
:param k: 跳数
:param pos: 位置编码
:return: 多跳推理结果
'''
# 生成mask
max_length = features.shape[1]
mask = mask[:, :max_length]
mask_a = mask.unsqueeze(1).expand([-1, max_length, -1])
mask_b = mask.unsqueeze(2).expand([-1, -1, max_length])
mask = mask_a * mask_b
mask = torch.triu(mask).unsqueeze(3).expand([-1, -1, -1, self.args.class_num])
# 保存所有logits
logits_list = []
logits = self._cls_logits(features)
logits_list.append(logits)
# 多跳推理
for i in range(k):
# probs = torch.softmax(logits, dim=3)
probs = logits
# 乘上mask
logits = probs * mask
# 获取最大值
logits_a = torch.max(logits, dim=1)[0]
logits_b = torch.max(logits, dim=2)[0]
logits = torch.cat([logits_a.unsqueeze(3), logits_b.unsqueeze(3)], dim=3)
logits = torch.max(logits, dim=3)[0]
# 维度变换
logits = logits.unsqueeze(2).expand([-1, -1, max_length, -1])
logits_T = logits.transpose(1, 2)
logits = torch.cat([logits, logits_T], dim=3)
# 特征拼接
new_features = torch.cat([features, logits, probs], dim=3)
# 线性变换
features = self.feature_linear(new_features)
logits = self._cls_logits(features)
logits_list.append(logits)
# 返回多跳推理结果
return logits_list
def forward(self, sentence_tokens, lengths, mask, dependencies, src_ids, dst_ids, edge_nums, pos):
'''
模型前向传播
:param sentence_tokens: 句子token
:param lengths: 句子长度
:param mask: mask矩阵
:param dependencies: 依存关系矩阵
:param src_ids: 源节点id
:param dst_ids: 目标节点id
:param edge_nums: 每条边的数量
:param pos: 位置编码
:return: 分类结果
'''
# 获取词嵌入
embedding = self._get_embedding(sentence_tokens, mask)
# 获取LSTM特征
lstm_feature = self._lstm_feature(embedding)
# 构建DGL图
batched_graph = self._build_batched_graph(lstm_feature, src_ids, dst_ids, edge_nums)
# 获取GCN特征
in_feat = batched_graph.ndata['in_feat']
h_feat = self._gcn_feature(batched_graph, in_feat)
batched_graph.ndata['h_feat'] = h_feat
# 获取GCN特征
h_feat = batched_graph.ndata['h_feat']
gcn_feature = h_feat.view(lengths.shape[0], self.args.max_sequence_len, 100)
gcn_feature = gcn_feature[:, :lengths[0], :]
# 特征拼接
gcn_feature = gcn_feature.unsqueeze(2).expand([-1, -1, lengths[0], -1])
gcn_feature_T = gcn_feature.transpose(1, 2)
features = torch.cat([gcn_feature, gcn_feature_T], dim=3)
# 多次迭代
logits = self.multi_hops(features, lengths, mask, self.args.nhops, pos)
# 返回最后一次迭代的结果
return [logits[-1]]
原文地址: https://www.cveoy.top/t/topic/jTRl 著作权归作者所有。请勿转载和采集!