PyTorch 深度学习模型训练和验证过程代码详解

本文将详细解析 PyTorch 深度学习模型的训练和验证过程代码。代码示例中,我们使用自定义损失函数custom_loss,并通过余弦相似度进行分类预测。

训练阶段

num_epochs = 100
for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0

    # 训练阶段
    network.train()
    correct_total = 0
    total = 0
    for i, input_tensor in enumerate(train_tensors):
        optimizer.zero_grad()

        output = network(input_tensor)

        loss = custom_loss(output, tensor_list[i])

        loss.backward()
        optimizer.step()

        # 统计准确率
        target_similarity = F.cosine_similarity(output, tensor_list[i].unsqueeze(0), dim=1)
        label_list = [torch.tensor([1, 0, 0, 0]), torch.tensor([0, 1, 0, 0]), torch.tensor([0, 0, 1, 0]),
                      torch.tensor([1, 1, 1, 1])]
        other_list = []
        for label_tensor in label_list:
            if not torch.all(torch.eq(tensor_list[i], label_tensor)):
                other_list.append(label_tensor)

        if target_similarity > torch.max(torch.stack([F.cosine_similarity(output, other.unsqueeze(0), dim=1) for other in other_list]), dim=0).values:
            correct_total += 1

        total += 1

    # 计算最终的正确率
    accuracy = correct_total / total
    running_loss += loss.item()

    # 打印训练信息
    # print('Epoch: %d, Loss: %.3f, Training Accuracy: %.2f%%' % (epoch+1, running_loss, 100 * correct / total))
    print('Final Accuracy: %.2f%%' % (100 * accuracy))
    print('Epoch: %d, Loss: %.3f' % (epoch + 1, running_loss))

    network.eval()
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for j, val_input_tensor in enumerate(val_tensors):
            val_output = network(val_input_tensor)

            # 计算相似度
            val_target_similarity = F.cosine_similarity(val_output, tensor_list[j].unsqueeze(0), dim=1)

            val_other_similarities = []
            for k, tensor in enumerate(tensor_list):
                if k != j:
                    similarity = F.cosine_similarity(val_output, tensor.unsqueeze(0), dim=1)
                    val_other_similarities.append(similarity)

            val_other_similarities = torch.cat(val_other_similarities)

            val_labels = [torch.tensor([1, 0, 0, 0]), torch.tensor([0, 1, 0, 0]), torch.tensor([0, 0, 1, 0]),
                          torch.tensor([1, 1, 1, 1])]
            val_label_index = torch.argmax(tensor_list[j])
            val_label = val_labels[val_label_index]

            if val_target_similarity > torch.max(val_other_similarities):
                val_predicted_index = torch.argmax(val_output)
                if torch.all(torch.eq(val_label, val_labels[val_predicted_index])):
                    val_correct += 1

            val_total += 1

    # 打印验证信息
    print('Validation Accuracy: %.2f%%' % (100 * val_correct / val_total))

验证过程优化

在验证阶段,除了计算准确率外,我们还可以计算验证集上的平均损失,以更全面地评估模型性能。以下是对验证过程代码的优化:

network.eval()
val_loss = 0.0
val_correct = 0
val_total = 0

with torch.no_grad():
    for j, val_input_tensor in enumerate(val_tensors):
        val_output = network(val_input_tensor)

        # 计算相似度
        val_target_similarity = F.cosine_similarity(val_output, tensor_list[j].unsqueeze(0), dim=1)

        val_other_similarities = []
        for k, tensor in enumerate(tensor_list):
            if k != j:
                similarity = F.cosine_similarity(val_output, tensor.unsqueeze(0), dim=1)
                val_other_similarities.append(similarity)

        val_other_similarities = torch.cat(val_other_similarities)

        val_labels = [torch.tensor([1, 0, 0, 0]), torch.tensor([0, 1, 0, 0]), torch.tensor([0, 0, 1, 0]),
                      torch.tensor([1, 1, 1, 1])]
        val_label_index = torch.argmax(tensor_list[j])
        val_label = val_labels[val_label_index]

        if val_target_similarity > torch.max(val_other_similarities):
            val_predicted_index = torch.argmax(val_output)
            if torch.all(torch.eq(val_label, val_labels[val_predicted_index])):
                val_correct += 1

        val_total += 1

        val_loss += custom_loss(val_output, tensor_list[j]).item()

# 计算验证集上的平均损失和准确率
val_loss /= val_total
val_accuracy = 100 * val_correct / val_total

# 打印验证信息
print('Validation Loss: %.3f, Accuracy: %.2f%%' % (val_loss, val_accuracy))

代码解析

  1. 训练阶段

    • 使用network.train()将模型设置为训练模式,并使用optimizer.zero_grad()将梯度清零;
    • 使用network(input_tensor)进行前向传播;
    • 使用custom_loss(output, tensor_list[i])计算损失;
    • 使用loss.backward()计算梯度;
    • 使用optimizer.step()更新模型参数;
    • 计算准确率并打印训练信息。
  2. 验证阶段

    • 使用network.eval()将模型设置为评估模式,并使用torch.no_grad()禁止梯度计算;
    • 使用network(val_input_tensor)进行前向传播;
    • 计算相似度并判断预测是否正确,统计正确预测的数量和总样本数量;
    • 计算验证集上的平均损失和准确率,并打印验证信息。

总结

本文详细解析了 PyTorch 深度学习模型的训练和验证过程代码,并对验证过程进行了优化,添加了损失计算,使得评估结果更加全面。希望本文能帮助您更好地理解 PyTorch 模型训练和验证过程,并能应用于您的深度学习项目。

PyTorch 深度学习模型训练和验证过程代码详解

原文地址: https://www.cveoy.top/t/topic/Nt1 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录