PyTorch 模型训练代码示例:记录测试准确率和损失值
以下代码展示了 PyTorch 模型训练过程中记录测试准确率和损失值的示例:
for epoch in range(50,args.epochs):
tic = time.time()
train_predict = []
train_label = []
runloss = 0
for (cnt, i) in enumerate(tqdm(train_loader)):
batch_x = i['data']
batch_y = i['label']
batch_x = torch.unsqueeze(batch_x, dim=1)
batch_x = batch_x.float()
if torch.cuda.is_available():
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()
tlabels, tpredi, tloss = train(model, batch_x, batch_y, optimizer, criterion)
runloss = runloss + tloss
train_label.extend(tlabels)
train_predict.extend(tpredi)
per_epoch_train_time = time.time() - tic
train_time = per_epoch_train_time + train_time
taccuracy = accuracy_score(train_label, train_predict)
train_acc_list.append(taccuracy)
loss = runloss / batches_per_epoch
train_loss_list.append(loss)
acc_score, loss_score, feature_list, C = test(model, test_loader, criterion)
print('Epoch %d Val_accuracy %.3f Val_loss %.3f' % (epoch, acc_score, loss_score))
with open('train_log.csv', 'a+') as f:
f.write(f'{epoch}, {round(taccuracy, 4)}, {round(loss, 4)}
')
f.close()
with open('log.csv', 'a+') as f:
f.write(f'{epoch}, {round(acc_score, 4)}, {round(loss_score, 4)}
')
f.close()
test_acc_list.append(acc_score)
test_loss_list.append(loss_score)
代码中,test_acc_list.append(acc_score) 表示将每个 epoch 的测试准确率 acc_score 添加到 test_acc_list 列表中,test_loss_list.append(loss_score) 表示将每个 epoch 的测试损失值 loss_score 添加到 test_loss_list 列表中。这样可以记录每个 epoch 的测试准确率和损失值,用于后续分析和可视化。
例如,可以使用这些列表绘制训练过程中的准确率和损失值变化曲线,帮助分析模型的性能和训练过程。
代码说明
train_loader: 训练数据集的 DataLoader 对象test_loader: 测试数据集的 DataLoader 对象model: 待训练的 PyTorch 模型optimizer: 优化器,用于更新模型参数criterion: 损失函数accuracy_score: 用于计算准确率的函数train_acc_list: 用于记录每个 epoch 的训练准确率train_loss_list: 用于记录每个 epoch 的训练损失值test_acc_list: 用于记录每个 epoch 的测试准确率test_loss_list: 用于记录每个 epoch 的测试损失值
代码解析
- 训练阶段:
- 循环遍历每个 epoch
- 使用
train函数进行训练,计算训练损失和准确率,并将结果添加到相应的列表中
- 测试阶段:
- 使用
test函数进行测试,计算测试损失和准确率,并将结果添加到相应的列表中
- 使用
- 记录结果:
- 将每个 epoch 的训练和测试结果写入 CSV 文件,用于后续分析和可视化
总结
这段代码展示了如何记录 PyTorch 模型训练过程中的测试准确率和损失值,并使用这些信息进行分析和可视化。通过记录这些数据,我们可以更好地理解模型的训练过程,并调整模型参数以提高模型性能。
原文地址: https://www.cveoy.top/t/topic/fsS6 著作权归作者所有。请勿转载和采集!