基于GAT的半监督学习:训练过程详解
基于GAT的半监督学习:训练过程详解
本文将深入探讨基于图注意力网络(GAT)的半监督学习训练过程,并解析关键代码实现,帮助读者理解GAT在半监督学习中的应用。
训练过程
以下代码展示了GAT模型的训练过程,其中包含了模型前向传播、损失计算、验证等步骤。
for epoch in range(args.epochs):
    t = time.time()
    #   for train
    model.train()
    optimizer.zero_grad()
    output = model(features, adjtensor)
    # 平均输出
    areout = output[1]
    loss_xy = 0
    loss_ncl = 0
    for k in range(len(output[0])):
        # print('k = '+str(k))
        # print(F.nll_loss(output[0][k][idx_train], labels[idx_train]))
        # print(F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel]))
        loss_xy += F.nll_loss(output[0][k][idx_train], labels[idx_train])
        loss_ncl += F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel])
    loss_train = (1-args.lamd)* loss_xy - args.lamd * loss_ncl
    # loss_train = (1 - args.lamd) * loss_xy + args.lamd * 1 / loss_ncl
    # loss_train = (1 - args.lamd) * loss_xy + args.lamd * (torch.exp(-loss_ncl))
    print(loss_xy)
    print(loss_ncl)
    print(torch.exp(-loss_ncl))
    print((1 - args.lamd) * loss_xy)
    print(args.lamd * (torch.exp(-loss_ncl)))
    print(epoch)
    print(loss_train)
    print('.............')
    acc_train = accuracy(areout[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()
    #   for val
    if validate:
        # print('no')
        model.eval()
        output = model(features, adjtensor)
        areout = output[1]
        vl_step = len(idx_val)
        loss_val = F.nll_loss(areout[idx_val], labels[idx_val])
        acc_val = accuracy(areout[idx_val], labels[idx_val])
        # vl_step = len(idx_train)
        # loss_val = F.nll_loss(areout[idx_train], labels[idx_train])
        # acc_val = accuracy(areout[idx_train], labels[idx_train])
        cost_val.append(loss_val)
        # 原始GCN的验证
        # if epoch > args.early_stopping and cost_val[-1] > torch.mean(torch.stack(cost_val[-(args.early_stopping + 1):-1])):
        #     # print('Early stopping...')
        #     print(epoch)
        #     break
        # print(epoch)
        # GAT的验证
        if acc_val/vl_step >= vacc_mx or loss_val/vl_step <= vlss_mn:
            if acc_val/vl_step >= vacc_mx and loss_val/vl_step <= vlss_mn:
                vacc_early_model = acc_val/vl_step
                vlss_early_model = loss_val/vl_step
                torch.save(model, checkpt_file)
            vacc_mx = np.max((vacc_early_model, vacc_mx))
            vlss_mn = np.min((vlss_early_model, vlss_mn))
            curr_step = 0
        else:
            curr_step += 1
            # print(curr_step)
            if curr_step == args.early_stopping:
                # print('Early stop! Min loss: ', vlss_mn, ', Max accuracy: ', vacc_mx)
                # print('Early stop model validation loss: ', vlss_early_model, ', accuracy: ', vacc_early_model)
                break
代码解析
- 
模型前向传播:
output = model(features, adjtensor)使用训练好的模型对输入的特征和邻接矩阵进行前向传播,得到模型的输出。在这个例子中,模型的输出包括两部分:第一部分是一个列表,包含了多个GAT层的输出,每个GAT层的输出都是一个大小为节点数×类别数的矩阵,表示每个节点属于每个类别的概率;第二部分是一个大小为节点数×类别数的矩阵,表示每个节点的平均输出。在这个例子中,我们只使用了第二部分的输出来计算损失和准确率。 - 
损失计算: 代码使用两个损失函数来衡量模型的性能:
loss_xy: 使用交叉熵损失来计算已标记节点的预测结果和真实标签之间的误差。loss_ncl: 使用均方误差来计算未标记节点的平均输出和已标记节点的平均输出之间的误差,用于鼓励模型学习一个一致的表示。
 - 
验证: 在每个epoch结束之后,代码会使用验证集来评估模型的性能,并记录最佳的模型参数。
 - 
早停: 当验证集的性能在连续几个epoch没有提高时,代码会提前结束训练,避免模型过拟合。
 
总结
本文详细介绍了基于GAT的半监督学习训练过程,并解析了关键代码实现,帮助读者了解GAT在半监督学习中的应用。希望本文能够为读者学习和应用GAT提供帮助。
参考资料
原文地址: https://www.cveoy.top/t/topic/ie7I 著作权归作者所有。请勿转载和采集!