PyTorch语义分割代码实战:计算TN指标
PyTorch语义分割代码实战:计算TN指标
在语义分割任务中,除了常用的Dice系数等指标外,真负率(True Negative Rate, TNR)也是一个重要的评估指标,它反映了模型正确识别负样本的能力。本文将介绍如何在PyTorch语义分割代码中计算TN指标。
代码示例
以下代码片段展示了如何在PyTorch语义分割代码中计算TN指标:pythonwith torch.no_grad(): for i, b in enumerate(batch(dataset, batch_size)):
imgs = np.array([k[0] for k in b]).astype(np.float32) true_masks = np.array([k[1] for k in b])
imgs = torch.from_numpy(imgs) imgs = imgs.unsqueeze(1) true_masks = torch.from_numpy(true_masks)
pre_masks_eval = torch.zeros(true_masks.shape[0],14,256,256) true_masks_eval = torch.zeros(true_masks.shape[0],14,256,256) batchshape = true_masks.shape[0]
batch_dice = torch.zeros(14).cuda() if gpu: imgs = imgs.cuda() true_masks = true_masks.cuda() net.cuda()
output_img = net(imgs) input = output_img.cuda() pre_masks = input.max(1)[1].float() #索引代表像素所属类别的数字 for ak in range(14): if ak == 0: continue pre_masks_eval[:,ak] = (pre_masks==ak) true_masks_eval[:,ak] = (true_masks==ak) premasks = pre_masks_eval[:,ak].view(true_masks.shape[0],-1) truemasks = true_masks_eval[:,ak].view(true_masks.shape[0],-1)
intersection = premasks * truemasks TP = intersection.sum(1) FP = premasks.sum(1) - TP FN = truemasks.sum(1) - TP
# 计算TN TN = (premasks == 0) & (truemasks == 0) TN = TN.sum(1)
for bk in range(true_masks.shape[0]): if TP[bk] == 0 and FP[bk] == 0 and FN[bk] == 0: NE[ak] += 1 JNE[ak] += 1
# 在这里添加TN的统计 TN = (TP == 0) & (FP == 0) & (FN == 0) TN = TN.sum()
代码解释
-
计算TN:
python TN = (premasks == 0) & (truemasks == 0) TN = TN.sum(1)这段代码首先使用逻辑与运算符(&)找到预测为负样本且实际也为负样本的像素。然后,使用sum(1)函数对每个样本的TN进行求和。 -
统计TN:
python TN = (TP == 0) & (FP == 0) & (FN == 0) TN = TN.sum()这段代码在原有的if语句中添加了TN的统计逻辑,用于统计整个batch的TN数量。
总结
通过以上代码修改,我们可以在PyTorch语义分割代码中计算和统计TN指标,从而更全面地评估模型的性能。需要注意的是,上述代码仅供参考,实际应用中需要根据具体情况进行调整。
原文地址: https://www.cveoy.top/t/topic/fRT3 著作权归作者所有。请勿转载和采集!