PyTorch医学图像加载器详解:构建高效的数据预处理流程
PyTorch医学图像加载器详解:构建高效的数据预处理流程
这篇文章将详细解析一个用于医学图像加载的PyTorch代码,该代码定义了一个名为 get_loader 的函数,它负责加载和预处理训练和验证图像数据。pythondef get_loader(args): data_dir = args.data_dir datalist_json = os.path.join(data_dir, args.json_list) train_transform = transforms.Compose( [ transforms.LoadImaged(keys=['image', 'label']), transforms.EnsureChannelFirstd(keys=['image', 'label']), transforms.Orientationd(keys=['image', 'label'], axcodes='RAS'), transforms.Spacingd( keys=['image', 'label'], pixdim=(args.space_x, args.space_y, args.space_z), mode=('bilinear', 'nearest') ), transforms.ScaleIntensityRanged( keys=['image'], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), transforms.CropForegroundd(keys=['image', 'label'], source_key='image'), transforms.RandCropByPosNegLabeld( keys=['image', 'label'], label_key='label', spatial_size=(args.roi_x, args.roi_y, args.roi_z), pos=1, neg=1, num_samples=4, image_key='image', image_threshold=0, ), transforms.RandFlipd(keys=['image', 'label'], prob=args.RandFlipd_prob, spatial_axis=0), transforms.RandFlipd(keys=['image', 'label'], prob=args.RandFlipd_prob, spatial_axis=1), transforms.RandFlipd(keys=['image', 'label'], prob=args.RandFlipd_prob, spatial_axis=2), transforms.RandRotate90d(keys=['image', 'label'], prob=args.RandRotate90d_prob, max_k=3), transforms.RandScaleIntensityd(keys='image', factors=0.1, prob=args.RandScaleIntensityd_prob), transforms.RandShiftIntensityd(keys='image', offsets=0.1, prob=args.RandShiftIntensityd_prob), transforms.ToTensord(keys=['image', 'label']), ] ) val_transform = transforms.Compose( [ transforms.LoadImaged(keys=['image', 'label']), transforms.EnsureChannelFirstd(keys=['image', 'label']), transforms.Orientationd(keys=['image', 'label'], axcodes='RAS'), transforms.Spacingd( keys=['image', 'label'], pixdim=(args.space_x, args.space_y, args.space_z), mode=('bilinear', 'nearest') ), transforms.ScaleIntensityRanged( keys=['image'], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), transforms.CropForegroundd(keys=['image', 'label'], source_key='image'), transforms.ToTensord(keys=['image', 'label']), ] )
if args.test_mode: test_files = load_decathlon_datalist(datalist_json, True, 'validation', base_dir=data_dir) test_ds = data.Dataset(data=test_files, transform=val_transform) test_sampler = Sampler(test_ds, shuffle=False) if args.distributed else None test_loader = data.DataLoader( test_ds, batch_size=1, shuffle=False, num_workers=args.workers, sampler=test_sampler, pin_memory=True, persistent_workers=True, ) loader = test_loader else: datalist = load_decathlon_datalist(datalist_json, True, 'training', base_dir=data_dir) if args.use_normal_dataset: train_ds = data.Dataset(data=datalist, transform=train_transform) else: train_ds = data.CacheDataset( data=datalist, transform=train_transform, cache_num=150, cache_rate=1.0, num_workers=args.workers ) train_sampler = Sampler(train_ds) if args.distributed else None train_loader = data.DataLoader( train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, sampler=train_sampler, pin_memory=True, persistent_workers=True, ) val_files = load_decathlon_datalist(datalist_json, True, 'validation', base_dir=data_dir) val_ds = data.Dataset(data=val_files, transform=val_transform) val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None val_loader = data.DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=args.workers, sampler=val_sampler, pin_memory=True, persistent_workers=True, ) loader = [train_loader, val_loader]
return loader
代码解析
-
参数解析: 函数首先解析传入的
args参数,获取数据路径、预处理参数等信息。2. 数据转换: 定义了两个数据转换管道train_transform和val_transform,分别用于训练和验证数据。这两个管道包含一系列图像操作,例如: -LoadImaged: 加载图像和标签数据。 -EnsureChannelFirstd: 确保图像通道位于第一维度。 -Orientationd: 调整图像方向。 -Spacingd: 调整图像间距。 -ScaleIntensityRanged: 缩放像素强度。 -CropForegroundd: 裁剪前景区域。 -RandCropByPosNegLabeld: 随机裁剪包含正负样本的区域(仅用于训练集)。 -RandFlipd: 随机翻转图像。 -RandRotate90d: 随机旋转图像。 -RandScaleIntensityd: 随机缩放像素强度(仅用于训练集)。 -RandShiftIntensityd: 随机平移像素强度(仅用于训练集)。 -ToTensord: 将数据转换为PyTorch张量。 -
数据加载: - 根据
test_mode参数判断是否为测试模式。 - 如果是测试模式,加载验证集数据,并使用val_transform进行预处理,创建test_loader。 - 如果不是测试模式,加载训练集和验证集数据,分别使用train_transform和val_transform进行预处理,创建train_loader和val_loader。4. DataLoader: 使用DataLoader类创建数据加载器,用于迭代训练和验证数据。可以设置batch_size,shuffle,num_workers等参数来控制数据加载过程。
总结
这个代码展示了如何使用 PyTorch 构建一个完整的医学图像数据加载器,包括数据预处理、数据增强和数据加载等步骤。通过灵活地配置数据转换管道和 DataLoader 参数,可以构建高效的医学图像训练流程。
原文地址: https://www.cveoy.top/t/topic/fzTf 著作权归作者所有。请勿转载和采集!