PyTorch 代码中计算 TN (真负例) 指标
PyTorch 代码中计算 TN (真负例) 指标
在下面的 PyTorch 代码中,我们补全了 TN (真负例) 的计算。
with torch.no_grad():
for i, b in enumerate(batch(dataset, batch_size)):
imgs = np.array([k[0] for k in b]).astype(np.float32)
true_masks = np.array([k[1] for k in b])
imgs = torch.from_numpy(imgs)
imgs = imgs.unsqueeze(1)
true_masks = torch.from_numpy(true_masks)
pre_masks_eval = torch.zeros(true_masks.shape[0], 14, 256, 256)
true_masks_eval = torch.zeros(true_masks.shape[0], 14, 256, 256)
batchshape = true_masks.shape[0]
batch_dice = torch.zeros(14).cuda()
if gpu:
imgs = imgs.cuda()
true_masks = true_masks.cuda()
net.cuda()
output_img = net(imgs)
input = output_img.cuda()
pre_masks = input.max(1)[1].float() # 索引代表像素所属类别的数字
for ak in range(14):
if ak == 0:
continue
pre_masks_eval[:, ak] = (pre_masks == ak)
true_masks_eval[:, ak] = (true_masks == ak)
premasks = pre_masks_eval[:, ak].view(true_masks.shape[0], -1)
truemasks = true_masks_eval[:, ak].view(true_masks.shape[0], -1)
intersection = premasks * truemasks
TP = intersection.sum(1)
FP = premasks.sum(1) - TP
FN = truemasks.sum(1) - TP
TN = torch.zeros(14).cuda()
TN[ak] = (premasks * truemasks).sum(1).eq(0).sum()
for bk in range(true_masks.shape[0]):
if TP[bk] == 0 and FP[bk] == 0 and FN[bk] == 0:
NE[ak] += 1
JNE[ak] += 1
代码解释:
-
TN 计算:
- 在每个类别的循环中,我们将预测的掩膜和真实的掩膜进行比较,将不是当前类别的像素置为 1。
- 然后,我们计算与真实掩膜的交集。如果交集为 0,则表示该像素被正确地分类为非当前类别,即为 TN。
- 最后,我们将所有类别的 TN 累加得到总的 TN。
-
代码逻辑:
(premasks * truemasks).sum(1)计算每个样本中预测和真实掩膜的交集。.eq(0)判断交集是否为 0。.sum()对每个样本中 TN 的数量进行计数,并将其累加得到总的 TN。
通过以上代码,您可以轻松地在 PyTorch 中计算 TN 指标,并将其用于评估模型的性能。
原文地址: https://www.cveoy.top/t/topic/fRUb 著作权归作者所有。请勿转载和采集!