PyTorch 半监督学习训练循环示例

该示例代码演示了使用 PyTorch 训练 GCN 或 GAT 模型进行半监督学习的完整过程。

for epoch in range(args.epochs):
    t = time.time()

    # 训练模式
    model.train()
    optimizer.zero_grad()

    output = model(features, adjtensor)

    # 平均输出
    areout = output[1]

    loss_xy = 0
    loss_ncl = 0

    for k in range(len(output[0])):
        # print('k = ' + str(k))
        # print(F.nll_loss(output[0][k][idx_train], labels[idx_train]))
        # print(F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel]))
        loss_xy += F.nll_loss(output[0][k][idx_train], labels[idx_train])
        loss_ncl += F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel])

    loss_train = (1 - args.lamd) * loss_xy - args.lamd * loss_ncl
    # loss_train = (1 - args.lamd) * loss_xy + args.lamd * 1 / loss_ncl

    # loss_train = (1 - args.lamd) * loss_xy + args.lamd * (torch.exp(-loss_ncl))

    print(loss_xy)
    print(loss_ncl)

    print(torch.exp(-loss_ncl))

    print((1 - args.lamd) * loss_xy)
    print(args.lamd * (torch.exp(-loss_ncl)))

    print(epoch)
    print(loss_train)
    print('.............')
    acc_train = accuracy(areout[idx_train], labels[idx_train])

    loss_train.backward()
    optimizer.step()

    # 验证模式
    if validate:
        # print('no')
        model.eval()
        output = model(features, adjtensor)
        areout = output[1]
        vl_step = len(idx_val)
        loss_val = F.nll_loss(areout[idx_val], labels[idx_val])
        acc_val = accuracy(areout[idx_val], labels[idx_val])
        # vl_step = len(idx_train)
        # loss_val = F.nll_loss(areout[idx_train], labels[idx_train])
        # acc_val = accuracy(areout[idx_train], labels[idx_train])

        cost_val.append(loss_val)

        # 原始 GCN 的验证
        # if epoch > args.early_stopping and cost_val[-1] > torch.mean(torch.stack(cost_val[-(args.early_stopping + 1):-1])):
        #     # print('Early stopping...')
        #     print(epoch)
        #     break
        # print(epoch)
        # GAT 的验证
        if acc_val / vl_step >= vacc_mx or loss_val / vl_step <= vlss_mn:
            if acc_val / vl_step >= vacc_mx and loss_val / vl_step <= vlss_mn:
                vacc_early_model = acc_val / vl_step
                vlss_early_model = loss_val / vl_step
                torch.save(model, checkpt_file)

            vacc_mx = np.max((vacc_early_model, vacc_mx))
            vlss_mn = np.min((vlss_early_model, vlss_mn))
            curr_step = 0

        else:
            curr_step += 1

            # print(curr_step)
            if curr_step == args.early_stopping:
                # print('Early stop! Min loss: ', vlss_mn, ', Max accuracy: ', vacc_mx)
                # print('Early stop model validation loss: ', vlss_early_model, ', accuracy: ', vacc_early_model)
                break

代码说明

  • 训练循环遍历每个 epoch,并执行以下步骤:
    • 将模型设置为训练模式
    • 清空优化器的梯度
    • 使用模型进行前向传播,得到输出
    • 计算训练损失,包括分类损失 loss_xy 和非监督损失 loss_ncl
    • 计算总训练损失 loss_train
    • 反向传播并更新模型参数
  • 验证模式在每个 epoch 结束后执行,主要用于评估模型在验证集上的性能:
    • 将模型设置为评估模式
    • 进行前向传播,得到输出
    • 计算验证损失和准确率
    • 早期停止机制,如果验证损失不再下降,则停止训练
  • vlss_mn = np.min((vlss_early_model, vlss_mn)) 的作用是记录并更新验证集上的最小损失值,用于 early stopping 机制。

代码解释

  • args: 包含训练配置参数,例如 epochs (训练轮数), lamd (损失函数权重),early_stopping (早期停止阈值) 等。
  • model: 训练的 GCN/GAT 模型。
  • optimizer: 模型优化器,例如 Adam 或 SGD。
  • features: 节点特征矩阵。
  • adjtensor: 图邻接矩阵。
  • idx_train: 训练集节点索引。
  • idx_val: 验证集节点索引。
  • labels: 节点标签。
  • output: 模型的输出,包括分类概率和平均输出。
  • areout: 平均输出。
  • cost_val: 存储验证集损失的列表。
  • vacc_mx: 记录验证集上的最高准确率。
  • vlss_mn: 记录验证集上的最低损失。
  • vacc_early_model: 记录在早期停止时模型的验证集准确率。
  • vlss_early_model: 记录在早期停止时模型的验证集损失。
  • curr_step: 记录连续验证集损失没有下降的次数。
  • checkpt_file: 模型保存路径。

该代码示例提供了完整的半监督学习训练过程,可以作为学习 PyTorch 的基础,也可以作为构建更复杂的半监督学习模型的参考。

PyTorch 训练循环:基于 GCN/GAT 的半监督学习示例

原文地址: https://www.cveoy.top/t/topic/ildt 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录