PyTorch 数据加载和训练:输入和标签类型解析
在 PyTorch 中,图像数据加载和训练过程中,输入(input)和标签(target)是两个关键数据结构。
输入 (input)
input 是一个批次(batch)的图像数据,类型为 torch.Tensor,形状为 [batch_size, channel, height, width]。
batch_size:批次大小,表示每个批次包含的图像数量。channel:通道数,通常为 3,表示 RGB 三个通道。height和width:图像的高度和宽度。
标签 (target)
target 是一个批次的标签数据,类型为 torch.Tensor,形状为 [batch_size]。每个元素代表该图像对应的类别标签。
示例代码:
traindir = os.path.join('/private/PT4AL-new/DATA', 'train')
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
augmentation = [
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomApply(
[transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8 # not strengthened
),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([moco.loader.GaussianBlur([0.1, 2.0])], p=0.5),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
train_dataset = datasets.ImageFolder(
traindir, moco.loader.TwoCropsTransform(transforms.Compose(augmentation))
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
num_workers=2,
pin_memory=True,
# sampler=train_sampler,
drop_last=True,
)
for batch_idx, (input, target) in enumerate(train_loader):
# print(type(input))
# print(type(target))
input, target = input.to(device), target.to(device)
output = net(input)
loss = criterion(output, target)
test_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
总结:
input 和 target 是 PyTorch 训练过程中不可或缺的数据结构,分别代表了图像数据和对应的类别标签,它们在模型训练、损失计算和预测评估中扮演着重要角色。
原文地址: https://www.cveoy.top/t/topic/lC0o 著作权归作者所有。请勿转载和采集!