PyTorch 图神经网络训练代码:早期停止策略实现
该代码片段展示了使用 PyTorch 训练图神经网络时,如何使用早期停止策略来优化模型训练过程。
for epoch in range(args.epochs):
t = time.time()
# for train
model.train()
optimizer.zero_grad()
output = model(features, adjtensor)
# 平均输出
areout = output[1]
loss_xy = 0
loss_ncl = 0
for k in range(len(output[0])):
# print('k = '+str(k))
# print(F.nll_loss(output[0][k][idx_train], labels[idx_train]))
# print(F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel]))
loss_xy += F.nll_loss(output[0][k][idx_train], labels[idx_train])
loss_ncl += F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel])
loss_train = (1-args.lamd)* loss_xy - args.lamd * loss_ncl
# loss_train = (1 - args.lamd) * loss_xy + args.lamd * 1 / loss_ncl
# loss_train = (1 - args.lamd) * loss_xy + args.lamd * (torch.exp(-loss_ncl))
print(loss_xy)
print(loss_ncl)
print(torch.exp(-loss_ncl))
print((1 - args.lamd) * loss_xy)
print(args.lamd * (torch.exp(-loss_ncl)))
print(epoch)
print(loss_train)
print('.............')
acc_train = accuracy(areout[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()
# for val
if validate:
# print('no')
model.eval()
output = model(features, adjtensor)
areout = output[1]
vl_step = len(idx_val)
loss_val = F.nll_loss(areout[idx_val], labels[idx_val])
acc_val = accuracy(areout[idx_val], labels[idx_val])
# vl_step = len(idx_train)
# loss_val = F.nll_loss(areout[idx_train], labels[idx_train])
# acc_val = accuracy(areout[idx_train], labels[idx_train])
cost_val.append(loss_val)
# 原始GCN的验证
# if epoch > args.early_stopping and cost_val[-1] > torch.mean(torch.stack(cost_val[-(args.early_stopping + 1):-1])):
# # print('Early stopping...')
# print(epoch)
# break
# print(epoch)
# GAT的验证
if acc_val/vl_step >= vacc_mx or loss_val/vl_step <= vlss_mn:
if acc_val/vl_step >= vacc_mx and loss_val/vl_step <= vlss_mn:
vacc_early_model = acc_val/vl_step
vlss_early_model = loss_val/vl_step
torch.save(model, checkpt_file)
vacc_mx = np.max((vacc_early_model, vacc_mx))
vlss_mn = np.min((vlss_early_model, vlss_mn))
curr_step = 0
else:
curr_step += 1
# print(curr_step)
if curr_step == args.early_stopping:
# print('Early stop! Min loss: ', vlss_mn, ', Max accuracy: ', vacc_mx)
# print('Early stop model validation loss: ', vlss_early_model, ', accuracy: ', vacc_early_model)
break
代码解释:
-
模型训练 (for train):
- 将模型设置为训练模式 (
model.train()) - 清零优化器梯度 (
optimizer.zero_grad()) - 将特征和邻接矩阵输入模型,获得模型输出 (
output = model(features, adjtensor)) - 计算训练集上的损失,并进行反向传播和优化器更新
- 将模型设置为训练模式 (
-
模型验证 (for val):
- 将模型设置为评估模式 (
model.eval()) - 将特征和邻接矩阵输入模型,获得模型输出 (
output = model(features, adjtensor)) - 计算验证集上的损失和准确率
- 将模型设置为评估模式 (
-
早期停止 (Early Stopping):
- 使用
acc_val和loss_val评估模型性能 - 当验证集上的准确率不再提升或损失不再下降时,停止训练
- 使用
代码中 acc_val = accuracy(areout[idx_val], labels[idx_val]) 的作用:
该行代码的作用是计算模型在验证集上的准确率。其中,areout 是模型的输出结果,idx_val 是验证集的索引,labels 是节点的真实标签。通过计算模型输出结果和真实标签的准确率,可以评估模型在验证集上的性能。在早期停止策略中,如果模型的验证集准确率不再提高或者验证集损失不再降低,就会停止训练,以避免过拟合。
早期停止的优点:
- 可以避免过拟合,提高模型在未知数据上的泛化能力
- 可以节省训练时间
总结:
本代码片段展示了使用 PyTorch 训练图神经网络时,如何使用早期停止策略来优化模型训练过程。代码中包括模型训练、验证以及基于准确率和损失的早期停止逻辑。早期停止策略是一种重要的技术,可以帮助我们训练出更强大、泛化能力更强的模型。
原文地址: https://www.cveoy.top/t/topic/leBs 著作权归作者所有。请勿转载和采集!