PyTorch GCN/GAT模型训练代码解析:损失函数与Early Stopping
这段代码展示了使用PyTorch训练GCN或GAT模型的典型流程。以下是代码的详细解析:
for epoch in range(args.epochs):
t = time.time()
# for train
model.train()
optimizer.zero_grad()
output = model(features, adjtensor)
# 平均输出
areout = output[1]
loss_xy = 0
loss_ncl = 0
for k in range(len(output[0])):
# print('k = ' + str(k))
# print(F.nll_loss(output[0][k][idx_train], labels[idx_train]))
# print(F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel]))
loss_xy += F.nll_loss(output[0][k][idx_train], labels[idx_train])
loss_ncl += F.mse_loss(output[0][k][idx_unlabel], areout[idx_unlabel])
loss_train = (1-args.lamd)* loss_xy - args.lamd * loss_ncl
# loss_train = (1 - args.lamd) * loss_xy + args.lamd * 1 / loss_ncl
# loss_train = (1 - args.lamd) * loss_xy + args.lamd * (torch.exp(-loss_ncl))
print(loss_xy)
print(loss_ncl)
print(torch.exp(-loss_ncl))
print((1 - args.lamd) * loss_xy)
print(args.lamd * (torch.exp(-loss_ncl)))
print(epoch)
print(loss_train)
print('.............')
acc_train = accuracy(areout[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()
# for val
if validate:
# print('no')
model.eval()
output = model(features, adjtensor)
areout = output[1]
vl_step = len(idx_val)
loss_val = F.nll_loss(areout[idx_val], labels[idx_val])
acc_val = accuracy(areout[idx_val], labels[idx_val])
# vl_step = len(idx_train)
# loss_val = F.nll_loss(areout[idx_train], labels[idx_train])
# acc_val = accuracy(areout[idx_train], labels[idx_train])
cost_val.append(loss_val)
# 原始GCN的验证
# if epoch > args.early_stopping and cost_val[-1] > torch.mean(torch.stack(cost_val[-(args.early_stopping + 1):-1])):
# # print('Early stopping...')
# print(epoch)
# break
# print(epoch)
# GAT的验证
if acc_val/vl_step >= vacc_mx or loss_val/vl_step <= vlss_mn:
if acc_val/vl_step >= vacc_mx and loss_val/vl_step <= vlss_mn:
vacc_early_model = acc_val/vl_step
vlss_early_model = loss_val/vl_step
torch.save(model, checkpt_file)
vacc_mx = np.max((vacc_early_model, vacc_mx))
vlss_mn = np.min((vlss_early_model, vlss_mn))
curr_step = 0
else:
curr_step += 1
# print(curr_step)
if curr_step == args.early_stopping:
# print('Early stop! Min loss: ', vlss_mn, ', Max accuracy: ', vacc_mx)
# print('Early stop model validation loss: ', vlss_early_model, ', accuracy: ', vacc_early_model)
break
代码解析:
-
训练循环: 代码使用
for epoch in range(args.epochs)循环进行模型训练,args.epochs代表训练的总轮数。 -
训练阶段: 代码中的
model.train()将模型设置为训练模式,optimizer.zero_grad()清空模型参数的梯度。output = model(features, adjtensor)通过模型预测输出,areout = output[1]获取模型的平均输出。 -
损失函数计算: 代码使用两个损失函数:
loss_xy表示对训练节点的分类损失,使用F.nll_loss计算负对数似然损失。loss_ncl表示对未标记节点的无监督损失,使用F.mse_loss计算均方误差损失。 -
总损失计算: 代码使用
loss_train = (1-args.lamd)* loss_xy - args.lamd * loss_ncl计算总损失,其中args.lamd是控制两种损失权重的超参数。 -
反向传播: 代码使用
loss_train.backward()计算梯度,optimizer.step()更新模型参数。 -
验证阶段: 代码使用
model.eval()将模型设置为评估模式。计算验证集上的损失loss_val和准确率acc_val,并将损失记录到cost_val列表中。 -
Early Stopping: 代码使用Early Stopping机制来避免过拟合。当验证集上的准确率达到最大值或者损失达到最小值时,将模型保存到
checkpt_file文件中。如果连续args.early_stopping轮验证集上的指标没有改善,则停止训练。
总结: 这段代码演示了如何使用PyTorch训练GCN/GAT模型,包含了损失函数的计算、Early Stopping机制的实现以及模型的保存。该代码结构清晰,注释详细,方便理解模型训练的原理和实践。
代码中loss_xy += F.nll_loss(output[0][k][idx_train], labels[idx_train])的作用:
这行代码计算了模型对训练节点的分类损失,使用的是负对数似然损失函数(Negative Log Likelihood Loss)。该损失函数的作用是根据模型的输出和真实标签计算出一个概率分布,然后将真实标签所对应的概率取负对数作为损失值,用于反向传播更新模型参数。损失函数的值越小,说明模型对训练数据的拟合效果越好。
原文地址: https://www.cveoy.top/t/topic/igxK 著作权归作者所有。请勿转载和采集!