PyTorch深度学习:计算模型预测准确率
使用PyTorch计算模型预测准确率
本文将介绍如何使用PyTorch加载训练好的模型,并使用测试数据计算模型的预测准确率。
以下是完整的Python代码示例:
import torch
import torch.nn.functional as F
import torch.nn as nn
# 定义网络结构
class MyNetwork(nn.Module):
def __init__(self):
super(MyNetwork, self).__init__()
self.flatten = nn.Flatten()
self.fc = nn.Linear(339 * 256, 4)
def forward(self, x):
x = self.flatten(x)
x = self.fc(x)
return x
# 加载保存的模型参数
network = MyNetwork()
network.load_state_dict(torch.load('final_model.pt'))
# 读取标签数据
tensor_list = []
with open('biaoqian.txt', 'r') as file:
lines = file.readlines()
for line in lines:
line = line.strip()
numbers = line.split()
tensor = torch.tensor([float(num) for num in numbers])
tensor_list.append(tensor)
# 读取测试数据
test_data = torch.load('test_data.pt')
# 将测试数据转换为张量
test_tensors = [torch.tensor(tensor) for tensor in test_data]
# 在测试集上计算准确率
network.eval()
correct_total = 0
total = 0
with torch.no_grad():
for j, test_input_tensor in enumerate(test_tensors):
output = network(test_input_tensor.unsqueeze(0))
# 使用cosine similarity计算预测标签与真实标签的相似度
target_similarity = F.cosine_similarity(output, tensor_list[j].unsqueeze(0), dim=1)
# 定义可能的标签列表
label_list = [torch.tensor([1, 0, 0, 0]), torch.tensor([0, 1, 0, 0]), torch.tensor([0, 0, 1, 0]),
torch.tensor([1, 1, 1, 1])]
# 获取除真实标签外的其他标签
other_list = [label_tensor for label_tensor in label_list if not torch.all(torch.eq(tensor_list[j], label_tensor))]
# 判断预测结果是否与真实标签最相似
if target_similarity > torch.max(torch.stack([F.cosine_similarity(output, other.unsqueeze(0), dim=1) for other in other_list]), dim=0).values:
correct_total += 1
total += 1
# 计算准确率
accuracy = correct_total / total
# 打印准确率
print('Accuracy: %.2f%%' % (accuracy * 100))
代码解释:
- 定义网络结构: 定义了一个名为
MyNetwork的简单神经网络,包括Flatten层和Linear层。 - 加载模型参数: 使用
torch.load()加载训练好的模型参数。 - 读取数据: 从 'biaoqian.txt' 文件读取标签数据,从 'test_data.pt' 文件读取测试数据。
- 数据预处理: 将测试数据转换为PyTorch张量。
- 计算准确率: 遍历测试数据,进行预测,并根据预测结果与真实标签计算准确率。
- 输出结果: 打印模型在测试集上的准确率。
注意: 这段代码使用了cosine similarity来比较预测结果与真实标签的相似度。你需要根据实际情况修改代码,例如调整网络结构、数据读取方式、准确率计算方法等。
原文地址: https://www.cveoy.top/t/topic/3Hd 著作权归作者所有。请勿转载和采集!