PyTorch 代码:计算 TN 指标
def eval_net(net, dataset, slicetotal, batch_size=12, gpu=True):
'''Evaluation without the densecrf with the dice coefficient'''
net.eval()
start = time.time()
dice_ = torch.zeros(14).cuda()
jac_ = torch.zeros(14).cuda()
NE = torch.zeros(14).cuda()
JNE = torch.zeros(14).cuda()
accuracy_ = torch.zeros(14).cuda()
precision_ = torch.zeros(14).cuda()
recall_ = torch.zeros(14).cuda()
specificity_ = torch.zeros(14).cuda()
TN = torch.zeros(14).cuda() # 补充代码: 初始化 TN 数组
print(1)
with torch.no_grad():
for i, b in enumerate(batch(dataset, batch_size)):
imgs = np.array([k[0] for k in b]).astype(np.float32)
true_masks = np.array([k[1] for k in b])
imgs = torch.from_numpy(imgs)
imgs = imgs.unsqueeze(1)
true_masks = torch.from_numpy(true_masks)
pre_masks_eval = torch.zeros(true_masks.shape[0],14,256,256)
true_masks_eval = torch.zeros(true_masks.shape[0],14,256,256)
batchshape = true_masks.shape[0]
batch_dice = torch.zeros(14).cuda()
if gpu:
imgs = imgs.cuda()
true_masks = true_masks.cuda()
net.cuda()
output_img = net(imgs)
input = output_img.cuda()
pre_masks = input.max(1)[1].float() #索引代表像素所属类别的数>字
for ak in range(14):
if ak == 0:
continue
pre_masks_eval[:,ak] = (pre_masks==ak)
true_masks_eval[:,ak] = (true_masks==ak)
premasks = pre_masks_eval[:,ak].view(true_masks.shape[0],-1)
truemasks = true_masks_eval[:,ak].view(true_masks.shape[0],-1)
intersection = premasks * truemasks
TP = intersection.sum(1)
FP = premasks.sum(1) - TP
FN = truemasks.sum(1) - TP
# 补充代码: 计算 TN
TN[ak] = (premasks * truemasks).sum(1).eq(0).sum()
for bk in range(true_masks.shape[0]):
if TP[bk] == 0 and FP[bk] == 0 and FN[bk] == 0:
NE[ak] += 1
JNE[ak] += 1
else:
batch_dice[ak] = batch_dice[ak] + 2*TP[bk] / (2*TP[bk] + FP[bk] + FN[bk])
jac_[ak] = jac_[ak] + TP[bk] / (TP[bk] + FP[bk] + FN[bk])
dice_ = dice_ + batch_dice
for knum in range(14):
dice_[knum] = dice_[knum] / (slicetotal - NE[knum])
jac_[knum] = jac_[knum] / (slicetotal - JNE[knum])
end = time.time()
print('time used:',end - start)
return dice_,jac_
代码解释:
- 初始化 TN 数组: 在循环开始之前,添加了一行代码
TN = torch.zeros(14).cuda()用于初始化 TN 数组,用于存储每个类别的 TN 值。 - 计算 TN: 在每个类别循环中,添加了一行代码
TN[ak] = (premasks * truemasks).sum(1).eq(0).sum()来计算 TN 值。(premasks * truemasks)计算了预测为负样本且实际为负样本的像素数量。.sum(1)沿着样本维度进行求和,计算每个样本的 TN 值。.eq(0).sum()统计每个样本中 TN 值为 0 的数量,得到该类别的 TN 值。
代码修改说明:
将提供的代码段插入到 for ak in range(14): 循环内,紧接在计算 TP、FP 和 FN 之后。
其他指标:
除了 TN 指标,代码还计算了其他指标,例如:
- 准确率 (Accuracy)
- 精确率 (Precision)
- 召回率 (Recall)
- 特异度 (Specificity)
- Dice 系数
- Jaccard 系数
注意:
- 代码中的
14代表了分割任务中类别的数量。 - 代码假设每个样本的维度都是
256x256。 - 代码中使用的是 GPU 计算,因此需要确保 GPU 可用。
原文地址: https://www.cveoy.top/t/topic/fRS6 著作权归作者所有。请勿转载和采集!