PyTorch医学图像加载器详解:构建高效的数据预处理流程

这篇文章将详细解析一个用于医学图像加载的PyTorch代码,该代码定义了一个名为 get_loader 的函数,它负责加载和预处理训练和验证图像数据。pythondef get_loader(args): data_dir = args.data_dir datalist_json = os.path.join(data_dir, args.json_list) train_transform = transforms.Compose( [ transforms.LoadImaged(keys=['image', 'label']), transforms.EnsureChannelFirstd(keys=['image', 'label']), transforms.Orientationd(keys=['image', 'label'], axcodes='RAS'), transforms.Spacingd( keys=['image', 'label'], pixdim=(args.space_x, args.space_y, args.space_z), mode=('bilinear', 'nearest') ), transforms.ScaleIntensityRanged( keys=['image'], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), transforms.CropForegroundd(keys=['image', 'label'], source_key='image'), transforms.RandCropByPosNegLabeld( keys=['image', 'label'], label_key='label', spatial_size=(args.roi_x, args.roi_y, args.roi_z), pos=1, neg=1, num_samples=4, image_key='image', image_threshold=0, ), transforms.RandFlipd(keys=['image', 'label'], prob=args.RandFlipd_prob, spatial_axis=0), transforms.RandFlipd(keys=['image', 'label'], prob=args.RandFlipd_prob, spatial_axis=1), transforms.RandFlipd(keys=['image', 'label'], prob=args.RandFlipd_prob, spatial_axis=2), transforms.RandRotate90d(keys=['image', 'label'], prob=args.RandRotate90d_prob, max_k=3), transforms.RandScaleIntensityd(keys='image', factors=0.1, prob=args.RandScaleIntensityd_prob), transforms.RandShiftIntensityd(keys='image', offsets=0.1, prob=args.RandShiftIntensityd_prob), transforms.ToTensord(keys=['image', 'label']), ] ) val_transform = transforms.Compose( [ transforms.LoadImaged(keys=['image', 'label']), transforms.EnsureChannelFirstd(keys=['image', 'label']), transforms.Orientationd(keys=['image', 'label'], axcodes='RAS'), transforms.Spacingd( keys=['image', 'label'], pixdim=(args.space_x, args.space_y, args.space_z), mode=('bilinear', 'nearest') ), transforms.ScaleIntensityRanged( keys=['image'], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True ), transforms.CropForegroundd(keys=['image', 'label'], source_key='image'), transforms.ToTensord(keys=['image', 'label']), ] )

if args.test_mode:        test_files = load_decathlon_datalist(datalist_json, True, 'validation', base_dir=data_dir)        test_ds = data.Dataset(data=test_files, transform=val_transform)        test_sampler = Sampler(test_ds, shuffle=False) if args.distributed else None        test_loader = data.DataLoader(            test_ds,            batch_size=1,            shuffle=False,            num_workers=args.workers,            sampler=test_sampler,            pin_memory=True,            persistent_workers=True,        )        loader = test_loader    else:        datalist = load_decathlon_datalist(datalist_json, True, 'training', base_dir=data_dir)        if args.use_normal_dataset:            train_ds = data.Dataset(data=datalist, transform=train_transform)        else:            train_ds = data.CacheDataset(                data=datalist, transform=train_transform, cache_num=150, cache_rate=1.0, num_workers=args.workers            )        train_sampler = Sampler(train_ds) if args.distributed else None        train_loader = data.DataLoader(            train_ds,            batch_size=args.batch_size,            shuffle=(train_sampler is None),            num_workers=args.workers,            sampler=train_sampler,            pin_memory=True,            persistent_workers=True,        )        val_files = load_decathlon_datalist(datalist_json, True, 'validation', base_dir=data_dir)        val_ds = data.Dataset(data=val_files, transform=val_transform)        val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None        val_loader = data.DataLoader(            val_ds,            batch_size=1,            shuffle=False,            num_workers=args.workers,            sampler=val_sampler,            pin_memory=True,            persistent_workers=True,        )        loader = [train_loader, val_loader]

return loader

代码解析

  1. 参数解析: 函数首先解析传入的 args 参数,获取数据路径、预处理参数等信息。2. 数据转换: 定义了两个数据转换管道 train_transformval_transform,分别用于训练和验证数据。这两个管道包含一系列图像操作,例如: - LoadImaged: 加载图像和标签数据。 - EnsureChannelFirstd: 确保图像通道位于第一维度。 - Orientationd: 调整图像方向。 - Spacingd: 调整图像间距。 - ScaleIntensityRanged: 缩放像素强度。 - CropForegroundd: 裁剪前景区域。 - RandCropByPosNegLabeld: 随机裁剪包含正负样本的区域(仅用于训练集)。 - RandFlipd: 随机翻转图像。 - RandRotate90d: 随机旋转图像。 - RandScaleIntensityd: 随机缩放像素强度(仅用于训练集)。 - RandShiftIntensityd: 随机平移像素强度(仅用于训练集)。 - ToTensord: 将数据转换为PyTorch张量。

  2. 数据加载: - 根据 test_mode 参数判断是否为测试模式。 - 如果是测试模式,加载验证集数据,并使用 val_transform 进行预处理,创建 test_loader。 - 如果不是测试模式,加载训练集和验证集数据,分别使用 train_transformval_transform 进行预处理,创建 train_loaderval_loader。4. DataLoader: 使用 DataLoader 类创建数据加载器,用于迭代训练和验证数据。可以设置 batch_sizeshufflenum_workers 等参数来控制数据加载过程。

总结

这个代码展示了如何使用 PyTorch 构建一个完整的医学图像数据加载器,包括数据预处理、数据增强和数据加载等步骤。通过灵活地配置数据转换管道和 DataLoader 参数,可以构建高效的医学图像训练流程。

PyTorch医学图像加载器详解:构建高效的数据预处理流程

原文地址: https://www.cveoy.top/t/topic/fzTf 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录