PyTorch图片分类代码示例 - 从入门到进阶
使用PyTorch构建图片分类器:从入门到进阶
这篇博客将带你学习如何使用PyTorch编写一个简单的图片分类器。我们将使用CIFAR-10数据集进行训练和测试,并逐步讲解每一步代码的含义。
1. 准备工作
首先,我们需要导入必要的库:pythonimport torchimport torch.nn as nnimport torch.optim as optimimport torchvisionimport torchvision.transforms as transforms
- torch: PyTorch的核心库,包含了张量操作和自动求导等功能。- torch.nn: 包含了神经网络模块和损失函数等。- torch.optim: 包含了各种优化算法,例如SGD、Adam等。- torchvision: 包含了常用的数据集、模型架构和图像处理工具。- torchvision.transforms: 用于对图像数据进行预处理。
2. 数据加载和预处理
接下来,我们加载CIFAR-10数据集并进行预处理:python# 设置随机种子torch.manual_seed(42)
定义数据预处理和加载器transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
- transforms.Compose: 将多个图像变换操作组合在一起。- transforms.ToTensor(): 将图像数据转换为PyTorch张量。- transforms.Normalize(): 对图像数据进行标准化处理。- torchvision.datasets.CIFAR10(): 加载CIFAR-10数据集。- torch.utils.data.DataLoader(): 创建数据加载器,用于迭代加载数据。
3. 定义模型结构
我们定义一个简单的卷积神经网络模型:python# 定义模型结构class Net(nn.Module): def init(self): super(Net, self).init() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10)
def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = self.pool(torch.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x
net = Net()
- nn.Module: 所有神经网络模块的基类。- nn.Conv2d: 2D卷积层。- nn.MaxPool2d: 2D最大池化层。- nn.Linear: 全连接层。- torch.relu: ReLU激活函数。
4. 定义损失函数和优化器python# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
- nn.CrossEntropyLoss(): 交叉熵损失函数,常用于多分类问题。- optim.SGD(): 随机梯度下降优化器。
5. 训练模型python# 训练模型for epoch in range(10): # 进行10个epoch的训练 running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step()
running_loss += loss.item() if i % 2000 == 1999: # 每2000个小批量数据打印一次损失 print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0
print('Finished Training')
- epoch: 训练过程中遍历数据集的次数。- optimizer.zero_grad(): 清空梯度缓存。- loss.backward(): 反向传播计算梯度。- optimizer.step(): 更新模型参数。
6. 评估模型python# 使用测试集评估模型correct = 0total = 0with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %.2f %%' % ( 100 * correct / total))
- torch.no_grad(): 表示在评估过程中不需要计算梯度。- torch.max(): 返回输入张量中每一行的最大值及其索引。
7. 总结
这只是一个简单的PyTorch图片分类代码示例,你可以根据自己的需求进行修改和扩展。例如,你可以尝试使用更复杂的模型架构、不同的优化算法、数据增强技巧等,以提高模型的性能。
原文地址: https://www.cveoy.top/t/topic/bqjS 著作权归作者所有。请勿转载和采集!