TensorFlow 时尚 MNIST 分类 - 构建简单三层神经网络
以下代码展示了如何使用 TensorFlow 构建一个简单的三层神经网络,用于对时尚 MNIST 数据集进行分类。
import tensorflow as tf
import os
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./checkpoint/fashion.ckpt" # 指定保存文件的地址
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, # 给出模型保存路径
save_weights_only=True, # 只保留模型参数
save_best_only=True) # 只保留最优模型
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback])
model.summary()
这段代码的神经网络为什么是三层?
这段代码的神经网络是三层的,因为它包含了两个密集连接层(Dense layer)和一个扁平化层(Flatten layer)。
-
扁平化层 (Flatten layer): 这是输入层,将输入的二维图像数据转换为一维向量。在这种情况下,图像的高度和宽度被展平为一个长度为784的向量。
-
密集连接层 (Dense layer): 这是隐藏层,具有128个神经元。每个神经元都与前一层的所有神经元相连,通过权重和激活函数将输入进行加权和转换。
-
密集连接层 (Dense layer): 这是输出层,具有10个神经元,对应于数据集中的10个类别。在这种情况下,使用softmax激活函数将输出转换为每个类别的概率分布。
因此,这个神经网络的结构是输入层 -> 隐藏层 -> 输出层,共计三层。
代码解释:
- 数据加载: 使用
fashion = tf.keras.datasets.fashion_mnist加载时尚 MNIST 数据集,并将其分为训练集和测试集。 - 数据预处理: 将图像像素值除以255.0,将像素值归一化到0到1之间。
- 模型构建: 使用
tf.keras.models.Sequential创建一个顺序模型,并添加三个层:tf.keras.layers.Flatten():将二维图像数据展平成一维向量。tf.keras.layers.Dense(128, activation='relu'):第一个密集连接层,具有128个神经元,使用 ReLU 激活函数。tf.keras.layers.Dense(10, activation='softmax'):输出层,具有10个神经元,使用 softmax 激活函数。
- 模型编译: 使用
model.compile设置优化器、损失函数和评估指标。 - 模型训练: 使用
model.fit训练模型,指定训练集、测试集、批次大小、训练轮数、验证频率和回调函数。 - 模型保存: 使用
tf.keras.callbacks.ModelCheckpoint回调函数保存模型权重。 - 模型评估: 使用
model.summary()打印模型摘要。
原文地址: https://www.cveoy.top/t/topic/ccjb 著作权归作者所有。请勿转载和采集!