{"title":"MNIST 数据集可视化 - 使用 PyTorch 展示样本图像","description":"本代码使用 PyTorch 加载 MNIST 数据集并可视化部分样本图像,展示了如何使用 PyTorch 进行数据处理和可视化操作。","keywords":"MNIST, PyTorch, 数据集, 可视化, 图像处理, 深度学习, 机器学习","content":"import torch\nfrom torchvision import datasets, transforms\nimport numpy as np\nfrom torch.utils.data import DataLoader\nimport matplotlib.pyplot as plt\n\nn_samples = 100\n\nX_train = datasets.MNIST(root='./data', train=True, download=True,\n transform=transforms.Compose([transforms.ToTensor()]))\nidx = np.append(np.where(X_train.targets == 0)[0][:n_samples], \n np.where(X_train.targets == 1)[0][:n_samples])\n\nX_train.data = X_train.data[idx]\nX_train.targets = X_train.targets[idx]\n\ntrain_loader = DataLoader(X_train, batch_size=1, shuffle=True)\n\nn_samples_show = 6\n\ndata_iter = iter(train_loader)\nfig, axes = plt.subplots(nrows=1, ncols=n_samples_show, figsize=(10, 3))\n\nwhile n_samples_show > 0:\n images, targets = data_iter.next()\n\n axes[n_samples_show - 1].imshow(images[0].numpy().squeeze(), cmap='gray')\n axes[n_samples_show - 1].set_xticks([])\n axes[n_samples_show - 1].set_yticks([])\n axes[n_samples_show - 1].set_title("Labeled: {}".format(targets.item()))\n \n n_samples_show -= 1\n\nplt.show()"}


原文地址: http://www.cveoy.top/t/topic/p8LZ 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录