PyTorch 图神经网络训练代码:基于交叉熵和均方误差的半监督学习
这段代码展示了使用 PyTorch 训练图神经网络的示例,它利用交叉熵损失和均方误差损失函数来实现半监督学习。
for epoch in range(args.epochs):
t = time.time()
# for train
model.train()
optimizer.zero_grad()
output = model(features, adjtensor)
# 平均输出
areout = output[1]
loss_xy = 0
loss_ncl = 0
for k in range(len(output[0])):
# print('k = ' + str(k))
# print(F.nll_loss(output[0][k][idx_train], labels[idx_train]))
# print(F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel]))
loss_xy += F.nll_loss(output[0][k][idx_train], labels[idx_train])
loss_ncl += F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel])
loss_train = (1-args.lamd)* loss_xy - args.lamd * loss_ncl
# loss_train = (1 - args.lamd) * loss_xy + args.lamd * 1 / loss_ncl
# loss_train = (1 - args.lamd) * loss_xy + args.lamd * (torch.exp(-loss_ncl))
print(loss_xy)
print(loss_ncl)
print(torch.exp(-loss_ncl))
print((1 - args.lamd) * loss_xy)
print(args.lamd * (torch.exp(-loss_ncl)))
print(epoch)
print(loss_train)
print('.............')
acc_train = accuracy(areout[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()
# for val
if validate:
# print('no')
model.eval()
output = model(features, adjtensor)
areout = output[1]
vl_step = len(idx_val)
loss_val = F.nll_loss(areout[idx_val], labels[idx_val])
acc_val = accuracy(areout[idx_val], labels[idx_val])
# vl_step = len(idx_train)
# loss_val = F.nll_loss(areout[idx_train], labels[idx_train])
# acc_val = accuracy(areout[idx_train], labels[idx_train])
cost_val.append(loss_val)
# 原始GCN的验证
# if epoch > args.early_stopping and cost_val[-1] > torch.mean(torch.stack(cost_val[-(args.early_stopping + 1):-1])):
# # print('Early stopping...')
# print(epoch)
# break
# print(epoch)
# GAT的验证
if acc_val/vl_step >= vacc_mx or loss_val/vl_step <= vlss_mn:
if acc_val/vl_step >= vacc_mx and loss_val/vl_step <= vlss_mn:
vacc_early_model = acc_val/vl_step
vlss_early_model = loss_val/vl_step
torch.save(model, checkpt_file)
vacc_mx = np.max((vacc_early_model, vacc_mx))
vlss_mn = np.min((vlss_early_model, vlss_mn))
curr_step = 0
else:
curr_step += 1
# print(curr_step)
if curr_step == args.early_stopping:
# print('Early stop! Min loss: ', vlss_mn, ', Max accuracy: ', vacc_mx)
# print('Early stop model validation loss: ', vlss_early_model, ', accuracy: ', vacc_early_model)
break
这段代码主要包含以下部分:
-
训练循环: 迭代训练 epochs 次,每次迭代包括以下步骤:
- 将模型设置为训练模式
- 清零优化器梯度
- 使用模型预测输出
- 计算交叉熵损失 (loss_xy) 和均方误差损失 (loss_ncl)
- 计算总训练损失 (loss_train)
- 反向传播并更新模型参数
-
验证循环: 在每个 epoch 后进行验证,步骤包括:
- 将模型设置为验证模式
- 使用模型预测输出
- 计算验证损失和准确率
- 根据验证结果更新早停机制的计数器
- 当满足早停条件时,停止训练
-
早停机制: 用于防止过拟合,当验证损失不再降低或准确率不再提高时,停止训练。
这段代码中的 acc_test = accuracy(areout[idx_test], labels[idx_test]) 用于计算测试集上的分类准确率。areout 是模型在测试集上的输出,idx_test 是测试集的节点索引,labels 是测试集上的真实标签。accuracy 函数是一个自定义的函数,用于计算分类准确率。
代码注释:
args.epochs表示训练的总次数args.lamd表示交叉熵损失和均方误差损失的权重idx_train、idx_val、idx_test分别表示训练集、验证集和测试集的节点索引features表示节点特征adjtensor表示图的邻接矩阵model表示图神经网络模型optimizer表示优化器cost_val用于记录验证损失vacc_mx和vlss_mn分别记录验证准确率的最高值和验证损失的最低值curr_step用于计数早停机制checkpt_file表示模型保存路径accuracy函数用于计算分类准确率
代码说明:
这段代码展示了如何在 PyTorch 中训练图神经网络,并使用了交叉熵损失和均方误差损失来实现半监督学习。它还包含了早停机制,以防止过拟合。您可以根据您的具体任务进行修改,例如更改损失函数、优化器、模型结构等。
原文地址: http://www.cveoy.top/t/topic/ilGX 著作权归作者所有。请勿转载和采集!