PyTorch 图神经网络训练代码优化:早停机制实现及代码解析
本文将详细介绍 PyTorch 图神经网络训练代码中使用 early stopping 的方法,以及代码的具体解析。
Early Stopping 简介
Early stopping 是一种用于防止模型过拟合的常用技术。在模型训练过程中,模型会不断地学习训练数据,并尝试将损失函数降到最低。但如果模型训练时间过长,模型可能会过度拟合训练数据,导致在测试数据上的表现下降。Early stopping 的作用就是通过监控验证集上的指标(例如,准确率或损失),在验证集指标不再提升时停止训练,从而防止模型过拟合。
代码解析
for epoch in range(args.epochs):
t = time.time()
# for train
model.train()
optimizer.zero_grad()
output = model(features, adjtensor)
# 平均输出
areout = output[1]
loss_xy = 0
loss_ncl = 0
for k in range(len(output[0])):
# print('k = '+str(k))
# print(F.nll_loss(output[0][k][idx_train], labels[idx_train]))
# print(F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel]))
loss_xy += F.nll_loss(output[0][k][idx_train], labels[idx_train])
loss_ncl += F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel])
loss_train = (1-args.lamd)* loss_xy - args.lamd * loss_ncl
# loss_train = (1 - args.lamd) * loss_xy + args.lamd * 1 / loss_ncl
# loss_train = (1 - args.lamd) * loss_xy + args.lamd * (torch.exp(-loss_ncl))
print(loss_xy)
print(loss_ncl)
print(torch.exp(-loss_ncl))
print((1 - args.lamd) * loss_xy)
print(args.lamd * (torch.exp(-loss_ncl)))
print(epoch)
print(loss_train)
print('.............')
acc_train = accuracy(areout[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()
# for val
if validate:
# print('no')
model.eval()
output = model(features, adjtensor)
areout = output[1]
vl_step = len(idx_val)
loss_val = F.nll_loss(areout[idx_val], labels[idx_val])
acc_val = accuracy(areout[idx_val], labels[idx_val])
# vl_step = len(idx_train)
# loss_val = F.nll_loss(areout[idx_train], labels[idx_train])
# acc_val = accuracy(areout[idx_train], labels[idx_train])
cost_val.append(loss_val)
# 原始GCN的验证
# if epoch > args.early_stopping and cost_val[-1] > torch.mean(torch.stack(cost_val[-(args.early_stopping + 1):-1])):
# # print('Early stopping...')
# print(epoch)
# break
# print(epoch)
# GAT的验证
if acc_val/vl_step >= vacc_mx or loss_val/vl_step <= vlss_mn:
if acc_val/vl_step >= vacc_mx and loss_val/vl_step <= vlss_mn:
vacc_early_model = acc_val/vl_step
vlss_early_model = loss_val/vl_step
torch.save(model, checkpt_file)
vacc_mx = np.max((vacc_early_model, vacc_mx))
vlss_mn = np.min((vlss_early_model, vlss_mn))
curr_step = 0
else:
curr_step += 1
# print(curr_step)
if curr_step == args.early_stopping:
# print('Early stop! Min loss: ', vlss_mn, ', Max accuracy: ', vacc_mx)
# print('Early stop model validation loss: ', vlss_early_model, ', accuracy: ', vacc_early_model)
break
代码解析:
-
early stopping 的实现:
- 变量
curr_step用于记录当前连续验证集指标没有提升的次数。 - 在每个 epoch 的验证阶段,判断当前模型的验证集指标(例如,准确率或损失)是否超过之前的最优值。
- 如果超过,则将
curr_step重置为 0,并更新最优指标值。 - 如果没有超过,则将
curr_step加 1。 - 当
curr_step等于args.early_stopping时,则停止训练。
- 变量
-
验证阶段:
- 使用
model.eval()将模型设置为评估模式,禁用 dropout 和 batch normalization 等操作。 - 使用验证集计算模型的验证集指标。
- 使用
cost_val.append(loss_val)保存每个 epoch 的验证集损失,用于后续的 early stopping 判断。
- 使用
总结
本代码中使用 curr_step 记录连续验证集指标没有提升的次数,并根据 curr_step 和 args.early_stopping 的值判断是否停止训练。这种 early stopping 的方法可以有效地防止模型过拟合,提高模型的泛化能力。
在实际使用中,可以根据具体任务和数据集选择合适的验证集指标和 early stopping 的参数,以获得最佳的训练效果。
原文地址: https://www.cveoy.top/t/topic/leCx 著作权归作者所有。请勿转载和采集!