Python Fashion-MNIST 数据集读取与可视化 - 从 CSV 文件构建 Dataset 和 Dataloader
以下是用 Python 代码从本地文件夹中读取 Fashion-MNIST 数据(以 CSV 格式读取,训练集和测试集分别读取),并构建 Dataset 和 Dataloader,以及遍历 Dataloader 绘制前 10 个样本的示例:
import torch
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
# 定义 Fashion-MNIST 数据集类
class FashionMNISTDataset(Dataset):
def __init__(self, csv_file):
self.data = pd.read_csv(csv_file)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
label = self.data.iloc[idx, 0]
img = torch.tensor(self.data.iloc[idx, 1:].values, dtype=torch.float32)
return img, label
# 读取训练集和测试集
train_dataset = FashionMNISTDataset('fashion-mnist_train.csv')
test_dataset = FashionMNISTDataset('fashion-mnist_test.csv')
# 构建训练集和测试集的 Dataloader
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)
# 定义类别名称
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 遍历测试集 Dataloader,并绘制前 10 个样本
for i, (images, labels) in enumerate(test_dataloader):
if i == 2:
break
fig, axs = plt.subplots(1, 10, figsize=(20, 2))
for j in range(10):
axs[j].imshow(images[j].reshape(28, 28), cmap='gray')
axs[j].set_title(class_names[labels[j].item()])
plt.show()
运行结果如下:

可以看到,每次遍历测试集 Dataloader 时,都绘制了前 10 个样本,并标出了对应的类别名称。
本示例展示了如何使用 Python 从 CSV 文件读取 Fashion-MNIST 数据集,构建 Dataset 和 Dataloader,以及如何遍历 Dataloader 并可视化前 10 个样本。这只是一个简单的示例,您可以根据自己的需要进行修改和扩展。例如,您可以尝试使用不同的数据集,或使用更复杂的方法来进行可视化。
原文地址: https://www.cveoy.top/t/topic/nIAZ 著作权归作者所有。请勿转载和采集!