PyTorch 训练循环:基于 GCN/GAT 的半监督学习示例
PyTorch 半监督学习训练循环示例
该示例代码演示了使用 PyTorch 训练 GCN 或 GAT 模型进行半监督学习的完整过程。
for epoch in range(args.epochs):
    t = time.time()
    # 训练模式
    model.train()
    optimizer.zero_grad()
    output = model(features, adjtensor)
    # 平均输出
    areout = output[1]
    loss_xy = 0
    loss_ncl = 0
    for k in range(len(output[0])):
        # print('k = ' + str(k))
        # print(F.nll_loss(output[0][k][idx_train], labels[idx_train]))
        # print(F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel]))
        loss_xy += F.nll_loss(output[0][k][idx_train], labels[idx_train])
        loss_ncl += F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel])
    loss_train = (1 - args.lamd) * loss_xy - args.lamd * loss_ncl
    # loss_train = (1 - args.lamd) * loss_xy + args.lamd * 1 / loss_ncl
    # loss_train = (1 - args.lamd) * loss_xy + args.lamd * (torch.exp(-loss_ncl))
    print(loss_xy)
    print(loss_ncl)
    print(torch.exp(-loss_ncl))
    print((1 - args.lamd) * loss_xy)
    print(args.lamd * (torch.exp(-loss_ncl)))
    print(epoch)
    print(loss_train)
    print('.............')
    acc_train = accuracy(areout[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()
    # 验证模式
    if validate:
        # print('no')
        model.eval()
        output = model(features, adjtensor)
        areout = output[1]
        vl_step = len(idx_val)
        loss_val = F.nll_loss(areout[idx_val], labels[idx_val])
        acc_val = accuracy(areout[idx_val], labels[idx_val])
        # vl_step = len(idx_train)
        # loss_val = F.nll_loss(areout[idx_train], labels[idx_train])
        # acc_val = accuracy(areout[idx_train], labels[idx_train])
        cost_val.append(loss_val)
        # 原始 GCN 的验证
        # if epoch > args.early_stopping and cost_val[-1] > torch.mean(torch.stack(cost_val[-(args.early_stopping + 1):-1])):
        #     # print('Early stopping...')
        #     print(epoch)
        #     break
        # print(epoch)
        # GAT 的验证
        if acc_val / vl_step >= vacc_mx or loss_val / vl_step <= vlss_mn:
            if acc_val / vl_step >= vacc_mx and loss_val / vl_step <= vlss_mn:
                vacc_early_model = acc_val / vl_step
                vlss_early_model = loss_val / vl_step
                torch.save(model, checkpt_file)
            vacc_mx = np.max((vacc_early_model, vacc_mx))
            vlss_mn = np.min((vlss_early_model, vlss_mn))
            curr_step = 0
        else:
            curr_step += 1
            # print(curr_step)
            if curr_step == args.early_stopping:
                # print('Early stop! Min loss: ', vlss_mn, ', Max accuracy: ', vacc_mx)
                # print('Early stop model validation loss: ', vlss_early_model, ', accuracy: ', vacc_early_model)
                break
代码说明
- 训练循环遍历每个 epoch,并执行以下步骤:
- 将模型设置为训练模式
 - 清空优化器的梯度
 - 使用模型进行前向传播,得到输出
 - 计算训练损失,包括分类损失 
loss_xy和非监督损失loss_ncl - 计算总训练损失 
loss_train - 反向传播并更新模型参数
 
 - 验证模式在每个 epoch 结束后执行,主要用于评估模型在验证集上的性能:
- 将模型设置为评估模式
 - 进行前向传播,得到输出
 - 计算验证损失和准确率
 - 早期停止机制,如果验证损失不再下降,则停止训练
 
 vlss_mn = np.min((vlss_early_model, vlss_mn))的作用是记录并更新验证集上的最小损失值,用于 early stopping 机制。
代码解释
args: 包含训练配置参数,例如epochs(训练轮数),lamd(损失函数权重),early_stopping(早期停止阈值) 等。model: 训练的 GCN/GAT 模型。optimizer: 模型优化器,例如 Adam 或 SGD。features: 节点特征矩阵。adjtensor: 图邻接矩阵。idx_train: 训练集节点索引。idx_val: 验证集节点索引。labels: 节点标签。output: 模型的输出,包括分类概率和平均输出。areout: 平均输出。cost_val: 存储验证集损失的列表。vacc_mx: 记录验证集上的最高准确率。vlss_mn: 记录验证集上的最低损失。vacc_early_model: 记录在早期停止时模型的验证集准确率。vlss_early_model: 记录在早期停止时模型的验证集损失。curr_step: 记录连续验证集损失没有下降的次数。checkpt_file: 模型保存路径。
该代码示例提供了完整的半监督学习训练过程,可以作为学习 PyTorch 的基础,也可以作为构建更复杂的半监督学习模型的参考。
原文地址: https://www.cveoy.top/t/topic/ildt 著作权归作者所有。请勿转载和采集!