PyTorch 图神经网络训练代码 - 带有早期停止和验证功能
for epoch in range(args.epochs):
t = time.time()
# for train
model.train()
optimizer.zero_grad()
output = model(features, adjtensor)
# 平均输出
areout = output[1]
loss_xy = 0
loss_ncl = 0
for k in range(len(output[0])):
# print('k = ' + str(k))
# print(F.nll_loss(output[0][k][idx_train], labels[idx_train]))
# print(F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel]))
loss_xy += F.nll_loss(output[0][k][idx_train], labels[idx_train])
loss_ncl += F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel])
loss_train = (1-args.lamd)* loss_xy - args.lamd * loss_ncl
# loss_train = (1 - args.lamd) * loss_xy + args.lamd * 1 / loss_ncl
# loss_train = (1 - args.lamd) * loss_xy + args.lamd * (torch.exp(-loss_ncl))
print(loss_xy)
print(loss_ncl)
print(torch.exp(-loss_ncl))
print((1 - args.lamd) * loss_xy)
print(args.lamd * (torch.exp(-loss_ncl)))
print(epoch)
print(loss_train)
print('.............')
acc_train = accuracy(areout[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()
# for val
if validate:
# print('no')
model.eval()
output = model(features, adjtensor)
areout = output[1]
vl_step = len(idx_val)
loss_val = F.nll_loss(areout[idx_val], labels[idx_val])
acc_val = accuracy(areout[idx_val], labels[idx_val])
# vl_step = len(idx_train)
# loss_val = F.nll_loss(areout[idx_train], labels[idx_train])
# acc_val = accuracy(areout[idx_train], labels[idx_train])
cost_val.append(loss_val)
# 原始GCN的验证
# if epoch > args.early_stopping and cost_val[-1] > torch.mean(torch.stack(cost_val[-(args.early_stopping + 1):-1])):
# # print('Early stopping...')
# print(epoch)
# break
# print(epoch)
# GAT的验证
if acc_val/vl_step >= vacc_mx or loss_val/vl_step <= vlss_mn:
if acc_val/vl_step >= vacc_mx and loss_val/vl_step <= vlss_mn:
vacc_early_model = acc_val/vl_step
vlss_early_model = loss_val/vl_step
torch.save(model, checkpt_file)
vacc_mx = np.max((vacc_early_model, vacc_mx))
vlss_mn = np.min((vlss_early_model, vlss_mn))
curr_step = 0
else:
curr_step += 1
# print(curr_step)
if curr_step == args.early_stopping:
# print('Early stop! Min loss: ', vlss_mn, ', Max accuracy: ', vacc_mx)
# print('Early stop model validation loss: ', vlss_early_model, ', accuracy: ', vacc_early_model)
break
# 在vl_step = len(idx_val)中,vl_step代表验证集样本数量,它用来计算验证集上的准确率和损失。idx_val是一个包含所有验证集样本索引的列表。len(idx_val)就是验证集样本的数量。
代码解释
这段代码展示了一个使用 PyTorch 训练图神经网络 (GNN) 的例子。它包含了以下几个关键部分:
- 训练循环: 使用
for epoch in range(args.epochs):循环进行训练,迭代args.epochs次。 - 模型训练: 在每次迭代中,首先使用
model.train()将模型设置为训练模式,并使用optimizer.zero_grad()清零梯度。然后使用模型进行预测,并计算训练损失。 - 损失计算: 代码中定义了两个损失项,
loss_xy和loss_ncl。最终的训练损失由这两项加权组合而成。 - 反向传播: 使用
loss_train.backward()计算梯度,并使用optimizer.step()更新模型参数。 - 验证: 代码在训练过程中还使用验证集进行评估,以查看模型在未见过的数据上的性能。它计算了验证集上的损失和准确率,并使用
vacc_mx和vlss_mn记录最佳的验证集准确率和最低的验证集损失。 - 早期停止: 代码使用
curr_step变量跟踪验证集损失连续上升的次数,当curr_step等于args.early_stopping时,停止训练,并保存最佳模型。
代码要点
- 代码中的
vl_step代表验证集样本数量,它用来计算验证集上的准确率和损失。 idx_val是一个包含所有验证集样本索引的列表。len(idx_val)表示验证集样本的数量。- 代码使用了早期停止策略,以防止模型过拟合,并找到在验证集上表现最佳的模型。
代码使用
这段代码可以作为训练图神经网络的一个示例,可以根据实际任务进行修改和扩展。需要根据具体的模型和数据集修改代码中的一些参数,例如 args.epochs、args.lamd、args.early_stopping 等。
原文地址: https://www.cveoy.top/t/topic/iju8 著作权归作者所有。请勿转载和采集!