在 PyTorch 中,图像数据加载和训练过程中,输入(input)和标签(target)是两个关键数据结构。

输入 (input)

input 是一个批次(batch)的图像数据,类型为 torch.Tensor,形状为 [batch_size, channel, height, width]

  • batch_size:批次大小,表示每个批次包含的图像数量。
  • channel:通道数,通常为 3,表示 RGB 三个通道。
  • heightwidth:图像的高度和宽度。

标签 (target)

target 是一个批次的标签数据,类型为 torch.Tensor,形状为 [batch_size]。每个元素代表该图像对应的类别标签。

示例代码:

traindir = os.path.join('/private/PT4AL-new/DATA', 'train')
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)

augmentation = [
            transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
            transforms.RandomApply(
                [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8  # not strengthened
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([moco.loader.GaussianBlur([0.1, 2.0])], p=0.5),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,

        ]

train_dataset = datasets.ImageFolder(
        traindir, moco.loader.TwoCropsTransform(transforms.Compose(augmentation))
    )
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        # sampler=train_sampler,
        drop_last=True,

    )
        for batch_idx, (input, target) in enumerate(train_loader):
            # print(type(input))
            # print(type(target))
            input, target = input.to(device), target.to(device)
            output = net(input)
            loss = criterion(output, target)
            test_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

总结:

inputtarget 是 PyTorch 训练过程中不可或缺的数据结构,分别代表了图像数据和对应的类别标签,它们在模型训练、损失计算和预测评估中扮演着重要角色。

PyTorch 数据加载和训练:输入和标签类型解析

原文地址: https://www.cveoy.top/t/topic/lC0o 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录