请给下列代码的每一行添加注释 # define optimizer定义优化器 optimizer = torchoptimAdamWnetworkparameters lr=m_setuplr weight_decay=b_setupweight_decay lr_scheduler = CosineScheduleroptimizer param_name=lr t_max=b_setupepo
define optimizer定义优化器
optimizer = torch.optim.AdamW(network.parameters(), lr=m_setup['lr'], weight_decay=b_setup['weight_decay']) # 使用AdamW优化器,设置学习率和权重衰减
定义学习率和权重衰减的调整方式
lr_scheduler = CosineScheduler(optimizer, param_name='lr', t_max=b_setup['epochs'], value_min=m_setup['lr'] * 1e-2, warmup_t=b_setup['warmup_epochs'], const_t=b_setup['const_epochs']) wd_scheduler = CosineScheduler(optimizer, param_name='weight_decay', t_max=b_setup['epochs']) # seems not to work
使用GradScaler进行梯度缩放
scaler = GradScaler()
load saved model载入已保存好的模型
save_dir = os.path.join(args.save_dir, args.exp) # 设置保存模型的文件夹路径 os.makedirs(save_dir, exist_ok=True) # 创建保存模型的文件夹 if not os.path.exists(os.path.join(save_dir, args.model+'.pth')): # 如果该模型不存在 best_psnr = 0 # 初始最好的PSNR为0 cur_epoch = 0 # 当前epoch为0 else: if not args.use_ddp or local_rank == 0: print('==> Loaded existing trained model.') # 如果不是DDP模式或者当前进程是主进程,输出已加载的模型提示信息 model_info = torch.load(os.path.join(save_dir, args.model+'.pth'), map_location='cpu') # 加载模型 network.load_state_dict(model_info['state_dict']) # 载入模型参数 optimizer.load_state_dict(model_info['optimizer']) # 载入优化器参数 lr_scheduler.load_state_dict(model_info['lr_scheduler']) # 载入学习率调整器参数 wd_scheduler.load_state_dict(model_info['wd_scheduler']) # 载入权重衰减调整器参数 scaler.load_state_dict(model_info['scaler']) # 载入梯度缩放器参数 cur_epoch = model_info['cur_epoch'] # 当前epoch为已保存的epoch best_psnr = model_info['best_psnr'] # 最好的PSNR为已保存的PSNR
define dataset 定义数据集
train_dataset = PairLoader(os.path.join(args.data_dir, args.train_set), 'train', b_setup['t_patch_size'], b_setup['edge_decay'], b_setup['data_augment'], b_setup['cache_memory']) # 训练集 train_loader = DataLoader(train_dataset, batch_size=m_setup['batch_size'] // world_size, # 设置batch_size sampler=RandomSampler(train_dataset, num_samples=b_setup['num_iter'] // world_size), # 随机采样 num_workers=args.num_workers // world_size, # 设置使用的进程数 pin_memory=True, # 是否将数据保存在固定的内存区域中加速内存的读取 drop_last=True, # 是否丢弃最后未满batch_size的数据 persistent_workers=True) # comment this line for cache_memory
val_dataset = PairLoader(os.path.join(args.data_dir, args.val_set), b_setup['valid_mode'], b_setup['v_patch_size']) # 验证集 val_loader = DataLoader(val_dataset, batch_size=max(int(m_setup['batch_size'] * b_setup['v_batch_ratio'] // world_size), 1), # 设置batch_size # sampler=DistributedSampler(val_dataset, shuffle=False), # comment this line for more accurate validation num_workers=args.num_workers // world_size, pin_memory=True)
start training 开始训练
if not args.use_ddp or local_rank == 0: print('==> Start training, current model name: ' + args.model) # 如果不是DDP模式或者当前进程是主进程,输出开始训练提示信息 writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.exp, args.model)) # 设置日志记录器
for epoch in tqdm(range(cur_epoch, b_setup['epochs'] + 1)): frozen_bn = epoch > (b_setup['epochs'] - b_setup['frozen_epochs']) # 是否冻结BN层
loss = train(train_loader, network, criterion, optimizer, scaler, frozen_bn) # 训练模型
lr_scheduler.step(epoch + 1) # 更新学习率
wd_scheduler.step(epoch + 1) # 更新权重衰减系数
if not args.use_ddp or local_rank == 0:
writer.add_scalar('train_loss', loss, epoch) # 如果不是DDP模式或者当前进程是主进程,记录训练损失
if epoch % b_setup['eval_freq'] == 0:
avg_psnr = valid(val_loader, network) # 对模型进行验证
if not args.use_ddp or local_rank == 0:
if avg_psnr > best_psnr: # 如果当前PSNR比最好的PSNR好
best_psnr = avg_psnr # 更新最好的PSNR
torch.save({'cur_epoch': epoch + 1,
'best_psnr': best_psnr,
'state_dict': network.state_dict(),
'optimizer' : optimizer.state_dict(),
'lr_scheduler' : lr_scheduler.state_dict(),
'wd_scheduler' : wd_scheduler.state_dict(),
'scaler' : scaler.state_dict()},
os.path.join(save_dir, args.model+'.pth')) # 保存模型参数
writer.add_scalar('valid_psnr', avg_psnr, epoch) # 如果不是DDP模式或者当前进程是主进程,记录验证集的PSNR
writer.add_scalar('best_psnr', best_psnr, epoch) # 如果不是DDP模式或者当前进程是主进程,记录最好的PSNR
if args.use_ddp: dist.barrier() # 如果是DDP模式,等待所有进程都完成验
原文地址: https://www.cveoy.top/t/topic/cxa4 著作权归作者所有。请勿转载和采集!