PyTorch图像分类:训练与验证实战
import torch
from net import simpleconv
from torchvision import transforms, datasets
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tensorboardX import SummaryWriter
import time
import os
def train(train_loader, model, criterion, optimizer, device, len_train, batch_size):
'''
训练模型的一个epoch
参数:
train_loader: 训练数据加载器
model: 模型
criterion: 损失函数
optimizer: 优化器
device: 设备
len_train: 训练集大小
batch_size: 批次大小
返回值:
train_loss: 平均训练损失
train_acc: 训练准确率
'''
num_loss = 0.0
num_corrects = 0.0
model.train()
for i, (data, target) in enumerate(train_loader):
data = data.to(device)
target = target.to(device)
optimizer.zero_grad()
outputs = model(data)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
num_corrects += torch.sum(preds == target).item()
num_loss = num_loss + loss.item()
train_loss = num_loss / (len_train // batch_size + 1)
train_acc = num_corrects / len_train
return train_loss, train_acc
def val(val_loader, model, criterion, device, len_val, batch_size):
'''
验证模型在一个epoch上的性能
参数:
val_loader: 验证数据加载器
model: 模型
criterion: 损失函数
device: 设备
len_val: 验证集大小
batch_size: 批次大小
返回值:
val_loss: 平均验证损失
val_acc: 验证准确率
'''
num_loss = 0.0
num_corrects = 0.0
model.eval()
with torch.no_grad():
for i, (data, target) in enumerate(val_loader):
data = data.to(device)
target = target.to(device)
outputs = model(data)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, target)
num_corrects += torch.sum(preds == target).item()
num_loss = num_loss + loss.item()
val_loss = num_loss / (len_val // batch_size + 1)
val_acc = num_corrects / len_val
return val_loss, val_acc
# 设置网络参数
batch_size = 100
nclass = 13
num_epochs = 364
data_dir = './testimage'
# 初始化网络模型
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
model = simpleconv(nclass).to(device)
# 数据预处理
train_transforms = transforms.Compose([
transforms.Resize((64, 64)),
transforms.RandomResizedCrop(48),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
val_transforms = transforms.Compose([
transforms.Resize((64, 64)),
transforms.CenterCrop(48),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# 加载数据
train_datasets = datasets.ImageFolder(os.path.join(data_dir, 'train'), train_transforms)
val_datasets = datasets.ImageFolder(os.path.join(data_dir, 'test'), val_transforms)
len_train = len(train_datasets)
len_val = len(val_datasets)
# 创建数据加载器
train_loaders = torch.utils.data.DataLoader(
dataset=train_datasets,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
val_loaders = torch.utils.data.DataLoader(
dataset=val_datasets,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model.parameters(), lr=0.1, momentum=0.8)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=200, gamma=0.1)
# 使用Tensorboard记录训练过程
writer = SummaryWriter('runs')
best_acc = 0.0
# 开始训练
for epoch in range(num_epochs):
start = time.time()
train_loss, train_acc = train(train_loaders, model, criterion, optimizer_ft, device, len_train, batch_size)
exp_lr_scheduler.step()
val_loss, val_acc = val(val_loaders, model, criterion, device, len_val, batch_size)
# 保存最佳模型
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), './best.pth')
# 记录训练信息
writer.add_scalar('trainloss', train_loss, epoch)
writer.add_scalar('trainacc', train_acc, epoch)
writer.add_scalar('valloss', val_loss, epoch)
writer.add_scalar('valacc', val_acc, epoch)
end = time.time()
print('[{}/{}]: train_loss:{:.3f}, train_acc:{:.3f}, eval_loss:{:.3f}, eval_acc:{:.3f}, time:{:.3f}'.format(
epoch + 1, num_epochs, train_loss, train_acc, val_loss, val_acc, end - start))
writer.close()
这段代码使用 PyTorch 实现了图像分类任务的训练和验证过程,并包含以下改进:
- 添加注释: 为代码的关键部分添加了详细的注释,提高代码可读性。
- 函数化: 将训练和验证逻辑封装成函数,使代码结构更加清晰。
- 参数设置: 将网络参数集中设置,方便修改和管理。
- 数据预处理: 使用
torchvision.transforms对数据进行增强处理,提高模型泛化能力。 - 学习率调整: 使用
torch.optim.lr_scheduler动态调整学习率,加快模型收敛速度。 - 可视化: 使用
tensorboardX记录训练过程中的指标变化,方便可视化分析。 - 保存最佳模型: 在验证集上取得最佳性能的模型会被保存下来。
希望这些改进能够帮助你更好地理解和使用这段代码进行图像分类任务。
原文地址: https://www.cveoy.top/t/topic/RMx 著作权归作者所有。请勿转载和采集!