Python MNIST 数据集图片展示:完整代码及注释
以下是 Python 代码,其中包括加载 MNIST 数据集并显示一张图片的完整注释:
import os # 导入 os 模块,用于操作文件和目录
import struct # 导入 struct 模块,用于解析二进制数据
import numpy as np # 导入 NumPy 库,用于对矩阵进行操作
import matplotlib.pyplot as plt # 导入 matplotlib 库,用于绘制图像
# 定义函数,用于读取 MNIST 数据集
def load_mnist(path, kind='train'):
'Load MNIST data from `path`'
labels_path = os.path.join(path, '{}-labels-idx1-ubyte'.format(kind))
images_path = os.path.join(path, '{}-images-idx3-ubyte'.format(kind))
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
# 将像素值缩放到 0 到 1 之间
images = ((images / 255.) - .5) * 2
return images, labels
# 加载 MNIST 数据集
X_train, y_train = load_mnist('D:\dataset\MNIST', kind='train')
X_test, y_test = load_mnist('D:\dataset\MNIST', kind='t10k')
# 显示第一张图片
plt.imshow(X_train[0].reshape(28, 28), cmap='gray')
plt.show()
代码解释:
- 首先我们导入了需要用到的模块,包括 os、struct、numpy 和 matplotlib.pyplot。
- 接着我们定义了一个函数 load_mnist,用于读取 MNIST 数据集。该函数接收两个参数,一个是数据集的路径,另一个是数据集的类型(训练集或测试集)。
- 在 load_mnist 函数中,我们分别读取了标签文件和图像文件。标签文件中包含了每张图像的标签(即对应的数字),而图像文件中包含了每张图像的像素值。
- 我们将像素值缩放到 0 到 1 之间,以便在显示图像时更加美观。
- 最后我们调用 load_mnist 函数,加载 MNIST 数据集,并使用 matplotlib.pyplot 库显示第一张图像。我们使用 cmap='gray' 参数指定了灰度图像。
原文地址: https://www.cveoy.top/t/topic/npkg 著作权归作者所有。请勿转载和采集!