PyTorch Model Training with Selective Pixel Supervision for Improved Accuracy
import torch
def train(train_loader, model, criterion, optimizer, epoch, args, tb_writer, fp16_args):
total_losses, ssim_losses, mse_losses, batch_time, data_time = [AverageMeter() for _ in range(5)]
norm_ssim_losses = AverageMeter()
max_lr = max([param['lr'] for param in optimizer.param_groups])
print('epoch %d, processed %d samples, lr %.10f' % (epoch, epoch * len(train_loader.dataset), max_lr))
tb_writer.add_scalar('lr', max_lr, epoch)
model.train()
if args['norm_eval'] and args['model_type'].lower() != 'hrnet':
if args['norm_eval_encoder']:
model.encoder = freeze_bn(model.encoder)
else:
model = freeze_bn(model)
end = time.time()
for i_batch, (fname, img, fidt_map, kpoint) in enumerate(tqdm(train_loader)):
data_time.update(time.time() - end)
if args['fp16']:
with torch.autocast(device_type=fp16_args['device_type'], dtype=fp16_args['dtype'], enabled=True):
d6 = model(img.half().cuda())
mse_loss, ssim_loss = criterion(d6, fidt_map.half().cuda(), kpoint)
else:
if int(args['gpu_id']) >= 0:
img = img.cuda()
fidt_map = fidt_map.cuda()
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
img = img.to('mps')
fidt_map = fidt_map.to('mps')
d6 = model(img)
mse_loss, ssim_loss = criterion(d6, fidt_map, kpoint)
# 计算预测图d6与目标图fidt_map之间的差异
diff_map = torch.abs(d6 - fidt_map)
# 将差异向量展平为一维,并排序
diff_vector = diff_map.view(-1)
sorted_diff, _ = torch.sort(diff_vector)
# 计算差异分位点
num_pixels = len(diff_vector)
threshold_index = int(num_pixels * 0.9)
threshold = sorted_diff[threshold_index]
# 保留差异小于阈值的部分样本用于模型监督
supervised_indices = (diff_vector <= threshold).nonzero()
supervised_diff = diff_vector[supervised_indices]
# 获取监督样本的其他损失
supervised_mse_loss = mse_loss[supervised_indices]
supervised_ssim_loss = ssim_loss[supervised_indices]
# 计算监督损失
supervised_loss = torch.mean(supervised_diff)
if args['loss_type'] == 'MSE':
loss = supervised_loss + torch.sum(supervised_mse_loss)
else:
loss = supervised_loss + torch.sum(supervised_mse_loss) + torch.sum(supervised_ssim_loss)
if args['fp16']:
fp16_args['scaler'].scale(loss).backward()
fp16_args['scaler'].step(optimizer)
fp16_args['scaler'].update()
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_losses.update(loss.item())
mse_losses.update(mse_loss.item())
ssim_losses.update(ssim_loss.item())
norm_ssim_losses.update(ssim_loss.item() / criterion.ssim_coefficient)
batch_time.update(time.time() - end)
end = time.time()
print('Epoch: [{0}] '
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
'Loss {loss.val:.4f} ({loss.avg:.4f}) '.format(epoch, batch_time=batch_time,
data_time=data_time, loss=total_losses))
tb_writer.add_scalar('train/total_loss', total_losses.avg, epoch)
tb_writer.add_scalar('train/mse_loss', mse_losses.avg, epoch)
tb_writer.add_scalar('train/ssim_loss', ssim_losses.avg, epoch)
tb_writer.add_scalar('train/norm_ssim_loss', norm_ssim_losses.avg, epoch)
This code introduces a selective pixel supervision mechanism to improve model accuracy during training. It calculates the difference between the predicted and target images and utilizes pixels with smaller differences for supervision. This approach helps the model focus on areas with significant discrepancies, potentially leading to enhanced performance.
Key steps:
- Calculate pixel difference: Compute the absolute difference between the predicted output (
d6) and the target image (fidt_map). - Sort pixel differences: Flatten the difference map into a vector and sort it to identify pixels with smaller and larger differences.
- Select supervised pixels: Determine a threshold based on the 90th percentile of the sorted differences. Pixels with differences below this threshold are selected for supervision.
- Calculate supervised loss: Calculate the average of the selected pixel differences as the supervised loss.
- Combine supervised and original losses: Combine the supervised loss with the original MSE or SSIM loss based on the chosen loss type (
loss_type) to calculate the final loss used for optimization.
By focusing on specific pixels with greater errors, this strategy allows the model to prioritize learning from areas where it struggles, leading to potentially faster convergence and improved accuracy.
原文地址: https://www.cveoy.top/t/topic/Sj6 著作权归作者所有。请勿转载和采集!