以下是 Python 代码,其中包括加载 MNIST 数据集并显示一张图片的完整注释:

import os  # 导入 os 模块,用于操作文件和目录
import struct  # 导入 struct 模块,用于解析二进制数据
import numpy as np  # 导入 NumPy 库,用于对矩阵进行操作
import matplotlib.pyplot as plt  # 导入 matplotlib 库,用于绘制图像

# 定义函数,用于读取 MNIST 数据集
def load_mnist(path, kind='train'):
    'Load MNIST data from `path`'
    labels_path = os.path.join(path, '{}-labels-idx1-ubyte'.format(kind))
    images_path = os.path.join(path, '{}-images-idx3-ubyte'.format(kind))
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        labels = np.fromfile(lbpath, dtype=np.uint8)
    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
        # 将像素值缩放到 0 到 1 之间
        images = ((images / 255.) - .5) * 2
    return images, labels

# 加载 MNIST 数据集
X_train, y_train = load_mnist('D:\dataset\MNIST', kind='train')
X_test, y_test = load_mnist('D:\dataset\MNIST', kind='t10k')

# 显示第一张图片
plt.imshow(X_train[0].reshape(28, 28), cmap='gray')
plt.show()

代码解释:

  1. 首先我们导入了需要用到的模块,包括 os、struct、numpy 和 matplotlib.pyplot。
  2. 接着我们定义了一个函数 load_mnist,用于读取 MNIST 数据集。该函数接收两个参数,一个是数据集的路径,另一个是数据集的类型(训练集或测试集)。
  3. 在 load_mnist 函数中,我们分别读取了标签文件和图像文件。标签文件中包含了每张图像的标签(即对应的数字),而图像文件中包含了每张图像的像素值。
  4. 我们将像素值缩放到 0 到 1 之间,以便在显示图像时更加美观。
  5. 最后我们调用 load_mnist 函数,加载 MNIST 数据集,并使用 matplotlib.pyplot 库显示第一张图像。我们使用 cmap='gray' 参数指定了灰度图像。
Python MNIST 数据集图片展示:完整代码及注释

原文地址: https://www.cveoy.top/t/topic/npkg 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录