从零开始训练图像超分辨率模型:基于GUnet的实现

本教程将指导您从零开始训练一个图像超分辨率模型,使用GUnet网络架构。我们将深入探讨模型构建、训练过程和评估方法。

1. 准备工作

1.1 安装必要的库

import os
import argparse
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from utils import AverageMeter, CosineScheduler, pad_img
from datasets import PairLoader
from models import *

1.2 定义命令行参数

## 定义命令行参数
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='gunet_t', type=str, help='model name')# 模型名称
parser.add_argument('--num_workers', default=16, type=int, help='number of workers')# 工作进程数
parser.add_argument('--use_mp', action='store_true', default=False, help='use Mixed Precision')# 是否使用混合精度训练(混合精度训练就是f为32位或16位)
parser.add_argument('--save_dir', default='./saved_models/', type=str, help='path to models saving')# 模型保存路径
parser.add_argument('--data_dir', default='./data/', type=str, help='path to dataset')# 数据集路径
parser.add_argument('--log_dir', default='./logs/', type=str, help='path to logs')# 训练日志路径
parser.add_argument('--train_set', default='SOTS-OUT', type=str, help='train dataset name') # 训练集名称
parser.add_argument('--val_set', default='SOTS-IN/SOTS-IN', type=str, help='valid dataset name')# 验证集名称
parser.add_argument('--exp', default='reside-in', type=str, help='experiment setting')# 实验设置
args = parser.parse_args()

1.3 设置训练环境

# training environment训练环境
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

1.4 加载配置文件

# training config加载配置文件
with open(os.path.join('configs', args.exp, 'base.json'), 'r') as f:
    b_setup = json.load(f)# 读取基本配置

variant = args.model.split('_')[-1]
config_name = 'model_'+variant+'.json' if variant in ['t', 's', 'b', 'd'] else 'default.json'    # default.json as baselines' configuration file
## 模型配置文件名称
with open(os.path.join('configs', args.exp, config_name), 'r') as f:
    m_setup = json.load(f) # 读取模型配置
print(m_setup)

2. 定义训练函数

def train(train_loader, network, criterion, optimizer, scaler, frozen_bn=False):#损失函数criterion优化器optimizer
    losses = AverageMeter()# 损失函数值的平均值

    torch.cuda.empty_cache()
    # PyTorch是有缓存区的设置的,意思

    network.eval() if frozen_bn else network.train()# 设置网络模型的训练状态?
    # 
    # simplified implementation that other modules may be affect
    #简化了其他模块可能受到影响的实现

    for batch in train_loader:# 遍历训练集
        source_img = batch['source'].cuda() # 输入图像
        target_img = batch['target'].cuda()# 目标图像
#模型和相应的数据进行.cuda()处理。就可以将内存中的数据复制到GPU的显存中去
        with autocast(args.use_mp):# 自动混合精度训练
            output = network(source_img)# 模型输出
            loss = criterion(output, target_img)#计算损失,利用传进来的损失函数

        optimizer.zero_grad() # 优化器梯度归零
        scaler.scale(loss).backward()# 反向传播
        scaler.step(optimizer)# 参数更新
        scaler.update()# 更新梯度缩放器

        losses.update(loss.item()) # 更新平均损失函数值

    return losses.avg# 返回平均损失函数值

3. 定义验证函数

# 验证函数
def valid(val_loader, network):
    PSNR = AverageMeter()# PSNR的平均值

    torch.cuda.empty_cache()# 清空缓存

    network.eval()# 设置网络模型的验证状态

    for batch in val_loader:# 遍历验证集
        source_img = batch['source'].cuda()# 输入图像
        target_img = batch['target'].cuda()# 目标图像

        with torch.no_grad():# 不进行梯度计算
            H, W = source_img.shape[2:]
            source_img = pad_img(source_img, network.module.patch_size if hasattr(network.module, 'patch_size') else 16)
            output = network(source_img).clamp_(-1, 1)
            output = output[:, :, :H, :W]

        mse_loss = F.mse_loss(output * 0.5 + 0.5, target_img * 0.5 + 0.5, reduction='none').mean((1, 2, 3))
        psnr = 10 * torch.log10(1 / mse_loss).mean()# 计算PSNR
        # if args.use_ddp: psnr = reduce_mean(psnr, dist.get_world_size())		# comment this line for more accurate validation
        
        PSNR.update(psnr.item(), source_img.size(0))# 更新平均PSNR值

    return PSNR.avg# 返回平均PSNR值

4. 定义主函数

def main():#定义主函数
    # define network, and use DDP for faster training
    # 定义网络模型
    network = eval(args.model)()# 根据模型名称创建模型
    network.cuda()# 将模型移动到GPU上

    # define loss function定义损失函数
    criterion = nn.L1Loss()
    
    # define optimizer定义优化器
    optimizer = torch.optim.AdamW(network.parameters(), lr=m_setup['lr'], weight_decay=b_setup['weight_decay'])
    # 使用AdamW优化器,设置学习率和权重衰减

    # 定义学习率和权重衰减的调整方式
    lr_scheduler = CosineScheduler(optimizer, param_name='lr', t_max=b_setup['epochs'], value_min=m_setup['lr'] * 1e-2, 
                                   warmup_t=b_setup['warmup_epochs'], const_t=b_setup['const_epochs'])
    wd_scheduler = CosineScheduler(optimizer, param_name='weight_decay', t_max=b_setup['epochs'])	# seems not to work

    # 使用GradScaler进行梯度缩放
    scaler = GradScaler()

    # load saved model载入已保存好的模型
    save_dir = os.path.join(args.save_dir, args.exp)# 设置保存模型的文件夹路径
    os.makedirs(save_dir, exist_ok=True)# 创建保存模型的文件夹
    if not os.path.exists(os.path.join(save_dir, args.model+'.pth')):# 如果该模型不存在
        best_psnr = 0# 初始最好的PSNR为0
        cur_epoch = 0# 当前epoch为0
    else:
        if not args.use_ddp or local_rank == 0: print('==> Loaded existing trained model.')# 如果不是DDP模式或者当前进程是主进程,输出已加载的模型提示信息
        model_info = torch.load(os.path.join(save_dir, args.model+'.pth'), map_location='cpu')# 加载模型
        network.load_state_dict(model_info['state_dict'])# 载入模型参数
        optimizer.load_state_dict(model_info['optimizer'])# 载入优化器参数
        lr_scheduler.load_state_dict(model_info['lr_scheduler'])# 载入学习率调整器参数
        wd_scheduler.load_state_dict(model_info['wd_scheduler'])# 载入权重衰减调整器参数
        scaler.load_state_dict(model_info['scaler'])# 载入梯度缩放器参数
        cur_epoch = model_info['cur_epoch']# 当前epoch为已保存的epoch
        best_psnr = model_info['best_psnr']# 最好的PSNR为已保存的PSNR

    # define dataset定义数据集
    train_dataset = PairLoader(os.path.join(args.data_dir, args.train_set), 'train', 
                            b_setup['t_patch_size'], 
                            b_setup['edge_decay'], 
                            b_setup['data_augment'], 
                            b_setup['cache_memory'])# 训练集
    train_loader = DataLoader(train_dataset,
                            batch_size=m_setup['batch_size'] // world_size,# 设置batch_size
                            sampler=RandomSampler(train_dataset, num_samples=b_setup['num_iter'] // world_size), # 随机采样
                            num_workers=args.num_workers // world_size,# 设置使用的进程数
                            pin_memory=True,# 是否将数据保存在固定的内存区域中加速内存的读取
                            drop_last=True,# 是否丢弃最后未满batch_size的数据
                            persistent_workers=True)	# comment this line for cache_memory

    val_dataset = PairLoader(os.path.join(args.data_dir, args.val_set), b_setup['valid_mode'], 
                            b_setup['v_patch_size'])# 验证集
    val_loader = DataLoader(val_dataset,
                            batch_size=max(int(m_setup['batch_size'] * b_setup['v_batch_ratio'] // world_size), 1),
                            # sampler=DistributedSampler(val_dataset, shuffle=False),		# comment this line for more accurate validation
                            num_workers=args.num_workers // world_size,
                            pin_memory=True)

    # start training开始训练
    if not args.use_ddp or local_rank == 0:
        print('==> Start training, current model name: ' + args.model)# 如果不是DDP模式或者当前进程是主进程,输出开始训练提示信息
        writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.exp, args.model))# 设置日志记录器

    for epoch in tqdm(range(cur_epoch, b_setup['epochs'] + 1)):
        frozen_bn = epoch > (b_setup['epochs'] - b_setup['frozen_epochs']) # 是否冻结BN层
        
        loss = train(train_loader, network, criterion, optimizer, scaler, frozen_bn)# 训练模型
        lr_scheduler.step(epoch + 1)# 更新学习率
        wd_scheduler.step(epoch + 1) # 更新权重衰减系数

        if not args.use_ddp or local_rank == 0:
            writer.add_scalar('train_loss', loss, epoch) # 如果不是DDP模式或者当前进程是主进程,记录训练损失

        if epoch % b_setup['eval_freq'] == 0:
            avg_psnr = valid(val_loader, network)# 对模型进行验证
            
            if not args.use_ddp or local_rank == 0:
                if avg_psnr > best_psnr:# 如果当前PSNR比最好的PSNR好
                    best_psnr = avg_psnr# 更新最好的PSNR
                    torch.save({'cur_epoch': epoch + 1,
                                'best_psnr': best_psnr,
                                'state_dict': network.state_dict(),
                                'optimizer' : optimizer.state_dict(),
                                'lr_scheduler' : lr_scheduler.state_dict(),
                                'wd_scheduler' : wd_scheduler.state_dict(),
                                'scaler' : scaler.state_dict()}, 
                                os.path.join(save_dir, args.model+'.pth'))# 保存模型参数
                
                writer.add_scalar('valid_psnr', avg_psnr, epoch)# 如果不是DDP模式或者当前进程是主进程,记录验证集的PSNR
                writer.add_scalar('best_psnr', best_psnr, epoch)# 如果不是DDP模式或者当前进程是主进程,记录最好的PSNR
        
            if args.use_ddp: dist.barrier()# 如果是DDP模式,等待所有进程都完成验证
        

if __name__ == '__main__':
    main()

5. 总结

本教程演示了如何从零开始训练一个GUnet图像超分辨率模型。您可以根据自己的需要修改代码和配置参数。希望本教程对您有所帮助!

从零开始训练图像超分辨率模型:基于GUnet的实现

原文地址: https://www.cveoy.top/t/topic/j0ey 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录