从零开始训练图像超分辨率模型:基于GUnet的实现
从零开始训练图像超分辨率模型:基于GUnet的实现
本教程将指导您从零开始训练一个图像超分辨率模型,使用GUnet网络架构。我们将深入探讨模型构建、训练过程和评估方法。
1. 准备工作
1.1 安装必要的库
import os
import argparse
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils import AverageMeter, CosineScheduler, pad_img
from datasets import PairLoader
from models import *
1.2 定义命令行参数
## 定义命令行参数
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='gunet_t', type=str, help='model name')# 模型名称
parser.add_argument('--num_workers', default=16, type=int, help='number of workers')# 工作进程数
parser.add_argument('--use_mp', action='store_true', default=False, help='use Mixed Precision')# 是否使用混合精度训练(混合精度训练就是f为32位或16位)
parser.add_argument('--save_dir', default='./saved_models/', type=str, help='path to models saving')# 模型保存路径
parser.add_argument('--data_dir', default='./data/', type=str, help='path to dataset')# 数据集路径
parser.add_argument('--log_dir', default='./logs/', type=str, help='path to logs')# 训练日志路径
parser.add_argument('--train_set', default='SOTS-OUT', type=str, help='train dataset name') # 训练集名称
parser.add_argument('--val_set', default='SOTS-IN/SOTS-IN', type=str, help='valid dataset name')# 验证集名称
parser.add_argument('--exp', default='reside-in', type=str, help='experiment setting')# 实验设置
args = parser.parse_args()
1.3 设置训练环境
# training environment训练环境
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
1.4 加载配置文件
# training config加载配置文件
with open(os.path.join('configs', args.exp, 'base.json'), 'r') as f:
b_setup = json.load(f)# 读取基本配置
variant = args.model.split('_')[-1]
config_name = 'model_'+variant+'.json' if variant in ['t', 's', 'b', 'd'] else 'default.json' # default.json as baselines' configuration file
## 模型配置文件名称
with open(os.path.join('configs', args.exp, config_name), 'r') as f:
m_setup = json.load(f) # 读取模型配置
print(m_setup)
2. 定义训练函数
def train(train_loader, network, criterion, optimizer, scaler, frozen_bn=False):#损失函数criterion优化器optimizer
losses = AverageMeter()# 损失函数值的平均值
torch.cuda.empty_cache()
# PyTorch是有缓存区的设置的,意思
network.eval() if frozen_bn else network.train()# 设置网络模型的训练状态?
#
# simplified implementation that other modules may be affect
#简化了其他模块可能受到影响的实现
for batch in train_loader:# 遍历训练集
source_img = batch['source'].cuda() # 输入图像
target_img = batch['target'].cuda()# 目标图像
#模型和相应的数据进行.cuda()处理。就可以将内存中的数据复制到GPU的显存中去
with autocast(args.use_mp):# 自动混合精度训练
output = network(source_img)# 模型输出
loss = criterion(output, target_img)#计算损失,利用传进来的损失函数
optimizer.zero_grad() # 优化器梯度归零
scaler.scale(loss).backward()# 反向传播
scaler.step(optimizer)# 参数更新
scaler.update()# 更新梯度缩放器
losses.update(loss.item()) # 更新平均损失函数值
return losses.avg# 返回平均损失函数值
3. 定义验证函数
# 验证函数
def valid(val_loader, network):
PSNR = AverageMeter()# PSNR的平均值
torch.cuda.empty_cache()# 清空缓存
network.eval()# 设置网络模型的验证状态
for batch in val_loader:# 遍历验证集
source_img = batch['source'].cuda()# 输入图像
target_img = batch['target'].cuda()# 目标图像
with torch.no_grad():# 不进行梯度计算
H, W = source_img.shape[2:]
source_img = pad_img(source_img, network.module.patch_size if hasattr(network.module, 'patch_size') else 16)
output = network(source_img).clamp_(-1, 1)
output = output[:, :, :H, :W]
mse_loss = F.mse_loss(output * 0.5 + 0.5, target_img * 0.5 + 0.5, reduction='none').mean((1, 2, 3))
psnr = 10 * torch.log10(1 / mse_loss).mean()# 计算PSNR
# if args.use_ddp: psnr = reduce_mean(psnr, dist.get_world_size()) # comment this line for more accurate validation
PSNR.update(psnr.item(), source_img.size(0))# 更新平均PSNR值
return PSNR.avg# 返回平均PSNR值
4. 定义主函数
def main():#定义主函数
# define network, and use DDP for faster training
# 定义网络模型
network = eval(args.model)()# 根据模型名称创建模型
network.cuda()# 将模型移动到GPU上
# define loss function定义损失函数
criterion = nn.L1Loss()
# define optimizer定义优化器
optimizer = torch.optim.AdamW(network.parameters(), lr=m_setup['lr'], weight_decay=b_setup['weight_decay'])
# 使用AdamW优化器,设置学习率和权重衰减
# 定义学习率和权重衰减的调整方式
lr_scheduler = CosineScheduler(optimizer, param_name='lr', t_max=b_setup['epochs'], value_min=m_setup['lr'] * 1e-2,
warmup_t=b_setup['warmup_epochs'], const_t=b_setup['const_epochs'])
wd_scheduler = CosineScheduler(optimizer, param_name='weight_decay', t_max=b_setup['epochs']) # seems not to work
# 使用GradScaler进行梯度缩放
scaler = GradScaler()
# load saved model载入已保存好的模型
save_dir = os.path.join(args.save_dir, args.exp)# 设置保存模型的文件夹路径
os.makedirs(save_dir, exist_ok=True)# 创建保存模型的文件夹
if not os.path.exists(os.path.join(save_dir, args.model+'.pth')):# 如果该模型不存在
best_psnr = 0# 初始最好的PSNR为0
cur_epoch = 0# 当前epoch为0
else:
if not args.use_ddp or local_rank == 0: print('==> Loaded existing trained model.')# 如果不是DDP模式或者当前进程是主进程,输出已加载的模型提示信息
model_info = torch.load(os.path.join(save_dir, args.model+'.pth'), map_location='cpu')# 加载模型
network.load_state_dict(model_info['state_dict'])# 载入模型参数
optimizer.load_state_dict(model_info['optimizer'])# 载入优化器参数
lr_scheduler.load_state_dict(model_info['lr_scheduler'])# 载入学习率调整器参数
wd_scheduler.load_state_dict(model_info['wd_scheduler'])# 载入权重衰减调整器参数
scaler.load_state_dict(model_info['scaler'])# 载入梯度缩放器参数
cur_epoch = model_info['cur_epoch']# 当前epoch为已保存的epoch
best_psnr = model_info['best_psnr']# 最好的PSNR为已保存的PSNR
# define dataset定义数据集
train_dataset = PairLoader(os.path.join(args.data_dir, args.train_set), 'train',
b_setup['t_patch_size'],
b_setup['edge_decay'],
b_setup['data_augment'],
b_setup['cache_memory'])# 训练集
train_loader = DataLoader(train_dataset,
batch_size=m_setup['batch_size'] // world_size,# 设置batch_size
sampler=RandomSampler(train_dataset, num_samples=b_setup['num_iter'] // world_size), # 随机采样
num_workers=args.num_workers // world_size,# 设置使用的进程数
pin_memory=True,# 是否将数据保存在固定的内存区域中加速内存的读取
drop_last=True,# 是否丢弃最后未满batch_size的数据
persistent_workers=True) # comment this line for cache_memory
val_dataset = PairLoader(os.path.join(args.data_dir, args.val_set), b_setup['valid_mode'],
b_setup['v_patch_size'])# 验证集
val_loader = DataLoader(val_dataset,
batch_size=max(int(m_setup['batch_size'] * b_setup['v_batch_ratio'] // world_size), 1),
# sampler=DistributedSampler(val_dataset, shuffle=False), # comment this line for more accurate validation
num_workers=args.num_workers // world_size,
pin_memory=True)
# start training开始训练
if not args.use_ddp or local_rank == 0:
print('==> Start training, current model name: ' + args.model)# 如果不是DDP模式或者当前进程是主进程,输出开始训练提示信息
writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.exp, args.model))# 设置日志记录器
for epoch in tqdm(range(cur_epoch, b_setup['epochs'] + 1)):
frozen_bn = epoch > (b_setup['epochs'] - b_setup['frozen_epochs']) # 是否冻结BN层
loss = train(train_loader, network, criterion, optimizer, scaler, frozen_bn)# 训练模型
lr_scheduler.step(epoch + 1)# 更新学习率
wd_scheduler.step(epoch + 1) # 更新权重衰减系数
if not args.use_ddp or local_rank == 0:
writer.add_scalar('train_loss', loss, epoch) # 如果不是DDP模式或者当前进程是主进程,记录训练损失
if epoch % b_setup['eval_freq'] == 0:
avg_psnr = valid(val_loader, network)# 对模型进行验证
if not args.use_ddp or local_rank == 0:
if avg_psnr > best_psnr:# 如果当前PSNR比最好的PSNR好
best_psnr = avg_psnr# 更新最好的PSNR
torch.save({'cur_epoch': epoch + 1,
'best_psnr': best_psnr,
'state_dict': network.state_dict(),
'optimizer' : optimizer.state_dict(),
'lr_scheduler' : lr_scheduler.state_dict(),
'wd_scheduler' : wd_scheduler.state_dict(),
'scaler' : scaler.state_dict()},
os.path.join(save_dir, args.model+'.pth'))# 保存模型参数
writer.add_scalar('valid_psnr', avg_psnr, epoch)# 如果不是DDP模式或者当前进程是主进程,记录验证集的PSNR
writer.add_scalar('best_psnr', best_psnr, epoch)# 如果不是DDP模式或者当前进程是主进程,记录最好的PSNR
if args.use_ddp: dist.barrier()# 如果是DDP模式,等待所有进程都完成验证
if __name__ == '__main__':
main()
5. 总结
本教程演示了如何从零开始训练一个GUnet图像超分辨率模型。您可以根据自己的需要修改代码和配置参数。希望本教程对您有所帮助!
原文地址: https://www.cveoy.top/t/topic/j0ey 著作权归作者所有。请勿转载和采集!