PyTorch DataLoader 性能测试:MNIST 数据集多线程加载
该代码用于测试在不同的CPU核心数下,使用PyTorch的DataLoader载入MNIST数据集的训练效率。
首先,通过导入必要的库,包括multiprocessing库(用于并行处理)、torch库(PyTorch库)、torchvision库(PyTorch的图像数据集处理库)等。其中,transforms.Compose()方法用于组合多个图像变换操作,trainset用于载入MNIST数据集并进行预处理。
然后,通过循环测试不同的CPU核心数(从2开始,每次增加2),使用torch.utils.data.DataLoader()方法载入训练数据集trainset,其中包括了shuffle(是否打乱数据集)、num_workers(使用的CPU核心数)、batch_size(每次载入的数据量)、pin_memory(是否使用固定内存)等参数。
最后,循环训练数据集trainset,记录训练时间并输出。
代码示例:
from time import time
import multiprocessing as mp
import torch
import torchvision
from torchvision import transforms
transform = transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
trainset = torchvision.datasets.MNIST(
root='dataset/',
train=True, #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
transform=transform
)
print(f'num of CPU: {mp.cpu_count()}')
for num_workers in range(2, mp.cpu_count(), 2):
train_loader = torch.utils.data.DataLoader(trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
start = time()
for epoch in range(1, 3):
for i, data in enumerate(train_loader, 0):
pass
end = time()
print('Finish with:{} second, num_workers={}'.format(end - start, num_workers))
解释:
from time import time:导入time模块,用于计时。import multiprocessing as mp:导入multiprocessing模块,用于多进程处理。import torch:导入torch模块,PyTorch 的核心库。import torchvision:导入torchvision模块,用于处理图像数据集。from torchvision import transforms:从torchvision模块中导入transforms类,用于对图像进行预处理。transform = transforms.Compose(...):创建一个图像变换组合,包括将图像转换为张量和归一化。trainset = torchvision.datasets.MNIST(...):载入 MNIST 数据集,并应用预处理。print(f'num of CPU: {mp.cpu_count()}'):打印 CPU 核心数。for num_workers in range(2, mp.cpu_count(), 2):循环遍历不同的 CPU 核心数,从 2 开始,每次增加 2。train_loader = torch.utils.data.DataLoader(...):创建 DataLoader,用于加载 MNIST 数据集,并指定参数:shuffle=True:打乱数据集。num_workers=num_workers:使用指定的 CPU 核心数进行数据加载。batch_size=64:每次加载 64 个样本。pin_memory=True:使用固定内存来提高数据加载速度。
start = time():记录训练开始时间。for epoch in range(1, 3):循环训练 2 个 epoch。for i, data in enumerate(train_loader, 0):循环遍历 DataLoader 加载的数据。end = time():记录训练结束时间。print('Finish with:{} second, num_workers={}'.format(end - start, num_workers)):打印训练时间和使用的 CPU 核心数。
总结:
该代码通过循环测试不同的 CPU 核心数,使用 PyTorch 的 DataLoader 加载 MNIST 数据集,并记录训练时间,从而评估多线程加载数据集的性能表现。
原文地址: https://www.cveoy.top/t/topic/nr9x 著作权归作者所有。请勿转载和采集!