gUNet 模型训练教程

本教程将指导您使用 PyTorch 训练 gUNet 模型,这是一个用于图像处理的深度学习模型。

1. 环境配置

确保您已安装以下库:

  • PyTorch
  • torchvision
  • numpy
  • scikit-learn
  • tqdm
  • tensorboard

2. 代码示例

import os
import argparse
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from utils import AverageMeter, CosineScheduler, pad_img
from datasets import PairLoader
from models import *

## 定义命令行参数
parser = argparse.ArgumentParser()
parser.add_argument('--num_workers', default=16, type=int, help='number of workers')#工作进程数
parser.add_argument('--use_mp', action='store_true', default=False, help='use Mixed Precision')# 是否使用混合精度训练(混合精度训练就是f为32位或16位)
parser.add_argument('--use_ddp', action='store_true', default=False, help='use Distributed Data Parallel')# 是否使用分布式训练
parser.add_argument('--save_dir', default='./saved_models/', type=str, help='path to models saving')# 模型保存路径
parser.add_argument('--data_dir', default='./data/', type=str, help='path to dataset')# 数据集路径
parser.add_argument('--log_dir', default='./logs/', type=str, help='path to logs')# 训练日志路径
parser.add_argument('--train_set', default='SOTS-OUT', type=str, help='train dataset name') # 训练集名称
parser.add_argument('--val_set', default='SOTS-IN/SOTS-IN', type=str, help='valid dataset name')# 验证集名称
parser.add_argument('--exp', default='reside-in', type=str, help='experiment setting')# 实验设置
args = parser.parse_args()


# training environment训练环境
if args.use_ddp:
    torch.distributed.init_process_group(backend='nccl', init_method='env://')# 初始化分布式训练
    world_size=16# 分布式训练时设备总数
    local_rank = dist.get_rank()# 获取当前进程的rank
    torch.cuda.set_device(local_rank)# 设置当前进程使用的GPU设备
    if local_rank == 0: print('==> Using DDP.')# 仅进程0输出信息
else:
    world_size=16 # 设备总数

# training config加载配置文件
with open(os.path.join('configs', args.exp, 'base.json'), 'r') as f:
    b_setup = json.load(f)# 读取基本配置

variant = 'UNet'
config_name = 'model_'+variant+'.json'
## 模型配置文件名称
with open(os.path.join('configs', args.exp, config_name), 'r') as f:
    m_setup = json.load(f) # 读取模型配置
print(m_setup)

def reduce_mean(tensor, nprocs):
    rt = tensor.clone()#返回tensor的拷贝,返回的新tensor和原来的tensor具有同样的大小和数据类型。
    # print('--------------------------------------------------------------------------------------------')
    # print(type(rt))
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)#求和, 操作的输入Tensor,同时也会将归约结果返回至此 Tensor 中
    rt /= nprocs# 求平均
    return rt


def train(train_loader, network, criterion, optimizer, scaler, frozen_bn=False):#损失函数criterion优化器optimizer
    losses = AverageMeter()# 损失函数值的平均值
    #train函数是模型训练的入口。首先一些变量的更新采用自定义的AverageMeter类来管理。然后model.train()是设置为训练模式。
    #AverageMeter可以记录当前的输出,累加到某个变量之中,然后根据需要可以打印出历史上的平均

    torch.cuda.empty_cache()
    # PyTorch是有缓存区的设置的,意思就是一个Tensor就算被释放了,进程也不会把空闲出来的显存还给GPU,而是等待下一个Tensor来填入这一片被释放的空间。
    # 所以我们用torch.cuda.empty_cache可以清空缓冲区


    network.eval() if frozen_bn else network.train()# 设置网络模型的训练状态?
    #
    # simplified implementation that other modules may be affect
    #简化了其他模块可能受到影响的实现

    for batch in train_loader:# 遍历训练集
        source_img = batch['source'].cuda() # 输入图像
        target_img = batch['target'].cuda()# 目标图像
#模型和相应的数据进行.cuda()处理。就可以将内存中的数据复制到GPU的显存中去
        with autocast(args.use_mp):# 自动混合精度训练
            output = network(source_img)# 模型输出
            loss = criterion(output, target_img)#计算损失,利用传进来的损失函数

        optimizer.zero_grad() # 优化器梯度归零
        scaler.scale(loss).backward()# 反向传播
        scaler.step(optimizer)# 参数更新
        scaler.update()# 更新梯度缩放器

        if args.use_ddp: loss = reduce_mean(loss, dist.get_world_size()) # 求平均损失函数值
        losses.update(loss.item()) # 更新平均损失函数值

    return losses.avg# 返回平均损失函数值

# 验证函数
def valid(val_loader, network):
    PSNR = AverageMeter()# PSNR的平均值

    torch.cuda.empty_cache()# 清空缓存

    network.eval()# 设置网络模型的验证状态

    for batch in val_loader:# 遍历验证集
        source_img = batch['source'].cuda()# 输入图像
        target_img = batch['target'].cuda()# 目标图像

        with torch.no_grad():# 不进行梯度计算
            H, W = source_img.shape[2:]
            source_img = pad_img(source_img, network.module.patch_size if hasattr(network.module, 'patch_size') else 16)
            output = network(source_img).clamp_(-1, 1)
            output = output[:, :, :H, :W]

        mse_loss = F.mse_loss(output * 0.5 + 0.5, target_img * 0.5 + 0.5, reduction='none').mean((1, 2, 3))
        psnr = 10 * torch.log10(1 / mse_loss).mean()# 计算PSNR
        # if args.use_ddp: psnr = reduce_mean(psnr, dist.get_world_size())		# comment this line for more accurate validation
        
        PSNR.update(psnr.item(), source_img.size(0))# 更新平均PSNR值

    return PSNR.avg# 返回平均PSNR值


def main():#定义主函数
    # define network, and use DDP for faster training
    # 定义网络模型
    network = eval('gUNet')()
    network.cuda()# 将模型移动到GPU上

    if args.use_ddp:
        network = DistributedDataParallel(network, device_ids=[local_rank], output_device=local_rank)# 分布式训练
        if m_setup['batch_size'] // world_size < 16:
            if local_rank == 0: print('==> Using SyncBN because of too small norm-batch-size.')
            nn.SyncBatchNorm.convert_sync_batchnorm(network)# 将批量归一化(BN)转换为同步批量归一化(SyncBN)
    else:
        network = DataParallel(network)# 多GPU训练
        if m_setup['batch_size'] // torch.cuda.device_count() < 16:
            print('==> Using SyncBN because of too small norm-batch-size.')
            convert_model(network)# 转换模型

    # define loss function定义损失函数
    criterion = nn.L1Loss()
    
    # define optimizer定义优化器
    optimizer = torch.optim.AdamW(network.parameters(), lr=m_setup['lr'], weight_decay=b_setup['weight_decay'])
    # 使用AdamW优化器,设置学习率和权重衰减

    # 定义学习率和权重衰减的调整方式
    lr_scheduler = CosineScheduler(optimizer, param_name='lr', t_max=b_setup['epochs'], value_min=m_setup['lr'] * 1e-2, 
                                   warmup_t=b_setup['warmup_epochs'], const_t=b_setup['const_epochs'])
    wd_scheduler = CosineScheduler(optimizer, param_name='weight_decay', t_max=b_setup['epochs'])	# seems not to work

    # 使用GradScaler进行梯度缩放
    scaler = GradScaler()

    # load saved model载入已保存好的模型
    save_dir = os.path.join(args.save_dir, args.exp)# 设置保存模型的文件夹路径
    os.makedirs(save_dir, exist_ok=True)# 创建保存模型的文件夹
    if not os.path.exists(os.path.join(save_dir, 'gUNet'+'.pth')):# 如果该模型不存在
        best_psnr = 0# 初始最好的PSNR为0
        cur_epoch = 0# 当前epoch为0
    else:
        if not args.use_ddp or local_rank == 0: print('==> Loaded existing trained model.')# 如果不是DDP模式或者当前进程是主进程,输出已加载的模型提示信息
        model_info = torch.load(os.path.join(save_dir, 'gUNet'+'.pth'), map_location='cpu')# 加载模型
        network.load_state_dict(model_info['state_dict'])# 载入模型参数
        optimizer.load_state_dict(model_info['optimizer'])# 载入优化器参数
        lr_scheduler.load_state_dict(model_info['lr_scheduler'])# 载入学习率调整器参数
        wd_scheduler.load_state_dict(model_info['wd_scheduler'])# 载入权重衰减调整器参数
        scaler.load_state_dict(model_info['scaler'])# 载入梯度缩放器参数
        cur_epoch = model_info['cur_epoch']# 当前epoch为已保存的epoch
        best_psnr = model_info['best_psnr']# 最好的PSNR为已保存的PSNR

    # define dataset定义数据集
    train_dataset = PairLoader(os.path.join(args.data_dir, args.train_set), 'train', 
                            b_setup['t_patch_size'], 
                            b_setup['edge_decay'], 
                            b_setup['data_augment'], 
                            b_setup['cache_memory'])# 训练集
    train_loader = DataLoader(train_dataset,
                            batch_size=m_setup['batch_size'] // world_size,# 设置batch_size
                            sampler=RandomSampler(train_dataset, num_samples=b_setup['num_iter'] // world_size), # 随机采样
                            num_workers=args.num_workers // world_size,# 设置使用的进程数
                            pin_memory=True,# 是否将数据保存在固定的内存区域中加速内存的读取
                            drop_last=True,# 是否丢弃最后未满batch_size的数据
                            persistent_workers=True)	# comment this line for cache_memory

    val_dataset = PairLoader(os.path.join(args.data_dir, args.val_set), b_setup['valid_mode'], 
                            b_setup['v_patch_size'])# 验证集
    val_loader = DataLoader(val_dataset,
                            batch_size=max(int(m_setup['batch_size'] * b_setup['v_batch_ratio'] // world_size), 1),
                            # sampler=DistributedSampler(val_dataset, shuffle=False),		# comment this line for more accurate validation
                            num_workers=args.num_workers // world_size,
                            pin_memory=True)

    # start training开始训练
    if not args.use_ddp or local_rank == 0:
        print('==> Start training, current model name: ' + 'gUNet')# 如果不是DDP模式或者当前进程是主进程,输出开始训练提示信息
        writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.exp, 'gUNet'))# 设置日志记录器

    for epoch in tqdm(range(cur_epoch, b_setup['epochs'] + 1)):
        frozen_bn = epoch > (b_setup['epochs'] - b_setup['frozen_epochs']) # 是否冻结BN层
        
        loss = train(train_loader, network, criterion, optimizer, scaler, frozen_bn)# 训练模型
        lr_scheduler.step(epoch + 1)# 更新学习率
        wd_scheduler.step(epoch + 1) # 更新权重衰减系数

        if not args.use_ddp or local_rank == 0:
            writer.add_scalar('train_loss', loss, epoch) # 如果不是DDP模式或者当前进程是主进程,记录训练损失

        if epoch % b_setup['eval_freq'] == 0:
            avg_psnr = valid(val_loader, network)# 对模型进行验证
            
            if not args.use_ddp or local_rank == 0:
                if avg_psnr > best_psnr:# 如果当前PSNR比最好的PSNR好
                    best_psnr = avg_psnr# 更新最好的PSNR
                    torch.save({'cur_epoch': epoch + 1,
                                'best_psnr': best_psnr,
                                'state_dict': network.state_dict(),
                                'optimizer' : optimizer.state_dict(),
                                'lr_scheduler' : lr_scheduler.state_dict(),
                                'wd_scheduler' : wd_scheduler.state_dict(),
                                'scaler' : scaler.state_dict()}, os.path.join(save_dir, 'gUNet'+'.pth'))# 保存模型参数
                
                writer.add_scalar('valid_psnr', avg_psnr, epoch)# 如果不是DDP模式或者当前进程是主进程,记录验证集的PSNR
                writer.add_scalar('best_psnr', best_psnr, epoch)# 如果不是DDP模式或者当前进程是主进程,记录最好的PSNR
        
            if args.use_ddp: dist.barrier()# 如果是DDP模式,等待所有进程都完成验证
        

if __name__ == '__main__':
    main()

3. 训练步骤

  1. 准备数据集: 确保您已准备好用于训练和验证的图像数据集。
  2. 定义模型: 使用 models.gUNet 创建 gUNet 模型。
  3. 配置训练参数: 使用命令行参数或配置文件设置训练参数,例如学习率、批次大小、迭代次数等。
  4. 训练模型: 调用 main() 函数开始训练,其中包含训练循环、损失计算、优化器更新、验证等步骤。
  5. 保存模型: 在训练过程中,保存模型的最佳性能参数和模型状态。
  6. 评估模型: 使用验证集评估训练好的模型的性能。

4. 参数调整

您可以通过调整以下参数来优化模型性能:

  • 学习率: 尝试不同的学习率值,例如 1e-4、1e-3 等。
  • 批次大小: 根据您的硬件资源调整批次大小。
  • 迭代次数: 增加迭代次数可以提高模型的准确性,但也会增加训练时间。
  • 优化器: 尝试使用不同的优化器,例如 AdamW、SGD 等。
  • 损失函数: 选择合适的损失函数,例如 L1Loss、MSELoss 等。

5. 总结

本教程提供了使用 PyTorch 训练 gUNet 模型的完整流程,帮助您了解模型训练的基本原理和步骤。 您可以根据自己的需求对代码进行调整和扩展。

6. 附加信息

  • 模型架构: 您可以在 models.py 文件中找到 gUNet 模型的定义。
  • 数据集加载: datasets.py 文件包含了数据集加载和预处理的代码。
  • 训练参数: configs 文件夹包含了各种训练参数设置。

7. 注意事项

  • 本教程使用的是默认参数,您可能需要根据自己的数据集和任务调整参数。
  • 建议使用 GPU 进行训练,以加速训练过程。
  • 注意模型的过拟合问题,并采取适当的措施防止过拟合,例如使用 dropout、正则化等技术。

原文地址: https://www.cveoy.top/t/topic/j0kX 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录